@ -8,7 +8,8 @@ import logging
import comfy . sampler_helpers
def get_area_and_mult ( conds , x_in , timestep_in ) :
area = ( x_in . shape [ 2 ] , x_in . shape [ 3 ] , 0 , 0 )
dims = tuple ( x_in . shape [ 2 : ] )
area = None
strength = 1.0
if ' timestep_start ' in conds :
@ -20,11 +21,16 @@ def get_area_and_mult(conds, x_in, timestep_in):
if timestep_in [ 0 ] < timestep_end :
return None
if ' area ' in conds :
area = conds [ ' area ' ]
area = list ( conds [ ' area ' ] )
if ' strength ' in conds :
strength = conds [ ' strength ' ]
input_x = x_in [ : , : , area [ 2 ] : area [ 0 ] + area [ 2 ] , area [ 3 ] : area [ 1 ] + area [ 3 ] ]
input_x = x_in
if area is not None :
for i in range ( len ( dims ) ) :
area [ i ] = min ( input_x . shape [ i + 2 ] - area [ len ( dims ) + i ] , area [ i ] )
input_x = input_x . narrow ( i + 2 , area [ len ( dims ) + i ] , area [ i ] )
if ' mask ' in conds :
# Scale the mask to the size of the input
# The mask should have been resized as we began the sampling process
@ -32,28 +38,30 @@ def get_area_and_mult(conds, x_in, timestep_in):
if " mask_strength " in conds :
mask_strength = conds [ " mask_strength " ]
mask = conds [ ' mask ' ]
assert ( mask . shape [ 1 ] == x_in . shape [ 2 ] )
assert ( mask . shape [ 2 ] == x_in . shape [ 3 ] )
mask = mask [ : input_x . shape [ 0 ] , area [ 2 ] : area [ 0 ] + area [ 2 ] , area [ 3 ] : area [ 1 ] + area [ 3 ] ] * mask_strength
assert ( mask . shape [ 1 : ] == x_in . shape [ 2 : ] )
mask = mask [ : input_x . shape [ 0 ] ]
if area is not None :
for i in range ( len ( dims ) ) :
mask = mask . narrow ( i + 1 , area [ len ( dims ) + i ] , area [ i ] )
mask = mask * mask_strength
mask = mask . unsqueeze ( 1 ) . repeat ( input_x . shape [ 0 ] / / mask . shape [ 0 ] , input_x . shape [ 1 ] , 1 , 1 )
else :
mask = torch . ones_like ( input_x )
mult = mask * strength
if ' mask ' not in conds :
if ' mask ' not in conds and area is not None :
rr = 8
if area [ 2 ] != 0 :
for t in range ( rr ) :
mult [ : , : , t : 1 + t , : ] * = ( ( 1.0 / rr ) * ( t + 1 ) )
if ( area [ 0 ] + area [ 2 ] ) < x_in . shape [ 2 ] :
for t in range ( rr ) :
mult [ : , : , area [ 0 ] - 1 - t : area [ 0 ] - t , : ] * = ( ( 1.0 / rr ) * ( t + 1 ) )
if area [ 3 ] != 0 :
for t in range ( rr ) :
mult [ : , : , : , t : 1 + t ] * = ( ( 1.0 / rr ) * ( t + 1 ) )
if ( area [ 1 ] + area [ 3 ] ) < x_in . shape [ 3 ] :
for t in range ( rr ) :
mult [ : , : , : , area [ 1 ] - 1 - t : area [ 1 ] - t ] * = ( ( 1.0 / rr ) * ( t + 1 ) )
for i in range ( len ( dims ) ) :
if area [ len ( dims ) + i ] != 0 :
for t in range ( rr ) :
m = mult . narrow ( i + 2 , t , 1 )
m * = ( ( 1.0 / rr ) * ( t + 1 ) )
if ( area [ i ] + area [ len ( dims ) + i ] ) < x_in . shape [ i + 2 ] :
for t in range ( rr ) :
m = mult . narrow ( i + 2 , area [ i ] - 1 - t , 1 )
m * = ( ( 1.0 / rr ) * ( t + 1 ) )
conditioning = { }
model_conds = conds [ " model_conds " ]
@ -219,8 +227,19 @@ def calc_cond_batch(model, conds, x_in, timestep, model_options):
for o in range ( batch_chunks ) :
cond_index = cond_or_uncond [ o ]
out_conds [ cond_index ] [ : , : , area [ o ] [ 2 ] : area [ o ] [ 0 ] + area [ o ] [ 2 ] , area [ o ] [ 3 ] : area [ o ] [ 1 ] + area [ o ] [ 3 ] ] + = output [ o ] * mult [ o ]
out_counts [ cond_index ] [ : , : , area [ o ] [ 2 ] : area [ o ] [ 0 ] + area [ o ] [ 2 ] , area [ o ] [ 3 ] : area [ o ] [ 1 ] + area [ o ] [ 3 ] ] + = mult [ o ]
a = area [ o ]
if a is None :
out_conds [ cond_index ] + = output [ o ] * mult [ o ]
out_counts [ cond_index ] + = mult [ o ]
else :
out_c = out_conds [ cond_index ]
out_cts = out_counts [ cond_index ]
dims = len ( a ) / / 2
for i in range ( dims ) :
out_c = out_c . narrow ( i + 2 , a [ i + dims ] , a [ i ] )
out_cts = out_cts . narrow ( i + 2 , a [ i + dims ] , a [ i ] )
out_c + = output [ o ] * mult [ o ]
out_cts + = mult [ o ]
for i in range ( len ( out_conds ) ) :
out_conds [ i ] / = out_counts [ i ]
@ -335,7 +354,7 @@ def get_mask_aabb(masks):
return bounding_boxes , is_empty
def resolve_areas_and_cond_masks ( conditions , h , w , device ) :
def resolve_areas_and_cond_masks _multidim( conditions , dims , device ) :
# We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes.
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
for i in range ( len ( conditions ) ) :
@ -344,7 +363,14 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
area = c [ ' area ' ]
if area [ 0 ] == " percentage " :
modified = c . copy ( )
area = ( max ( 1 , round ( area [ 1 ] * h ) ) , max ( 1 , round ( area [ 2 ] * w ) ) , round ( area [ 3 ] * h ) , round ( area [ 4 ] * w ) )
a = area [ 1 : ]
a_len = len ( a ) / / 2
area = ( )
for d in range ( len ( dims ) ) :
area + = ( max ( 1 , round ( a [ d ] * dims [ d ] ) ) , )
for d in range ( len ( dims ) ) :
area + = ( round ( a [ d + a_len ] * dims [ d ] ) , )
modified [ ' area ' ] = area
c = modified
conditions [ i ] = c
@ -353,12 +379,12 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
mask = c [ ' mask ' ]
mask = mask . to ( device = device )
modified = c . copy ( )
if len ( mask . shape ) == 2 :
if len ( mask . shape ) == len ( dims ) :
mask = mask . unsqueeze ( 0 )
if mask . shape [ 1 ] != h or mask . shape [ 2 ] != w :
mask = torch . nn . functional . interpolate ( mask . unsqueeze ( 1 ) , size = ( h , w ) , mode = ' bilinear ' , align_corners = False ) . squeeze ( 1 )
if mask . shape [ 1 :] != dims :
mask = torch . nn . functional . interpolate ( mask . unsqueeze ( 1 ) , size = dims , mode = ' bilinear ' , align_corners = False ) . squeeze ( 1 )
if modified . get ( " set_area_to_bounds " , False ) :
if modified . get ( " set_area_to_bounds " , False ) : #TODO: handle dim != 2
bounds = torch . max ( torch . abs ( mask ) , dim = 0 ) . values . unsqueeze ( 0 )
boxes , is_empty = get_mask_aabb ( bounds )
if is_empty [ 0 ] :
@ -375,7 +401,11 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
modified [ ' mask ' ] = mask
conditions [ i ] = modified
def create_cond_with_same_area_if_none ( conds , c ) :
def resolve_areas_and_cond_masks ( conditions , h , w , device ) :
logging . warning ( " WARNING: The comfy.samplers.resolve_areas_and_cond_masks function is deprecated please use the resolve_areas_and_cond_masks_multidim one instead. " )
return resolve_areas_and_cond_masks_multidim ( conditions , [ h , w ] , device )
def create_cond_with_same_area_if_none ( conds , c ) : #TODO: handle dim != 2
if ' area ' not in c :
return
@ -479,7 +509,10 @@ def encode_model_conds(model_function, conds, noise, device, prompt_type, **kwar
params = x . copy ( )
params [ " device " ] = device
params [ " noise " ] = noise
params [ " width " ] = params . get ( " width " , noise . shape [ 3 ] * 8 )
default_width = None
if len ( noise . shape ) > = 4 : #TODO: 8 multiple should be set by the model
default_width = noise . shape [ 3 ] * 8
params [ " width " ] = params . get ( " width " , default_width )
params [ " height " ] = params . get ( " height " , noise . shape [ 2 ] * 8 )
params [ " prompt_type " ] = params . get ( " prompt_type " , prompt_type )
for k in kwargs :
@ -567,7 +600,7 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
def process_conds ( model , noise , conds , device , latent_image = None , denoise_mask = None , seed = None ) :
for k in conds :
conds [ k ] = conds [ k ] [ : ]
resolve_areas_and_cond_masks ( conds [ k ] , noise . shape [ 2 ], noise . shape [ 3 ] , device )
resolve_areas_and_cond_masks _multidim ( conds [ k ] , noise . shape [ 2 : ] , device )
for k in conds :
calculate_start_end_timesteps ( model , conds [ k ] )