|
|
|
@ -38,7 +38,9 @@ class ClipTokenWeightEncoder:
|
|
|
|
|
if has_weights or sections == 0:
|
|
|
|
|
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
|
|
|
|
|
|
|
|
|
|
out, pooled = self.encode(to_encode)
|
|
|
|
|
o = self.encode(to_encode)
|
|
|
|
|
out, pooled = o[:2]
|
|
|
|
|
|
|
|
|
|
if pooled is not None:
|
|
|
|
|
first_pooled = pooled[0:1].to(model_management.intermediate_device())
|
|
|
|
|
else:
|
|
|
|
@ -57,8 +59,11 @@ class ClipTokenWeightEncoder:
|
|
|
|
|
output.append(z)
|
|
|
|
|
|
|
|
|
|
if (len(output) == 0):
|
|
|
|
|
return out[-1:].to(model_management.intermediate_device()), first_pooled
|
|
|
|
|
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
|
|
|
|
|
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
|
|
|
|
|
else:
|
|
|
|
|
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
|
|
|
|
|
r = r + tuple(map(lambda a: a[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device()), o[2:]))
|
|
|
|
|
return r
|
|
|
|
|
|
|
|
|
|
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|
|
|
|
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
|
|
|
@ -70,7 +75,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|
|
|
|
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
|
|
|
|
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
|
|
|
|
|
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
|
|
|
|
return_projected_pooled=True): # clip-vit-base-patch32
|
|
|
|
|
return_projected_pooled=True, return_attention_masks=False): # clip-vit-base-patch32
|
|
|
|
|
super().__init__()
|
|
|
|
|
assert layer in self.LAYERS
|
|
|
|
|
|
|
|
|
@ -96,6 +101,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|
|
|
|
|
|
|
|
|
self.layer_norm_hidden_state = layer_norm_hidden_state
|
|
|
|
|
self.return_projected_pooled = return_projected_pooled
|
|
|
|
|
self.return_attention_masks = return_attention_masks
|
|
|
|
|
|
|
|
|
|
if layer == "hidden":
|
|
|
|
|
assert layer_idx is not None
|
|
|
|
@ -169,7 +175,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|
|
|
|
tokens = torch.LongTensor(tokens).to(device)
|
|
|
|
|
|
|
|
|
|
attention_mask = None
|
|
|
|
|
if self.enable_attention_masks or self.zero_out_masked:
|
|
|
|
|
if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks:
|
|
|
|
|
attention_mask = torch.zeros_like(tokens)
|
|
|
|
|
end_token = self.special_tokens.get("end", -1)
|
|
|
|
|
for x in range(attention_mask.shape[0]):
|
|
|
|
@ -200,6 +206,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|
|
|
|
elif outputs[2] is not None:
|
|
|
|
|
pooled_output = outputs[2].float()
|
|
|
|
|
|
|
|
|
|
if self.return_attention_masks:
|
|
|
|
|
return z, pooled_output, attention_mask
|
|
|
|
|
|
|
|
|
|
return z, pooled_output
|
|
|
|
|
|
|
|
|
|
def encode(self, tokens):
|
|
|
|
|