diff --git a/comfy/model_management.py b/comfy/model_management.py index 33401b2..44e3541 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -562,7 +562,7 @@ def cleanup_models(keep_clone_weights_loaded=False): to_delete = [] for i in range(len(current_loaded_models)): #TODO: very fragile function needs improvement - num_refs = sys.getrefcount(current_loaded_models[i].model) - current_loaded_models[i].model.lowvram_patch_counter() + num_refs = sys.getrefcount(current_loaded_models[i].model) if num_refs <= 2: if not keep_clone_weights_loaded: to_delete = [i] + to_delete diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index ae33687..e25568f 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -91,13 +91,180 @@ def wipe_lowvram_weight(m): m.weight_function = None m.bias_function = None +def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32): + for p in patches: + strength = p[0] + v = p[1] + strength_model = p[2] + offset = p[3] + function = p[4] + if function is None: + function = lambda a: a + + old_weight = None + if offset is not None: + old_weight = weight + weight = weight.narrow(offset[0], offset[1], offset[2]) + + if strength_model != 1.0: + weight *= strength_model + + if isinstance(v, list): + v = (calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype), ) + + if len(v) == 1: + patch_type = "diff" + elif len(v) == 2: + patch_type = v[0] + v = v[1] + + if patch_type == "diff": + w1 = v[0] + if strength != 0.0: + if w1.shape != weight.shape: + logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) + else: + weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)) + elif patch_type == "lora": #lora/locon + mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype) + mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype) + dora_scale = v[4] + if v[2] is not None: + alpha = v[2] / mat2.shape[0] + else: + alpha = 1.0 + + if v[3] is not None: + #locon mid weights, hopefully the math is fine because I didn't properly test it + mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype) + final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] + mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) + try: + lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape) + if dora_scale is not None: + weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype)) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + logging.error("ERROR {} {} {}".format(patch_type, key, e)) + elif patch_type == "lokr": + w1 = v[0] + w2 = v[1] + w1_a = v[3] + w1_b = v[4] + w2_a = v[5] + w2_b = v[6] + t2 = v[7] + dora_scale = v[8] + dim = None + + if w1 is None: + dim = w1_b.shape[0] + w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype)) + else: + w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype) + + if w2 is None: + dim = w2_b.shape[0] + if t2 is None: + w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype)) + else: + w2 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype)) + else: + w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype) + + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + if v[2] is not None and dim is not None: + alpha = v[2] / dim + else: + alpha = 1.0 + + try: + lora_diff = torch.kron(w1, w2).reshape(weight.shape) + if dora_scale is not None: + weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype)) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + logging.error("ERROR {} {} {}".format(patch_type, key, e)) + elif patch_type == "loha": + w1a = v[0] + w1b = v[1] + if v[2] is not None: + alpha = v[2] / w1b.shape[0] + else: + alpha = 1.0 + + w2a = v[3] + w2b = v[4] + dora_scale = v[7] + if v[5] is not None: #cp decomposition + t1 = v[5] + t2 = v[6] + m1 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype)) + + m2 = torch.einsum('i j k l, j r, i p -> p r k l', + comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype)) + else: + m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype)) + m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype), + comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype)) + + try: + lora_diff = (m1 * m2).reshape(weight.shape) + if dora_scale is not None: + weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype)) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + logging.error("ERROR {} {} {}".format(patch_type, key, e)) + elif patch_type == "glora": + if v[4] is not None: + alpha = v[4] / v[0].shape[0] + else: + alpha = 1.0 + + dora_scale = v[5] + + a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype) + a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype) + b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype) + b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype) + + try: + lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape) + if dora_scale is not None: + weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype)) + else: + weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) + except Exception as e: + logging.error("ERROR {} {} {}".format(patch_type, key, e)) + else: + logging.warning("patch type not recognized {} {}".format(patch_type, key)) + + if old_weight is not None: + weight = old_weight + + return weight + class LowVramPatch: - def __init__(self, key, model_patcher): + def __init__(self, key, patches): self.key = key - self.model_patcher = model_patcher + self.patches = patches def __call__(self, weight): - return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype) - + return calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype) class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): @@ -329,7 +496,7 @@ class ModelPatcher: temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) else: temp_weight = weight.to(torch.float32, copy=True) - out_weight = self.calculate_weight(self.patches[key], temp_weight, key) + out_weight = calculate_weight(self.patches[key], temp_weight, key) out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype) if inplace_update: comfy.utils.copy_to_param(self.model, key, out_weight) @@ -360,13 +527,13 @@ class ModelPatcher: if force_patch_weights: self.patch_weight_to_device(weight_key) else: - m.weight_function = LowVramPatch(weight_key, self) + m.weight_function = LowVramPatch(weight_key, self.model_patcher.patches) patch_counter += 1 if bias_key in self.patches: if force_patch_weights: self.patch_weight_to_device(bias_key) else: - m.bias_function = LowVramPatch(bias_key, self) + m.bias_function = LowVramPatch(bias_key, self.model_patcher.patches) patch_counter += 1 m.prev_comfy_cast_weights = m.comfy_cast_weights @@ -428,174 +595,6 @@ class ModelPatcher: self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load) return self.model - def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32): - for p in patches: - strength = p[0] - v = p[1] - strength_model = p[2] - offset = p[3] - function = p[4] - if function is None: - function = lambda a: a - - old_weight = None - if offset is not None: - old_weight = weight - weight = weight.narrow(offset[0], offset[1], offset[2]) - - if strength_model != 1.0: - weight *= strength_model - - if isinstance(v, list): - v = (self.calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype), ) - - if len(v) == 1: - patch_type = "diff" - elif len(v) == 2: - patch_type = v[0] - v = v[1] - - if patch_type == "diff": - w1 = v[0] - if strength != 0.0: - if w1.shape != weight.shape: - logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape)) - else: - weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)) - elif patch_type == "lora": #lora/locon - mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype) - mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype) - dora_scale = v[4] - if v[2] is not None: - alpha = v[2] / mat2.shape[0] - else: - alpha = 1.0 - - if v[3] is not None: - #locon mid weights, hopefully the math is fine because I didn't properly test it - mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype) - final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]] - mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) - try: - lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape) - if dora_scale is not None: - weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype)) - else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) - except Exception as e: - logging.error("ERROR {} {} {}".format(patch_type, key, e)) - elif patch_type == "lokr": - w1 = v[0] - w2 = v[1] - w1_a = v[3] - w1_b = v[4] - w2_a = v[5] - w2_b = v[6] - t2 = v[7] - dora_scale = v[8] - dim = None - - if w1 is None: - dim = w1_b.shape[0] - w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype)) - else: - w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype) - - if w2 is None: - dim = w2_b.shape[0] - if t2 is None: - w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype)) - else: - w2 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype)) - else: - w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype) - - if len(w2.shape) == 4: - w1 = w1.unsqueeze(2).unsqueeze(2) - if v[2] is not None and dim is not None: - alpha = v[2] / dim - else: - alpha = 1.0 - - try: - lora_diff = torch.kron(w1, w2).reshape(weight.shape) - if dora_scale is not None: - weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype)) - else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) - except Exception as e: - logging.error("ERROR {} {} {}".format(patch_type, key, e)) - elif patch_type == "loha": - w1a = v[0] - w1b = v[1] - if v[2] is not None: - alpha = v[2] / w1b.shape[0] - else: - alpha = 1.0 - - w2a = v[3] - w2b = v[4] - dora_scale = v[7] - if v[5] is not None: #cp decomposition - t1 = v[5] - t2 = v[6] - m1 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype)) - - m2 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype)) - else: - m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype)) - m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype)) - - try: - lora_diff = (m1 * m2).reshape(weight.shape) - if dora_scale is not None: - weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype)) - else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) - except Exception as e: - logging.error("ERROR {} {} {}".format(patch_type, key, e)) - elif patch_type == "glora": - if v[4] is not None: - alpha = v[4] / v[0].shape[0] - else: - alpha = 1.0 - - dora_scale = v[5] - - a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype) - a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype) - b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype) - b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype) - - try: - lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape) - if dora_scale is not None: - weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype)) - else: - weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) - except Exception as e: - logging.error("ERROR {} {} {}".format(patch_type, key, e)) - else: - logging.warning("patch type not recognized {} {}".format(patch_type, key)) - - if old_weight is not None: - weight = old_weight - - return weight - def unpatch_model(self, device_to=None, unpatch_weights=True): if unpatch_weights: if self.model.model_lowvram: @@ -695,3 +694,7 @@ class ModelPatcher: def current_loaded_device(self): return self.model.device + + def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32): + print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.model_patcher.calculate_weight instead") + return calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)