diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 947e200..d511dda 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -278,9 +278,20 @@ def attention_split(q, k, v, heads, mask=None): ) return r1 +BROKEN_XFORMERS = False +try: + x_vers = xformers.__version__ + #I think 0.0.23 is also broken (q with bs bigger than 65535 gives CUDA error) + BROKEN_XFORMERS = x_vers.startswith("0.0.21") or x_vers.startswith("0.0.22") or x_vers.startswith("0.0.23") +except: + pass + def attention_xformers(q, k, v, heads, mask=None): b, _, dim_head = q.shape dim_head //= heads + if BROKEN_XFORMERS: + if b * heads > 65535: + return attention_pytorch(q, k, v, heads, mask) q, k, v = map( lambda t: t.unsqueeze(3)