diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index fef8a48..918e608 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -2,6 +2,7 @@ import folder_paths import comfy.sd import comfy.model_sampling import comfy.latent_formats +import nodes import torch class LCM(comfy.model_sampling.EPS): @@ -174,7 +175,10 @@ class ModelSamplingFlux: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 100.0, "step":0.01}), + "max_shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 100.0, "step":0.01}), + "base_shift": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01}), + "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), }} RETURN_TYPES = ("MODEL",) @@ -182,9 +186,15 @@ class ModelSamplingFlux: CATEGORY = "advanced/model" - def patch(self, model, shift): + def patch(self, model, max_shift, base_shift, width, height): m = model.clone() + x1 = 256 + x2 = 4096 + mm = (max_shift - base_shift) / (x2 - x1) + b = base_shift - mm * x1 + shift = (width * height / (8 * 8 * 2 * 2)) * mm + b + sampling_base = comfy.model_sampling.ModelSamplingFlux sampling_type = comfy.model_sampling.CONST