diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 78e556b..3b812b4 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -169,7 +169,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): tokens = torch.LongTensor(tokens).to(device) attention_mask = None - if self.enable_attention_masks: + if self.enable_attention_masks or self.zero_out_masked: attention_mask = torch.zeros_like(tokens) end_token = self.special_tokens.get("end", -1) for x in range(attention_mask.shape[0]): @@ -178,7 +178,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): if tokens[x, y] == end_token: break - outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state) + attention_mask_model = None + if self.enable_attention_masks: + attention_mask_model = attention_mask + + outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state) self.transformer.set_input_embeddings(backup_embeds) if self.layer == "last": @@ -186,7 +190,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): else: z = outputs[1].float() - if self.zero_out_masked and attention_mask is not None: + if self.zero_out_masked: z *= attention_mask.unsqueeze(-1).float() pooled_output = None