From 69df7eba94edd8eee430374a380de15e54ab1f2c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 31 Jan 2023 03:09:38 -0500 Subject: [PATCH] Add KSamplerAdvanced node. This node exposes more sampling options and makes it possible for example to sample the first few steps on the latent image, do some operations on it and then do the rest of the sampling steps. This can be achieved using the start_at_step and end_at_step options. --- comfy/samplers.py | 15 +++++-- nodes.py | 105 ++++++++++++++++++++++++++++++++-------------- 2 files changed, 86 insertions(+), 34 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index f4faa4c..84df795 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -168,15 +168,24 @@ class KSampler: self.sigmas = sigmas[-(steps + 1):] - def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None): + def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False): sigmas = self.sigmas sigma_min = self.sigma_min - if last_step is not None: + if last_step is not None and last_step < (len(sigmas) - 1): sigma_min = sigmas[last_step] sigmas = sigmas[:last_step + 1] + if force_full_denoise: + sigmas[-1] = 0 + if start_step is not None: - sigmas = sigmas[start_step:] + if start_step < (len(sigmas) - 1): + sigmas = sigmas[start_step:] + else: + if latent_image is not None: + return latent_image + else: + return torch.zeros_like(noise) noise *= sigmas[0] if latent_image is not None: diff --git a/nodes.py b/nodes.py index a55e8c9..974ebf5 100644 --- a/nodes.py +++ b/nodes.py @@ -221,13 +221,50 @@ class LatentRotate: s = torch.rot90(samples, k=rotate_by, dims=[3, 2]) return (s,) +def common_ksampler(device, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): + if disable_noise: + noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") + else: + noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu") + + model = model.to(device) + noise = noise.to(device) + latent_image = latent_image.to(device) + + positive_copy = [] + negative_copy = [] + + for p in positive: + t = p[0] + if t.shape[0] < noise.shape[0]: + t = torch.cat([t] * noise.shape[0]) + t = t.to(device) + positive_copy += [[t] + p[1:]] + for n in negative: + t = n[0] + if t.shape[0] < noise.shape[0]: + t = torch.cat([t] * noise.shape[0]) + t = t.to(device) + negative_copy += [[t] + n[1:]] + + if sampler_name in comfy.samplers.KSampler.SAMPLERS: + sampler = comfy.samplers.KSampler(model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise) + else: + #other samplers + pass + + samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise) + samples = samples.cpu() + model = model.cpu() + return (samples, ) + class KSampler: def __init__(self, device="cuda"): self.device = device @classmethod def INPUT_TYPES(s): - return {"required": + return {"required": {"model": ("MODEL",), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), @@ -246,38 +283,43 @@ class KSampler: CATEGORY = "sampling" def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0): - noise = torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=torch.manual_seed(seed), device="cpu") - model = model.to(self.device) - noise = noise.to(self.device) - latent_image = latent_image.to(self.device) - - positive_copy = [] - negative_copy = [] - - for p in positive: - t = p[0] - if t.shape[0] < noise.shape[0]: - t = torch.cat([t] * noise.shape[0]) - t = t.to(self.device) - positive_copy += [[t] + p[1:]] - for n in negative: - t = n[0] - if t.shape[0] < noise.shape[0]: - t = torch.cat([t] * noise.shape[0]) - t = t.to(self.device) - negative_copy += [[t] + n[1:]] - - if sampler_name in comfy.samplers.KSampler.SAMPLERS: - sampler = comfy.samplers.KSampler(model, steps=steps, device=self.device, sampler=sampler_name, scheduler=scheduler, denoise=denoise) - else: - #other samplers - pass + return common_ksampler(self.device, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise) - samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image) - samples = samples.cpu() - model = model.cpu() - return (samples, ) +class KSamplerAdvanced: + def __init__(self, device="cuda"): + self.device = device + + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "add_noise": (["enable", "disable"], ), + "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}), + "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), + "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), + "positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "latent_image": ("LATENT", ), + "start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}), + "end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}), + "return_with_leftover_noise": (["disable", "enable"], ), + }} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "sample" + + CATEGORY = "sampling" + def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0): + force_full_denoise = True + if return_with_leftover_noise == "enable": + force_full_denoise = False + disable_noise = False + if add_noise == "disable": + disable_noise = True + return common_ksampler(self.device, model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise) class SaveImage: def __init__(self): @@ -365,6 +407,7 @@ NODE_CLASS_MAPPINGS = { "LoadImage": LoadImage, "ConditioningCombine": ConditioningCombine, "ConditioningSetArea": ConditioningSetArea, + "KSamplerAdvanced": KSamplerAdvanced, "LatentRotate": LatentRotate, }