@ -9,6 +9,8 @@ from . import sdxl_clip
from . import supported_models_base
from . import latent_formats
from . import diffusers_convert
class SD15 ( supported_models_base . BASE ) :
unet_config = {
" context_dim " : 768 ,
@ -63,6 +65,13 @@ class SD20(supported_models_base.BASE):
state_dict = utils . transformers_convert ( state_dict , " cond_stage_model.model. " , " cond_stage_model.transformer.text_model. " , 24 )
return state_dict
def process_clip_state_dict_for_saving ( self , state_dict ) :
replace_prefix = { }
replace_prefix [ " " ] = " cond_stage_model.model. "
state_dict = supported_models_base . state_dict_prefix_replace ( state_dict , replace_prefix )
state_dict = diffusers_convert . convert_text_enc_state_dict_v20 ( state_dict )
return state_dict
def clip_target ( self ) :
return supported_models_base . ClipTarget ( sd2_clip . SD2Tokenizer , sd2_clip . SD2ClipModel )
@ -113,6 +122,13 @@ class SDXLRefiner(supported_models_base.BASE):
state_dict = supported_models_base . state_dict_key_replace ( state_dict , keys_to_replace )
return state_dict
def process_clip_state_dict_for_saving ( self , state_dict ) :
replace_prefix = { }
state_dict_g = diffusers_convert . convert_text_enc_state_dict_v20 ( state_dict , " clip_g " )
replace_prefix [ " clip_g " ] = " conditioner.embedders.0.model "
state_dict_g = supported_models_base . state_dict_prefix_replace ( state_dict_g , replace_prefix )
return state_dict_g
def clip_target ( self ) :
return supported_models_base . ClipTarget ( sdxl_clip . SDXLTokenizer , sdxl_clip . SDXLRefinerClipModel )
@ -142,6 +158,19 @@ class SDXL(supported_models_base.BASE):
state_dict = supported_models_base . state_dict_key_replace ( state_dict , keys_to_replace )
return state_dict
def process_clip_state_dict_for_saving ( self , state_dict ) :
replace_prefix = { }
keys_to_replace = { }
state_dict_g = diffusers_convert . convert_text_enc_state_dict_v20 ( state_dict , " clip_g " )
for k in state_dict :
if k . startswith ( " clip_l " ) :
state_dict_g [ k ] = state_dict [ k ]
replace_prefix [ " clip_g " ] = " conditioner.embedders.1.model "
replace_prefix [ " clip_l " ] = " conditioner.embedders.0 "
state_dict_g = supported_models_base . state_dict_prefix_replace ( state_dict_g , replace_prefix )
return state_dict_g
def clip_target ( self ) :
return supported_models_base . ClipTarget ( sdxl_clip . SDXLTokenizer , sdxl_clip . SDXLClipModel )