|
|
|
@ -514,11 +514,14 @@ class VAE:
|
|
|
|
|
self.device = device
|
|
|
|
|
|
|
|
|
|
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
|
|
|
|
steps = samples.shape[0] * utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
|
|
|
|
pbar = utils.ProgressBar(steps)
|
|
|
|
|
|
|
|
|
|
decode_fn = lambda a: (self.first_stage_model.decode(1. / self.scale_factor * a.to(self.device)) + 1.0)
|
|
|
|
|
output = torch.clamp((
|
|
|
|
|
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8) +
|
|
|
|
|
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8) +
|
|
|
|
|
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8))
|
|
|
|
|
(utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) +
|
|
|
|
|
utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) +
|
|
|
|
|
utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar))
|
|
|
|
|
/ 3.0) / 2.0, min=0.0, max=1.0)
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
@ -562,9 +565,13 @@ class VAE:
|
|
|
|
|
model_management.unload_model()
|
|
|
|
|
self.first_stage_model = self.first_stage_model.to(self.device)
|
|
|
|
|
pixel_samples = pixel_samples.movedim(-1,1).to(self.device)
|
|
|
|
|
samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4)
|
|
|
|
|
samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4)
|
|
|
|
|
samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4)
|
|
|
|
|
|
|
|
|
|
steps = utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
|
|
|
|
pbar = utils.ProgressBar(steps)
|
|
|
|
|
|
|
|
|
|
samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
|
|
|
|
samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
|
|
|
|
samples += utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
|
|
|
|
|
samples /= 3.0
|
|
|
|
|
self.first_stage_model = self.first_stage_model.cpu()
|
|
|
|
|
samples = samples.cpu()
|
|
|
|
|