|
|
|
@ -75,7 +75,7 @@ class SD20(supported_models_base.BASE):
|
|
|
|
|
replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format
|
|
|
|
|
replace_prefix["cond_stage_model.model."] = "clip_h."
|
|
|
|
|
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
|
|
|
|
state_dict = utils.transformers_convert(state_dict, "clip_h.", "clip_h.transformer.text_model.", 24)
|
|
|
|
|
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.")
|
|
|
|
|
return state_dict
|
|
|
|
|
|
|
|
|
|
def process_clip_state_dict_for_saving(self, state_dict):
|
|
|
|
@ -134,7 +134,7 @@ class SDXLRefiner(supported_models_base.BASE):
|
|
|
|
|
replace_prefix["conditioner.embedders.0.model."] = "clip_g."
|
|
|
|
|
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
|
|
|
|
|
|
|
|
|
state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32)
|
|
|
|
|
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
|
|
|
|
|
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
|
|
|
|
|
return state_dict
|
|
|
|
|
|
|
|
|
@ -182,10 +182,8 @@ class SDXL(supported_models_base.BASE):
|
|
|
|
|
replace_prefix["conditioner.embedders.1.model."] = "clip_g."
|
|
|
|
|
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
|
|
|
|
|
|
|
|
|
state_dict = utils.transformers_convert(state_dict, "clip_g.", "clip_g.transformer.text_model.", 32)
|
|
|
|
|
keys_to_replace["clip_g.text_projection.weight"] = "clip_g.text_projection"
|
|
|
|
|
|
|
|
|
|
state_dict = utils.state_dict_key_replace(state_dict, keys_to_replace)
|
|
|
|
|
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_g.", "clip_g.transformer.")
|
|
|
|
|
return state_dict
|
|
|
|
|
|
|
|
|
|
def process_clip_state_dict_for_saving(self, state_dict):
|
|
|
|
@ -338,6 +336,12 @@ class Stable_Cascade_C(supported_models_base.BASE):
|
|
|
|
|
state_dict[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
|
|
|
|
return state_dict
|
|
|
|
|
|
|
|
|
|
def process_clip_state_dict(self, state_dict):
|
|
|
|
|
state_dict = utils.state_dict_prefix_replace(state_dict, {k: "" for k in self.text_encoder_key_prefix}, filter_keys=True)
|
|
|
|
|
if "clip_g.text_projection" in state_dict:
|
|
|
|
|
state_dict["clip_g.transformer.text_projection.weight"] = state_dict.pop("clip_g.text_projection").transpose(0, 1)
|
|
|
|
|
return state_dict
|
|
|
|
|
|
|
|
|
|
def get_model(self, state_dict, prefix="", device=None):
|
|
|
|
|
out = model_base.StableCascade_C(self, device=device)
|
|
|
|
|
return out
|
|
|
|
|