|
|
|
@ -60,6 +60,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|
|
|
|
|
|
|
|
|
if dtype is not None:
|
|
|
|
|
self.transformer.to(dtype)
|
|
|
|
|
self.transformer.text_model.embeddings.token_embedding.to(torch.float32)
|
|
|
|
|
self.transformer.text_model.embeddings.position_embedding.to(torch.float32)
|
|
|
|
|
|
|
|
|
|
self.max_length = max_length
|
|
|
|
|
if freeze:
|
|
|
|
|
self.freeze()
|
|
|
|
@ -138,7 +141,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|
|
|
|
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
|
|
|
|
tokens = torch.LongTensor(tokens).to(device)
|
|
|
|
|
|
|
|
|
|
if backup_embeds.weight.dtype != torch.float32:
|
|
|
|
|
if self.transformer.text_model.final_layer_norm.weight.dtype != torch.float32:
|
|
|
|
|
precision_scope = torch.autocast
|
|
|
|
|
else:
|
|
|
|
|
precision_scope = lambda a, b: contextlib.nullcontext(a)
|
|
|
|
|