|
|
|
@ -89,27 +89,12 @@ class ControlBase:
|
|
|
|
|
return self.previous_controlnet.inference_memory_requirements(dtype)
|
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
def control_merge(self, control_input, control_output, control_prev, output_dtype):
|
|
|
|
|
def control_merge(self, control, control_prev, output_dtype):
|
|
|
|
|
out = {'input':[], 'middle':[], 'output': []}
|
|
|
|
|
|
|
|
|
|
if control_input is not None:
|
|
|
|
|
for i in range(len(control_input)):
|
|
|
|
|
key = 'input'
|
|
|
|
|
x = control_input[i]
|
|
|
|
|
if x is not None:
|
|
|
|
|
x *= self.strength
|
|
|
|
|
if x.dtype != output_dtype:
|
|
|
|
|
x = x.to(output_dtype)
|
|
|
|
|
out[key].insert(0, x)
|
|
|
|
|
|
|
|
|
|
if control_output is not None:
|
|
|
|
|
for key in control:
|
|
|
|
|
control_output = control[key]
|
|
|
|
|
for i in range(len(control_output)):
|
|
|
|
|
if i == (len(control_output) - 1):
|
|
|
|
|
key = 'middle'
|
|
|
|
|
index = 0
|
|
|
|
|
else:
|
|
|
|
|
key = 'output'
|
|
|
|
|
index = i
|
|
|
|
|
x = control_output[i]
|
|
|
|
|
if x is not None:
|
|
|
|
|
if self.global_average_pooling:
|
|
|
|
@ -120,6 +105,7 @@ class ControlBase:
|
|
|
|
|
x = x.to(output_dtype)
|
|
|
|
|
|
|
|
|
|
out[key].append(x)
|
|
|
|
|
|
|
|
|
|
if control_prev is not None:
|
|
|
|
|
for x in ['input', 'middle', 'output']:
|
|
|
|
|
o = out[x]
|
|
|
|
@ -182,7 +168,7 @@ class ControlNet(ControlBase):
|
|
|
|
|
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
|
|
|
|
|
|
|
|
|
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
|
|
|
|
|
return self.control_merge(None, control, control_prev, output_dtype)
|
|
|
|
|
return self.control_merge(control, control_prev, output_dtype)
|
|
|
|
|
|
|
|
|
|
def copy(self):
|
|
|
|
|
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
|
|
|
@ -490,12 +476,11 @@ class T2IAdapter(ControlBase):
|
|
|
|
|
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
|
|
|
|
|
self.t2i_model.cpu()
|
|
|
|
|
|
|
|
|
|
control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input))
|
|
|
|
|
mid = None
|
|
|
|
|
if self.t2i_model.xl == True:
|
|
|
|
|
mid = control_input[-1:]
|
|
|
|
|
control_input = control_input[:-1]
|
|
|
|
|
return self.control_merge(control_input, mid, control_prev, x_noisy.dtype)
|
|
|
|
|
control_input = {}
|
|
|
|
|
for k in self.control_input:
|
|
|
|
|
control_input[k] = list(map(lambda a: None if a is None else a.clone(), self.control_input[k]))
|
|
|
|
|
|
|
|
|
|
return self.control_merge(control_input, control_prev, x_noisy.dtype)
|
|
|
|
|
|
|
|
|
|
def copy(self):
|
|
|
|
|
c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm)
|
|
|
|
|