forked from OpenBioLink/SimulateGPT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSnakefile
61 lines (53 loc) · 1.63 KB
/
Snakefile
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from pathlib import Path
import itertools
import os
configfile: "config.yaml"
experiment_name = config["experiment_name"]
human_names = [
name.stem
for name in Path(f"experiments/{experiment_name}/prompts").iterdir()
if name.is_file()
]
combinations = list(itertools.product(config["system_names"], human_names))
rule all:
input:
f"reference_analysis/{experiment_name}_all.csv"
rule simulate:
input:
system_message=ancient("system_messages/{sys}"),
human_message=ancient("experiments/{experiment_name}/prompts/{human}")
output:
protected("experiments/{experiment_name}/ai_messages/{sys}--{human}")
conda: "env.yml"
retries: 3
script:
"src/rule_simulate.py"
rule analyze_references:
input:
ancient("experiments/{experiment_name}/ai_messages/{sys}--{human}")
output:
"reference_analysis/cross_ref/{experiment_name}--{sys}--{human}.csv"
conda: "env.yml"
script:
"src/check_references.py"
rule match_references:
input:
rules.analyze_references.output
output:
"reference_analysis/matched/{experiment_name}--{sys}--{human}.csv"
conda: "env.yml"
script:
"src/match_references.py"
rule aggregate_matched_references:
input:
["reference_analysis/matched/{}--{}--{}.csv".format(experiment_name, sys, human) for sys, human in combinations]
output:
f"reference_analysis/{experiment_name}_all.csv"
run:
import pandas as pd
dfs = [
pd.read_csv(input_fn)
for input_fn in input
]
df = pd.concat(dfs)
df.to_csv(output[0], index=False)