You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository has been archived by the owner on Aug 1, 2024. It is now read-only.
When loading the pretrained msa transformer, in the function esm.pretrained._load_model_and_alphabet_core_v1 there is the following section of code:
elif model_data["args"].arch == "msa_transformer":
# upgrade state dict
pra = lambda s: "".join(s.split("encoder_")[1:] if "encoder" in s else s)
prs1 = lambda s: "".join(s.split("encoder.")[1:] if "encoder" in s else s)
prs2 = lambda s: "".join(
s.split("sentence_encoder.")[1:] if "sentence_encoder" in s else s
)
prs3 = lambda s: s.replace("row", "column") if "row" in s else s.replace("column", "row")
model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()}
model_state = {prs1(prs2(prs3(arg[0]))): arg[1] for arg in model_data["model"].items()}
if model_args.get("embed_positions_msa", False):
emb_dim = model_state["msa_position_embedding"].size(-1)
model_args["embed_positions_msa_dim"] = emb_dim # initial release, bug: emb_dim==1
prs3 swaps all state dict keys containing 'row' with 'column' and vice versa. Why is this happening? This seems to have the end effect of swapping the weights of the AxialTransformerLayers, treating row attention as column attention and vice versa, which is very counterintuitive. Some insight into this would be welcome.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
When loading the pretrained msa transformer, in the function esm.pretrained._load_model_and_alphabet_core_v1 there is the following section of code:
prs3 swaps all state dict keys containing 'row' with 'column' and vice versa. Why is this happening? This seems to have the end effect of swapping the weights of the AxialTransformerLayers, treating row attention as column attention and vice versa, which is very counterintuitive. Some insight into this would be welcome.
Beta Was this translation helpful? Give feedback.
All reactions