Add Flux model support for InstantX style controlnet residuals (#4444)

* Add Flux model support for InstantX style controlnet residuals

* Refactor Flux controlnet residual step to a separate method

* Rollback minor change

* New format for applying controlnet residuals: input->double_blocks, output->single_blocks

* Adjust XLabs Flux controlnet to fit new syntax of applying Flux controlnet residuals

* Remove unnecessary import and minor style change
main
Xrvk 6 months ago committed by GitHub
parent 310ad09258
commit e68763f40c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -82,7 +82,7 @@ class ControlNetFlux(Flux):
block_res_sample = controlnet_block(block_res_sample)
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
return {"output": (controlnet_block_res_samples * 10)[:19]}
return {"input": (controlnet_block_res_samples * 10)[:19]}
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
hint = hint * 2.0 - 1.0

@ -114,19 +114,28 @@ class Flux(nn.Module):
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
for i in range(len(self.double_blocks)):
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
for i, block in enumerate(self.double_blocks):
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
if control is not None: #Controlnet
if control is not None: # Controlnet
control_i = control.get("input")
if i < len(control_i):
add = control_i[i]
if add is not None:
img += add
img = torch.cat((txt, img), 1)
for i, block in enumerate(self.single_blocks):
img = block(img, vec=vec, pe=pe)
if control is not None: # Controlnet
control_o = control.get("output")
if i < len(control_o):
add = control_o[i]
if add is not None:
img += add
img[:, txt.shape[1] :, ...] += add
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)

Loading…
Cancel
Save