diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 477d5c3..b84a384 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -60,6 +60,9 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): if dtype is not None: self.transformer.to(dtype) + self.transformer.text_model.embeddings.token_embedding.to(torch.float32) + self.transformer.text_model.embeddings.position_embedding.to(torch.float32) + self.max_length = max_length if freeze: self.freeze() @@ -138,7 +141,7 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): tokens = self.set_up_textual_embeddings(tokens, backup_embeds) tokens = torch.LongTensor(tokens).to(device) - if backup_embeds.weight.dtype != torch.float32: + if self.transformer.text_model.final_layer_norm.weight.dtype != torch.float32: precision_scope = torch.autocast else: precision_scope = lambda a, b: contextlib.nullcontext(a)