|
|
|
@ -138,11 +138,13 @@ class ControlBase:
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
class ControlNet(ControlBase):
|
|
|
|
|
def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
|
|
|
|
|
def __init__(self, control_model=None, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
|
|
|
|
|
super().__init__(device)
|
|
|
|
|
self.control_model = control_model
|
|
|
|
|
self.load_device = load_device
|
|
|
|
|
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
|
|
|
|
if control_model is not None:
|
|
|
|
|
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
|
|
|
|
|
|
|
|
|
self.global_average_pooling = global_average_pooling
|
|
|
|
|
self.model_sampling_current = None
|
|
|
|
|
self.manual_cast_dtype = manual_cast_dtype
|
|
|
|
@ -183,7 +185,9 @@ class ControlNet(ControlBase):
|
|
|
|
|
return self.control_merge(None, control, control_prev, output_dtype)
|
|
|
|
|
|
|
|
|
|
def copy(self):
|
|
|
|
|
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
|
|
|
|
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
|
|
|
|
c.control_model = self.control_model
|
|
|
|
|
c.control_model_wrapped = self.control_model_wrapped
|
|
|
|
|
self.copy_to(c)
|
|
|
|
|
return c
|
|
|
|
|
|
|
|
|
|