forked from techthiyanes/lm-eval2
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain_eval.py
executable file
·90 lines (71 loc) · 2.44 KB
/
main_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import argparse
import json
import logging
import fnmatch
import wandb
from pathlib import Path
from typing import Union
import yaml
from pydantic import BaseModel
from lm_eval import tasks, evaluator
logging.getLogger("openai").setLevel(logging.WARNING)
def load_config(path: Union[str, Path]):
with open(path, "r") as stream:
try:
return yaml.safe_load(stream)
except yaml.YAMLError as exc:
print(exc)
class EvalPipelineConfig(BaseModel):
model: str
model_args: str = ""
tasks: str = None # check the types
num_fewshot: int = 0
batch_size: int = None
device: str = None
limit: int = None
decontamination_ngrams_path: str = None
check_integrity: bool = False
wandb_log: bool = False
wandb_project: str = None
wandb_run_name: str = None
# Returns a list containing all values of the source_list that
# match at least one of the patterns
def pattern_match(patterns, source_list):
task_names = set()
for pattern in patterns:
for matching in fnmatch.filter(source_list, pattern):
task_names.add(matching)
return list(task_names)
def main(config_path: str) -> None:
raw_config = load_config(config_path)
args = EvalPipelineConfig(**raw_config)
if args.wandb_log:
assert (args.wandb_project is not None) and (args.wandb_run_name is not None)
wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=args)
if args.tasks is None:
task_names = tasks.ALL_TASKS
else:
task_names = pattern_match(args.tasks.split(","), tasks.ALL_TASKS)
print(f"Selected Tasks: {task_names}")
results = evaluator.simple_evaluate(
model=args.model,
model_args=args.model_args,
tasks=task_names,
num_fewshot=args.num_fewshot,
batch_size=args.batch_size,
device=args.device,
limit=args.limit,
decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity,
)
dumped = json.dumps(results, indent=2)
print(dumped)
if args.wandb_log:
# TODO: where is "filter" coming from?
for task, metrics in results["results"].items():
wandb.log({task.split()[0]: metrics})
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("config_path", help="The full path to the YAML config file.")
args = parser.parse_args()
main(args.config_path)