|
|
|
@ -979,7 +979,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|
|
|
|
if torch.version.hip:
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
props = torch.cuda.get_device_properties("cuda")
|
|
|
|
|
props = torch.cuda.get_device_properties(device)
|
|
|
|
|
if props.major >= 8:
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
@ -1035,7 +1035,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|
|
|
|
if is_intel_xpu():
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
props = torch.cuda.get_device_properties("cuda")
|
|
|
|
|
props = torch.cuda.get_device_properties(device)
|
|
|
|
|
if props.major >= 8:
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|