-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
34 lines (29 loc) · 952 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import os
from config import *
from shared_storage import SharedStorage
from self_play import run_selfplay, run_eval
from CartPole import make_cartpole_config
from LunarLander import make_lunarlander_config
from FrozenLake import make_frozenlake_config
import torch
class MuZero:
def __init__(self, config):
self.config = config
def test(self):
model = self.config.new_network()
model.set_weights(torch.load(self.config.result_path))
storage = SharedStorage(model, model)
score = run_eval(self.config, storage)
print(f"Test score: {score}")
if __name__ == '__main__':
task = {
"0": make_cartpole_config,
"1": make_lunarlander_config,
"2": make_frozenlake_config,
}
try:
choice = input("choose a task (0: CartPole, 1: LunarLander, 2: Frozenlake): ")
alg = MuZero(task[choice]())
alg.test()
except:
print("invalid input")