diff --git a/comfy/sd.py b/comfy/sd.py index f3ec62b..d8c0bfa 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,4 +1,5 @@ import torch +from enum import Enum from comfy import model_management from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine @@ -309,8 +310,11 @@ def load_style_model(ckpt_path): model.load_state_dict(model_data) return StyleModel(model) +class CLIPType(Enum): + STABLE_DIFFUSION = 1 + STABLE_CASCADE = 2 -def load_clip(ckpt_paths, embedding_directory=None): +def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION): clip_data = [] for p in ckpt_paths: clip_data.append(comfy.utils.load_torch_file(p, safe_load=True)) @@ -326,8 +330,12 @@ def load_clip(ckpt_paths, embedding_directory=None): clip_target.params = {} if len(clip_data) == 1: if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]: - clip_target.clip = sdxl_clip.SDXLRefinerClipModel - clip_target.tokenizer = sdxl_clip.SDXLTokenizer + if clip_type == CLIPType.STABLE_CASCADE: + clip_target.clip = sdxl_clip.StableCascadeClipModel + clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer + else: + clip_target.clip = sdxl_clip.SDXLRefinerClipModel + clip_target.tokenizer = sdxl_clip.SDXLTokenizer elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]: clip_target.clip = sd2_clip.SD2ClipModel clip_target.tokenizer = sd2_clip.SD2Tokenizer diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 65ea909..8287ad2 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -67,7 +67,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): ] def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel, - special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True): # clip-vit-base-patch32 + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS @@ -88,7 +88,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.special_tokens = special_tokens self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) - self.enable_attention_masks = False + self.enable_attention_masks = enable_attention_masks self.layer_norm_hidden_state = layer_norm_hidden_state if layer == "hidden": diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index b35056b..3ce5c7e 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -64,3 +64,25 @@ class SDXLClipModel(torch.nn.Module): class SDXLRefinerClipModel(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None): super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG) + + +class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer): + def __init__(self, tokenizer_path=None, embedding_directory=None): + super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g') + +class StableCascadeTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None): + super().__init__(embedding_directory=embedding_directory, clip_name="g", tokenizer=StableCascadeClipGTokenizer) + +class StableCascadeClipG(sd1_clip.SDClipModel): + def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None): + textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") + super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True) + + def load_sd(self, sd): + return super().load_sd(sd) + +class StableCascadeClipModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None): + super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 7859bac..3a317ed 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -336,7 +336,7 @@ class Stable_Cascade_C(supported_models_base.BASE): return out def clip_target(self): - return None + return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel) class Stable_Cascade_B(Stable_Cascade_C): unet_config = { diff --git a/nodes.py b/nodes.py index 7b8f09d..47203f4 100644 --- a/nodes.py +++ b/nodes.py @@ -854,15 +854,20 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ), + "type": (["stable_diffusion", "stable_cascade"], ), }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" CATEGORY = "advanced/loaders" - def load_clip(self, clip_name): + def load_clip(self, clip_name, type="stable_diffusion"): + clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION + if type == "stable_cascade": + clip_type = comfy.sd.CLIPType.STABLE_CASCADE + clip_path = folder_paths.get_full_path("clip", clip_name) - clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings")) + clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) return (clip,) class DualCLIPLoader: