|
|
|
@ -1,5 +1,6 @@
|
|
|
|
|
from comfy import sd1_clip
|
|
|
|
|
import comfy.text_encoders.t5
|
|
|
|
|
import comfy.model_management
|
|
|
|
|
from transformers import T5TokenizerFast
|
|
|
|
|
import torch
|
|
|
|
|
import os
|
|
|
|
@ -34,11 +35,12 @@ class FluxTokenizer:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FluxClipModel(torch.nn.Module):
|
|
|
|
|
def __init__(self, device="cpu", dtype=None):
|
|
|
|
|
def __init__(self, dtype_t5=None, device="cpu", dtype=None):
|
|
|
|
|
super().__init__()
|
|
|
|
|
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
|
|
|
|
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False)
|
|
|
|
|
self.t5xxl = T5XXLModel(device=device, dtype=dtype)
|
|
|
|
|
self.dtypes = set([dtype])
|
|
|
|
|
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
|
|
|
|
|
self.dtypes = set([dtype, dtype_t5])
|
|
|
|
|
|
|
|
|
|
def set_clip_options(self, options):
|
|
|
|
|
self.clip_l.set_clip_options(options)
|
|
|
|
@ -62,3 +64,8 @@ class FluxClipModel(torch.nn.Module):
|
|
|
|
|
else:
|
|
|
|
|
return self.t5xxl.load_sd(sd)
|
|
|
|
|
|
|
|
|
|
def flux_clip(dtype_t5=None):
|
|
|
|
|
class FluxClipModel_(FluxClipModel):
|
|
|
|
|
def __init__(self, device="cpu", dtype=None):
|
|
|
|
|
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype)
|
|
|
|
|
return FluxClipModel_
|
|
|
|
|