diff --git a/comfy/model_management.py b/comfy/model_management.py index 1050c13..8b89637 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -488,6 +488,8 @@ def cast_to_device(tensor, device, dtype, copy=False): elif tensor.dtype == torch.bfloat16: if hasattr(device, 'type') and device.type.startswith("cuda"): device_supports_cast = True + elif is_intel_xpu(): + device_supports_cast = True if device_supports_cast: if copy: