|
|
@ -6,7 +6,7 @@ import sd2_clip
|
|
|
|
import model_management
|
|
|
|
import model_management
|
|
|
|
from .ldm.util import instantiate_from_config
|
|
|
|
from .ldm.util import instantiate_from_config
|
|
|
|
from .ldm.models.autoencoder import AutoencoderKL
|
|
|
|
from .ldm.models.autoencoder import AutoencoderKL
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
import yaml
|
|
|
|
from .cldm import cldm
|
|
|
|
from .cldm import cldm
|
|
|
|
from .t2i_adapter import adapter
|
|
|
|
from .t2i_adapter import adapter
|
|
|
|
|
|
|
|
|
|
|
@ -726,7 +726,8 @@ def load_clip(ckpt_path, embedding_directory=None):
|
|
|
|
return clip
|
|
|
|
return clip
|
|
|
|
|
|
|
|
|
|
|
|
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None):
|
|
|
|
def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None):
|
|
|
|
config = OmegaConf.load(config_path)
|
|
|
|
with open(config_path, 'r') as stream:
|
|
|
|
|
|
|
|
config = yaml.safe_load(stream)
|
|
|
|
model_config_params = config['model']['params']
|
|
|
|
model_config_params = config['model']['params']
|
|
|
|
clip_config = model_config_params['cond_stage_config']
|
|
|
|
clip_config = model_config_params['cond_stage_config']
|
|
|
|
scale_factor = model_config_params['scale_factor']
|
|
|
|
scale_factor = model_config_params['scale_factor']
|
|
|
@ -750,7 +751,7 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e
|
|
|
|
w.cond_stage_model = clip.cond_stage_model
|
|
|
|
w.cond_stage_model = clip.cond_stage_model
|
|
|
|
load_state_dict_to = [w]
|
|
|
|
load_state_dict_to = [w]
|
|
|
|
|
|
|
|
|
|
|
|
model = instantiate_from_config(config.model)
|
|
|
|
model = instantiate_from_config(config["model"])
|
|
|
|
sd = load_torch_file(ckpt_path)
|
|
|
|
sd = load_torch_file(ckpt_path)
|
|
|
|
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
|
|
|
|
model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to)
|
|
|
|
return (ModelPatcher(model), clip, vae)
|
|
|
|
return (ModelPatcher(model), clip, vae)
|
|
|
|