From a47f609f904842a12c54c465fc93bda38257e289 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 2 Jan 2024 01:50:57 -0500 Subject: [PATCH] Auto detect out_channels from model. --- comfy/model_detection.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index e3af422..ad16c0f 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -34,7 +34,6 @@ def detect_unet_config(state_dict, key_prefix, dtype): unet_config = { "use_checkpoint": False, "image_size": 32, - "out_channels": 4, "use_spatial_transformer": True, "legacy": False } @@ -49,6 +48,7 @@ def detect_unet_config(state_dict, key_prefix, dtype): unet_config["dtype"] = dtype model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0] in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1] + out_channels = state_dict['{}out.2.weight'.format(key_prefix)].shape[0] num_res_blocks = [] channel_mult = [] @@ -122,6 +122,7 @@ def detect_unet_config(state_dict, key_prefix, dtype): transformer_depth_middle = -1 unet_config["in_channels"] = in_channels + unet_config["out_channels"] = out_channels unet_config["model_channels"] = model_channels unet_config["num_res_blocks"] = num_res_blocks unet_config["transformer_depth"] = transformer_depth