|
|
|
@ -774,17 +774,23 @@ class UNetModel(nn.Module):
|
|
|
|
|
emb = emb + self.label_emb(y)
|
|
|
|
|
|
|
|
|
|
h = x.type(self.dtype)
|
|
|
|
|
for module in self.input_blocks:
|
|
|
|
|
for id, module in enumerate(self.input_blocks):
|
|
|
|
|
h = module(h, emb, context)
|
|
|
|
|
if control is not None and 'input' in control and len(control['input']) > 0:
|
|
|
|
|
ctrl = control['input'].pop()
|
|
|
|
|
if ctrl is not None:
|
|
|
|
|
h += ctrl
|
|
|
|
|
hs.append(h)
|
|
|
|
|
h = self.middle_block(h, emb, context)
|
|
|
|
|
if control is not None:
|
|
|
|
|
h += control.pop()
|
|
|
|
|
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
|
|
|
|
h += control['middle'].pop()
|
|
|
|
|
|
|
|
|
|
for module in self.output_blocks:
|
|
|
|
|
hsp = hs.pop()
|
|
|
|
|
if control is not None:
|
|
|
|
|
hsp += control.pop()
|
|
|
|
|
if control is not None and 'output' in control and len(control['output']) > 0:
|
|
|
|
|
ctrl = control['output'].pop()
|
|
|
|
|
if ctrl is not None:
|
|
|
|
|
hsp += ctrl
|
|
|
|
|
h = th.cat([h, hsp], dim=1)
|
|
|
|
|
del hsp
|
|
|
|
|
h = module(h, emb, context)
|
|
|
|
|