Fix hunyuan dit text encoder weights always being in fp32.

main
comfyanonymous 7 months ago
parent 2c038ccef0
commit a5991a7aa6

@ -52,8 +52,8 @@ class HyditTokenizer:
class HyditModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None):
super().__init__()
self.hydit_clip = HyditBertModel()
self.mt5xl = MT5XLModel()
self.hydit_clip = HyditBertModel(dtype=dtype)
self.mt5xl = MT5XLModel(dtype=dtype)
self.dtypes = set()
if dtype is not None:

Loading…
Cancel
Save