|
|
|
@ -650,4 +650,78 @@ def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disab
|
|
|
|
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
|
|
|
|
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def sample_dpmpp_3m(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
|
|
|
|
"""DPM-Solver++(3M) without SDE-specific parts."""
|
|
|
|
|
|
|
|
|
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
|
|
|
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
|
|
|
|
|
extra_args = {} if extra_args is None else extra_args
|
|
|
|
|
s_in = x.new_ones([x.shape[0]])
|
|
|
|
|
|
|
|
|
|
for i in trange(len(sigmas) - 1, disable=disable):
|
|
|
|
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
|
|
|
|
if callback is not None:
|
|
|
|
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
|
|
|
|
|
|
|
|
|
# Update x using the DPM-Solver++(3M) update rule
|
|
|
|
|
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
|
|
|
|
h = s - t
|
|
|
|
|
h_eta = h * (eta + 1)
|
|
|
|
|
|
|
|
|
|
x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised
|
|
|
|
|
|
|
|
|
|
if eta:
|
|
|
|
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
|
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
|
|
|
|
"""DPM-Solver++(3M) SDE."""
|
|
|
|
|
|
|
|
|
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
|
|
|
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler
|
|
|
|
|
extra_args = {} if extra_args is None else extra_args
|
|
|
|
|
s_in = x.new_ones([x.shape[0]])
|
|
|
|
|
|
|
|
|
|
denoised_1, denoised_2 = None, None
|
|
|
|
|
h_1, h_2 = None, None
|
|
|
|
|
|
|
|
|
|
for i in trange(len(sigmas) - 1, disable=disable):
|
|
|
|
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
|
|
|
|
if callback is not None:
|
|
|
|
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
|
|
|
|
if sigmas[i + 1] == 0:
|
|
|
|
|
# Denoising step
|
|
|
|
|
x = denoised
|
|
|
|
|
else:
|
|
|
|
|
t, s = -sigmas[i].log(), -sigmas[i + 1].log()
|
|
|
|
|
h = s - t
|
|
|
|
|
h_eta = h * (eta + 1)
|
|
|
|
|
|
|
|
|
|
x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised
|
|
|
|
|
|
|
|
|
|
if h_2 is not None:
|
|
|
|
|
r0 = h_1 / h
|
|
|
|
|
r1 = h_2 / h
|
|
|
|
|
d1_0 = (denoised - denoised_1) / r0
|
|
|
|
|
d1_1 = (denoised_1 - denoised_2) / r1
|
|
|
|
|
d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1)
|
|
|
|
|
d2 = (d1_0 - d1_1) / (r0 + r1)
|
|
|
|
|
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
|
|
|
|
phi_3 = phi_2 / h_eta - 0.5
|
|
|
|
|
x = x + phi_2 * d1 - phi_3 * d2
|
|
|
|
|
elif h_1 is not None:
|
|
|
|
|
r = h_1 / h
|
|
|
|
|
d = (denoised - denoised_1) / r
|
|
|
|
|
phi_2 = h_eta.neg().expm1() / h_eta + 1
|
|
|
|
|
x = x + phi_2 * d
|
|
|
|
|
|
|
|
|
|
if eta:
|
|
|
|
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
|
|
|
|
|
|
|
|
|
|
denoised_1, denoised_2 = denoised, denoised_1
|
|
|
|
|
h_1, h_2 = h, h_1
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|