From 4efa67fa1239b49bdbdb944ac1980a6a4730b5e2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 16 Feb 2023 10:38:08 -0500 Subject: [PATCH] Add ControlNet support. --- comfy/cldm/cldm.py | 286 ++++++++++++++++++ comfy/extra_samplers/uni_pc.py | 4 +- comfy/ldm/models/diffusion/ddpm.py | 18 +- .../modules/diffusionmodules/openaimodel.py | 10 +- comfy/model_management.py | 17 +- comfy/samplers.py | 121 ++++++-- comfy/sd.py | 76 +++++ comfy/utils.py | 18 ++ nodes.py | 93 ++++-- 9 files changed, 580 insertions(+), 63 deletions(-) create mode 100644 comfy/cldm/cldm.py create mode 100644 comfy/utils.py diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py new file mode 100644 index 0000000..c75830a --- /dev/null +++ b/comfy/cldm/cldm.py @@ -0,0 +1,286 @@ +#taken from: https://github.com/lllyasviel/ControlNet +#and modified + +import einops +import torch +import torch as th +import torch.nn as nn + +from ldm.modules.diffusionmodules.util import ( + conv_nd, + linear, + zero_module, + timestep_embedding, +) + +from einops import rearrange, repeat +from torchvision.utils import make_grid +from ldm.modules.attention import SpatialTransformer +from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock +from ldm.models.diffusion.ddpm import LatentDiffusion +from ldm.util import log_txt_as_img, exists, instantiate_from_config + + +class ControlledUnetModel(UNetModel): + #implemented in the ldm unet + pass + +class ControlNet(nn.Module): + def __init__( + self, + image_size, + in_channels, + model_channels, + hint_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.dims = dims + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)]) + + self.input_hint_block = TimestepEmbedSequential( + conv_nd(dims, hint_channels, 16, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 16, 16, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 16, 32, 3, padding=1, stride=2), + nn.SiLU(), + conv_nd(dims, 32, 32, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 32, 96, 3, padding=1, stride=2), + nn.SiLU(), + conv_nd(dims, 96, 96, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 96, 256, 3, padding=1, stride=2), + nn.SiLU(), + zero_module(conv_nd(dims, 256, model_channels, 3, padding=1)) + ) + + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self.zero_convs.append(self.make_zero_conv(ch)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + self.zero_convs.append(self.make_zero_conv(ch)) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self.middle_block_out = self.make_zero_conv(ch) + self._feature_size += ch + + def make_zero_conv(self, channels): + return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0))) + + def forward(self, x, hint, timesteps, context, **kwargs): + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + guided_hint = self.input_hint_block(hint, emb, context) + + outs = [] + + h = x.type(self.dtype) + for module, zero_conv in zip(self.input_blocks, self.zero_convs): + if guided_hint is not None: + h = module(h, emb, context) + h += guided_hint + guided_hint = None + else: + h = module(h, emb, context) + outs.append(zero_conv(h, emb, context)) + + h = self.middle_block(h, emb, context) + outs.append(self.middle_block_out(h, emb, context)) + + return outs + diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index cfd7225..bcc7c0f 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -856,13 +856,13 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None device = noise.device - if model.inner_model.parameterization == "v": + if model.parameterization == "v": model_type = "v" else: model_type = "noise" model_fn = model_wrapper( - model.inner_model.apply_model, + model.inner_model.inner_model.apply_model, sampling_function, ns, model_type=model_type, diff --git a/comfy/ldm/models/diffusion/ddpm.py b/comfy/ldm/models/diffusion/ddpm.py index 074919d..efe20a3 100644 --- a/comfy/ldm/models/diffusion/ddpm.py +++ b/comfy/ldm/models/diffusion/ddpm.py @@ -1320,12 +1320,12 @@ class DiffusionWrapper(torch.nn.Module): self.conditioning_key = conditioning_key assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] - def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None): + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None): if self.conditioning_key is None: - out = self.diffusion_model(x, t) + out = self.diffusion_model(x, t, control=control) elif self.conditioning_key == 'concat': xc = torch.cat([x] + c_concat, dim=1) - out = self.diffusion_model(xc, t) + out = self.diffusion_model(xc, t, control=control) elif self.conditioning_key == 'crossattn': if not self.sequential_cross_attn: cc = torch.cat(c_crossattn, 1) @@ -1335,25 +1335,25 @@ class DiffusionWrapper(torch.nn.Module): # TorchScript changes names of the arguments # with argument cc defined as context=cc scripted model will produce # an error: RuntimeError: forward() is missing value for argument 'argument_3'. - out = self.scripted_diffusion_model(x, t, cc) + out = self.scripted_diffusion_model(x, t, cc, control=control) else: - out = self.diffusion_model(x, t, context=cc) + out = self.diffusion_model(x, t, context=cc, control=control) elif self.conditioning_key == 'hybrid': xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(xc, t, context=cc) + out = self.diffusion_model(xc, t, context=cc, control=control) elif self.conditioning_key == 'hybrid-adm': assert c_adm is not None xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(xc, t, context=cc, y=c_adm) + out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control) elif self.conditioning_key == 'crossattn-adm': assert c_adm is not None cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(x, t, context=cc, y=c_adm) + out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control) elif self.conditioning_key == 'adm': cc = c_crossattn[0] - out = self.diffusion_model(x, t, y=cc) + out = self.diffusion_model(x, t, y=cc, control=control) else: raise NotImplementedError() diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 764a34b..1769cc0 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -753,7 +753,7 @@ class UNetModel(nn.Module): self.middle_block.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32) - def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + def forward(self, x, timesteps=None, context=None, y=None, control=None, **kwargs): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. @@ -778,8 +778,14 @@ class UNetModel(nn.Module): h = module(h, emb, context) hs.append(h) h = self.middle_block(h, emb, context) + if control is not None: + h += control.pop() + for module in self.output_blocks: - h = th.cat([h, hs.pop()], dim=1) + hsp = hs.pop() + if control is not None: + hsp += control.pop() + h = th.cat([h, hsp], dim=1) h = module(h, emb, context) h = h.type(x.dtype) if self.predict_codebook_ids: diff --git a/comfy/model_management.py b/comfy/model_management.py index ff7cbeb..b8fd879 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -48,7 +48,7 @@ print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM"][vram_s current_loaded_model = None - +current_gpu_controlnets = [] model_accelerated = False @@ -56,6 +56,7 @@ model_accelerated = False def unload_model(): global current_loaded_model global model_accelerated + global current_gpu_controlnets if current_loaded_model is not None: if model_accelerated: accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model) @@ -64,6 +65,10 @@ def unload_model(): current_loaded_model.model.cpu() current_loaded_model.unpatch_model() current_loaded_model = None + if len(current_gpu_controlnets) > 0: + for n in current_gpu_controlnets: + n.cpu() + current_gpu_controlnets = [] def load_model_gpu(model): @@ -95,6 +100,16 @@ def load_model_gpu(model): model_accelerated = True return current_loaded_model +def load_controlnet_gpu(models): + global current_gpu_controlnets + for m in current_gpu_controlnets: + if m not in models: + m.cpu() + + current_gpu_controlnets = [] + for m in models: + current_gpu_controlnets.append(m.cuda()) + def get_free_memory(): dev = torch.cuda.current_device() diff --git a/comfy/samplers.py b/comfy/samplers.py index 7e2e667..a5a3181 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -21,12 +21,13 @@ class CFGDenoiser(torch.nn.Module): uncond = self.inner_model(x, sigma, cond=uncond) return uncond + (cond - uncond) * cond_scale -def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_concat=None): - def get_area_and_mult(cond, x_in, cond_concat_in): + +#The main sampling function shared by all the samplers +#Returns predicted noise +def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None): + def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) strength = 1.0 - min_sigma = 0.0 - max_sigma = 999.0 if 'area' in cond[1]: area = cond[1]['area'] if 'strength' in cond[1]: @@ -56,9 +57,15 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] cropped.append(cr) conditionning['c_concat'] = torch.cat(cropped, dim=1) - return (input_x, mult, conditionning, area) + + control = None + if 'control' in cond[1]: + control = cond[1]['control'] + return (input_x, mult, conditionning, area, control) def cond_equal_size(c1, c2): + if c1 is c2: + return True if c1.keys() != c2.keys(): return False if 'c_crossattn' in c1: @@ -69,6 +76,17 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c return False return True + def can_concat_cond(c1, c2): + if c1[0].shape != c2[0].shape: + return False + if (c1[4] is None) != (c2[4] is None): + return False + if c1[4] is not None: + if c1[4] is not c2[4]: + return False + + return cond_equal_size(c1[2], c2[2]) + def cond_cat(c_list): c_crossattn = [] c_concat = [] @@ -84,7 +102,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c out['c_concat'] = [torch.cat(c_concat)] return out - def calc_cond_uncond_batch(model_function, cond, uncond, x_in, sigma, max_total_area, cond_concat_in): + def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in): out_cond = torch.zeros_like(x_in) out_count = torch.ones_like(x_in)/100000.0 @@ -96,13 +114,13 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c to_run = [] for x in cond: - p = get_area_and_mult(x, x_in, cond_concat_in) + p = get_area_and_mult(x, x_in, cond_concat_in, timestep) if p is None: continue to_run += [(p, COND)] for x in uncond: - p = get_area_and_mult(x, x_in, cond_concat_in) + p = get_area_and_mult(x, x_in, cond_concat_in, timestep) if p is None: continue @@ -113,9 +131,8 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c first_shape = first[0][0].shape to_batch_temp = [] for x in range(len(to_run)): - if to_run[x][0][0].shape == first_shape: - if cond_equal_size(to_run[x][0][2], first[0][2]): - to_batch_temp += [x] + if can_concat_cond(to_run[x][0], first[0]): + to_batch_temp += [x] to_batch_temp.reverse() to_batch = to_batch_temp[:1] @@ -131,6 +148,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c c = [] cond_or_uncond = [] area = [] + control = None for x in to_batch: o = to_run.pop(x) p = o[0] @@ -139,13 +157,17 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c c += [p[2]] area += [p[3]] cond_or_uncond += [o[1]] + control = p[4] batch_chunks = len(cond_or_uncond) input_x = torch.cat(input_x) c = cond_cat(c) - sigma_ = torch.cat([sigma] * batch_chunks) + timestep_ = torch.cat([timestep] * batch_chunks) - output = model_function(input_x, sigma_, cond=c).chunk(batch_chunks) + if control is not None: + c['control'] = control.get_control(input_x, timestep_, c['c_crossattn']) + + output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks) del input_x for o in range(batch_chunks): @@ -166,10 +188,29 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c max_total_area = model_management.maximum_batch_area() - cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, sigma, max_total_area, cond_concat) + cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat) return uncond + (cond - uncond) * cond_scale -class CFGDenoiserComplex(torch.nn.Module): + +class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser): + def __init__(self, model, quantize=False, device='cpu'): + super().__init__(model, model.alphas_cumprod, quantize=quantize) + + def get_v(self, x, t, cond, **kwargs): + return self.inner_model.apply_model(x, t, cond, **kwargs) + + +class CFGNoisePredictor(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.inner_model = model + self.alphas_cumprod = model.alphas_cumprod + def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None): + out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat) + return out + + +class KSamplerX0Inpaint(torch.nn.Module): def __init__(self, model): super().__init__() self.inner_model = model @@ -177,7 +218,7 @@ class CFGDenoiserComplex(torch.nn.Module): if denoise_mask is not None: latent_mask = 1. - denoise_mask x = x * denoise_mask + (self.latent_image + self.noise * sigma) * latent_mask - out = sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale, cond_concat) + out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat) if denoise_mask is not None: out *= denoise_mask @@ -196,8 +237,6 @@ def simple_scheduler(model, steps): def blank_inpaint_image_like(latent_image): blank_image = torch.ones_like(latent_image) # these are the values for "zero" in pixel space translated to latent space - # the proper way to do this is to apply the mask to the image in pixel space and then send it through the VAE - # unfortunately that gives zero flexibility so I did things like this instead which hopefully works blank_image[:,0] *= 0.8223 blank_image[:,1] *= -0.6876 blank_image[:,2] *= 0.6364 @@ -234,6 +273,42 @@ def create_cond_with_same_area_if_none(conds, c): n = c[1].copy() conds += [[smallest[0], n]] + +def apply_control_net_to_equal_area(conds, uncond): + cond_cnets = [] + cond_other = [] + uncond_cnets = [] + uncond_other = [] + for t in range(len(conds)): + x = conds[t] + if 'area' not in x[1]: + if 'control' in x[1] and x[1]['control'] is not None: + cond_cnets.append(x[1]['control']) + else: + cond_other.append((x, t)) + for t in range(len(uncond)): + x = uncond[t] + if 'area' not in x[1]: + if 'control' in x[1] and x[1]['control'] is not None: + uncond_cnets.append(x[1]['control']) + else: + uncond_other.append((x, t)) + + if len(uncond_cnets) > 0: + return + + for x in range(len(cond_cnets)): + temp = uncond_other[x % len(uncond_other)] + o = temp[0] + if 'control' in o[1] and o[1]['control'] is not None: + n = o[1].copy() + n['control'] = cond_cnets[x] + uncond += [[o[0], n]] + else: + n = o[1].copy() + n['control'] = cond_cnets[x] + uncond[temp[1]] = [o[0], n] + class KSampler: SCHEDULERS = ["karras", "normal", "simple"] SAMPLERS = ["sample_euler", "sample_euler_ancestral", "sample_heun", "sample_dpm_2", "sample_dpm_2_ancestral", @@ -242,11 +317,13 @@ class KSampler: def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None): self.model = model + self.model_denoise = CFGNoisePredictor(self.model) if self.model.parameterization == "v": - self.model_wrap = k_diffusion_external.CompVisVDenoiser(self.model, quantize=True) + self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True) else: - self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model, quantize=True) - self.model_k = CFGDenoiserComplex(self.model_wrap) + self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True) + self.model_wrap.parameterization = self.model.parameterization + self.model_k = KSamplerX0Inpaint(self.model_wrap) self.device = device if scheduler not in self.SCHEDULERS: scheduler = self.SCHEDULERS[0] @@ -316,6 +393,8 @@ class KSampler: for c in negative: create_cond_with_same_area_if_none(positive, c) + apply_control_net_to_equal_area(positive, negative) + if self.model.model.diffusion_model.dtype == torch.float16: precision_scope = torch.autocast else: diff --git a/comfy/sd.py b/comfy/sd.py index a3c0066..d37e531 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -6,6 +6,9 @@ import model_management from ldm.util import instantiate_from_config from ldm.models.autoencoder import AutoencoderKL from omegaconf import OmegaConf +from .cldm import cldm + +from . import utils def load_torch_file(ckpt): if ckpt.lower().endswith(".safetensors"): @@ -323,6 +326,79 @@ class VAE: samples = samples.cpu() return samples +class ControlNet: + def __init__(self, control_model): + self.control_model = control_model + self.cond_hint_original = None + self.cond_hint = None + + def get_control(self, x_noisy, t, cond_txt): + if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: + if self.cond_hint is not None: + del self.cond_hint + self.cond_hint = None + self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(x_noisy.device) + print("set cond_hint", self.cond_hint.shape) + control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt) + return control + + def set_cond_hint(self, cond_hint): + self.cond_hint_original = cond_hint + return self + + def cleanup(self): + if self.cond_hint is not None: + del self.cond_hint + self.cond_hint = None + + def copy(self): + c = ControlNet(self.control_model) + c.cond_hint_original = self.cond_hint_original + return c + +def load_controlnet(ckpt_path): + controlnet_data = load_torch_file(ckpt_path) + pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight' + pth = False + sd2 = False + key = 'input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight' + if pth_key in controlnet_data: + pth = True + key = pth_key + elif key in controlnet_data: + pass + else: + print("error checkpoint does not contain controlnet data", ckpt_path) + return None + + context_dim = controlnet_data[key].shape[1] + control_model = cldm.ControlNet(image_size=32, + in_channels=4, + hint_channels=3, + model_channels=320, + attention_resolutions=[ 4, 2, 1 ], + num_res_blocks=2, + channel_mult=[ 1, 2, 4, 4 ], + num_heads=8, + use_spatial_transformer=True, + transformer_depth=1, + context_dim=context_dim, + use_checkpoint=True, + legacy=False) + + if pth: + class WeightsLoader(torch.nn.Module): + pass + w = WeightsLoader() + w.control_model = control_model + w.load_state_dict(controlnet_data, strict=False) + else: + control_model.load_state_dict(controlnet_data, strict=False) + + control = ControlNet(control_model) + return control + + def load_clip(ckpt_path, embedding_directory=None): clip_data = load_torch_file(ckpt_path) config = {} diff --git a/comfy/utils.py b/comfy/utils.py new file mode 100644 index 0000000..815e899 --- /dev/null +++ b/comfy/utils.py @@ -0,0 +1,18 @@ +import torch + +def common_upscale(samples, width, height, upscale_method, crop): + if crop == "center": + old_width = samples.shape[3] + old_height = samples.shape[2] + old_aspect = old_width / old_height + new_aspect = width / height + x = 0 + y = 0 + if old_aspect > new_aspect: + x = round((old_width - old_width * (new_aspect / old_aspect)) / 2) + elif old_aspect < new_aspect: + y = round((old_height - old_height * (old_aspect / new_aspect)) / 2) + s = samples[:,:,y:old_height-y,x:old_width-x] + else: + s = samples + return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) diff --git a/nodes.py b/nodes.py index 5784ba3..9aec923 100644 --- a/nodes.py +++ b/nodes.py @@ -15,10 +15,12 @@ sys.path.insert(0, os.path.join(sys.path[0], "comfy")) import comfy.samplers import comfy.sd +import comfy.utils + import model_management -supported_ckpt_extensions = ['.ckpt'] -supported_pt_extensions = ['.ckpt', '.pt', '.bin'] +supported_ckpt_extensions = ['.ckpt', '.pth'] +supported_pt_extensions = ['.ckpt', '.pt', '.bin', '.pth'] try: import safetensors.torch supported_ckpt_extensions += ['.safetensors'] @@ -77,12 +79,14 @@ class ConditioningSetArea: CATEGORY = "conditioning" def append(self, conditioning, width, height, x, y, strength, min_sigma=0.0, max_sigma=99.0): - c = copy.deepcopy(conditioning) - for t in c: - t[1]['area'] = (height // 8, width // 8, y // 8, x // 8) - t[1]['strength'] = strength - t[1]['min_sigma'] = min_sigma - t[1]['max_sigma'] = max_sigma + c = [] + for t in conditioning: + n = [t[0], t[1].copy()] + n[1]['area'] = (height // 8, width // 8, y // 8, x // 8) + n[1]['strength'] = strength + n[1]['min_sigma'] = min_sigma + n[1]['max_sigma'] = max_sigma + c.append(n) return (c, ) class VAEDecode: @@ -134,7 +138,6 @@ class VAEEncodeForInpaint: CATEGORY = "latent/inpaint" def encode(self, vae, pixels, mask): - print(pixels.shape, mask.shape) x = (pixels.shape[1] // 64) * 64 y = (pixels.shape[2] // 64) * 64 if pixels.shape[1] != x or pixels.shape[2] != y: @@ -144,7 +147,6 @@ class VAEEncodeForInpaint: #shave off a few pixels to keep things seamless kernel_tensor = torch.ones((1, 1, 6, 6)) mask_erosion = torch.clamp(torch.nn.functional.conv2d((1.0 - mask.round())[None], kernel_tensor, padding=3), 0, 1) - print(mask_erosion.shape, pixels.shape) for i in range(3): pixels[:,:,:,i] -= 0.5 pixels[:,:,:,i] *= mask_erosion[0][:x,:y].round() @@ -211,6 +213,44 @@ class VAELoader: vae = comfy.sd.VAE(ckpt_path=vae_path) return (vae,) +class ControlNetLoader: + models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") + controlnet_dir = os.path.join(models_dir, "controlnet") + @classmethod + def INPUT_TYPES(s): + return {"required": { "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}} + + RETURN_TYPES = ("CONTROL_NET",) + FUNCTION = "load_controlnet" + + CATEGORY = "loaders" + + def load_controlnet(self, control_net_name): + controlnet_path = os.path.join(self.controlnet_dir, control_net_name) + controlnet = comfy.sd.load_controlnet(controlnet_path) + return (controlnet,) + + +class ControlNetApply: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning": ("CONDITIONING", ), "control_net": ("CONTROL_NET", ), "image": ("IMAGE", )}} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "apply_controlnet" + + CATEGORY = "conditioning" + + def apply_controlnet(self, conditioning, control_net, image): + c = [] + control_hint = image.movedim(-1,1) + print(control_hint.shape) + for t in conditioning: + n = [t[0], t[1].copy()] + n[1]['control'] = control_net.copy().set_cond_hint(control_hint) + c.append(n) + return (c, ) + + class CLIPLoader: models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") clip_dir = os.path.join(models_dir, "clip") @@ -248,22 +288,7 @@ class EmptyLatentImage: latent = torch.zeros([batch_size, 4, height // 8, width // 8]) return ({"samples":latent}, ) -def common_upscale(samples, width, height, upscale_method, crop): - if crop == "center": - old_width = samples.shape[3] - old_height = samples.shape[2] - old_aspect = old_width / old_height - new_aspect = width / height - x = 0 - y = 0 - if old_aspect > new_aspect: - x = round((old_width - old_width * (new_aspect / old_aspect)) / 2) - elif old_aspect < new_aspect: - y = round((old_height - old_height * (old_aspect / new_aspect)) / 2) - s = samples[:,:,y:old_height-y,x:old_width-x] - else: - s = samples - return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) + class LatentUpscale: upscale_methods = ["nearest-exact", "bilinear", "area"] @@ -282,7 +307,7 @@ class LatentUpscale: def upscale(self, samples, upscale_method, width, height, crop): s = samples.copy() - s["samples"] = common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop) + s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop) return (s,) class LatentRotate: @@ -461,19 +486,26 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po positive_copy = [] negative_copy = [] + control_nets = [] for p in positive: t = p[0] if t.shape[0] < noise.shape[0]: t = torch.cat([t] * noise.shape[0]) t = t.to(device) + if 'control' in p[1]: + control_nets += [p[1]['control']] positive_copy += [[t] + p[1:]] for n in negative: t = n[0] if t.shape[0] < noise.shape[0]: t = torch.cat([t] * noise.shape[0]) t = t.to(device) + if 'control' in p[1]: + control_nets += [p[1]['control']] negative_copy += [[t] + n[1:]] + model_management.load_controlnet_gpu(list(map(lambda a: a.control_model, control_nets))) + if sampler_name in comfy.samplers.KSampler.SAMPLERS: sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise) else: @@ -482,6 +514,9 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask) samples = samples.cpu() + for c in control_nets: + c.cleanup() + out = latent.copy() out["samples"] = samples return (out, ) @@ -676,7 +711,7 @@ class ImageScale: def upscale(self, image, upscale_method, width, height, crop): samples = image.movedim(-1,1) - s = common_upscale(samples, width, height, upscale_method, crop) + s = comfy.utils.common_upscale(samples, width, height, upscale_method, crop) s = s.movedim(1,-1) return (s,) @@ -704,6 +739,8 @@ NODE_CLASS_MAPPINGS = { "LatentCrop": LatentCrop, "LoraLoader": LoraLoader, "CLIPLoader": CLIPLoader, + "ControlNetApply": ControlNetApply, + "ControlNetLoader": ControlNetLoader, }