From ae43f09ef72683c044a635cdedbe67329583e2bc Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 15 Jun 2023 18:42:30 -0400 Subject: [PATCH] All the unet weights should now be initialized with the right dtype. --- comfy/ldm/modules/attention.py | 12 +++---- .../modules/diffusionmodules/openaimodel.py | 36 +++++++++++-------- comfy/ldm/modules/diffusionmodules/util.py | 4 +-- 3 files changed, 29 insertions(+), 23 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 62248f7..62707df 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -51,9 +51,9 @@ def init_(tensor): # feedforward class GEGLU(nn.Module): - def __init__(self, dim_in, dim_out): + def __init__(self, dim_in, dim_out, dtype=None): super().__init__() - self.proj = comfy.ops.Linear(dim_in, dim_out * 2) + self.proj = comfy.ops.Linear(dim_in, dim_out * 2, dtype=dtype) def forward(self, x): x, gate = self.proj(x).chunk(2, dim=-1) @@ -68,7 +68,7 @@ class FeedForward(nn.Module): project_in = nn.Sequential( comfy.ops.Linear(dim, inner_dim, dtype=dtype), nn.GELU() - ) if not glu else GEGLU(dim, inner_dim) + ) if not glu else GEGLU(dim, inner_dim, dtype=dtype) self.net = nn.Sequential( project_in, @@ -89,8 +89,8 @@ def zero_module(module): return module -def Normalize(in_channels): - return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) +def Normalize(in_channels, dtype=None): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype) class SpatialSelfAttention(nn.Module): @@ -594,7 +594,7 @@ class SpatialTransformer(nn.Module): context_dim = [context_dim] self.in_channels = in_channels inner_dim = n_heads * d_head - self.norm = Normalize(in_channels) + self.norm = Normalize(in_channels, dtype=dtype) if not use_linear: self.proj_in = nn.Conv2d(in_channels, inner_dim, diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 0307831..e170f67 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -111,14 +111,14 @@ class Upsample(nn.Module): upsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.dims = dims if use_conv: - self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype) def forward(self, x, output_shape=None): assert x.shape[1] == self.channels @@ -160,7 +160,7 @@ class Downsample(nn.Module): downsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -169,7 +169,7 @@ class Downsample(nn.Module): stride = 2 if dims != 3 else (1, 2, 2) if use_conv: self.op = conv_nd( - dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype ) else: assert self.channels == self.out_channels @@ -220,7 +220,7 @@ class ResBlock(TimestepBlock): self.use_scale_shift_norm = use_scale_shift_norm self.in_layers = nn.Sequential( - normalization(channels), + normalization(channels, dtype=dtype), nn.SiLU(), conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype), ) @@ -228,11 +228,11 @@ class ResBlock(TimestepBlock): self.updown = up or down if up: - self.h_upd = Upsample(channels, False, dims) - self.x_upd = Upsample(channels, False, dims) + self.h_upd = Upsample(channels, False, dims, dtype=dtype) + self.x_upd = Upsample(channels, False, dims, dtype=dtype) elif down: - self.h_upd = Downsample(channels, False, dims) - self.x_upd = Downsample(channels, False, dims) + self.h_upd = Downsample(channels, False, dims, dtype=dtype) + self.x_upd = Downsample(channels, False, dims, dtype=dtype) else: self.h_upd = self.x_upd = nn.Identity() @@ -240,11 +240,11 @@ class ResBlock(TimestepBlock): nn.SiLU(), linear( emb_channels, - 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype ), ) self.out_layers = nn.Sequential( - normalization(self.out_channels), + normalization(self.out_channels, dtype=dtype), nn.SiLU(), nn.Dropout(p=dropout), zero_module( @@ -604,6 +604,7 @@ class UNetModel(nn.Module): dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype ) ] ch = mult * model_channels @@ -651,10 +652,11 @@ class UNetModel(nn.Module): use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, down=True, + dtype=self.dtype ) if resblock_updown else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch + ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype ) ) ) @@ -679,6 +681,7 @@ class UNetModel(nn.Module): dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype ), AttentionBlock( ch, @@ -698,6 +701,7 @@ class UNetModel(nn.Module): dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype ), ) self._feature_size += ch @@ -715,6 +719,7 @@ class UNetModel(nn.Module): dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype ) ] ch = model_channels * mult @@ -758,18 +763,19 @@ class UNetModel(nn.Module): use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, up=True, + dtype=self.dtype ) if resblock_updown - else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype) ) ds //= 2 self.output_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch self.out = nn.Sequential( - normalization(ch), + normalization(ch, dtype=self.dtype), nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype)), ) if self.predict_codebook_ids: self.id_predictor = nn.Sequential( diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index d6a4778..d890c80 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -206,13 +206,13 @@ def mean_flat(tensor): return tensor.mean(dim=list(range(1, len(tensor.shape)))) -def normalization(channels): +def normalization(channels, dtype=None): """ Make a standard normalization layer. :param channels: number of input channels. :return: an nn.Module for normalization. """ - return GroupNorm32(32, channels) + return GroupNorm32(32, channels, dtype=dtype) # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.