diff --git a/examples/get_result_web.py b/examples/get_result_web.py index 2c0413da..bffe48fc 100644 --- a/examples/get_result_web.py +++ b/examples/get_result_web.py @@ -1,3 +1,4 @@ +import argparse import json import os from pathlib import Path @@ -9,20 +10,30 @@ from sympy import im from tqdm import tqdm +from dance.settings import METADIR from dance.utils import try_import # get yaml of best method def check_identical_strings(string_list): - """ - Check if all strings in the list are identical - Args: - string_list: List of strings to compare - Returns: + """Compare strings in a list to check if they are identical. + + Parameters + ---------- + string_list : list + List of strings to compare + + Returns + ------- + str The common string if all strings are identical - Raises: - ValueError if list is empty or strings are different + + Raises + ------ + ValueError + If list is empty or strings are different + """ if not string_list: raise ValueError("The list is empty") @@ -43,13 +54,20 @@ def check_identical_strings(string_list): def get_sweep_url(step_csv: pd.DataFrame, single=True): - """ - Extract wandb sweep URL from a DataFrame containing run IDs - Args: - step_csv: DataFrame containing run IDs - single: If True, only process the first run - Returns: - The sweep URL + """Extract Weights & Biases sweep URL from a DataFrame containing run IDs. + + Parameters + ---------- + step_csv : pd.DataFrame + DataFrame containing wandb run IDs in an 'id' column + single : bool, optional + If True, only process the first run, by default True + + Returns + ------- + str + The wandb sweep URL + """ ids = step_csv["id"] sweep_urls = [] @@ -68,12 +86,19 @@ def get_sweep_url(step_csv: pd.DataFrame, single=True): def spilt_web(url: str): - """ - Parse wandb URL to extract entity, project and sweep ID - Args: - url: wandb sweep URL - Returns: - Tuple of (entity, project, sweep_id) or None if parsing fails + """Parse Weights & Biases URL to extract entity, project and sweep components. + + Parameters + ---------- + url : str + Complete wandb sweep URL + + Returns + ------- + tuple or None + Tuple of (entity, project, sweep_id) if parsing succeeds + None if parsing fails + """ pattern = r"https://wandb\.ai/([^/]+)/([^/]+)/sweeps/([^/]+)" @@ -94,13 +119,23 @@ def spilt_web(url: str): def get_best_method(urls, metric_col="test_acc"): - """ - Find the best performing method across multiple sweeps - Args: - urls: List of sweep URLs to compare - metric_col: Metric column name to use for comparison - Returns: - Tuple of (best_step_name, best_run, best_metric_value) + """Find the best performing method across multiple wandb sweeps. + + Parameters + ---------- + urls : list + List of wandb sweep URLs to compare + metric_col : str, optional + Metric column name to use for comparison, by default "test_acc" + + Returns + ------- + tuple + (best_step_name, best_run, best_metric_value) where: + - best_step_name: name of the step with best performance + - best_run: wandb run object of best performing run + - best_metric_value: value of the metric for best run + """ all_best_run = None all_best_step_name = None @@ -137,14 +172,22 @@ def get_metric(run): def get_best_yaml(step_name, best_run, file_path): - """ - Generate YAML configuration for the best performing run - Args: - step_name: Name of the step ('step2' or 'step3_X') - best_run: Best wandb run object - file_path: Path to configuration files - Returns: + """Generate YAML configuration for the best performing wandb run. + + Parameters + ---------- + step_name : str + Name of the step ('step2' or 'step3_X') + best_run : wandb.Run + Best performing wandb run object + file_path : str + Path to configuration files + + Returns + ------- + str YAML string containing the best configuration + """ if step_name == "step2": conf = OmegaConf.load(f"{file_path}/pipeline_params_tuning_config.yaml") @@ -182,12 +225,18 @@ def get_best_yaml(step_name, best_run, file_path): def check_exist(file_path): - """ - Check if results directory exists and contains multiple files - Args: - file_path: Path to check - Returns: - Boolean indicating if valid results exist + """Check if results directory exists and contains multiple result files. + + Parameters + ---------- + file_path : str + Path to check for results + + Returns + ------- + bool + True if valid results exist (directory exists and contains >1 file) + """ file_path = f"{file_path}/results/params/" if os.path.exists(file_path) and os.path.isdir(file_path): @@ -197,19 +246,16 @@ def check_exist(file_path): return False -def write_ans(tissue): - """ - Process results for a specific tissue type and write to CSV - Args: - tissue: Name of the tissue to process - Writes results to a CSV file named '{tissue}_ans.csv' - """ +def get_new_ans(tissue): ans = [] - collect_datasets = all_datasets[tissue] + collect_datasets = [ + collect_dataset.split(tissue)[1].split("_")[0] + for collect_dataset in all_datasets[all_datasets["tissue"] == tissue]["data_fname"].tolist() + ] for method_folder in tqdm(collect_datasets): for dataset_id in collect_datasets[method_folder]: - file_path = f"{file_root}/{method_folder}/{dataset_id}" + file_path = f"tuning/{method_folder}/{dataset_id}" if not check_exist(file_path): continue step2_url = get_sweep_url(pd.read_csv(f"{file_path}/results/pipeline/best_test_acc.csv")) @@ -232,20 +278,78 @@ def write_ans(tissue): }) # with open('temp_ans.json', 'w') as f: # json.dump(ans, f,indent=4) - pd.DataFrame(ans).to_csv(f"{tissue}_ans.csv") + new_df = pd.DataFrame(ans) + return new_df + + +def write_ans(tissue, new_df): + """Process and write results for a specific tissue type to CSV. + + Parameters + ---------- + tissue : str + Name of the tissue to process + + Notes + ----- + Writes results to '{tissue}_ans.csv' containing: + - Dataset IDs + - Sweep URLs for each step + - Best performing YAML configurations + - Best result metrics + + """ + # 检查是否存在现有文件 + output_file = f"atlas/sweep_results/{tissue}_ans.csv" + if os.path.exists(output_file): + existing_df = pd.read_csv(output_file, index_col=0) + # 设置Dataset_id为索引以便更容易合并 + existing_df.set_index('Dataset_id', inplace=True) + new_df.set_index('Dataset_id', inplace=True) + # 检查重叠的Dataset_id + common_indices = existing_df.index.intersection(new_df.index) + + # 对于每个重叠的Dataset_id,检查是否有冲突 + for idx in common_indices: + for col in existing_df.columns.intersection(new_df.columns): + if not str(col).endswith("_best_res"): + continue + existing_value = existing_df.loc[idx, col] + new_value = new_df.loc[idx, col] + + # 如果两者都不是NaN且值不同 + if (pd.notna(existing_value) and pd.notna(new_value) + and (not isinstance(existing_value, float) or abs(existing_value - new_value) > 1e-10)): + raise ValueError(f"结果冲突: Dataset {idx}, Column {col}\n" + f"现有值: {existing_value}\n新值: {new_value}") + + # 合并数据 + # 1. 对于重叠的index,使用update更新非NaN的值 + existing_df.update(new_df) + + # 2. 添加仅在new_df中存在的行 + new_indices = new_df.index.difference(existing_df.index) + if len(new_indices) > 0: + existing_df = pd.concat([existing_df, new_df.loc[new_indices]]) + + # 重置索引并保存 + existing_df.to_csv(output_file) + else: + # 如果文件不存在,直接写入新文件 + new_df.to_csv(output_file) + + +wandb = try_import("wandb") +entity = "xzy11632" +project = "dance-dev" if __name__ == "__main__": # Initialize wandb and set global configuration - wandb = try_import("wandb") - entity = "xzy11632" - project = "dance-dev" - # Load dataset configuration and process results for tissue - file_root = str(Path(__file__).resolve().parent) - with open(f"{file_root}/dataset_server.json") as f: - all_datasets = json.load(f) - file_root = "./tuning" - tissues = ["heart"] - for tissue in tissues: - write_ans(tissue) + all_datasets = pd.read_csv(METADIR / "scdeepsort.csv", header=0, skiprows=[i for i in range(1, 69)]) + args = argparse.ArgumentParser() + args.add_argument("--tissue", type=str, default="heart") + args = args.parse_args() + new_df = get_new_ans(args.tissue) + write_ans(args.tissue, new_df) diff --git a/tests/test_get_result_web.py b/tests/test_get_result_web.py new file mode 100644 index 00000000..257c74d7 --- /dev/null +++ b/tests/test_get_result_web.py @@ -0,0 +1,138 @@ +from pathlib import Path + +import pandas as pd +import pytest + +from examples.get_result_web import check_exist, check_identical_strings, spilt_web + + +# 测试 check_identical_strings 函数 +def test_check_identical_strings(): + # 测试相同字符串的情况 + assert check_identical_strings(["test", "test", "test"]) == "test" + + # 测试空列表 + with pytest.raises(ValueError, match="The list is empty"): + check_identical_strings([]) + + # 测试不同字符串 + with pytest.raises(ValueError, match="Different strings found"): + check_identical_strings(["test1", "test2"]) + + +# 测试 spilt_web 函数 +def test_spilt_web(): + # 测试有效的URL + url = "https://wandb.ai/user123/project456/sweeps/abc789" + result = spilt_web(url) + assert result == ("user123", "project456", "abc789") + + # 测试无效的URL + invalid_url = "https://invalid-url.com" + assert spilt_web(invalid_url) is None + + +# 测试 check_exist 函数 +def test_check_exist(tmp_path): + # 创建临时测试目录 + results_dir = tmp_path / "results" / "params" + results_dir.mkdir(parents=True) + + # 测试空目录 + assert check_exist(str(tmp_path)) is False + + # 创建测试文件 + (results_dir / "file1.txt").touch() + (results_dir / "file2.txt").touch() + + # 测试有多个文件的情况 + assert check_exist(str(tmp_path)) is True + + +# 创建测试固定装置 +@pytest.fixture +def sample_df(): + return pd.DataFrame({"id": ["run1", "run2", "run3"], "metric": [0.8, 0.9, 0.7]}) + + +# 如果需要模拟wandb API,可以使用mock +@pytest.fixture +def mock_wandb(mocker): + mock_api = mocker.patch("wandb.Api") + # 这里可以设置mock的返回值 + return mock_api + + +def test_write_ans(tmp_path): + # 模拟 atlas/sweep_results 目录 + sweep_results_dir = tmp_path / "atlas" / "sweep_results" + sweep_results_dir.mkdir(parents=True) + + # 创建测试数据 + existing_data = pd.DataFrame({ + 'Dataset_id': ['dataset1', 'dataset2', 'dataset3'], + 'method1': ['url1', 'url2', 'url3'], + 'method1_best_yaml': ['yaml1', 'yaml2', 'yaml3'], + 'method1_best_res': [0.8, 0.9, 0.7] + }) + + new_data = pd.DataFrame({ + 'Dataset_id': ['dataset2', 'dataset3', 'dataset4'], # 部分重叠的数据 + 'method1': ['url2_new', 'url3_new', 'url4'], + 'method1_best_yaml': ['yaml2_new', 'yaml3_new', 'yaml4'], + 'method1_best_res': [0.9, 0.7, 0.85] # dataset2和dataset3的结果与现有数据相同 + }) + + # 写入现有数据 + output_file = sweep_results_dir / "heart_ans.csv" + existing_data.to_csv(output_file) + + # 测试写入新数据 + from examples.get_result_web import write_ans + write_ans("heart", new_data) + + # 读取合并后的结果 + merged_df = pd.read_csv(output_file, index_col=0) + + # 验证结果 + assert len(merged_df) == 4 # 应该有4个唯一的Dataset_id + assert 'dataset4' in merged_df.index # 新数据被添加 + assert merged_df.loc['dataset2', 'method1'] == 'url2_new' # 更新了已存在的数据 + + # 测试结果冲突的情况 + conflicting_data = pd.DataFrame({ + 'Dataset_id': ['dataset1'], + 'method1': ['url1_new'], + 'method1_best_yaml': ['yaml1_new'], + 'method1_best_res': [0.95] # 不同的结果值 + }) + + # 验证冲突数据会引发异常 + with pytest.raises(ValueError, match="结果冲突"): + write_ans("heart", conflicting_data) + + +# 测试完全新的数据写入(文件不存在的情况) +def test_write_ans_new_file(tmp_path): + # 模拟 atlas/sweep_results 目录 + sweep_results_dir = tmp_path / "atlas" / "sweep_results" + sweep_results_dir.mkdir(parents=True) + + new_data = pd.DataFrame({ + 'Dataset_id': ['dataset1', 'dataset2'], + 'method1': ['url1', 'url2'], + 'method1_best_yaml': ['yaml1', 'yaml2'], + 'method1_best_res': [0.8, 0.9] + }) + + # 测试写入新文件 + from examples.get_result_web import write_ans + write_ans("heart", new_data) + + # 验证文件被创建并包含正确的数据 + output_file = sweep_results_dir / "heart_ans.csv" + assert output_file.exists() + + written_df = pd.read_csv(output_file, index_col=0) + assert len(written_df) == 2 + assert all(written_df.index == ['dataset1', 'dataset2'])