|
|
|
@ -517,7 +517,11 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|
|
|
|
if model_config is None:
|
|
|
|
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
|
|
|
|
|
|
|
|
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=[weight_dtype] + model_config.supported_inference_dtypes)
|
|
|
|
|
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
|
|
|
|
if weight_dtype is not None:
|
|
|
|
|
unet_weight_dtype.append(weight_dtype)
|
|
|
|
|
|
|
|
|
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
|
|
|
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
|
|
|
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
|
|
|
|
|
|
|
|
|