-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathrelabel.py
80 lines (63 loc) · 2.59 KB
/
relabel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import abc
class TrajectoryExperience(object):
"""An experience that holds a reference to the trajectory it came from.
This should be substitutable wherever Experience is used.
In particular, it holds:
- trajectory (list[Experience]): the in-order trajectory that this
experience is part of
- index (int): the index inside of this trajectory that this experience
is.
"""
def __init__(self, experience, trajectory, index):
self._experience = experience
self._trajectory = trajectory
self._index = index
def __getattr__(self, attr):
if attr[0] == "_" and attr != "_replace":
raise AttributeError("accessing private attribute '{}'".format(attr))
return getattr(self._experience, attr)
@property
def trajectory(self):
return self._trajectory
@property
def index(self):
return self._index
@property
def experience(self):
return self._experience
def cpu(self):
return TrajectoryExperience(self.experience.cpu(), self.trajectory, self.index)
def cuda(self):
return TrajectoryExperience(self.experience.cuda(), self.trajectory, self.index)
@classmethod
def episode_to_device(cls, episode, cpu=True):
"""Creates trajectory experiences & updates
Makes sure experiences are on correct device
Args:
trajectory (List[Experience]): List of experiences to update on.
"""
new_episode = []
trajectory = []
for idx, exp in enumerate(episode):
if cpu:
exp_on_device = exp.cpu()
else:
exp_on_device = exp.cuda()
new_episode.append(exp_on_device)
trajectory.append(TrajectoryExperience(exp_on_device, new_episode, idx))
return trajectory
class RewardLabeler(abc.ABC):
"""Computes rewards for trajectories on the fly."""
@abc.abstractmethod
def label_rewards(self, trajectories):
"""Computes rewards for each experience in the trajectory.
Args:
trajectories (list[list[TrajectoryExperience]]): batch of
trajectories.
Returns:
rewards (torch.FloatTensor): of shape (batch_size, max_seq_len) where
rewards[i][j] is the rewards for the experience trajectories[i][j].
This is padded with zeros and is detached from the graph.
distances (torch.FloatTensor): of shape (batch_size, max_seq_len + 1)
equal to ||f(e) - g(\tau^e_{:t})|| for each t.
"""