|
|
|
@ -391,7 +391,8 @@ def controlnet_config(sd):
|
|
|
|
|
else:
|
|
|
|
|
operations = comfy.ops.disable_weight_init
|
|
|
|
|
|
|
|
|
|
return model_config, operations, load_device, unet_dtype, manual_cast_dtype
|
|
|
|
|
offload_device = comfy.model_management.unet_offload_device()
|
|
|
|
|
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
|
|
|
|
|
|
|
|
|
def controlnet_load_state_dict(control_model, sd):
|
|
|
|
|
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
|
|
|
@ -405,12 +406,12 @@ def controlnet_load_state_dict(control_model, sd):
|
|
|
|
|
|
|
|
|
|
def load_controlnet_mmdit(sd):
|
|
|
|
|
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
|
|
|
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
|
|
|
|
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
|
|
|
|
|
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
|
|
|
|
for k in sd:
|
|
|
|
|
new_sd[k] = sd[k]
|
|
|
|
|
|
|
|
|
|
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
|
|
|
|
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
|
|
|
|
control_model = controlnet_load_state_dict(control_model, new_sd)
|
|
|
|
|
|
|
|
|
|
latent_format = comfy.latent_formats.SD3()
|
|
|
|
@ -420,9 +421,9 @@ def load_controlnet_mmdit(sd):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_controlnet_hunyuandit(controlnet_data):
|
|
|
|
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
|
|
|
|
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data)
|
|
|
|
|
|
|
|
|
|
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
|
|
|
|
|
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
|
|
|
|
|
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
|
|
|
|
|
|
|
|
|
latent_format = comfy.latent_formats.SDXL()
|
|
|
|
@ -431,8 +432,8 @@ def load_controlnet_hunyuandit(controlnet_data):
|
|
|
|
|
return control
|
|
|
|
|
|
|
|
|
|
def load_controlnet_flux_xlabs(sd):
|
|
|
|
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(sd)
|
|
|
|
|
control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
|
|
|
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
|
|
|
|
|
control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
|
|
|
|
control_model = controlnet_load_state_dict(control_model, sd)
|
|
|
|
|
extra_conds = ['y', 'guidance']
|
|
|
|
|
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
|
|
|
@ -536,6 +537,7 @@ def load_controlnet(ckpt_path, model=None):
|
|
|
|
|
if manual_cast_dtype is not None:
|
|
|
|
|
controlnet_config["operations"] = comfy.ops.manual_cast
|
|
|
|
|
controlnet_config["dtype"] = unet_dtype
|
|
|
|
|
controlnet_config["device"] = comfy.model_management.unet_offload_device()
|
|
|
|
|
controlnet_config.pop("out_channels")
|
|
|
|
|
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
|
|
|
|
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
|
|
|
|