Expose grow_mask_by in VAEEncodeForInpaint.

The mask is dilated by grow_mask_by pixels after being applied to the pixel
space image. This helps reduce seams caused by inpainting. Higher value
means less seams.
main
comfyanonymous 2 years ago
parent 9c335a553f
commit 35f636b6c7

@ -5,6 +5,7 @@ import sys
import json import json
import hashlib import hashlib
import traceback import traceback
import math
from PIL import Image from PIL import Image
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
@ -223,13 +224,13 @@ class VAEEncodeForInpaint:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", )}} return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "encode" FUNCTION = "encode"
CATEGORY = "latent/inpaint" CATEGORY = "latent/inpaint"
def encode(self, vae, pixels, mask): def encode(self, vae, pixels, mask, grow_mask_by=6):
x = (pixels.shape[1] // 64) * 64 x = (pixels.shape[1] // 64) * 64
y = (pixels.shape[2] // 64) * 64 y = (pixels.shape[2] // 64) * 64
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
@ -240,8 +241,14 @@ class VAEEncodeForInpaint:
mask = mask[:,:,:x,:y] mask = mask[:,:,:x,:y]
#grow mask by a few pixels to keep things seamless in latent space #grow mask by a few pixels to keep things seamless in latent space
kernel_tensor = torch.ones((1, 1, 6, 6)) if grow_mask_by == 0:
mask_erosion = torch.clamp(torch.nn.functional.conv2d(mask.round(), kernel_tensor, padding=3), 0, 1) mask_erosion = mask
else:
kernel_tensor = torch.ones((1, 1, grow_mask_by, grow_mask_by))
padding = math.ceil((grow_mask_by - 1) / 2)
mask_erosion = torch.clamp(torch.nn.functional.conv2d(mask.round(), kernel_tensor, padding=padding), 0, 1)
m = (1.0 - mask.round()).squeeze(1) m = (1.0 - mask.round()).squeeze(1)
for i in range(3): for i in range(3):
pixels[:,:,:,i] -= 0.5 pixels[:,:,:,i] -= 0.5

Loading…
Cancel
Save