diff --git a/comfy/model_management.py b/comfy/model_management.py index c54a360..bcd86a0 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -649,12 +649,12 @@ def supports_cast(device, dtype): #TODO return True if dtype == torch.float16: return True - if is_device_mps(device): - return False if directml_enabled: #TODO: test this return False if dtype == torch.bfloat16: return True + if is_device_mps(device): + return False if dtype == torch.float8_e4m3fn: return True if dtype == torch.float8_e5m2: @@ -876,9 +876,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow return False - if device is not None: #TODO not sure about mps bf16 support + if device is not None: if is_device_mps(device): - return False + return True if FORCE_FP32: return False