|
|
|
@ -163,11 +163,9 @@ class ResBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StageA(nn.Module):
|
|
|
|
|
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192,
|
|
|
|
|
scale_factor=0.43): # 0.3764
|
|
|
|
|
def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.c_latent = c_latent
|
|
|
|
|
self.scale_factor = scale_factor
|
|
|
|
|
c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
|
|
|
|
|
|
|
|
|
|
# Encoder blocks
|
|
|
|
@ -214,12 +212,11 @@ class StageA(nn.Module):
|
|
|
|
|
x = self.down_blocks(x)
|
|
|
|
|
if quantize:
|
|
|
|
|
qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
|
|
|
|
|
return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25
|
|
|
|
|
return qe, x, indices, vq_loss + commit_loss * 0.25
|
|
|
|
|
else:
|
|
|
|
|
return x / self.scale_factor
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def decode(self, x):
|
|
|
|
|
x = x * self.scale_factor
|
|
|
|
|
x = self.up_blocks(x)
|
|
|
|
|
x = self.out_block(x)
|
|
|
|
|
return x
|
|
|
|
|