|
|
@ -146,8 +146,17 @@ def _get_attention_scores_no_kv_chunking(
|
|
|
|
alpha=scale,
|
|
|
|
alpha=scale,
|
|
|
|
beta=0,
|
|
|
|
beta=0,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
attn_probs = attn_scores.softmax(dim=-1)
|
|
|
|
|
|
|
|
del attn_scores
|
|
|
|
try:
|
|
|
|
|
|
|
|
attn_probs = attn_scores.softmax(dim=-1)
|
|
|
|
|
|
|
|
del attn_scores
|
|
|
|
|
|
|
|
except torch.cuda.OutOfMemoryError:
|
|
|
|
|
|
|
|
print("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
|
|
|
|
|
|
|
|
torch.exp(attn_scores, out=attn_scores)
|
|
|
|
|
|
|
|
summed = torch.sum(attn_scores, dim=-1, keepdim=True)
|
|
|
|
|
|
|
|
attn_scores /= summed
|
|
|
|
|
|
|
|
attn_probs = attn_scores
|
|
|
|
|
|
|
|
|
|
|
|
hidden_states_slice = torch.bmm(attn_probs, value)
|
|
|
|
hidden_states_slice = torch.bmm(attn_probs, value)
|
|
|
|
return hidden_states_slice
|
|
|
|
return hidden_states_slice
|
|
|
|
|
|
|
|
|
|
|
|