-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_v_network.py
33 lines (30 loc) · 1 KB
/
test_v_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from MCTS.sim import *
from RL.SupervisedPolicy import *
from RL.SupervisedValueNetwork import *
from MCTS.game import *
import numpy as np
import scipy.io as sio
import os
from tester import test_policy_vs_MCTS, test_policy_scenarios
import tensorflow as tf
from keras import backend as K
from amcts import *
import MCTS.mcts as mcts
def main():
"""
rl_agent = SupervisedPolicyAgent((144,144,3),7)
rl_agent.load_train_results()
rl_player = game.PolicyPlayer('algo_1', rl_agent)
print("Player Ready")
test_policy_scenarios(rl_player, True)
score,episode = test_policy_vs_MCTS(rl_player,verbose=True)
"""
rl_agent = SupervisedValueNetworkAgent((144,144,3))
rl_agent.load_train_results()
#pol = SupervisedPolicyAgent((144,144,3),7)
#pol.load_train_results()
amcts_v_1 = AMCTSPlayer('AMCTS_v_p1s', 0.2, value_agent=rl_agent,v_network_weight=0.1)
v_only = ValuePlayer('value_only', rl_agent)
score,episode = test_policy_vs_MCTS(amcts_v_1,verbose=True)
if __name__ == '__main__':
main()