Add a type of model patch useful for model merging.

main
comfyanonymous 2 years ago
parent 186f92042b
commit 45beebd33c

@ -347,15 +347,23 @@ class ModelPatcher:
def model_dtype(self):
return self.model.get_dtype()
def add_patches(self, patches, strength=1.0):
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
p = {}
model_sd = self.model.state_dict()
for k in patches:
if k in model_sd:
p[k] = patches[k]
self.patches += [(strength, p)]
self.patches += [(strength_patch, p, strength_model)]
return p.keys()
def model_state_dict(self):
sd = self.model.state_dict()
keys = list(sd.keys())
for k in keys:
if not k.startswith("diffusion_model."):
sd.pop(k)
return sd
def patch_model(self):
model_sd = self.model.state_dict()
for p in self.patches:
@ -371,8 +379,14 @@ class ModelPatcher:
self.backup[key] = weight.clone()
alpha = p[0]
strength_model = p[2]
if strength_model != 1.0:
weight *= strength_model
if len(v) == 4: #lora/locon
if len(v) == 1:
weight += alpha * (v[0]).type(weight.dtype).to(weight.device)
elif len(v) == 4: #lora/locon
mat1 = v[0]
mat2 = v[1]
if v[2] is not None:

Loading…
Cancel
Save