From f04dc2c2f4cba93ee180a337347df6f9f567aa70 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 22 Feb 2023 21:06:43 -0500 Subject: [PATCH] Implement DDIM sampler. --- comfy/ldm/models/diffusion/ddim.py | 99 ++++++++++++++++++++++++++---- comfy/samplers.py | 43 ++++++++++++- 2 files changed, 127 insertions(+), 15 deletions(-) diff --git a/comfy/ldm/models/diffusion/ddim.py b/comfy/ldm/models/diffusion/ddim.py index c6cfd57..fe39c76 100644 --- a/comfy/ldm/models/diffusion/ddim.py +++ b/comfy/ldm/models/diffusion/ddim.py @@ -22,11 +22,15 @@ class DDIMSampler(object): setattr(self, name, attr) def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): - self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + self.make_schedule_timesteps(ddim_timesteps, ddim_eta=ddim_eta, verbose=verbose) + + def make_schedule_timesteps(self, ddim_timesteps, ddim_eta=0., verbose=True): + self.ddim_timesteps = torch.tensor(ddim_timesteps) alphas_cumprod = self.model.alphas_cumprod assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' - to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device) self.register_buffer('betas', to_torch(self.model.betas)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) @@ -52,6 +56,58 @@ class DDIMSampler(object): 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + @torch.no_grad() + def sample_custom(self, + ddim_timesteps, + conditioning, + callback=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + ucg_schedule=None, + denoise_function=None, + cond_concat=None, + to_zero=True, + end_step=None, + **kwargs + ): + self.make_schedule_timesteps(ddim_timesteps=ddim_timesteps, ddim_eta=eta, verbose=verbose) + samples, intermediates = self.ddim_sampling(conditioning, x_T.shape, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ucg_schedule=ucg_schedule, + denoise_function=denoise_function, + cond_concat=cond_concat, + to_zero=to_zero, + end_step=end_step + ) + return samples, intermediates + + @torch.no_grad() def sample(self, S, @@ -116,7 +172,9 @@ class DDIMSampler(object): unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, dynamic_threshold=dynamic_threshold, - ucg_schedule=ucg_schedule + ucg_schedule=ucg_schedule, + denoise_function=None, + cond_concat=None ) return samples, intermediates @@ -127,7 +185,7 @@ class DDIMSampler(object): mask=None, x0=None, img_callback=None, log_every_t=100, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, - ucg_schedule=None): + ucg_schedule=None, denoise_function=None, cond_concat=None, to_zero=True, end_step=None): device = self.model.betas.device b = shape[0] if x_T is None: @@ -142,11 +200,11 @@ class DDIMSampler(object): timesteps = self.ddim_timesteps[:subset_end] intermediates = {'x_inter': [img], 'pred_x0': [img]} - time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else timesteps.flip(0) total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] - print(f"Running DDIM Sampling with {total_steps} timesteps") + # print(f"Running DDIM Sampling with {total_steps} timesteps") - iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + iterator = tqdm(time_range[:end_step], desc='DDIM Sampler', total=end_step) for i, step in enumerate(iterator): index = total_steps - i - 1 @@ -167,7 +225,7 @@ class DDIMSampler(object): corrector_kwargs=corrector_kwargs, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold) + dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, cond_concat=cond_concat) img, pred_x0 = outs if callback: callback(i) if img_callback: img_callback(pred_x0, i) @@ -176,16 +234,27 @@ class DDIMSampler(object): intermediates['x_inter'].append(img) intermediates['pred_x0'].append(pred_x0) + if to_zero: + img = pred_x0 + else: + if ddim_use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + img /= sqrt_alphas_cumprod[index - 1] + return img, intermediates @torch.no_grad() def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1., unconditional_conditioning=None, - dynamic_threshold=None): + dynamic_threshold=None, denoise_function=None, cond_concat=None): b, *_, device = *x.shape, x.device - if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + if denoise_function is not None: + model_output = denoise_function(self.model.apply_model, x, t, unconditional_conditioning, c, unconditional_guidance_scale, cond_concat) + elif unconditional_conditioning is None or unconditional_guidance_scale == 1.: model_output = self.model.apply_model(x, t, c) else: x_in = torch.cat([x] * 2) @@ -299,7 +368,7 @@ class DDIMSampler(object): return x_next, out @torch.no_grad() - def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None, max_denoise=False): # fast, but does not allow for exact reconstruction # t serves as an index to gather the correct alphas if use_original_steps: @@ -311,8 +380,12 @@ class DDIMSampler(object): if noise is None: noise = torch.randn_like(x0) - return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + - extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) + if max_denoise: + noise_multiplier = 1.0 + else: + noise_multiplier = extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) + + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + noise_multiplier * noise) @torch.no_grad() def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, diff --git a/comfy/samplers.py b/comfy/samplers.py index 2dc5a53..437d164 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -4,6 +4,8 @@ from .extra_samplers import uni_pc import torch import contextlib import model_management +from .ldm.models.diffusion.ddim import DDIMSampler +from .ldm.modules.diffusionmodules.util import make_ddim_timesteps class CFGDenoiser(torch.nn.Module): def __init__(self, model): @@ -234,6 +236,14 @@ def simple_scheduler(model, steps): sigs += [0.0] return torch.FloatTensor(sigs) +def ddim_scheduler(model, steps): + sigs = [] + ddim_timesteps = make_ddim_timesteps(ddim_discr_method="uniform", num_ddim_timesteps=steps, num_ddpm_timesteps=model.inner_model.inner_model.num_timesteps, verbose=False) + for x in range(len(ddim_timesteps) - 1, -1, -1): + sigs.append(model.t_to_sigma(torch.tensor(ddim_timesteps[x]))) + sigs += [0.0] + return torch.FloatTensor(sigs) + def blank_inpaint_image_like(latent_image): blank_image = torch.ones_like(latent_image) # these are the values for "zero" in pixel space translated to latent space @@ -310,10 +320,10 @@ def apply_control_net_to_equal_area(conds, uncond): uncond[temp[1]] = [o[0], n] class KSampler: - SCHEDULERS = ["karras", "normal", "simple"] + SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"] SAMPLERS = ["sample_euler", "sample_euler_ancestral", "sample_heun", "sample_dpm_2", "sample_dpm_2_ancestral", "sample_lms", "sample_dpm_fast", "sample_dpm_adaptive", "sample_dpmpp_2s_ancestral", "sample_dpmpp_sde", - "sample_dpmpp_2m", "uni_pc", "uni_pc_bh2"] + "sample_dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"] def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None): self.model = model @@ -350,6 +360,8 @@ class KSampler: sigmas = self.model_wrap.get_sigmas(steps).to(self.device) elif self.scheduler == "simple": sigmas = simple_scheduler(self.model_wrap, steps).to(self.device) + elif self.scheduler == "ddim_uniform": + sigmas = ddim_scheduler(self.model_wrap, steps).to(self.device) else: print("error invalid scheduler", self.scheduler) @@ -403,6 +415,7 @@ class KSampler: extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg} + cond_concat = None if hasattr(self.model, 'concat_keys'): cond_concat = [] for ck in self.model.concat_keys: @@ -428,6 +441,32 @@ class KSampler: samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask) elif self.sampler == "uni_pc_bh2": samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, variant='bh2') + elif self.sampler == "ddim": + timesteps = [] + for s in range(sigmas.shape[0]): + timesteps.insert(0, self.model_wrap.sigma_to_t(sigmas[s])) + noise_mask = None + if denoise_mask is not None: + noise_mask = 1.0 - denoise_mask + sampler = DDIMSampler(self.model) + sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False) + z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise) + samples, _ = sampler.sample_custom(ddim_timesteps=timesteps, + conditioning=positive, + batch_size=noise.shape[0], + shape=noise.shape[1:], + verbose=False, + unconditional_guidance_scale=cfg, + unconditional_conditioning=negative, + eta=0.0, + x_T=z_enc, + x0=latent_image, + denoise_function=sampling_function, + cond_concat=cond_concat, + mask=noise_mask, + to_zero=sigmas[-1]==0, + end_step=sigmas.shape[0] - 1) + else: extra_args["denoise_mask"] = denoise_mask self.model_k.latent_image = latent_image