Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading model fails in tutorial #97

Open
velezbeltran opened this issue Jan 12, 2024 · 4 comments
Open

Loading model fails in tutorial #97

velezbeltran opened this issue Jan 12, 2024 · 4 comments

Comments

@velezbeltran
Copy link

velezbeltran commented Jan 12, 2024

Hello!

I have been working on the ACDC_Main_Demo.ipynb repo and I am currently facing an issue where if I attempt to load the model from a subgraph I get an error. In particular I attempt the following steps.

  1. I run the notebook as is adding one additional line where I save the output of the subgraph using the command below
exp.save_subgraph(
    save_path,
    return_it=True,
)
  1. I run all of the cells except for the cell containing the code
for i in range(args.max_num_epochs):
    exp.step(testing=False)

    show(
        exp.corr,
        f"ims/img_new_{i+1}.png",
        show_full_index=False,
    )

    if IN_COLAB or ipython is not None:
        # so long as we're not running this as a script, show the image!
        display(Image(f"ims/img_new_{i+1}.png"))

    print(i, "-" * 50)
    print(exp.count_no_edges())

    if i == 0:
        exp.save_edges("edges.pkl")

    if exp.current_node is None or SINGLE_STEP:
        break

exp.save_edges("another_final_edges.pkl")

if USING_WANDB:
    edges_fname = f"edges.pth"
    exp.save_edges(edges_fname)
    artifact = wandb.Artifact(edges_fname, type="dataset")
    artifact.add_file(edges_fname)
    wandb.log_artifact(artifact)
    os.remove(edges_fname)
    wandb.finish()
    
  1. I load the subgraph using
# load using torch
circuit = t.load(subgraph_path)
exp.load_subgraph(circuit)

If I do this I get the following assertion error:

AssertionError: Ensure that the dictionary includes exactly the correct keys... e.g missing [('blocks.1.hook_q_input', (None, None, 0), 'blocks.0.attn.hook_result', (None, None, 1))] and has excess stuff []

What could be causing this? Am I doing something wrong? Alternatively, what is the standard way of loading in circuits?
Also, if I do run the cell that contains the .step() method I don't have this issue.

Thank you!
Nicolas

@rhaps0dy
Copy link
Collaborator

Possibly the TransformerLens version you're using is different from the one that was used to save the hypothesis, so the hook names are different. What's the list of edges from exp.corr.all_edges().keys() ?

@velezbeltran
Copy link
Author

velezbeltran commented Jan 12, 2024

Thanks for your lighting fast response!

Before running the `.step()` function block

