|
|
@ -668,6 +668,7 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
|
|
|
|
if bf16_supported and weight_dtype == torch.bfloat16:
|
|
|
|
if bf16_supported and weight_dtype == torch.bfloat16:
|
|
|
|
return None
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fp16_supported = should_use_fp16(inference_device, prioritize_performance=True)
|
|
|
|
for dt in supported_dtypes:
|
|
|
|
for dt in supported_dtypes:
|
|
|
|
if dt == torch.float16 and fp16_supported:
|
|
|
|
if dt == torch.float16 and fp16_supported:
|
|
|
|
return torch.float16
|
|
|
|
return torch.float16
|
|
|
|