From 4225d1cb9fcbb22556ce8a69327d8c531b755945 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 4 Feb 2023 15:53:29 -0500 Subject: [PATCH] Add a basic ImageScale node. It's pretty much the same as the LatentUpscale node for now but for images in pixel space. --- nodes.py | 53 ++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 38 insertions(+), 15 deletions(-) diff --git a/nodes.py b/nodes.py index 9101c6e..a58c0ce 100644 --- a/nodes.py +++ b/nodes.py @@ -186,6 +186,23 @@ class EmptyLatentImage: latent = torch.zeros([batch_size, 4, height // 8, width // 8]) return (latent, ) +def common_upscale(samples, width, height, upscale_method, crop): + if crop == "center": + old_width = samples.shape[3] + old_height = samples.shape[2] + old_aspect = old_width / old_height + new_aspect = width / height + x = 0 + y = 0 + if old_aspect > new_aspect: + x = round((old_width - old_width * (new_aspect / old_aspect)) / 2) + elif old_aspect < new_aspect: + y = round((old_height - old_height * (old_aspect / new_aspect)) / 2) + s = samples[:,:,y:old_height-y,x:old_width-x] + else: + s = samples + return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) + class LatentUpscale: upscale_methods = ["nearest-exact", "bilinear", "area"] crop_methods = ["disabled", "center"] @@ -202,21 +219,7 @@ class LatentUpscale: CATEGORY = "latent" def upscale(self, samples, upscale_method, width, height, crop): - if crop == "center": - old_width = samples.shape[3] - old_height = samples.shape[2] - old_aspect = old_width / old_height - new_aspect = width / height - x = 0 - y = 0 - if old_aspect > new_aspect: - x = round((old_width - old_width * (new_aspect / old_aspect)) / 2) - elif old_aspect < new_aspect: - y = round((old_height - old_height * (old_aspect / new_aspect)) / 2) - s = samples[:,:,y:old_height-y,x:old_width-x] - else: - s = samples - s = torch.nn.functional.interpolate(s, size=(height // 8, width // 8), mode=upscale_method) + s = common_upscale(samples, width // 8, height // 8, upscale_method, crop) return (s,) class LatentRotate: @@ -505,7 +508,26 @@ class LoadImage: m.update(f.read()) return m.digest().hex() +class ImageScale: + upscale_methods = ["nearest-exact", "bilinear", "area"] + crop_methods = ["disabled", "center"] + + @classmethod + def INPUT_TYPES(s): + return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,), + "width": ("INT", {"default": 512, "min": 1, "max": 4096, "step": 1}), + "height": ("INT", {"default": 512, "min": 1, "max": 4096, "step": 1}), + "crop": (s.crop_methods,)}} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "upscale" + + CATEGORY = "image" + def upscale(self, image, upscale_method, width, height, crop): + samples = image.movedim(-1,1) + s = common_upscale(samples, width, height, upscale_method, crop) + s = s.movedim(1,-1) + return (s,) NODE_CLASS_MAPPINGS = { "KSampler": KSampler, @@ -518,6 +540,7 @@ NODE_CLASS_MAPPINGS = { "LatentUpscale": LatentUpscale, "SaveImage": SaveImage, "LoadImage": LoadImage, + "ImageScale": ImageScale, "ConditioningCombine": ConditioningCombine, "ConditioningSetArea": ConditioningSetArea, "KSamplerAdvanced": KSamplerAdvanced,