|
|
|
@ -120,20 +120,24 @@ UNET_MAP_RESNET = {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
UNET_MAP_BASIC = {
|
|
|
|
|
"label_emb.0.0.weight": "class_embedding.linear_1.weight",
|
|
|
|
|
"label_emb.0.0.bias": "class_embedding.linear_1.bias",
|
|
|
|
|
"label_emb.0.2.weight": "class_embedding.linear_2.weight",
|
|
|
|
|
"label_emb.0.2.bias": "class_embedding.linear_2.bias",
|
|
|
|
|
"input_blocks.0.0.weight": "conv_in.weight",
|
|
|
|
|
"input_blocks.0.0.bias": "conv_in.bias",
|
|
|
|
|
"out.0.weight": "conv_norm_out.weight",
|
|
|
|
|
"out.0.bias": "conv_norm_out.bias",
|
|
|
|
|
"out.2.weight": "conv_out.weight",
|
|
|
|
|
"out.2.bias": "conv_out.bias",
|
|
|
|
|
"time_embed.0.weight": "time_embedding.linear_1.weight",
|
|
|
|
|
"time_embed.0.bias": "time_embedding.linear_1.bias",
|
|
|
|
|
"time_embed.2.weight": "time_embedding.linear_2.weight",
|
|
|
|
|
"time_embed.2.bias": "time_embedding.linear_2.bias"
|
|
|
|
|
("label_emb.0.0.weight", "class_embedding.linear_1.weight"),
|
|
|
|
|
("label_emb.0.0.bias", "class_embedding.linear_1.bias"),
|
|
|
|
|
("label_emb.0.2.weight", "class_embedding.linear_2.weight"),
|
|
|
|
|
("label_emb.0.2.bias", "class_embedding.linear_2.bias"),
|
|
|
|
|
("label_emb.0.0.weight", "add_embedding.linear_1.weight"),
|
|
|
|
|
("label_emb.0.0.bias", "add_embedding.linear_1.bias"),
|
|
|
|
|
("label_emb.0.2.weight", "add_embedding.linear_2.weight"),
|
|
|
|
|
("label_emb.0.2.bias", "add_embedding.linear_2.bias"),
|
|
|
|
|
("input_blocks.0.0.weight", "conv_in.weight"),
|
|
|
|
|
("input_blocks.0.0.bias", "conv_in.bias"),
|
|
|
|
|
("out.0.weight", "conv_norm_out.weight"),
|
|
|
|
|
("out.0.bias", "conv_norm_out.bias"),
|
|
|
|
|
("out.2.weight", "conv_out.weight"),
|
|
|
|
|
("out.2.bias", "conv_out.bias"),
|
|
|
|
|
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
|
|
|
|
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
|
|
|
|
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
|
|
|
|
("time_embed.2.bias", "time_embedding.linear_2.bias")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def unet_to_diffusers(unet_config):
|
|
|
|
@ -208,7 +212,7 @@ def unet_to_diffusers(unet_config):
|
|
|
|
|
n += 1
|
|
|
|
|
|
|
|
|
|
for k in UNET_MAP_BASIC:
|
|
|
|
|
diffusers_unet_map[UNET_MAP_BASIC[k]] = k
|
|
|
|
|
diffusers_unet_map[k[1]] = k[0]
|
|
|
|
|
|
|
|
|
|
return diffusers_unet_map
|
|
|
|
|
|
|
|
|
|