|
|
|
@ -527,6 +527,9 @@ def unet_inital_load_device(parameters, dtype):
|
|
|
|
|
else:
|
|
|
|
|
return cpu_dev
|
|
|
|
|
|
|
|
|
|
def maximum_vram_for_weights(device=None):
|
|
|
|
|
return (get_total_memory(device) * 0.8 - minimum_inference_memory())
|
|
|
|
|
|
|
|
|
|
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
|
|
|
|
if args.bf16_unet:
|
|
|
|
|
return torch.bfloat16
|
|
|
|
@ -536,6 +539,21 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
|
|
|
|
return torch.float8_e4m3fn
|
|
|
|
|
if args.fp8_e5m2_unet:
|
|
|
|
|
return torch.float8_e5m2
|
|
|
|
|
|
|
|
|
|
fp8_dtype = None
|
|
|
|
|
try:
|
|
|
|
|
for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
|
|
|
|
if dtype in supported_dtypes:
|
|
|
|
|
fp8_dtype = dtype
|
|
|
|
|
break
|
|
|
|
|
except:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
if fp8_dtype is not None:
|
|
|
|
|
free_model_memory = maximum_vram_for_weights(device)
|
|
|
|
|
if model_params * 2 > free_model_memory:
|
|
|
|
|
return fp8_dtype
|
|
|
|
|
|
|
|
|
|
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
|
|
|
|
if torch.float16 in supported_dtypes:
|
|
|
|
|
return torch.float16
|
|
|
|
@ -871,7 +889,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|
|
|
|
fp16_works = True
|
|
|
|
|
|
|
|
|
|
if fp16_works or manual_cast:
|
|
|
|
|
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
|
|
|
|
|
free_model_memory = maximum_vram_for_weights(device)
|
|
|
|
|
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
@ -920,7 +938,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|
|
|
|
bf16_works = torch.cuda.is_bf16_supported()
|
|
|
|
|
|
|
|
|
|
if bf16_works or manual_cast:
|
|
|
|
|
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
|
|
|
|
|
free_model_memory = maximum_vram_for_weights(device)
|
|
|
|
|
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|