diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index fbb5832..0307831 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -208,6 +208,7 @@ class ResBlock(TimestepBlock): use_checkpoint=False, up=False, down=False, + dtype=None ): super().__init__() self.channels = channels @@ -221,7 +222,7 @@ class ResBlock(TimestepBlock): self.in_layers = nn.Sequential( normalization(channels), nn.SiLU(), - conv_nd(dims, channels, self.out_channels, 3, padding=1), + conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype), ) self.updown = up or down @@ -247,7 +248,7 @@ class ResBlock(TimestepBlock): nn.SiLU(), nn.Dropout(p=dropout), zero_module( - conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype) ), ) @@ -255,10 +256,10 @@ class ResBlock(TimestepBlock): self.skip_connection = nn.Identity() elif use_conv: self.skip_connection = conv_nd( - dims, channels, self.out_channels, 3, padding=1 + dims, channels, self.out_channels, 3, padding=1, dtype=dtype ) else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype) def forward(self, x, emb): """ @@ -558,9 +559,9 @@ class UNetModel(nn.Module): time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), + linear(model_channels, time_embed_dim, dtype=self.dtype), nn.SiLU(), - linear(time_embed_dim, time_embed_dim), + linear(time_embed_dim, time_embed_dim, dtype=self.dtype), ) if self.num_classes is not None: @@ -573,9 +574,9 @@ class UNetModel(nn.Module): assert adm_in_channels is not None self.label_emb = nn.Sequential( nn.Sequential( - linear(adm_in_channels, time_embed_dim), + linear(adm_in_channels, time_embed_dim, dtype=self.dtype), nn.SiLU(), - linear(time_embed_dim, time_embed_dim), + linear(time_embed_dim, time_embed_dim, dtype=self.dtype), ) ) else: @@ -584,7 +585,7 @@ class UNetModel(nn.Module): self.input_blocks = nn.ModuleList( [ TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) + conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype) ) ] )