|
|
@ -644,18 +644,24 @@ class ModelPatcher:
|
|
|
|
def partially_unload(self, device_to, memory_to_free=0):
|
|
|
|
def partially_unload(self, device_to, memory_to_free=0):
|
|
|
|
memory_freed = 0
|
|
|
|
memory_freed = 0
|
|
|
|
patch_counter = 0
|
|
|
|
patch_counter = 0
|
|
|
|
|
|
|
|
unload_list = []
|
|
|
|
|
|
|
|
|
|
|
|
for n, m in list(self.model.named_modules())[::-1]:
|
|
|
|
for n, m in self.model.named_modules():
|
|
|
|
if memory_to_free < memory_freed:
|
|
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
shift_lowvram = False
|
|
|
|
shift_lowvram = False
|
|
|
|
if hasattr(m, "comfy_cast_weights"):
|
|
|
|
if hasattr(m, "comfy_cast_weights"):
|
|
|
|
module_mem = comfy.model_management.module_size(m)
|
|
|
|
module_mem = comfy.model_management.module_size(m)
|
|
|
|
|
|
|
|
unload_list.append((module_mem, n, m))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unload_list.sort()
|
|
|
|
|
|
|
|
for unload in unload_list:
|
|
|
|
|
|
|
|
if memory_to_free < memory_freed:
|
|
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
module_mem = unload[0]
|
|
|
|
|
|
|
|
n = unload[1]
|
|
|
|
|
|
|
|
m = unload[2]
|
|
|
|
weight_key = "{}.weight".format(n)
|
|
|
|
weight_key = "{}.weight".format(n)
|
|
|
|
bias_key = "{}.bias".format(n)
|
|
|
|
bias_key = "{}.bias".format(n)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if m.weight is not None and m.weight.device != device_to:
|
|
|
|
if m.weight is not None and m.weight.device != device_to:
|
|
|
|
for key in [weight_key, bias_key]:
|
|
|
|
for key in [weight_key, bias_key]:
|
|
|
|
bk = self.backup.get(key, None)
|
|
|
|
bk = self.backup.get(key, None)
|
|
|
|