|
|
|
@ -208,6 +208,7 @@ class ResBlock(TimestepBlock):
|
|
|
|
|
use_checkpoint=False,
|
|
|
|
|
up=False,
|
|
|
|
|
down=False,
|
|
|
|
|
dtype=None
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.channels = channels
|
|
|
|
@ -221,7 +222,7 @@ class ResBlock(TimestepBlock):
|
|
|
|
|
self.in_layers = nn.Sequential(
|
|
|
|
|
normalization(channels),
|
|
|
|
|
nn.SiLU(),
|
|
|
|
|
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
|
|
|
|
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.updown = up or down
|
|
|
|
@ -247,7 +248,7 @@ class ResBlock(TimestepBlock):
|
|
|
|
|
nn.SiLU(),
|
|
|
|
|
nn.Dropout(p=dropout),
|
|
|
|
|
zero_module(
|
|
|
|
|
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
|
|
|
|
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype)
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -255,10 +256,10 @@ class ResBlock(TimestepBlock):
|
|
|
|
|
self.skip_connection = nn.Identity()
|
|
|
|
|
elif use_conv:
|
|
|
|
|
self.skip_connection = conv_nd(
|
|
|
|
|
dims, channels, self.out_channels, 3, padding=1
|
|
|
|
|
dims, channels, self.out_channels, 3, padding=1, dtype=dtype
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
|
|
|
|
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=dtype)
|
|
|
|
|
|
|
|
|
|
def forward(self, x, emb):
|
|
|
|
|
"""
|
|
|
|
@ -558,9 +559,9 @@ class UNetModel(nn.Module):
|
|
|
|
|
|
|
|
|
|
time_embed_dim = model_channels * 4
|
|
|
|
|
self.time_embed = nn.Sequential(
|
|
|
|
|
linear(model_channels, time_embed_dim),
|
|
|
|
|
linear(model_channels, time_embed_dim, dtype=self.dtype),
|
|
|
|
|
nn.SiLU(),
|
|
|
|
|
linear(time_embed_dim, time_embed_dim),
|
|
|
|
|
linear(time_embed_dim, time_embed_dim, dtype=self.dtype),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if self.num_classes is not None:
|
|
|
|
@ -573,9 +574,9 @@ class UNetModel(nn.Module):
|
|
|
|
|
assert adm_in_channels is not None
|
|
|
|
|
self.label_emb = nn.Sequential(
|
|
|
|
|
nn.Sequential(
|
|
|
|
|
linear(adm_in_channels, time_embed_dim),
|
|
|
|
|
linear(adm_in_channels, time_embed_dim, dtype=self.dtype),
|
|
|
|
|
nn.SiLU(),
|
|
|
|
|
linear(time_embed_dim, time_embed_dim),
|
|
|
|
|
linear(time_embed_dim, time_embed_dim, dtype=self.dtype),
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
@ -584,7 +585,7 @@ class UNetModel(nn.Module):
|
|
|
|
|
self.input_blocks = nn.ModuleList(
|
|
|
|
|
[
|
|
|
|
|
TimestepEmbedSequential(
|
|
|
|
|
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
|
|
|
|
conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype)
|
|
|
|
|
)
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|