|
|
|
@ -415,6 +415,59 @@ def auraflow_to_diffusers(mmdit_config, output_prefix=""):
|
|
|
|
|
|
|
|
|
|
return key_map
|
|
|
|
|
|
|
|
|
|
def flux_to_diffusers(mmdit_config, output_prefix=""):
|
|
|
|
|
n_double_layers = mmdit_config.get("depth", 0)
|
|
|
|
|
n_single_layers = mmdit_config.get("depth_single_blocks", 0)
|
|
|
|
|
hidden_size = mmdit_config.get("hidden_size", 0)
|
|
|
|
|
|
|
|
|
|
key_map = {}
|
|
|
|
|
for index in range(n_double_layers):
|
|
|
|
|
prefix_from = "transformer_blocks.{}".format(index)
|
|
|
|
|
prefix_to = "{}double_blocks.{}".format(output_prefix, index)
|
|
|
|
|
|
|
|
|
|
for end in ("weight", "bias"):
|
|
|
|
|
k = "{}.attn.".format(prefix_from)
|
|
|
|
|
qkv = "{}.img_attn.qkv.{}".format(prefix_to, end)
|
|
|
|
|
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
|
|
|
|
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
|
|
|
|
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
|
|
|
|
|
|
|
|
|
block_map = {"attn.to_out.0.weight": "img_attn.proj.weight",
|
|
|
|
|
"attn.to_out.0.bias": "img_attn.proj.bias",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for k in block_map:
|
|
|
|
|
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
|
|
|
|
|
|
|
|
|
|
for index in range(n_single_layers):
|
|
|
|
|
prefix_from = "single_transformer_blocks.{}".format(index)
|
|
|
|
|
prefix_to = "{}single_blocks.{}".format(output_prefix, index)
|
|
|
|
|
|
|
|
|
|
for end in ("weight", "bias"):
|
|
|
|
|
k = "{}.attn.".format(prefix_from)
|
|
|
|
|
qkv = "{}.linear1.{}".format(prefix_to, end)
|
|
|
|
|
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
|
|
|
|
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
|
|
|
|
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
|
|
|
|
key_map["{}proj_mlp.{}".format(k, end)] = (qkv, (0, hidden_size * 3, hidden_size))
|
|
|
|
|
|
|
|
|
|
block_map = {#TODO
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for k in block_map:
|
|
|
|
|
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
|
|
|
|
|
|
|
|
|
|
MAP_BASIC = { #TODO
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for k in MAP_BASIC:
|
|
|
|
|
if len(k) > 2:
|
|
|
|
|
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
|
|
|
|
|
else:
|
|
|
|
|
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
|
|
|
|
|
|
|
|
|
return key_map
|
|
|
|
|
|
|
|
|
|
def repeat_to_batch_size(tensor, batch_size, dim=0):
|
|
|
|
|
if tensor.shape[dim] > batch_size:
|
|
|
|
|
return tensor.narrow(dim, 0, batch_size)
|
|
|
|
|