|
|
|
@ -313,13 +313,18 @@ def simple_scheduler(model_sampling, steps):
|
|
|
|
|
def ddim_scheduler(model_sampling, steps):
|
|
|
|
|
s = model_sampling
|
|
|
|
|
sigs = []
|
|
|
|
|
ss = max(len(s.sigmas) // steps, 1)
|
|
|
|
|
x = 1
|
|
|
|
|
if math.isclose(float(s.sigmas[x]), 0, abs_tol=0.00001):
|
|
|
|
|
steps += 1
|
|
|
|
|
sigs = []
|
|
|
|
|
else:
|
|
|
|
|
sigs = [0.0]
|
|
|
|
|
|
|
|
|
|
ss = max(len(s.sigmas) // steps, 1)
|
|
|
|
|
while x < len(s.sigmas):
|
|
|
|
|
sigs += [float(s.sigmas[x])]
|
|
|
|
|
x += ss
|
|
|
|
|
sigs = sigs[::-1]
|
|
|
|
|
sigs += [0.0]
|
|
|
|
|
return torch.FloatTensor(sigs)
|
|
|
|
|
|
|
|
|
|
def normal_scheduler(model_sampling, steps, sgm=False, floor=False):
|
|
|
|
@ -327,16 +332,23 @@ def normal_scheduler(model_sampling, steps, sgm=False, floor=False):
|
|
|
|
|
start = s.timestep(s.sigma_max)
|
|
|
|
|
end = s.timestep(s.sigma_min)
|
|
|
|
|
|
|
|
|
|
append_zero = True
|
|
|
|
|
if sgm:
|
|
|
|
|
timesteps = torch.linspace(start, end, steps + 1)[:-1]
|
|
|
|
|
else:
|
|
|
|
|
if math.isclose(float(s.sigma(end)), 0, abs_tol=0.00001):
|
|
|
|
|
steps += 1
|
|
|
|
|
append_zero = False
|
|
|
|
|
timesteps = torch.linspace(start, end, steps)
|
|
|
|
|
|
|
|
|
|
sigs = []
|
|
|
|
|
for x in range(len(timesteps)):
|
|
|
|
|
ts = timesteps[x]
|
|
|
|
|
sigs.append(s.sigma(ts))
|
|
|
|
|
sigs += [0.0]
|
|
|
|
|
sigs.append(float(s.sigma(ts)))
|
|
|
|
|
|
|
|
|
|
if append_zero:
|
|
|
|
|
sigs += [0.0]
|
|
|
|
|
|
|
|
|
|
return torch.FloatTensor(sigs)
|
|
|
|
|
|
|
|
|
|
# Implemented based on: https://arxiv.org/abs/2407.12173
|
|
|
|
|