diff --git a/nodes.py b/nodes.py index 002d022..7c88402 100644 --- a/nodes.py +++ b/nodes.py @@ -815,7 +815,7 @@ class LoadImage: CATEGORY = "image" - RETURN_TYPES = ("IMAGE",) + RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" def load_image(self, image): image_path = os.path.join(self.input_dir, image) @@ -823,7 +823,12 @@ class LoadImage: image = i.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 image = torch.from_numpy(image)[None,] - return (image,) + if 'A' in i.getbands(): + mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 + mask = 1. - torch.from_numpy(mask) + else: + mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") + return (image, mask) @classmethod def IS_CHANGED(s, image):