diff --git a/nodes.py b/nodes.py index 40e7558..a3c0335 100644 --- a/nodes.py +++ b/nodes.py @@ -1,3 +1,5 @@ +import math + import torch import os @@ -942,28 +944,31 @@ class ImagePadForOutpaint: dtype=torch.float32, ) + t = torch.zeros( + (d2, d3), + dtype=torch.float32 + ) + if feathering > 0 and feathering * 2 < d2 and feathering * 2 < d3: - # distances to border - mi, mj = torch.meshgrid( - torch.arange(d2, dtype=torch.float32), - torch.arange(d3, dtype=torch.float32), - indexing='ij', - ) - distances = torch.minimum( - torch.minimum(mi, mj), - torch.minimum(d2 - 1 - mi, d3 - 1 - mj), - ) - # convert distances to square falloff from 1 to 0 - t = (feathering - distances) / feathering - t.clamp_(min=0) - t.square_() - - mask[top:top + d2, left:left + d3] = t - else: - mask[top:top + d2, left:left + d3] = torch.zeros( - (d2, d3), - dtype=torch.float32, - ) + + for i in range(d2): + for j in range(d3): + dt = i if top != 0 else d2 + db = d2 - i if bottom != 0 else d2 + + dl = j if left != 0 else d3 + dr = d3 - j if right != 0 else d3 + + d = min(dt, db, dl, dr) + + if d >= feathering: + continue + + v = (feathering - d) / feathering + + t[i, j] = v * v + + mask[top:top + d2, left:left + d3] = t return (new_image, mask)