|
|
|
@ -154,14 +154,18 @@ def is_nvidia():
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
|
|
|
|
|
ENABLE_PYTORCH_ATTENTION = False
|
|
|
|
|
if args.use_pytorch_cross_attention:
|
|
|
|
|
ENABLE_PYTORCH_ATTENTION = True
|
|
|
|
|
XFORMERS_IS_AVAILABLE = False
|
|
|
|
|
|
|
|
|
|
VAE_DTYPE = torch.float32
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
if is_nvidia():
|
|
|
|
|
torch_version = torch.version.__version__
|
|
|
|
|
if int(torch_version[0]) >= 2:
|
|
|
|
|
if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
|
|
|
|
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
|
|
|
|
ENABLE_PYTORCH_ATTENTION = True
|
|
|
|
|
if torch.cuda.is_bf16_supported():
|
|
|
|
|
VAE_DTYPE = torch.bfloat16
|
|
|
|
@ -186,7 +190,6 @@ if ENABLE_PYTORCH_ATTENTION:
|
|
|
|
|
torch.backends.cuda.enable_math_sdp(True)
|
|
|
|
|
torch.backends.cuda.enable_flash_sdp(True)
|
|
|
|
|
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
|
|
|
|
XFORMERS_IS_AVAILABLE = False
|
|
|
|
|
|
|
|
|
|
if args.lowvram:
|
|
|
|
|
set_vram_to = VRAMState.LOW_VRAM
|
|
|
|
|