You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I setup ACDC for my custom task and saved runs as log.txt files as I do not have a wandb account. I'm trying to adapt the get_acdc_runs function in roc_plot_generator.py:
def get_acdc_runs(
exp,
project_path: str, # Local path to the project directory
pre_run_filter: Optional[Callable[[str], bool]] = None,
run_filter: Optional[Callable[[Dict[str, Any]], bool]] = None,
clip: Optional[int] = None,
return_ids: bool = False,
):
if clip is None:
clip = 100_000
# List all run directories, excluding known non-run directories like 'logs'
run_dirs = [d for d in os.listdir(project_path) if os.path.isdir(os.path.join(project_path, d)) and d != 'logs']
if pre_run_filter is not None:
run_dirs = list(filter(pre_run_filter, run_dirs))
if run_filter is None:
filtered_run_dirs = run_dirs[:clip]
else:
filtered_run_dirs = list(filter(run_filter, tqdm(run_dirs[:clip])))
print(f"Loading {len(filtered_run_dirs)} runs with filter {pre_run_filter} and {run_filter}")
threshold_to_run_map: Dict[float, AcdcRunCandidate] = {}
def add_run_for_processing(candidate: AcdcRunCandidate):
if candidate.threshold not in threshold_to_run_map:
threshold_to_run_map[candidate.threshold] = candidate
else:
if candidate.steps > threshold_to_run_map[candidate.threshold].steps:
threshold_to_run_map[candidate.threshold] = candidate
for run_dir in filtered_run_dirs:
# run_path = os.path.join(project_path, run_dir)
metrics_path = os.path.join(project_path, 'logs', 'metrics.json')
# Skip directories without the necessary files
if not os.path.exists(metrics_path):
print(f"Metrics file not found for run {run_dir}")
continue
try:
with open(metrics_path, 'r') as f:
metrics_data = json.load(f)
parents = metrics_data["list_of_parents_evaluated"]
children = metrics_data["list_of_children_evaluated"]
results = metrics_data["results"]
steps = metrics_data.get("steps", [])
except KeyError:
print(f"Required keys not found in metrics file for run {run_dir}")
continue
try:
log_path = os.path.join(project_path, 'logs', 'log.txt')
with open(log_path) as f:
log_file = f.read()
exp.load_from_local_run(log_file)
except FileNotFoundError:
print(f"Log file not found for run {run_dir}")
continue
threshold = 0.05 # Fixed threshold value
corr = deepcopy(exp.corr)
for parent, child, result in zip(parents, children, results):
parent_node = parse_interpnode(parent)
child_node = parse_interpnode(child)
if result < threshold:
corr.edges[child_node.name][child_node.index][parent_node.name][parent_node.index].present = False
corr.remove_edge(child_node.name, child_node.index, parent_node.name, parent_node.index)
else:
corr.edges[child_node.name][child_node.index][parent_node.name][parent_node.index].present = True
# Use the number of steps in the results if available, otherwise default to length of results
num_steps = steps[-1] if steps else len(results)
score_d = {"steps": num_steps, "score": threshold}
candidate = AcdcRunCandidate(
threshold=threshold,
steps=num_steps,
score_d=score_d,
corr=corr
)
add_run_for_processing(candidate)
print(f"Added candidate: {candidate}")
# Added to handle the test functions
def all_test_fns(data: torch.Tensor) -> dict[str, float]:
return {f"test_{name}": fn(data).item() for name, fn in things.test_metrics.items()}
all_candidates = list(threshold_to_run_map.values())
for candidate in all_candidates:
test_metrics = exp.call_metric_with_corr(candidate.corr, all_test_fns, things.test_data)
candidate.score_d.update(test_metrics)
print(f"Processed candidate: {candidate}")
corrs = [(candidate.corr, candidate.score_d) for candidate in all_candidates]
if return_ids:
return corrs, [candidate.threshold for candidate in all_candidates] # Corrected to return threshold
return corrs
When I run:
if not SKIP_ACDC:
project_path = "acdc/hybridretrieval/acdc_results/kbicr_indirect_kl_div_0.05"
acdc_corrs = get_acdc_runs(None if things is None else exp, project_path, clip=1 if TESTING else None, return_ids=True)
assert len(acdc_corrs) > 1
print("acdc_corrs", len(acdc_corrs), acdc_corrs)
I get an error on exp.load_from_local_run(log_file) saying that:
File ~/Mech-Interp/Automatic-Circuit-Discovery/acdc/TLACDCExperiment.py:889 in load_from_local_run
parent_name, parent_list, current_name, current_list = extract_info(previous_line)
File ~/Mech-Interp/Automatic-Circuit-Discovery/acdc/acdc_utils.py:219 in extract_info
parent_list = [ast.literal_eval(item if item != "COL" else "None") for item in parent_list_items]
File ~/Mech-Interp/Automatic-Circuit-Discovery/acdc/acdc_utils.py:219 in <listcomp>
parent_list = [ast.literal_eval(item if item != "COL" else "None") for item in parent_list_items]
File ~/.conda/envs/acdc/lib/python3.10/ast.py:62 in literal_eval
node_or_string = parse(node_or_string.lstrip(" \t"), mode='eval')
File ~/.conda/envs/acdc/lib/python3.10/ast.py:50 in parse
return compile(source, filename, mode, flags,
...
File <unknown>:1
:
^
SyntaxError: unexpected EOF while parsing
And adding a print statement to load_from_local_run in TLACDCExperiment class:
try:
parent_name, parent_list, current_name, current_list = extract_info(previous_line)
if parent_name and parent_list and current_name and current_list:
parent_torch_index, current_torch_index = TorchIndex(parent_list), TorchIndex(current_list)
# Debugging print statements
print(f"parent_torch_index: {parent_torch_index}")
print(f"current_torch_index: {current_torch_index}")
print(f"self.corr.edges keys: {self.corr.edges.keys()}")
if current_name in self.corr.edges:
print(f"self.corr.edges[{current_name}] keys: {self.corr.edges[current_name].keys()}")
if current_name in self.corr.edges and current_torch_index in self.corr.edges[current_name]:
print(f"self.corr.edges[{current_name}][{current_torch_index}] keys: {self.corr.edges[current_name][current_torch_index].keys()}")
self.corr.edges[current_name][current_torch_index][parent_name][parent_torch_index].present = keeping_connection
else:
print(f"Invalid data extracted from line: {previous_line}")
except SyntaxError as e:
print(f"Error parsing line: {previous_line}")
print(f"Exception: {e}")
continue # Skip this line and move to the next one
except KeyError as e:
print(f"KeyError for line: {previous_line}")
print(f"Exception: {e}")
continue # Skip this line and move to the next one
assert found_at_least_one_readable_line, f"No readable lines found in the log file. Is this formatted correctly ??? {lines=}"
EDIT: I think the error is caused by the format of the edge correspondence. The [:] symbol is parsed as an unexpected EOF in the cur_parent part of the string. Is the extract_info supposed to select the parent node from each line and iteratively add parents to a list (and the same to the current node)? If yes, then how can I work my way around the symbol causing the error without modifying the function?
The text was updated successfully, but these errors were encountered:
I setup ACDC for my custom task and saved runs as
log.txt
files as I do not have a wandb account. I'm trying to adapt theget_acdc_runs
function inroc_plot_generator.py
:When I run:
I get an error on
exp.load_from_local_run(log_file)
saying that:And adding a print statement to
load_from_local_run
in TLACDCExperiment class:returns this error for all edges:
Every run was saved to a
log.txt
with this command:I would really appreciate some help!
EDIT: I think the error is caused by the format of the edge correspondence. The
[:]
symbol is parsed as an unexpected EOF in thecur_parent
part of the string. Is theextract_info
supposed to select the parent node from each line and iteratively add parents to a list (and the same to the current node)? If yes, then how can I work my way around the symbol causing the error without modifying the function?The text was updated successfully, but these errors were encountered: