|
|
@ -58,7 +58,7 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
|
|
|
|
attn = attn.reshape(b, -1, hw1, hw2)
|
|
|
|
attn = attn.reshape(b, -1, hw1, hw2)
|
|
|
|
# Global Average Pool
|
|
|
|
# Global Average Pool
|
|
|
|
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
|
|
|
|
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
|
|
|
|
ratio = math.ceil(math.sqrt(lh * lw / hw1))
|
|
|
|
ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length()
|
|
|
|
mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
|
|
|
|
mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
|
|
|
|
|
|
|
|
|
|
|
|
# Reshape
|
|
|
|
# Reshape
|
|
|
|