|
|
|
@ -111,14 +111,14 @@ class Upsample(nn.Module):
|
|
|
|
|
upsampling occurs in the inner-two dimensions.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
|
|
|
|
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.channels = channels
|
|
|
|
|
self.out_channels = out_channels or channels
|
|
|
|
|
self.use_conv = use_conv
|
|
|
|
|
self.dims = dims
|
|
|
|
|
if use_conv:
|
|
|
|
|
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
|
|
|
|
|
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype)
|
|
|
|
|
|
|
|
|
|
def forward(self, x, output_shape=None):
|
|
|
|
|
assert x.shape[1] == self.channels
|
|
|
|
@ -160,7 +160,7 @@ class Downsample(nn.Module):
|
|
|
|
|
downsampling occurs in the inner-two dimensions.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
|
|
|
|
|
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.channels = channels
|
|
|
|
|
self.out_channels = out_channels or channels
|
|
|
|
@ -169,7 +169,7 @@ class Downsample(nn.Module):
|
|
|
|
|
stride = 2 if dims != 3 else (1, 2, 2)
|
|
|
|
|
if use_conv:
|
|
|
|
|
self.op = conv_nd(
|
|
|
|
|
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
|
|
|
|
|
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
assert self.channels == self.out_channels
|
|
|
|
@ -220,7 +220,7 @@ class ResBlock(TimestepBlock):
|
|
|
|
|
self.use_scale_shift_norm = use_scale_shift_norm
|
|
|
|
|
|
|
|
|
|
self.in_layers = nn.Sequential(
|
|
|
|
|
normalization(channels),
|
|
|
|
|
normalization(channels, dtype=dtype),
|
|
|
|
|
nn.SiLU(),
|
|
|
|
|
conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype),
|
|
|
|
|
)
|
|
|
|
@ -228,11 +228,11 @@ class ResBlock(TimestepBlock):
|
|
|
|
|
self.updown = up or down
|
|
|
|
|
|
|
|
|
|
if up:
|
|
|
|
|
self.h_upd = Upsample(channels, False, dims)
|
|
|
|
|
self.x_upd = Upsample(channels, False, dims)
|
|
|
|
|
self.h_upd = Upsample(channels, False, dims, dtype=dtype)
|
|
|
|
|
self.x_upd = Upsample(channels, False, dims, dtype=dtype)
|
|
|
|
|
elif down:
|
|
|
|
|
self.h_upd = Downsample(channels, False, dims)
|
|
|
|
|
self.x_upd = Downsample(channels, False, dims)
|
|
|
|
|
self.h_upd = Downsample(channels, False, dims, dtype=dtype)
|
|
|
|
|
self.x_upd = Downsample(channels, False, dims, dtype=dtype)
|
|
|
|
|
else:
|
|
|
|
|
self.h_upd = self.x_upd = nn.Identity()
|
|
|
|
|
|
|
|
|
@ -240,11 +240,11 @@ class ResBlock(TimestepBlock):
|
|
|
|
|
nn.SiLU(),
|
|
|
|
|
linear(
|
|
|
|
|
emb_channels,
|
|
|
|
|
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
|
|
|
|
2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
self.out_layers = nn.Sequential(
|
|
|
|
|
normalization(self.out_channels),
|
|
|
|
|
normalization(self.out_channels, dtype=dtype),
|
|
|
|
|
nn.SiLU(),
|
|
|
|
|
nn.Dropout(p=dropout),
|
|
|
|
|
zero_module(
|
|
|
|
@ -604,6 +604,7 @@ class UNetModel(nn.Module):
|
|
|
|
|
dims=dims,
|
|
|
|
|
use_checkpoint=use_checkpoint,
|
|
|
|
|
use_scale_shift_norm=use_scale_shift_norm,
|
|
|
|
|
dtype=self.dtype
|
|
|
|
|
)
|
|
|
|
|
]
|
|
|
|
|
ch = mult * model_channels
|
|
|
|
@ -651,10 +652,11 @@ class UNetModel(nn.Module):
|
|
|
|
|
use_checkpoint=use_checkpoint,
|
|
|
|
|
use_scale_shift_norm=use_scale_shift_norm,
|
|
|
|
|
down=True,
|
|
|
|
|
dtype=self.dtype
|
|
|
|
|
)
|
|
|
|
|
if resblock_updown
|
|
|
|
|
else Downsample(
|
|
|
|
|
ch, conv_resample, dims=dims, out_channels=out_ch
|
|
|
|
|
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
@ -679,6 +681,7 @@ class UNetModel(nn.Module):
|
|
|
|
|
dims=dims,
|
|
|
|
|
use_checkpoint=use_checkpoint,
|
|
|
|
|
use_scale_shift_norm=use_scale_shift_norm,
|
|
|
|
|
dtype=self.dtype
|
|
|
|
|
),
|
|
|
|
|
AttentionBlock(
|
|
|
|
|
ch,
|
|
|
|
@ -698,6 +701,7 @@ class UNetModel(nn.Module):
|
|
|
|
|
dims=dims,
|
|
|
|
|
use_checkpoint=use_checkpoint,
|
|
|
|
|
use_scale_shift_norm=use_scale_shift_norm,
|
|
|
|
|
dtype=self.dtype
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
self._feature_size += ch
|
|
|
|
@ -715,6 +719,7 @@ class UNetModel(nn.Module):
|
|
|
|
|
dims=dims,
|
|
|
|
|
use_checkpoint=use_checkpoint,
|
|
|
|
|
use_scale_shift_norm=use_scale_shift_norm,
|
|
|
|
|
dtype=self.dtype
|
|
|
|
|
)
|
|
|
|
|
]
|
|
|
|
|
ch = model_channels * mult
|
|
|
|
@ -758,18 +763,19 @@ class UNetModel(nn.Module):
|
|
|
|
|
use_checkpoint=use_checkpoint,
|
|
|
|
|
use_scale_shift_norm=use_scale_shift_norm,
|
|
|
|
|
up=True,
|
|
|
|
|
dtype=self.dtype
|
|
|
|
|
)
|
|
|
|
|
if resblock_updown
|
|
|
|
|
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
|
|
|
|
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype)
|
|
|
|
|
)
|
|
|
|
|
ds //= 2
|
|
|
|
|
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
|
|
|
|
self._feature_size += ch
|
|
|
|
|
|
|
|
|
|
self.out = nn.Sequential(
|
|
|
|
|
normalization(ch),
|
|
|
|
|
normalization(ch, dtype=self.dtype),
|
|
|
|
|
nn.SiLU(),
|
|
|
|
|
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
|
|
|
|
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype)),
|
|
|
|
|
)
|
|
|
|
|
if self.predict_codebook_ids:
|
|
|
|
|
self.id_predictor = nn.Sequential(
|
|
|
|
|