Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue in adapting roc_plot_generator to custom task for locally saved runs #101

Open
Iust1n2 opened this issue Jul 16, 2024 · 0 comments
Open

Comments

@Iust1n2
Copy link

Iust1n2 commented Jul 16, 2024

I setup ACDC for my custom task and saved runs as log.txt files as I do not have a wandb account. I'm trying to adapt the get_acdc_runs function in roc_plot_generator.py:

def get_acdc_runs(
    exp,
    project_path: str,  # Local path to the project directory
    pre_run_filter: Optional[Callable[[str], bool]] = None,
    run_filter: Optional[Callable[[Dict[str, Any]], bool]] = None,
    clip: Optional[int] = None,
    return_ids: bool = False,
):
    if clip is None:
        clip = 100_000

    # List all run directories, excluding known non-run directories like 'logs'
    run_dirs = [d for d in os.listdir(project_path) if os.path.isdir(os.path.join(project_path, d)) and d != 'logs']
    
    if pre_run_filter is not None:
        run_dirs = list(filter(pre_run_filter, run_dirs))

    if run_filter is None:
        filtered_run_dirs = run_dirs[:clip]
    else:
        filtered_run_dirs = list(filter(run_filter, tqdm(run_dirs[:clip])))

    print(f"Loading {len(filtered_run_dirs)} runs with filter {pre_run_filter} and {run_filter}")

    threshold_to_run_map: Dict[float, AcdcRunCandidate] = {}

    def add_run_for_processing(candidate: AcdcRunCandidate):
        if candidate.threshold not in threshold_to_run_map:
            threshold_to_run_map[candidate.threshold] = candidate
        else:
            if candidate.steps > threshold_to_run_map[candidate.threshold].steps:
                threshold_to_run_map[candidate.threshold] = candidate

    for run_dir in filtered_run_dirs:
        # run_path = os.path.join(project_path, run_dir)
        metrics_path = os.path.join(project_path, 'logs', 'metrics.json')
        
        # Skip directories without the necessary files
        if not os.path.exists(metrics_path):
            print(f"Metrics file not found for run {run_dir}")
            continue

        try:
            with open(metrics_path, 'r') as f:
                metrics_data = json.load(f)
                parents = metrics_data["list_of_parents_evaluated"]
                children = metrics_data["list_of_children_evaluated"]
                results = metrics_data["results"]
                steps = metrics_data.get("steps", [])
        except KeyError:
            print(f"Required keys not found in metrics file for run {run_dir}")
            continue
        try:
            log_path = os.path.join(project_path, 'logs', 'log.txt')
            with open(log_path) as f:
                log_file = f.read()
            exp.load_from_local_run(log_file)
        except FileNotFoundError:
            print(f"Log file not found for run {run_dir}")
            continue

        threshold = 0.05  # Fixed threshold value
        corr = deepcopy(exp.corr)

        for parent, child, result in zip(parents, children, results):
            parent_node = parse_interpnode(parent)
            child_node = parse_interpnode(child)

            if result < threshold:
                corr.edges[child_node.name][child_node.index][parent_node.name][parent_node.index].present = False
                corr.remove_edge(child_node.name, child_node.index, parent_node.name, parent_node.index)
            else:
                corr.edges[child_node.name][child_node.index][parent_node.name][parent_node.index].present = True

        # Use the number of steps in the results if available, otherwise default to length of results
        num_steps = steps[-1] if steps else len(results)
        score_d = {"steps": num_steps, "score": threshold}

        candidate = AcdcRunCandidate(
            threshold=threshold,
            steps=num_steps,
            score_d=score_d,
            corr=corr
        )
        add_run_for_processing(candidate)
        print(f"Added candidate: {candidate}")

    # Added to handle the test functions
    def all_test_fns(data: torch.Tensor) -> dict[str, float]:
        return {f"test_{name}": fn(data).item() for name, fn in things.test_metrics.items()}

    all_candidates = list(threshold_to_run_map.values())
    for candidate in all_candidates:
        test_metrics = exp.call_metric_with_corr(candidate.corr, all_test_fns, things.test_data)
        candidate.score_d.update(test_metrics)
        print(f"Processed candidate: {candidate}")

    corrs = [(candidate.corr, candidate.score_d) for candidate in all_candidates]
    if return_ids:
        return corrs, [candidate.threshold for candidate in all_candidates]  # Corrected to return threshold
    return corrs

