forked from pytorch/rl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
153 lines (126 loc) · 5.46 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import pkg_resources
from tensordict.nn.probabilistic import ( # noqa
interaction_mode as exploration_mode,
set_interaction_mode as set_exploration_mode,
)
from tensordict.tensordict import TensorDictBase
AVAILABLE_LIBRARIES = {pkg.key for pkg in pkg_resources.working_set}
class _classproperty(property):
def __get__(self, cls, owner):
return classmethod(self.fget).__get__(None, owner)()
def step_mdp(
tensordict: TensorDictBase,
next_tensordict: TensorDictBase = None,
keep_other: bool = True,
exclude_reward: bool = True,
exclude_done: bool = True,
exclude_action: bool = True,
_run_check: bool = True,
) -> TensorDictBase:
"""Creates a new tensordict that reflects a step in time of the input tensordict.
Given a tensordict retrieved after a step, returns the :obj:`"next"` indexed-tensordict.
Args:
tensordict (TensorDictBase): tensordict with keys to be renamed
next_tensordict (TensorDictBase, optional): destination tensordict
keep_other (bool, optional): if True, all keys that do not start with :obj:`'next_'` will be kept.
Default is True.
exclude_reward (bool, optional): if True, the :obj:`"reward"` key will be discarded
from the resulting tensordict.
Default is True.
exclude_done (bool, optional): if True, the :obj:`"done"` key will be discarded
from the resulting tensordict.
Default is True.
exclude_action (bool, optional): if True, the :obj:`"action"` key will be discarded
from the resulting tensordict.
Default is True.
Returns:
A new tensordict (or next_tensordict) containing the tensors of the t+1 step.
Examples:
This funtion allows for this kind of loop to be used:
>>> td_out = []
>>> env = make_env()
>>> policy = make_policy()
>>> td = env.reset()
>>> for i in range(max_steps):
>>> td = env.step(td)
>>> next_td = step_mdp(td)
>>> assert next_td is not td # make sure that keys are not overwritten
>>> td_out.append(td)
>>> td = next_td
>>> td_out = torch.stack(td_out, 0)
>>> print(td_out) # should contain keys 'observation', 'next_observation', 'action', 'reward', 'done' or similar
"""
other_keys = []
prohibited = set()
if exclude_done:
prohibited.add("done")
else:
other_keys.append("done")
if exclude_reward:
prohibited.add("reward")
else:
other_keys.append("reward")
if exclude_action:
prohibited.add("action")
else:
other_keys.append("action")
prohibited.add("next")
if keep_other:
other_keys = [key for key in tensordict.keys() if key not in prohibited]
select_tensordict = tensordict.select(*other_keys)
select_tensordict = select_tensordict.update(tensordict.get("next"))
if next_tensordict is not None:
return next_tensordict.update(select_tensordict)
else:
return select_tensordict
def get_available_libraries():
"""Returns all the supported libraries."""
return SUPPORTED_LIBRARIES
def _check_gym():
"""Returns True if the gym library is installed."""
return "gym" in AVAILABLE_LIBRARIES
def _check_gym_atari():
"""Returns True if the gym library is installed and atari envs can be found."""
if not _check_gym():
return False
return "atari-py" in AVAILABLE_LIBRARIES
def _check_mario():
"""Returns True if the "gym-super-mario-bros" library is installed."""
return "gym-super-mario-bros" in AVAILABLE_LIBRARIES
def _check_dmcontrol():
"""Returns True if the "dm-control" library is installed."""
return "dm-control" in AVAILABLE_LIBRARIES
def _check_dmlab():
"""Returns True if the "deepmind-lab" library is installed."""
return "deepmind-lab" in AVAILABLE_LIBRARIES
SUPPORTED_LIBRARIES = {
"gym": _check_gym(), # OpenAI
"gym[atari]": _check_gym_atari(), #
"dm_control": _check_dmcontrol(),
"habitat": None,
"gym-super-mario-bros": _check_mario(),
# "vizdoom": None, # gym based, https://github.com/mwydmuch/ViZDoom
# "openspiel": None, # DM, https://github.com/deepmind/open_spiel
# "pysc2": None, # DM, https://github.com/deepmind/pysc2
# "deepmind_lab": _check_dmlab(),
# DM, https://github.com/deepmind/lab, https://github.com/deepmind/lab/tree/master/python/pip_package
# "serpent.ai": None, # https://github.com/SerpentAI/SerpentAI
# "gfootball": None, # 2.8k G, https://github.com/google-research/football
# DM, https://github.com/deepmind/dm_control
# FB, https://github.com/facebookresearch/habitat-sim
# "meta-world": None, # https://github.com/rlworkgroup/metaworld
# "minerl": None, # https://github.com/minerllabs/minerl
# "multi-agent-emergence-environments": None,
# OpenAI, https://github.com/openai/multi-agent-emergence-environments
# "procgen": None, # OpenAI, https://github.com/openai/procgen
# "pybullet": None, # https://github.com/benelot/pybullet-gym
# "realworld_rl_suite": None,
# G, https://github.com/google-research/realworldrl_suite
# "rlcard": None, # https://github.com/datamllab/rlcard
# "screeps": None, # https://github.com/screeps/screeps
# "ml-agents": None,
}