|
|
|
@ -169,7 +169,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|
|
|
|
tokens = torch.LongTensor(tokens).to(device)
|
|
|
|
|
|
|
|
|
|
attention_mask = None
|
|
|
|
|
if self.enable_attention_masks:
|
|
|
|
|
if self.enable_attention_masks or self.zero_out_masked:
|
|
|
|
|
attention_mask = torch.zeros_like(tokens)
|
|
|
|
|
end_token = self.special_tokens.get("end", -1)
|
|
|
|
|
for x in range(attention_mask.shape[0]):
|
|
|
|
@ -178,7 +178,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|
|
|
|
if tokens[x, y] == end_token:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
|
|
|
|
|
attention_mask_model = None
|
|
|
|
|
if self.enable_attention_masks:
|
|
|
|
|
attention_mask_model = attention_mask
|
|
|
|
|
|
|
|
|
|
outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
|
|
|
|
|
self.transformer.set_input_embeddings(backup_embeds)
|
|
|
|
|
|
|
|
|
|
if self.layer == "last":
|
|
|
|
@ -186,7 +190,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|
|
|
|
else:
|
|
|
|
|
z = outputs[1].float()
|
|
|
|
|
|
|
|
|
|
if self.zero_out_masked and attention_mask is not None:
|
|
|
|
|
if self.zero_out_masked:
|
|
|
|
|
z *= attention_mask.unsqueeze(-1).float()
|
|
|
|
|
|
|
|
|
|
pooled_output = None
|
|
|
|
|