diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py new file mode 100644 index 0000000..b1823d7 --- /dev/null +++ b/comfy_extras/nodes_latent.py @@ -0,0 +1,74 @@ +import comfy.utils + +def reshape_latent_to(target_shape, latent): + if latent.shape[1:] != target_shape[1:]: + latent.movedim(1, -1) + latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center") + latent.movedim(-1, 1) + return comfy.utils.repeat_to_batch_size(latent, target_shape[0]) + + +class LatentAdd: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/advanced" + + def op(self, samples1, samples2): + samples_out = samples1.copy() + + s1 = samples1["samples"] + s2 = samples2["samples"] + + s2 = reshape_latent_to(s1.shape, s2) + samples_out["samples"] = s1 + s2 + return (samples_out,) + +class LatentSubtract: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/advanced" + + def op(self, samples1, samples2): + samples_out = samples1.copy() + + s1 = samples1["samples"] + s2 = samples2["samples"] + + s2 = reshape_latent_to(s1.shape, s2) + samples_out["samples"] = s1 - s2 + return (samples_out,) + +class LatentMuliply: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + }} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/advanced" + + def op(self, samples, multiplier): + samples_out = samples.copy() + + s1 = samples["samples"] + samples_out["samples"] = s1 * multiplier + return (samples_out,) + +NODE_CLASS_MAPPINGS = { + "LatentAdd": LatentAdd, + "LatentSubtract": LatentSubtract, + "LatentMuliply": LatentMuliply, +} diff --git a/nodes.py b/nodes.py index 18d82ea..6e0d437 100644 --- a/nodes.py +++ b/nodes.py @@ -1772,6 +1772,7 @@ def load_custom_nodes(): print() def init_custom_nodes(): + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_latent.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py"))