ModelSamplingFlux now takes a resolution and adjusts the shift with it.

If you want to sample Flux dev exactly how the reference code does use
the same resolution as your image in this node.
main
comfyanonymous 7 months ago
parent f7a5107784
commit 56f3c660bf

@ -2,6 +2,7 @@ import folder_paths
import comfy.sd
import comfy.model_sampling
import comfy.latent_formats
import nodes
import torch
class LCM(comfy.model_sampling.EPS):
@ -174,7 +175,10 @@ class ModelSamplingFlux:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 100.0, "step":0.01}),
"max_shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 100.0, "step":0.01}),
"base_shift": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01}),
"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
}}
RETURN_TYPES = ("MODEL",)
@ -182,9 +186,15 @@ class ModelSamplingFlux:
CATEGORY = "advanced/model"
def patch(self, model, shift):
def patch(self, model, max_shift, base_shift, width, height):
m = model.clone()
x1 = 256
x2 = 4096
mm = (max_shift - base_shift) / (x2 - x1)
b = base_shift - mm * x1
shift = (width * height / (8 * 8 * 2 * 2)) * mm + b
sampling_base = comfy.model_sampling.ModelSamplingFlux
sampling_type = comfy.model_sampling.CONST

Loading…
Cancel
Save