Fix to get fp8 working on T5 base.

main
comfyanonymous 7 months ago
parent a5991a7aa6
commit c24f897352

@ -236,4 +236,6 @@ class T5(torch.nn.Module):
def forward(self, input_ids, *args, **kwargs):
x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32))
if self.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
x = torch.nan_to_num(x) #Fix for fp8 T5 base
return self.encoder(x, *args, **kwargs)

Loading…
Cancel
Save