diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index ac7e271..5a6aa7d 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -98,7 +98,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, alphas = torch.cos(alphas).pow(2) alphas = alphas / alphas[0] betas = 1 - alphas[1:] / alphas[:-1] - betas = np.clip(betas, a_min=0, a_max=0.999) + betas = torch.clamp(betas, min=0, max=0.999) elif schedule == "squaredcos_cap_v2": # used for karlo prior # return early @@ -113,7 +113,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 else: raise ValueError(f"schedule '{schedule}' unknown.") - return betas.numpy() + return betas def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index cc8745c..d587002 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -1,5 +1,4 @@ import torch -import numpy as np from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule import math @@ -42,8 +41,7 @@ class ModelSamplingDiscrete(torch.nn.Module): else: betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) alphas = 1. - betas - alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32) - # alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + alphas_cumprod = torch.cumprod(alphas, dim=0) timesteps, = betas.shape self.num_timesteps = int(timesteps) @@ -58,8 +56,8 @@ class ModelSamplingDiscrete(torch.nn.Module): self.set_sigmas(sigmas) def set_sigmas(self, sigmas): - self.register_buffer('sigmas', sigmas) - self.register_buffer('log_sigmas', sigmas.log()) + self.register_buffer('sigmas', sigmas.float()) + self.register_buffer('log_sigmas', sigmas.log().float()) @property def sigma_min(self):