Skip to content

Commit

Permalink
update get_result_web
Browse files Browse the repository at this point in the history
  • Loading branch information
xingzhongyu committed Dec 25, 2024
1 parent 814f39e commit ef30006
Show file tree
Hide file tree
Showing 2 changed files with 303 additions and 61 deletions.
226 changes: 165 additions & 61 deletions examples/get_result_web.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
import json
import os
from pathlib import Path
Expand All @@ -9,20 +10,30 @@
from sympy import im
from tqdm import tqdm

from dance.settings import METADIR
from dance.utils import try_import

# get yaml of best method


def check_identical_strings(string_list):
"""
Check if all strings in the list are identical
Args:
string_list: List of strings to compare
Returns:
"""Compare strings in a list to check if they are identical.
Parameters
----------
string_list : list
List of strings to compare
Returns
-------
str
The common string if all strings are identical
Raises:
ValueError if list is empty or strings are different
Raises
------
ValueError
If list is empty or strings are different
"""
if not string_list:
raise ValueError("The list is empty")
Expand All @@ -43,13 +54,20 @@ def check_identical_strings(string_list):


def get_sweep_url(step_csv: pd.DataFrame, single=True):
"""
Extract wandb sweep URL from a DataFrame containing run IDs
Args:
step_csv: DataFrame containing run IDs
single: If True, only process the first run
Returns:
The sweep URL
"""Extract Weights & Biases sweep URL from a DataFrame containing run IDs.
Parameters
----------
step_csv : pd.DataFrame
DataFrame containing wandb run IDs in an 'id' column
single : bool, optional
If True, only process the first run, by default True
Returns
-------
str
The wandb sweep URL
"""
ids = step_csv["id"]
sweep_urls = []
Expand All @@ -68,12 +86,19 @@ def get_sweep_url(step_csv: pd.DataFrame, single=True):


def spilt_web(url: str):
"""
Parse wandb URL to extract entity, project and sweep ID
Args:
url: wandb sweep URL
Returns:
Tuple of (entity, project, sweep_id) or None if parsing fails
"""Parse Weights & Biases URL to extract entity, project and sweep components.
Parameters
----------
url : str
Complete wandb sweep URL
Returns
-------
tuple or None
Tuple of (entity, project, sweep_id) if parsing succeeds
None if parsing fails
"""
pattern = r"https://wandb\.ai/([^/]+)/([^/]+)/sweeps/([^/]+)"

Expand All @@ -94,13 +119,23 @@ def spilt_web(url: str):


def get_best_method(urls, metric_col="test_acc"):
"""
Find the best performing method across multiple sweeps
Args:
urls: List of sweep URLs to compare
metric_col: Metric column name to use for comparison
Returns:
Tuple of (best_step_name, best_run, best_metric_value)
"""Find the best performing method across multiple wandb sweeps.
Parameters
----------
urls : list
List of wandb sweep URLs to compare
metric_col : str, optional
Metric column name to use for comparison, by default "test_acc"
Returns
-------
tuple
(best_step_name, best_run, best_metric_value) where:
- best_step_name: name of the step with best performance
- best_run: wandb run object of best performing run
- best_metric_value: value of the metric for best run
"""
all_best_run = None
all_best_step_name = None
Expand Down Expand Up @@ -137,14 +172,22 @@ def get_metric(run):


def get_best_yaml(step_name, best_run, file_path):
"""
Generate YAML configuration for the best performing run
Args:
step_name: Name of the step ('step2' or 'step3_X')
best_run: Best wandb run object
file_path: Path to configuration files
Returns:
"""Generate YAML configuration for the best performing wandb run.
Parameters
----------
step_name : str
Name of the step ('step2' or 'step3_X')
best_run : wandb.Run
Best performing wandb run object
file_path : str
Path to configuration files
Returns
-------
str
YAML string containing the best configuration
"""
if step_name == "step2":
conf = OmegaConf.load(f"{file_path}/pipeline_params_tuning_config.yaml")
Expand Down Expand Up @@ -182,12 +225,18 @@ def get_best_yaml(step_name, best_run, file_path):


def check_exist(file_path):
"""
Check if results directory exists and contains multiple files
Args:
file_path: Path to check
Returns:
Boolean indicating if valid results exist
"""Check if results directory exists and contains multiple result files.
Parameters
----------
file_path : str
Path to check for results
Returns
-------
bool
True if valid results exist (directory exists and contains >1 file)
"""
file_path = f"{file_path}/results/params/"
if os.path.exists(file_path) and os.path.isdir(file_path):
Expand All @@ -197,19 +246,16 @@ def check_exist(file_path):
return False


