|
|
|
@ -191,13 +191,16 @@ class ControlNet(ControlBase):
|
|
|
|
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
|
|
|
|
|
|
|
|
|
context = cond.get('crossattn_controlnet', cond['c_crossattn'])
|
|
|
|
|
y = cond.get('y', None)
|
|
|
|
|
if y is not None:
|
|
|
|
|
y = y.to(dtype)
|
|
|
|
|
extra = self.extra_args.copy()
|
|
|
|
|
for c in ["y", "guidance"]: #TODO
|
|
|
|
|
temp = cond.get(c, None)
|
|
|
|
|
if temp is not None:
|
|
|
|
|
extra[c] = temp.to(dtype)
|
|
|
|
|
|
|
|
|
|
timestep = self.model_sampling_current.timestep(t)
|
|
|
|
|
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
|
|
|
|
|
|
|
|
|
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, **self.extra_args)
|
|
|
|
|
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
|
|
|
|
return self.control_merge(control, control_prev, output_dtype)
|
|
|
|
|
|
|
|
|
|
def copy(self):
|
|
|
|
@ -338,12 +341,8 @@ class ControlLora(ControlNet):
|
|
|
|
|
def inference_memory_requirements(self, dtype):
|
|
|
|
|
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
|
|
|
|
|
|
|
|
|
def load_controlnet_mmdit(sd):
|
|
|
|
|
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
|
|
|
|
model_config = comfy.model_detection.model_config_from_unet(new_sd, "", True)
|
|
|
|
|
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
|
|
|
|
for k in sd:
|
|
|
|
|
new_sd[k] = sd[k]
|
|
|
|
|
def controlnet_config(sd):
|
|
|
|
|
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
|
|
|
|
|
|
|
|
|
|
supported_inference_dtypes = model_config.supported_inference_dtypes
|
|
|
|
|
|
|
|
|
@ -356,14 +355,27 @@ def load_controlnet_mmdit(sd):
|
|
|
|
|
else:
|
|
|
|
|
operations = comfy.ops.disable_weight_init
|
|
|
|
|
|
|
|
|
|
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **controlnet_config)
|
|
|
|
|
missing, unexpected = control_model.load_state_dict(new_sd, strict=False)
|
|
|
|
|
return model_config, operations, load_device, unet_dtype, manual_cast_dtype
|
|
|
|
|
|
|
|
|
|
def controlnet_load_state_dict(control_model, sd):
|
|
|
|
|
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
|
|
|
|
|
|
|
|
|
if len(missing) > 0:
|
|
|
|
|
logging.warning("missing controlnet keys: {}".format(missing))
|
|
|
|
|
|
|
|
|
|
if len(unexpected) > 0:
|
|
|
|
|
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
|
|
|
|
return control_model
|
|
|
|
|
|
|
|
|
|
def load_controlnet_mmdit(sd):
|
|
|
|
|
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
|
|
|
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
|
|
|
|
|
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
|
|
|
|
for k in sd:
|
|
|
|
|
new_sd[k] = sd[k]
|
|
|
|
|
|
|
|
|
|
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
|
|
|
|
control_model = controlnet_load_state_dict(control_model, new_sd)
|
|
|
|
|
|
|
|
|
|
latent_format = comfy.latent_formats.SD3()
|
|
|
|
|
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
|
|
|
|