|
|
|
@ -3,7 +3,6 @@ import math
|
|
|
|
|
from scipy import integrate
|
|
|
|
|
import torch
|
|
|
|
|
from torch import nn
|
|
|
|
|
from torchdiffeq import odeint
|
|
|
|
|
import torchsde
|
|
|
|
|
from tqdm.auto import trange, tqdm
|
|
|
|
|
|
|
|
|
@ -287,30 +286,6 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
|
|
|
|
|
extra_args = {} if extra_args is None else extra_args
|
|
|
|
|
s_in = x.new_ones([x.shape[0]])
|
|
|
|
|
v = torch.randint_like(x, 2) * 2 - 1
|
|
|
|
|
fevals = 0
|
|
|
|
|
def ode_fn(sigma, x):
|
|
|
|
|
nonlocal fevals
|
|
|
|
|
with torch.enable_grad():
|
|
|
|
|
x = x[0].detach().requires_grad_()
|
|
|
|
|
denoised = model(x, sigma * s_in, **extra_args)
|
|
|
|
|
d = to_d(x, sigma, denoised)
|
|
|
|
|
fevals += 1
|
|
|
|
|
grad = torch.autograd.grad((d * v).sum(), x)[0]
|
|
|
|
|
d_ll = (v * grad).flatten(1).sum(1)
|
|
|
|
|
return d.detach(), d_ll
|
|
|
|
|
x_min = x, x.new_zeros([x.shape[0]])
|
|
|
|
|
t = x.new_tensor([sigma_min, sigma_max])
|
|
|
|
|
sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
|
|
|
|
|
latent, delta_ll = sol[0][-1], sol[1][-1]
|
|
|
|
|
ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
|
|
|
|
|
return ll_prior + delta_ll, {'fevals': fevals}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PIDStepSizeController:
|
|
|
|
|
"""A PID controller for ODE adaptive step size control."""
|
|
|
|
|
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
|
|
|
|
|