diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 7b202b7..d2d2cef 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -444,7 +444,12 @@ def load_controlnet_flux_instantx(sd): for k in sd: new_sd[k] = sd[k] - control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) + num_union_modes = 0 + union_cnet = "controlnet_mode_embedder.weight" + if union_cnet in new_sd: + num_union_modes = new_sd[union_cnet].shape[0] + + control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) control_model = controlnet_load_state_dict(control_model, new_sd) latent_format = comfy.latent_formats.Flux() diff --git a/comfy/ldm/flux/controlnet.py b/comfy/ldm/flux/controlnet.py index 2c658a4..2598e71 100644 --- a/comfy/ldm/flux/controlnet.py +++ b/comfy/ldm/flux/controlnet.py @@ -14,7 +14,7 @@ import comfy.ldm.common_dit class ControlNetFlux(Flux): - def __init__(self, latent_input=False, image_model=None, dtype=None, device=None, operations=None, **kwargs): + def __init__(self, latent_input=False, num_union_modes=0, image_model=None, dtype=None, device=None, operations=None, **kwargs): super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs) self.main_model_double = 19 @@ -29,6 +29,11 @@ class ControlNetFlux(Flux): for _ in range(self.params.depth_single_blocks): self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)) + self.num_union_modes = num_union_modes + self.controlnet_mode_embedder = None + if self.num_union_modes > 0: + self.controlnet_mode_embedder = operations.Embedding(self.num_union_modes, self.hidden_size, dtype=dtype, device=device) + self.gradient_checkpointing = False self.latent_input = latent_input self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) @@ -61,6 +66,7 @@ class ControlNetFlux(Flux): timesteps: Tensor, y: Tensor, guidance: Tensor = None, + control_type: Tensor = None, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -79,6 +85,11 @@ class ControlNetFlux(Flux): vec = vec + self.vector_in(y) txt = self.txt_in(txt) + if self.controlnet_mode_embedder is not None and len(control_type) > 0: + control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1)) + txt = torch.cat([control_cond, txt], dim=1) + txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1) + ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) @@ -137,4 +148,4 @@ class ControlNetFlux(Flux): img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) - return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance) + return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type=kwargs.get("control_type", []))