|
|
@ -400,38 +400,6 @@ def encode_adm(noise_augmentor, conds, batch_size, device):
|
|
|
|
|
|
|
|
|
|
|
|
return conds
|
|
|
|
return conds
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_sigmas(model, steps, scheduler, sampler):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Returns a tensor containing the sigmas corresponding to the given model, number of steps, scheduler type and sample technique
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
if not (isinstance(model, CompVisVDenoiser) or isinstance(model, k_diffusion_external.CompVisDenoiser)):
|
|
|
|
|
|
|
|
model = CFGNoisePredictor(model)
|
|
|
|
|
|
|
|
if model.inner_model.parameterization == "v":
|
|
|
|
|
|
|
|
model = CompVisVDenoiser(model, quantize=True)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
model = k_diffusion_external.CompVisDenoiser(model, quantize=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sigmas = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
discard_penultimate_sigma = False
|
|
|
|
|
|
|
|
if sampler in ['dpm_2', 'dpm_2_ancestral']:
|
|
|
|
|
|
|
|
steps += 1
|
|
|
|
|
|
|
|
discard_penultimate_sigma = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if scheduler == "karras":
|
|
|
|
|
|
|
|
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.sigma_min), sigma_max=float(model.sigma_max))
|
|
|
|
|
|
|
|
elif scheduler == "normal":
|
|
|
|
|
|
|
|
sigmas = model.get_sigmas(steps)
|
|
|
|
|
|
|
|
elif scheduler == "simple":
|
|
|
|
|
|
|
|
sigmas = simple_scheduler(model, steps)
|
|
|
|
|
|
|
|
elif scheduler == "ddim_uniform":
|
|
|
|
|
|
|
|
sigmas = ddim_scheduler(model, steps)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
print("error invalid scheduler", scheduler)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if discard_penultimate_sigma:
|
|
|
|
|
|
|
|
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
|
|
|
|
|
|
|
return sigmas
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class KSampler:
|
|
|
|
class KSampler:
|
|
|
|
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
|
|
|
|
SCHEDULERS = ["karras", "normal", "simple", "ddim_uniform"]
|
|
|
@ -461,13 +429,36 @@ class KSampler:
|
|
|
|
self.denoise = denoise
|
|
|
|
self.denoise = denoise
|
|
|
|
self.model_options = model_options
|
|
|
|
self.model_options = model_options
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_sigmas(self, steps):
|
|
|
|
|
|
|
|
sigmas = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
discard_penultimate_sigma = False
|
|
|
|
|
|
|
|
if self.sampler in ['dpm_2', 'dpm_2_ancestral']:
|
|
|
|
|
|
|
|
steps += 1
|
|
|
|
|
|
|
|
discard_penultimate_sigma = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.scheduler == "karras":
|
|
|
|
|
|
|
|
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
|
|
|
|
|
|
|
|
elif self.scheduler == "normal":
|
|
|
|
|
|
|
|
sigmas = self.model_wrap.get_sigmas(steps)
|
|
|
|
|
|
|
|
elif self.scheduler == "simple":
|
|
|
|
|
|
|
|
sigmas = simple_scheduler(self.model_wrap, steps)
|
|
|
|
|
|
|
|
elif self.scheduler == "ddim_uniform":
|
|
|
|
|
|
|
|
sigmas = ddim_scheduler(self.model_wrap, steps)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
print("error invalid scheduler", self.scheduler)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if discard_penultimate_sigma:
|
|
|
|
|
|
|
|
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
|
|
|
|
|
|
|
|
return sigmas
|
|
|
|
|
|
|
|
|
|
|
|
def set_steps(self, steps, denoise=None):
|
|
|
|
def set_steps(self, steps, denoise=None):
|
|
|
|
self.steps = steps
|
|
|
|
self.steps = steps
|
|
|
|
if denoise is None or denoise > 0.9999:
|
|
|
|
if denoise is None or denoise > 0.9999:
|
|
|
|
self.sigmas = calculate_sigmas(self.model_wrap, steps, self.scheduler, self.sampler).to(self.device)
|
|
|
|
self.sigmas = self.calculate_sigmas(steps).to(self.device)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
new_steps = int(steps/denoise)
|
|
|
|
new_steps = int(steps/denoise)
|
|
|
|
sigmas = calculate_sigmas(self.model_wrap, new_steps, self.scheduler, self.sampler).to(self.device)
|
|
|
|
sigmas = self.calculate_sigmas(new_steps).to(self.device)
|
|
|
|
self.sigmas = sigmas[-(steps + 1):]
|
|
|
|
self.sigmas = sigmas[-(steps + 1):]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|