forked from LTL2Action/LTL2Action
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
185 lines (144 loc) · 7.33 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""
This is the description of the deep NN currently being used.
It is a small CNN for the features with an GRU encoding of the LTL task.
The features and LTL are preprocessed by utils.format.get_obss_preprocessor(...) function:
- In that function, I transformed the LTL tuple representation into a text representation:
- Input: ('until',('not','a'),('and', 'b', ('until',('not','c'),'d')))
- output: ['until', 'not', 'a', 'and', 'b', 'until', 'not', 'c', 'd']
Each of those tokens get a one-hot embedding representation by the utils.format.Vocabulary class.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical, Normal
import torch_ac
from gym.spaces import Box, Discrete
from gnns.graphs.GCN import *
from gnns.graphs.GNN import GNNMaker
from env_model import getEnvModel
from policy_network import PolicyNetwork
# Function from https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/blob/master/model.py
def init_params(m):
classname = m.__class__.__name__
if classname.find("TypedLinear") != -1:
weight = m.get_weight()
weight.data.normal_(0, 1)
weight.data *= 1 / torch.sqrt(weight.data.pow(2).sum(1, keepdim=True))
elif classname.find("Linear") != -1:
m.weight.data.normal_(0, 1)
m.weight.data *= 1 / torch.sqrt(m.weight.data.pow(2).sum(1, keepdim=True))
if m.bias is not None:
m.bias.data.fill_(0)
class ACModel(nn.Module, torch_ac.ACModel):
def __init__(self, env, obs_space, action_space, ignoreLTL, gnn_type, dumb_ac, freeze_ltl):
super().__init__()
# Decide which components are enabled
self.use_progression_info = "progress_info" in obs_space
self.use_text = not ignoreLTL and (gnn_type == "GRU" or gnn_type == "LSTM") and "text" in obs_space
self.use_ast = not ignoreLTL and ("GCN" in gnn_type or "Transformer" in gnn_type or "GATv2Conv" in gnn_type) and "text" in obs_space
self.gnn_type = gnn_type
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.action_space = action_space
self.dumb_ac = dumb_ac
self.freeze_pretrained_params = freeze_ltl
if self.freeze_pretrained_params:
print("Freezing the LTL module.")
self.env_model = getEnvModel(env, obs_space)
# Define text embedding
if self.use_progression_info:
self.text_embedding_size = 32
self.simple_encoder = nn.Sequential(
nn.Linear(obs_space["progress_info"], 64),
nn.Tanh(),
nn.Linear(64, self.text_embedding_size),
nn.Tanh()
).to(self.device)
print("Linear encoder Number of parameters:", sum(p.numel() for p in self.simple_encoder.parameters() if p.requires_grad))
elif self.use_text:
self.word_embedding_size = 32
self.text_embedding_size = 32
if self.gnn_type == "GRU":
self.text_rnn = GRUModel(obs_space["text"], self.word_embedding_size, 16, self.text_embedding_size).to(self.device)
else:
assert(self.gnn_type == "LSTM")
self.text_rnn = LSTMModel(obs_space["text"], self.word_embedding_size, 16, self.text_embedding_size).to(self.device)
print("RNN Number of parameters:", sum(p.numel() for p in self.text_rnn.parameters() if p.requires_grad))
elif self.use_ast:
hidden_dim = 32
self.text_embedding_size = 32
self.gnn = GNNMaker(self.gnn_type, obs_space["text"], self.text_embedding_size).to(self.device)
print("GNN Number of parameters:", sum(p.numel() for p in self.gnn.parameters() if p.requires_grad))
# Resize image embedding
self.embedding_size = self.env_model.size()
print("embedding size:", self.embedding_size)
if self.use_text or self.use_ast or self.use_progression_info:
self.embedding_size += self.text_embedding_size
if self.dumb_ac:
# Define actor's model
self.actor = PolicyNetwork(self.embedding_size, self.action_space)
# Define critic's model
self.critic = nn.Sequential(
nn.Linear(self.embedding_size, 1)
)
else:
# Define actor's model
self.actor = PolicyNetwork(self.embedding_size, self.action_space, hiddens=[64, 64, 64], activation=nn.ReLU())
# Define critic's model
self.critic = nn.Sequential(
nn.Linear(self.embedding_size, 64),
nn.Tanh(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, 1)
)
# Initialize parameters correctly
self.apply(init_params)
def forward(self, obs):
embedding = self.env_model(obs)
if self.use_progression_info:
embed_ltl = self.simple_encoder(obs.progress_info)
embedding = torch.cat((embedding, embed_ltl), dim=1) if embedding is not None else embed_ltl
# Adding Text
elif self.use_text:
embed_text = self.text_rnn(obs.text)
embedding = torch.cat((embedding, embed_text), dim=1) if embedding is not None else embed_text
# Adding GNN
elif self.use_ast:
embed_gnn = self.gnn(obs.text)
embedding = torch.cat((embedding, embed_gnn), dim=1) if embedding is not None else embed_gnn
# Actor
dist = self.actor(embedding)
# Critic
x = self.critic(embedding)
value = x.squeeze(1)
return dist, value
def load_pretrained_gnn(self, model_state):
# We delete all keys relating to the actor/critic.
new_model_state = model_state.copy()
for key in model_state.keys():
if key.find("actor") != -1 or key.find("critic") != -1:
del new_model_state[key]
self.load_state_dict(new_model_state, strict=False)
if self.freeze_pretrained_params:
target = self.text_rnn if self.gnn_type == "GRU" or self.gnn_type == "LSTM" else self.gnn
for param in target.parameters():
param.requires_grad = False
class LSTMModel(nn.Module):
def __init__(self, obs_size, word_embedding_size=32, hidden_dim=32, text_embedding_size=32):
super().__init__()
# For all our experiments we want the embedding to be a fixed size so we can "transfer".
self.word_embedding = nn.Embedding(obs_size, word_embedding_size)
self.lstm = nn.LSTM(word_embedding_size, hidden_dim, num_layers=2, batch_first=True, bidirectional=True)
self.output_layer = nn.Linear(2*hidden_dim, text_embedding_size)
def forward(self, text):
hidden, _ = self.lstm(self.word_embedding(text))
return self.output_layer(hidden[:, -1, :])
class GRUModel(nn.Module):
def __init__(self, obs_size, word_embedding_size=32, hidden_dim=32, text_embedding_size=32):
super().__init__()
self.word_embedding = nn.Embedding(obs_size, word_embedding_size)
self.gru = nn.GRU(word_embedding_size, hidden_dim, num_layers=2, batch_first=True, bidirectional=True)
self.output_layer = nn.Linear(2*hidden_dim, text_embedding_size)
def forward(self, text):
hidden, _ = self.gru(self.word_embedding(text))
return self.output_layer(hidden[:, -1, :])