Bug fixes.

main
comfyanonymous 6 months ago
parent 6138f92084
commit 4d341b78e8

@ -387,15 +387,14 @@ class ModelPatcher:
m = x[2] m = x[2]
weight_key = "{}.weight".format(n) weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n) bias_key = "{}.bias".format(n)
param = list(m.parameters()) if hasattr(m, "comfy_patched_weights"):
if len(param) > 0: if m.comfy_patched_weights == True:
weight = param[0]
if weight.device == device_to:
continue continue
self.patch_weight_to_device(weight_key, device_to=device_to) self.patch_weight_to_device(weight_key, device_to=device_to)
self.patch_weight_to_device(bias_key, device_to=device_to) self.patch_weight_to_device(bias_key, device_to=device_to)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
m.comfy_patched_weights = True
for x in load_completely: for x in load_completely:
x[2].to(device_to) x[2].to(device_to)
@ -622,6 +621,10 @@ class ModelPatcher:
self.model.device = device_to self.model.device = device_to
self.model.model_loaded_weight_memory = 0 self.model.model_loaded_weight_memory = 0
for m in self.model.modules():
if hasattr(m, "comfy_patched_weights"):
del m.comfy_patched_weights
keys = list(self.object_patches_backup.keys()) keys = list(self.object_patches_backup.keys())
for k in keys: for k in keys:
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k]) comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
@ -649,7 +652,7 @@ class ModelPatcher:
weight_key = "{}.weight".format(n) weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n) bias_key = "{}.bias".format(n)
if m.weight is not None and m.weight.device != device_to: if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
for key in [weight_key, bias_key]: for key in [weight_key, bias_key]:
bk = self.backup.get(key, None) bk = self.backup.get(key, None)
if bk is not None: if bk is not None:
@ -669,6 +672,7 @@ class ModelPatcher:
m.prev_comfy_cast_weights = m.comfy_cast_weights m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True m.comfy_cast_weights = True
m.comfy_patched_weights = False
memory_freed += module_mem memory_freed += module_mem
logging.debug("freed {}".format(n)) logging.debug("freed {}".format(n))

Loading…
Cancel
Save