From 45beebd33cd086f1b46e7e7054ba065d3a999cfe Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 20 Jun 2023 17:34:11 -0400 Subject: [PATCH] Add a type of model patch useful for model merging. --- comfy/sd.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index e6cda51..0ff918c 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -347,15 +347,23 @@ class ModelPatcher: def model_dtype(self): return self.model.get_dtype() - def add_patches(self, patches, strength=1.0): + def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): p = {} model_sd = self.model.state_dict() for k in patches: if k in model_sd: p[k] = patches[k] - self.patches += [(strength, p)] + self.patches += [(strength_patch, p, strength_model)] return p.keys() + def model_state_dict(self): + sd = self.model.state_dict() + keys = list(sd.keys()) + for k in keys: + if not k.startswith("diffusion_model."): + sd.pop(k) + return sd + def patch_model(self): model_sd = self.model.state_dict() for p in self.patches: @@ -371,8 +379,14 @@ class ModelPatcher: self.backup[key] = weight.clone() alpha = p[0] + strength_model = p[2] + + if strength_model != 1.0: + weight *= strength_model - if len(v) == 4: #lora/locon + if len(v) == 1: + weight += alpha * (v[0]).type(weight.dtype).to(weight.device) + elif len(v) == 4: #lora/locon mat1 = v[0] mat2 = v[1] if v[2] is not None: