From 9906e3efe31a0fc399262766da33c210fb4e8215 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 21 Oct 2023 13:23:03 -0400 Subject: [PATCH] Make xformers work with hypertile. --- comfy/ldm/modules/attention.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 9cd14a5..4eda361 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -253,12 +253,14 @@ def attention_split(q, k, v, heads, mask=None): return r2 def attention_xformers(q, k, v, heads, mask=None): - b, _, _ = q.shape + b, _, dim_head = q.shape + dim_head //= heads + q, k, v = map( lambda t: t.unsqueeze(3) - .reshape(b, t.shape[1], heads, -1) + .reshape(b, -1, heads, dim_head) .permute(0, 2, 1, 3) - .reshape(b * heads, t.shape[1], -1) + .reshape(b * heads, -1, dim_head) .contiguous(), (q, k, v), ) @@ -270,9 +272,9 @@ def attention_xformers(q, k, v, heads, mask=None): raise NotImplementedError out = ( out.unsqueeze(0) - .reshape(b, heads, out.shape[1], -1) + .reshape(b, heads, -1, dim_head) .permute(0, 2, 1, 3) - .reshape(b, out.shape[1], -1) + .reshape(b, -1, heads * dim_head) ) return out