When I run:

if not SKIP_ACDC: 
    project_path = "acdc/hybridretrieval/acdc_results/kbicr_indirect_kl_div_0.05"
    acdc_corrs = get_acdc_runs(None if things is None else exp, project_path, clip=1 if TESTING else None, return_ids=True)
    assert len(acdc_corrs) > 1
    print("acdc_corrs", len(acdc_corrs), acdc_corrs)

I get an error on exp.load_from_local_run(log_file) saying that:

File ~/Mech-Interp/Automatic-Circuit-Discovery/acdc/TLACDCExperiment.py:889 in load_from_local_run
    parent_name, parent_list, current_name, current_list = extract_info(previous_line)

  File ~/Mech-Interp/Automatic-Circuit-Discovery/acdc/acdc_utils.py:219 in extract_info
    parent_list = [ast.literal_eval(item if item != "COL" else "None") for item in parent_list_items]

  File ~/Mech-Interp/Automatic-Circuit-Discovery/acdc/acdc_utils.py:219 in <listcomp>
    parent_list = [ast.literal_eval(item if item != "COL" else "None") for item in parent_list_items]

  File ~/.conda/envs/acdc/lib/python3.10/ast.py:62 in literal_eval
    node_or_string = parse(node_or_string.lstrip(" \t"), mode='eval')

  File ~/.conda/envs/acdc/lib/python3.10/ast.py:50 in parse
    return compile(source, filename, mode, flags,
...
  File <unknown>:1
    :
    ^
SyntaxError: unexpected EOF while parsing 

And adding a print statement to load_from_local_run in TLACDCExperiment class:

  try:
        parent_name, parent_list, current_name, current_list = extract_info(previous_line)
        if parent_name and parent_list and current_name and current_list:
            parent_torch_index, current_torch_index = TorchIndex(parent_list), TorchIndex(current_list)
            # Debugging print statements
            print(f"parent_torch_index: {parent_torch_index}")
            print(f"current_torch_index: {current_torch_index}")
            print(f"self.corr.edges keys: {self.corr.edges.keys()}")
            if current_name in self.corr.edges:
                print(f"self.corr.edges[{current_name}] keys: {self.corr.edges[current_name].keys()}")
            if current_name in self.corr.edges and current_torch_index in self.corr.edges[current_name]:
                print(f"self.corr.edges[{current_name}][{current_torch_index}] keys: {self.corr.edges[current_name][current_torch_index].keys()}")
            self.corr.edges[current_name][current_torch_index][parent_name][parent_torch_index].present = keeping_connection
        else:
            print(f"Invalid data extracted from line: {previous_line}")
    except SyntaxError as e:
        print(f"Error parsing line: {previous_line}")
        print(f"Exception: {e}")
        continue  # Skip this line and move to the next one
    except KeyError as e:
        print(f"KeyError for line: {previous_line}")
        print(f"Exception: {e}")
        continue  # Skip this line and move to the next one

assert found_at_least_one_readable_line, f"No readable lines found in the log file. Is this formatted correctly ??? {lines=}"

returns this error for all edges:

Processing previous_line: Node: cur_parent=TLACDCInterpNode(blocks.11.hook_mlp_out, [:]) (self.current_node=TLACDCInterpNode(blocks.11.hook_resid_post, [:]))
Error parsing line: Node: cur_parent=TLACDCInterpNode(blocks.11.hook_mlp_out, [:]) (self.current_node=TLACDCInterpNode(blocks.11.hook_resid_post, [:]))
Exception: unexpected EOF while parsing (<unknown>, line 1)

Every run was saved to a log.txt with this command:

python main.py --task hybrid-retrieval --zero-ablation --threshold 0.15 --indices-mode reverse --first-cache-cpu False --second-cache-cpu False --max-num-epochs 100000 > log.txt 2>&1

I would really appreciate some help!

EDIT: I think the error is caused by the format of the edge correspondence. The [:] symbol is parsed as an unexpected EOF in the cur_parent part of the string. Is the extract_info supposed to select the parent node from each line and iteratively add parents to a list (and the same to the current node)? If yes, then how can I work my way around the symbol causing the error without modifying the function?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant