|
|
@ -83,7 +83,8 @@ def _summarize_chunk(
|
|
|
|
)
|
|
|
|
)
|
|
|
|
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
|
|
|
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
|
|
|
max_score = max_score.detach()
|
|
|
|
max_score = max_score.detach()
|
|
|
|
torch.exp(attn_weights - max_score, out=attn_weights)
|
|
|
|
attn_weights -= max_score
|
|
|
|
|
|
|
|
torch.exp(attn_weights, out=attn_weights)
|
|
|
|
exp_weights = attn_weights.to(value.dtype)
|
|
|
|
exp_weights = attn_weights.to(value.dtype)
|
|
|
|
exp_values = torch.bmm(exp_weights, value)
|
|
|
|
exp_values = torch.bmm(exp_weights, value)
|
|
|
|
max_score = max_score.squeeze(-1)
|
|
|
|
max_score = max_score.squeeze(-1)
|
|
|
|