--disable-xformers should not even try to import xformers.

main
comfyanonymous 2 years ago
parent ebfa749b7b
commit a256a2abde

@ -11,11 +11,10 @@ from .sub_quadratic_attention import efficient_dot_product_attention
import model_management import model_management
try:
if model_management.xformers_enabled():
import xformers import xformers
import xformers.ops import xformers.ops
except:
pass
# CrossAttn precision handling # CrossAttn precision handling
import os import os

@ -9,11 +9,9 @@ from typing import Optional, Any
from ldm.modules.attention import MemoryEfficientCrossAttention from ldm.modules.attention import MemoryEfficientCrossAttention
import model_management import model_management
try: if model_management.xformers_enabled():
import xformers import xformers
import xformers.ops import xformers.ops
except:
pass
try: try:
OOM_EXCEPTION = torch.cuda.OutOfMemoryError OOM_EXCEPTION = torch.cuda.OutOfMemoryError

@ -31,6 +31,9 @@ try:
except: except:
pass pass
if "--disable-xformers" in sys.argv:
XFORMERS_IS_AVAILBLE = False
else:
try: try:
import xformers import xformers
import xformers.ops import xformers.ops
@ -38,8 +41,6 @@ try:
except: except:
XFORMERS_IS_AVAILBLE = False XFORMERS_IS_AVAILBLE = False
if "--disable-xformers" in sys.argv:
XFORMERS_IS_AVAILBLE = False
if "--cpu" in sys.argv: if "--cpu" in sys.argv:
vram_state = CPU vram_state = CPU

@ -8,9 +8,6 @@ if os.name == "nt":
import logging import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
import execution
import server
if __name__ == "__main__": if __name__ == "__main__":
if '--help' in sys.argv: if '--help' in sys.argv:
print("Valid Command line Arguments:") print("Valid Command line Arguments:")
@ -18,6 +15,7 @@ if __name__ == "__main__":
print("\t--port 8188\t\t\tSet the listen port.") print("\t--port 8188\t\t\tSet the listen port.")
print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n") print("\t--dont-upcast-attention\t\tDisable upcasting of attention \n\t\t\t\t\tcan boost speed but increase the chances of black images.\n")
print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.") print("\t--use-split-cross-attention\tUse the split cross attention optimization instead of the sub-quadratic one.\n\t\t\t\t\tIgnored when xformers is used.")
print("\t--disable-xformers\t\tdisables xformers")
print() print()
print("\t--highvram\t\t\tBy default models will be unloaded to CPU memory after being used.\n\t\t\t\t\tThis option keeps them in GPU memory.\n") print("\t--highvram\t\t\tBy default models will be unloaded to CPU memory after being used.\n\t\t\t\t\tThis option keeps them in GPU memory.\n")
print("\t--normalvram\t\t\tUsed to force normal vram use if lowvram gets automatically enabled.") print("\t--normalvram\t\t\tUsed to force normal vram use if lowvram gets automatically enabled.")
@ -31,6 +29,9 @@ if __name__ == "__main__":
print("disabling upcasting of attention") print("disabling upcasting of attention")
os.environ['ATTN_PRECISION'] = "fp16" os.environ['ATTN_PRECISION'] = "fp16"
import execution
import server
def prompt_worker(q, server): def prompt_worker(q, server):
e = execution.PromptExecutor(server) e = execution.PromptExecutor(server)
while True: while True:

Loading…
Cancel
Save