Make sub_quad and split work with hypertile.

main
comfyanonymous 1 year ago
parent 8cfce083c4
commit e6bc42df46

@ -124,11 +124,14 @@ def attention_basic(q, k, v, heads, mask=None):
def attention_sub_quad(query, key, value, heads, mask=None): def attention_sub_quad(query, key, value, heads, mask=None):
scale = (query.shape[-1] // heads) ** -0.5 b, _, dim_head = query.shape
query = query.unflatten(-1, (heads, -1)).transpose(1,2).flatten(end_dim=1) dim_head //= heads
key_t = key.transpose(1,2).unflatten(1, (heads, -1)).flatten(end_dim=1)
del key scale = dim_head ** -0.5
value = value.unflatten(-1, (heads, -1)).transpose(1,2).flatten(end_dim=1) query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
dtype = query.dtype dtype = query.dtype
upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32 upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
@ -137,7 +140,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
else: else:
bytes_per_token = torch.finfo(query.dtype).bits//8 bytes_per_token = torch.finfo(query.dtype).bits//8
batch_x_heads, q_tokens, _ = query.shape batch_x_heads, q_tokens, _ = query.shape
_, _, k_tokens = key_t.shape _, _, k_tokens = key.shape
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True) mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
@ -171,7 +174,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
hidden_states = efficient_dot_product_attention( hidden_states = efficient_dot_product_attention(
query, query,
key_t, key,
value, value,
query_chunk_size=query_chunk_size, query_chunk_size=query_chunk_size,
kv_chunk_size=kv_chunk_size, kv_chunk_size=kv_chunk_size,
@ -186,9 +189,19 @@ def attention_sub_quad(query, key, value, heads, mask=None):
return hidden_states return hidden_states
def attention_split(q, k, v, heads, mask=None): def attention_split(q, k, v, heads, mask=None):
scale = (q.shape[-1] // heads) ** -0.5 b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
h = heads h = heads
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) q, k, v = map(
lambda t: t.unsqueeze(3)
.reshape(b, -1, heads, dim_head)
.permute(0, 2, 1, 3)
.reshape(b * heads, -1, dim_head)
.contiguous(),
(q, k, v),
)
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
@ -248,9 +261,13 @@ def attention_split(q, k, v, heads, mask=None):
del q, k, v del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) r1 = (
del r1 r1.unsqueeze(0)
return r2 .reshape(b, heads, -1, dim_head)
.permute(0, 2, 1, 3)
.reshape(b, -1, heads * dim_head)
)
return r1
def attention_xformers(q, k, v, heads, mask=None): def attention_xformers(q, k, v, heads, mask=None):
b, _, dim_head = q.shape b, _, dim_head = q.shape

Loading…
Cancel
Save