|
|
|
@ -562,12 +562,22 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
|
|
|
|
if model_params * 2 > free_model_memory:
|
|
|
|
|
return fp8_dtype
|
|
|
|
|
|
|
|
|
|
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
|
|
|
|
if torch.float16 in supported_dtypes:
|
|
|
|
|
return torch.float16
|
|
|
|
|
if should_use_bf16(device, model_params=model_params, manual_cast=True):
|
|
|
|
|
if torch.bfloat16 in supported_dtypes:
|
|
|
|
|
return torch.bfloat16
|
|
|
|
|
for dt in supported_dtypes:
|
|
|
|
|
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params):
|
|
|
|
|
if torch.float16 in supported_dtypes:
|
|
|
|
|
return torch.float16
|
|
|
|
|
if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params):
|
|
|
|
|
if torch.bfloat16 in supported_dtypes:
|
|
|
|
|
return torch.bfloat16
|
|
|
|
|
|
|
|
|
|
for dt in supported_dtypes:
|
|
|
|
|
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
|
|
|
|
if torch.float16 in supported_dtypes:
|
|
|
|
|
return torch.float16
|
|
|
|
|
if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params, manual_cast=True):
|
|
|
|
|
if torch.bfloat16 in supported_dtypes:
|
|
|
|
|
return torch.bfloat16
|
|
|
|
|
|
|
|
|
|
return torch.float32
|
|
|
|
|
|
|
|
|
|
# None means no manual cast
|
|
|
|
@ -583,13 +593,13 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
|
|
|
|
|
if bf16_supported and weight_dtype == torch.bfloat16:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
if fp16_supported and torch.float16 in supported_dtypes:
|
|
|
|
|
return torch.float16
|
|
|
|
|
for dt in supported_dtypes:
|
|
|
|
|
if dt == torch.float16 and fp16_supported:
|
|
|
|
|
return torch.float16
|
|
|
|
|
if dt == torch.bfloat16 and bf16_supported:
|
|
|
|
|
return torch.bfloat16
|
|
|
|
|
|
|
|
|
|
elif bf16_supported and torch.bfloat16 in supported_dtypes:
|
|
|
|
|
return torch.bfloat16
|
|
|
|
|
else:
|
|
|
|
|
return torch.float32
|
|
|
|
|
return torch.float32
|
|
|
|
|
|
|
|
|
|
def text_encoder_offload_device():
|
|
|
|
|
if args.gpu_only:
|
|
|
|
|