Make xformers work with hypertile.

main
comfyanonymous 1 year ago
parent 1443caf373
commit 9906e3efe3

@ -253,12 +253,14 @@ def attention_split(q, k, v, heads, mask=None):
return r2 return r2
def attention_xformers(q, k, v, heads, mask=None): def attention_xformers(q, k, v, heads, mask=None):
b, _, _ = q.shape b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map( q, k, v = map(
lambda t: t.unsqueeze(3) lambda t: t.unsqueeze(3)
.reshape(b, t.shape[1], heads, -1) .reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3) .permute(0, 2, 1, 3)
.reshape(b * heads, t.shape[1], -1) .reshape(b * heads, -1, dim_head)
.contiguous(), .contiguous(),
(q, k, v), (q, k, v),
) )
@ -270,9 +272,9 @@ def attention_xformers(q, k, v, heads, mask=None):
raise NotImplementedError raise NotImplementedError
out = ( out = (
out.unsqueeze(0) out.unsqueeze(0)
.reshape(b, heads, out.shape[1], -1) .reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3) .permute(0, 2, 1, 3)
.reshape(b, out.shape[1], -1) .reshape(b, -1, heads * dim_head)
) )
return out return out

Loading…
Cancel
Save