ict_keys([('blocks.1.hook_resid_post', [:], 'blocks.1.attn.hook_result', [:, :, 7]), ('blocks.1.hook_resid_post', [:],
'blocks.1.attn.hook_result', [:, :, 6]), ('blocks.1.hook_resid_post', [:], 'blocks.1.attn.hook_result', [:, :, 5]),
('blocks.1.hook_resid_post', [:], 'blocks.1.attn.hook_result', [:, :, 4]), ('blocks.1.hook_resid_post', [:],
'blocks.1.attn.hook_result', [:, :, 3]), ('blocks.1.hook_resid_post', [:], 'blocks.1.attn.hook_result', [:, :, 2]),
('blocks.1.hook_resid_post', [:], 'blocks.1.attn.hook_result', [:, :, 1]), ('blocks.1.hook_resid_post', [:],
'blocks.1.attn.hook_result', [:, :, 0]), ('blocks.1.hook_resid_post', [:], 'blocks.0.attn.hook_result', [:, :, 7]),
('blocks.1.hook_resid_post', [:], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_resid_post', [:],
'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_resid_post', [:], 'blocks.0.attn.hook_result', [:, :, 4]),
('blocks.1.hook_resid_post', [:], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_resid_post', [:],
'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_resid_post', [:], 'blocks.0.attn.hook_result', [:, :, 1]),
('blocks.1.hook_resid_post', [:], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_resid_post', [:],
'blocks.0.hook_resid_pre', [:]), ('blocks.1.attn.hook_result', [:, :, 7], 'blocks.1.attn.hook_q', [:, :, 7]),
('blocks.1.attn.hook_result', [:, :, 7], 'blocks.1.attn.hook_k', [:, :, 7]), ('blocks.1.attn.hook_result', [:, :, 7],
'blocks.1.attn.hook_v', [:, :, 7]), ('blocks.1.attn.hook_result', [:, :, 6], 'blocks.1.attn.hook_q', [:, :, 6]),
('blocks.1.attn.hook_result', [:, :, 6], 'blocks.1.attn.hook_k', [:, :, 6]), ('blocks.1.attn.hook_result', [:, :, 6],
'blocks.1.attn.hook_v', [:, :, 6]), ('blocks.1.attn.hook_result', [:, :, 5], 'blocks.1.attn.hook_q', [:, :, 5]),
('blocks.1.attn.hook_result', [:, :, 5], 'blocks.1.attn.hook_k', [:, :, 5]), ('blocks.1.attn.hook_result', [:, :, 5],
'blocks.1.attn.hook_v', [:, :, 5]), ('blocks.1.attn.hook_result', [:, :, 4], 'blocks.1.attn.hook_q', [:, :, 4]),
('blocks.1.attn.hook_result', [:, :, 4], 'blocks.1.attn.hook_k', [:, :, 4]), ('blocks.1.attn.hook_result', [:, :, 4],
'blocks.1.attn.hook_v', [:, :, 4]), ('blocks.1.attn.hook_result', [:, :, 3], 'blocks.1.attn.hook_q', [:, :, 3]),
('blocks.1.attn.hook_result', [:, :, 3], 'blocks.1.attn.hook_k', [:, :, 3]), ('blocks.1.attn.hook_result', [:, :, 3],
'blocks.1.attn.hook_v', [:, :, 3]), ('blocks.1.attn.hook_result', [:, :, 2], 'blocks.1.attn.hook_q', [:, :, 2]),
('blocks.1.attn.hook_result', [:, :, 2], 'blocks.1.attn.hook_k', [:, :, 2]), ('blocks.1.attn.hook_result', [:, :, 2],
'blocks.1.attn.hook_v', [:, :, 2]), ('blocks.1.attn.hook_result', [:, :, 1], 'blocks.1.attn.hook_q', [:, :, 1]),
('blocks.1.attn.hook_result', [:, :, 1], 'blocks.1.attn.hook_k', [:, :, 1]), ('blocks.1.attn.hook_result', [:, :, 1],
'blocks.1.attn.hook_v', [:, :, 1]), ('blocks.1.attn.hook_result', [:, :, 0], 'blocks.1.attn.hook_q', [:, :, 0]),
('blocks.1.attn.hook_result', [:, :, 0], 'blocks.1.attn.hook_k', [:, :, 0]), ('blocks.1.attn.hook_result', [:, :, 0],
'blocks.1.attn.hook_v', [:, :, 0]), ('blocks.1.attn.hook_q', [:, :, 7], 'blocks.1.hook_q_input', [:, :, 7]),
('blocks.1.attn.hook_q', [:, :, 6], 'blocks.1.hook_q_input', [:, :, 6]),
('blocks.1.attn.hook_q', [:, :, 5], 'blocks.1.hook_q_input', [:, :, 5]),
('blocks.1.attn.hook_q', [:, :, 4], 'blocks.1.hook_q_input', [:, :, 4]),
('blocks.1.attn.hook_q', [:, :, 3], 'blocks.1.hook_q_input', [:, :, 3]),
('blocks.1.attn.hook_q', [:, :, 2], 'blocks.1.hook_q_input', [:, :, 2]),
('blocks.1.attn.hook_q', [:, :, 1], 'blocks.1.hook_q_input', [:, :, 1]),
('blocks.1.attn.hook_q', [:, :, 0], 'blocks.1.hook_q_input', [:, :, 0]),
('blocks.1.attn.hook_k', [:, :, 7], 'blocks.1.hook_k_input', [:, :, 7]),
('blocks.1.attn.hook_k', [:, :, 6], 'blocks.1.hook_k_input', [:, :, 6]),
('blocks.1.attn.hook_k', [:, :, 5], 'blocks.1.hook_k_input', [:, :, 5]),
('blocks.1.attn.hook_k', [:, :, 4], 'blocks.1.hook_k_input', [:, :, 4]), ('blocks.1.attn.hook_k', [:, :, 3], 'blocks.1.hook_k_input', [:, :, 3]), ('blocks.1.attn.hook_k', [:, :, 2], 'blocks.1.hook_k_input', [:, :, 2]), ('blocks.1.attn.hook_k', [:, :, 1], 'blocks.1.hook_k_input', [:, :, 1]), ('blocks.1.attn.hook_k', [:, :, 0], 'blocks.1.hook_k_input', [:, :, 0]), ('blocks.1.attn.hook_v', [:, :, 7], 'blocks.1.hook_v_input', [:, :, 7]), ('blocks.1.attn.hook_v', [:, :, 6], 'blocks.1.hook_v_input', [:, :, 6]), ('blocks.1.attn.hook_v', [:, :, 5], 'blocks.1.hook_v_input', [:, :, 5]), ('blocks.1.attn.hook_v', [:, :, 4], 'blocks.1.hook_v_input', [:, :, 4]), ('blocks.1.attn.hook_v', [:, :, 3], 'blocks.1.hook_v_input', [:, :, 3]), ('blocks.1.attn.hook_v', [:, :, 2], 'blocks.1.hook_v_input', [:, :, 2]), ('blocks.1.attn.hook_v', [:, :, 1], 'blocks.1.hook_v_input', [:, :, 1]), ('blocks.1.attn.hook_v', [:, :, 0], 'blocks.1.hook_v_input', [:, :, 0]), ('blocks.1.hook_q_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_q_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_q_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_q_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_q_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_q_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_q_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_q_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_q_input', [:, :, 7], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_q_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_q_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_q_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_q_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_q_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_q_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_q_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_q_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_q_input', [:, :, 5], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_q_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_q_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_q_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_q_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_q_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_q_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_q_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_q_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_q_input', [:, :, 4], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_q_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_q_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_q_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_q_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_q_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_q_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_q_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_q_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_q_input', [:, :, 3], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_q_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_q_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_q_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_q_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_q_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_q_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_q_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_q_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_q_input', [:, :, 2], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_q_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_q_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_q_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_q_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_q_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_q_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_q_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_q_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_q_input', [:, :, 1], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_q_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_q_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_q_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_q_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_q_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_q_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_q_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_q_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_q_input', [:, :, 0], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_k_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_k_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_k_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_k_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_k_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_k_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_k_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_k_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_k_input', [:, :, 7], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_k_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_k_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_k_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_k_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_k_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_k_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_k_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_k_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_k_input', [:, :, 5], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_k_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_k_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_k_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_k_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_k_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_k_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_k_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_k_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_k_input', [:, :, 4], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_k_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_k_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_k_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_k_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_k_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_k_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_k_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_k_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_k_input', [:, :, 3], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_k_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_k_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_k_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_k_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_k_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_k_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_k_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_k_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_k_input', [:, :, 2], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_k_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_k_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_k_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_k_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_k_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_k_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_k_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_k_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_k_input', [:, :, 1], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_k_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_k_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_k_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_k_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_k_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_k_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_k_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_k_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_k_input', [:, :, 0], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_v_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_v_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_v_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_v_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_v_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_v_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_v_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_v_input', [:, :, 7], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_v_input', [:, :, 7], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_v_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_v_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_v_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_v_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_v_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_v_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_v_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_v_input', [:, :, 5], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_v_input', [:, :, 5], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_v_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_v_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_v_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_v_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_v_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_v_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_v_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_v_input', [:, :, 4], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_v_input', [:, :, 4], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_v_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_v_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_v_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_v_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_v_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_v_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_v_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_v_input', [:, :, 3], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_v_input', [:, :, 3], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_v_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_v_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_v_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_v_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_v_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_v_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_v_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_v_input', [:, :, 2], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_v_input', [:, :, 2], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_v_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_v_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_v_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_v_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_v_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_v_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_v_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_v_input', [:, :, 1], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_v_input', [:, :, 1], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_v_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 7]), ('blocks.1.hook_v_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 6]), ('blocks.1.hook_v_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 5]), ('blocks.1.hook_v_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 4]), ('blocks.1.hook_v_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 3]), ('blocks.1.hook_v_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 2]), ('blocks.1.hook_v_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 1]), ('blocks.1.hook_v_input', [:, :, 0], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_v_input', [:, :, 0], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.attn.hook_result', [:, :, 7], 'blocks.0.attn.hook_q', [:, :, 7]), ('blocks.0.attn.hook_result', [:, :, 7], 'blocks.0.attn.hook_k', [:, :, 7]), ('blocks.0.attn.hook_result', [:, :, 7], 'blocks.0.attn.hook_v', [:, :, 7]), ('blocks.0.attn.hook_result', [:, :, 6], 'blocks.0.attn.hook_q', [:, :, 6]), ('blocks.0.attn.hook_result', [:, :, 6], 'blocks.0.attn.hook_k', [:, :, 6]), ('blocks.0.attn.hook_result', [:, :, 6], 'blocks.0.attn.hook_v', [:, :, 6]), ('blocks.0.attn.hook_result', [:, :, 5], 'blocks.0.attn.hook_q', [:, :, 5]), ('blocks.0.attn.hook_result', [:, :, 5], 'blocks.0.attn.hook_k', [:, :, 5]), ('blocks.0.attn.hook_result', [:, :, 5], 'blocks.0.attn.hook_v', [:, :, 5]), ('blocks.0.attn.hook_result', [:, :, 4], 'blocks.0.attn.hook_q', [:, :, 4]), ('blocks.0.attn.hook_result', [:, :, 4], 'blocks.0.attn.hook_k', [:, :, 4]), ('blocks.0.attn.hook_result', [:, :, 4], 'blocks.0.attn.hook_v', [:, :, 4]), ('blocks.0.attn.hook_result', [:, :, 3], 'blocks.0.attn.hook_q', [:, :, 3]), ('blocks.0.attn.hook_result', [:, :, 3], 'blocks.0.attn.hook_k', [:, :, 3]), ('blocks.0.attn.hook_result', [:, :, 3], 'blocks.0.attn.hook_v', [:, :, 3]), ('blocks.0.attn.hook_result', [:, :, 2], 'blocks.0.attn.hook_q', [:, :, 2]), ('blocks.0.attn.hook_result', [:, :, 2], 'blocks.0.attn.hook_k', [:, :, 2]), ('blocks.0.attn.hook_result', [:, :, 2], 'blocks.0.attn.hook_v', [:, :, 2]), ('blocks.0.attn.hook_result', [:, :, 1], 'blocks.0.attn.hook_q', [:, :, 1]), ('blocks.0.attn.hook_result', [:, :, 1], 'blocks.0.attn.hook_k', [:, :, 1]), ('blocks.0.attn.hook_result', [:, :, 1], 'blocks.0.attn.hook_v', [:, :, 1]), ('blocks.0.attn.hook_result', [:, :, 0], 'blocks.0.attn.hook_q', [:, :, 0]), ('blocks.0.attn.hook_result', [:, :, 0], 'blocks.0.attn.hook_k', [:, :, 0]), ('blocks.0.attn.hook_result', [:, :, 0], 'blocks.0.attn.hook_v', [:, :, 0]), ('blocks.0.attn.hook_q', [:, :, 7], 'blocks.0.hook_q_input', [:, :, 7]), ('blocks.0.attn.hook_q', [:, :, 6], 'blocks.0.hook_q_input', [:, :, 6]), ('blocks.0.attn.hook_q', [:, :, 5], 'blocks.0.hook_q_input', [:, :, 5]), ('blocks.0.attn.hook_q', [:, :, 4], 'blocks.0.hook_q_input', [:, :, 4]), ('blocks.0.attn.hook_q', [:, :, 3], 'blocks.0.hook_q_input', [:, :, 3]), ('blocks.0.attn.hook_q', [:, :, 2], 'blocks.0.hook_q_input', [:, :, 2]), ('blocks.0.attn.hook_q', [:, :, 1], 'blocks.0.hook_q_input', [:, :, 1]), ('blocks.0.attn.hook_q', [:, :, 0], 'blocks.0.hook_q_input', [:, :, 0]), ('blocks.0.attn.hook_k', [:, :, 7], 'blocks.0.hook_k_input', [:, :, 7]), ('blocks.0.attn.hook_k', [:, :, 6], 'blocks.0.hook_k_input', [:, :, 6]), ('blocks.0.attn.hook_k', [:, :, 5], 'blocks.0.hook_k_input', [:, :, 5]), ('blocks.0.attn.hook_k', [:, :, 4], 'blocks.0.hook_k_input', [:, :, 4]), ('blocks.0.attn.hook_k', [:, :, 3], 'blocks.0.hook_k_input', [:, :, 3]), ('blocks.0.attn.hook_k', [:, :, 2], 'blocks.0.hook_k_input', [:, :, 2]), ('blocks.0.attn.hook_k', [:, :, 1], 'blocks.0.hook_k_input', [:, :, 1]), ('blocks.0.attn.hook_k', [:, :, 0], 'blocks.0.hook_k_input', [:, :, 0]), ('blocks.0.attn.hook_v', [:, :, 7], 'blocks.0.hook_v_input', [:, :, 7]), ('blocks.0.attn.hook_v', [:, :, 6], 'blocks.0.hook_v_input', [:, :, 6]), ('blocks.0.attn.hook_v', [:, :, 5], 'blocks.0.hook_v_input', [:, :, 5]), ('blocks.0.attn.hook_v', [:, :, 4], 'blocks.0.hook_v_input', [:, :, 4]), ('blocks.0.attn.hook_v', [:, :, 3], 'blocks.0.hook_v_input', [:, :, 3]), ('blocks.0.attn.hook_v', [:, :, 2], 'blocks.0.hook_v_input', [:, :, 2]), ('blocks.0.attn.hook_v', [:, :, 1], 'blocks.0.hook_v_input', [:, :, 1]), ('blocks.0.attn.hook_v', [:, :, 0], 'blocks.0.hook_v_input', [:, :, 0]), ('blocks.0.hook_q_input', [:, :, 7], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_q_input', [:, :, 6], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_q_input', [:, :, 5], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_q_input', [:, :, 4], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_q_input', [:, :, 3], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_q_input', [:, :, 2], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_q_input', [:, :, 1], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_q_input', [:, :, 0], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_k_input', [:, :, 7], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_k_input', [:, :, 6], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_k_input', [:, :, 5], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_k_input', [:, :, 4], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_k_input', [:, :, 3], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_k_input', [:, :, 2], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_k_input', [:, :, 1], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_k_input', [:, :, 0], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_v_input', [:, :, 7], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_v_input', [:, :, 6], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_v_input', [:, :, 5], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_v_input', [:, :, 4], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_v_input', [:, :, 3], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_v_input', [:, :, 2], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_v_input', [:, :, 1], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.hook_v_input', [:, :, 0], 'blocks.0.hook_resid_pre', [:])])

