diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py new file mode 100644 index 0000000..c75830a --- /dev/null +++ b/comfy/cldm/cldm.py @@ -0,0 +1,286 @@ +#taken from: https://github.com/lllyasviel/ControlNet +#and modified + +import einops +import torch +import torch as th +import torch.nn as nn + +from ldm.modules.diffusionmodules.util import ( + conv_nd, + linear, + zero_module, + timestep_embedding, +) + +from einops import rearrange, repeat +from torchvision.utils import make_grid +from ldm.modules.attention import SpatialTransformer +from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock +from ldm.models.diffusion.ddpm import LatentDiffusion +from ldm.util import log_txt_as_img, exists, instantiate_from_config + + +class ControlledUnetModel(UNetModel): + #implemented in the ldm unet + pass + +class ControlNet(nn.Module): + def __init__( + self, + image_size, + in_channels, + model_channels, + hint_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.dims = dims + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)]) + + self.input_hint_block = TimestepEmbedSequential( + conv_nd(dims, hint_channels, 16, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 16, 16, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 16, 32, 3, padding=1, stride=2), + nn.SiLU(), + conv_nd(dims, 32, 32, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 32, 96, 3, padding=1, stride=2), + nn.SiLU(), + conv_nd(dims, 96, 96, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 96, 256, 3, padding=1, stride=2), + nn.SiLU(), + zero_module(conv_nd(dims, 256, model_channels, 3, padding=1)) + ) + + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self.zero_convs.append(self.make_zero_conv(ch)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + self.zero_convs.append(self.make_zero_conv(ch)) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self.middle_block_out = self.make_zero_conv(ch) + self._feature_size += ch + + def make_zero_conv(self, channels): + return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0))) + + def forward(self, x, hint, timesteps, context, **kwargs): + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + guided_hint = self.input_hint_block(hint, emb, context) + + outs = [] + + h = x.type(self.dtype) + for module, zero_conv in zip(self.input_blocks, self.zero_convs): + if guided_hint is not None: + h = module(h, emb, context) + h += guided_hint + guided_hint = None + else: + h = module(h, emb, context) + outs.append(zero_conv(h, emb, context)) + + h = self.middle_block(h, emb, context) + outs.append(self.middle_block_out(h, emb, context)) + + return outs + diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index ae3a544..bcc7c0f 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -358,7 +358,10 @@ class UniPC: predict_x0=True, thresholding=False, max_val=1., - variant='bh1' + variant='bh1', + noise_mask=None, + masked_image=None, + noise=None, ): """Construct a UniPC. @@ -370,7 +373,10 @@ class UniPC: self.predict_x0 = predict_x0 self.thresholding = thresholding self.max_val = max_val - + self.noise_mask = noise_mask + self.masked_image = masked_image + self.noise = noise + def dynamic_thresholding_fn(self, x0, t=None): """ The dynamic thresholding method. @@ -386,7 +392,10 @@ class UniPC: """ Return the noise prediction model. """ - return self.model(x, t) + if self.noise_mask is not None: + return self.model(x, t) * self.noise_mask + else: + return self.model(x, t) def data_prediction_fn(self, x, t): """ @@ -401,6 +410,8 @@ class UniPC: s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) x0 = torch.clamp(x0, -s, s) / s + if self.noise_mask is not None: + x0 = x0 * self.noise_mask + (1. - self.noise_mask) * self.masked_image return x0 def model_fn(self, x, t): @@ -713,6 +724,8 @@ class UniPC: assert timesteps.shape[0] - 1 == steps # with torch.no_grad(): for step_index in trange(steps): + if self.noise_mask is not None: + x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index])) if step_index == 0: vec_t = timesteps[0].expand((x.shape[0])) model_prev_list = [self.model_fn(x, vec_t)] @@ -820,7 +833,7 @@ def expand_dims(v, dims): -def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None, callback=None, disable=None): +def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None, callback=None, disable=None, noise_mask=None): to_zero = False if sigmas[-1] == 0: timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0] @@ -843,13 +856,13 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None device = noise.device - if model.inner_model.parameterization == "v": + if model.parameterization == "v": model_type = "v" else: model_type = "noise" model_fn = model_wrapper( - model.inner_model.apply_model, + model.inner_model.inner_model.apply_model, sampling_function, ns, model_type=model_type, @@ -857,7 +870,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, extra_args=None model_kwargs=extra_args, ) - uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False) + uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise) x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True) if not to_zero: x /= ns.marginal_alpha(timesteps[-1]) diff --git a/comfy/ldm/models/diffusion/ddpm.py b/comfy/ldm/models/diffusion/ddpm.py index 074919d..efe20a3 100644 --- a/comfy/ldm/models/diffusion/ddpm.py +++ b/comfy/ldm/models/diffusion/ddpm.py @@ -1320,12 +1320,12 @@ class DiffusionWrapper(torch.nn.Module): self.conditioning_key = conditioning_key assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] - def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None): + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, control=None): if self.conditioning_key is None: - out = self.diffusion_model(x, t) + out = self.diffusion_model(x, t, control=control) elif self.conditioning_key == 'concat': xc = torch.cat([x] + c_concat, dim=1) - out = self.diffusion_model(xc, t) + out = self.diffusion_model(xc, t, control=control) elif self.conditioning_key == 'crossattn': if not self.sequential_cross_attn: cc = torch.cat(c_crossattn, 1) @@ -1335,25 +1335,25 @@ class DiffusionWrapper(torch.nn.Module): # TorchScript changes names of the arguments # with argument cc defined as context=cc scripted model will produce # an error: RuntimeError: forward() is missing value for argument 'argument_3'. - out = self.scripted_diffusion_model(x, t, cc) + out = self.scripted_diffusion_model(x, t, cc, control=control) else: - out = self.diffusion_model(x, t, context=cc) + out = self.diffusion_model(x, t, context=cc, control=control) elif self.conditioning_key == 'hybrid': xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(xc, t, context=cc) + out = self.diffusion_model(xc, t, context=cc, control=control) elif self.conditioning_key == 'hybrid-adm': assert c_adm is not None xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(xc, t, context=cc, y=c_adm) + out = self.diffusion_model(xc, t, context=cc, y=c_adm, control=control) elif self.conditioning_key == 'crossattn-adm': assert c_adm is not None cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(x, t, context=cc, y=c_adm) + out = self.diffusion_model(x, t, context=cc, y=c_adm, control=control) elif self.conditioning_key == 'adm': cc = c_crossattn[0] - out = self.diffusion_model(x, t, y=cc) + out = self.diffusion_model(x, t, y=cc, control=control) else: raise NotImplementedError() diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 764a34b..1769cc0 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -753,7 +753,7 @@ class UNetModel(nn.Module): self.middle_block.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32) - def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + def forward(self, x, timesteps=None, context=None, y=None, control=None, **kwargs): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. @@ -778,8 +778,14 @@ class UNetModel(nn.Module): h = module(h, emb, context) hs.append(h) h = self.middle_block(h, emb, context) + if control is not None: + h += control.pop() + for module in self.output_blocks: - h = th.cat([h, hs.pop()], dim=1) + hsp = hs.pop() + if control is not None: + hsp += control.pop() + h = th.cat([h, hsp], dim=1) h = module(h, emb, context) h = h.type(x.dtype) if self.predict_codebook_ids: diff --git a/comfy/model_management.py b/comfy/model_management.py index ff7cbeb..b8fd879 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -48,7 +48,7 @@ print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM"][vram_s current_loaded_model = None - +current_gpu_controlnets = [] model_accelerated = False @@ -56,6 +56,7 @@ model_accelerated = False def unload_model(): global current_loaded_model global model_accelerated + global current_gpu_controlnets if current_loaded_model is not None: if model_accelerated: accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model) @@ -64,6 +65,10 @@ def unload_model(): current_loaded_model.model.cpu() current_loaded_model.unpatch_model() current_loaded_model = None + if len(current_gpu_controlnets) > 0: + for n in current_gpu_controlnets: + n.cpu() + current_gpu_controlnets = [] def load_model_gpu(model): @@ -95,6 +100,16 @@ def load_model_gpu(model): model_accelerated = True return current_loaded_model +def load_controlnet_gpu(models): + global current_gpu_controlnets + for m in current_gpu_controlnets: + if m not in models: + m.cpu() + + current_gpu_controlnets = [] + for m in models: + current_gpu_controlnets.append(m.cuda()) + def get_free_memory(): dev = torch.cuda.current_device() diff --git a/comfy/samplers.py b/comfy/samplers.py index 7f6dc97..a5a3181 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -21,12 +21,13 @@ class CFGDenoiser(torch.nn.Module): uncond = self.inner_model(x, sigma, cond=uncond) return uncond + (cond - uncond) * cond_scale -def sampling_function(model_function, x, sigma, uncond, cond, cond_scale): - def get_area_and_mult(cond, x_in): + +#The main sampling function shared by all the samplers +#Returns predicted noise +def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None): + def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) strength = 1.0 - min_sigma = 0.0 - max_sigma = 999.0 if 'area' in cond[1]: area = cond[1]['area'] if 'strength' in cond[1]: @@ -48,9 +49,60 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale): if (area[1] + area[3]) < x_in.shape[3]: for t in range(rr): mult[:,:,:,area[1] + area[3] - 1 - t:area[1] + area[3] - t] *= ((1.0/rr) * (t + 1)) - return (input_x, mult, cond[0], area) - - def calc_cond_uncond_batch(model_function, cond, uncond, x_in, sigma, max_total_area): + conditionning = {} + conditionning['c_crossattn'] = cond[0] + if cond_concat_in is not None and len(cond_concat_in) > 0: + cropped = [] + for x in cond_concat_in: + cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] + cropped.append(cr) + conditionning['c_concat'] = torch.cat(cropped, dim=1) + + control = None + if 'control' in cond[1]: + control = cond[1]['control'] + return (input_x, mult, conditionning, area, control) + + def cond_equal_size(c1, c2): + if c1 is c2: + return True + if c1.keys() != c2.keys(): + return False + if 'c_crossattn' in c1: + if c1['c_crossattn'].shape != c2['c_crossattn'].shape: + return False + if 'c_concat' in c1: + if c1['c_concat'].shape != c2['c_concat'].shape: + return False + return True + + def can_concat_cond(c1, c2): + if c1[0].shape != c2[0].shape: + return False + if (c1[4] is None) != (c2[4] is None): + return False + if c1[4] is not None: + if c1[4] is not c2[4]: + return False + + return cond_equal_size(c1[2], c2[2]) + + def cond_cat(c_list): + c_crossattn = [] + c_concat = [] + for x in c_list: + if 'c_crossattn' in x: + c_crossattn.append(x['c_crossattn']) + if 'c_concat' in x: + c_concat.append(x['c_concat']) + out = {} + if len(c_crossattn) > 0: + out['c_crossattn'] = [torch.cat(c_crossattn)] + if len(c_concat) > 0: + out['c_concat'] = [torch.cat(c_concat)] + return out + + def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in): out_cond = torch.zeros_like(x_in) out_count = torch.ones_like(x_in)/100000.0 @@ -62,13 +114,13 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale): to_run = [] for x in cond: - p = get_area_and_mult(x, x_in) + p = get_area_and_mult(x, x_in, cond_concat_in, timestep) if p is None: continue to_run += [(p, COND)] for x in uncond: - p = get_area_and_mult(x, x_in) + p = get_area_and_mult(x, x_in, cond_concat_in, timestep) if p is None: continue @@ -79,9 +131,8 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale): first_shape = first[0][0].shape to_batch_temp = [] for x in range(len(to_run)): - if to_run[x][0][0].shape == first_shape: - if to_run[x][0][2].shape == first[0][2].shape: - to_batch_temp += [x] + if can_concat_cond(to_run[x][0], first[0]): + to_batch_temp += [x] to_batch_temp.reverse() to_batch = to_batch_temp[:1] @@ -97,6 +148,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale): c = [] cond_or_uncond = [] area = [] + control = None for x in to_batch: o = to_run.pop(x) p = o[0] @@ -105,13 +157,17 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale): c += [p[2]] area += [p[3]] cond_or_uncond += [o[1]] + control = p[4] batch_chunks = len(cond_or_uncond) input_x = torch.cat(input_x) - c = torch.cat(c) - sigma_ = torch.cat([sigma] * batch_chunks) + c = cond_cat(c) + timestep_ = torch.cat([timestep] * batch_chunks) - output = model_function(input_x, sigma_, cond=c).chunk(batch_chunks) + if control is not None: + c['control'] = control.get_control(input_x, timestep_, c['c_crossattn']) + + output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks) del input_x for o in range(batch_chunks): @@ -132,15 +188,43 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale): max_total_area = model_management.maximum_batch_area() - cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, sigma, max_total_area) + cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat) return uncond + (cond - uncond) * cond_scale -class CFGDenoiserComplex(torch.nn.Module): + +class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser): + def __init__(self, model, quantize=False, device='cpu'): + super().__init__(model, model.alphas_cumprod, quantize=quantize) + + def get_v(self, x, t, cond, **kwargs): + return self.inner_model.apply_model(x, t, cond, **kwargs) + + +class CFGNoisePredictor(torch.nn.Module): def __init__(self, model): super().__init__() self.inner_model = model - def forward(self, x, sigma, uncond, cond, cond_scale): - return sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale) + self.alphas_cumprod = model.alphas_cumprod + def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None): + out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat) + return out + + +class KSamplerX0Inpaint(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.inner_model = model + def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None): + if denoise_mask is not None: + latent_mask = 1. - denoise_mask + x = x * denoise_mask + (self.latent_image + self.noise * sigma) * latent_mask + out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat) + if denoise_mask is not None: + out *= denoise_mask + + if denoise_mask is not None: + out += self.latent_image * latent_mask + return out def simple_scheduler(model, steps): sigs = [] @@ -150,6 +234,15 @@ def simple_scheduler(model, steps): sigs += [0.0] return torch.FloatTensor(sigs) +def blank_inpaint_image_like(latent_image): + blank_image = torch.ones_like(latent_image) + # these are the values for "zero" in pixel space translated to latent space + blank_image[:,0] *= 0.8223 + blank_image[:,1] *= -0.6876 + blank_image[:,2] *= 0.6364 + blank_image[:,3] *= 0.1380 + return blank_image + def create_cond_with_same_area_if_none(conds, c): if 'area' not in c[1]: return @@ -180,6 +273,42 @@ def create_cond_with_same_area_if_none(conds, c): n = c[1].copy() conds += [[smallest[0], n]] + +def apply_control_net_to_equal_area(conds, uncond): + cond_cnets = [] + cond_other = [] + uncond_cnets = [] + uncond_other = [] + for t in range(len(conds)): + x = conds[t] + if 'area' not in x[1]: + if 'control' in x[1] and x[1]['control'] is not None: + cond_cnets.append(x[1]['control']) + else: + cond_other.append((x, t)) + for t in range(len(uncond)): + x = uncond[t] + if 'area' not in x[1]: + if 'control' in x[1] and x[1]['control'] is not None: + uncond_cnets.append(x[1]['control']) + else: + uncond_other.append((x, t)) + + if len(uncond_cnets) > 0: + return + + for x in range(len(cond_cnets)): + temp = uncond_other[x % len(uncond_other)] + o = temp[0] + if 'control' in o[1] and o[1]['control'] is not None: + n = o[1].copy() + n['control'] = cond_cnets[x] + uncond += [[o[0], n]] + else: + n = o[1].copy() + n['control'] = cond_cnets[x] + uncond[temp[1]] = [o[0], n] + class KSampler: SCHEDULERS = ["karras", "normal", "simple"] SAMPLERS = ["sample_euler", "sample_euler_ancestral", "sample_heun", "sample_dpm_2", "sample_dpm_2_ancestral", @@ -188,11 +317,13 @@ class KSampler: def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None): self.model = model + self.model_denoise = CFGNoisePredictor(self.model) if self.model.parameterization == "v": - self.model_wrap = k_diffusion_external.CompVisVDenoiser(self.model, quantize=True) + self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True) else: - self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model, quantize=True) - self.model_k = CFGDenoiserComplex(self.model_wrap) + self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True) + self.model_wrap.parameterization = self.model.parameterization + self.model_k = KSamplerX0Inpaint(self.model_wrap) self.device = device if scheduler not in self.SCHEDULERS: scheduler = self.SCHEDULERS[0] @@ -200,8 +331,8 @@ class KSampler: sampler = self.SAMPLERS[0] self.scheduler = scheduler self.sampler = sampler - self.sigma_min=float(self.model_wrap.sigmas[0]) - self.sigma_max=float(self.model_wrap.sigmas[-1]) + self.sigma_min=float(self.model_wrap.sigma_min) + self.sigma_max=float(self.model_wrap.sigma_max) self.set_steps(steps, denoise) def _calculate_sigmas(self, steps): @@ -235,7 +366,7 @@ class KSampler: self.sigmas = sigmas[-(steps + 1):] - def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False): + def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None): sigmas = self.sigmas sigma_min = self.sigma_min @@ -262,22 +393,47 @@ class KSampler: for c in negative: create_cond_with_same_area_if_none(positive, c) + apply_control_net_to_equal_area(positive, negative) + if self.model.model.diffusion_model.dtype == torch.float16: precision_scope = torch.autocast else: precision_scope = contextlib.nullcontext + extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg} + + if hasattr(self.model, 'concat_keys'): + cond_concat = [] + for ck in self.model.concat_keys: + if denoise_mask is not None: + if ck == "mask": + cond_concat.append(denoise_mask[:,:1]) + elif ck == "masked_image": + cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space + else: + if ck == "mask": + cond_concat.append(torch.ones_like(noise)[:,:1]) + elif ck == "masked_image": + cond_concat.append(blank_inpaint_image_like(noise)) + extra_args["cond_concat"] = cond_concat + with precision_scope(self.device): if self.sampler == "uni_pc": - samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg}) + samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, extra_args=extra_args, noise_mask=denoise_mask) else: - noise *= sigmas[0] + extra_args["denoise_mask"] = denoise_mask + self.model_k.latent_image = latent_image + self.model_k.noise = noise + + noise = noise * sigmas[0] + if latent_image is not None: noise += latent_image if self.sampler == "sample_dpm_fast": - samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg}) + samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], self.steps, extra_args=extra_args) elif self.sampler == "sample_dpm_adaptive": - samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg}) + samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args) else: - samples = getattr(k_diffusion_sampling, self.sampler)(self.model_k, noise, sigmas, extra_args={"cond":positive, "uncond":negative, "cond_scale": cfg}) + samples = getattr(k_diffusion_sampling, self.sampler)(self.model_k, noise, sigmas, extra_args=extra_args) + return samples.to(torch.float32) diff --git a/comfy/sd.py b/comfy/sd.py index a3c0066..61a01de 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -6,6 +6,9 @@ import model_management from ldm.util import instantiate_from_config from ldm.models.autoencoder import AutoencoderKL from omegaconf import OmegaConf +from .cldm import cldm + +from . import utils def load_torch_file(ckpt): if ckpt.lower().endswith(".safetensors"): @@ -323,6 +326,84 @@ class VAE: samples = samples.cpu() return samples +class ControlNet: + def __init__(self, control_model): + self.control_model = control_model + self.cond_hint_original = None + self.cond_hint = None + self.strength = 1.0 + + def get_control(self, x_noisy, t, cond_txt): + if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: + if self.cond_hint is not None: + del self.cond_hint + self.cond_hint = None + self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(x_noisy.device) + print("set cond_hint", self.cond_hint.shape) + control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=cond_txt) + for x in control: + x *= self.strength + return control + + def set_cond_hint(self, cond_hint, strength=1.0): + self.cond_hint_original = cond_hint + self.strength = strength + return self + + def cleanup(self): + if self.cond_hint is not None: + del self.cond_hint + self.cond_hint = None + + def copy(self): + c = ControlNet(self.control_model) + c.cond_hint_original = self.cond_hint_original + c.strength = self.strength + return c + +def load_controlnet(ckpt_path): + controlnet_data = load_torch_file(ckpt_path) + pth_key = 'control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight' + pth = False + sd2 = False + key = 'input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight' + if pth_key in controlnet_data: + pth = True + key = pth_key + elif key in controlnet_data: + pass + else: + print("error checkpoint does not contain controlnet data", ckpt_path) + return None + + context_dim = controlnet_data[key].shape[1] + control_model = cldm.ControlNet(image_size=32, + in_channels=4, + hint_channels=3, + model_channels=320, + attention_resolutions=[ 4, 2, 1 ], + num_res_blocks=2, + channel_mult=[ 1, 2, 4, 4 ], + num_heads=8, + use_spatial_transformer=True, + transformer_depth=1, + context_dim=context_dim, + use_checkpoint=True, + legacy=False) + + if pth: + class WeightsLoader(torch.nn.Module): + pass + w = WeightsLoader() + w.control_model = control_model + w.load_state_dict(controlnet_data, strict=False) + else: + control_model.load_state_dict(controlnet_data, strict=False) + + control = ControlNet(control_model) + return control + + def load_clip(ckpt_path, embedding_directory=None): clip_data = load_torch_file(ckpt_path) config = {} diff --git a/comfy/utils.py b/comfy/utils.py new file mode 100644 index 0000000..815e899 --- /dev/null +++ b/comfy/utils.py @@ -0,0 +1,18 @@ +import torch + +def common_upscale(samples, width, height, upscale_method, crop): + if crop == "center": + old_width = samples.shape[3] + old_height = samples.shape[2] + old_aspect = old_width / old_height + new_aspect = width / height + x = 0 + y = 0 + if old_aspect > new_aspect: + x = round((old_width - old_width * (new_aspect / old_aspect)) / 2) + elif old_aspect < new_aspect: + y = round((old_height - old_height * (old_aspect / new_aspect)) / 2) + s = samples[:,:,y:old_height-y,x:old_width-x] + else: + s = samples + return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) diff --git a/main.py b/main.py index c5f04aa..f5aec44 100644 --- a/main.py +++ b/main.py @@ -7,6 +7,10 @@ import heapq import traceback import asyncio +if os.name == "nt": + import logging + logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) + try: import aiohttp from aiohttp import web diff --git a/models/configs/v2-inpainting-inference.yaml b/models/configs/v2-inpainting-inference.yaml new file mode 100644 index 0000000..32a9471 --- /dev/null +++ b/models/configs/v2-inpainting-inference.yaml @@ -0,0 +1,158 @@ +model: + base_learning_rate: 5.0e-05 + target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: hybrid + scale_factor: 0.18215 + monitor: val/loss_simple_ema + finetune_keys: null + use_ema: False + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + image_size: 32 # unused + in_channels: 9 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + + +data: + target: ldm.data.laion.WebDataModuleFromConfig + params: + tar_base: null # for concat as in LAION-A + p_unsafe_threshold: 0.1 + filter_word_list: "data/filters.yaml" + max_pwatermark: 0.45 + batch_size: 8 + num_workers: 6 + multinode: True + min_size: 512 + train: + shards: + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar" + shuffle: 10000 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.RandomCrop + params: + size: 512 + postprocess: + target: ldm.data.laion.AddMask + params: + mode: "512train-large" + p_drop: 0.25 + # NOTE use enough shards to avoid empty validation loops in workers + validation: + shards: + - "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - " + shuffle: 0 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.CenterCrop + params: + size: 512 + postprocess: + target: ldm.data.laion.AddMask + params: + mode: "512train-large" + p_drop: 0.25 + +lightning: + find_unused_parameters: True + modelcheckpoint: + params: + every_n_train_steps: 5000 + + callbacks: + metrics_over_trainsteps_checkpoint: + params: + every_n_train_steps: 10000 + + image_logger: + target: main.ImageLogger + params: + enable_autocast: False + disabled: False + batch_frequency: 1000 + max_images: 4 + increase_log_steps: False + log_first_step: False + log_images_kwargs: + use_ema_scope: False + inpaint: False + plot_progressive_rows: False + plot_diffusion_rows: False + N: 4 + unconditional_guidance_scale: 5.0 + unconditional_guidance_label: [""] + ddim_steps: 50 # todo check these out for depth2img, + ddim_eta: 0.0 # todo check these out for depth2img, + + trainer: + benchmark: True + val_check_interval: 5000000 + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 diff --git a/models/controlnet/put_controlnets_here b/models/controlnet/put_controlnets_here new file mode 100644 index 0000000..e69de29 diff --git a/nodes.py b/nodes.py index 65f64e8..d2fd3ff 100644 --- a/nodes.py +++ b/nodes.py @@ -15,11 +15,13 @@ sys.path.insert(0, os.path.join(sys.path[0], "comfy")) import comfy.samplers import comfy.sd +import comfy.utils + import model_management import importlib -supported_ckpt_extensions = ['.ckpt'] -supported_pt_extensions = ['.ckpt', '.pt', '.bin'] +supported_ckpt_extensions = ['.ckpt', '.pth'] +supported_pt_extensions = ['.ckpt', '.pt', '.bin', '.pth'] try: import safetensors.torch supported_ckpt_extensions += ['.safetensors'] @@ -78,12 +80,14 @@ class ConditioningSetArea: CATEGORY = "conditioning" def append(self, conditioning, width, height, x, y, strength, min_sigma=0.0, max_sigma=99.0): - c = copy.deepcopy(conditioning) - for t in c: - t[1]['area'] = (height // 8, width // 8, y // 8, x // 8) - t[1]['strength'] = strength - t[1]['min_sigma'] = min_sigma - t[1]['max_sigma'] = max_sigma + c = [] + for t in conditioning: + n = [t[0], t[1].copy()] + n[1]['area'] = (height // 8, width // 8, y // 8, x // 8) + n[1]['strength'] = strength + n[1]['min_sigma'] = min_sigma + n[1]['max_sigma'] = max_sigma + c.append(n) return (c, ) class VAEDecode: @@ -99,7 +103,7 @@ class VAEDecode: CATEGORY = "latent" def decode(self, vae, samples): - return (vae.decode(samples), ) + return (vae.decode(samples["samples"]), ) class VAEEncode: def __init__(self, device="cpu"): @@ -118,7 +122,39 @@ class VAEEncode: y = (pixels.shape[2] // 64) * 64 if pixels.shape[1] != x or pixels.shape[2] != y: pixels = pixels[:,:x,:y,:] - return (vae.encode(pixels), ) + t = vae.encode(pixels[:,:,:,:3]) + + return ({"samples":t}, ) + +class VAEEncodeForInpaint: + def __init__(self, device="cpu"): + self.device = device + + @classmethod + def INPUT_TYPES(s): + return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", )}} + RETURN_TYPES = ("LATENT",) + FUNCTION = "encode" + + CATEGORY = "latent/inpaint" + + def encode(self, vae, pixels, mask): + x = (pixels.shape[1] // 64) * 64 + y = (pixels.shape[2] // 64) * 64 + if pixels.shape[1] != x or pixels.shape[2] != y: + pixels = pixels[:,:x,:y,:] + mask = mask[:x,:y] + + #shave off a few pixels to keep things seamless + kernel_tensor = torch.ones((1, 1, 6, 6)) + mask_erosion = torch.clamp(torch.nn.functional.conv2d((1.0 - mask.round())[None], kernel_tensor, padding=3), 0, 1) + for i in range(3): + pixels[:,:,:,i] -= 0.5 + pixels[:,:,:,i] *= mask_erosion[0][:x,:y].round() + pixels[:,:,:,i] += 0.5 + t = vae.encode(pixels) + + return ({"samples":t, "noise_mask": mask}, ) class CheckpointLoader: models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") @@ -178,6 +214,48 @@ class VAELoader: vae = comfy.sd.VAE(ckpt_path=vae_path) return (vae,) +class ControlNetLoader: + models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") + controlnet_dir = os.path.join(models_dir, "controlnet") + @classmethod + def INPUT_TYPES(s): + return {"required": { "control_net_name": (filter_files_extensions(recursive_search(s.controlnet_dir), supported_pt_extensions), )}} + + RETURN_TYPES = ("CONTROL_NET",) + FUNCTION = "load_controlnet" + + CATEGORY = "loaders" + + def load_controlnet(self, control_net_name): + controlnet_path = os.path.join(self.controlnet_dir, control_net_name) + controlnet = comfy.sd.load_controlnet(controlnet_path) + return (controlnet,) + + +class ControlNetApply: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning": ("CONDITIONING", ), + "control_net": ("CONTROL_NET", ), + "image": ("IMAGE", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}) + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "apply_controlnet" + + CATEGORY = "conditioning" + + def apply_controlnet(self, conditioning, control_net, image, strength): + c = [] + control_hint = image.movedim(-1,1) + print(control_hint.shape) + for t in conditioning: + n = [t[0], t[1].copy()] + n[1]['control'] = control_net.copy().set_cond_hint(control_hint, strength) + c.append(n) + return (c, ) + + class CLIPLoader: models_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "models") clip_dir = os.path.join(models_dir, "clip") @@ -213,24 +291,9 @@ class EmptyLatentImage: def generate(self, width, height, batch_size=1): latent = torch.zeros([batch_size, 4, height // 8, width // 8]) - return (latent, ) - -def common_upscale(samples, width, height, upscale_method, crop): - if crop == "center": - old_width = samples.shape[3] - old_height = samples.shape[2] - old_aspect = old_width / old_height - new_aspect = width / height - x = 0 - y = 0 - if old_aspect > new_aspect: - x = round((old_width - old_width * (new_aspect / old_aspect)) / 2) - elif old_aspect < new_aspect: - y = round((old_height - old_height * (old_aspect / new_aspect)) / 2) - s = samples[:,:,y:old_height-y,x:old_width-x] - else: - s = samples - return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) + return ({"samples":latent}, ) + + class LatentUpscale: upscale_methods = ["nearest-exact", "bilinear", "area"] @@ -248,7 +311,8 @@ class LatentUpscale: CATEGORY = "latent" def upscale(self, samples, upscale_method, width, height, crop): - s = common_upscale(samples, width // 8, height // 8, upscale_method, crop) + s = samples.copy() + s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop) return (s,) class LatentRotate: @@ -263,6 +327,7 @@ class LatentRotate: CATEGORY = "latent" def rotate(self, samples, rotation): + s = samples.copy() rotate_by = 0 if rotation.startswith("90"): rotate_by = 1 @@ -271,7 +336,7 @@ class LatentRotate: elif rotation.startswith("270"): rotate_by = 3 - s = torch.rot90(samples, k=rotate_by, dims=[3, 2]) + s["samples"] = torch.rot90(samples["samples"], k=rotate_by, dims=[3, 2]) return (s,) class LatentFlip: @@ -286,12 +351,11 @@ class LatentFlip: CATEGORY = "latent" def flip(self, samples, flip_method): + s = samples.copy() if flip_method.startswith("x"): - s = torch.flip(samples, dims=[2]) + s["samples"] = torch.flip(samples["samples"], dims=[2]) elif flip_method.startswith("y"): - s = torch.flip(samples, dims=[3]) - else: - s = samples + s["samples"] = torch.flip(samples["samples"], dims=[3]) return (s,) @@ -313,12 +377,15 @@ class LatentComposite: x = x // 8 y = y // 8 feather = feather // 8 - s = samples_to.clone() + samples_out = samples_to.copy() + s = samples_to["samples"].clone() + samples_to = samples_to["samples"] + samples_from = samples_from["samples"] if feather == 0: s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] else: - s_from = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] - mask = torch.ones_like(s_from) + samples_from = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] + mask = torch.ones_like(samples_from) for t in range(feather): if y != 0: mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1)) @@ -331,7 +398,8 @@ class LatentComposite: mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1)) rev_mask = torch.ones_like(mask) - mask s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] * mask + s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] * rev_mask - return (s,) + samples_out["samples"] = s + return (samples_out,) class LatentCrop: @classmethod @@ -348,6 +416,8 @@ class LatentCrop: CATEGORY = "latent" def crop(self, samples, width, height, x, y): + s = samples.copy() + samples = samples['samples'] x = x // 8 y = y // 8 @@ -371,15 +441,43 @@ class LatentCrop: #make sure size is always multiple of 64 x, to_x = enforce_image_dim(x, to_x, samples.shape[3]) y, to_y = enforce_image_dim(y, to_y, samples.shape[2]) - s = samples[:,:,y:to_y, x:to_x] + s['samples'] = samples[:,:,y:to_y, x:to_x] return (s,) -def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): +class SetLatentNoiseMask: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "mask": ("MASK",), + }} + RETURN_TYPES = ("LATENT",) + FUNCTION = "set_mask" + + CATEGORY = "latent/inpaint" + + def set_mask(self, samples, mask): + s = samples.copy() + s["noise_mask"] = mask + return (s,) + + +def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): + latent_image = latent["samples"] + noise_mask = None + if disable_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") else: noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu") + if "noise_mask" in latent: + noise_mask = latent['noise_mask'] + noise_mask = torch.nn.functional.interpolate(noise_mask[None,None,], size=(noise.shape[2], noise.shape[3]), mode="bilinear") + noise_mask = noise_mask.round() + noise_mask = torch.cat([noise_mask] * noise.shape[1], dim=1) + noise_mask = torch.cat([noise_mask] * noise.shape[0]) + noise_mask = noise_mask.to(device) + real_model = None if device != "cpu": model_management.load_model_gpu(model) @@ -393,29 +491,40 @@ def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, po positive_copy = [] negative_copy = [] + control_nets = [] for p in positive: t = p[0] if t.shape[0] < noise.shape[0]: t = torch.cat([t] * noise.shape[0]) t = t.to(device) + if 'control' in p[1]: + control_nets += [p[1]['control']] positive_copy += [[t] + p[1:]] for n in negative: t = n[0] if t.shape[0] < noise.shape[0]: t = torch.cat([t] * noise.shape[0]) t = t.to(device) + if 'control' in p[1]: + control_nets += [p[1]['control']] negative_copy += [[t] + n[1:]] + model_management.load_controlnet_gpu(list(map(lambda a: a.control_model, control_nets))) + if sampler_name in comfy.samplers.KSampler.SAMPLERS: sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise) else: #other samplers pass - samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise) + samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask) samples = samples.cpu() + for c in control_nets: + c.cleanup() - return (samples, ) + out = latent.copy() + out["samples"] = samples + return (out, ) class KSampler: def __init__(self, device="cuda"): @@ -532,7 +641,7 @@ class LoadImage: @classmethod def INPUT_TYPES(s): return {"required": - {"image": (os.listdir(s.input_dir), )}, + {"image": (sorted(os.listdir(s.input_dir)), )}, } CATEGORY = "image" @@ -541,10 +650,11 @@ class LoadImage: FUNCTION = "load_image" def load_image(self, image): image_path = os.path.join(self.input_dir, image) - image = Image.open(image_path).convert("RGB") + i = Image.open(image_path) + image = i.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 - image = torch.from_numpy(image[None])[None,] - return image + image = torch.from_numpy(image)[None,] + return (image,) @classmethod def IS_CHANGED(s, image): @@ -554,6 +664,41 @@ class LoadImage: m.update(f.read()) return m.digest().hex() +class LoadImageMask: + input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") + @classmethod + def INPUT_TYPES(s): + return {"required": + {"image": (os.listdir(s.input_dir), ), + "channel": (["alpha", "red", "green", "blue"], ),} + } + + CATEGORY = "image" + + RETURN_TYPES = ("MASK",) + FUNCTION = "load_image" + def load_image(self, image, channel): + image_path = os.path.join(self.input_dir, image) + i = Image.open(image_path) + mask = None + c = channel[0].upper() + if c in i.getbands(): + mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0 + mask = torch.from_numpy(mask) + if c == 'A': + mask = 1. - mask + else: + mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") + return (mask,) + + @classmethod + def IS_CHANGED(s, image, channel): + image_path = os.path.join(s.input_dir, image) + m = hashlib.sha256() + with open(image_path, 'rb') as f: + m.update(f.read()) + return m.digest().hex() + class ImageScale: upscale_methods = ["nearest-exact", "bilinear", "area"] crop_methods = ["disabled", "center"] @@ -571,7 +716,7 @@ class ImageScale: def upscale(self, image, upscale_method, width, height, crop): samples = image.movedim(-1,1) - s = common_upscale(samples, width, height, upscale_method, crop) + s = comfy.utils.common_upscale(samples, width, height, upscale_method, crop) s = s.movedim(1,-1) return (s,) @@ -581,21 +726,26 @@ NODE_CLASS_MAPPINGS = { "CLIPTextEncode": CLIPTextEncode, "VAEDecode": VAEDecode, "VAEEncode": VAEEncode, + "VAEEncodeForInpaint": VAEEncodeForInpaint, "VAELoader": VAELoader, "EmptyLatentImage": EmptyLatentImage, "LatentUpscale": LatentUpscale, "SaveImage": SaveImage, "LoadImage": LoadImage, + "LoadImageMask": LoadImageMask, "ImageScale": ImageScale, "ConditioningCombine": ConditioningCombine, "ConditioningSetArea": ConditioningSetArea, "KSamplerAdvanced": KSamplerAdvanced, + "SetLatentNoiseMask": SetLatentNoiseMask, "LatentComposite": LatentComposite, "LatentRotate": LatentRotate, "LatentFlip": LatentFlip, "LatentCrop": LatentCrop, "LoraLoader": LoraLoader, "CLIPLoader": CLIPLoader, + "ControlNetApply": ControlNetApply, + "ControlNetLoader": ControlNetLoader, } CUSTOM_NODE_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "custom_nodes")