|
|
|
@ -713,8 +713,8 @@ class UniPC:
|
|
|
|
|
method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
|
|
|
|
|
atol=0.0078, rtol=0.05, corrector=False, callback=None, disable_pbar=False
|
|
|
|
|
):
|
|
|
|
|
t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
|
|
|
|
t_T = self.noise_schedule.T if t_start is None else t_start
|
|
|
|
|
# t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
|
|
|
|
# t_T = self.noise_schedule.T if t_start is None else t_start
|
|
|
|
|
device = x.device
|
|
|
|
|
steps = len(timesteps) - 1
|
|
|
|
|
if method == 'multistep':
|
|
|
|
@ -769,8 +769,8 @@ class UniPC:
|
|
|
|
|
callback(step_index, model_prev_list[-1], x, steps)
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
if denoise_to_zero:
|
|
|
|
|
x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
|
|
|
|
|
# if denoise_to_zero:
|
|
|
|
|
# x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -833,21 +833,33 @@ def expand_dims(v, dims):
|
|
|
|
|
return v[(...,) + (None,)*(dims - 1)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SigmaConvert:
|
|
|
|
|
schedule = ""
|
|
|
|
|
def marginal_log_mean_coeff(self, sigma):
|
|
|
|
|
return 0.5 * torch.log(1 / ((sigma * sigma) + 1))
|
|
|
|
|
|
|
|
|
|
def marginal_alpha(self, t):
|
|
|
|
|
return torch.exp(self.marginal_log_mean_coeff(t))
|
|
|
|
|
|
|
|
|
|
def marginal_std(self, t):
|
|
|
|
|
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
|
|
|
|
|
|
|
|
|
|
def marginal_lambda(self, t):
|
|
|
|
|
"""
|
|
|
|
|
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
|
|
|
|
|
"""
|
|
|
|
|
log_mean_coeff = self.marginal_log_mean_coeff(t)
|
|
|
|
|
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
|
|
|
|
|
return log_mean_coeff - log_std
|
|
|
|
|
|
|
|
|
|
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
|
|
|
|
|
to_zero = False
|
|
|
|
|
timesteps = sigmas.clone()
|
|
|
|
|
if sigmas[-1] == 0:
|
|
|
|
|
timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0]
|
|
|
|
|
to_zero = True
|
|
|
|
|
timesteps = sigmas[:]
|
|
|
|
|
timesteps[-1] = 0.001
|
|
|
|
|
else:
|
|
|
|
|
timesteps = sigmas.clone()
|
|
|
|
|
|
|
|
|
|
alphas_cumprod = model.inner_model.alphas_cumprod
|
|
|
|
|
|
|
|
|
|
for s in range(timesteps.shape[0]):
|
|
|
|
|
timesteps[s] = (model.sigma_to_discrete_timestep(timesteps[s]) / 1000) + (1 / len(alphas_cumprod))
|
|
|
|
|
|
|
|
|
|
ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
|
|
|
|
ns = SigmaConvert()
|
|
|
|
|
|
|
|
|
|
if image is not None:
|
|
|
|
|
img = image * ns.marginal_alpha(timesteps[0])
|
|
|
|
@ -859,16 +871,10 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
|
|
|
|
|
else:
|
|
|
|
|
img = noise
|
|
|
|
|
|
|
|
|
|
if to_zero:
|
|
|
|
|
timesteps[-1] = (1 / len(alphas_cumprod))
|
|
|
|
|
|
|
|
|
|
device = noise.device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_type = "noise"
|
|
|
|
|
|
|
|
|
|
model_fn = model_wrapper(
|
|
|
|
|
model.predict_eps_discrete_timestep,
|
|
|
|
|
model.predict_eps_sigma,
|
|
|
|
|
ns,
|
|
|
|
|
model_type=model_type,
|
|
|
|
|
guidance_type="uncond",
|
|
|
|
@ -878,6 +884,5 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
|
|
|
|
|
order = min(3, len(timesteps) - 1)
|
|
|
|
|
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
|
|
|
|
|
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
|
|
|
|
|
if not to_zero:
|
|
|
|
|
x /= ns.marginal_alpha(timesteps[-1])
|
|
|
|
|
return x
|
|
|
|
|