Make xformers work with hypertile.

main
comfyanonymous 1 year ago
parent 1443caf373
commit 9906e3efe3

@ -253,12 +253,14 @@ def attention_split(q, k, v, heads, mask=None):
return r2
def attention_xformers(q, k, v, heads, mask=None):
b, _, _ = q.shape
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], heads, -1)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, t.shape[1], -1)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
@ -270,9 +272,9 @@ def attention_xformers(q, k, v, heads, mask=None):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, heads, out.shape[1], -1)
.reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], -1)
.reshape(b, -1, heads * dim_head)
)
return out

Loading…
Cancel
Save