@ -296,7 +296,7 @@ class LoadedModel:
def model_memory_required ( self , device ) :
def model_memory_required ( self , device ) :
if device == self . model . current_loaded_device ( ) :
if device == self . model . current_loaded_device ( ) :
return 0
return self . model_offloaded_memory ( )
else :
else :
return self . model_memory ( )
return self . model_memory ( )
@ -308,15 +308,21 @@ class LoadedModel:
load_weights = not self . weights_loaded
load_weights = not self . weights_loaded
try :
if self . model . loaded_size ( ) > 0 :
if lowvram_model_memory > 0 and load_weights :
use_more_vram = lowvram_model_memory
self . real_model = self . model . patch_model_lowvram ( device_to = patch_model_to , lowvram_model_memory = lowvram_model_memory , force_patch_weights = force_patch_weights )
if use_more_vram == 0 :
else :
use_more_vram = 1e32
self . real_model = self . model . patch_model ( device_to = patch_model_to , patch_weights = load_weights )
self . model_use_more_vram ( use_more_vram )
except Exception as e :
else :
self . model . unpatch_model ( self . model . offload_device )
try :
self . model_unload ( )
if lowvram_model_memory > 0 and load_weights :
raise e
self . real_model = self . model . patch_model_lowvram ( device_to = patch_model_to , lowvram_model_memory = lowvram_model_memory , force_patch_weights = force_patch_weights )
else :
self . real_model = self . model . patch_model ( device_to = patch_model_to , patch_weights = load_weights )
except Exception as e :
self . model . unpatch_model ( self . model . offload_device )
self . model_unload ( )
raise e
if is_intel_xpu ( ) and not args . disable_ipex_optimize :
if is_intel_xpu ( ) and not args . disable_ipex_optimize :
self . real_model = ipex . optimize ( self . real_model . eval ( ) , graph_mode = True , concat_linear = True )
self . real_model = ipex . optimize ( self . real_model . eval ( ) , graph_mode = True , concat_linear = True )
@ -484,18 +490,21 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
total_memory_required = { }
total_memory_required = { }
for loaded_model in models_to_load :
for loaded_model in models_to_load :
if unload_model_clones ( loaded_model . model , unload_weights_only = True , force_unload = False ) == True : #unload clones where the weights are different
unload_model_clones ( loaded_model . model , unload_weights_only = True , force_unload = False ) #unload clones where the weights are different
total_memory_required [ loaded_model . device ] = total_memory_required . get ( loaded_model . device , 0 ) + loaded_model . model_memory_required ( loaded_model . device )
total_memory_required [ loaded_model . device ] = total_memory_required . get ( loaded_model . device , 0 ) + loaded_model . model_memory_required ( loaded_model . device )
for device in total_memory_required :
for loaded_model in models_already_loaded :
if device != torch . device ( " cpu " ) :
total_memory_required [ loaded_model . device ] = total_memory_required . get ( loaded_model . device , 0 ) + loaded_model . model_memory_required ( loaded_model . device )
free_memory ( total_memory_required [ device ] * 1.3 + extra_mem , device , models_already_loaded )
for loaded_model in models_to_load :
for loaded_model in models_to_load :
weights_unloaded = unload_model_clones ( loaded_model . model , unload_weights_only = False , force_unload = False ) #unload the rest of the clones where the weights can stay loaded
weights_unloaded = unload_model_clones ( loaded_model . model , unload_weights_only = False , force_unload = False ) #unload the rest of the clones where the weights can stay loaded
if weights_unloaded is not None :
if weights_unloaded is not None :
loaded_model . weights_loaded = not weights_unloaded
loaded_model . weights_loaded = not weights_unloaded
for device in total_memory_required :
if device != torch . device ( " cpu " ) :
free_memory ( total_memory_required [ device ] * 1.1 + extra_mem , device , models_already_loaded )
for loaded_model in models_to_load :
for loaded_model in models_to_load :
model = loaded_model . model
model = loaded_model . model
torch_dev = model . load_device
torch_dev = model . load_device