|
|
|
@ -274,6 +274,7 @@ class LoadedModel:
|
|
|
|
|
self.model = model
|
|
|
|
|
self.device = model.load_device
|
|
|
|
|
self.weights_loaded = False
|
|
|
|
|
self.real_model = None
|
|
|
|
|
|
|
|
|
|
def model_memory(self):
|
|
|
|
|
return self.model.model_size()
|
|
|
|
@ -312,6 +313,7 @@ class LoadedModel:
|
|
|
|
|
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
|
|
|
|
|
self.model.model_patches_to(self.model.offload_device)
|
|
|
|
|
self.weights_loaded = self.weights_loaded and not unpatch_weights
|
|
|
|
|
self.real_model = None
|
|
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
|
return self.model is other.model
|
|
|
|
@ -326,7 +328,7 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
|
|
|
|
|
to_unload = [i] + to_unload
|
|
|
|
|
|
|
|
|
|
if len(to_unload) == 0:
|
|
|
|
|
return None
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
same_weights = 0
|
|
|
|
|
for i in to_unload:
|
|
|
|
@ -408,8 +410,8 @@ def load_models_gpu(models, memory_required=0):
|
|
|
|
|
|
|
|
|
|
total_memory_required = {}
|
|
|
|
|
for loaded_model in models_to_load:
|
|
|
|
|
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different
|
|
|
|
|
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
|
|
|
|
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) == True:#unload clones where the weights are different
|
|
|
|
|
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
|
|
|
|
|
|
|
|
|
for device in total_memory_required:
|
|
|
|
|
if device != torch.device("cpu"):
|
|
|
|
@ -448,11 +450,15 @@ def load_models_gpu(models, memory_required=0):
|
|
|
|
|
def load_model_gpu(model):
|
|
|
|
|
return load_models_gpu([model])
|
|
|
|
|
|
|
|
|
|
def cleanup_models():
|
|
|
|
|
def cleanup_models(keep_clone_weights_loaded=False):
|
|
|
|
|
to_delete = []
|
|
|
|
|
for i in range(len(current_loaded_models)):
|
|
|
|
|
if sys.getrefcount(current_loaded_models[i].model) <= 2:
|
|
|
|
|
to_delete = [i] + to_delete
|
|
|
|
|
if not keep_clone_weights_loaded:
|
|
|
|
|
to_delete = [i] + to_delete
|
|
|
|
|
#TODO: find a less fragile way to do this.
|
|
|
|
|
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model
|
|
|
|
|
to_delete = [i] + to_delete
|
|
|
|
|
|
|
|
|
|
for i in to_delete:
|
|
|
|
|
x = current_loaded_models.pop(i)
|
|
|
|
|