|
|
|
@ -115,14 +115,14 @@ class Flux(nn.Module):
|
|
|
|
|
pe = self.pe_embedder(ids)
|
|
|
|
|
|
|
|
|
|
for i, block in enumerate(self.double_blocks):
|
|
|
|
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
|
|
|
|
|
|
|
|
|
if control is not None: # Controlnet
|
|
|
|
|
control_i = control.get("input")
|
|
|
|
|
if i < len(control_i):
|
|
|
|
|
add = control_i[i]
|
|
|
|
|
if add is not None:
|
|
|
|
|
img += add
|
|
|
|
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
|
|
|
|
|
|
|
|
|
if control is not None: # Controlnet
|
|
|
|
|
control_i = control.get("input")
|
|
|
|
|
if i < len(control_i):
|
|
|
|
|
add = control_i[i]
|
|
|
|
|
if add is not None:
|
|
|
|
|
img += add
|
|
|
|
|
|
|
|
|
|
img = torch.cat((txt, img), 1)
|
|
|
|
|
|
|
|
|
|