|
|
|
@ -630,8 +630,14 @@ def supports_dtype(device, dtype): #TODO
|
|
|
|
|
def device_supports_non_blocking(device):
|
|
|
|
|
if is_device_mps(device):
|
|
|
|
|
return False #pytorch bug? mps doesn't support non blocking
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
def device_should_use_non_blocking(device):
|
|
|
|
|
if not device_supports_non_blocking(device):
|
|
|
|
|
return False
|
|
|
|
|
return False
|
|
|
|
|
# return True #TODO: figure out why this causes issues
|
|
|
|
|
# return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cast_to_device(tensor, device, dtype, copy=False):
|
|
|
|
|
device_supports_cast = False
|
|
|
|
@ -643,7 +649,7 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
|
|
|
|
elif is_intel_xpu():
|
|
|
|
|
device_supports_cast = True
|
|
|
|
|
|
|
|
|
|
non_blocking = device_supports_non_blocking(device)
|
|
|
|
|
non_blocking = device_should_use_non_blocking(device)
|
|
|
|
|
|
|
|
|
|
if device_supports_cast:
|
|
|
|
|
if copy:
|
|
|
|
|