diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 4bb88a855..cd44457cd 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -27,7 +27,7 @@ jobs: - name: Test with pytest # ignore test/throughput which only profiles the code run: | - pytest test --ignore-glob='*profile.py' --cov=tianshou --cov-report=xml --cov-report=term-missing --durations=0 -v + pytest test --ignore-glob='*profile.py' --ignore="test/3rd_party" --cov=tianshou --cov-report=xml --cov-report=term-missing --durations=0 -v - name: Upload coverage to Codecov uses: codecov/codecov-action@v1 with: diff --git a/setup.py b/setup.py index 73ce9ea94..f462f5572 100644 --- a/setup.py +++ b/setup.py @@ -15,14 +15,13 @@ def get_version() -> str: def get_install_requires() -> str: return [ - "gym>=0.21", + "gym>=0.15.4", "tqdm", "numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793 "tensorboard>=2.5.0", "torch>=1.4.0", "numba>=0.51.0", "h5py>=2.10.0", # to match tensorflow's minimal requirements - "pettingzoo>=1.15", ] @@ -46,8 +45,10 @@ def get_extras_require() -> str: "doc8", "scipy", "pillow", + "pettingzoo>=1.12", "pygame>=2.1.0", # pettingzoo test cases pistonball "pymunk>=6.2.1", # pettingzoo test cases pistonball + "nni>=2.3", ], "atari": ["atari_py", "opencv-python"], "mujoco": ["mujoco_py"], diff --git a/test/3rd_party/test_nni.py b/test/3rd_party/test_nni.py new file mode 100644 index 000000000..23d714b17 --- /dev/null +++ b/test/3rd_party/test_nni.py @@ -0,0 +1,126 @@ +# https://github.com/microsoft/nni/blob/master/test/ut/retiarii/test_strategy.py + +import random +import threading +import time +from typing import List, Union + +import nni.retiarii.execution.api +import nni.retiarii.nn.pytorch as nn +import nni.retiarii.strategy as strategy +import torch +import torch.nn.functional as F +from nni.retiarii import Model +from nni.retiarii.converter import convert_to_graph +from nni.retiarii.execution import wait_models +from nni.retiarii.execution.interface import ( + AbstractExecutionEngine, + AbstractGraphListener, + MetricData, + WorkerInfo, +) +from nni.retiarii.graph import DebugEvaluator, ModelStatus +from nni.retiarii.nn.pytorch.mutator import process_inline_mutation + + +class MockExecutionEngine(AbstractExecutionEngine): + + def __init__(self, failure_prob=0.): + self.models = [] + self.failure_prob = failure_prob + self._resource_left = 4 + + def _model_complete(self, model: Model): + time.sleep(random.uniform(0, 1)) + if random.uniform(0, 1) < self.failure_prob: + model.status = ModelStatus.Failed + else: + model.metric = random.uniform(0, 1) + model.status = ModelStatus.Trained + self._resource_left += 1 + + def submit_models(self, *models: Model) -> None: + for model in models: + self.models.append(model) + self._resource_left -= 1 + threading.Thread(target=self._model_complete, args=(model, )).start() + + def list_models(self) -> List[Model]: + return self.models + + def query_available_resource(self) -> Union[List[WorkerInfo], int]: + return self._resource_left + + def budget_exhausted(self) -> bool: + pass + + def register_graph_listener(self, listener: AbstractGraphListener) -> None: + pass + + def trial_execute_graph(cls) -> MetricData: + pass + + +def _reset_execution_engine(engine=None): + nni.retiarii.execution.api._execution_engine = engine + + +class Net(nn.Module): + + def __init__(self, hidden_size=32, diff_size=False): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 20, 5, 1) + self.conv2 = nn.Conv2d(20, 50, 5, 1) + self.fc1 = nn.LayerChoice( + [ + nn.Linear(4 * 4 * 50, hidden_size, bias=True), + nn.Linear(4 * 4 * 50, hidden_size, bias=False) + ], + label='fc1' + ) + self.fc2 = nn.LayerChoice( + [ + nn.Linear(hidden_size, 10, bias=False), + nn.Linear(hidden_size, 10, bias=True) + ] + ([] if not diff_size else [nn.Linear(hidden_size, 10, bias=False)]), + label='fc2' + ) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = F.max_pool2d(x, 2, 2) + x = F.relu(self.conv2(x)) + x = F.max_pool2d(x, 2, 2) + x = x.view(-1, 4 * 4 * 50) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + + +def _get_model_and_mutators(**kwargs): + base_model = Net(**kwargs) + script_module = torch.jit.script(base_model) + base_model_ir = convert_to_graph(script_module, base_model) + base_model_ir.evaluator = DebugEvaluator() + mutators = process_inline_mutation(base_model_ir) + return base_model_ir, mutators + + +def test_rl(): + rl = strategy.PolicyBasedRL(max_collect=2, trial_per_collect=10) + engine = MockExecutionEngine(failure_prob=0.2) + _reset_execution_engine(engine) + rl.run(*_get_model_and_mutators(diff_size=True)) + wait_models(*engine.models) + _reset_execution_engine() + + rl = strategy.PolicyBasedRL(max_collect=2, trial_per_collect=10) + engine = MockExecutionEngine(failure_prob=0.2) + _reset_execution_engine(engine) + rl.run(*_get_model_and_mutators()) + wait_models(*engine.models) + _reset_execution_engine() + + +if __name__ == '__main__': + test_rl() diff --git a/tianshou/__init__.py b/tianshou/__init__.py index 55b8a0ea9..cd337d188 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1,6 +1,6 @@ from tianshou import data, env, exploration, policy, trainer, utils -__version__ = "0.4.6" +__version__ = "0.4.6.post1" __all__ = [ "env", diff --git a/tianshou/env/__init__.py b/tianshou/env/__init__.py index f32e0cff0..3845f2691 100644 --- a/tianshou/env/__init__.py +++ b/tianshou/env/__init__.py @@ -1,6 +1,5 @@ """Env package.""" -from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.env.venvs import ( BaseVectorEnv, DummyVectorEnv, @@ -9,6 +8,11 @@ SubprocVectorEnv, ) +try: + from tianshou.env.pettingzoo_env import PettingZooEnv +except ImportError: + pass + __all__ = [ "BaseVectorEnv", "DummyVectorEnv", diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index 02b50b183..c668109b6 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -2,7 +2,6 @@ import gym import numpy as np -import pettingzoo from tianshou.env.worker import ( DummyEnvWorker, @@ -365,10 +364,7 @@ class DummyVectorEnv(BaseVectorEnv): Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. """ - def __init__( - self, env_fns: List[Callable[[], Union[gym.Env, pettingzoo.AECEnv]]], - **kwargs: Any - ) -> None: + def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None: super().__init__(env_fns, DummyEnvWorker, **kwargs) diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py index f23f8969a..3ae1b1618 100644 --- a/tianshou/env/worker/base.py +++ b/tianshou/env/worker/base.py @@ -1,3 +1,4 @@ +import warnings from abc import ABC, abstractmethod from typing import Any, Callable, List, Optional, Tuple, Union @@ -11,8 +12,10 @@ class EnvWorker(ABC): def __init__(self, env_fn: Callable[[], gym.Env]) -> None: self._env_fn = env_fn self.is_closed = False - self.result: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] + self.result: Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], + np.ndarray] self.action_space = self.get_env_attr("action_space") # noqa: B009 + self.is_reset = False @abstractmethod def get_env_attr(self, key: str) -> Any: @@ -22,7 +25,6 @@ def get_env_attr(self, key: str) -> Any: def set_env_attr(self, key: str, value: Any) -> None: pass - @abstractmethod def send(self, action: Optional[np.ndarray]) -> None: """Send action signal to low-level worker. @@ -30,7 +32,17 @@ def send(self, action: Optional[np.ndarray]) -> None: it indicates "step" signal. The paired return value from "recv" function is determined by such kind of different signal. """ - pass + if hasattr(self, "send_action"): + warnings.warn( + "send_action will soon be deprecated. " + "Please use send and recv for your own EnvWorker." + ) + if action is None: + self.is_reset = True + self.result = self.reset() + else: + self.is_reset = False + self.send_action(action) # type: ignore def recv( self @@ -41,6 +53,13 @@ def recv( single observation; otherwise it returns a tuple of (obs, rew, done, info). """ + if hasattr(self, "get_result"): + warnings.warn( + "get_result will soon be deprecated. " + "Please use send and recv for your own EnvWorker." + ) + if not self.is_reset: + self.result = self.get_result() # type: ignore return self.result def reset(self) -> np.ndarray: diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index 6c69a8d36..bcbba9795 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -3,9 +3,13 @@ import numpy as np from tianshou.data import Batch, ReplayBuffer -from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.policy import BasePolicy +try: + from tianshou.env.pettingzoo_env import PettingZooEnv +except ImportError: + PettingZooEnv = None # type: ignore + class MultiAgentPolicyManager(BasePolicy): """Multi-agent policy manager for MARL.