def write_ans(tissue):
"""
Process results for a specific tissue type and write to CSV
Args:
tissue: Name of the tissue to process
Writes results to a CSV file named '{tissue}_ans.csv'
"""
def get_new_ans(tissue):
ans = []
collect_datasets = all_datasets[tissue]
collect_datasets = [
collect_dataset.split(tissue)[1].split("_")[0]
for collect_dataset in all_datasets[all_datasets["tissue"] == tissue]["data_fname"].tolist()
]

for method_folder in tqdm(collect_datasets):
for dataset_id in collect_datasets[method_folder]:
file_path = f"{file_root}/{method_folder}/{dataset_id}"
file_path = f"tuning/{method_folder}/{dataset_id}"
if not check_exist(file_path):
continue
step2_url = get_sweep_url(pd.read_csv(f"{file_path}/results/pipeline/best_test_acc.csv"))
Expand All @@ -232,20 +278,78 @@ def write_ans(tissue):
})
# with open('temp_ans.json', 'w') as f:
# json.dump(ans, f,indent=4)
pd.DataFrame(ans).to_csv(f"{tissue}_ans.csv")
new_df = pd.DataFrame(ans)
return new_df


def write_ans(tissue, new_df):
"""Process and write results for a specific tissue type to CSV.
Parameters
----------
tissue : str
Name of the tissue to process
Notes
-----
Writes results to '{tissue}_ans.csv' containing:
- Dataset IDs
- Sweep URLs for each step
- Best performing YAML configurations
- Best result metrics
"""
# 检查是否存在现有文件
output_file = f"atlas/sweep_results/{tissue}_ans.csv"
if os.path.exists(output_file):
existing_df = pd.read_csv(output_file, index_col=0)

# 设置Dataset_id为索引以便更容易合并
existing_df.set_index('Dataset_id', inplace=True)
new_df.set_index('Dataset_id', inplace=True)

# 检查重叠的Dataset_id
common_indices = existing_df.index.intersection(new_df.index)

# 对于每个重叠的Dataset_id,检查是否有冲突
for idx in common_indices:
for col in existing_df.columns.intersection(new_df.columns):
if not str(col).endswith("_best_res"):
continue
existing_value = existing_df.loc[idx, col]
new_value = new_df.loc[idx, col]

# 如果两者都不是NaN且值不同
if (pd.notna(existing_value) and pd.notna(new_value)
and (not isinstance(existing_value, float) or abs(existing_value - new_value) > 1e-10)):
raise ValueError(f"结果冲突: Dataset {idx}, Column {col}\n"
f"现有值: {existing_value}\n新值: {new_value}")

# 合并数据
# 1. 对于重叠的index,使用update更新非NaN的值
existing_df.update(new_df)

# 2. 添加仅在new_df中存在的行
new_indices = new_df.index.difference(existing_df.index)
if len(new_indices) > 0:
existing_df = pd.concat([existing_df, new_df.loc[new_indices]])

# 重置索引并保存
existing_df.to_csv(output_file)
else:
# 如果文件不存在,直接写入新文件
new_df.to_csv(output_file)


wandb = try_import("wandb")
entity = "xzy11632"
project = "dance-dev"
if __name__ == "__main__":
# Initialize wandb and set global configuration
wandb = try_import("wandb")
entity = "xzy11632"
project = "dance-dev"

# Load dataset configuration and process results for tissue
file_root = str(Path(__file__).resolve().parent)
with open(f"{file_root}/dataset_server.json") as f:
all_datasets = json.load(f)
file_root = "./tuning"
tissues = ["heart"]
for tissue in tissues:
write_ans(tissue)
all_datasets = pd.read_csv(METADIR / "scdeepsort.csv", header=0, skiprows=[i for i in range(1, 69)])
args = argparse.ArgumentParser()
args.add_argument("--tissue", type=str, default="heart")
args = args.parse_args()
new_df = get_new_ans(args.tissue)
write_ans(args.tissue, new_df)
Loading

0 comments on commit ef30006

Please sign in to comment.