-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrandom_test.py
87 lines (76 loc) · 2.6 KB
/
random_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
from utils.mouselab_flat import MouselabEnv
from distributions import Normal, Categorical
import random
import math
import time
import pandas as pd
from itertools import compress
import numpy as np
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('no_goals', type=str)
args = parser.parse_args()
NO_OPTION = int(args.no_goals)
cwd = os.getcwd()
cwd += '/' + str(NO_OPTION) + '_' + str(NO_OPTION * 18)
TREE_1 = np.load(cwd + '/tree.npy')
DIST = np.load(cwd + '/dist.npy')
TREE = []
for t in TREE_1:
TREE.append(t)
OPTION_SET = np.load(cwd + '/option_set.npy')
BRANCH_COST = 1
SWITCH_COST = 1
SEED = 0
TAU = 20
NO_BINS = 4
NO_OPTION = 2
BRANCH_COST = 1
SWITCH_COST = 1
SEED = 0
TAU = 20
node_types = []
for tpe in DIST:
node_types.append(tpe)
def reward(i):
global node_types
sigma_val = {'V1': 5, 'V2': 10, 'V3': 20, 'V4': 40, 'G1': 100, 'G2': 120, 'G3': 140, 'G4': 160, 'G5': 180}
return Normal(mu=0, sigma=sigma_val[node_types[i]])
try:
os.makedirs(cwd + '/Random_Results')
except:
pass
num_episodes = 100
# env._state = env.discretize(env._state, 4)
# for i,s in enumerate(env._state):
# print(i,s)
cum_reward = 0
tic = time.time()
df = pd.DataFrame(columns=['i', 'return','actions','Actual Path','Time', 'ground_truth'])
for i in range(num_episodes):
print("i = {}".format(i))
env = MouselabEnv.new(NO_OPTION, TREE, reward=reward, branch_cost=BRANCH_COST, switch_cost=SWITCH_COST, tau=TAU,
seed=SEED+i)
env_tic = time.time()
exp_reward = 0
actions = []
while True:
# print("Env State: {}".format(env._state))
action_possible = list(env.actions(env._state))
action = random.choice(action_possible)
actions.append(action)
# print("Action Taken: {}".format(action))
_, rew, done, _=env._step_actual(action)
exp_reward += rew
if done:
break
env_toc = time.time()
df.loc[i] = [i, exp_reward, actions, env.actual_path(env._state), env_toc - env_tic, env.ground_truth]
cum_reward += exp_reward
df.to_csv(cwd + '/Random_Results/random_'+ str(NO_BINS)+ '.csv')
np.save(cwd + '/Random_Results/CumResult_' + str(NO_BINS), cum_reward / num_episodes)
print(cum_reward / num_episodes)
toc = time.time()
np.save(cwd + '/Random_Results/Eval_Time_' + str(NO_BINS), toc - tic)