From c19dcd362f5e32ce4800e600b91d09c89b19ab4f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 7 Aug 2024 12:59:28 -0400 Subject: [PATCH] Controlnet code refactor. --- comfy/controlnet.py | 36 ++++++++++++++++++++++++------------ comfy/model_detection.py | 4 ++-- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 12e5f16..97e4f4d 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -191,13 +191,16 @@ class ControlNet(ControlBase): self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) context = cond.get('crossattn_controlnet', cond['c_crossattn']) - y = cond.get('y', None) - if y is not None: - y = y.to(dtype) + extra = self.extra_args.copy() + for c in ["y", "guidance"]: #TODO + temp = cond.get(c, None) + if temp is not None: + extra[c] = temp.to(dtype) + timestep = self.model_sampling_current.timestep(t) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) - control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, **self.extra_args) + control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra) return self.control_merge(control, control_prev, output_dtype) def copy(self): @@ -338,12 +341,8 @@ class ControlLora(ControlNet): def inference_memory_requirements(self, dtype): return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype) -def load_controlnet_mmdit(sd): - new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "") - model_config = comfy.model_detection.model_config_from_unet(new_sd, "", True) - num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.') - for k in sd: - new_sd[k] = sd[k] +def controlnet_config(sd): + model_config = comfy.model_detection.model_config_from_unet(sd, "", True) supported_inference_dtypes = model_config.supported_inference_dtypes @@ -356,14 +355,27 @@ def load_controlnet_mmdit(sd): else: operations = comfy.ops.disable_weight_init - control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **controlnet_config) - missing, unexpected = control_model.load_state_dict(new_sd, strict=False) + return model_config, operations, load_device, unet_dtype, manual_cast_dtype + +def controlnet_load_state_dict(control_model, sd): + missing, unexpected = control_model.load_state_dict(sd, strict=False) if len(missing) > 0: logging.warning("missing controlnet keys: {}".format(missing)) if len(unexpected) > 0: logging.debug("unexpected controlnet keys: {}".format(unexpected)) + return control_model + +def load_controlnet_mmdit(sd): + new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "") + model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd) + num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.') + for k in sd: + new_sd[k] = sd[k] + + control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config) + control_model = controlnet_load_state_dict(control_model, new_sd) latent_format = comfy.latent_formats.SD3() latent_format.shift_factor = 0 #SD3 controlnet weirdness diff --git a/comfy/model_detection.py b/comfy/model_detection.py index c471196..15e6b73 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -137,8 +137,8 @@ def detect_unet_config(state_dict, key_prefix): dit_config["hidden_size"] = 3072 dit_config["mlp_ratio"] = 4.0 dit_config["num_heads"] = 24 - dit_config["depth"] = 19 - dit_config["depth_single_blocks"] = 38 + dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') + dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') dit_config["axes_dim"] = [16, 56, 56] dit_config["theta"] = 10000 dit_config["qkv_bias"] = True