From 58b2364f58e88e54159399b39452d060c8c609ac Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 21 Jul 2023 14:38:56 -0400 Subject: [PATCH] Properly support SDXL diffusers unet with UNETLoader node. --- comfy/sd.py | 5 ++++- comfy/utils.py | 34 +++++++++++++++++++--------------- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index c5314da..7a079da 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1139,12 +1139,14 @@ def load_unet(unet_path): #load unet in diffusers format fp16 = model_management.should_use_fp16(model_params=parameters) match = {} - match["context_dim"] = sd["down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1] + match["context_dim"] = sd["down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight"].shape[1] match["model_channels"] = sd["conv_in.weight"].shape[0] match["in_channels"] = sd["conv_in.weight"].shape[1] match["adm_in_channels"] = None if "class_embedding.linear_1.weight" in sd: match["adm_in_channels"] = sd["class_embedding.linear_1.weight"].shape[1] + elif "add_embedding.linear_1.weight" in sd: + match["adm_in_channels"] = sd["add_embedding.linear_1.weight"].shape[1] SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': fp16, 'in_channels': 4, 'model_channels': 320, @@ -1198,6 +1200,7 @@ def load_unet(unet_path): #load unet in diffusers format model = model.to(offload_device) model.load_model_weights(new_sd, "") return ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) + print("ERROR UNSUPPORTED UNET", unet_path) def save_checkpoint(output_path, model, clip, vae, metadata=None): try: diff --git a/comfy/utils.py b/comfy/utils.py index d410e6a..3bbe4f9 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -120,20 +120,24 @@ UNET_MAP_RESNET = { } UNET_MAP_BASIC = { - "label_emb.0.0.weight": "class_embedding.linear_1.weight", - "label_emb.0.0.bias": "class_embedding.linear_1.bias", - "label_emb.0.2.weight": "class_embedding.linear_2.weight", - "label_emb.0.2.bias": "class_embedding.linear_2.bias", - "input_blocks.0.0.weight": "conv_in.weight", - "input_blocks.0.0.bias": "conv_in.bias", - "out.0.weight": "conv_norm_out.weight", - "out.0.bias": "conv_norm_out.bias", - "out.2.weight": "conv_out.weight", - "out.2.bias": "conv_out.bias", - "time_embed.0.weight": "time_embedding.linear_1.weight", - "time_embed.0.bias": "time_embedding.linear_1.bias", - "time_embed.2.weight": "time_embedding.linear_2.weight", - "time_embed.2.bias": "time_embedding.linear_2.bias" + ("label_emb.0.0.weight", "class_embedding.linear_1.weight"), + ("label_emb.0.0.bias", "class_embedding.linear_1.bias"), + ("label_emb.0.2.weight", "class_embedding.linear_2.weight"), + ("label_emb.0.2.bias", "class_embedding.linear_2.bias"), + ("label_emb.0.0.weight", "add_embedding.linear_1.weight"), + ("label_emb.0.0.bias", "add_embedding.linear_1.bias"), + ("label_emb.0.2.weight", "add_embedding.linear_2.weight"), + ("label_emb.0.2.bias", "add_embedding.linear_2.bias"), + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + ("out.0.weight", "conv_norm_out.weight"), + ("out.0.bias", "conv_norm_out.bias"), + ("out.2.weight", "conv_out.weight"), + ("out.2.bias", "conv_out.bias"), + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias") } def unet_to_diffusers(unet_config): @@ -208,7 +212,7 @@ def unet_to_diffusers(unet_config): n += 1 for k in UNET_MAP_BASIC: - diffusers_unet_map[UNET_MAP_BASIC[k]] = k + diffusers_unet_map[k[1]] = k[0] return diffusers_unet_map