-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsim.py
129 lines (108 loc) · 4.59 KB
/
sim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# show off all combinatations
import argparse
import numpy as np
import gymnasium as gym
from sb3_contrib.tqc.policies import MultiInputPolicy
from robosuite_envs import *
from pointcloud_vision import *
# parse arguments
sensors = {
'default': None,
'passthru': PassthroughSensor,
'PC': PointCloudSensor,
}
encoders = {
'default': None,
'passthru': PassthroughEncoder,
'AE': GlobalAEEncoder,
'Seg': GlobalSegmenterEncoder,
'MultiSeg': MultiSegmenterEncoder,
'StatePred': StatePredictor,
'StatePredVisGoal': StatePredictorVisualGoal,
}
parser = argparse.ArgumentParser()
parser.add_argument('env', type=str, help='environment ID')
parser.add_argument('--horizon', type=int, default=100, help='horizon')
parser.add_argument('--sensor', default='default', choices=list(sensors.keys()), help='sensor')
parser.add_argument('--encoder', default='default', choices=list(encoders.keys()), help='observation encoder')
parser.add_argument('--passive_encoder', default='', choices=list(encoders.keys()), help='passive encoder just for goal checking and visualization')
parser.add_argument('--policy', default='', type=str, help='path to policy file')
parser.add_argument('--benchmark', default=None, type=int, help='number of episodes to run for benchmarking')
a = parser.parse_args()
# load environment
kwargs = {'sensor': sensors[a.sensor], 'encoder': encoders[a.encoder]}
if kwargs['encoder'] and kwargs['encoder'].requires_vision or a.passive_encoder and encoders[a.passive_encoder].requires_vision:
kwargs['sensor'] = PointCloudSensor
env = gym.make(a.env, render_mode='human', max_episode_steps=a.horizon, **{k: v for k,v in kwargs.items() if v})
# create passive encoder
if a.passive_encoder and encoders[a.passive_encoder]:
env.reset() # to get first obs
pe = encoders[a.passive_encoder](env, env.obs_keys, env.goal_keys)
if type(pe) is StatePredictor:
pe.passthrough_goal = False
pe_goal = pe.encode_goal(env.goal_obs)
def show_sucess(h, w):
# swap out the encoders temporarily
env.unwrapped.encoder, orig = pe, env.encoder
pe_achieved = pe.encode_goal(env.observation)
pe_succ = env.check_success(pe_achieved, pe_goal, info=None)
env.unwrapped.encoder = orig # restore original encoder
overlay = np.zeros((h, w, 3))
overlay[h-2:h, :, :] = [0, 1, 0] if pe_succ else [1, 0, 0]
return overlay
env.unwrapped.overlay = show_sucess
else:
pe = None
# load policy
if a.policy:
agent = MultiInputPolicy.load(a.policy)
else:
agent = None
agent_input_dim = env.observation_space['observation'].shape[0] + env.observation_space['desired_goal'].shape[0]
agent_output_dim = env.action_space.shape[0]
assert all(-env.action_space.low == env.action_space.high)
agent_action_limit = env.action_space.high
# simulation
ep_rewards = []
ep_is_success = []
run = True
while run:
obs, info = env.reset()
if pe:
pe_goal = pe.encode_goal(env.goal_obs)
total_reward = 0
for t in range(a.horizon):
# select action
if agent:
action, _states = agent.predict(obs, deterministic=True)
else:
action = np.random.randn(agent_output_dim)
# take action in the environment
obs, reward, terminated, truncated, info = env.step(action)
# update results
total_reward += reward
if env.viewer.is_pressed('g'): # show goal state
env.show_frame(env.goal_state, None)
if env.viewer.is_pressed('v'): # save visual goal
# pickle current robo obs
import pickle
with open(f'pointcloud_vision/input/{env.scene}/{a.env}_visual_goal.pkl', 'wb') as f:
pickle.dump(env.raw_state, f)
print('saved visual goal state')
if env.viewer.is_pressed('b'): # benchmark mean reward and success
print("episodes = ", len(ep_rewards))
print(f"mean reward = {np.mean(ep_rewards)}")
print(f"median reward = {np.median(ep_rewards)}")
print(f"success rate = {np.mean(ep_is_success)}")
if terminated or truncated:
break
ep_rewards.append(total_reward)
ep_is_success.append(info['is_success'])
if a.benchmark and len(ep_rewards) >= a.benchmark:
print("episodes = ", len(ep_rewards))
print(f"mean reward = {np.mean(ep_rewards)}")
print(f"median reward = {np.median(ep_rewards)}")
print(f"success rate = {np.mean(ep_is_success)}")
run = False
if not a.benchmark:
print(f"\ntotal_reward = {total_reward}\nis_success = {info['is_success']}")