Basic hunyuan dit implementation. (#4102)
* Let tokenizers return weights to be stored in the saved checkpoint. * Basic hunyuan dit implementation. * Fix some resolutions not working. * Support hydit checkpoint save. * Init with right dtype. * Switch to optimized attention in pooler. * Fix black images on hunyuan dit.main
parent
f87810cd3e
commit
a5f4292f9f
@ -0,0 +1,219 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Tuple, Union, Optional
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False):
|
||||
"""
|
||||
Reshape frequency tensor for broadcasting it with another tensor.
|
||||
|
||||
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
|
||||
for the purpose of broadcasting the frequency tensor during element-wise operations.
|
||||
|
||||
Args:
|
||||
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
|
||||
x (torch.Tensor): Target tensor for broadcasting compatibility.
|
||||
head_first (bool): head dimension first (except batch dim) or not.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Reshaped frequency tensor.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the frequency tensor doesn't match the expected shape.
|
||||
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
|
||||
"""
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
|
||||
if isinstance(freqs_cis, tuple):
|
||||
# freqs_cis: (cos, sin) in real space
|
||||
if head_first:
|
||||
assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
|
||||
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
else:
|
||||
assert freqs_cis[0].shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
||||
else:
|
||||
# freqs_cis: values in complex space
|
||||
if head_first:
|
||||
assert freqs_cis.shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
|
||||
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
else:
|
||||
assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis.view(*shape)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: Optional[torch.Tensor],
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
head_first: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor.
|
||||
|
||||
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
|
||||
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
|
||||
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
|
||||
returned as real tensors.
|
||||
|
||||
Args:
|
||||
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
|
||||
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
|
||||
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials.
|
||||
head_first (bool): head dimension first (except batch dim) or not.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
||||
|
||||
"""
|
||||
xk_out = None
|
||||
if isinstance(freqs_cis, tuple):
|
||||
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
||||
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
||||
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
||||
if xk is not None:
|
||||
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
||||
else:
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
|
||||
if xk is not None:
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
|
||||
|
||||
return xq_out, xk_out
|
||||
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
"""
|
||||
Use QK Normalization.
|
||||
"""
|
||||
def __init__(self,
|
||||
qdim,
|
||||
kdim,
|
||||
num_heads,
|
||||
qkv_bias=True,
|
||||
qk_norm=False,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
attn_precision=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.attn_precision = attn_precision
|
||||
self.qdim = qdim
|
||||
self.kdim = kdim
|
||||
self.num_heads = num_heads
|
||||
assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
|
||||
self.head_dim = self.qdim // num_heads
|
||||
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.q_proj = operations.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
|
||||
self.kv_proj = operations.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
|
||||
|
||||
# TODO: eps should be 1 / 65530 if using fp16
|
||||
self.q_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
|
||||
self.k_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.out_proj = operations.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x, y, freqs_cis_img=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
x: torch.Tensor
|
||||
(batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim)
|
||||
y: torch.Tensor
|
||||
(batch, seqlen2, hidden_dim2)
|
||||
freqs_cis_img: torch.Tensor
|
||||
(batch, hidden_dim // 2), RoPE for image
|
||||
"""
|
||||
b, s1, c = x.shape # [b, s1, D]
|
||||
_, s2, c = y.shape # [b, s2, 1024]
|
||||
|
||||
q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d]
|
||||
kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim) # [b, s2, 2, h, d]
|
||||
k, v = kv.unbind(dim=2) # [b, s, h, d]
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if freqs_cis_img is not None:
|
||||
qq, _ = apply_rotary_emb(q, None, freqs_cis_img)
|
||||
assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}'
|
||||
q = qq
|
||||
|
||||
q = q.transpose(-2, -3).contiguous() # q -> B, L1, H, C - B, H, L1, C
|
||||
k = k.transpose(-2, -3).contiguous() # k -> B, L2, H, C - B, H, C, L2
|
||||
v = v.transpose(-2, -3).contiguous()
|
||||
|
||||
context = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)
|
||||
|
||||
out = self.out_proj(context) # context.reshape - B, L1, -1
|
||||
out = self.proj_drop(out)
|
||||
|
||||
out_tuple = (out,)
|
||||
|
||||
return out_tuple
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
We rename some layer names to align with flash attention
|
||||
"""
|
||||
def __init__(self, dim, num_heads, qkv_bias=True, qk_norm=False, attn_drop=0., proj_drop=0., attn_precision=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.attn_precision = attn_precision
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
assert self.dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
self.head_dim = self.dim // num_heads
|
||||
# This assertion is aligned with flash attention
|
||||
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
# qkv --> Wqkv
|
||||
self.Wqkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
# TODO: eps should be 1 / 65530 if using fp16
|
||||
self.q_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
|
||||
self.k_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.out_proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x, freqs_cis_img=None):
|
||||
B, N, C = x.shape
|
||||
qkv = self.Wqkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) # [3, b, h, s, d]
|
||||
q, k, v = qkv.unbind(0) # [b, h, s, d]
|
||||
q = self.q_norm(q) # [b, h, s, d]
|
||||
k = self.k_norm(k) # [b, h, s, d]
|
||||
|
||||
# Apply RoPE if needed
|
||||
if freqs_cis_img is not None:
|
||||
qq, kk = apply_rotary_emb(q, k, freqs_cis_img, head_first=True)
|
||||
assert qq.shape == q.shape and kk.shape == k.shape, \
|
||||
f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}'
|
||||
q, k = qq, kk
|
||||
|
||||
x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)
|
||||
x = self.out_proj(x)
|
||||
x = self.proj_drop(x)
|
||||
|
||||
out_tuple = (x,)
|
||||
|
||||
return out_tuple
|
@ -0,0 +1,36 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.attention import optimized_attention #TODO
|
||||
|
||||
|
||||
class AttentionPool(nn.Module):
|
||||
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(torch.empty(spacial_dim + 1, embed_dim, dtype=dtype, device=device))
|
||||
self.k_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
|
||||
self.q_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
|
||||
self.v_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
|
||||
self.c_proj = operations.Linear(embed_dim, output_dim or embed_dim, dtype=dtype, device=device)
|
||||
self.num_heads = num_heads
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(1, 0, 2) # NLC -> LNC
|
||||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
|
||||
x = x + self.positional_embedding[:, None, :].to(dtype=x.dtype, device=x.device) # (L+1)NC
|
||||
|
||||
q = self.q_proj(x[:1])
|
||||
k = self.k_proj(x)
|
||||
v = self.v_proj(x)
|
||||
|
||||
batch_size = q.shape[1]
|
||||
head_dim = self.embed_dim // self.num_heads
|
||||
q = q.view(1, batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
|
||||
k = k.view(k.shape[0], batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
|
||||
v = v.view(v.shape[0], batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
|
||||
|
||||
attn_output = optimized_attention(q, k, v, self.num_heads, skip_reshape=True).transpose(0, 1)
|
||||
|
||||
attn_output = self.c_proj(attn_output)
|
||||
return attn_output.squeeze(0)
|
@ -0,0 +1,224 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Union
|
||||
|
||||
|
||||
def _to_tuple(x):
|
||||
if isinstance(x, int):
|
||||
return x, x
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def get_fill_resize_and_crop(src, tgt):
|
||||
th, tw = _to_tuple(tgt)
|
||||
h, w = _to_tuple(src)
|
||||
|
||||
tr = th / tw # base resolution
|
||||
r = h / w # target resolution
|
||||
|
||||
# resize
|
||||
if r > tr:
|
||||
resize_height = th
|
||||
resize_width = int(round(th / h * w))
|
||||
else:
|
||||
resize_width = tw
|
||||
resize_height = int(round(tw / w * h)) # resize the target resolution down based on the base resolution
|
||||
|
||||
crop_top = int(round((th - resize_height) / 2.0))
|
||||
crop_left = int(round((tw - resize_width) / 2.0))
|
||||
|
||||
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
||||
|
||||
|
||||
def get_meshgrid(start, *args):
|
||||
if len(args) == 0:
|
||||
# start is grid_size
|
||||
num = _to_tuple(start)
|
||||
start = (0, 0)
|
||||
stop = num
|
||||
elif len(args) == 1:
|
||||
# start is start, args[0] is stop, step is 1
|
||||
start = _to_tuple(start)
|
||||
stop = _to_tuple(args[0])
|
||||
num = (stop[0] - start[0], stop[1] - start[1])
|
||||
elif len(args) == 2:
|
||||
# start is start, args[0] is stop, args[1] is num
|
||||
start = _to_tuple(start)
|
||||
stop = _to_tuple(args[0])
|
||||
num = _to_tuple(args[1])
|
||||
else:
|
||||
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
||||
|
||||
grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0) # [2, W, H]
|
||||
return grid
|
||||
|
||||
#################################################################################
|
||||
# Sine/Cosine Positional Embedding Functions #
|
||||
#################################################################################
|
||||
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
||||
|
||||
def get_2d_sincos_pos_embed(embed_dim, start, *args, cls_token=False, extra_tokens=0):
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
grid = get_meshgrid(start, *args) # [2, H, w]
|
||||
# grid_h = np.arange(grid_size, dtype=np.float32)
|
||||
# grid_w = np.arange(grid_size, dtype=np.float32)
|
||||
# grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
# grid = np.stack(grid, axis=0) # [2, W, H]
|
||||
|
||||
grid = grid.reshape([2, 1, *grid.shape[1:]])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token and extra_tokens > 0:
|
||||
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (W,H)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega /= embed_dim / 2.
|
||||
omega = 1. / 10000**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Rotary Positional Embedding Functions #
|
||||
#################################################################################
|
||||
# https://github.com/facebookresearch/llama/blob/main/llama/model.py#L443
|
||||
|
||||
def get_2d_rotary_pos_embed(embed_dim, start, *args, use_real=True):
|
||||
"""
|
||||
This is a 2d version of precompute_freqs_cis, which is a RoPE for image tokens with 2d structure.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
embed_dim: int
|
||||
embedding dimension size
|
||||
start: int or tuple of int
|
||||
If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1;
|
||||
If len(args) == 2, start is start, args[0] is stop, args[1] is num.
|
||||
use_real: bool
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pos_embed: torch.Tensor
|
||||
[HW, D/2]
|
||||
"""
|
||||
grid = get_meshgrid(start, *args) # [2, H, w]
|
||||
grid = grid.reshape([2, 1, *grid.shape[1:]]) # Returns a sampling matrix with the same resolution as the target resolution
|
||||
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
|
||||
assert embed_dim % 4 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
|
||||
emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
|
||||
|
||||
if use_real:
|
||||
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
|
||||
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
|
||||
return cos, sin
|
||||
else:
|
||||
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
|
||||
"""
|
||||
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
||||
|
||||
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
|
||||
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
||||
The returned tensor contains complex values in complex64 data type.
|
||||
|
||||
Args:
|
||||
dim (int): Dimension of the frequency tensor.
|
||||
pos (np.ndarray, int): Position indices for the frequency tensor. [S] or scalar
|
||||
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
||||
use_real (bool, optional): If True, return real part and imaginary part separately.
|
||||
Otherwise, return complex numbers.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
||||
|
||||
"""
|
||||
if isinstance(pos, int):
|
||||
pos = np.arange(pos)
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
|
||||
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
|
||||
if use_real:
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
||||
return freqs_cis
|
||||
|
||||
|
||||
|
||||
def calc_sizes(rope_img, patch_size, th, tw):
|
||||
if rope_img == 'extend':
|
||||
# Expansion mode
|
||||
sub_args = [(th, tw)]
|
||||
elif rope_img.startswith('base'):
|
||||
# Based on the specified dimensions, other dimensions are obtained through interpolation.
|
||||
base_size = int(rope_img[4:]) // 8 // patch_size
|
||||
start, stop = get_fill_resize_and_crop((th, tw), base_size)
|
||||
sub_args = [start, stop, (th, tw)]
|
||||
else:
|
||||
raise ValueError(f"Unknown rope_img: {rope_img}")
|
||||
return sub_args
|
||||
|
||||
|
||||
def init_image_posemb(rope_img,
|
||||
resolutions,
|
||||
patch_size,
|
||||
hidden_size,
|
||||
num_heads,
|
||||
log_fn,
|
||||
rope_real=True,
|
||||
):
|
||||
freqs_cis_img = {}
|
||||
for reso in resolutions:
|
||||
th, tw = reso.height // 8 // patch_size, reso.width // 8 // patch_size
|
||||
sub_args = calc_sizes(rope_img, patch_size, th, tw)
|
||||
freqs_cis_img[str(reso)] = get_2d_rotary_pos_embed(hidden_size // num_heads, *sub_args, use_real=rope_real)
|
||||
log_fn(f" Using image RoPE ({rope_img}) ({'real' if rope_real else 'complex'}): {sub_args} | ({reso}) "
|
||||
f"{freqs_cis_img[str(reso)][0].shape if rope_real else freqs_cis_img[str(reso)].shape}")
|
||||
return freqs_cis_img
|
@ -0,0 +1,120 @@
|
||||
from comfy import sd1_clip
|
||||
from transformers import T5TokenizerFast, BertTokenizer, BertModel, modeling_utils, BertConfig
|
||||
from .spiece_tokenizer import SPieceTokenizer
|
||||
import comfy.text_encoders.t5
|
||||
import os
|
||||
|
||||
import torch
|
||||
import contextlib
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_comfy_ops(ops, device=None, dtype=None):
|
||||
old_torch_nn_linear = torch.nn.Linear
|
||||
force_device = device
|
||||
force_dtype = dtype
|
||||
def linear_with_dtype(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
|
||||
if force_device is not None:
|
||||
device = force_device
|
||||
if force_dtype is not None:
|
||||
dtype = force_dtype
|
||||
return ops.Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)
|
||||
|
||||
torch.nn.Linear = linear_with_dtype
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
torch.nn.Linear = old_torch_nn_linear
|
||||
|
||||
|
||||
class RobertaWrapper(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = BertConfig(**config_dict)
|
||||
with use_comfy_ops(operations, device, dtype):
|
||||
with modeling_utils.no_init_weights():
|
||||
self.bert = BertModel(config, add_pooling_layer=False)
|
||||
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.bert.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
return self.bert.set_input_embeddings(value)
|
||||
|
||||
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
|
||||
intermediate = None
|
||||
out = self.bert(input_ids=input_tokens, output_hidden_states=intermediate_output is not None, attention_mask=attention_mask)
|
||||
return out.last_hidden_state, intermediate, out.pooler_output
|
||||
|
||||
class HyditBertModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=RobertaWrapper, enable_attention_masks=True, return_attention_masks=True)
|
||||
|
||||
class HyditBertTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer)
|
||||
|
||||
|
||||
class MT5XLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True)
|
||||
|
||||
class MT5XLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
#tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_tokenizer"), "spiece.model")
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
|
||||
class HyditTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
mt5_tokenizer_data = tokenizer_data.get("mt5xl.spiece_model", None)
|
||||
self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory)
|
||||
self.mt5xl = MT5XLTokenizer(tokenizer_data={"spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
out = {}
|
||||
out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids)
|
||||
out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return self.hydit_clip.untokenize(token_weight_pair)
|
||||
|
||||
def state_dict(self):
|
||||
return {"mt5xl.spiece_model": self.mt5xl.state_dict()["spiece_model"]}
|
||||
|
||||
class HyditModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__()
|
||||
self.hydit_clip = HyditBertModel()
|
||||
self.mt5xl = MT5XLModel()
|
||||
|
||||
self.dtypes = set()
|
||||
if dtype is not None:
|
||||
self.dtypes.add(dtype)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
hydit_out = self.hydit_clip.encode_token_weights(token_weight_pairs["hydit_clip"])
|
||||
mt5_out = self.mt5xl.encode_token_weights(token_weight_pairs["mt5xl"])
|
||||
return hydit_out[0], hydit_out[1], {"attention_mask": hydit_out[2]["attention_mask"], "conditioning_mt5xl": mt5_out[0], "attention_mask_mt5xl": mt5_out[2]["attention_mask"]}
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "bert.encoder.layer.0.attention.self.query.weight" in sd:
|
||||
return self.hydit_clip.load_sd(sd)
|
||||
else:
|
||||
return self.mt5xl.load_sd(sd)
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.hydit_clip.set_clip_options(options)
|
||||
self.mt5xl.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
self.hydit_clip.reset_clip_options()
|
||||
self.mt5xl.reset_clip_options()
|
@ -0,0 +1,35 @@
|
||||
{
|
||||
"_name_or_path": "hfl/chinese-roberta-wwm-ext-large",
|
||||
"architectures": [
|
||||
"BertModel"
|
||||
],
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"bos_token_id": 0,
|
||||
"classifier_dropout": null,
|
||||
"directionality": "bidi",
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"hidden_size": 1024,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"layer_norm_eps": 1e-12,
|
||||
"max_position_embeddings": 512,
|
||||
"model_type": "bert",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 24,
|
||||
"output_past": true,
|
||||
"pad_token_id": 0,
|
||||
"pooler_fc_size": 768,
|
||||
"pooler_num_attention_heads": 12,
|
||||
"pooler_num_fc_layers": 3,
|
||||
"pooler_size_per_head": 128,
|
||||
"pooler_type": "first_token_transform",
|
||||
"position_embedding_type": "absolute",
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.22.1",
|
||||
"type_vocab_size": 2,
|
||||
"use_cache": true,
|
||||
"vocab_size": 47020
|
||||
}
|
||||
|
@ -0,0 +1,7 @@
|
||||
{
|
||||
"cls_token": "[CLS]",
|
||||
"mask_token": "[MASK]",
|
||||
"pad_token": "[PAD]",
|
||||
"sep_token": "[SEP]",
|
||||
"unk_token": "[UNK]"
|
||||
}
|
@ -0,0 +1,16 @@
|
||||
{
|
||||
"cls_token": "[CLS]",
|
||||
"do_basic_tokenize": true,
|
||||
"do_lower_case": true,
|
||||
"mask_token": "[MASK]",
|
||||
"name_or_path": "hfl/chinese-roberta-wwm-ext",
|
||||
"never_split": null,
|
||||
"pad_token": "[PAD]",
|
||||
"sep_token": "[SEP]",
|
||||
"special_tokens_map_file": "/home/chenweifeng/.cache/huggingface/hub/models--hfl--chinese-roberta-wwm-ext/snapshots/5c58d0b8ec1d9014354d691c538661bf00bfdb44/special_tokens_map.json",
|
||||
"strip_accents": null,
|
||||
"tokenize_chinese_chars": true,
|
||||
"tokenizer_class": "BertTokenizer",
|
||||
"unk_token": "[UNK]",
|
||||
"model_max_length": 77
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,22 @@
|
||||
{
|
||||
"d_ff": 5120,
|
||||
"d_kv": 64,
|
||||
"d_model": 2048,
|
||||
"decoder_start_token_id": 0,
|
||||
"dropout_rate": 0.1,
|
||||
"eos_token_id": 1,
|
||||
"dense_act_fn": "gelu_pytorch_tanh",
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": true,
|
||||
"is_gated_act": true,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "mt5",
|
||||
"num_decoder_layers": 24,
|
||||
"num_heads": 32,
|
||||
"num_layers": 24,
|
||||
"output_past": true,
|
||||
"pad_token_id": 0,
|
||||
"relative_attention_num_buckets": 32,
|
||||
"tie_word_embeddings": false,
|
||||
"vocab_size": 250112
|
||||
}
|
Loading…
Reference in New Issue