|
|
@ -202,6 +202,14 @@ def model_lora_keys_unet(model, key_map={}):
|
|
|
|
key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k])
|
|
|
|
key_map["lora_unet_{}".format(key_lora)] = "diffusion_model.{}".format(diffusers_keys[k])
|
|
|
|
return key_map
|
|
|
|
return key_map
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_attr(obj, attr, value):
|
|
|
|
|
|
|
|
attrs = attr.split(".")
|
|
|
|
|
|
|
|
for name in attrs[:-1]:
|
|
|
|
|
|
|
|
obj = getattr(obj, name)
|
|
|
|
|
|
|
|
prev = getattr(obj, attrs[-1])
|
|
|
|
|
|
|
|
setattr(obj, attrs[-1], torch.nn.Parameter(value))
|
|
|
|
|
|
|
|
del prev
|
|
|
|
|
|
|
|
|
|
|
|
class ModelPatcher:
|
|
|
|
class ModelPatcher:
|
|
|
|
def __init__(self, model, load_device, offload_device, size=0):
|
|
|
|
def __init__(self, model, load_device, offload_device, size=0):
|
|
|
|
self.size = size
|
|
|
|
self.size = size
|
|
|
@ -340,10 +348,11 @@ class ModelPatcher:
|
|
|
|
weight = model_sd[key]
|
|
|
|
weight = model_sd[key]
|
|
|
|
|
|
|
|
|
|
|
|
if key not in self.backup:
|
|
|
|
if key not in self.backup:
|
|
|
|
self.backup[key] = weight.to(self.offload_device, copy=True)
|
|
|
|
self.backup[key] = weight.to(self.offload_device)
|
|
|
|
|
|
|
|
|
|
|
|
temp_weight = weight.to(torch.float32, copy=True)
|
|
|
|
temp_weight = weight.to(torch.float32, copy=True)
|
|
|
|
weight[:] = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
|
|
|
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
|
|
|
|
|
|
|
set_attr(self.model, key, out_weight)
|
|
|
|
del temp_weight
|
|
|
|
del temp_weight
|
|
|
|
return self.model
|
|
|
|
return self.model
|
|
|
|
|
|
|
|
|
|
|
@ -439,13 +448,6 @@ class ModelPatcher:
|
|
|
|
|
|
|
|
|
|
|
|
def unpatch_model(self):
|
|
|
|
def unpatch_model(self):
|
|
|
|
keys = list(self.backup.keys())
|
|
|
|
keys = list(self.backup.keys())
|
|
|
|
def set_attr(obj, attr, value):
|
|
|
|
|
|
|
|
attrs = attr.split(".")
|
|
|
|
|
|
|
|
for name in attrs[:-1]:
|
|
|
|
|
|
|
|
obj = getattr(obj, name)
|
|
|
|
|
|
|
|
prev = getattr(obj, attrs[-1])
|
|
|
|
|
|
|
|
setattr(obj, attrs[-1], torch.nn.Parameter(value))
|
|
|
|
|
|
|
|
del prev
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for k in keys:
|
|
|
|
for k in keys:
|
|
|
|
set_attr(self.model, k, self.backup[k])
|
|
|
|
set_attr(self.model, k, self.backup[k])
|
|
|
|