|
|
|
@ -14,7 +14,7 @@ import comfy.ldm.common_dit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlNetFlux(Flux):
|
|
|
|
|
def __init__(self, latent_input=False, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
|
|
|
|
def __init__(self, latent_input=False, num_union_modes=0, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
|
|
|
|
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
|
|
|
|
|
|
|
|
|
self.main_model_double = 19
|
|
|
|
@ -29,6 +29,11 @@ class ControlNetFlux(Flux):
|
|
|
|
|
for _ in range(self.params.depth_single_blocks):
|
|
|
|
|
self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device))
|
|
|
|
|
|
|
|
|
|
self.num_union_modes = num_union_modes
|
|
|
|
|
self.controlnet_mode_embedder = None
|
|
|
|
|
if self.num_union_modes > 0:
|
|
|
|
|
self.controlnet_mode_embedder = operations.Embedding(self.num_union_modes, self.hidden_size, dtype=dtype, device=device)
|
|
|
|
|
|
|
|
|
|
self.gradient_checkpointing = False
|
|
|
|
|
self.latent_input = latent_input
|
|
|
|
|
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
|
|
|
@ -61,6 +66,7 @@ class ControlNetFlux(Flux):
|
|
|
|
|
timesteps: Tensor,
|
|
|
|
|
y: Tensor,
|
|
|
|
|
guidance: Tensor = None,
|
|
|
|
|
control_type: Tensor = None,
|
|
|
|
|
) -> Tensor:
|
|
|
|
|
if img.ndim != 3 or txt.ndim != 3:
|
|
|
|
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
|
|
|
@ -79,6 +85,11 @@ class ControlNetFlux(Flux):
|
|
|
|
|
vec = vec + self.vector_in(y)
|
|
|
|
|
txt = self.txt_in(txt)
|
|
|
|
|
|
|
|
|
|
if self.controlnet_mode_embedder is not None and len(control_type) > 0:
|
|
|
|
|
control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1))
|
|
|
|
|
txt = torch.cat([control_cond, txt], dim=1)
|
|
|
|
|
txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1)
|
|
|
|
|
|
|
|
|
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
|
|
|
|
pe = self.pe_embedder(ids)
|
|
|
|
|
|
|
|
|
@ -137,4 +148,4 @@ class ControlNetFlux(Flux):
|
|
|
|
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
|
|
|
|
|
|
|
|
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
|
|
|
|
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance)
|
|
|
|
|
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type=kwargs.get("control_type", []))
|
|
|
|
|