forked from LTL2Action/LTL2Action
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathenv_model.py
147 lines (118 loc) · 4.28 KB
/
env_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
import torch.nn as nn
from envs import *
from gym.envs.classic_control import PendulumEnv
def getEnvModel(env, obs_space):
env = env.unwrapped
if isinstance(env, LetterEnv):
return LetterEnvModel(obs_space)
if isinstance(env, MinigridEnv):
return MinigridEnvModel(obs_space)
if isinstance(env, ZonesEnv):
return ZonesEnvModel(obs_space)
if isinstance(env, PendulumEnv):
return PendulumEnvModel(obs_space)
# Add your EnvModel here...
# The default case (No environment observations) - SimpleLTLEnv uses this
return EnvModel(obs_space)
"""
This class is in charge of embedding the environment part of the observations.
Every environment has its own set of observations ('image', 'direction', etc) which is handeled
here by associated EnvModel subclass.
How to subclass this:
1. Call the super().__init__() from your init
2. In your __init__ after building the compute graph set the self.embedding_size appropriately
3. In your forward() method call the super().forward as the default case.
4. Add the if statement in the getEnvModel() method
"""
class EnvModel(nn.Module):
def __init__(self, obs_space):
super().__init__()
self.embedding_size = 0
def forward(self, obs):
return None
def size(self):
return self.embedding_size
class LetterEnvModel(EnvModel):
def __init__(self, obs_space):
super().__init__(obs_space)
if "image" in obs_space.keys():
n = obs_space["image"][0]
m = obs_space["image"][1]
k = obs_space["image"][2]
self.image_conv = nn.Sequential(
nn.Conv2d(k, 16, (2, 2)),
nn.ReLU(),
nn.Conv2d(16, 32, (2, 2)),
nn.ReLU(),
nn.Conv2d(32, 64, (2, 2)),
nn.ReLU()
)
self.embedding_size = (n-3)*(m-3)*64
def forward(self, obs):
if "image" in obs.keys():
x = obs.image.transpose(1, 3).transpose(2, 3)
x = self.image_conv(x)
x = x.reshape(x.shape[0], -1)
return x
return super().forward(obs)
class MinigridEnvModel(EnvModel):
def __init__(self, obs_space):
super().__init__(obs_space)
if "image" in obs_space.keys():
n = obs_space["image"][0]
m = obs_space["image"][1]
k = obs_space["image"][2]
self.image_conv = nn.Sequential(
nn.Conv2d(k, 16, (2, 2)),
nn.ReLU(),
nn.MaxPool2d((2, 2)),
nn.Conv2d(16, 32, (2, 2)),
nn.ReLU(),
nn.Conv2d(32, 64, (2, 2)),
nn.ReLU()
)
self.embedding_size = ((n-1)//2-2)*((m-1)//2-2)*64
def forward(self, obs):
if "image" in obs.keys():
x = obs.image.transpose(1, 3).transpose(2, 3)
x = self.image_conv(x)
x = x.reshape(x.shape[0], -1)
return x
return super().forward(obs)
class ZonesEnvModel(EnvModel):
def __init__(self, obs_space):
super().__init__(obs_space)
if "image" in obs_space.keys():
n = obs_space["image"][0]
lidar_num_bins = 16
self.embedding_size = 64 #(n-12)//lidar_num_bins + 4
self.net_ = nn.Sequential(
nn.Linear(n, 128),
nn.ReLU(),
nn.Linear(128, self.embedding_size),
nn.ReLU()
)
# embedding_size = number of propositional lidars + 4 normal sensors
def forward(self, obs):
if "image" in obs.keys():
return self.net_(obs.image)
return super().forward(obs)
class PendulumEnvModel(EnvModel):
def __init__(self, obs_space):
super().__init__(obs_space)
if "image" in obs_space.keys():
self.net_ = nn.Sequential(
nn.Linear(obs_space["image"][0], 3),
nn.Tanh(),
# nn.Linear(3, 3),
# nn.Tanh()
)
self.embedding_size = 3
def forward(self, obs):
if "image" in obs.keys():
x = obs.image
# x = torch.cat((x, x * x), 1)
x = self.net_(x)
return x
return super().forward(obs)