|
|
|
@ -498,14 +498,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|
|
|
|
|
|
|
|
|
return (model, clip, vae)
|
|
|
|
|
|
|
|
|
|
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
|
|
|
|
|
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}):
|
|
|
|
|
sd = comfy.utils.load_torch_file(ckpt_path)
|
|
|
|
|
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model)
|
|
|
|
|
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options)
|
|
|
|
|
if out is None:
|
|
|
|
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
|
|
|
|
|
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}):
|
|
|
|
|
clip = None
|
|
|
|
|
clipvision = None
|
|
|
|
|
vae = None
|
|
|
|
@ -525,7 +525,12 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|
|
|
|
if weight_dtype is not None:
|
|
|
|
|
unet_weight_dtype.append(weight_dtype)
|
|
|
|
|
|
|
|
|
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
|
|
|
|
model_config.custom_operations = model_options.get("custom_operations", None)
|
|
|
|
|
unet_dtype = model_options.get("weight_dtype", None)
|
|
|
|
|
|
|
|
|
|
if unet_dtype is None:
|
|
|
|
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
|
|
|
|
|
|
|
|
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
|
|
|
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
|
|
|
|
|
|
|
|
|