From e33dc2b33beaa40b1a8bd0a40ea8a9a143bbe568 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 11 Mar 2023 15:28:15 -0500 Subject: [PATCH] Add a VAEEncodeTiled node. --- comfy/sd.py | 8 ++++++++ nodes.py | 22 ++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/comfy/sd.py b/comfy/sd.py index 0a82038..fd434ba 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -400,6 +400,14 @@ class VAE: samples = samples.cpu() return samples + def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): + model_management.unload_model() + self.first_stage_model = self.first_stage_model.to(self.device) + pixel_samples = pixel_samples.movedim(-1,1).to(self.device) + samples = utils.tiled_scale(pixel_samples, lambda a: self.first_stage_model.encode(2. * a - 1.).sample() * self.scale_factor, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4) + self.first_stage_model = self.first_stage_model.cpu() + samples = samples.cpu() + return samples def resize_image_to(tensor, target_latent_tensor, batched_number): tensor = utils.common_upscale(tensor, target_latent_tensor.shape[3] * 8, target_latent_tensor.shape[2] * 8, 'nearest-exact', "center") diff --git a/nodes.py b/nodes.py index 3d2f83d..7a9e598 100644 --- a/nodes.py +++ b/nodes.py @@ -151,6 +151,27 @@ class VAEEncode: return ({"samples":t}, ) + +class VAEEncodeTiled: + def __init__(self, device="cpu"): + self.device = device + + @classmethod + def INPUT_TYPES(s): + return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}} + RETURN_TYPES = ("LATENT",) + FUNCTION = "encode" + + CATEGORY = "_for_testing" + + def encode(self, vae, pixels): + x = (pixels.shape[1] // 64) * 64 + y = (pixels.shape[2] // 64) * 64 + if pixels.shape[1] != x or pixels.shape[2] != y: + pixels = pixels[:,:x,:y,:] + t = vae.encode_tiled(pixels[:,:,:,:3]) + + return ({"samples":t}, ) class VAEEncodeForInpaint: def __init__(self, device="cpu"): self.device = device @@ -946,6 +967,7 @@ NODE_CLASS_MAPPINGS = { "StyleModelLoader": StyleModelLoader, "CLIPVisionLoader": CLIPVisionLoader, "VAEDecodeTiled": VAEDecodeTiled, + "VAEEncodeTiled": VAEEncodeTiled, } def load_custom_node(module_path):