From 1beb348ee2dafea19615cdcfdcfc096501c2f519 Mon Sep 17 00:00:00 2001 From: Ashen <> Date: Sat, 17 Aug 2024 17:32:27 -0700 Subject: [PATCH] dpmpp_2s_ancestral_RF for rectified flow (Flux, SD3 and Auraflow). --- comfy/k_diffusion/sampling.py | 49 +++++++++++++++++++++++++++++++++++ comfy/samplers.py | 2 +- 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 763d8cc..ca822ff 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -541,6 +541,55 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, return x +@torch.no_grad() +def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): + """Ancestral sampling with DPM-Solver++(2S) second-order steps.""" + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1 + lambda_fn = lambda sigma: ((1-sigma)/sigma).log() + + # logged_x = x.unsqueeze(0) + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta + sigma_down = sigmas[i+1] * downstep_ratio + alpha_ip1 = 1 - sigmas[i+1] + alpha_down = 1 - sigma_down + renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5 + # sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + # Euler method + d = to_d(x, sigmas[i], denoised) + dt = sigma_down - sigmas[i] + x = x + d * dt + else: + # DPM-Solver++(2S) + if sigmas[i] == 1.0: + sigma_s = 0.9999 + else: + t_i, t_down = lambda_fn(sigmas[i]), lambda_fn(sigma_down) + r = 1 / 2 + h = t_down - t_i + s = t_i + r * h + sigma_s = sigma_fn(s) + # sigma_s = sigmas[i+1] + sigma_s_i_ratio = sigma_s / sigmas[i] + u = sigma_s_i_ratio * x + (1 - sigma_s_i_ratio) * denoised + D_i = model(u, sigma_s * s_in, **extra_args) + sigma_down_i_ratio = sigma_down / sigmas[i] + x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * D_i + # print("sigma_i", sigmas[i], "sigma_ip1", sigmas[i+1],"sigma_down", sigma_down, "sigma_down_i_ratio", sigma_down_i_ratio, "sigma_s_i_ratio", sigma_s_i_ratio, "renoise_coeff", renoise_coeff) + # Noise addition + if sigmas[i + 1] > 0 and eta > 0: + x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff + # logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0) + return x + @torch.no_grad() def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): """DPM-Solver++ (stochastic).""" diff --git a/comfy/samplers.py b/comfy/samplers.py index ce4371d..b8b3aa8 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -570,7 +570,7 @@ class Sampler: return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", - "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", + "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_RF", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "ipndm", "ipndm_v", "deis"]