|
|
|
@ -140,15 +140,13 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|
|
|
|
|
|
|
|
|
def set_up_textual_embeddings(self, tokens, current_embeds):
|
|
|
|
|
out_tokens = []
|
|
|
|
|
next_new_token = token_dict_size = current_embeds.weight.shape[0] - 1
|
|
|
|
|
next_new_token = token_dict_size = current_embeds.weight.shape[0]
|
|
|
|
|
embedding_weights = []
|
|
|
|
|
|
|
|
|
|
for x in tokens:
|
|
|
|
|
tokens_temp = []
|
|
|
|
|
for y in x:
|
|
|
|
|
if isinstance(y, numbers.Integral):
|
|
|
|
|
if y == token_dict_size: #EOS token
|
|
|
|
|
y = -1
|
|
|
|
|
tokens_temp += [int(y)]
|
|
|
|
|
else:
|
|
|
|
|
if y.shape[0] == current_embeds.weight.shape[1]:
|
|
|
|
@ -164,11 +162,10 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|
|
|
|
n = token_dict_size
|
|
|
|
|
if len(embedding_weights) > 0:
|
|
|
|
|
new_embedding = torch.nn.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
|
|
|
|
|
new_embedding.weight[:token_dict_size] = current_embeds.weight[:-1]
|
|
|
|
|
new_embedding.weight[:token_dict_size] = current_embeds.weight
|
|
|
|
|
for x in embedding_weights:
|
|
|
|
|
new_embedding.weight[n] = x
|
|
|
|
|
n += 1
|
|
|
|
|
new_embedding.weight[n] = current_embeds.weight[-1] #EOS embedding
|
|
|
|
|
self.transformer.set_input_embeddings(new_embedding)
|
|
|
|
|
|
|
|
|
|
processed_tokens = []
|
|
|
|
|