|
|
|
@ -351,8 +351,11 @@ else:
|
|
|
|
|
optimized_attention_masked = optimized_attention
|
|
|
|
|
|
|
|
|
|
def optimized_attention_for_device(device, mask=False, small_input=False):
|
|
|
|
|
if small_input and model_management.pytorch_attention_enabled():
|
|
|
|
|
return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases
|
|
|
|
|
if small_input:
|
|
|
|
|
if model_management.pytorch_attention_enabled():
|
|
|
|
|
return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases
|
|
|
|
|
else:
|
|
|
|
|
return attention_basic
|
|
|
|
|
|
|
|
|
|
if device == torch.device("cpu"):
|
|
|
|
|
return attention_sub_quad
|
|
|
|
|