From c6f38d990991f49c257ffe7616dd91867711525f Mon Sep 17 00:00:00 2001 From: Sergey Ordinskiy Date: Mon, 28 Nov 2022 18:30:01 +0100 Subject: [PATCH 01/10] [Feature] Initial commit --- test/test_libs.py | 92 ++++++++++++++++++- torchrl/data/__init__.py | 2 + torchrl/data/tensor_specs.py | 91 +++++++++++++++++++ torchrl/envs/libs/smac.py | 170 +++++++++++++++++++++++++++++++++++ 4 files changed, 353 insertions(+), 2 deletions(-) create mode 100644 torchrl/envs/libs/smac.py diff --git a/test/test_libs.py b/test/test_libs.py index 4e4b4b811b2..b0caeb2c8ef 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -25,6 +25,7 @@ from torchrl.envs.libs.gym import _has_gym, _is_from_pixels, GymEnv, GymWrapper from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv +from torchrl.envs.libs.smac import SC2Env, _has_smac if _has_gym: import gym @@ -44,7 +45,7 @@ IS_OSX = platform == "darwin" -@pytest.mark.skipif(not _has_gym, reason="no gym library found") +@pytest.mark.skipif(_has_gym, reason="no gym library found") @pytest.mark.parametrize( "env_name", [ @@ -149,7 +150,7 @@ def _make_gym_environment(env_name): # noqa: F811 return gym.make(env_name, render_mode="rgb_array") -@pytest.mark.skipif(not _has_dmc, reason="no dm_control library found") +@pytest.mark.skipif(_has_dmc, reason="no dm_control library found") @pytest.mark.parametrize("env_name,task", [["cheetah", "run"]]) @pytest.mark.parametrize("frame_skip", [1, 3]) @pytest.mark.parametrize( @@ -425,6 +426,93 @@ def test_jumanji_consistency(self, envname, batch_size): ) +@pytest.mark.skipif(not _has_smac, reason="smac not installed") +@pytest.mark.parametrize("map_name", ["8m"]) # TODO: add more maps +class TestSmac: + def test_smac_seeding(self, map_name): + """Verifies deterministic behaviour of the environment.""" + final_seed = [] + tdreset = [] + tdrollout = [] + seed = 0 + for _ in range(2): + torch.manual_seed(seed) + np.random.seed(seed) + env = SC2Env(map_name, seed) + # final_seed.append(env.set_seed(0)) + tdreset.append(env.reset()) + tdrollout.append(env.rollout(max_steps=50)) + env.close() + del env + # TODO: in this case seed is always static? + # assert final_seed[0] == final_seed[1] + assert_allclose_td(*tdreset) + assert_allclose_td(*tdrollout) + + # @pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)]) + # def test_jumanji_batch_size(self, envname, batch_size): + # env = JumanjiEnv(envname, batch_size=batch_size) + # env.set_seed(0) + # tdreset = env.reset() + # tdrollout = env.rollout(max_steps=50) + # env.close() + # del env + # assert tdreset.batch_size == batch_size + # assert tdrollout.batch_size[:-1] == batch_size + # + # @pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)]) + # def test_jumanji_spec_rollout(self, envname, batch_size): + # env = JumanjiEnv(envname, batch_size=batch_size) + # env.set_seed(0) + # _test_fake_tensordict(env) + # + # @pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)]) + # def test_jumanji_consistency(self, envname, batch_size): + # import jax + # import jax.numpy as jnp + # import numpy as onp + # + # env = JumanjiEnv(envname, batch_size=batch_size) + # obs_keys = list(env.observation_spec.keys(True)) + # env.set_seed(1) + # rollout = env.rollout(10) + # + # env.set_seed(1) + # key = env.key + # base_env = env._env + # key, *keys = jax.random.split(key, np.prod(batch_size) + 1) + # state, timestep = jax.vmap(base_env.reset)(jnp.stack(keys)) + # # state = env._reshape(state) + # # timesteps.append(timestep) + # for i in range(rollout.shape[-1]): + # action = rollout[..., i]["action"] + # # state = env._flatten(state) + # action = env._flatten(env.read_action(action)) + # state, timestep = jax.vmap(base_env.step)(state, action) + # # state = env._reshape(state) + # # timesteps.append(timestep) + # checked = False + # for _key in obs_keys: + # if isinstance(_key, str): + # _key = (_key,) + # try: + # t2 = getattr(timestep, _key[0]) + # except AttributeError: + # try: + # t2 = getattr(timestep.observation, _key[0]) + # except AttributeError: + # continue + # t1 = rollout[..., i][("next", *_key)] + # for __key in _key[1:]: + # t2 = getattr(t2, _key) + # t2 = torch.tensor(onp.asarray(t2)).view_as(t1) + # torch.testing.assert_close(t1, t2) + # checked = True + # if not checked: + # raise AttributeError( + # f"None of the keys matched: {rollout}, {list(timestep.__dict__.keys())}" + # ) + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 44822cbfa7e..c95b405764d 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -25,6 +25,8 @@ NdUnboundedContinuousTensorSpec, NdUnboundedDiscreteTensorSpec, OneHotDiscreteTensorSpec, + NdOneHotDiscreteTensorSpec, + CustomNdOneHotDiscreteTensorSpec, TensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec, diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index ec347be62ed..be8cf613d3e 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -586,6 +586,97 @@ def __eq__(self, other): ) +@dataclass(repr=False) +class NdOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec): + """An N-dimensional One hot discrete tensor spec data class""" + + def __init__( + self, + n: int, + *shape: int, + device: Optional[DEVICE_TYPING] = None, + dtype: Optional[Union[str, torch.dtype]] = torch.long, + use_register: bool = False, + ): + dtype, device = _default_dtype_and_device(dtype, device) + self.use_register = use_register + space = DiscreteBox( + n, + ) + self.shape = shape + total_shape = torch.Size( + ( + *shape, + space.n, + ) + ) + super(OneHotDiscreteTensorSpec, self).__init__( + total_shape, space, device, dtype, "discrete" + ) + + def rand(self, shape=torch.Size([])) -> torch.Tensor: + return torch.nn.functional.gumbel_softmax( + torch.rand(*shape, self.d, self.space.n, device=self.device), + hard=True, + dim=-1, + ).to(torch.long) + + +@dataclass(repr=False) +class CustomNdOneHotDiscreteTensorSpec(NdOneHotDiscreteTensorSpec): + """A masked N-dimensional One-Hot discrete tensor spec data-class + The aim of this class is to check / project or document a discrete space + when it varies from environment to environment, or from step to step in the + same environment. + """ + + def __init__( + self, + mask: torch.Tensor, + device: Optional[DEVICE_TYPING] = None, + dtype: Optional[Union[str, torch.dtype]] = torch.long, + use_register: bool = False, + ): + if mask.dtype is not torch.bool: + raise RuntimeError( + f"Expected a mask with dtype torch.bool but got {mask.dtype}" + ) + if (mask.sum(-1) == 0).any(): + raise RuntimeError("Got an empty mask for some dimension.") + self.mask = mask + *shape, n = mask.shape + + dtype, device = _default_dtype_and_device(dtype, device) + self.use_register = use_register + space = DiscreteBox( + n, + ) + self.shape = shape + total_shape = torch.Size( + ( + *shape, + space.n, + ) + ) + super(OneHotDiscreteTensorSpec, self).__init__( + total_shape, space, device, dtype, "discrete" + ) + + def to(self, dest): + out = super().to(dest) + out.mask = self.mask.to(dest) + return out + + def rand(self, shape=torch.Size([])) -> torch.Tensor: + mask = self.mask.expand(*shape, *self.mask.shape) + r = torch.rand(mask.shape, device=mask.device).masked_fill_(~mask, 0.0) + return (r == r.max(-1, keepdim=True)[0]).to(torch.long) + + def is_in(self, value): + congruent = self.mask & value.to(torch.bool) + return (congruent.sum(-1) == 1).all() + + @dataclass(repr=False) class UnboundedContinuousTensorSpec(TensorSpec): """An unbounded, unidimensional, continuous tensor spec. diff --git a/torchrl/envs/libs/smac.py b/torchrl/envs/libs/smac.py new file mode 100644 index 00000000000..c4847337e5a --- /dev/null +++ b/torchrl/envs/libs/smac.py @@ -0,0 +1,170 @@ +import numpy as np +import torch +from tensordict.tensordict import TensorDict, TensorDictBase +from typing import Dict, Optional + +from torchrl.data import ( + CompositeSpec, + DEVICE_TYPING, + DiscreteTensorSpec, + NdBoundedTensorSpec, + CustomNdOneHotDiscreteTensorSpec, + NdUnboundedContinuousTensorSpec, + UnboundedContinuousTensorSpec, + NdUnboundedDiscreteTensorSpec, + OneHotDiscreteTensorSpec, + TensorSpec, +) +from torchrl.envs import GymLikeEnv + +try: + import smac + from smac.env import StarCraft2Env + + _has_smac = True +except ImportError as err: + _has_smac = False + IMPORT_ERR = str(err) + + +# TODO: discuss with Vincent if separation to ..Wrapper and ..Env classes makes sense here. +class SC2Wrapper(GymLikeEnv): + """TODO: comments + """ + git_url = "https://github.com/oxwhirl/smac" + + def __init__(self, map_name: str = None, **kwargs): + if map_name is not None: + kwargs["map_name"] = map_name + # TODO: process seed? + super().__init__(**kwargs) + + def _check_kwargs(self, kwargs: Dict): + pass + + def _init_env(self) -> Optional[int]: + # TODO: verify that isn't required. + pass + + def _build_env(self, env, seed: Optional[int] = None, **kwargs) -> "smac.env.StarCraft2Env": + # TODO: if required + # self.from_pixels = from_pixels + # self.pixels_only = pixels_only + + # if from_pixels: + # raise NotImplementedError("TODO") + return env + + def _make_state_example(self, env): + # TODO + pass + # key = jax.random.PRNGKey(0) + # keys = jax.random.split(key, self.batch_size.numel()) + # state, _ = jax.vmap(env.reset)(jnp.stack(keys)) + # state = self._reshape(state) + # return state + + def _make_state_spec(self, env) -> TensorSpec: + # TODO + pass + # key = jax.random.PRNGKey(0) + # state, _ = env.reset(key) + # state_dict = _object_to_tensordict(state, self.device, batch_size=()) + # state_spec = _torchrl_data_to_spec_transform(state_dict) + # return state_spec + + def _make_input_spec(self, env: StarCraft2Env) -> TensorSpec: + action_spec = CustomNdOneHotDiscreteTensorSpec( + torch.tensor(env.get_avail_actions()), + device=self.device + ) + return CompositeSpec(action=action_spec) + + def _make_observation_spec(self, env: StarCraft2Env) -> TensorSpec: + info = env.get_env_info() + size = torch.Size(info["n_agents"], info["obs_shape"]) + return NdUnboundedContinuousTensorSpec(size, device=self.device) + + def _make_reward_spec(self) -> TensorSpec: + return UnboundedContinuousTensorSpec(device=self.device) + + def _make_specs(self, env: StarCraft2Env) -> None: + # extract specs from definition + self._reward_spec = self._make_reward_spec() + + # extract specs from instance + self._input_spec = self._make_input_spec(env) + self._observation_spec = self._make_observation_spec(env) + self._state_spec = self._make_state_spec(env) + self._input_spec["state"] = self._state_spec + + # TODO: build state example for data conversion + self._state_example = self._make_state_example(env) + + def _set_seed(self, seed: Optional[int]): + raise NotImplementedError("Seed can be set only when creating environment.") + + def _reset( + self, tensordict: Optional[TensorDictBase] = None, **kwargs + ) -> TensorDictBase: + + env: smac.env.StarCraft2Env = self._env + obs, state = env.reset() + + # reshape batch size from vector + # TODO + state = self._reshape(state) + obs = self._reshape(obs) + + # collect outputs + state_dict = self.read_state(state) + obs_dict = self.read_obs(obs) + done = torch.zeros(self.batch_size, dtype=torch.bool) + + self._is_done = done + + # build results + tensordict_out = TensorDict( + source=obs_dict, + batch_size=self.batch_size, + device=self.device, + ) + tensordict_out.set("done", done) + tensordict_out["state"] = state_dict + + return tensordict_out + + +class SC2Env(SC2Wrapper): + """TODO: comments + """ + + def __init__(self, map_name: str, seed: Optional[int] = None, **kwargs): + kwargs["map_name"] = map_name + if seed is not None: + kwargs["seed"] = map_name + + super().__init__(**kwargs) + + def _build_env( + self, + map_name: str, + seed: Optional[int] = None, + **kwargs, + ) -> "smac.env.StarCraft2Env": + if not _has_smac: + raise RuntimeError( + f"smac not found, unable to create smac.env.StarCraft2Env. " + f"Consider installing smac. More info:" + f" {self.git_url}. (Original error message during import: {IMPORT_ERR})." + ) + # TODO: check if those are required + # from_pixels = kwargs.pop("from_pixels", False) + # pixels_only = kwargs.pop("pixels_only", True) + + # TODO: check if this is required + # self.wrapper_frame_skip = 1 + env = smac.env.StarCraft2Env(map_name, seed, **kwargs) + + # TODO: return super()._build_env(env, pixels_only=pixels_only, from_pixels=from_pixels) + return super()._build_env(env) From 1ac54010365123d8dcc40f1d0d90da7bd3f87ce0 Mon Sep 17 00:00:00 2001 From: Sergey Ordinskiy Date: Mon, 9 Jan 2023 19:20:10 +0100 Subject: [PATCH 02/10] amend --- test/test_libs.py | 71 +---------- torchrl/data/tensor_specs.py | 9 +- torchrl/envs/libs/smac.py | 234 ++++++++++++++++++++++------------- 3 files changed, 158 insertions(+), 156 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index b85d5d523a6..2ac10bcd3e4 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -523,85 +523,20 @@ def test_smac_seeding(self, map_name): final_seed = [] tdreset = [] tdrollout = [] - seed = 0 + seed = 42 for _ in range(2): torch.manual_seed(seed) np.random.seed(seed) env = SC2Env(map_name, seed) - # final_seed.append(env.set_seed(0)) + final_seed.append(env.get_seed()) tdreset.append(env.reset()) tdrollout.append(env.rollout(max_steps=50)) env.close() del env - # TODO: in this case seed is always static? - # assert final_seed[0] == final_seed[1] + assert final_seed[0] == final_seed[1] assert_allclose_td(*tdreset) assert_allclose_td(*tdrollout) - # @pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)]) - # def test_jumanji_batch_size(self, envname, batch_size): - # env = JumanjiEnv(envname, batch_size=batch_size) - # env.set_seed(0) - # tdreset = env.reset() - # tdrollout = env.rollout(max_steps=50) - # env.close() - # del env - # assert tdreset.batch_size == batch_size - # assert tdrollout.batch_size[:-1] == batch_size - # - # @pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)]) - # def test_jumanji_spec_rollout(self, envname, batch_size): - # env = JumanjiEnv(envname, batch_size=batch_size) - # env.set_seed(0) - # _test_fake_tensordict(env) - # - # @pytest.mark.parametrize("batch_size", [(), (5,), (5, 4)]) - # def test_jumanji_consistency(self, envname, batch_size): - # import jax - # import jax.numpy as jnp - # import numpy as onp - # - # env = JumanjiEnv(envname, batch_size=batch_size) - # obs_keys = list(env.observation_spec.keys(True)) - # env.set_seed(1) - # rollout = env.rollout(10) - # - # env.set_seed(1) - # key = env.key - # base_env = env._env - # key, *keys = jax.random.split(key, np.prod(batch_size) + 1) - # state, timestep = jax.vmap(base_env.reset)(jnp.stack(keys)) - # # state = env._reshape(state) - # # timesteps.append(timestep) - # for i in range(rollout.shape[-1]): - # action = rollout[..., i]["action"] - # # state = env._flatten(state) - # action = env._flatten(env.read_action(action)) - # state, timestep = jax.vmap(base_env.step)(state, action) - # # state = env._reshape(state) - # # timesteps.append(timestep) - # checked = False - # for _key in obs_keys: - # if isinstance(_key, str): - # _key = (_key,) - # try: - # t2 = getattr(timestep, _key[0]) - # except AttributeError: - # try: - # t2 = getattr(timestep.observation, _key[0]) - # except AttributeError: - # continue - # t1 = rollout[..., i][("next", *_key)] - # for __key in _key[1:]: - # t2 = getattr(t2, _key) - # t2 = torch.tensor(onp.asarray(t2)).view_as(t1) - # torch.testing.assert_close(t1, t2) - # checked = True - # if not checked: - # raise AttributeError( - # f"None of the keys matched: {rollout}, {list(timestep.__dict__.keys())}" - # ) - if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index e9bbb9eb105..5fd189e64ff 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -609,7 +609,7 @@ def __eq__(self, other): @dataclass(repr=False) class NdOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec): - """An N-dimensional One hot discrete tensor spec data class""" + """An N-dimensional One hot discrete tensor spec data class.""" def __init__( self, @@ -635,7 +635,7 @@ def __init__( total_shape, space, device, dtype, "discrete" ) - def rand(self, shape=torch.Size([])) -> torch.Tensor: + def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: return torch.nn.functional.gumbel_softmax( torch.rand(*shape, self.d, self.space.n, device=self.device), hard=True, @@ -645,7 +645,8 @@ def rand(self, shape=torch.Size([])) -> torch.Tensor: @dataclass(repr=False) class CustomNdOneHotDiscreteTensorSpec(NdOneHotDiscreteTensorSpec): - """A masked N-dimensional One-Hot discrete tensor spec data-class + """A masked N-dimensional One-Hot discrete tensor spec data-class. + The aim of this class is to check / project or document a discrete space when it varies from environment to environment, or from step to step in the same environment. @@ -688,7 +689,7 @@ def to(self, dest): out.mask = self.mask.to(dest) return out - def rand(self, shape=torch.Size([])) -> torch.Tensor: + def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: mask = self.mask.expand(*shape, *self.mask.shape) r = torch.rand(mask.shape, device=mask.device).masked_fill_(~mask, 0.0) return (r == r.max(-1, keepdim=True)[0]).to(torch.long) diff --git a/torchrl/envs/libs/smac.py b/torchrl/envs/libs/smac.py index 093a2251900..407b0267ec9 100644 --- a/torchrl/envs/libs/smac.py +++ b/torchrl/envs/libs/smac.py @@ -1,18 +1,12 @@ from typing import Dict, Optional -import numpy as np import torch from tensordict.tensordict import TensorDict, TensorDictBase from torchrl.data import ( CompositeSpec, CustomNdOneHotDiscreteTensorSpec, - DEVICE_TYPING, - DiscreteTensorSpec, - NdBoundedTensorSpec, NdUnboundedContinuousTensorSpec, - NdUnboundedDiscreteTensorSpec, - OneHotDiscreteTensorSpec, TensorSpec, UnboundedContinuousTensorSpec, ) @@ -20,7 +14,8 @@ try: import smac - from smac.env import StarCraft2Env + import smac.env + from smac.env.starcraft2.maps import smac_maps _has_smac = True except ImportError as err: @@ -28,98 +23,108 @@ IMPORT_ERR = str(err) -# TODO: discuss with Vincent if separation to ..Wrapper and ..Env classes makes sense here. +def _get_envs(): + if not _has_smac: + return [] + return [map_name for map_name, _ in smac_maps.get_smac_map_registry().items()] + + class SC2Wrapper(GymLikeEnv): - """TODO: comments""" + """SMAC (StarCraft Multi-Agent Challenge) environment wrapper. + + Examples: + >>> env = smac.env.StarCraft2Env("8m", seed=42) # Seed cannot be changed once environment was created. + >>> env = SC2Wrapper(env) + >>> td = env.reset() + >>> td["action"] = env.action_spec.rand() + >>> td = env.step(td) + >>> print(td) + TensorDict( + fields={ + action: Tensor(torch.Size([8, 14]), dtype=torch.int64), + done: Tensor(torch.Size([1]), dtype=torch.bool), + next: TensorDict( + fields={ + observation: Tensor(torch.Size([8, 80]), dtype=torch.float32)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False), + observation: Tensor(torch.Size([8, 80]), dtype=torch.float32), + reward: Tensor(torch.Size([1]), dtype=torch.float32)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False) + >>> print(env.available_envs) + ['3m', '8m', '25m', '5m_vs_6m', '8m_vs_9m', ...] + """ git_url = "https://github.com/oxwhirl/smac" + available_envs = _get_envs() + libname = "smac" - def __init__(self, map_name: str = None, **kwargs): - if map_name is not None: - kwargs["map_name"] = map_name - # TODO: process seed? + def __init__(self, env: smac.env.StarCraft2Env = None, **kwargs): + if env is not None: + kwargs["env"] = env super().__init__(**kwargs) def _check_kwargs(self, kwargs: Dict): - pass + if "env" not in kwargs: + raise TypeError("Could not find environment key 'env' in kwargs.") + env = kwargs["env"] + if not isinstance(env, (smac.env.StarCraft2Env,)): + raise TypeError("env is not of type 'smac.env.StarCraft2Env'.") + + def _build_env(self, env, **kwargs) -> smac.env.StarCraft2Env: + # StarCraft2Env must be initialized before _make_specs. + env.reset() + return env - def _init_env(self) -> Optional[int]: - # TODO: verify that isn't required. - pass + def _make_specs(self, env: smac.env.StarCraft2Env) -> None: + # Extract specs from definition. + self.reward_spec = self._make_reward_spec() - def _build_env( - self, env, seed: Optional[int] = None, **kwargs - ) -> "smac.env.StarCraft2Env": - # TODO: if required - # self.from_pixels = from_pixels - # self.pixels_only = pixels_only - - # if from_pixels: - # raise NotImplementedError("TODO") - return env + # Extract specs from instance. + # To extract these specs environment must be fully initialized with env.reset(). + self.input_spec = self._make_input_spec(env) + self.observation_spec = self._make_observation_spec(env) - def _make_state_example(self, env): - # TODO - pass - # key = jax.random.PRNGKey(0) - # keys = jax.random.split(key, self.batch_size.numel()) - # state, _ = jax.vmap(env.reset)(jnp.stack(keys)) - # state = self._reshape(state) - # return state - - def _make_state_spec(self, env) -> TensorSpec: - # TODO + # TODO: add support for the state. + # self.state_spec = self._make_state_spec(env) + # self.input_spec["state"] = self._state_spec + # self._state_example = self._make_state_example(env) + + def _init_env(self) -> None: pass - # key = jax.random.PRNGKey(0) - # state, _ = env.reset(key) - # state_dict = _object_to_tensordict(state, self.device, batch_size=()) - # state_spec = _torchrl_data_to_spec_transform(state_dict) - # return state_spec - def _make_input_spec(self, env: StarCraft2Env) -> TensorSpec: + def _make_reward_spec(self) -> TensorSpec: + return UnboundedContinuousTensorSpec(device=self.device) + + def _make_input_spec(self, env: smac.env.StarCraft2Env) -> TensorSpec: action_spec = CustomNdOneHotDiscreteTensorSpec( - torch.tensor(env.get_avail_actions()), device=self.device + torch.tensor(env.get_avail_actions(), dtype=torch.bool), device=self.device ) return CompositeSpec(action=action_spec) - def _make_observation_spec(self, env: StarCraft2Env) -> TensorSpec: + def _make_observation_spec(self, env: smac.env.StarCraft2Env) -> TensorSpec: info = env.get_env_info() - size = torch.Size(info["n_agents"], info["obs_shape"]) - return NdUnboundedContinuousTensorSpec(size, device=self.device) - - def _make_reward_spec(self) -> TensorSpec: - return UnboundedContinuousTensorSpec(device=self.device) - - def _make_specs(self, env: StarCraft2Env) -> None: - # extract specs from definition - self._reward_spec = self._make_reward_spec() - - # extract specs from instance - self._input_spec = self._make_input_spec(env) - self._observation_spec = self._make_observation_spec(env) - self._state_spec = self._make_state_spec(env) - self._input_spec["state"] = self._state_spec - - # TODO: build state example for data conversion - self._state_example = self._make_state_example(env) + size = torch.Size((info["n_agents"], info["obs_shape"])) + obs_spec = NdUnboundedContinuousTensorSpec(size, device=self.device) + return CompositeSpec(observation=obs_spec) def _set_seed(self, seed: Optional[int]): - raise NotImplementedError("Seed can be set only when creating environment.") + raise NotImplementedError( + "Seed cannot be changed once environment was created." + ) def _reset( self, tensordict: Optional[TensorDictBase] = None, **kwargs ) -> TensorDictBase: - env: smac.env.StarCraft2Env = self._env obs, state = env.reset() - # reshape batch size from vector - # TODO - state = self._reshape(state) - obs = self._reshape(obs) - # collect outputs - state_dict = self.read_state(state) + # TODO: add support for the state. + # state_dict = self.read_state(state) obs_dict = self.read_obs(obs) done = torch.zeros(self.batch_size, dtype=torch.bool) @@ -132,40 +137,101 @@ def _reset( device=self.device, ) tensordict_out.set("done", done) - tensordict_out["state"] = state_dict + # TODO: add support for the state. + # tensordict_out["state"] = state_dict + # TODO: return available actions? return tensordict_out + def _action_transform(self, action): + action_np = self.action_spec.to_numpy(action) + return action_np + + def _step(self, tensordict: TensorDictBase) -> TensorDictBase: + env: smac.env.StarCraft2Env = self._env + + # perform actions + action = tensordict.get("action") # this is a list of actions for each agent + action_np = self._action_transform(action) + + # Actions are validated by the environment. + reward, done, info = env.step(action_np) + + # collect outputs + # state_dict = self.read_state(state) + obs_dict = self.read_obs(env.get_obs()) + reward = self._to_tensor(reward, dtype=self.reward_spec.dtype) + done = self._to_tensor(done, dtype=torch.bool) + + # build results + tensordict_out = TensorDict( + source=obs_dict, + batch_size=tensordict.batch_size, + device=self.device, + ) + tensordict_out.set("reward", reward) + tensordict_out.set("done", done) + # TODO: support state. + # tensordict_out["state"] = state_dict + + # Update available actions mask. + self.input_spec = self._make_input_spec(env) + + return tensordict_out + + def get_seed(self) -> Optional[int]: + return self._env.seed() + class SC2Env(SC2Wrapper): - """TODO: comments""" + """SMAC (StarCraft Multi-Agent Challenge) environment wrapper. + + Examples: + >>> env = SC2Env(map_name="8m", seed=42) + >>> td = env.rand_step() + >>> print(td) + TensorDict( + fields={ + action: Tensor(torch.Size([8, 14]), dtype=torch.int64), + done: Tensor(torch.Size([1]), dtype=torch.bool), + next: TensorDict( + fields={ + observation: Tensor(torch.Size([8, 80]), dtype=torch.float32)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False), + reward: Tensor(torch.Size([1]), dtype=torch.float32)}, + batch_size=torch.Size([]), + device=cpu, + is_shared=False) + >>> print(env.available_envs) + ['3m', '8m', '25m', '5m_vs_6m', '8m_vs_9m', ...] + """ def __init__(self, map_name: str, seed: Optional[int] = None, **kwargs): kwargs["map_name"] = map_name if seed is not None: - kwargs["seed"] = map_name - + kwargs["seed"] = seed super().__init__(**kwargs) + def _check_kwargs(self, kwargs: Dict): + if "map_name" not in kwargs: + raise TypeError("Expected 'map_name' to be part of kwargs") + def _build_env( self, map_name: str, seed: Optional[int] = None, **kwargs, - ) -> "smac.env.StarCraft2Env": + ) -> smac.env.StarCraft2Env: if not _has_smac: raise RuntimeError( f"smac not found, unable to create smac.env.StarCraft2Env. " f"Consider installing smac. More info:" f" {self.git_url}. (Original error message during import: {IMPORT_ERR})." ) - # TODO: check if those are required - # from_pixels = kwargs.pop("from_pixels", False) - # pixels_only = kwargs.pop("pixels_only", True) - # TODO: check if this is required - # self.wrapper_frame_skip = 1 - env = smac.env.StarCraft2Env(map_name, seed, **kwargs) + self.wrapper_frame_skip = 1 + env = smac.env.StarCraft2Env(map_name, seed=seed, **kwargs) - # TODO: return super()._build_env(env, pixels_only=pixels_only, from_pixels=from_pixels) return super()._build_env(env) From 8c2b5f2e38d9cc822104940fe667593c1e5c4871 Mon Sep 17 00:00:00 2001 From: Sergey Ordinskiy Date: Tue, 10 Jan 2023 20:08:55 +0100 Subject: [PATCH 03/10] Added CI pipeline --- .circleci/config.yml | 49 +++ .../linux_libs/scripts_smac/environment.yml | 18 + .../linux_libs/scripts_smac/install.sh | 48 +++ .../linux_libs/scripts_smac/post_process.sh | 6 + .../scripts_smac/run-clang-format.py | 356 ++++++++++++++++++ .../linux_libs/scripts_smac/run_test.sh | 30 ++ .../linux_libs/scripts_smac/setup_env.sh | 59 +++ 7 files changed, 566 insertions(+) create mode 100644 .circleci/unittest/linux_libs/scripts_smac/environment.yml create mode 100755 .circleci/unittest/linux_libs/scripts_smac/install.sh create mode 100755 .circleci/unittest/linux_libs/scripts_smac/post_process.sh create mode 100755 .circleci/unittest/linux_libs/scripts_smac/run-clang-format.py create mode 100755 .circleci/unittest/linux_libs/scripts_smac/run_test.sh create mode 100755 .circleci/unittest/linux_libs/scripts_smac/setup_env.sh diff --git a/.circleci/config.yml b/.circleci/config.yml index cc4bea384a9..bfc603a5a1c 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -564,6 +564,51 @@ jobs: - store_test_results: path: test-results + unittest_linux_smac_gpu: + <<: *binary_common + machine: + image: ubuntu-2004-cuda-11.4:202110-01 + resource_class: gpu.nvidia.medium + environment: + image_name: "pytorch/manylinux-cuda113" + TAR_OPTIONS: --no-same-owner + PYTHON_VERSION: << parameters.python_version >> + CU_VERSION: << parameters.cu_version >> + + steps: + - checkout + - designate_upload_channel + - run: + name: Generate cache key + # This will refresh cache on Sundays, nightly build should generate new cache. + command: echo "$(date +"%Y-%U")" > .circleci-weekly + - restore_cache: + keys: + - env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_libs/scripts_smac/environment.yml" }}-{{ checksum ".circleci-weekly" }} + - run: + name: Setup + command: .circleci/unittest/linux_libs/scripts_jumanji/setup_env.sh + - save_cache: + key: env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_libs/scripts_smac/environment.yml" }}-{{ checksum ".circleci-weekly" }} + paths: + - conda + - env + - run: + name: Install torchrl + command: docker run -t --gpus all -v $PWD:$PWD -w $PWD -e UPLOAD_CHANNEL -e CU_VERSION "${image_name}" .circleci/unittest/linux_libs/scripts_smac/install.sh + - run: + name: Run tests + command: bash .circleci/unittest/linux_libs/scripts_smac/run_test.sh + - run: + name: Codecov upload + command: | + bash <(curl -s https://codecov.io/bash) -Z -F linux-smac + - run: + name: Post Process + command: docker run -t --gpus all -v $PWD:$PWD -w $PWD "${image_name}" .circleci/unittest/linux_libs/scripts_smac/post_process.sh + - store_test_results: + path: test-results + unittest_linux_vmas_gpu: <<: *binary_common machine: @@ -1112,6 +1157,10 @@ workflows: cu_version: cu113 name: unittest_linux_gym_gpu_py3.8 python_version: '3.8' + - unittest_linux_smac_gpu: + cu_version: cu113 + name: unittest_linux_smac_gpu_py3.8 + python_version: '3.8' - unittest_macos_cpu: diff --git a/.circleci/unittest/linux_libs/scripts_smac/environment.yml b/.circleci/unittest/linux_libs/scripts_smac/environment.yml new file mode 100644 index 00000000000..3612eb4264a --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_smac/environment.yml @@ -0,0 +1,18 @@ +channels: + - pytorch + - defaults +dependencies: + - pip + - pip: + - hypothesis + - future + - cloudpickle + - pytest + - pytest-cov + - pytest-mock + - pytest-instafail + - expecttest + - pyyaml + - scipy + - hydra-core + - smac diff --git a/.circleci/unittest/linux_libs/scripts_smac/install.sh b/.circleci/unittest/linux_libs/scripts_smac/install.sh new file mode 100755 index 00000000000..91671e8d985 --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_smac/install.sh @@ -0,0 +1,48 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +if [ "${CU_VERSION:-}" == cpu ] ; then + version="cpu" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" +fi + +# submodules +git submodule sync && git submodule update --init --recursive + +printf "Installing PyTorch with %s\n" "${CU_VERSION}" +if [ "${CU_VERSION:-}" == cpu ] ; then + # conda install -y pytorch torchvision cpuonly -c pytorch-nightly + # use pip to install pytorch as conda can frequently pick older release +# conda install -y pytorch cpuonly -c pytorch-nightly + pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall +else + pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cu116 --force-reinstall +fi + +# install tensordict +pip install git+https://github.com/pytorch-labs/tensordict.git + +# smoke test +python -c "import functorch;import tensordict" + +printf "* Installing torchrl\n" +pip3 install -e . + +# smoke test +python -c "import torchrl" diff --git a/.circleci/unittest/linux_libs/scripts_smac/post_process.sh b/.circleci/unittest/linux_libs/scripts_smac/post_process.sh new file mode 100755 index 00000000000..e97bf2a7b1b --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_smac/post_process.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env diff --git a/.circleci/unittest/linux_libs/scripts_smac/run-clang-format.py b/.circleci/unittest/linux_libs/scripts_smac/run-clang-format.py new file mode 100755 index 00000000000..5783a885d86 --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_smac/run-clang-format.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python +""" +MIT License + +Copyright (c) 2017 Guillaume Papin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +A wrapper script around clang-format, suitable for linting multiple files +and to use for continuous integration. + +This is an alternative API for the clang-format command line. +It runs over multiple files and directories in parallel. +A diff output is produced and a sensible exit code is returned. + +""" + +import argparse +import difflib +import fnmatch +import multiprocessing +import os +import signal +import subprocess +import sys +import traceback +from functools import partial + +try: + from subprocess import DEVNULL # py3k +except ImportError: + DEVNULL = open(os.devnull, "wb") + + +DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu" + + +class ExitStatus: + SUCCESS = 0 + DIFF = 1 + TROUBLE = 2 + + +def list_files(files, recursive=False, extensions=None, exclude=None): + if extensions is None: + extensions = [] + if exclude is None: + exclude = [] + + out = [] + for file in files: + if recursive and os.path.isdir(file): + for dirpath, dnames, fnames in os.walk(file): + fpaths = [os.path.join(dirpath, fname) for fname in fnames] + for pattern in exclude: + # os.walk() supports trimming down the dnames list + # by modifying it in-place, + # to avoid unnecessary directory listings. + dnames[:] = [ + x + for x in dnames + if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern) + ] + fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)] + for f in fpaths: + ext = os.path.splitext(f)[1][1:] + if ext in extensions: + out.append(f) + else: + out.append(file) + return out + + +def make_diff(file, original, reformatted): + return list( + difflib.unified_diff( + original, + reformatted, + fromfile=f"{file}\t(original)", + tofile=f"{file}\t(reformatted)", + n=3, + ) + ) + + +class DiffError(Exception): + def __init__(self, message, errs=None): + super().__init__(message) + self.errs = errs or [] + + +class UnexpectedError(Exception): + def __init__(self, message, exc=None): + super().__init__(message) + self.formatted_traceback = traceback.format_exc() + self.exc = exc + + +def run_clang_format_diff_wrapper(args, file): + try: + ret = run_clang_format_diff(args, file) + return ret + except DiffError: + raise + except Exception as e: + raise UnexpectedError(f"{file}: {e.__class__.__name__}: {e}", e) + + +def run_clang_format_diff(args, file): + try: + with open(file, encoding="utf-8") as f: + original = f.readlines() + except OSError as exc: + raise DiffError(str(exc)) + invocation = [args.clang_format_executable, file] + + # Use of utf-8 to decode the process output. + # + # Hopefully, this is the correct thing to do. + # + # It's done due to the following assumptions (which may be incorrect): + # - clang-format will returns the bytes read from the files as-is, + # without conversion, and it is already assumed that the files use utf-8. + # - if the diagnostics were internationalized, they would use utf-8: + # > Adding Translations to Clang + # > + # > Not possible yet! + # > Diagnostic strings should be written in UTF-8, + # > the client can translate to the relevant code page if needed. + # > Each translation completely replaces the format string + # > for the diagnostic. + # > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation + + try: + proc = subprocess.Popen( + invocation, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + encoding="utf-8", + ) + except OSError as exc: + raise DiffError( + f"Command '{subprocess.list2cmdline(invocation)}' failed to start: {exc}" + ) + proc_stdout = proc.stdout + proc_stderr = proc.stderr + + # hopefully the stderr pipe won't get full and block the process + outs = list(proc_stdout.readlines()) + errs = list(proc_stderr.readlines()) + proc.wait() + if proc.returncode: + raise DiffError( + "Command '{}' returned non-zero exit status {}".format( + subprocess.list2cmdline(invocation), proc.returncode + ), + errs, + ) + return make_diff(file, original, outs), errs + + +def bold_red(s): + return "\x1b[1m\x1b[31m" + s + "\x1b[0m" + + +def colorize(diff_lines): + def bold(s): + return "\x1b[1m" + s + "\x1b[0m" + + def cyan(s): + return "\x1b[36m" + s + "\x1b[0m" + + def green(s): + return "\x1b[32m" + s + "\x1b[0m" + + def red(s): + return "\x1b[31m" + s + "\x1b[0m" + + for line in diff_lines: + if line[:4] in ["--- ", "+++ "]: + yield bold(line) + elif line.startswith("@@ "): + yield cyan(line) + elif line.startswith("+"): + yield green(line) + elif line.startswith("-"): + yield red(line) + else: + yield line + + +def print_diff(diff_lines, use_color): + if use_color: + diff_lines = colorize(diff_lines) + sys.stdout.writelines(diff_lines) + + +def print_trouble(prog, message, use_colors): + error_text = "error:" + if use_colors: + error_text = bold_red(error_text) + print(f"{prog}: {error_text} {message}", file=sys.stderr) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--clang-format-executable", + metavar="EXECUTABLE", + help="path to the clang-format executable", + default="clang-format", + ) + parser.add_argument( + "--extensions", + help=f"comma separated list of file extensions (default: {DEFAULT_EXTENSIONS})", + default=DEFAULT_EXTENSIONS, + ) + parser.add_argument( + "-r", + "--recursive", + action="store_true", + help="run recursively over directories", + ) + parser.add_argument("files", metavar="file", nargs="+") + parser.add_argument("-q", "--quiet", action="store_true") + parser.add_argument( + "-j", + metavar="N", + type=int, + default=0, + help="run N clang-format jobs in parallel (default number of cpus + 1)", + ) + parser.add_argument( + "--color", + default="auto", + choices=["auto", "always", "never"], + help="show colored diff (default: auto)", + ) + parser.add_argument( + "-e", + "--exclude", + metavar="PATTERN", + action="append", + default=[], + help="exclude paths matching the given glob-like pattern(s) from recursive search", + ) + + args = parser.parse_args() + + # use default signal handling, like diff return SIGINT value on ^C + # https://bugs.python.org/issue14229#msg156446 + signal.signal(signal.SIGINT, signal.SIG_DFL) + try: + signal.SIGPIPE + except AttributeError: + # compatibility, SIGPIPE does not exist on Windows + pass + else: + signal.signal(signal.SIGPIPE, signal.SIG_DFL) + + colored_stdout = False + colored_stderr = False + if args.color == "always": + colored_stdout = True + colored_stderr = True + elif args.color == "auto": + colored_stdout = sys.stdout.isatty() + colored_stderr = sys.stderr.isatty() + + version_invocation = [args.clang_format_executable, "--version"] + try: + subprocess.check_call(version_invocation, stdout=DEVNULL) + except subprocess.CalledProcessError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + return ExitStatus.TROUBLE + except OSError as e: + print_trouble( + parser.prog, + f"Command '{subprocess.list2cmdline(version_invocation)}' failed to start: {e}", + use_colors=colored_stderr, + ) + return ExitStatus.TROUBLE + + retcode = ExitStatus.SUCCESS + files = list_files( + args.files, + recursive=args.recursive, + exclude=args.exclude, + extensions=args.extensions.split(","), + ) + + if not files: + return + + njobs = args.j + if njobs == 0: + njobs = multiprocessing.cpu_count() + 1 + njobs = min(len(files), njobs) + + if njobs == 1: + # execute directly instead of in a pool, + # less overhead, simpler stacktraces + it = (run_clang_format_diff_wrapper(args, file) for file in files) + pool = None + else: + pool = multiprocessing.Pool(njobs) + it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files) + while True: + try: + outs, errs = next(it) + except StopIteration: + break + except DiffError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + retcode = ExitStatus.TROUBLE + sys.stderr.writelines(e.errs) + except UnexpectedError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + sys.stderr.write(e.formatted_traceback) + retcode = ExitStatus.TROUBLE + # stop at the first unexpected error, + # something could be very wrong, + # don't process all files unnecessarily + if pool: + pool.terminate() + break + else: + sys.stderr.writelines(errs) + if outs == []: + continue + if not args.quiet: + print_diff(outs, use_color=colored_stdout) + if retcode == ExitStatus.SUCCESS: + retcode = ExitStatus.DIFF + return retcode + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.circleci/unittest/linux_libs/scripts_smac/run_test.sh b/.circleci/unittest/linux_libs/scripts_smac/run_test.sh new file mode 100755 index 00000000000..4f47aab13bd --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_smac/run_test.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env +apt-get update && apt-get install -y git wget + + +export PYTORCH_TEST_WITH_SLOW='1' +python -m torch.utils.collect_env +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' + +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" +export SC2PATH="${root_dir}/smac/StarCraftII" + +# solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir +export MKL_THREADING_LAYER=GNU +# more logging +export MAGNUM_LOG=verbose MAGNUM_GPU_VALIDATION=ON + +# this workflow only tests the libs +python -c "import smac" + +coverage run -m pytest test/test_libs.py --instafail -v --durations 20 --capture no -k TestSmac +coverage xml -i diff --git a/.circleci/unittest/linux_libs/scripts_smac/setup_env.sh b/.circleci/unittest/linux_libs/scripts_smac/setup_env.sh new file mode 100755 index 00000000000..234ca2d894f --- /dev/null +++ b/.circleci/unittest/linux_libs/scripts_smac/setup_env.sh @@ -0,0 +1,59 @@ +#!/usr/bin/env bash + +# This script is for setting up environment in which unit test is ran. +# To speed up the CI time, the resulting environment is cached. +# +# Do not install PyTorch and torchvision here, otherwise they also get cached. + +set -e + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" + bash ./miniconda.sh -b -f -p "${conda_dir}" +fi +eval "$(${conda_dir}/bin/conda shell.bash hook)" + +# 2. Create test environment at ./env +printf "python: ${PYTHON_VERSION}\n" +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" +fi +conda activate "${env_dir}" + +# 3. Install Conda dependencies +printf "* Installing dependencies (except PyTorch)\n" +echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +cat "${this_dir}/environment.yml" + +pip install pip --upgrade + +conda env update --file "${this_dir}/environment.yml" --prune + +# 4. Install StarCraft 2 with SMAC maps +# SC2PATH is set in run_test.sh +printf "* Installing StarCraft 2 and SMAC maps into '${root_dir}/smac/StarCraftII'\n" +mkdir $root_dir/smac +cd $root_dir/smac/ +# TODO: discuss how we can cache it to avoid downloading ~4 GB on each run. +wget https://blzdistsc2-a.akamaihd.net/Linux/SC2.4.10.zip +# The archive contains StarCraftII folder. Password comes from the documentation. +unzip -qo -P iagreetotheeula SC2.4.10.zip +wget https://github.com/oxwhirl/smac/releases/download/v0.1-beta1/SMAC_Maps.zip +tar -xf SMAC_Maps.zip --directory ./StarCraftII/Maps From f68355b094f3bc621474485e44744d124fc255c3 Mon Sep 17 00:00:00 2001 From: Sergey Ordinskiy Date: Tue, 10 Jan 2023 20:25:56 +0100 Subject: [PATCH 04/10] amend --- .../linux_libs/scripts_smac/setup_env.sh | 1 + test/test_libs.py | 2 +- torchrl/data/__init__.py | 3 -- torchrl/data/tensor_specs.py | 1 + torchrl/envs/common.py | 19 ++++++- torchrl/envs/libs/smac.py | 51 ++++++++++++------- 6 files changed, 52 insertions(+), 25 deletions(-) diff --git a/.circleci/unittest/linux_libs/scripts_smac/setup_env.sh b/.circleci/unittest/linux_libs/scripts_smac/setup_env.sh index 234ca2d894f..ee2f4dcfda4 100755 --- a/.circleci/unittest/linux_libs/scripts_smac/setup_env.sh +++ b/.circleci/unittest/linux_libs/scripts_smac/setup_env.sh @@ -52,6 +52,7 @@ printf "* Installing StarCraft 2 and SMAC maps into '${root_dir}/smac/StarCraftI mkdir $root_dir/smac cd $root_dir/smac/ # TODO: discuss how we can cache it to avoid downloading ~4 GB on each run. +# e.g adding this into the image (learn which one is used and how it is maintained) wget https://blzdistsc2-a.akamaihd.net/Linux/SC2.4.10.zip # The archive contains StarCraftII folder. Password comes from the documentation. unzip -qo -P iagreetotheeula SC2.4.10.zip diff --git a/test/test_libs.py b/test/test_libs.py index 179269d2690..7ae84b2bc63 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -687,7 +687,7 @@ def test_smac_seeding(self, map_name): for _ in range(2): torch.manual_seed(seed) np.random.seed(seed) - env = SC2Env(map_name, seed) + env = SC2Env(map_name, seed=seed) final_seed.append(env.get_seed()) tdreset.append(env.reset()) tdrollout.append(env.rollout(max_steps=50)) diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 50a90fdfc7e..31e18673bbd 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -24,10 +24,7 @@ DiscreteTensorSpec, MultiDiscreteTensorSpec, MultOneHotDiscreteTensorSpec, - NdBoundedTensorSpec, NdOneHotDiscreteTensorSpec, - NdUnboundedContinuousTensorSpec, - NdUnboundedDiscreteTensorSpec, OneHotDiscreteTensorSpec, TensorSpec, UnboundedContinuousTensorSpec, diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index ba4c312a6a2..f1cc9471c1e 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -520,6 +520,7 @@ def to_categorical(self) -> DiscreteTensorSpec: return DiscreteTensorSpec(self.space.n, device=self.device, dtype=self.dtype) +# TODO: ask Vincent if this should replace OneHotDiscreteTensorSpec as in https://github.com/pytorch/rl/issues/771 @dataclass(repr=False) class NdOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec): """An N-dimensional One hot discrete tensor spec data class.""" diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index aec32eff395..0dc613d5c2a 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -206,6 +206,7 @@ def __init__( dtype: Optional[Union[torch.dtype, np.dtype]] = None, batch_size: Optional[torch.Size] = None, run_type_checks: bool = True, + multi_agent_environment: bool = False ): super().__init__() if device is not None: @@ -228,6 +229,7 @@ def __init__( ): self.batch_size = torch.Size([]) self._run_type_checks = run_type_checks + self._multi_agent_environment = multi_agent_environment @classmethod def __new__(cls, *args, _inplace_update=False, _batch_locked=True, **kwargs): @@ -524,7 +526,12 @@ def rand_step(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBa tensordict = TensorDict( {}, device=self.device, batch_size=self.batch_size, _run_checks=False ) - action = self.action_spec.rand(self.batch_size) + if self._multi_agent_environment: + # TODO: change in accordance with https://github.com/pytorch/rl/issues/766 resolution + # In case of multi-agent environment, action spec includes batch_size. + action = self.action_spec.rand(torch.Size([])) + else: + action = self.action_spec.rand(self.batch_size) tensordict.set("action", action) return self.step(tensordict) @@ -594,7 +601,13 @@ def rollout( if policy is None: def policy(td): - return td.set("action", self.action_spec.rand(self.batch_size)) + if self._multi_agent_environment: + # TODO: change in accordance with https://github.com/pytorch/rl/issues/766 resolution + # In case of multi-agent environment, action spec includes batch_size. + action = self.action_spec.rand(torch.Size([])) + else: + action = self.action_spec.rand(self.batch_size) + return td.set("action", action) tensordicts = [] for i in range(max_steps): @@ -744,12 +757,14 @@ def __init__( dtype: Optional[np.dtype] = None, device: DEVICE_TYPING = "cpu", batch_size: Optional[torch.Size] = None, + multi_agent_environment: bool = False, **kwargs, ): super().__init__( device=device, dtype=dtype, batch_size=batch_size, + multi_agent_environment=multi_agent_environment ) if len(args): raise ValueError( diff --git a/torchrl/envs/libs/smac.py b/torchrl/envs/libs/smac.py index 407b0267ec9..1d7f3d0dfaa 100644 --- a/torchrl/envs/libs/smac.py +++ b/torchrl/envs/libs/smac.py @@ -6,7 +6,6 @@ from torchrl.data import ( CompositeSpec, CustomNdOneHotDiscreteTensorSpec, - NdUnboundedContinuousTensorSpec, TensorSpec, UnboundedContinuousTensorSpec, ) @@ -15,7 +14,7 @@ try: import smac import smac.env - from smac.env.starcraft2.maps import smac_maps + from smac.env.starcraft2.maps import smac_maps, get_map_params _has_smac = True except ImportError as err: @@ -62,10 +61,19 @@ class SC2Wrapper(GymLikeEnv): available_envs = _get_envs() libname = "smac" - def __init__(self, env: smac.env.StarCraft2Env = None, **kwargs): + def __init__( + self, + env: smac.env.StarCraft2Env = None, + batch_size: Optional[torch.Size] = None, + **kwargs, + ): if env is not None: kwargs["env"] = env - super().__init__(**kwargs) + if batch_size is None: + batch_size = torch.Size([env.n_agents]) + + kwargs["batch_size"] = batch_size + super().__init__(multi_agent_environment=True, **kwargs) def _check_kwargs(self, kwargs: Dict): if "env" not in kwargs: @@ -75,40 +83,36 @@ def _check_kwargs(self, kwargs: Dict): raise TypeError("env is not of type 'smac.env.StarCraft2Env'.") def _build_env(self, env, **kwargs) -> smac.env.StarCraft2Env: - # StarCraft2Env must be initialized before _make_specs. - env.reset() return env def _make_specs(self, env: smac.env.StarCraft2Env) -> None: # Extract specs from definition. self.reward_spec = self._make_reward_spec() - # Extract specs from instance. - # To extract these specs environment must be fully initialized with env.reset(). - self.input_spec = self._make_input_spec(env) - self.observation_spec = self._make_observation_spec(env) - # TODO: add support for the state. # self.state_spec = self._make_state_spec(env) # self.input_spec["state"] = self._state_spec # self._state_example = self._make_state_example(env) def _init_env(self) -> None: - pass + self._env.reset() + + # Before extracting environment specific specs, env.reset() must be executed. + self.input_spec = self._make_input_spec(self._env) + self.observation_spec = self._make_observation_spec(self._env) def _make_reward_spec(self) -> TensorSpec: return UnboundedContinuousTensorSpec(device=self.device) def _make_input_spec(self, env: smac.env.StarCraft2Env) -> TensorSpec: + # TODO: change in accordance with https://github.com/pytorch/rl/issues/766 resolution action_spec = CustomNdOneHotDiscreteTensorSpec( torch.tensor(env.get_avail_actions(), dtype=torch.bool), device=self.device ) return CompositeSpec(action=action_spec) def _make_observation_spec(self, env: smac.env.StarCraft2Env) -> TensorSpec: - info = env.get_env_info() - size = torch.Size((info["n_agents"], info["obs_shape"])) - obs_spec = NdUnboundedContinuousTensorSpec(size, device=self.device) + obs_spec = UnboundedContinuousTensorSpec(env.get_obs_size(), device=self.device) return CompositeSpec(observation=obs_spec) def _set_seed(self, seed: Optional[int]): @@ -139,7 +143,8 @@ def _reset( tensordict_out.set("done", done) # TODO: add support for the state. # tensordict_out["state"] = state_dict - # TODO: return available actions? + + self.input_spec = self._make_input_spec(env) return tensordict_out @@ -160,8 +165,9 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # collect outputs # state_dict = self.read_state(state) obs_dict = self.read_obs(env.get_obs()) - reward = self._to_tensor(reward, dtype=self.reward_spec.dtype) - done = self._to_tensor(done, dtype=torch.bool) + # TODO: discuss if this is correct way to handle reward and state in case of MARL with batch_size=n_agents + reward = self._to_tensor(reward, dtype=self.reward_spec.dtype).expand(self.batch_size) + done = self._to_tensor(done, dtype=torch.bool).expand(self.batch_size) # build results tensordict_out = TensorDict( @@ -208,10 +214,17 @@ class SC2Env(SC2Wrapper): ['3m', '8m', '25m', '5m_vs_6m', '8m_vs_9m', ...] """ - def __init__(self, map_name: str, seed: Optional[int] = None, **kwargs): + def __init__(self, map_name: str, batch_size: Optional[torch.Size] = None, seed: Optional[int] = None, **kwargs): kwargs["map_name"] = map_name + + if batch_size is None: + map_info = get_map_params(map_name) + batch_size = torch.Size([map_info["n_agents"]]) + kwargs["batch_size"] = batch_size + if seed is not None: kwargs["seed"] = seed + super().__init__(**kwargs) def _check_kwargs(self, kwargs: Dict): From d168ebc13caf058e5d743c735bbf029743bc5fac Mon Sep 17 00:00:00 2001 From: Sergey Ordinskiy Date: Fri, 27 Jan 2023 21:05:06 +0100 Subject: [PATCH 05/10] amend --- torchrl/data/tensor_specs.py | 60 ++++++++++++++++++++++++++---------- torchrl/envs/libs/smac.py | 6 ++-- 2 files changed, 48 insertions(+), 18 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 611d75346a2..706ef686751 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1066,6 +1066,8 @@ class MultiOneHotDiscreteTensorSpec(OneHotDiscreteTensorSpec): device (str, int or torch.device, optional): device of the tensors. dtype (str or torch.dtype, optional): dtype of the tensors. + mask (torch.Tensor, optional): mask that defines elements that can be sampled by rand(). + The expected mask shape is the spec shape and mask dtype is torch.bool. Examples: >>> ts = MultiOneHotDiscreteTensorSpec((3,2,3)) @@ -1087,6 +1089,7 @@ def __init__( device=None, dtype=torch.long, use_register=False, + mask: Optional[torch.Tensor] = None ): self.nvec = nvec dtype, device = _default_dtype_and_device(dtype, device) @@ -1101,6 +1104,23 @@ def __init__( ) space = BoxList([DiscreteBox(n) for n in nvec]) self.use_register = use_register + + if mask is not None: + if mask.shape != shape: + raise ValueError( + f"Expected a mask with the shape of the spec. " + f"Got {mask.shape} but expected {self.shape}." + ) + if mask.dtype is not torch.bool: + raise ValueError( + f"Expected a mask with dtype torch.bool but got {mask.dtype}" + ) + if (mask.sum(-1) == 0).any(): + raise ValueError("Got an empty mask for some dimension.") + if mask.device != self.device: + raise ValueError(f"Expected a mask with the same device {self.device} but got {mask.device}.") + self.mask = mask + super(OneHotDiscreteTensorSpec, self).__init__( shape, space, device, dtype, domain="discrete" ) @@ -1112,11 +1132,14 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: else: dest_dtype = self.dtype dest_device = torch.device(dest) + mask = self.mask.to(dest_device, copy=True) if self.mask else None return self.__class__( nvec=deepcopy(self.nvec), shape=self.shape, device=dest_device, dtype=dest_dtype, + # TODO: is it a bug that use_register is not copied? + mask=mask ) def clone(self) -> CompositeSpec: @@ -1125,6 +1148,7 @@ def clone(self) -> CompositeSpec: shape=self.shape, device=self.device, dtype=self.dtype, + mask=self.mask.clone() ) def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: @@ -1133,23 +1157,27 @@ def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: else: shape = torch.Size([*shape, *self.shape[:-1]]) - x = torch.cat( - [ - torch.nn.functional.one_hot( - torch.randint( - space.n, - ( - *shape, - 1, + if self.mask is not None: + r = torch.rand(self.mask.shape, device=self.device).masked_fill_(~self.mask, 0.0) + x = (r == r.max(-1, keepdim=True)[0]).to(torch.long) + else: + x = torch.cat( + [ + torch.nn.functional.one_hot( + torch.randint( + space.n, + ( + *shape, + 1, + ), + device=self.device, ), - device=self.device, - ), - space.n, - ).to(torch.long) - for space in self.space - ], - -1, - ).squeeze(-2) + space.n, + ).to(torch.long) + for space in self.space + ], + -1, + ).squeeze(-2) return x def encode(self, val: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: diff --git a/torchrl/envs/libs/smac.py b/torchrl/envs/libs/smac.py index 7967acb4e32..731d50ae100 100644 --- a/torchrl/envs/libs/smac.py +++ b/torchrl/envs/libs/smac.py @@ -115,11 +115,13 @@ def _make_reward_spec(self) -> TensorSpec: ) def _make_input_spec(self, env: smac.env.StarCraft2Env) -> TensorSpec: - # TODO: add mask from env.get_avail_actions() + mask = torch.tensor(env.get_avail_actions(), dtype=torch.bool, device=self.device) + action_spec = MultiOneHotDiscreteTensorSpec( [env.n_actions], shape=torch.Size([env.n_agents, env.n_actions]), device=self.device, + mask=mask ) return CompositeSpec(action=action_spec, shape=self.batch_size) @@ -177,7 +179,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # collect outputs # state_dict = self.read_state(state) obs_dict = self.read_obs(env.get_obs()) - reward = self._to_tensor(reward, dtype=self.reward_spec.dtype) + reward = self._to_tensor(reward, dtype=self.reward_spec.dtype).expand(self.batch_size) done = self._to_tensor(done, dtype=torch.bool).expand(self.batch_size) # build results From 9a95bce394a9129f66dbe264b9f12e7222f7c2d3 Mon Sep 17 00:00:00 2001 From: Sergey Ordinskiy Date: Mon, 30 Jan 2023 20:35:13 +0100 Subject: [PATCH 06/10] amend --- torchrl/data/tensor_specs.py | 14 +++++++---- torchrl/envs/libs/smac.py | 48 ++++++++++++++++++++++-------------- 2 files changed, 38 insertions(+), 24 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 706ef686751..31ef435492e 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1089,7 +1089,7 @@ def __init__( device=None, dtype=torch.long, use_register=False, - mask: Optional[torch.Tensor] = None + mask: Optional[torch.Tensor] = None, ): self.nvec = nvec dtype, device = _default_dtype_and_device(dtype, device) @@ -1118,7 +1118,9 @@ def __init__( if (mask.sum(-1) == 0).any(): raise ValueError("Got an empty mask for some dimension.") if mask.device != self.device: - raise ValueError(f"Expected a mask with the same device {self.device} but got {mask.device}.") + raise ValueError( + f"Expected a mask with the same device {self.device} but got {mask.device}." + ) self.mask = mask super(OneHotDiscreteTensorSpec, self).__init__( @@ -1139,7 +1141,7 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CompositeSpec: device=dest_device, dtype=dest_dtype, # TODO: is it a bug that use_register is not copied? - mask=mask + mask=mask, ) def clone(self) -> CompositeSpec: @@ -1148,7 +1150,7 @@ def clone(self) -> CompositeSpec: shape=self.shape, device=self.device, dtype=self.dtype, - mask=self.mask.clone() + mask=self.mask.clone(), ) def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: @@ -1158,7 +1160,9 @@ def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: shape = torch.Size([*shape, *self.shape[:-1]]) if self.mask is not None: - r = torch.rand(self.mask.shape, device=self.device).masked_fill_(~self.mask, 0.0) + r = torch.rand(self.mask.shape, device=self.device).masked_fill_( + ~self.mask, 0.0 + ) x = (r == r.max(-1, keepdim=True)[0]).to(torch.long) else: x = torch.cat( diff --git a/torchrl/envs/libs/smac.py b/torchrl/envs/libs/smac.py index 731d50ae100..473e199d2a0 100644 --- a/torchrl/envs/libs/smac.py +++ b/torchrl/envs/libs/smac.py @@ -1,5 +1,6 @@ from typing import Dict, Optional +import numpy as np import torch from tensordict.tensordict import TensorDict, TensorDictBase @@ -91,17 +92,13 @@ def _make_specs(self, env: smac.env.StarCraft2Env) -> None: # Specs that require initialized environment are built in _init_env. - # TODO: add support for the state. - # self.state_spec = self._make_state_spec(env) - # self.input_spec["state"] = self._state_spec - # self._state_example = self._make_state_example(env) - def _init_env(self) -> None: self._env.reset() # Before extracting environment specific specs, env.reset() must be executed. self.input_spec = self._make_input_spec(self._env) self.observation_spec = self._make_observation_spec(self._env) + self.state_spec = self._make_state_spec(self._env) def _make_reward_spec(self) -> TensorSpec: return UnboundedContinuousTensorSpec( @@ -115,13 +112,15 @@ def _make_reward_spec(self) -> TensorSpec: ) def _make_input_spec(self, env: smac.env.StarCraft2Env) -> TensorSpec: - mask = torch.tensor(env.get_avail_actions(), dtype=torch.bool, device=self.device) + mask = torch.tensor( + env.get_avail_actions(), dtype=torch.bool, device=self.device + ) action_spec = MultiOneHotDiscreteTensorSpec( [env.n_actions], shape=torch.Size([env.n_agents, env.n_actions]), device=self.device, - mask=mask + mask=mask, ) return CompositeSpec(action=action_spec, shape=self.batch_size) @@ -131,6 +130,20 @@ def _make_observation_spec(self, env: smac.env.StarCraft2Env) -> TensorSpec: ) return CompositeSpec(observation=obs_spec, shape=self.batch_size) + def _make_state_spec(self, env: smac.env.StarCraft2Env) -> TensorSpec: + return UnboundedContinuousTensorSpec( + torch.Size([env.n_agents, env.get_state_size()]), device=self.device + ) + + def _action_transform(self, action: torch.Tensor): + action_np = self.action_spec.to_numpy(action) + return action_np + + def _read_state(self, state: np.ndarray) -> torch.Tensor: + return self.state_spec.encode( + torch.Tensor(state, device=self.device).expand(*self.state_spec.shape) + ) + def _set_seed(self, seed: Optional[int]): raise NotImplementedError( "Seed cannot be changed once environment was created." @@ -143,9 +156,8 @@ def _reset( obs, state = env.reset() # collect outputs - # TODO: add support for the state. - # state_dict = self.read_state(state) obs_dict = self.read_obs(obs) + state = self._read_state(state) self._is_done = torch.zeros(self.batch_size, dtype=torch.bool) # build results @@ -155,17 +167,12 @@ def _reset( device=self.device, ) tensordict_out.set("done", self._is_done) - # TODO: add support for the state. - # tensordict_out["state"] = state_dict + tensordict_out["state"] = state self.input_spec = self._make_input_spec(env) return tensordict_out - def _action_transform(self, action): - action_np = self.action_spec.to_numpy(action) - return action_np - def _step(self, tensordict: TensorDictBase) -> TensorDictBase: env: smac.env.StarCraft2Env = self._env @@ -177,9 +184,13 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: reward, done, info = env.step(action_np) # collect outputs - # state_dict = self.read_state(state) obs_dict = self.read_obs(env.get_obs()) - reward = self._to_tensor(reward, dtype=self.reward_spec.dtype).expand(self.batch_size) + # TODO: add centralized flag? + state = self._read_state(env.get_state()) + + reward = self._to_tensor(reward, dtype=self.reward_spec.dtype).expand( + self.batch_size + ) done = self._to_tensor(done, dtype=torch.bool).expand(self.batch_size) # build results @@ -190,8 +201,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: ) tensordict_out.set("reward", reward) tensordict_out.set("done", done) - # TODO: support state. - # tensordict_out["state"] = state_dict + tensordict_out["state"] = state # Update available actions mask. self.input_spec = self._make_input_spec(env) From 69e0208285d002b3b1ea878f4c756ded4d33498b Mon Sep 17 00:00:00 2001 From: Sergey Ordinskiy Date: Tue, 31 Jan 2023 18:44:03 +0100 Subject: [PATCH 07/10] amend --- test/test_libs.py | 4 ++-- torchrl/envs/libs/smac.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index ab1877adda5..6165efff888 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -51,7 +51,7 @@ IS_OSX = platform == "darwin" -@pytest.mark.skipif(_has_gym, reason="no gym library found") +@pytest.mark.skipif(not _has_gym, reason="no gym library found") @pytest.mark.parametrize( "env_name", [ @@ -156,7 +156,7 @@ def _make_gym_environment(env_name): # noqa: F811 return gym.make(env_name, render_mode="rgb_array") -@pytest.mark.skipif(_has_dmc, reason="no dm_control library found") +@pytest.mark.skipif(not _has_dmc, reason="no dm_control library found") @pytest.mark.parametrize("env_name,task", [["cheetah", "run"]]) @pytest.mark.parametrize("frame_skip", [1, 3]) @pytest.mark.parametrize( diff --git a/torchrl/envs/libs/smac.py b/torchrl/envs/libs/smac.py index 473e199d2a0..3a0328a9ea2 100644 --- a/torchrl/envs/libs/smac.py +++ b/torchrl/envs/libs/smac.py @@ -64,7 +64,7 @@ class SC2Wrapper(GymLikeEnv): def __init__( self, - env: smac.env.StarCraft2Env = None, + env: "smac.env.StarCraft2Env" = None, batch_size: Optional[torch.Size] = None, **kwargs, ): @@ -86,7 +86,7 @@ def _check_kwargs(self, kwargs: Dict): def _build_env(self, env, **kwargs) -> smac.env.StarCraft2Env: return env - def _make_specs(self, env: smac.env.StarCraft2Env) -> None: + def _make_specs(self, env: "smac.env.StarCraft2Env") -> None: # Extract specs from definition. self.reward_spec = self._make_reward_spec() @@ -111,7 +111,7 @@ def _make_reward_spec(self) -> TensorSpec: device=self.device, ) - def _make_input_spec(self, env: smac.env.StarCraft2Env) -> TensorSpec: + def _make_input_spec(self, env: "smac.env.StarCraft2Env") -> TensorSpec: mask = torch.tensor( env.get_avail_actions(), dtype=torch.bool, device=self.device ) @@ -124,13 +124,13 @@ def _make_input_spec(self, env: smac.env.StarCraft2Env) -> TensorSpec: ) return CompositeSpec(action=action_spec, shape=self.batch_size) - def _make_observation_spec(self, env: smac.env.StarCraft2Env) -> TensorSpec: + def _make_observation_spec(self, env: "smac.env.StarCraft2Env") -> TensorSpec: obs_spec = UnboundedContinuousTensorSpec( torch.Size([env.n_agents, env.get_obs_size()]), device=self.device ) return CompositeSpec(observation=obs_spec, shape=self.batch_size) - def _make_state_spec(self, env: smac.env.StarCraft2Env) -> TensorSpec: + def _make_state_spec(self, env: "smac.env.StarCraft2Env") -> TensorSpec: return UnboundedContinuousTensorSpec( torch.Size([env.n_agents, env.get_state_size()]), device=self.device ) From d629a4fd8b5edf9fbca7f0e602b15f37b23d84b3 Mon Sep 17 00:00:00 2001 From: Sergey Ordinskiy Date: Tue, 31 Jan 2023 19:04:56 +0100 Subject: [PATCH 08/10] amend --- .circleci/unittest/linux_libs/scripts_smac/environment.yml | 2 +- torchrl/envs/libs/smac.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.circleci/unittest/linux_libs/scripts_smac/environment.yml b/.circleci/unittest/linux_libs/scripts_smac/environment.yml index 3612eb4264a..7a94c25317a 100644 --- a/.circleci/unittest/linux_libs/scripts_smac/environment.yml +++ b/.circleci/unittest/linux_libs/scripts_smac/environment.yml @@ -15,4 +15,4 @@ dependencies: - pyyaml - scipy - hydra-core - - smac + - git+https://github.com/oxwhirl/smac.git diff --git a/torchrl/envs/libs/smac.py b/torchrl/envs/libs/smac.py index 3a0328a9ea2..d1578b06c1e 100644 --- a/torchrl/envs/libs/smac.py +++ b/torchrl/envs/libs/smac.py @@ -265,7 +265,7 @@ def _build_env( map_name: str, seed: Optional[int] = None, **kwargs, - ) -> smac.env.StarCraft2Env: + ) -> "smac.env.StarCraft2Env": if not _has_smac: raise RuntimeError( f"smac not found, unable to create smac.env.StarCraft2Env. " From 50c0f44ab29496a1396f946cbfc18978a823d190 Mon Sep 17 00:00:00 2001 From: Sergey Ordinskiy Date: Tue, 31 Jan 2023 20:23:01 +0100 Subject: [PATCH 09/10] amend --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index f5970e72133..731e6262067 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -587,7 +587,7 @@ jobs: - env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_libs/scripts_smac/environment.yml" }}-{{ checksum ".circleci-weekly" }} - run: name: Setup - command: .circleci/unittest/linux_libs/scripts_jumanji/setup_env.sh + command: .circleci/unittest/linux_libs/scripts_smac/setup_env.sh - save_cache: key: env-v3-linux-{{ arch }}-py<< parameters.python_version >>-{{ checksum ".circleci/unittest/linux_libs/scripts_smac/environment.yml" }}-{{ checksum ".circleci-weekly" }} paths: From 9af08376ad309fa6f83028ee0c594c888509eeaa Mon Sep 17 00:00:00 2001 From: Sergey Ordinskiy Date: Thu, 2 Feb 2023 15:12:19 +0100 Subject: [PATCH 10/10] amend --- .circleci/unittest/linux_libs/scripts_smac/setup_env.sh | 2 +- torchrl/envs/libs/smac.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.circleci/unittest/linux_libs/scripts_smac/setup_env.sh b/.circleci/unittest/linux_libs/scripts_smac/setup_env.sh index ee2f4dcfda4..e265b61f5f3 100755 --- a/.circleci/unittest/linux_libs/scripts_smac/setup_env.sh +++ b/.circleci/unittest/linux_libs/scripts_smac/setup_env.sh @@ -57,4 +57,4 @@ wget https://blzdistsc2-a.akamaihd.net/Linux/SC2.4.10.zip # The archive contains StarCraftII folder. Password comes from the documentation. unzip -qo -P iagreetotheeula SC2.4.10.zip wget https://github.com/oxwhirl/smac/releases/download/v0.1-beta1/SMAC_Maps.zip -tar -xf SMAC_Maps.zip --directory ./StarCraftII/Maps +unzip -qo SMAC_Maps.zip -d ./StarCraftII/Maps diff --git a/torchrl/envs/libs/smac.py b/torchrl/envs/libs/smac.py index d1578b06c1e..ad11d21affb 100644 --- a/torchrl/envs/libs/smac.py +++ b/torchrl/envs/libs/smac.py @@ -83,7 +83,7 @@ def _check_kwargs(self, kwargs: Dict): if not isinstance(env, (smac.env.StarCraft2Env,)): raise TypeError("env is not of type 'smac.env.StarCraft2Env'.") - def _build_env(self, env, **kwargs) -> smac.env.StarCraft2Env: + def _build_env(self, env, **kwargs) -> "smac.env.StarCraft2Env": return env def _make_specs(self, env: "smac.env.StarCraft2Env") -> None: