diff --git a/comfy/model_management.py b/comfy/model_management.py index 994fcd8..ec80afe 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -562,12 +562,22 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor if model_params * 2 > free_model_memory: return fp8_dtype - if should_use_fp16(device=device, model_params=model_params, manual_cast=True): - if torch.float16 in supported_dtypes: - return torch.float16 - if should_use_bf16(device, model_params=model_params, manual_cast=True): - if torch.bfloat16 in supported_dtypes: - return torch.bfloat16 + for dt in supported_dtypes: + if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params): + if torch.float16 in supported_dtypes: + return torch.float16 + if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params): + if torch.bfloat16 in supported_dtypes: + return torch.bfloat16 + + for dt in supported_dtypes: + if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params, manual_cast=True): + if torch.float16 in supported_dtypes: + return torch.float16 + if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params, manual_cast=True): + if torch.bfloat16 in supported_dtypes: + return torch.bfloat16 + return torch.float32 # None means no manual cast @@ -583,13 +593,13 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo if bf16_supported and weight_dtype == torch.bfloat16: return None - if fp16_supported and torch.float16 in supported_dtypes: - return torch.float16 + for dt in supported_dtypes: + if dt == torch.float16 and fp16_supported: + return torch.float16 + if dt == torch.bfloat16 and bf16_supported: + return torch.bfloat16 - elif bf16_supported and torch.bfloat16 in supported_dtypes: - return torch.bfloat16 - else: - return torch.float32 + return torch.float32 def text_encoder_offload_device(): if args.gpu_only: