|
|
|
@ -206,7 +206,7 @@ class ModelPatcher:
|
|
|
|
|
def __init__(self, model, load_device, offload_device, size=0):
|
|
|
|
|
self.size = size
|
|
|
|
|
self.model = model
|
|
|
|
|
self.patches = []
|
|
|
|
|
self.patches = {}
|
|
|
|
|
self.backup = {}
|
|
|
|
|
self.model_options = {"transformer_options":{}}
|
|
|
|
|
self.model_size()
|
|
|
|
@ -227,7 +227,10 @@ class ModelPatcher:
|
|
|
|
|
|
|
|
|
|
def clone(self):
|
|
|
|
|
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size)
|
|
|
|
|
n.patches = self.patches[:]
|
|
|
|
|
n.patches = {}
|
|
|
|
|
for k in self.patches:
|
|
|
|
|
n.patches[k] = self.patches[k][:]
|
|
|
|
|
|
|
|
|
|
n.model_options = copy.deepcopy(self.model_options)
|
|
|
|
|
n.model_keys = self.model_keys
|
|
|
|
|
return n
|
|
|
|
@ -295,12 +298,28 @@ class ModelPatcher:
|
|
|
|
|
return self.model.get_dtype()
|
|
|
|
|
|
|
|
|
|
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
|
|
|
|
p = {}
|
|
|
|
|
p = set()
|
|
|
|
|
for k in patches:
|
|
|
|
|
if k in self.model_keys:
|
|
|
|
|
p[k] = patches[k]
|
|
|
|
|
self.patches += [(strength_patch, p, strength_model)]
|
|
|
|
|
return p.keys()
|
|
|
|
|
p.add(k)
|
|
|
|
|
current_patches = self.patches.get(k, [])
|
|
|
|
|
current_patches.append((strength_patch, patches[k], strength_model))
|
|
|
|
|
self.patches[k] = current_patches
|
|
|
|
|
|
|
|
|
|
return list(p)
|
|
|
|
|
|
|
|
|
|
def get_key_patches(self, filter_prefix=None):
|
|
|
|
|
model_sd = self.model_state_dict()
|
|
|
|
|
p = {}
|
|
|
|
|
for k in model_sd:
|
|
|
|
|
if filter_prefix is not None:
|
|
|
|
|
if not k.startswith(filter_prefix):
|
|
|
|
|
continue
|
|
|
|
|
if k in self.patches:
|
|
|
|
|
p[k] = [model_sd[k]] + self.patches[k]
|
|
|
|
|
else:
|
|
|
|
|
p[k] = (model_sd[k],)
|
|
|
|
|
return p
|
|
|
|
|
|
|
|
|
|
def model_state_dict(self, filter_prefix=None):
|
|
|
|
|
sd = self.model.state_dict()
|
|
|
|
@ -313,85 +332,93 @@ class ModelPatcher:
|
|
|
|
|
|
|
|
|
|
def patch_model(self):
|
|
|
|
|
model_sd = self.model_state_dict()
|
|
|
|
|
for p in self.patches:
|
|
|
|
|
for k in p[1]:
|
|
|
|
|
v = p[1][k]
|
|
|
|
|
key = k
|
|
|
|
|
if key not in model_sd:
|
|
|
|
|
print("could not patch. key doesn't exist in model:", k)
|
|
|
|
|
continue
|
|
|
|
|
for key in self.patches:
|
|
|
|
|
if key not in model_sd:
|
|
|
|
|
print("could not patch. key doesn't exist in model:", k)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
weight = model_sd[key]
|
|
|
|
|
if key not in self.backup:
|
|
|
|
|
self.backup[key] = weight.clone()
|
|
|
|
|
weight = model_sd[key]
|
|
|
|
|
|
|
|
|
|
alpha = p[0]
|
|
|
|
|
strength_model = p[2]
|
|
|
|
|
if key not in self.backup:
|
|
|
|
|
self.backup[key] = weight.clone()
|
|
|
|
|
|
|
|
|
|
if strength_model != 1.0:
|
|
|
|
|
weight *= strength_model
|
|
|
|
|
weight[:] = self.calculate_weight(self.patches[key], weight.clone(), key)
|
|
|
|
|
return self.model
|
|
|
|
|
|
|
|
|
|
if len(v) == 1:
|
|
|
|
|
w1 = v[0]
|
|
|
|
|
if w1.shape != weight.shape:
|
|
|
|
|
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
|
|
|
|
else:
|
|
|
|
|
weight += alpha * w1.type(weight.dtype).to(weight.device)
|
|
|
|
|
elif len(v) == 4: #lora/locon
|
|
|
|
|
mat1 = v[0]
|
|
|
|
|
mat2 = v[1]
|
|
|
|
|
if v[2] is not None:
|
|
|
|
|
alpha *= v[2] / mat2.shape[0]
|
|
|
|
|
if v[3] is not None:
|
|
|
|
|
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
|
|
|
|
final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]]
|
|
|
|
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1)
|
|
|
|
|
weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
|
|
|
|
elif len(v) == 8: #lokr
|
|
|
|
|
w1 = v[0]
|
|
|
|
|
w2 = v[1]
|
|
|
|
|
w1_a = v[3]
|
|
|
|
|
w1_b = v[4]
|
|
|
|
|
w2_a = v[5]
|
|
|
|
|
w2_b = v[6]
|
|
|
|
|
t2 = v[7]
|
|
|
|
|
dim = None
|
|
|
|
|
|
|
|
|
|
if w1 is None:
|
|
|
|
|
dim = w1_b.shape[0]
|
|
|
|
|
w1 = torch.mm(w1_a.float(), w1_b.float())
|
|
|
|
|
|
|
|
|
|
if w2 is None:
|
|
|
|
|
dim = w2_b.shape[0]
|
|
|
|
|
if t2 is None:
|
|
|
|
|
w2 = torch.mm(w2_a.float(), w2_b.float())
|
|
|
|
|
else:
|
|
|
|
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2_b.float(), w2_a.float())
|
|
|
|
|
|
|
|
|
|
if len(w2.shape) == 4:
|
|
|
|
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
|
|
|
if v[2] is not None and dim is not None:
|
|
|
|
|
alpha *= v[2] / dim
|
|
|
|
|
|
|
|
|
|
weight += alpha * torch.kron(w1.float(), w2.float()).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
|
|
|
|
else: #loha
|
|
|
|
|
w1a = v[0]
|
|
|
|
|
w1b = v[1]
|
|
|
|
|
if v[2] is not None:
|
|
|
|
|
alpha *= v[2] / w1b.shape[0]
|
|
|
|
|
w2a = v[3]
|
|
|
|
|
w2b = v[4]
|
|
|
|
|
if v[5] is not None: #cp decomposition
|
|
|
|
|
t1 = v[5]
|
|
|
|
|
t2 = v[6]
|
|
|
|
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float(), w1b.float(), w1a.float())
|
|
|
|
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2b.float(), w2a.float())
|
|
|
|
|
def calculate_weight(self, patches, weight, key):
|
|
|
|
|
for p in patches:
|
|
|
|
|
alpha = p[0]
|
|
|
|
|
v = p[1]
|
|
|
|
|
strength_model = p[2]
|
|
|
|
|
|
|
|
|
|
if strength_model != 1.0:
|
|
|
|
|
weight *= strength_model
|
|
|
|
|
|
|
|
|
|
if isinstance(v, list):
|
|
|
|
|
v = (self.calculate_weight(v[1:], v[0].clone(), key), )
|
|
|
|
|
|
|
|
|
|
if len(v) == 1:
|
|
|
|
|
w1 = v[0]
|
|
|
|
|
if w1.shape != weight.shape:
|
|
|
|
|
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
|
|
|
|
else:
|
|
|
|
|
weight += alpha * w1.type(weight.dtype).to(weight.device)
|
|
|
|
|
elif len(v) == 4: #lora/locon
|
|
|
|
|
mat1 = v[0]
|
|
|
|
|
mat2 = v[1]
|
|
|
|
|
if v[2] is not None:
|
|
|
|
|
alpha *= v[2] / mat2.shape[0]
|
|
|
|
|
if v[3] is not None:
|
|
|
|
|
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
|
|
|
|
final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]]
|
|
|
|
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1)
|
|
|
|
|
weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
|
|
|
|
elif len(v) == 8: #lokr
|
|
|
|
|
w1 = v[0]
|
|
|
|
|
w2 = v[1]
|
|
|
|
|
w1_a = v[3]
|
|
|
|
|
w1_b = v[4]
|
|
|
|
|
w2_a = v[5]
|
|
|
|
|
w2_b = v[6]
|
|
|
|
|
t2 = v[7]
|
|
|
|
|
dim = None
|
|
|
|
|
|
|
|
|
|
if w1 is None:
|
|
|
|
|
dim = w1_b.shape[0]
|
|
|
|
|
w1 = torch.mm(w1_a.float(), w1_b.float())
|
|
|
|
|
|
|
|
|
|
if w2 is None:
|
|
|
|
|
dim = w2_b.shape[0]
|
|
|
|
|
if t2 is None:
|
|
|
|
|
w2 = torch.mm(w2_a.float(), w2_b.float())
|
|
|
|
|
else:
|
|
|
|
|
m1 = torch.mm(w1a.float(), w1b.float())
|
|
|
|
|
m2 = torch.mm(w2a.float(), w2b.float())
|
|
|
|
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2_b.float(), w2_a.float())
|
|
|
|
|
|
|
|
|
|
if len(w2.shape) == 4:
|
|
|
|
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
|
|
|
if v[2] is not None and dim is not None:
|
|
|
|
|
alpha *= v[2] / dim
|
|
|
|
|
|
|
|
|
|
weight += alpha * torch.kron(w1.float(), w2.float()).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
|
|
|
|
else: #loha
|
|
|
|
|
w1a = v[0]
|
|
|
|
|
w1b = v[1]
|
|
|
|
|
if v[2] is not None:
|
|
|
|
|
alpha *= v[2] / w1b.shape[0]
|
|
|
|
|
w2a = v[3]
|
|
|
|
|
w2b = v[4]
|
|
|
|
|
if v[5] is not None: #cp decomposition
|
|
|
|
|
t1 = v[5]
|
|
|
|
|
t2 = v[6]
|
|
|
|
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l', t1.float(), w1b.float(), w1a.float())
|
|
|
|
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l', t2.float(), w2b.float(), w2a.float())
|
|
|
|
|
else:
|
|
|
|
|
m1 = torch.mm(w1a.float(), w1b.float())
|
|
|
|
|
m2 = torch.mm(w2a.float(), w2b.float())
|
|
|
|
|
|
|
|
|
|
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
|
|
|
|
return weight
|
|
|
|
|
|
|
|
|
|
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
|
|
|
|
return self.model
|
|
|
|
|
def unpatch_model(self):
|
|
|
|
|
model_sd = self.model_state_dict()
|
|
|
|
|
keys = list(self.backup.keys())
|
|
|
|
|