|
|
|
@ -206,6 +206,21 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
|
|
|
|
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
|
|
|
|
|
code2idx = {"q": 0, "k": 1, "v": 2}
|
|
|
|
|
|
|
|
|
|
# This function exists because at the time of writing torch.cat can't do fp8 with cuda
|
|
|
|
|
def cat_tensors(tensors):
|
|
|
|
|
x = 0
|
|
|
|
|
for t in tensors:
|
|
|
|
|
x += t.shape[0]
|
|
|
|
|
|
|
|
|
|
shape = [x] + list(tensors[0].shape)[1:]
|
|
|
|
|
out = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype)
|
|
|
|
|
|
|
|
|
|
x = 0
|
|
|
|
|
for t in tensors:
|
|
|
|
|
out[x:x + t.shape[0]] = t
|
|
|
|
|
x += t.shape[0]
|
|
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
|
|
|
|
new_state_dict = {}
|
|
|
|
@ -249,13 +264,13 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
|
|
|
|
|
if None in tensors:
|
|
|
|
|
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
|
|
|
|
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
|
|
|
|
new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
|
|
|
|
|
new_state_dict[relabelled_key + ".in_proj_weight"] = cat_tensors(tensors)
|
|
|
|
|
|
|
|
|
|
for k_pre, tensors in capture_qkv_bias.items():
|
|
|
|
|
if None in tensors:
|
|
|
|
|
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
|
|
|
|
|
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
|
|
|
|
|
new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
|
|
|
|
|
new_state_dict[relabelled_key + ".in_proj_bias"] = cat_tensors(tensors)
|
|
|
|
|
|
|
|
|
|
return new_state_dict
|
|
|
|
|
|
|
|
|
|