|
|
|
@ -23,8 +23,12 @@ class ControlNetFlux(Flux):
|
|
|
|
|
self.controlnet_blocks = nn.ModuleList([])
|
|
|
|
|
for _ in range(self.params.depth):
|
|
|
|
|
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
|
|
|
|
# controlnet_block = zero_module(controlnet_block)
|
|
|
|
|
self.controlnet_blocks.append(controlnet_block)
|
|
|
|
|
|
|
|
|
|
self.controlnet_single_blocks = nn.ModuleList([])
|
|
|
|
|
for _ in range(self.params.depth_single_blocks):
|
|
|
|
|
self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device))
|
|
|
|
|
|
|
|
|
|
self.gradient_checkpointing = False
|
|
|
|
|
self.latent_input = latent_input
|
|
|
|
|
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
|
|
|
@ -78,26 +82,39 @@ class ControlNetFlux(Flux):
|
|
|
|
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
|
|
|
|
pe = self.pe_embedder(ids)
|
|
|
|
|
|
|
|
|
|
block_res_samples = ()
|
|
|
|
|
controlnet_double = ()
|
|
|
|
|
|
|
|
|
|
for i in range(len(self.double_blocks)):
|
|
|
|
|
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
|
|
|
|
|
controlnet_double = controlnet_double + (self.controlnet_blocks[i](img),)
|
|
|
|
|
|
|
|
|
|
for block in self.double_blocks:
|
|
|
|
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
|
|
|
|
block_res_samples = block_res_samples + (img,)
|
|
|
|
|
img = torch.cat((txt, img), 1)
|
|
|
|
|
|
|
|
|
|
controlnet_block_res_samples = ()
|
|
|
|
|
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
|
|
|
|
block_res_sample = controlnet_block(block_res_sample)
|
|
|
|
|
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
|
|
|
|
controlnet_single = ()
|
|
|
|
|
|
|
|
|
|
for i in range(len(self.single_blocks)):
|
|
|
|
|
img = self.single_blocks[i](img, vec=vec, pe=pe)
|
|
|
|
|
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),)
|
|
|
|
|
|
|
|
|
|
repeat = math.ceil(self.main_model_double / len(controlnet_block_res_samples))
|
|
|
|
|
repeat = math.ceil(self.main_model_double / len(controlnet_double))
|
|
|
|
|
if self.latent_input:
|
|
|
|
|
out_input = ()
|
|
|
|
|
for x in controlnet_block_res_samples:
|
|
|
|
|
for x in controlnet_double:
|
|
|
|
|
out_input += (x,) * repeat
|
|
|
|
|
else:
|
|
|
|
|
out_input = (controlnet_block_res_samples * repeat)
|
|
|
|
|
return {"input": out_input[:self.main_model_double]}
|
|
|
|
|
out_input = (controlnet_double * repeat)
|
|
|
|
|
|
|
|
|
|
out = {"input": out_input[:self.main_model_double]}
|
|
|
|
|
if len(controlnet_single) > 0:
|
|
|
|
|
repeat = math.ceil(self.main_model_single / len(controlnet_single))
|
|
|
|
|
out_output = ()
|
|
|
|
|
if self.latent_input:
|
|
|
|
|
for x in controlnet_single:
|
|
|
|
|
out_output += (x,) * repeat
|
|
|
|
|
else:
|
|
|
|
|
out_output = (controlnet_single * repeat)
|
|
|
|
|
out["output"] = out_output[:self.main_model_single]
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
|
|
|
|
patch_size = 2
|
|
|
|
|