After the step function

After running the `.step()` function block

dict_keys([('blocks.1.hook_resid_post', [:], 'blocks.1.attn.hook_result', [:, :, 6]), ('blocks.1.hook_resid_post', [:], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.attn.hook_result', [:, :, 6], 'blocks.1.attn.hook_q', [:, :, 6]), ('blocks.1.attn.hook_result', [:, :, 6], 'blocks.1.attn.hook_k', [:, :, 6]), ('blocks.1.attn.hook_result', [:, :, 6], 'blocks.1.attn.hook_v', [:, :, 6]), ('blocks.1.attn.hook_q', [:, :, 6], 'blocks.1.hook_q_input', [:, :, 6]), ('blocks.1.attn.hook_k', [:, :, 6], 'blocks.1.hook_k_input', [:, :, 6]), ('blocks.1.attn.hook_v', [:, :, 6], 'blocks.1.hook_v_input', [:, :, 6]), ('blocks.1.hook_q_input', [:, :, 6], 'blocks.0.hook_resid_pre', [:]), ('blocks.1.hook_k_input', [:, :, 6], 'blocks.0.attn.hook_result', [:, :, 0]), ('blocks.1.hook_v_input', [:, :, 6], 'blocks.0.hook_resid_pre', [:]), ('blocks.0.attn.hook_result', [:, :, 0], 'blocks.0.attn.hook_q', [:, :, 0]), ('blocks.0.attn.hook_result', [:, :, 0], 'blocks.0.attn.hook_k', [:, :, 0]), ('blocks.0.attn.hook_result', [:, :, 0], 'blocks.0.attn.hook_v', [:, :, 0]), ('blocks.0.attn.hook_v', [:, :, 0], 'blocks.0.hook_v_input', [:, :, 0]), ('blocks.0.hook_v_input', [:, :, 0], 'blocks.0.hook_resid_pre', [:])])

I don't think the issue is that the TransformerLens versions are different because I can reproduce this all from the same notebook in colab.

Thank you

@rhaps0dy
Copy link
Collaborator

Turns out the explanation is: the ACDC algorithm literally removes edges (i.e. removes them from the correspondence dictionaries), as opposed to saying edge.present = False. That makes it fail when loading.

The loading code should be changed to fix this.

@Iust1n2
Copy link

Iust1n2 commented Jul 2, 2024

@velezbeltran I'm curious if you would be so kind to share the working code for loading the subgraph weights edges.pth for inference. I did not quite catch from @rhaps0dy what the modification should be and where. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants