diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 9054a1c..9a652c2 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -774,17 +774,23 @@ class UNetModel(nn.Module): emb = emb + self.label_emb(y) h = x.type(self.dtype) - for module in self.input_blocks: + for id, module in enumerate(self.input_blocks): h = module(h, emb, context) + if control is not None and 'input' in control and len(control['input']) > 0: + ctrl = control['input'].pop() + if ctrl is not None: + h += ctrl hs.append(h) h = self.middle_block(h, emb, context) - if control is not None: - h += control.pop() + if control is not None and 'middle' in control and len(control['middle']) > 0: + h += control['middle'].pop() for module in self.output_blocks: hsp = hs.pop() - if control is not None: - hsp += control.pop() + if control is not None and 'output' in control and len(control['output']) > 0: + ctrl = control['output'].pop() + if ctrl is not None: + hsp += ctrl h = th.cat([h, hsp], dim=1) del hsp h = module(h, emb, context) diff --git a/comfy/sd.py b/comfy/sd.py index 39f88fa..1113677 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -388,18 +388,28 @@ class ControlNet: self.control_model = model_management.load_if_low_vram(self.control_model) control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt) self.control_model = model_management.unload_if_low_vram(self.control_model) - out = [] + out = {'input':[], 'middle':[], 'output': []} autocast_enabled = torch.is_autocast_enabled() for i in range(len(control)): + if i == (len(control) - 1): + key = 'middle' + index = 0 + else: + key = 'output' + index = i x = control[i] x *= self.strength if x.dtype != output_dtype and not autocast_enabled: x = x.to(output_dtype) - if control_prev is not None: - x += control_prev[i] - out.append(x) + if control_prev is not None and key in control_prev: + prev = control_prev[key][index] + if prev is not None: + x += prev + out[key].append(x) + if control_prev is not None and 'input' in control_prev: + out['input'] = control_prev['input'] return out def set_cond_hint(self, cond_hint, strength=1.0):