|
|
|
@ -12,8 +12,6 @@ from .sub_quadratic_attention import efficient_dot_product_attention
|
|
|
|
|
from comfy import model_management
|
|
|
|
|
import comfy.ops
|
|
|
|
|
|
|
|
|
|
from . import tomesd
|
|
|
|
|
|
|
|
|
|
if model_management.xformers_enabled():
|
|
|
|
|
import xformers
|
|
|
|
|
import xformers.ops
|
|
|
|
@ -519,23 +517,39 @@ class BasicTransformerBlock(nn.Module):
|
|
|
|
|
self.norm2 = nn.LayerNorm(dim, dtype=dtype)
|
|
|
|
|
self.norm3 = nn.LayerNorm(dim, dtype=dtype)
|
|
|
|
|
self.checkpoint = checkpoint
|
|
|
|
|
self.n_heads = n_heads
|
|
|
|
|
self.d_head = d_head
|
|
|
|
|
|
|
|
|
|
def forward(self, x, context=None, transformer_options={}):
|
|
|
|
|
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
|
|
|
|
|
|
|
|
|
def _forward(self, x, context=None, transformer_options={}):
|
|
|
|
|
extra_options = {}
|
|
|
|
|
block = None
|
|
|
|
|
block_index = 0
|
|
|
|
|
if "current_index" in transformer_options:
|
|
|
|
|
extra_options["transformer_index"] = transformer_options["current_index"]
|
|
|
|
|
if "block_index" in transformer_options:
|
|
|
|
|
extra_options["block_index"] = transformer_options["block_index"]
|
|
|
|
|
block_index = transformer_options["block_index"]
|
|
|
|
|
extra_options["block_index"] = block_index
|
|
|
|
|
if "original_shape" in transformer_options:
|
|
|
|
|
extra_options["original_shape"] = transformer_options["original_shape"]
|
|
|
|
|
if "block" in transformer_options:
|
|
|
|
|
block = transformer_options["block"]
|
|
|
|
|
extra_options["block"] = block
|
|
|
|
|
if "patches" in transformer_options:
|
|
|
|
|
transformer_patches = transformer_options["patches"]
|
|
|
|
|
else:
|
|
|
|
|
transformer_patches = {}
|
|
|
|
|
|
|
|
|
|
extra_options["n_heads"] = self.n_heads
|
|
|
|
|
extra_options["dim_head"] = self.d_head
|
|
|
|
|
|
|
|
|
|
if "patches_replace" in transformer_options:
|
|
|
|
|
transformer_patches_replace = transformer_options["patches_replace"]
|
|
|
|
|
else:
|
|
|
|
|
transformer_patches_replace = {}
|
|
|
|
|
|
|
|
|
|
n = self.norm1(x)
|
|
|
|
|
if self.disable_self_attn:
|
|
|
|
|
context_attn1 = context
|
|
|
|
@ -551,12 +565,29 @@ class BasicTransformerBlock(nn.Module):
|
|
|
|
|
for p in patch:
|
|
|
|
|
n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)
|
|
|
|
|
|
|
|
|
|
if "tomesd" in transformer_options:
|
|
|
|
|
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
|
|
|
|
n = u(self.attn1(m(n), context=context_attn1, value=value_attn1))
|
|
|
|
|
transformer_block = (block[0], block[1], block_index)
|
|
|
|
|
attn1_replace_patch = transformer_patches_replace.get("attn1", {})
|
|
|
|
|
block_attn1 = transformer_block
|
|
|
|
|
if block_attn1 not in attn1_replace_patch:
|
|
|
|
|
block_attn1 = block
|
|
|
|
|
|
|
|
|
|
if block_attn1 in attn1_replace_patch:
|
|
|
|
|
if context_attn1 is None:
|
|
|
|
|
context_attn1 = n
|
|
|
|
|
value_attn1 = n
|
|
|
|
|
n = self.attn1.to_q(n)
|
|
|
|
|
context_attn1 = self.attn1.to_k(context_attn1)
|
|
|
|
|
value_attn1 = self.attn1.to_v(value_attn1)
|
|
|
|
|
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
|
|
|
|
|
n = self.attn1.to_out(n)
|
|
|
|
|
else:
|
|
|
|
|
n = self.attn1(n, context=context_attn1, value=value_attn1)
|
|
|
|
|
|
|
|
|
|
if "attn1_output_patch" in transformer_patches:
|
|
|
|
|
patch = transformer_patches["attn1_output_patch"]
|
|
|
|
|
for p in patch:
|
|
|
|
|
n = p(n, extra_options)
|
|
|
|
|
|
|
|
|
|
x += n
|
|
|
|
|
if "middle_patch" in transformer_patches:
|
|
|
|
|
patch = transformer_patches["middle_patch"]
|
|
|
|
@ -573,7 +604,21 @@ class BasicTransformerBlock(nn.Module):
|
|
|
|
|
for p in patch:
|
|
|
|
|
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
|
|
|
|
|
|
|
|
|
|
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
|
|
|
|
attn2_replace_patch = transformer_patches_replace.get("attn2", {})
|
|
|
|
|
block_attn2 = transformer_block
|
|
|
|
|
if block_attn2 not in attn2_replace_patch:
|
|
|
|
|
block_attn2 = block
|
|
|
|
|
|
|
|
|
|
if block_attn2 in attn2_replace_patch:
|
|
|
|
|
if value_attn2 is None:
|
|
|
|
|
value_attn2 = context_attn2
|
|
|
|
|
n = self.attn2.to_q(n)
|
|
|
|
|
context_attn2 = self.attn2.to_k(context_attn2)
|
|
|
|
|
value_attn2 = self.attn2.to_v(value_attn2)
|
|
|
|
|
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
|
|
|
|
|
n = self.attn2.to_out(n)
|
|
|
|
|
else:
|
|
|
|
|
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
|
|
|
|
|
|
|
|
|
if "attn2_output_patch" in transformer_patches:
|
|
|
|
|
patch = transformer_patches["attn2_output_patch"]
|
|
|
|
|