|
|
|
@ -7,6 +7,7 @@ import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from .. import attention
|
|
|
|
|
from einops import rearrange, repeat
|
|
|
|
|
from .util import timestep_embedding
|
|
|
|
|
|
|
|
|
|
def default(x, y):
|
|
|
|
|
if x is not None:
|
|
|
|
@ -230,34 +231,8 @@ class TimestepEmbedder(nn.Module):
|
|
|
|
|
)
|
|
|
|
|
self.frequency_embedding_size = frequency_embedding_size
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def timestep_embedding(t, dim, max_period=10000):
|
|
|
|
|
"""
|
|
|
|
|
Create sinusoidal timestep embeddings.
|
|
|
|
|
:param t: a 1-D Tensor of N indices, one per batch element.
|
|
|
|
|
These may be fractional.
|
|
|
|
|
:param dim: the dimension of the output.
|
|
|
|
|
:param max_period: controls the minimum frequency of the embeddings.
|
|
|
|
|
:return: an (N, D) Tensor of positional embeddings.
|
|
|
|
|
"""
|
|
|
|
|
half = dim // 2
|
|
|
|
|
freqs = torch.exp(
|
|
|
|
|
-math.log(max_period)
|
|
|
|
|
* torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
|
|
|
|
|
/ half
|
|
|
|
|
)
|
|
|
|
|
args = t[:, None].float() * freqs[None]
|
|
|
|
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
|
|
|
if dim % 2:
|
|
|
|
|
embedding = torch.cat(
|
|
|
|
|
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
|
|
|
|
)
|
|
|
|
|
if torch.is_floating_point(t):
|
|
|
|
|
embedding = embedding.to(dtype=t.dtype)
|
|
|
|
|
return embedding
|
|
|
|
|
|
|
|
|
|
def forward(self, t, dtype, **kwargs):
|
|
|
|
|
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
|
|
|
|
t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
|
|
|
|
t_emb = self.mlp(t_freq)
|
|
|
|
|
return t_emb
|
|
|
|
|
|
|
|
|
|