|
|
|
@ -151,7 +151,7 @@ class CLIPVisionEmbeddings(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
def forward(self, pixel_values):
|
|
|
|
|
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
|
|
|
|
|
return torch.cat([self.class_embedding.expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight
|
|
|
|
|
return torch.cat([self.class_embedding.to(embeds.device).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight.to(embeds.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CLIPVision(torch.nn.Module):
|
|
|
|
|