|
|
|
@ -608,6 +608,7 @@ class UNetModel(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
transformer_options["original_shape"] = list(x.shape)
|
|
|
|
|
transformer_options["current_index"] = 0
|
|
|
|
|
transformer_patches = transformer_options.get("patches", {})
|
|
|
|
|
|
|
|
|
|
assert (y is not None) == (
|
|
|
|
|
self.num_classes is not None
|
|
|
|
@ -644,6 +645,11 @@ class UNetModel(nn.Module):
|
|
|
|
|
if ctrl is not None:
|
|
|
|
|
hsp += ctrl
|
|
|
|
|
|
|
|
|
|
if "output_block_patch" in transformer_patches:
|
|
|
|
|
patch = transformer_patches["output_block_patch"]
|
|
|
|
|
for p in patch:
|
|
|
|
|
h, hsp = p(h, hsp, transformer_options)
|
|
|
|
|
|
|
|
|
|
h = th.cat([h, hsp], dim=1)
|
|
|
|
|
del hsp
|
|
|
|
|
if len(hs) > 0:
|
|
|
|
|