Initialize more unet weights as the right dtype.

main
comfyanonymous 2 years ago
parent e21d9ad445
commit 7bf89ba923

@ -208,6 +208,7 @@ class ResBlock(TimestepBlock):
use_checkpoint=False,
up=False,
down=False,
dtype=None
):
super().__init__()
self.channels = channels
@ -221,7 +222,7 @@ class ResBlock(TimestepBlock):
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype),
)
self.updown = up or down
@ -247,7 +248,7 @@ class ResBlock(TimestepBlock):
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype)
),
)
@ -255,10 +256,10 @@ class ResBlock(TimestepBlock):
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1
dims, channels, self.out_channels, 3, padding=1, dtype=dtype
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype)
def forward(self, x, emb):
"""
@ -558,9 +559,9 @@ class UNetModel(nn.Module):
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
linear(model_channels, time_embed_dim, dtype=self.dtype),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
linear(time_embed_dim, time_embed_dim, dtype=self.dtype),
)
if self.num_classes is not None:
@ -573,9 +574,9 @@ class UNetModel(nn.Module):
assert adm_in_channels is not None
self.label_emb = nn.Sequential(
nn.Sequential(
linear(adm_in_channels, time_embed_dim),
linear(adm_in_channels, time_embed_dim, dtype=self.dtype),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
linear(time_embed_dim, time_embed_dim, dtype=self.dtype),
)
)
else:
@ -584,7 +585,7 @@ class UNetModel(nn.Module):
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype)
)
]
)

Loading…
Cancel
Save