|
|
|
@ -129,12 +129,17 @@ def load_lora(path, to_load):
|
|
|
|
|
A_name = "{}.lora_up.weight".format(x)
|
|
|
|
|
B_name = "{}.lora_down.weight".format(x)
|
|
|
|
|
alpha_name = "{}.alpha".format(x)
|
|
|
|
|
mid_name = "{}.lora_mid.weight".format(x)
|
|
|
|
|
if A_name in lora.keys():
|
|
|
|
|
alpha = None
|
|
|
|
|
if alpha_name in lora.keys():
|
|
|
|
|
alpha = lora[alpha_name].item()
|
|
|
|
|
loaded_keys.add(alpha_name)
|
|
|
|
|
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha)
|
|
|
|
|
mid = None
|
|
|
|
|
if mid_name in lora.keys():
|
|
|
|
|
mid = lora[mid_name]
|
|
|
|
|
loaded_keys.add(mid_name)
|
|
|
|
|
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
|
|
|
|
|
loaded_keys.add(A_name)
|
|
|
|
|
loaded_keys.add(B_name)
|
|
|
|
|
for x in lora.keys():
|
|
|
|
@ -279,6 +284,10 @@ class ModelPatcher:
|
|
|
|
|
mat2 = v[1]
|
|
|
|
|
if v[2] is not None:
|
|
|
|
|
alpha *= v[2] / mat2.shape[0]
|
|
|
|
|
if v[3] is not None:
|
|
|
|
|
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
|
|
|
|
final_shape = [mat2.shape[1], mat2.shape[0], v[3].shape[2], v[3].shape[3]]
|
|
|
|
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1).float(), v[3].transpose(0, 1).flatten(start_dim=1).float()).reshape(final_shape).transpose(0, 1)
|
|
|
|
|
weight += (alpha * torch.mm(mat1.flatten(start_dim=1).float(), mat2.flatten(start_dim=1).float())).reshape(weight.shape).type(weight.dtype).to(weight.device)
|
|
|
|
|
return self.model
|
|
|
|
|
def unpatch_model(self):
|
|
|
|
|