From 6908f9c94992b32fbb96be0f6cd8c5b362d72a77 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 22 Apr 2023 14:30:39 -0400 Subject: [PATCH] This makes pytorch2.0 attention perform a bit faster. --- comfy/ldm/modules/attention.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 98dbda6..c27d032 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -455,11 +455,7 @@ class CrossAttentionPytorch(nn.Module): b, _, _ = q.shape q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(b, t.shape[1], self.heads, self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b * self.heads, t.shape[1], self.dim_head) - .contiguous(), + lambda t: t.view(b, -1, self.heads, self.dim_head).transpose(1, 2), (q, k, v), ) @@ -468,10 +464,7 @@ class CrossAttentionPytorch(nn.Module): if exists(mask): raise NotImplementedError out = ( - out.unsqueeze(0) - .reshape(b, self.heads, out.shape[1], self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b, out.shape[1], self.heads * self.dim_head) + out.transpose(1, 2).reshape(b, -1, self.heads * self.dim_head) ) return self.to_out(out)