|
|
@ -92,8 +92,8 @@ class DiscreteSchedule(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def predict_eps_discrete_timestep(self, input, t, **kwargs):
|
|
|
|
def predict_eps_discrete_timestep(self, input, t, **kwargs):
|
|
|
|
sigma = self.t_to_sigma(t.round())
|
|
|
|
sigma = self.t_to_sigma(t.round())
|
|
|
|
input = input * ((sigma ** 2 + 1.0) ** 0.5)
|
|
|
|
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
|
|
|
|
return (input - self(input, sigma, **kwargs)) / sigma
|
|
|
|
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
|
|
|
|
|
|
|
|
|
|
|
|
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
|
|
|
|
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
|
|
|
|
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
|
|
|
|
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
|
|
|
|