|
|
@ -809,7 +809,7 @@ class UNetModel(nn.Module):
|
|
|
|
self.out = nn.Sequential(
|
|
|
|
self.out = nn.Sequential(
|
|
|
|
operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
|
|
|
operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
|
|
|
nn.SiLU(),
|
|
|
|
nn.SiLU(),
|
|
|
|
zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
|
|
|
|
operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
if self.predict_codebook_ids:
|
|
|
|
if self.predict_codebook_ids:
|
|
|
|
self.id_predictor = nn.Sequential(
|
|
|
|
self.id_predictor = nn.Sequential(
|
|
|
|