|
|
|
@ -64,9 +64,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
|
|
|
|
|
return model_options
|
|
|
|
|
|
|
|
|
|
class ModelPatcher:
|
|
|
|
|
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
|
|
|
|
|
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
|
|
|
|
self.size = size
|
|
|
|
|
self.model = model
|
|
|
|
|
if not hasattr(self.model, 'device'):
|
|
|
|
|
logging.info("Model doesn't have a device attribute.")
|
|
|
|
|
self.model.device = offload_device
|
|
|
|
|
elif self.model.device is None:
|
|
|
|
|
self.model.device = offload_device
|
|
|
|
|
|
|
|
|
|
self.patches = {}
|
|
|
|
|
self.backup = {}
|
|
|
|
|
self.object_patches = {}
|
|
|
|
@ -75,11 +81,6 @@ class ModelPatcher:
|
|
|
|
|
self.model_size()
|
|
|
|
|
self.load_device = load_device
|
|
|
|
|
self.offload_device = offload_device
|
|
|
|
|
if current_device is None:
|
|
|
|
|
self.current_device = self.offload_device
|
|
|
|
|
else:
|
|
|
|
|
self.current_device = current_device
|
|
|
|
|
|
|
|
|
|
self.weight_inplace_update = weight_inplace_update
|
|
|
|
|
self.model_lowvram = False
|
|
|
|
|
self.lowvram_patch_counter = 0
|
|
|
|
@ -92,7 +93,7 @@ class ModelPatcher:
|
|
|
|
|
return self.size
|
|
|
|
|
|
|
|
|
|
def clone(self):
|
|
|
|
|
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
|
|
|
|
|
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
|
|
|
|
|
n.patches = {}
|
|
|
|
|
for k in self.patches:
|
|
|
|
|
n.patches[k] = self.patches[k][:]
|
|
|
|
@ -302,7 +303,7 @@ class ModelPatcher:
|
|
|
|
|
|
|
|
|
|
if device_to is not None:
|
|
|
|
|
self.model.to(device_to)
|
|
|
|
|
self.current_device = device_to
|
|
|
|
|
self.model.device = device_to
|
|
|
|
|
|
|
|
|
|
return self.model
|
|
|
|
|
|
|
|
|
@ -355,6 +356,7 @@ class ModelPatcher:
|
|
|
|
|
|
|
|
|
|
self.model_lowvram = True
|
|
|
|
|
self.lowvram_patch_counter = patch_counter
|
|
|
|
|
self.model.device = device_to
|
|
|
|
|
return self.model
|
|
|
|
|
|
|
|
|
|
def calculate_weight(self, patches, weight, key):
|
|
|
|
@ -551,10 +553,13 @@ class ModelPatcher:
|
|
|
|
|
|
|
|
|
|
if device_to is not None:
|
|
|
|
|
self.model.to(device_to)
|
|
|
|
|
self.current_device = device_to
|
|
|
|
|
self.model.device = device_to
|
|
|
|
|
|
|
|
|
|
keys = list(self.object_patches_backup.keys())
|
|
|
|
|
for k in keys:
|
|
|
|
|
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
|
|
|
|
|
|
|
|
|
|
self.object_patches_backup.clear()
|
|
|
|
|
|
|
|
|
|
def current_loaded_device(self):
|
|
|
|
|
return self.model.device
|
|
|
|
|