diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index 7e88bb9..58e030d 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -713,8 +713,8 @@ class UniPC: method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', atol=0.0078, rtol=0.05, corrector=False, callback=None, disable_pbar=False ): - t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end - t_T = self.noise_schedule.T if t_start is None else t_start + # t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + # t_T = self.noise_schedule.T if t_start is None else t_start device = x.device steps = len(timesteps) - 1 if method == 'multistep': @@ -769,8 +769,8 @@ class UniPC: callback(step_index, model_prev_list[-1], x, steps) else: raise NotImplementedError() - if denoise_to_zero: - x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) + # if denoise_to_zero: + # x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) return x @@ -833,21 +833,33 @@ def expand_dims(v, dims): return v[(...,) + (None,)*(dims - 1)] +class SigmaConvert: + schedule = "" + def marginal_log_mean_coeff(self, sigma): + return 0.5 * torch.log(1 / ((sigma * sigma) + 1)) + + def marginal_alpha(self, t): + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'): - to_zero = False + timesteps = sigmas.clone() if sigmas[-1] == 0: - timesteps = torch.nn.functional.interpolate(sigmas[None,None,:-1], size=(len(sigmas),), mode='linear')[0][0] - to_zero = True + timesteps = sigmas[:] + timesteps[-1] = 0.001 else: timesteps = sigmas.clone() - - alphas_cumprod = model.inner_model.alphas_cumprod - - for s in range(timesteps.shape[0]): - timesteps[s] = (model.sigma_to_discrete_timestep(timesteps[s]) / 1000) + (1 / len(alphas_cumprod)) - - ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + ns = SigmaConvert() if image is not None: img = image * ns.marginal_alpha(timesteps[0]) @@ -859,16 +871,10 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex else: img = noise - if to_zero: - timesteps[-1] = (1 / len(alphas_cumprod)) - - device = noise.device - - model_type = "noise" model_fn = model_wrapper( - model.predict_eps_discrete_timestep, + model.predict_eps_sigma, ns, model_type=model_type, guidance_type="uncond", @@ -878,6 +884,5 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex order = min(3, len(timesteps) - 1) uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, noise_mask=noise_mask, masked_image=image, noise=noise, variant=variant) x = uni_pc.sample(img, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable) - if not to_zero: - x /= ns.marginal_alpha(timesteps[-1]) + x /= ns.marginal_alpha(timesteps[-1]) return x diff --git a/comfy/k_diffusion/external.py b/comfy/k_diffusion/external.py index c1a137d..953d3db 100644 --- a/comfy/k_diffusion/external.py +++ b/comfy/k_diffusion/external.py @@ -97,6 +97,10 @@ class DiscreteSchedule(nn.Module): input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5) return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim) + def predict_eps_sigma(self, input, sigma, **kwargs): + input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5) + return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim) + class DiscreteEpsDDPMDenoiser(DiscreteSchedule): """A wrapper for discrete schedule DDPM models that output eps (the predicted noise).""" diff --git a/comfy/samplers.py b/comfy/samplers.py index 4840b6d..0b38fbd 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -739,7 +739,7 @@ class KSampler: sigmas = None discard_penultimate_sigma = False - if self.sampler in ['dpm_2', 'dpm_2_ancestral']: + if self.sampler in ['dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2']: steps += 1 discard_penultimate_sigma = True