Basic Flux Schnell and Flux Dev model implementation.
parent
7ad574bffd
commit
1589b58d3e
@ -0,0 +1,257 @@
|
|||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
from .math import attention, rope
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedND(nn.Module):
|
||||||
|
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.theta = theta
|
||||||
|
self.axes_dim = axes_dim
|
||||||
|
|
||||||
|
def forward(self, ids: Tensor) -> Tensor:
|
||||||
|
n_axes = ids.shape[-1]
|
||||||
|
emb = torch.cat(
|
||||||
|
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||||
|
dim=-3,
|
||||||
|
)
|
||||||
|
|
||||||
|
return emb.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
||||||
|
"""
|
||||||
|
Create sinusoidal timestep embeddings.
|
||||||
|
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||||
|
These may be fractional.
|
||||||
|
:param dim: the dimension of the output.
|
||||||
|
:param max_period: controls the minimum frequency of the embeddings.
|
||||||
|
:return: an (N, D) Tensor of positional embeddings.
|
||||||
|
"""
|
||||||
|
t = time_factor * t
|
||||||
|
half = dim // 2
|
||||||
|
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
||||||
|
t.device
|
||||||
|
)
|
||||||
|
|
||||||
|
args = t[:, None].float() * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2:
|
||||||
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||||
|
if torch.is_floating_point(t):
|
||||||
|
embedding = embedding.to(t)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
class MLPEmbedder(nn.Module):
|
||||||
|
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||||
|
self.silu = nn.SiLU()
|
||||||
|
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return self.out_layer(self.silu(self.in_layer(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(torch.nn.Module):
|
||||||
|
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
||||||
|
|
||||||
|
def forward(self, x: Tensor):
|
||||||
|
x_dtype = x.dtype
|
||||||
|
x = x.float()
|
||||||
|
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
||||||
|
return (x * rrms).to(dtype=x_dtype) * comfy.ops.cast_to(self.scale, dtype=x_dtype, device=x.device)
|
||||||
|
|
||||||
|
|
||||||
|
class QKNorm(torch.nn.Module):
|
||||||
|
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
||||||
|
q = self.query_norm(q)
|
||||||
|
k = self.key_norm(k)
|
||||||
|
return q.to(v), k.to(v)
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention(nn.Module):
|
||||||
|
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
|
||||||
|
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||||
|
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
||||||
|
qkv = self.qkv(x)
|
||||||
|
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||||
|
q, k = self.norm(q, k, v)
|
||||||
|
x = attention(q, k, v, pe=pe)
|
||||||
|
x = self.proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModulationOut:
|
||||||
|
shift: Tensor
|
||||||
|
scale: Tensor
|
||||||
|
gate: Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class Modulation(nn.Module):
|
||||||
|
def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.is_double = double
|
||||||
|
self.multiplier = 6 if double else 3
|
||||||
|
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
||||||
|
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
||||||
|
|
||||||
|
return (
|
||||||
|
ModulationOut(*out[:3]),
|
||||||
|
ModulationOut(*out[3:]) if self.is_double else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DoubleStreamBlock(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.img_mlp = nn.Sequential(
|
||||||
|
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||||
|
nn.GELU(approximate="tanh"),
|
||||||
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.txt_mlp = nn.Sequential(
|
||||||
|
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||||
|
nn.GELU(approximate="tanh"),
|
||||||
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
|
||||||
|
img_mod1, img_mod2 = self.img_mod(vec)
|
||||||
|
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||||
|
|
||||||
|
# prepare image for attention
|
||||||
|
img_modulated = self.img_norm1(img)
|
||||||
|
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||||
|
img_qkv = self.img_attn.qkv(img_modulated)
|
||||||
|
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||||
|
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||||
|
|
||||||
|
# prepare txt for attention
|
||||||
|
txt_modulated = self.txt_norm1(txt)
|
||||||
|
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||||
|
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||||
|
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||||
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||||
|
|
||||||
|
# run actual attention
|
||||||
|
q = torch.cat((txt_q, img_q), dim=2)
|
||||||
|
k = torch.cat((txt_k, img_k), dim=2)
|
||||||
|
v = torch.cat((txt_v, img_v), dim=2)
|
||||||
|
|
||||||
|
attn = attention(q, k, v, pe=pe)
|
||||||
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||||
|
|
||||||
|
# calculate the img bloks
|
||||||
|
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||||
|
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||||
|
|
||||||
|
# calculate the txt bloks
|
||||||
|
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||||
|
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||||
|
return img, txt
|
||||||
|
|
||||||
|
|
||||||
|
class SingleStreamBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
A DiT block with parallel linear layers as described in
|
||||||
|
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
qk_scale: float | None = None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_dim = hidden_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = hidden_size // num_heads
|
||||||
|
self.scale = qk_scale or head_dim**-0.5
|
||||||
|
|
||||||
|
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
|
# qkv and mlp_in
|
||||||
|
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
|
||||||
|
# proj and mlp_out
|
||||||
|
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.mlp_act = nn.GELU(approximate="tanh")
|
||||||
|
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||||
|
mod, _ = self.modulation(vec)
|
||||||
|
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||||
|
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||||
|
|
||||||
|
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||||
|
q, k = self.norm(q, k, v)
|
||||||
|
|
||||||
|
# compute attention
|
||||||
|
attn = attention(q, k, v, pe=pe)
|
||||||
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||||
|
return x + mod.gate * output
|
||||||
|
|
||||||
|
|
||||||
|
class LastLayer(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||||
|
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
||||||
|
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
||||||
|
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
@ -0,0 +1,29 @@
|
|||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import Tensor
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
|
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
||||||
|
q, k = apply_rope(q, k, pe)
|
||||||
|
|
||||||
|
heads = q.shape[1]
|
||||||
|
x = optimized_attention(q, k, v, heads, skip_reshape=True)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||||
|
assert dim % 2 == 0
|
||||||
|
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||||
|
omega = 1.0 / (theta**scale)
|
||||||
|
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||||
|
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||||
|
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
||||||
|
return out.float()
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
||||||
|
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||||
|
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||||
|
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||||
|
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||||
|
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
@ -0,0 +1,136 @@
|
|||||||
|
#Original code can be found on: https://github.com/black-forest-labs/flux
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
from .layers import (
|
||||||
|
DoubleStreamBlock,
|
||||||
|
EmbedND,
|
||||||
|
LastLayer,
|
||||||
|
MLPEmbedder,
|
||||||
|
SingleStreamBlock,
|
||||||
|
timestep_embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FluxParams:
|
||||||
|
in_channels: int
|
||||||
|
vec_in_dim: int
|
||||||
|
context_in_dim: int
|
||||||
|
hidden_size: int
|
||||||
|
mlp_ratio: float
|
||||||
|
num_heads: int
|
||||||
|
depth: int
|
||||||
|
depth_single_blocks: int
|
||||||
|
axes_dim: list[int]
|
||||||
|
theta: int
|
||||||
|
qkv_bias: bool
|
||||||
|
guidance_embed: bool
|
||||||
|
|
||||||
|
|
||||||
|
class Flux(nn.Module):
|
||||||
|
"""
|
||||||
|
Transformer model for flow matching on sequences.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
params = FluxParams(**kwargs)
|
||||||
|
self.params = params
|
||||||
|
self.in_channels = params.in_channels
|
||||||
|
self.out_channels = self.in_channels
|
||||||
|
if params.hidden_size % params.num_heads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||||
|
)
|
||||||
|
pe_dim = params.hidden_size // params.num_heads
|
||||||
|
if sum(params.axes_dim) != pe_dim:
|
||||||
|
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||||
|
self.hidden_size = params.hidden_size
|
||||||
|
self.num_heads = params.num_heads
|
||||||
|
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||||
|
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.guidance_in = (
|
||||||
|
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
||||||
|
)
|
||||||
|
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.double_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
DoubleStreamBlock(
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_heads,
|
||||||
|
mlp_ratio=params.mlp_ratio,
|
||||||
|
qkv_bias=params.qkv_bias,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
for _ in range(params.depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.single_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
||||||
|
for _ in range(params.depth_single_blocks)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
def forward_orig(
|
||||||
|
self,
|
||||||
|
img: Tensor,
|
||||||
|
img_ids: Tensor,
|
||||||
|
txt: Tensor,
|
||||||
|
txt_ids: Tensor,
|
||||||
|
timesteps: Tensor,
|
||||||
|
y: Tensor,
|
||||||
|
guidance: Tensor | None = None,
|
||||||
|
) -> Tensor:
|
||||||
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
|
|
||||||
|
# running on sequences img
|
||||||
|
img = self.img_in(img)
|
||||||
|
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
|
||||||
|
if self.params.guidance_embed:
|
||||||
|
if guidance is None:
|
||||||
|
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||||
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||||
|
|
||||||
|
vec = vec + self.vector_in(y)
|
||||||
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
pe = self.pe_embedder(ids)
|
||||||
|
|
||||||
|
for block in self.double_blocks:
|
||||||
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||||
|
|
||||||
|
img = torch.cat((txt, img), 1)
|
||||||
|
for block in self.single_blocks:
|
||||||
|
img = block(img, vec=vec, pe=pe)
|
||||||
|
img = img[:, txt.shape[1] :, ...]
|
||||||
|
|
||||||
|
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||||
|
return img
|
||||||
|
|
||||||
|
def forward(self, x, timestep, context, y, guidance, **kwargs):
|
||||||
|
bs, c, h, w = x.shape
|
||||||
|
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||||
|
|
||||||
|
img_ids = torch.zeros((h // 2, w // 2, 3), device=x.device, dtype=x.dtype)
|
||||||
|
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=x.device, dtype=x.dtype)[:, None]
|
||||||
|
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=x.device, dtype=x.dtype)[None, :]
|
||||||
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
|
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance)
|
||||||
|
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h // 2, w=w // 2, ph=2, pw=2)
|
@ -0,0 +1,64 @@
|
|||||||
|
from comfy import sd1_clip
|
||||||
|
import comfy.text_encoders.t5
|
||||||
|
from transformers import T5TokenizerFast
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
|
class T5XXLModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||||
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5)
|
||||||
|
|
||||||
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||||
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||||
|
|
||||||
|
|
||||||
|
class FluxTokenizer:
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
||||||
|
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||||
|
out = {}
|
||||||
|
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||||
|
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def untokenize(self, token_weight_pair):
|
||||||
|
return self.clip_g.untokenize(token_weight_pair)
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class FluxClipModel(torch.nn.Module):
|
||||||
|
def __init__(self, device="cpu", dtype=None):
|
||||||
|
super().__init__()
|
||||||
|
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False)
|
||||||
|
self.t5xxl = T5XXLModel(device=device, dtype=dtype)
|
||||||
|
self.dtypes = set([dtype])
|
||||||
|
|
||||||
|
def set_clip_options(self, options):
|
||||||
|
self.clip_l.set_clip_options(options)
|
||||||
|
self.t5xxl.set_clip_options(options)
|
||||||
|
|
||||||
|
def reset_clip_options(self):
|
||||||
|
self.clip_l.reset_clip_options()
|
||||||
|
self.t5xxl.reset_clip_options()
|
||||||
|
|
||||||
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
|
token_weight_pairs_l = token_weight_pairs["l"]
|
||||||
|
token_weight_pars_t5 = token_weight_pairs["t5xxl"]
|
||||||
|
|
||||||
|
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5)
|
||||||
|
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||||
|
return t5_out, l_pooled
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
||||||
|
return self.clip_l.load_sd(sd)
|
||||||
|
else:
|
||||||
|
return self.t5xxl.load_sd(sd)
|
||||||
|
|
Loading…
Reference in New Issue