@ -4,6 +4,7 @@ from enum import Enum
from comfy import model_management
from . ldm . models . autoencoder import AutoencoderKL , AutoencodingEngine
from . ldm . cascade . stage_a import StageA
from . ldm . cascade . stage_c_coder import StageC_coder
import yaml
@ -158,6 +159,7 @@ class VAE:
self . memory_used_encode = lambda shape , dtype : ( 1767 * shape [ 2 ] * shape [ 3 ] ) * model_management . dtype_size ( dtype ) #These are for AutoencoderKL and need tweaking (should be lower)
self . memory_used_decode = lambda shape , dtype : ( 2178 * shape [ 2 ] * shape [ 3 ] * 64 ) * model_management . dtype_size ( dtype )
self . downscale_ratio = 8
self . upscale_ratio = 8
self . latent_channels = 4
self . process_input = lambda image : image * 2.0 - 1.0
self . process_output = lambda image : torch . clamp ( ( image + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
@ -176,11 +178,31 @@ class VAE:
elif " vquantizer.codebook.weight " in sd : #VQGan: stage a of stable cascade
self . first_stage_model = StageA ( )
self . downscale_ratio = 4
self . upscale_ratio = 4
#TODO
#self.memory_used_encode
#self.memory_used_decode
self . process_input = lambda image : image
self . process_output = lambda image : image
elif " backbone.1.0.block.0.1.num_batches_tracked " in sd : #effnet: encoder for stage c latent of stable cascade
self . first_stage_model = StageC_coder ( )
self . downscale_ratio = 32
self . latent_channels = 16
new_sd = { }
for k in sd :
new_sd [ " encoder. {} " . format ( k ) ] = sd [ k ]
sd = new_sd
elif " blocks.11.num_batches_tracked " in sd : #previewer: decoder for stage c latent of stable cascade
self . first_stage_model = StageC_coder ( )
self . latent_channels = 16
new_sd = { }
for k in sd :
new_sd [ " previewer. {} " . format ( k ) ] = sd [ k ]
sd = new_sd
elif " encoder.backbone.1.0.block.0.1.num_batches_tracked " in sd : #combined effnet and previewer for stable cascade
self . first_stage_model = StageC_coder ( )
self . downscale_ratio = 32
self . latent_channels = 16
else :
#default SD1.x/SD2.x VAE parameters
ddconfig = { ' double_z ' : True , ' z_channels ' : 4 , ' resolution ' : 256 , ' in_channels ' : 3 , ' out_ch ' : 3 , ' ch ' : 128 , ' ch_mult ' : [ 1 , 2 , 4 , 4 ] , ' num_res_blocks ' : 2 , ' attn_resolutions ' : [ ] , ' dropout ' : 0.0 }
@ -188,6 +210,7 @@ class VAE:
if ' encoder.down.2.downsample.conv.weight ' not in sd : #Stable diffusion x4 upscaler VAE
ddconfig [ ' ch_mult ' ] = [ 1 , 2 , 4 ]
self . downscale_ratio = 4
self . upscale_ratio = 4
self . first_stage_model = AutoencoderKL ( ddconfig = ddconfig , embed_dim = 4 )
else :
@ -213,6 +236,15 @@ class VAE:
self . patcher = comfy . model_patcher . ModelPatcher ( self . first_stage_model , load_device = self . device , offload_device = offload_device )
def vae_encode_crop_pixels ( self , pixels ) :
x = ( pixels . shape [ 1 ] / / self . downscale_ratio ) * self . downscale_ratio
y = ( pixels . shape [ 2 ] / / self . downscale_ratio ) * self . downscale_ratio
if pixels . shape [ 1 ] != x or pixels . shape [ 2 ] != y :
x_offset = ( pixels . shape [ 1 ] % self . downscale_ratio ) / / 2
y_offset = ( pixels . shape [ 2 ] % self . downscale_ratio ) / / 2
pixels = pixels [ : , x_offset : x + x_offset , y_offset : y + y_offset , : ]
return pixels
def decode_tiled_ ( self , samples , tile_x = 64 , tile_y = 64 , overlap = 16 ) :
steps = samples . shape [ 0 ] * comfy . utils . get_tiled_scale_steps ( samples . shape [ 3 ] , samples . shape [ 2 ] , tile_x , tile_y , overlap )
steps + = samples . shape [ 0 ] * comfy . utils . get_tiled_scale_steps ( samples . shape [ 3 ] , samples . shape [ 2 ] , tile_x / / 2 , tile_y * 2 , overlap )
@ -221,9 +253,9 @@ class VAE:
decode_fn = lambda a : self . first_stage_model . decode ( a . to ( self . vae_dtype ) . to ( self . device ) ) . float ( )
output = self . process_output (
( comfy . utils . tiled_scale ( samples , decode_fn , tile_x / / 2 , tile_y * 2 , overlap , upscale_amount = self . down scale_ratio, output_device = self . output_device , pbar = pbar ) +
comfy . utils . tiled_scale ( samples , decode_fn , tile_x * 2 , tile_y / / 2 , overlap , upscale_amount = self . down scale_ratio, output_device = self . output_device , pbar = pbar ) +
comfy . utils . tiled_scale ( samples , decode_fn , tile_x , tile_y , overlap , upscale_amount = self . down scale_ratio, output_device = self . output_device , pbar = pbar ) )
( comfy . utils . tiled_scale ( samples , decode_fn , tile_x / / 2 , tile_y * 2 , overlap , upscale_amount = self . up scale_ratio, output_device = self . output_device , pbar = pbar ) +
comfy . utils . tiled_scale ( samples , decode_fn , tile_x * 2 , tile_y / / 2 , overlap , upscale_amount = self . up scale_ratio, output_device = self . output_device , pbar = pbar ) +
comfy . utils . tiled_scale ( samples , decode_fn , tile_x , tile_y , overlap , upscale_amount = self . up scale_ratio, output_device = self . output_device , pbar = pbar ) )
/ 3.0 )
return output
@ -248,7 +280,7 @@ class VAE:
batch_number = int ( free_memory / memory_used )
batch_number = max ( 1 , batch_number )
pixel_samples = torch . empty ( ( samples_in . shape [ 0 ] , 3 , round ( samples_in . shape [ 2 ] * self . down scale_ratio) , round ( samples_in . shape [ 3 ] * self . down scale_ratio) ) , device = self . output_device )
pixel_samples = torch . empty ( ( samples_in . shape [ 0 ] , 3 , round ( samples_in . shape [ 2 ] * self . up scale_ratio) , round ( samples_in . shape [ 3 ] * self . up scale_ratio) ) , device = self . output_device )
for x in range ( 0 , samples_in . shape [ 0 ] , batch_number ) :
samples = samples_in [ x : x + batch_number ] . to ( self . vae_dtype ) . to ( self . device )
pixel_samples [ x : x + batch_number ] = self . process_output ( self . first_stage_model . decode ( samples ) . to ( self . output_device ) . float ( ) )
@ -265,6 +297,7 @@ class VAE:
return output . movedim ( 1 , - 1 )
def encode ( self , pixel_samples ) :
pixel_samples = self . vae_encode_crop_pixels ( pixel_samples )
pixel_samples = pixel_samples . movedim ( - 1 , 1 )
try :
memory_used = self . memory_used_encode ( pixel_samples . shape , self . vae_dtype )
@ -284,6 +317,7 @@ class VAE:
return samples
def encode_tiled ( self , pixel_samples , tile_x = 512 , tile_y = 512 , overlap = 64 ) :
pixel_samples = self . vae_encode_crop_pixels ( pixel_samples )
model_management . load_model_gpu ( self . patcher )
pixel_samples = pixel_samples . movedim ( - 1 , 1 )
samples = self . encode_tiled_ ( pixel_samples , tile_x = tile_x , tile_y = tile_y , overlap = overlap )