|
|
|
@ -9,6 +9,7 @@ from einops import rearrange
|
|
|
|
|
from torch import nn
|
|
|
|
|
from torch.nn import functional as F
|
|
|
|
|
import math
|
|
|
|
|
import comfy.ops
|
|
|
|
|
|
|
|
|
|
class FourierFeatures(nn.Module):
|
|
|
|
|
def __init__(self, in_features, out_features, std=1., dtype=None, device=None):
|
|
|
|
@ -18,7 +19,7 @@ class FourierFeatures(nn.Module):
|
|
|
|
|
[out_features // 2, in_features], dtype=dtype, device=device))
|
|
|
|
|
|
|
|
|
|
def forward(self, input):
|
|
|
|
|
f = 2 * math.pi * input @ self.weight.T.to(dtype=input.dtype, device=input.device)
|
|
|
|
|
f = 2 * math.pi * input @ comfy.ops.cast_to_input(self.weight.T, input)
|
|
|
|
|
return torch.cat([f.cos(), f.sin()], dim=-1)
|
|
|
|
|
|
|
|
|
|
# norms
|
|
|
|
@ -38,9 +39,9 @@ class LayerNorm(nn.Module):
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
beta = self.beta
|
|
|
|
|
if self.beta is not None:
|
|
|
|
|
beta = beta.to(dtype=x.dtype, device=x.device)
|
|
|
|
|
return F.layer_norm(x, x.shape[-1:], weight=self.gamma.to(dtype=x.dtype, device=x.device), bias=beta)
|
|
|
|
|
if beta is not None:
|
|
|
|
|
beta = comfy.ops.cast_to_input(beta, x)
|
|
|
|
|
return F.layer_norm(x, x.shape[-1:], weight=comfy.ops.cast_to_input(self.gamma, x), bias=beta)
|
|
|
|
|
|
|
|
|
|
class GLU(nn.Module):
|
|
|
|
|
def __init__(
|
|
|
|
@ -123,7 +124,9 @@ class RotaryEmbedding(nn.Module):
|
|
|
|
|
scale_base = 512,
|
|
|
|
|
interpolation_factor = 1.,
|
|
|
|
|
base = 10000,
|
|
|
|
|
base_rescale_factor = 1.
|
|
|
|
|
base_rescale_factor = 1.,
|
|
|
|
|
dtype=None,
|
|
|
|
|
device=None,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
|
|
|
@ -131,8 +134,8 @@ class RotaryEmbedding(nn.Module):
|
|
|
|
|
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
|
|
|
|
base *= base_rescale_factor ** (dim / (dim - 2))
|
|
|
|
|
|
|
|
|
|
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
|
|
|
|
self.register_buffer('inv_freq', inv_freq)
|
|
|
|
|
# inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
|
|
|
|
self.register_buffer('inv_freq', torch.empty((dim // 2,), device=device, dtype=dtype))
|
|
|
|
|
|
|
|
|
|
assert interpolation_factor >= 1.
|
|
|
|
|
self.interpolation_factor = interpolation_factor
|
|
|
|
@ -161,14 +164,14 @@ class RotaryEmbedding(nn.Module):
|
|
|
|
|
|
|
|
|
|
t = t / self.interpolation_factor
|
|
|
|
|
|
|
|
|
|
freqs = torch.einsum('i , j -> i j', t, self.inv_freq.to(dtype=dtype, device=device))
|
|
|
|
|
freqs = torch.einsum('i , j -> i j', t, comfy.ops.cast_to_input(self.inv_freq, t))
|
|
|
|
|
freqs = torch.cat((freqs, freqs), dim = -1)
|
|
|
|
|
|
|
|
|
|
if self.scale is None:
|
|
|
|
|
return freqs, 1.
|
|
|
|
|
|
|
|
|
|
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
|
|
|
|
|
scale = self.scale.to(dtype=dtype, device=device) ** rearrange(power, 'n -> n 1')
|
|
|
|
|
scale = comfy.ops.cast_to_input(self.scale, t) ** rearrange(power, 'n -> n 1')
|
|
|
|
|
scale = torch.cat((scale, scale), dim = -1)
|
|
|
|
|
|
|
|
|
|
return freqs, scale
|
|
|
|
@ -568,7 +571,7 @@ class ContinuousTransformer(nn.Module):
|
|
|
|
|
self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity()
|
|
|
|
|
|
|
|
|
|
if rotary_pos_emb:
|
|
|
|
|
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
|
|
|
|
|
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32), device=device, dtype=dtype)
|
|
|
|
|
else:
|
|
|
|
|
self.rotary_pos_emb = None
|
|
|
|
|
|
|
|
|
|