Fix some performance issues with weight loading and unloading.

Lower peak memory usage when changing model.

Fix case where model weights would be unloaded and reloaded.
main
comfyanonymous 11 months ago
parent 327ca1313d
commit 5d8898c056

@ -274,6 +274,7 @@ class LoadedModel:
self.model = model
self.device = model.load_device
self.weights_loaded = False
self.real_model = None
def model_memory(self):
return self.model.model_size()
@ -312,6 +313,7 @@ class LoadedModel:
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
self.model.model_patches_to(self.model.offload_device)
self.weights_loaded = self.weights_loaded and not unpatch_weights
self.real_model = None
def __eq__(self, other):
return self.model is other.model
@ -326,7 +328,7 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
to_unload = [i] + to_unload
if len(to_unload) == 0:
return None
return True
same_weights = 0
for i in to_unload:
@ -408,8 +410,8 @@ def load_models_gpu(models, memory_required=0):
total_memory_required = {}
for loaded_model in models_to_load:
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) == True:#unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for device in total_memory_required:
if device != torch.device("cpu"):
@ -448,11 +450,15 @@ def load_models_gpu(models, memory_required=0):
def load_model_gpu(model):
return load_models_gpu([model])
def cleanup_models():
def cleanup_models(keep_clone_weights_loaded=False):
to_delete = []
for i in range(len(current_loaded_models)):
if sys.getrefcount(current_loaded_models[i].model) <= 2:
to_delete = [i] + to_delete
if not keep_clone_weights_loaded:
to_delete = [i] + to_delete
#TODO: find a less fragile way to do this.
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model
to_delete = [i] + to_delete
for i in to_delete:
x = current_loaded_models.pop(i)

@ -368,6 +368,7 @@ class PromptExecutor:
d = self.outputs_ui.pop(x)
del d
comfy.model_management.cleanup_models(keep_clone_weights_loaded=True)
self.add_message("execution_cached",
{ "nodes": list(current_outputs) , "prompt_id": prompt_id},
broadcast=False)

Loading…
Cancel
Save