|
|
@ -251,6 +251,12 @@ class Timestep(nn.Module):
|
|
|
|
def forward(self, t):
|
|
|
|
def forward(self, t):
|
|
|
|
return timestep_embedding(t, self.dim)
|
|
|
|
return timestep_embedding(t, self.dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def apply_control(h, control, name):
|
|
|
|
|
|
|
|
if control is not None and name in control and len(control[name]) > 0:
|
|
|
|
|
|
|
|
ctrl = control[name].pop()
|
|
|
|
|
|
|
|
if ctrl is not None:
|
|
|
|
|
|
|
|
h += ctrl
|
|
|
|
|
|
|
|
return h
|
|
|
|
|
|
|
|
|
|
|
|
class UNetModel(nn.Module):
|
|
|
|
class UNetModel(nn.Module):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -617,25 +623,17 @@ class UNetModel(nn.Module):
|
|
|
|
for id, module in enumerate(self.input_blocks):
|
|
|
|
for id, module in enumerate(self.input_blocks):
|
|
|
|
transformer_options["block"] = ("input", id)
|
|
|
|
transformer_options["block"] = ("input", id)
|
|
|
|
h = forward_timestep_embed(module, h, emb, context, transformer_options)
|
|
|
|
h = forward_timestep_embed(module, h, emb, context, transformer_options)
|
|
|
|
if control is not None and 'input' in control and len(control['input']) > 0:
|
|
|
|
h = apply_control(h, control, 'input')
|
|
|
|
ctrl = control['input'].pop()
|
|
|
|
|
|
|
|
if ctrl is not None:
|
|
|
|
|
|
|
|
h += ctrl
|
|
|
|
|
|
|
|
hs.append(h)
|
|
|
|
hs.append(h)
|
|
|
|
|
|
|
|
|
|
|
|
transformer_options["block"] = ("middle", 0)
|
|
|
|
transformer_options["block"] = ("middle", 0)
|
|
|
|
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
|
|
|
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
|
|
|
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
|
|
|
h = apply_control(h, control, 'middle')
|
|
|
|
ctrl = control['middle'].pop()
|
|
|
|
|
|
|
|
if ctrl is not None:
|
|
|
|
|
|
|
|
h += ctrl
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for id, module in enumerate(self.output_blocks):
|
|
|
|
for id, module in enumerate(self.output_blocks):
|
|
|
|
transformer_options["block"] = ("output", id)
|
|
|
|
transformer_options["block"] = ("output", id)
|
|
|
|
hsp = hs.pop()
|
|
|
|
hsp = hs.pop()
|
|
|
|
if control is not None and 'output' in control and len(control['output']) > 0:
|
|
|
|
h = apply_control(h, control, 'output')
|
|
|
|
ctrl = control['output'].pop()
|
|
|
|
|
|
|
|
if ctrl is not None:
|
|
|
|
|
|
|
|
hsp += ctrl
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "output_block_patch" in transformer_patches:
|
|
|
|
if "output_block_patch" in transformer_patches:
|
|
|
|
patch = transformer_patches["output_block_patch"]
|
|
|
|
patch = transformer_patches["output_block_patch"]
|
|
|
|