|
|
|
@ -1,4 +1,24 @@
|
|
|
|
|
"""
|
|
|
|
|
This file is part of ComfyUI.
|
|
|
|
|
Copyright (C) 2024 Comfy
|
|
|
|
|
|
|
|
|
|
This program is free software: you can redistribute it and/or modify
|
|
|
|
|
it under the terms of the GNU General Public License as published by
|
|
|
|
|
the Free Software Foundation, either version 3 of the License, or
|
|
|
|
|
(at your option) any later version.
|
|
|
|
|
|
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
|
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
|
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
|
|
|
GNU General Public License for more details.
|
|
|
|
|
|
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
|
|
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from enum import Enum
|
|
|
|
|
import math
|
|
|
|
|
import os
|
|
|
|
|
import logging
|
|
|
|
@ -33,6 +53,10 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
|
|
|
|
|
else:
|
|
|
|
|
return torch.cat([tensor] * batched_number, dim=0)
|
|
|
|
|
|
|
|
|
|
class StrengthType(Enum):
|
|
|
|
|
CONSTANT = 1
|
|
|
|
|
LINEAR_UP = 2
|
|
|
|
|
|
|
|
|
|
class ControlBase:
|
|
|
|
|
def __init__(self, device=None):
|
|
|
|
|
self.cond_hint_original = None
|
|
|
|
@ -51,6 +75,8 @@ class ControlBase:
|
|
|
|
|
device = comfy.model_management.get_torch_device()
|
|
|
|
|
self.device = device
|
|
|
|
|
self.previous_controlnet = None
|
|
|
|
|
self.extra_conds = []
|
|
|
|
|
self.strength_type = StrengthType.CONSTANT
|
|
|
|
|
|
|
|
|
|
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
|
|
|
|
|
self.cond_hint_original = cond_hint
|
|
|
|
@ -93,6 +119,8 @@ class ControlBase:
|
|
|
|
|
c.latent_format = self.latent_format
|
|
|
|
|
c.extra_args = self.extra_args.copy()
|
|
|
|
|
c.vae = self.vae
|
|
|
|
|
c.extra_conds = self.extra_conds.copy()
|
|
|
|
|
c.strength_type = self.strength_type
|
|
|
|
|
|
|
|
|
|
def inference_memory_requirements(self, dtype):
|
|
|
|
|
if self.previous_controlnet is not None:
|
|
|
|
@ -113,7 +141,10 @@ class ControlBase:
|
|
|
|
|
|
|
|
|
|
if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
|
|
|
|
|
applied_to.add(x)
|
|
|
|
|
x *= self.strength
|
|
|
|
|
if self.strength_type == StrengthType.CONSTANT:
|
|
|
|
|
x *= self.strength
|
|
|
|
|
elif self.strength_type == StrengthType.LINEAR_UP:
|
|
|
|
|
x *= (self.strength ** float(len(control_output) - i))
|
|
|
|
|
|
|
|
|
|
if x.dtype != output_dtype:
|
|
|
|
|
x = x.to(output_dtype)
|
|
|
|
@ -142,7 +173,7 @@ class ControlBase:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlNet(ControlBase):
|
|
|
|
|
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
|
|
|
|
|
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=[], strength_type=StrengthType.CONSTANT):
|
|
|
|
|
super().__init__(device)
|
|
|
|
|
self.control_model = control_model
|
|
|
|
|
self.load_device = load_device
|
|
|
|
@ -154,6 +185,8 @@ class ControlNet(ControlBase):
|
|
|
|
|
self.model_sampling_current = None
|
|
|
|
|
self.manual_cast_dtype = manual_cast_dtype
|
|
|
|
|
self.latent_format = latent_format
|
|
|
|
|
self.extra_conds += extra_conds
|
|
|
|
|
self.strength_type = strength_type
|
|
|
|
|
|
|
|
|
|
def get_control(self, x_noisy, t, cond, batched_number):
|
|
|
|
|
control_prev = None
|
|
|
|
@ -192,7 +225,7 @@ class ControlNet(ControlBase):
|
|
|
|
|
|
|
|
|
|
context = cond.get('crossattn_controlnet', cond['c_crossattn'])
|
|
|
|
|
extra = self.extra_args.copy()
|
|
|
|
|
for c in ["y", "guidance"]: #TODO
|
|
|
|
|
for c in self.extra_conds:
|
|
|
|
|
temp = cond.get(c, None)
|
|
|
|
|
if temp is not None:
|
|
|
|
|
extra[c] = temp.to(dtype)
|
|
|
|
@ -382,116 +415,22 @@ def load_controlnet_mmdit(sd):
|
|
|
|
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
|
|
|
|
return control
|
|
|
|
|
|
|
|
|
|
class ControlNetWarperHunyuanDiT(ControlNet):
|
|
|
|
|
def get_control(self, x_noisy, t, cond, batched_number):
|
|
|
|
|
control_prev = None
|
|
|
|
|
if self.previous_controlnet is not None:
|
|
|
|
|
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
|
|
|
|
|
|
|
|
|
if self.timestep_range is not None:
|
|
|
|
|
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
|
|
|
|
if control_prev is not None:
|
|
|
|
|
return control_prev
|
|
|
|
|
else:
|
|
|
|
|
return None
|
|
|
|
|
def load_controlnet_hunyuandit(controlnet_data):
|
|
|
|
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
|
|
|
|
|
|
|
|
|
|
dtype = self.control_model.dtype
|
|
|
|
|
if self.manual_cast_dtype is not None:
|
|
|
|
|
dtype = self.manual_cast_dtype
|
|
|
|
|
|
|
|
|
|
output_dtype = x_noisy.dtype
|
|
|
|
|
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
|
|
|
|
|
if self.cond_hint is not None:
|
|
|
|
|
del self.cond_hint
|
|
|
|
|
self.cond_hint = None
|
|
|
|
|
compression_ratio = self.compression_ratio
|
|
|
|
|
if self.vae is not None:
|
|
|
|
|
compression_ratio *= self.vae.downscale_ratio
|
|
|
|
|
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
|
|
|
|
if self.vae is not None:
|
|
|
|
|
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
|
|
|
|
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
|
|
|
|
|
comfy.model_management.load_models_gpu(loaded_models)
|
|
|
|
|
if self.latent_format is not None:
|
|
|
|
|
self.cond_hint = self.latent_format.process_in(self.cond_hint)
|
|
|
|
|
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
|
|
|
|
|
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
|
|
|
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
|
|
|
|
|
|
|
|
|
def get_tensor(name):
|
|
|
|
|
if name in cond:
|
|
|
|
|
if isinstance(cond[name], torch.Tensor):
|
|
|
|
|
return cond[name].to(dtype)
|
|
|
|
|
else:
|
|
|
|
|
return cond[name]
|
|
|
|
|
else:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
encoder_hidden_states = get_tensor('c_crossattn')
|
|
|
|
|
text_embedding_mask = get_tensor('text_embedding_mask')
|
|
|
|
|
encoder_hidden_states_t5 = get_tensor('encoder_hidden_states_t5')
|
|
|
|
|
text_embedding_mask_t5 = get_tensor('text_embedding_mask_t5')
|
|
|
|
|
image_meta_size = get_tensor('image_meta_size')
|
|
|
|
|
style = get_tensor('style')
|
|
|
|
|
cos_cis_img = get_tensor('cos_cis_img')
|
|
|
|
|
sin_cis_img = get_tensor('sin_cis_img')
|
|
|
|
|
|
|
|
|
|
timestep = self.model_sampling_current.timestep(t)
|
|
|
|
|
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
|
|
|
|
|
|
|
|
|
control = self.control_model(
|
|
|
|
|
x=x_noisy.to(dtype),
|
|
|
|
|
t=timestep.float(),
|
|
|
|
|
condition=self.cond_hint,
|
|
|
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
|
|
|
text_embedding_mask=text_embedding_mask,
|
|
|
|
|
encoder_hidden_states_t5=encoder_hidden_states_t5,
|
|
|
|
|
text_embedding_mask_t5=text_embedding_mask_t5,
|
|
|
|
|
image_meta_size=image_meta_size,
|
|
|
|
|
style=style,
|
|
|
|
|
cos_cis_img=cos_cis_img,
|
|
|
|
|
sin_cis_img=sin_cis_img,
|
|
|
|
|
**self.extra_args
|
|
|
|
|
)
|
|
|
|
|
return self.control_merge(control, control_prev, output_dtype)
|
|
|
|
|
|
|
|
|
|
def copy(self):
|
|
|
|
|
c = ControlNetWarperHunyuanDiT(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
|
|
|
|
c.control_model = self.control_model
|
|
|
|
|
c.control_model_wrapped = self.control_model_wrapped
|
|
|
|
|
self.copy_to(c)
|
|
|
|
|
return c
|
|
|
|
|
|
|
|
|
|
def load_controlnet_hunyuandit(controlnet_data):
|
|
|
|
|
|
|
|
|
|
supported_inference_dtypes = [torch.float16, torch.float32]
|
|
|
|
|
|
|
|
|
|
unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
|
|
|
|
|
load_device = comfy.model_management.get_torch_device()
|
|
|
|
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
|
|
|
|
if manual_cast_dtype is not None:
|
|
|
|
|
operations = comfy.ops.manual_cast
|
|
|
|
|
else:
|
|
|
|
|
operations = comfy.ops.disable_weight_init
|
|
|
|
|
|
|
|
|
|
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
|
|
|
|
|
missing, unexpected = control_model.load_state_dict(controlnet_data)
|
|
|
|
|
|
|
|
|
|
if len(missing) > 0:
|
|
|
|
|
logging.warning("missing controlnet keys: {}".format(missing))
|
|
|
|
|
|
|
|
|
|
if len(unexpected) > 0:
|
|
|
|
|
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
|
|
|
|
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
|
|
|
|
|
|
|
|
|
latent_format = comfy.latent_formats.SDXL()
|
|
|
|
|
control = ControlNetWarperHunyuanDiT(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
|
|
|
|
extra_conds = ['text_embedding_mask', 'encoder_hidden_states_t5', 'text_embedding_mask_t5', 'image_meta_size', 'style', 'cos_cis_img', 'sin_cis_img']
|
|
|
|
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.LINEAR_UP)
|
|
|
|
|
return control
|
|
|
|
|
|
|
|
|
|
def load_controlnet(ckpt_path, model=None):
|
|
|
|
|
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
|
|
|
|
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
|
|
|
|
return load_controlnet_hunyuandit(controlnet_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "lora_controlnet" in controlnet_data:
|
|
|
|
|
return ControlLora(controlnet_data)
|
|
|
|
|
|
|
|
|
|