|
|
|
@ -62,7 +62,16 @@ class ClipTokenWeightEncoder:
|
|
|
|
|
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
|
|
|
|
|
else:
|
|
|
|
|
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
|
|
|
|
|
r = r + tuple(map(lambda a: a[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device()), o[2:]))
|
|
|
|
|
|
|
|
|
|
if len(o) > 2:
|
|
|
|
|
extra = {}
|
|
|
|
|
for k in o[2]:
|
|
|
|
|
v = o[2][k]
|
|
|
|
|
if k == "attention_mask":
|
|
|
|
|
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
|
|
|
|
|
extra[k] = v
|
|
|
|
|
|
|
|
|
|
r = r + (extra,)
|
|
|
|
|
return r
|
|
|
|
|
|
|
|
|
|
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|
|
|
@ -206,8 +215,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|
|
|
|
elif outputs[2] is not None:
|
|
|
|
|
pooled_output = outputs[2].float()
|
|
|
|
|
|
|
|
|
|
extra = {}
|
|
|
|
|
if self.return_attention_masks:
|
|
|
|
|
return z, pooled_output, attention_mask
|
|
|
|
|
extra["attention_mask"] = attention_mask
|
|
|
|
|
|
|
|
|
|
if len(extra) > 0:
|
|
|
|
|
return z, pooled_output, extra
|
|
|
|
|
|
|
|
|
|
return z, pooled_output
|
|
|
|
|
|
|
|
|
@ -547,8 +560,8 @@ class SD1ClipModel(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
def encode_token_weights(self, token_weight_pairs):
|
|
|
|
|
token_weight_pairs = token_weight_pairs[self.clip_name]
|
|
|
|
|
out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
|
|
|
|
|
return out, pooled
|
|
|
|
|
out = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
def load_sd(self, sd):
|
|
|
|
|
return getattr(self, self.clip).load_sd(sd)
|
|
|
|
|