|
|
@ -358,9 +358,6 @@ class UniPC:
|
|
|
|
thresholding=False,
|
|
|
|
thresholding=False,
|
|
|
|
max_val=1.,
|
|
|
|
max_val=1.,
|
|
|
|
variant='bh1',
|
|
|
|
variant='bh1',
|
|
|
|
noise_mask=None,
|
|
|
|
|
|
|
|
masked_image=None,
|
|
|
|
|
|
|
|
noise=None,
|
|
|
|
|
|
|
|
):
|
|
|
|
):
|
|
|
|
"""Construct a UniPC.
|
|
|
|
"""Construct a UniPC.
|
|
|
|
|
|
|
|
|
|
|
@ -372,9 +369,6 @@ class UniPC:
|
|
|
|
self.predict_x0 = predict_x0
|
|
|
|
self.predict_x0 = predict_x0
|
|
|
|
self.thresholding = thresholding
|
|
|
|
self.thresholding = thresholding
|
|
|
|
self.max_val = max_val
|
|
|
|
self.max_val = max_val
|
|
|
|
self.noise_mask = noise_mask
|
|
|
|
|
|
|
|
self.masked_image = masked_image
|
|
|
|
|
|
|
|
self.noise = noise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dynamic_thresholding_fn(self, x0, t=None):
|
|
|
|
def dynamic_thresholding_fn(self, x0, t=None):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -391,10 +385,7 @@ class UniPC:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Return the noise prediction model.
|
|
|
|
Return the noise prediction model.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
if self.noise_mask is not None:
|
|
|
|
return self.model(x, t)
|
|
|
|
return self.model(x, t) * self.noise_mask
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
return self.model(x, t)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def data_prediction_fn(self, x, t):
|
|
|
|
def data_prediction_fn(self, x, t):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -409,8 +400,6 @@ class UniPC:
|
|
|
|
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
|
|
|
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
|
|
|
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
|
|
|
|
s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
|
|
|
|
x0 = torch.clamp(x0, -s, s) / s
|
|
|
|
x0 = torch.clamp(x0, -s, s) / s
|
|
|
|
if self.noise_mask is not None:
|
|
|
|
|
|
|
|
x0 = x0 * self.noise_mask + (1. - self.noise_mask) * self.masked_image
|
|
|
|
|
|
|
|
return x0
|
|
|
|
return x0
|
|
|
|
|
|
|
|
|
|
|
|
def model_fn(self, x, t):
|
|
|
|
def model_fn(self, x, t):
|
|
|
@ -723,8 +712,6 @@ class UniPC:
|
|
|
|
assert timesteps.shape[0] - 1 == steps
|
|
|
|
assert timesteps.shape[0] - 1 == steps
|
|
|
|
# with torch.no_grad():
|
|
|
|
# with torch.no_grad():
|
|
|
|
for step_index in trange(steps, disable=disable_pbar):
|
|
|
|
for step_index in trange(steps, disable=disable_pbar):
|
|
|
|
if self.noise_mask is not None:
|
|
|
|
|
|
|
|
x = x * self.noise_mask + (1. - self.noise_mask) * (self.masked_image * self.noise_schedule.marginal_alpha(timesteps[step_index]) + self.noise * self.noise_schedule.marginal_std(timesteps[step_index]))
|
|
|
|
|
|
|
|
if step_index == 0:
|
|
|
|
if step_index == 0:
|
|
|
|
vec_t = timesteps[0].expand((x.shape[0]))
|
|
|
|
vec_t = timesteps[0].expand((x.shape[0]))
|
|
|
|
model_prev_list = [self.model_fn(x, vec_t)]
|
|
|
|
model_prev_list = [self.model_fn(x, vec_t)]
|
|
|
@ -766,7 +753,7 @@ class UniPC:
|
|
|
|
model_x = self.model_fn(x, vec_t)
|
|
|
|
model_x = self.model_fn(x, vec_t)
|
|
|
|
model_prev_list[-1] = model_x
|
|
|
|
model_prev_list[-1] = model_x
|
|
|
|
if callback is not None:
|
|
|
|
if callback is not None:
|
|
|
|
callback(step_index, model_prev_list[-1], x, steps)
|
|
|
|
callback({'x': x, 'i': step_index, 'denoised': model_prev_list[-1]})
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
raise NotImplementedError()
|
|
|
|
raise NotImplementedError()
|
|
|
|
# if denoise_to_zero:
|
|
|
|
# if denoise_to_zero:
|
|
|
@ -858,7 +845,7 @@ def predict_eps_sigma(model, input, sigma_in, **kwargs):
|
|
|
|
return (input - model(input, sigma_in, **kwargs)) / sigma
|
|
|
|
return (input - model(input, sigma_in, **kwargs)) / sigma
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
|
|
|
|
def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'):
|
|
|
|
timesteps = sigmas.clone()
|
|
|
|
timesteps = sigmas.clone()
|
|
|
|
if sigmas[-1] == 0:
|
|
|
|
if sigmas[-1] == 0:
|
|
|
|
timesteps = sigmas[:]
|
|
|
|
timesteps = sigmas[:]
|
|
|
@ -867,16 +854,7 @@ def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, call
|
|
|
|
timesteps = sigmas.clone()
|
|
|
|
timesteps = sigmas.clone()
|
|
|
|
ns = SigmaConvert()
|
|
|
|
ns = SigmaConvert()
|
|
|
|
|
|
|
|
|
|
|
|
if image is not None:
|
|
|
|
noise = noise / torch.sqrt(1.0 + timesteps[0] ** 2.0)
|
|
|
|
img = image * ns.marginal_alpha(timesteps[0])
|
|
|
|
|
|
|
|
if max_denoise:
|
|
|
|
|
|
|
|
noise_mult = 1.0
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
noise_mult = ns.marginal_std(timesteps[0])
|
|
|
|
|
|
|
|
img += noise * noise_mult
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
img = noise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_type = "noise"
|
|
|
|
model_type = "noise"
|
|
|
|
|
|
|
|
|
|
|
|
model_fn = model_wrapper(
|
|
|
|
model_fn = model_wrapper(
|
|
|
@ -888,7 +866,10 @@ def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, call
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
order = min(3, len(timesteps) - 2)
|
|
|
|
order = min(3, len(timesteps) - 2)
|
|
|
|
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant)
|
|
|
|
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=variant)
|
|
|
|
x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
|
|
|
|
x = uni_pc.sample(noise, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
|
|
|
|
x /= ns.marginal_alpha(timesteps[-1])
|
|
|
|
x /= ns.marginal_alpha(timesteps[-1])
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sample_unipc_bh2(model, noise, sigmas, extra_args=None, callback=None, disable=False):
|
|
|
|
|
|
|
|
return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2')
|