@ -352,6 +352,7 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
def free_memory ( memory_required , device , keep_loaded = [ ] ) :
unloaded_model = [ ]
can_unload = [ ]
unloaded_models = [ ]
for i in range ( len ( current_loaded_models ) - 1 , - 1 , - 1 ) :
shift_model = current_loaded_models [ i ]
@ -369,7 +370,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
unloaded_model . append ( i )
for i in sorted ( unloaded_model , reverse = True ) :
current_loaded_models. pop ( i )
unloaded_models. append ( current_loaded_models. pop ( i ) )
if len ( unloaded_model ) > 0 :
soft_empty_cache ( )
@ -378,6 +379,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
mem_free_total , mem_free_torch = get_free_memory ( device , torch_free_too = True )
if mem_free_torch > mem_free_total * 0.25 :
soft_empty_cache ( )
return unloaded_models
def load_models_gpu ( models , memory_required = 0 , force_patch_weights = False , minimum_memory_required = None ) :
global vram_state
@ -421,7 +423,13 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
for d in devs :
if d != torch . device ( " cpu " ) :
free_memory ( extra_mem , d , models_already_loaded )
return
free_mem = get_free_memory ( d )
if free_mem < minimum_memory_required :
logging . info ( " Unloading models for lowram load. " ) #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed.
models_to_load = free_memory ( minimum_memory_required , d )
logging . info ( " {} models unloaded. " . format ( len ( models_to_load ) ) )
if len ( models_to_load ) == 0 :
return
logging . info ( f " Loading { len ( models_to_load ) } new model { ' s ' if len ( models_to_load ) > 1 else ' ' } " )