|
|
@ -185,7 +185,7 @@ def should_use_fp16():
|
|
|
|
if torch.cuda.is_bf16_supported():
|
|
|
|
if torch.cuda.is_bf16_supported():
|
|
|
|
return True
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
props = torch.cuda.get_device_properties()
|
|
|
|
props = torch.cuda.get_device_properties("cuda")
|
|
|
|
if props.major < 7:
|
|
|
|
if props.major < 7:
|
|
|
|
return False
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|