|
|
|
@ -815,7 +815,7 @@ class LoadImage:
|
|
|
|
|
|
|
|
|
|
CATEGORY = "image"
|
|
|
|
|
|
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
|
|
|
RETURN_TYPES = ("IMAGE", "MASK")
|
|
|
|
|
FUNCTION = "load_image"
|
|
|
|
|
def load_image(self, image):
|
|
|
|
|
image_path = os.path.join(self.input_dir, image)
|
|
|
|
@ -823,7 +823,12 @@ class LoadImage:
|
|
|
|
|
image = i.convert("RGB")
|
|
|
|
|
image = np.array(image).astype(np.float32) / 255.0
|
|
|
|
|
image = torch.from_numpy(image)[None,]
|
|
|
|
|
return (image,)
|
|
|
|
|
if 'A' in i.getbands():
|
|
|
|
|
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
|
|
|
|
|
mask = 1. - torch.from_numpy(mask)
|
|
|
|
|
else:
|
|
|
|
|
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
|
|
|
|
return (image, mask)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def IS_CHANGED(s, image):
|
|
|
|
|