diff --git a/comfy/samplers.py b/comfy/samplers.py index d3cd901..dffd7fe 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -273,7 +273,8 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con max_total_area = model_management.maximum_batch_area() cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options) if "sampler_cfg_function" in model_options: - return model_options["sampler_cfg_function"](cond, uncond, cond_scale) + args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep} + return model_options["sampler_cfg_function"](args) else: return uncond + (cond - uncond) * cond_scale diff --git a/comfy/sd.py b/comfy/sd.py index 4b3cb83..5237df3 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,6 +1,7 @@ import torch import contextlib import copy +import inspect from . import sd1_clip from . import sd2_clip @@ -313,8 +314,10 @@ class ModelPatcher: self.model_options["transformer_options"]["tomesd"] = {"ratio": ratio} def set_model_sampler_cfg_function(self, sampler_cfg_function): - self.model_options["sampler_cfg_function"] = sampler_cfg_function - + if len(inspect.signature(sampler_cfg_function).parameters) == 3: + self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way + else: + self.model_options["sampler_cfg_function"] = sampler_cfg_function def set_model_patch(self, patch, name): to = self.model_options["transformer_options"]