|
|
@ -374,7 +374,7 @@ def resolve_cond_masks(conditions, h, w, device):
|
|
|
|
modified = c[1].copy()
|
|
|
|
modified = c[1].copy()
|
|
|
|
if len(mask.shape) == 2:
|
|
|
|
if len(mask.shape) == 2:
|
|
|
|
mask = mask.unsqueeze(0)
|
|
|
|
mask = mask.unsqueeze(0)
|
|
|
|
if mask.shape[2] != h or mask.shape[3] != w:
|
|
|
|
if mask.shape[1] != h or mask.shape[2] != w:
|
|
|
|
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1)
|
|
|
|
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(h, w), mode='bilinear', align_corners=False).squeeze(1)
|
|
|
|
|
|
|
|
|
|
|
|
if modified.get("set_area_to_bounds", False):
|
|
|
|
if modified.get("set_area_to_bounds", False):
|
|
|
|