|
|
@ -278,9 +278,20 @@ def attention_split(q, k, v, heads, mask=None):
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return r1
|
|
|
|
return r1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BROKEN_XFORMERS = False
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
x_vers = xformers.__version__
|
|
|
|
|
|
|
|
#I think 0.0.23 is also broken (q with bs bigger than 65535 gives CUDA error)
|
|
|
|
|
|
|
|
BROKEN_XFORMERS = x_vers.startswith("0.0.21") or x_vers.startswith("0.0.22") or x_vers.startswith("0.0.23")
|
|
|
|
|
|
|
|
except:
|
|
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
def attention_xformers(q, k, v, heads, mask=None):
|
|
|
|
def attention_xformers(q, k, v, heads, mask=None):
|
|
|
|
b, _, dim_head = q.shape
|
|
|
|
b, _, dim_head = q.shape
|
|
|
|
dim_head //= heads
|
|
|
|
dim_head //= heads
|
|
|
|
|
|
|
|
if BROKEN_XFORMERS:
|
|
|
|
|
|
|
|
if b * heads > 65535:
|
|
|
|
|
|
|
|
return attention_pytorch(q, k, v, heads, mask)
|
|
|
|
|
|
|
|
|
|
|
|
q, k, v = map(
|
|
|
|
q, k, v = map(
|
|
|
|
lambda t: t.unsqueeze(3)
|
|
|
|
lambda t: t.unsqueeze(3)
|
|
|
|