diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 51259b5..ae33687 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -387,15 +387,14 @@ class ModelPatcher: m = x[2] weight_key = "{}.weight".format(n) bias_key = "{}.bias".format(n) - param = list(m.parameters()) - if len(param) > 0: - weight = param[0] - if weight.device == device_to: + if hasattr(m, "comfy_patched_weights"): + if m.comfy_patched_weights == True: continue self.patch_weight_to_device(weight_key, device_to=device_to) self.patch_weight_to_device(bias_key, device_to=device_to) logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) + m.comfy_patched_weights = True for x in load_completely: x[2].to(device_to) @@ -622,6 +621,10 @@ class ModelPatcher: self.model.device = device_to self.model.model_loaded_weight_memory = 0 + for m in self.model.modules(): + if hasattr(m, "comfy_patched_weights"): + del m.comfy_patched_weights + keys = list(self.object_patches_backup.keys()) for k in keys: comfy.utils.set_attr(self.model, k, self.object_patches_backup[k]) @@ -649,7 +652,7 @@ class ModelPatcher: weight_key = "{}.weight".format(n) bias_key = "{}.bias".format(n) - if m.weight is not None and m.weight.device != device_to: + if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: for key in [weight_key, bias_key]: bk = self.backup.get(key, None) if bk is not None: @@ -669,6 +672,7 @@ class ModelPatcher: m.prev_comfy_cast_weights = m.comfy_cast_weights m.comfy_cast_weights = True + m.comfy_patched_weights = False memory_freed += module_mem logging.debug("freed {}".format(n))