-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkgrag4sm_main.py
122 lines (100 loc) · 4.5 KB
/
kgrag4sm_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import pandas as pd
import sys
import logging
import argparse
from datetime import datetime
from tqdm import tqdm
from src.utils import get_devices, setup_kgrag_logging, extract_label, calculate_metrics
from src.llm import initialize_llm_model
from src.kgrag4sm import KGRAG_for_Schema_Matching
def get_data_file(dataset: str) -> str:
dataset_mapping = {
"cms": "datasets/reproduce/test_cms_q_with_paths.xlsx",
"mimic": "datasets/reproduce/test_mimic_q_with_paths.xlsx",
"synthea": "datasets/reproduce/test_synthea_q_with_paths.xlsx",
"emed": "datasets/reproduce/test_emed_q_with_paths.xlsx"
}
return dataset_mapping.get(dataset)
def get_paths_column(retrieved_paths: str) -> int:
paths_mapping = {
"llm_entity_retrieval_bfs_paths": 10,
"llm_paths": 11,
"vector_entity_retrieval_bfs_paths": 12,
"ver_bfs_top1": 13,
"ver_bfs_top2": 14,
"vector_kg_triples_retrieval_paths": 15,
"v_kg_tr_top1": 16,
"v_kg_tr_top2": 17
}
return paths_mapping.get(retrieved_paths)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='cms', help='dataset: mimic, mimic, synthea, emed')
parser.add_argument('--backbone_llm_model', type=str, default='jellyfish-8b',
help='Model name (e.g., gpt-4o-mini, jellyfish-8b, jellyfish-7b, mistral-7b)')
parser.add_argument('--log_dir', type=str, default='logs/', help='Directory for storing logs')
parser.add_argument('--device', type=str, default='cuda', help='cuda or cpu')
parser.add_argument('--retrieved_paths', type=str, default="llm_entity_retrieval_bfs_paths",
help='llm_paths, llm_entity_retrieval_bfs_paths, vector_entity_retrieval_bfs_paths, ver_bfs_top1, ver_bfs_top2, vector_kg_triples_retrieval_paths, v_kg_tr_top1, v_kg_tr_top2')
args = parser.parse_args()
start_time = datetime.now()
devices = get_devices(args.device)
# Setup logging
log_filename = setup_kgrag_logging(args.log_dir, args.dataset, args.backbone_llm_model, args.retrieved_paths)
# Initialize llm model
try:
model = initialize_llm_model(args.backbone_llm_model, devices)
except Exception as e:
logging.error(f"Error initializing llm model: {e}")
sys.exit(1)
# Load data
try:
data_file = get_data_file(args.dataset)
if not data_file:
raise ValueError(f"Invalid dataset: {args.dataset}")
reader = pd.read_excel(data_file)
except Exception as e:
logging.error(f"Error reading data file: {e}")
sys.exit(1)
paths_column = get_paths_column(args.retrieved_paths)
if paths_column is None:
logging.error(f"Invalid retrieved paths: {args.retrieved_paths}")
sys.exit(1)
kgrag4sm = KGRAG_for_Schema_Matching()
y_true = []
y_pred = []
for i in tqdm(range(reader.shape[0]), desc="Processing questions"):
try:
question = reader.iloc[i, 9]
paths = reader.iloc[i, paths_column]
ground_truth = reader.iloc[i, 4]
if pd.isna(paths):
paths = None
_, _, response = kgrag4sm.kgrag_query_for_schema_matching(question, paths, model)
label = extract_label(response)
if label != -1 and not pd.isna(ground_truth):
y_true.append(int(ground_truth))
y_pred.append(label)
logging.info(f"\nResults for row {i+1}:")
logging.info(f"Response: {response}")
logging.info(f"Extracted Label: {label}")
logging.info(f"Ground Truth: {ground_truth}")
except Exception as e:
logging.error(f"Error processing row {i+1}: {e}")
continue
# Calculate and log metrics
precision, recall, f1, accuracy = calculate_metrics(y_true, y_pred)
logging.info("\nFinal Metrics:")
logging.info(f"Precision: {precision:.4f}")
logging.info(f"Recall: {recall:.4f}")
logging.info(f"F1 Score: {f1:.4f}")
logging.info(f"Accuracy: {accuracy:.4f}")
end_time = datetime.now()
duration = end_time - start_time
logging.info("\nExecution Time Summary:")
logging.info(f"Start Time: {start_time}")
logging.info(f"End Time: {end_time}")
logging.info(f"Total Duration: {duration}")
print(f"Resutls are recorded in the following log file:{log_filename}")
if __name__ == "__main__":
main()