|
|
|
@ -379,11 +379,15 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
|
|
|
|
if mem_free_torch > mem_free_total * 0.25:
|
|
|
|
|
soft_empty_cache()
|
|
|
|
|
|
|
|
|
|
def load_models_gpu(models, memory_required=0, force_patch_weights=False):
|
|
|
|
|
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None):
|
|
|
|
|
global vram_state
|
|
|
|
|
|
|
|
|
|
inference_memory = minimum_inference_memory()
|
|
|
|
|
extra_mem = max(inference_memory, memory_required)
|
|
|
|
|
if minimum_memory_required is None:
|
|
|
|
|
minimum_memory_required = extra_mem
|
|
|
|
|
else:
|
|
|
|
|
minimum_memory_required = max(inference_memory, minimum_memory_required)
|
|
|
|
|
|
|
|
|
|
models = set(models)
|
|
|
|
|
|
|
|
|
@ -446,8 +450,8 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False):
|
|
|
|
|
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
|
|
|
|
model_size = loaded_model.model_memory_required(torch_dev)
|
|
|
|
|
current_free_mem = get_free_memory(torch_dev)
|
|
|
|
|
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - extra_mem)))
|
|
|
|
|
if model_size <= (current_free_mem - inference_memory): #only switch to lowvram if really necessary
|
|
|
|
|
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required)))
|
|
|
|
|
if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary
|
|
|
|
|
lowvram_model_memory = 0
|
|
|
|
|
|
|
|
|
|
if vram_set_state == VRAMState.NO_VRAM:
|
|
|
|
|