|
|
|
@ -249,6 +249,11 @@ def unet_to_diffusers(unet_config):
|
|
|
|
|
|
|
|
|
|
return diffusers_unet_map
|
|
|
|
|
|
|
|
|
|
def swap_scale_shift(weight):
|
|
|
|
|
shift, scale = weight.chunk(2, dim=0)
|
|
|
|
|
new_weight = torch.cat([scale, shift], dim=0)
|
|
|
|
|
return new_weight
|
|
|
|
|
|
|
|
|
|
MMDIT_MAP_BASIC = {
|
|
|
|
|
("context_embedder.bias", "context_embedder.bias"),
|
|
|
|
|
("context_embedder.weight", "context_embedder.weight"),
|
|
|
|
@ -263,8 +268,8 @@ MMDIT_MAP_BASIC = {
|
|
|
|
|
("y_embedder.mlp.2.bias", "time_text_embed.text_embedder.linear_2.bias"),
|
|
|
|
|
("y_embedder.mlp.2.weight", "time_text_embed.text_embedder.linear_2.weight"),
|
|
|
|
|
("pos_embed", "pos_embed.pos_embed"),
|
|
|
|
|
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias"),
|
|
|
|
|
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight"),
|
|
|
|
|
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
|
|
|
|
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
|
|
|
|
("final_layer.linear.bias", "proj_out.bias"),
|
|
|
|
|
("final_layer.linear.weight", "proj_out.weight"),
|
|
|
|
|
}
|
|
|
|
@ -313,8 +318,15 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
|
|
|
|
|
for k in MMDIT_MAP_BLOCK:
|
|
|
|
|
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
|
|
|
|
|
|
|
|
|
|
for k in MMDIT_MAP_BASIC:
|
|
|
|
|
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
|
|
|
|
map_basic = MMDIT_MAP_BASIC.copy()
|
|
|
|
|
map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.bias".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.bias".format(depth - 1), swap_scale_shift))
|
|
|
|
|
map_basic.add(("joint_blocks.{}.context_block.adaLN_modulation.1.weight".format(depth - 1), "transformer_blocks.{}.norm1_context.linear.weight".format(depth - 1), swap_scale_shift))
|
|
|
|
|
|
|
|
|
|
for k in map_basic:
|
|
|
|
|
if len(k) > 2:
|
|
|
|
|
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
|
|
|
|
|
else:
|
|
|
|
|
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
|
|
|
|
|
|
|
|
|
return key_map
|
|
|
|
|
|
|
|
|
|