From 50dc39d6ec5420f35b81f965c106b6710ff48e6e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 26 Nov 2023 03:13:56 -0500 Subject: [PATCH] Clean up the extra_options dict for the transformer patches. Now everything in transformer_options gets put in extra_options. --- comfy/ldm/modules/attention.py | 31 ++++++------------- .../modules/diffusionmodules/openaimodel.py | 11 ++++--- 2 files changed, 16 insertions(+), 26 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index d511dda..7dc1a1b 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -430,31 +430,20 @@ class BasicTransformerBlock(nn.Module): extra_options = {} block = None block_index = 0 - if "current_index" in transformer_options: - extra_options["transformer_index"] = transformer_options["current_index"] - if "block_index" in transformer_options: - block_index = transformer_options["block_index"] - extra_options["block_index"] = block_index - if "original_shape" in transformer_options: - extra_options["original_shape"] = transformer_options["original_shape"] - if "block" in transformer_options: - block = transformer_options["block"] - extra_options["block"] = block - if "cond_or_uncond" in transformer_options: - extra_options["cond_or_uncond"] = transformer_options["cond_or_uncond"] - if "patches" in transformer_options: - transformer_patches = transformer_options["patches"] - else: - transformer_patches = {} + transformer_patches = {} + transformer_patches_replace = {} + + for k in transformer_options: + if k == "patches": + transformer_patches = transformer_options[k] + elif k == "patches_replace": + transformer_patches_replace = transformer_options[k] + else: + extra_options[k] = transformer_options[k] extra_options["n_heads"] = self.n_heads extra_options["dim_head"] = self.d_head - if "patches_replace" in transformer_options: - transformer_patches_replace = transformer_options["patches_replace"] - else: - transformer_patches_replace = {} - if self.ff_in: x_skip = x x = self.ff_in(self.norm_in(x)) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index a497ed3..4826489 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -31,7 +31,7 @@ class TimestepBlock(nn.Module): Apply the module to `x` given `emb` timestep embeddings. """ -#This is needed because accelerate makes a copy of transformer_options which breaks "current_index" +#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index" def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None): for layer in ts: if isinstance(layer, VideoResBlock): @@ -40,11 +40,12 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out x = layer(x, emb) elif isinstance(layer, SpatialVideoTransformer): x = layer(x, context, time_context, num_video_frames, image_only_indicator, transformer_options) - transformer_options["current_index"] += 1 + if "transformer_index" in transformer_options: + transformer_options["transformer_index"] += 1 elif isinstance(layer, SpatialTransformer): x = layer(x, context, transformer_options) - if "current_index" in transformer_options: - transformer_options["current_index"] += 1 + if "transformer_index" in transformer_options: + transformer_options["transformer_index"] += 1 elif isinstance(layer, Upsample): x = layer(x, output_shape=output_shape) else: @@ -830,7 +831,7 @@ class UNetModel(nn.Module): :return: an [N x C x ...] Tensor of outputs. """ transformer_options["original_shape"] = list(x.shape) - transformer_options["current_index"] = 0 + transformer_options["transformer_index"] = 0 transformer_patches = transformer_options.get("patches", {}) num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)