|
|
|
@ -25,18 +25,19 @@ class Block(nn.Module):
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
return self.fuse(self.conv(x) + self.skip(x))
|
|
|
|
|
|
|
|
|
|
def Encoder():
|
|
|
|
|
def Encoder(latent_channels=4):
|
|
|
|
|
return nn.Sequential(
|
|
|
|
|
conv(3, 64), Block(64, 64),
|
|
|
|
|
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
|
|
|
|
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
|
|
|
|
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
|
|
|
|
conv(64, 4),
|
|
|
|
|
conv(64, latent_channels),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def Decoder():
|
|
|
|
|
|
|
|
|
|
def Decoder(latent_channels=4):
|
|
|
|
|
return nn.Sequential(
|
|
|
|
|
Clamp(), conv(4, 64), nn.ReLU(),
|
|
|
|
|
Clamp(), conv(latent_channels, 64), nn.ReLU(),
|
|
|
|
|
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
|
|
|
|
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
|
|
|
|
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
|
|
|
@ -47,11 +48,11 @@ class TAESD(nn.Module):
|
|
|
|
|
latent_magnitude = 3
|
|
|
|
|
latent_shift = 0.5
|
|
|
|
|
|
|
|
|
|
def __init__(self, encoder_path=None, decoder_path=None):
|
|
|
|
|
def __init__(self, encoder_path=None, decoder_path=None, latent_channels=4):
|
|
|
|
|
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.taesd_encoder = Encoder()
|
|
|
|
|
self.taesd_decoder = Decoder()
|
|
|
|
|
self.taesd_encoder = Encoder(latent_channels=latent_channels)
|
|
|
|
|
self.taesd_decoder = Decoder(latent_channels=latent_channels)
|
|
|
|
|
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
|
|
|
|
|
if encoder_path is not None:
|
|
|
|
|
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
|
|
|
|
|