@ -1,5 +1,6 @@
import torch
import torch . nn as nn
from . import kornia_functions
from torch . utils . checkpoint import checkpoint
from transformers import T5Tokenizer , T5EncoderModel , CLIPTokenizer , CLIPTextModel
@ -37,7 +38,7 @@ class ClassEmbedder(nn.Module):
c = batch [ key ] [ : , None ]
if self . ucg_rate > 0. and not disable_dropout :
mask = 1. - torch . bernoulli ( torch . ones_like ( c ) * self . ucg_rate )
c = mask * c + ( 1 - mask ) * torch . ones_like ( c ) * ( self . n_classes - 1 )
c = mask * c + ( 1 - mask ) * torch . ones_like ( c ) * ( self . n_classes - 1 )
c = c . long ( )
c = self . embedding ( c )
return c
@ -57,18 +58,20 @@ def disabled_train(self, mode=True):
class FrozenT5Embedder ( AbstractEncoder ) :
""" Uses the T5 transformer encoder for text """
def __init__ ( self , version = " google/t5-v1_1-large " , device = " cuda " , max_length = 77 , freeze = True ) : # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
def __init__ ( self , version = " google/t5-v1_1-large " , device = " cuda " , max_length = 77 ,
freeze = True ) : # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super ( ) . __init__ ( )
self . tokenizer = T5Tokenizer . from_pretrained ( version )
self . transformer = T5EncoderModel . from_pretrained ( version )
self . device = device
self . max_length = max_length # TODO: typical value?
self . max_length = max_length # TODO: typical value?
if freeze :
self . freeze ( )
def freeze ( self ) :
self . transformer = self . transformer . eval ( )
# self.train = disabled_train
# self.train = disabled_train
for param in self . parameters ( ) :
param . requires_grad = False
@ -92,6 +95,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
" pooled " ,
" hidden "
]
def __init__ ( self , version = " openai/clip-vit-large-patch14 " , device = " cuda " , max_length = 77 ,
freeze = True , layer = " last " , layer_idx = None ) : # clip-vit-base-patch32
super ( ) . __init__ ( )
@ -110,7 +114,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
def freeze ( self ) :
self . transformer = self . transformer . eval ( )
# self.train = disabled_train
# self.train = disabled_train
for param in self . parameters ( ) :
param . requires_grad = False
@ -118,7 +122,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
batch_encoding = self . tokenizer ( text , truncation = True , max_length = self . max_length , return_length = True ,
return_overflowing_tokens = False , padding = " max_length " , return_tensors = " pt " )
tokens = batch_encoding [ " input_ids " ] . to ( self . device )
outputs = self . transformer ( input_ids = tokens , output_hidden_states = self . layer == " hidden " )
outputs = self . transformer ( input_ids = tokens , output_hidden_states = self . layer == " hidden " )
if self . layer == " last " :
z = outputs . last_hidden_state
elif self . layer == " pooled " :
@ -131,15 +135,55 @@ class FrozenCLIPEmbedder(AbstractEncoder):
return self ( text )
class ClipImageEmbedder ( nn . Module ) :
def __init__ (
self ,
model ,
jit = False ,
device = ' cuda ' if torch . cuda . is_available ( ) else ' cpu ' ,
antialias = True ,
ucg_rate = 0.
) :
super ( ) . __init__ ( )
from clip import load as load_clip
self . model , _ = load_clip ( name = model , device = device , jit = jit )
self . antialias = antialias
self . register_buffer ( ' mean ' , torch . Tensor ( [ 0.48145466 , 0.4578275 , 0.40821073 ] ) , persistent = False )
self . register_buffer ( ' std ' , torch . Tensor ( [ 0.26862954 , 0.26130258 , 0.27577711 ] ) , persistent = False )
self . ucg_rate = ucg_rate
def preprocess ( self , x ) :
# normalize to [0,1]
# x = kornia_functions.geometry_resize(x, (224, 224),
# interpolation='bicubic', align_corners=True,
# antialias=self.antialias)
x = torch . nn . functional . interpolate ( x , size = ( 224 , 224 ) , mode = ' bicubic ' , align_corners = True , antialias = True )
x = ( x + 1. ) / 2.
# re-normalize according to clip
x = kornia_functions . enhance_normalize ( x , self . mean , self . std )
return x
def forward ( self , x , no_dropout = False ) :
# x is assumed to be in range [-1,1]
out = self . model . encode_image ( self . preprocess ( x ) )
out = out . to ( x . dtype )
if self . ucg_rate > 0. and not no_dropout :
out = torch . bernoulli ( ( 1. - self . ucg_rate ) * torch . ones ( out . shape [ 0 ] , device = out . device ) ) [ : , None ] * out
return out
class FrozenOpenCLIPEmbedder ( AbstractEncoder ) :
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS = [
#"pooled",
# "pooled",
" last " ,
" penultimate "
]
def __init__ ( self , arch = " ViT-H-14 " , version = " laion2b_s32b_b79k " , device = " cuda " , max_length = 77 ,
freeze = True , layer = " last " ) :
super ( ) . __init__ ( )
@ -179,7 +223,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
x = self . model . ln_final ( x )
return x
def text_transformer_forward ( self , x : torch . Tensor , attn_mask = None ) :
def text_transformer_forward ( self , x : torch . Tensor , attn_mask = None ) :
for i , r in enumerate ( self . model . transformer . resblocks ) :
if i == len ( self . model . transformer . resblocks ) - self . layer_idx :
break
@ -193,14 +237,73 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
return self ( text )
class FrozenOpenCLIPImageEmbedder ( AbstractEncoder ) :
"""
Uses the OpenCLIP vision transformer encoder for images
"""
def __init__ ( self , arch = " ViT-H-14 " , version = " laion2b_s32b_b79k " , device = " cuda " , max_length = 77 ,
freeze = True , layer = " pooled " , antialias = True , ucg_rate = 0. ) :
super ( ) . __init__ ( )
model , _ , _ = open_clip . create_model_and_transforms ( arch , device = torch . device ( ' cpu ' ) ,
pretrained = version , )
del model . transformer
self . model = model
self . device = device
self . max_length = max_length
if freeze :
self . freeze ( )
self . layer = layer
if self . layer == " penultimate " :
raise NotImplementedError ( )
self . layer_idx = 1
self . antialias = antialias
self . register_buffer ( ' mean ' , torch . Tensor ( [ 0.48145466 , 0.4578275 , 0.40821073 ] ) , persistent = False )
self . register_buffer ( ' std ' , torch . Tensor ( [ 0.26862954 , 0.26130258 , 0.27577711 ] ) , persistent = False )
self . ucg_rate = ucg_rate
def preprocess ( self , x ) :
# normalize to [0,1]
# x = kornia.geometry.resize(x, (224, 224),
# interpolation='bicubic', align_corners=True,
# antialias=self.antialias)
x = torch . nn . functional . interpolate ( x , size = ( 224 , 224 ) , mode = ' bicubic ' , align_corners = True , antialias = True )
x = ( x + 1. ) / 2.
# renormalize according to clip
x = kornia_functions . enhance_normalize ( x , self . mean , self . std )
return x
def freeze ( self ) :
self . model = self . model . eval ( )
for param in self . parameters ( ) :
param . requires_grad = False
def forward ( self , image , no_dropout = False ) :
z = self . encode_with_vision_transformer ( image )
if self . ucg_rate > 0. and not no_dropout :
z = torch . bernoulli ( ( 1. - self . ucg_rate ) * torch . ones ( z . shape [ 0 ] , device = z . device ) ) [ : , None ] * z
return z
def encode_with_vision_transformer ( self , img ) :
img = self . preprocess ( img )
x = self . model . visual ( img )
return x
def encode ( self , text ) :
return self ( text )
class FrozenCLIPT5Encoder ( AbstractEncoder ) :
def __init__ ( self , clip_version = " openai/clip-vit-large-patch14 " , t5_version = " google/t5-v1_1-xl " , device = " cuda " ,
clip_max_length = 77 , t5_max_length = 77 ) :
super ( ) . __init__ ( )
self . clip_encoder = FrozenCLIPEmbedder ( clip_version , device , max_length = clip_max_length )
self . t5_encoder = FrozenT5Embedder ( t5_version , device , max_length = t5_max_length )
print ( f " { self . clip_encoder . __class__ . __name__ } has { count_params ( self . clip_encoder ) * 1.e-6 : .2f } M parameters, "
f " { self . t5_encoder . __class__ . __name__ } comes with { count_params ( self . t5_encoder ) * 1.e-6 : .2f } M params. " )
print ( f " { self . clip_encoder . __class__ . __name__ } has { count_params ( self . clip_encoder ) * 1.e-6 : .2f } M parameters, "
f " { self . t5_encoder . __class__ . __name__ } comes with { count_params ( self . t5_encoder ) * 1.e-6 : .2f } M params. " )
def encode ( self , text ) :
return self ( text )
@ -209,5 +312,3 @@ class FrozenCLIPT5Encoder(AbstractEncoder):
clip_z = self . clip_encoder . encode ( text )
t5_z = self . t5_encoder . encode ( text )
return [ clip_z , t5_z ]