|
|
|
@ -31,7 +31,7 @@ class TimestepBlock(nn.Module):
|
|
|
|
|
Apply the module to `x` given `emb` timestep embeddings.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
#This is needed because accelerate makes a copy of transformer_options which breaks "current_index"
|
|
|
|
|
#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
|
|
|
|
|
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
|
|
|
|
|
for layer in ts:
|
|
|
|
|
if isinstance(layer, VideoResBlock):
|
|
|
|
@ -40,11 +40,12 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
|
|
|
|
|
x = layer(x, emb)
|
|
|
|
|
elif isinstance(layer, SpatialVideoTransformer):
|
|
|
|
|
x = layer(x, context, time_context, num_video_frames, image_only_indicator, transformer_options)
|
|
|
|
|
transformer_options["current_index"] += 1
|
|
|
|
|
if "transformer_index" in transformer_options:
|
|
|
|
|
transformer_options["transformer_index"] += 1
|
|
|
|
|
elif isinstance(layer, SpatialTransformer):
|
|
|
|
|
x = layer(x, context, transformer_options)
|
|
|
|
|
if "current_index" in transformer_options:
|
|
|
|
|
transformer_options["current_index"] += 1
|
|
|
|
|
if "transformer_index" in transformer_options:
|
|
|
|
|
transformer_options["transformer_index"] += 1
|
|
|
|
|
elif isinstance(layer, Upsample):
|
|
|
|
|
x = layer(x, output_shape=output_shape)
|
|
|
|
|
else:
|
|
|
|
@ -830,7 +831,7 @@ class UNetModel(nn.Module):
|
|
|
|
|
:return: an [N x C x ...] Tensor of outputs.
|
|
|
|
|
"""
|
|
|
|
|
transformer_options["original_shape"] = list(x.shape)
|
|
|
|
|
transformer_options["current_index"] = 0
|
|
|
|
|
transformer_options["transformer_index"] = 0
|
|
|
|
|
transformer_patches = transformer_options.get("patches", {})
|
|
|
|
|
|
|
|
|
|
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
|
|
|
|
|