-
Notifications
You must be signed in to change notification settings - Fork 25
/
run.py
95 lines (75 loc) · 2.82 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""Run a task defined in tasks.
Example usage:
python run.py reverse_config
python run.py dyck_config
python run.py dyck_config --model BufferedModel
"""
import argparse
from copy import copy
from tasks import Task
from models import *
from controllers import *
from structs import *
from configs import *
from visualization import *
def get_args():
parser = argparse.ArgumentParser(
description="Run a task and customize hyperparameters.")
parser.add_argument("config", type=str)
# Manually specified parameters override those in configs.
parser.add_argument("--model", type=str, default=None)
parser.add_argument("--controller", type=str, default=None)
parser.add_argument("--struct", type=str, default=None)
parser.add_argument("--visualizer", type=str, default=None)
# Path arguments for loading and saving models.
parser.add_argument("--loadpath", type=str, default=None)
parser.add_argument("--savepath", type=str, default=None)
return parser.parse_args()
def get_object_from_arg(arg, superclass, default=None):
"""
Verify that arg refers to an instance of superclass.
If so, return the instance.
Otherwise, throw an error.
"""
if arg is None:
return default
if arg not in globals():
raise ValueError("Invalid argument {}".format(arg))
obj = globals()[arg]
if not (isinstance(obj, superclass) or issubclass(obj, superclass)):
raise TypeError("{} is not a {}".format(arg, str(superclass)))
return obj
def main(config,
model_type=None,
controller_type=None,
struct_type=None,
visualizer_type=None,
load_path=None,
save_path=None):
config = copy(config)
if model_type is not None:
config["model_type"] = model_type
if controller_type is not None:
config["controller_type"] = controller_type
if struct_type is not None:
config["struct_type"] = struct_type
if load_path is not None:
config["load_path"] = load_path
if save_path is not None:
config["save_path"] = save_path
task = Task.from_config_dict(config)
metrics = task.run_experiment()
if visualizer_type is not None:
visualizer = visualizer_type(task)
visualizer.visualize_generic_example()
return metrics
if __name__ == "__main__":
args = get_args()
print("Loading {}".format(args.config))
config = get_object_from_arg(args.config, dict)
model_type = get_object_from_arg(args.model, Model)
controller_type = get_object_from_arg(args.controller, SimpleStructController)
struct_type = get_object_from_arg(args.struct, Struct)
visualizer_type = get_object_from_arg(args.visualizer, Visualizer)
main(config, model_type, controller_type, struct_type, visualizer_type,
args.loadpath, args.savepath)