From 88733c997fd807a572d4a214d2c15fc5dd17b3c6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 11 Oct 2023 21:29:03 -0400 Subject: [PATCH] pytorch_attention_enabled can now return True when xformers is enabled. --- comfy/ldm/modules/diffusionmodules/model.py | 2 +- comfy/model_management.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index e6cf954..6576df4 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -355,7 +355,7 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' if model_management.xformers_enabled_vae() and attn_type == "vanilla": attn_type = "vanilla-xformers" - if model_management.pytorch_attention_enabled() and attn_type == "vanilla": + elif model_management.pytorch_attention_enabled() and attn_type == "vanilla": attn_type = "vanilla-pytorch" print(f"making attention of type '{attn_type}' with {in_channels} in_channels") if attn_type == "vanilla": diff --git a/comfy/model_management.py b/comfy/model_management.py index 3b43b21..3c390d9 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -154,14 +154,18 @@ def is_nvidia(): return True return False -ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention +ENABLE_PYTORCH_ATTENTION = False +if args.use_pytorch_cross_attention: + ENABLE_PYTORCH_ATTENTION = True + XFORMERS_IS_AVAILABLE = False + VAE_DTYPE = torch.float32 try: if is_nvidia(): torch_version = torch.version.__version__ if int(torch_version[0]) >= 2: - if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: + if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True if torch.cuda.is_bf16_supported(): VAE_DTYPE = torch.bfloat16 @@ -186,7 +190,6 @@ if ENABLE_PYTORCH_ATTENTION: torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True) - XFORMERS_IS_AVAILABLE = False if args.lowvram: set_vram_to = VRAMState.LOW_VRAM