@ -589,11 +589,14 @@ class VAE:
self . first_stage_model = self . first_stage_model . to ( self . device )
pixel_samples = pixel_samples . movedim ( - 1 , 1 )
try :
batch_number = 1
free_memory = model_management . get_free_memory ( self . device )
batch_number = int ( ( free_memory * 0.7 ) / ( 2078 * pixel_samples . shape [ 2 ] * pixel_samples . shape [ 3 ] ) ) #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
batch_number = max ( 1 , batch_number )
samples = torch . empty ( ( pixel_samples . shape [ 0 ] , 4 , round ( pixel_samples . shape [ 2 ] / / 8 ) , round ( pixel_samples . shape [ 3 ] / / 8 ) ) , device = " cpu " )
for x in range ( 0 , pixel_samples . shape [ 0 ] , batch_number ) :
pixels_in = ( 2. * pixel_samples [ x : x + batch_number ] - 1. ) . to ( self . device )
samples [ x : x + batch_number ] = self . first_stage_model . encode ( pixels_in ) . sample ( ) . cpu ( ) * self . scale_factor
except model_management . OOM_EXCEPTION as e :
print ( " Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding. " )
samples = self . encode_tiled_ ( pixel_samples )