|
|
|
@ -387,15 +387,14 @@ class ModelPatcher:
|
|
|
|
|
m = x[2]
|
|
|
|
|
weight_key = "{}.weight".format(n)
|
|
|
|
|
bias_key = "{}.bias".format(n)
|
|
|
|
|
param = list(m.parameters())
|
|
|
|
|
if len(param) > 0:
|
|
|
|
|
weight = param[0]
|
|
|
|
|
if weight.device == device_to:
|
|
|
|
|
if hasattr(m, "comfy_patched_weights"):
|
|
|
|
|
if m.comfy_patched_weights == True:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
self.patch_weight_to_device(weight_key, device_to=device_to)
|
|
|
|
|
self.patch_weight_to_device(bias_key, device_to=device_to)
|
|
|
|
|
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
|
|
|
|
m.comfy_patched_weights = True
|
|
|
|
|
|
|
|
|
|
for x in load_completely:
|
|
|
|
|
x[2].to(device_to)
|
|
|
|
@ -622,6 +621,10 @@ class ModelPatcher:
|
|
|
|
|
self.model.device = device_to
|
|
|
|
|
self.model.model_loaded_weight_memory = 0
|
|
|
|
|
|
|
|
|
|
for m in self.model.modules():
|
|
|
|
|
if hasattr(m, "comfy_patched_weights"):
|
|
|
|
|
del m.comfy_patched_weights
|
|
|
|
|
|
|
|
|
|
keys = list(self.object_patches_backup.keys())
|
|
|
|
|
for k in keys:
|
|
|
|
|
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
|
|
|
|
@ -649,7 +652,7 @@ class ModelPatcher:
|
|
|
|
|
weight_key = "{}.weight".format(n)
|
|
|
|
|
bias_key = "{}.bias".format(n)
|
|
|
|
|
|
|
|
|
|
if m.weight is not None and m.weight.device != device_to:
|
|
|
|
|
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
|
|
|
|
for key in [weight_key, bias_key]:
|
|
|
|
|
bk = self.backup.get(key, None)
|
|
|
|
|
if bk is not None:
|
|
|
|
@ -669,6 +672,7 @@ class ModelPatcher:
|
|
|
|
|
|
|
|
|
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
|
|
|
|
m.comfy_cast_weights = True
|
|
|
|
|
m.comfy_patched_weights = False
|
|
|
|
|
memory_freed += module_mem
|
|
|
|
|
logging.debug("freed {}".format(n))
|
|
|
|
|
|
|
|
|
|