|
|
|
@ -8,9 +8,8 @@ from torch import Tensor, nn
|
|
|
|
|
from .math import attention, rope
|
|
|
|
|
import comfy.ops
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EmbedND(nn.Module):
|
|
|
|
|
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
|
|
|
|
def __init__(self, dim: int, theta: int, axes_dim: list):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.dim = dim
|
|
|
|
|
self.theta = theta
|
|
|
|
@ -79,7 +78,7 @@ class QKNorm(torch.nn.Module):
|
|
|
|
|
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
|
|
|
|
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
|
|
|
|
|
|
|
|
|
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
|
|
|
|
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
|
|
|
|
|
q = self.query_norm(q)
|
|
|
|
|
k = self.key_norm(k)
|
|
|
|
|
return q.to(v), k.to(v)
|
|
|
|
@ -118,7 +117,7 @@ class Modulation(nn.Module):
|
|
|
|
|
self.multiplier = 6 if double else 3
|
|
|
|
|
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
|
|
|
|
|
|
|
|
|
|
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
|
|
|
|
def forward(self, vec: Tensor) -> tuple:
|
|
|
|
|
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
|
|
|
|
|
|
|
|
|
return (
|
|
|
|
@ -156,7 +155,7 @@ class DoubleStreamBlock(nn.Module):
|
|
|
|
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
|
|
|
|
|
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
|
|
|
|
|
img_mod1, img_mod2 = self.img_mod(vec)
|
|
|
|
|
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
|
|
|
|
|
|
|
|
@ -203,7 +202,7 @@ class SingleStreamBlock(nn.Module):
|
|
|
|
|
hidden_size: int,
|
|
|
|
|
num_heads: int,
|
|
|
|
|
mlp_ratio: float = 4.0,
|
|
|
|
|
qk_scale: float | None = None,
|
|
|
|
|
qk_scale: float = None,
|
|
|
|
|
dtype=None,
|
|
|
|
|
device=None,
|
|
|
|
|
operations=None
|
|
|
|
|