diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index f42f301..ae42d81 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -142,12 +142,16 @@ class StableCascadeSampling(ModelSamplingDiscrete): else: sampling_settings = {} - self.num_timesteps = 1000 - self.shift = sampling_settings.get("shift", 1.0) - cosine_s=8e-3 + self.set_parameters(sampling_settings.get("shift", 1.0)) + + def set_parameters(self, shift=1.0, cosine_s=8e-3): + self.shift = shift self.cosine_s = torch.tensor(cosine_s) - sigmas = torch.empty((self.num_timesteps), dtype=torch.float32) self._init_alpha_cumprod = torch.cos(self.cosine_s / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 + + #This part is just for compatibility with some schedulers in the codebase + self.num_timesteps = 1000 + sigmas = torch.empty((self.num_timesteps), dtype=torch.float32) for x in range(self.num_timesteps): t = x / self.num_timesteps sigmas[x] = self.sigma(t) diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 541ce8f..ac7c1c1 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -99,6 +99,32 @@ class ModelSamplingDiscrete: m.add_object_patch("model_sampling", model_sampling) return (m, ) +class ModelSamplingStableCascade: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "shift": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 100.0, "step":0.01}), + }} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, shift): + m = model.clone() + + sampling_base = comfy.model_sampling.StableCascadeSampling + sampling_type = comfy.model_sampling.EPS + + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling.set_parameters(shift) + m.add_object_patch("model_sampling", model_sampling) + return (m, ) + class ModelSamplingContinuousEDM: @classmethod def INPUT_TYPES(s): @@ -171,5 +197,6 @@ class RescaleCFG: NODE_CLASS_MAPPINGS = { "ModelSamplingDiscrete": ModelSamplingDiscrete, "ModelSamplingContinuousEDM": ModelSamplingContinuousEDM, + "ModelSamplingStableCascade": ModelSamplingStableCascade, "RescaleCFG": RescaleCFG, }