Skip to content

Commit

Permalink
fix conda support and keep API compatibility (#536)
Browse files Browse the repository at this point in the history
* loose constrains

* fix nni issue (#478)

* fix coverage
  • Loading branch information
Trinkle23897 authored Feb 25, 2022
1 parent 97df511 commit c248b4f
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
- name: Test with pytest
# ignore test/throughput which only profiles the code
run: |
pytest test --ignore-glob='*profile.py' --cov=tianshou --cov-report=xml --cov-report=term-missing --durations=0 -v
pytest test --ignore-glob='*profile.py' --ignore="test/3rd_party" --cov=tianshou --cov-report=xml --cov-report=term-missing --durations=0 -v
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v1
with:
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@ def get_version() -> str:

def get_install_requires() -> str:
return [
"gym>=0.21",
"gym>=0.15.4",
"tqdm",
"numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793
"tensorboard>=2.5.0",
"torch>=1.4.0",
"numba>=0.51.0",
"h5py>=2.10.0", # to match tensorflow's minimal requirements
"pettingzoo>=1.15",
]


Expand All @@ -46,8 +45,10 @@ def get_extras_require() -> str:
"doc8",
"scipy",
"pillow",
"pettingzoo>=1.12",
"pygame>=2.1.0", # pettingzoo test cases pistonball
"pymunk>=6.2.1", # pettingzoo test cases pistonball
"nni>=2.3",
],
"atari": ["atari_py", "opencv-python"],
"mujoco": ["mujoco_py"],
Expand Down
126 changes: 126 additions & 0 deletions test/3rd_party/test_nni.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# https://github.com/microsoft/nni/blob/master/test/ut/retiarii/test_strategy.py

import random
import threading
import time
from typing import List, Union

import nni.retiarii.execution.api
import nni.retiarii.nn.pytorch as nn
import nni.retiarii.strategy as strategy
import torch
import torch.nn.functional as F
from nni.retiarii import Model
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.execution import wait_models
from nni.retiarii.execution.interface import (
AbstractExecutionEngine,
AbstractGraphListener,
MetricData,
WorkerInfo,
)
from nni.retiarii.graph import DebugEvaluator, ModelStatus
from nni.retiarii.nn.pytorch.mutator import process_inline_mutation


class MockExecutionEngine(AbstractExecutionEngine):

def __init__(self, failure_prob=0.):
self.models = []
self.failure_prob = failure_prob
self._resource_left = 4

def _model_complete(self, model: Model):
time.sleep(random.uniform(0, 1))
if random.uniform(0, 1) < self.failure_prob:
model.status = ModelStatus.Failed
else:
model.metric = random.uniform(0, 1)
model.status = ModelStatus.Trained
self._resource_left += 1

def submit_models(self, *models: Model) -> None:
for model in models:
self.models.append(model)
self._resource_left -= 1
threading.Thread(target=self._model_complete, args=(model, )).start()

def list_models(self) -> List[Model]:
return self.models

def query_available_resource(self) -> Union[List[WorkerInfo], int]:
return self._resource_left

def budget_exhausted(self) -> bool:
pass

def register_graph_listener(self, listener: AbstractGraphListener) -> None:
pass

def trial_execute_graph(cls) -> MetricData:
pass


def _reset_execution_engine(engine=None):
nni.retiarii.execution.api._execution_engine = engine


class Net(nn.Module):

def __init__(self, hidden_size=32, diff_size=False):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.LayerChoice(
[
nn.Linear(4 * 4 * 50, hidden_size, bias=True),
nn.Linear(4 * 4 * 50, hidden_size, bias=False)
],
label='fc1'
)
self.fc2 = nn.LayerChoice(
[
nn.Linear(hidden_size, 10, bias=False),
nn.Linear(hidden_size, 10, bias=True)
] + ([] if not diff_size else [nn.Linear(hidden_size, 10, bias=False)]),
label='fc2'
)

def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)


def _get_model_and_mutators(**kwargs):
base_model = Net(**kwargs)
script_module = torch.jit.script(base_model)
base_model_ir = convert_to_graph(script_module, base_model)
base_model_ir.evaluator = DebugEvaluator()
mutators = process_inline_mutation(base_model_ir)
return base_model_ir, mutators


def test_rl():
rl = strategy.PolicyBasedRL(max_collect=2, trial_per_collect=10)
engine = MockExecutionEngine(failure_prob=0.2)
_reset_execution_engine(engine)
rl.run(*_get_model_and_mutators(diff_size=True))
wait_models(*engine.models)
_reset_execution_engine()

rl = strategy.PolicyBasedRL(max_collect=2, trial_per_collect=10)
engine = MockExecutionEngine(failure_prob=0.2)
_reset_execution_engine(engine)
rl.run(*_get_model_and_mutators())
wait_models(*engine.models)
_reset_execution_engine()


if __name__ == '__main__':
test_rl()
2 changes: 1 addition & 1 deletion tianshou/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from tianshou import data, env, exploration, policy, trainer, utils

__version__ = "0.4.6"
__version__ = "0.4.6.post1"

__all__ = [
"env",
Expand Down
6 changes: 5 additions & 1 deletion tianshou/env/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Env package."""

from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.env.venvs import (
BaseVectorEnv,
DummyVectorEnv,
Expand All @@ -9,6 +8,11 @@
SubprocVectorEnv,
)

try:
from tianshou.env.pettingzoo_env import PettingZooEnv
except ImportError:
pass

__all__ = [
"BaseVectorEnv",
"DummyVectorEnv",
Expand Down
6 changes: 1 addition & 5 deletions tianshou/env/venvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import gym
import numpy as np
import pettingzoo

from tianshou.env.worker import (
DummyEnvWorker,
Expand Down Expand Up @@ -365,10 +364,7 @@ class DummyVectorEnv(BaseVectorEnv):
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
"""

def __init__(
self, env_fns: List[Callable[[], Union[gym.Env, pettingzoo.AECEnv]]],
**kwargs: Any
) -> None:
def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
super().__init__(env_fns, DummyEnvWorker, **kwargs)


Expand Down
25 changes: 22 additions & 3 deletions tianshou/env/worker/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from abc import ABC, abstractmethod
from typing import Any, Callable, List, Optional, Tuple, Union

Expand All @@ -11,8 +12,10 @@ class EnvWorker(ABC):
def __init__(self, env_fn: Callable[[], gym.Env]) -> None:
self._env_fn = env_fn
self.is_closed = False
self.result: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
self.result: Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
np.ndarray]
self.action_space = self.get_env_attr("action_space") # noqa: B009
self.is_reset = False

@abstractmethod
def get_env_attr(self, key: str) -> Any:
Expand All @@ -22,15 +25,24 @@ def get_env_attr(self, key: str) -> Any:
def set_env_attr(self, key: str, value: Any) -> None:
pass

@abstractmethod
def send(self, action: Optional[np.ndarray]) -> None:
"""Send action signal to low-level worker.
When action is None, it indicates sending "reset" signal; otherwise
it indicates "step" signal. The paired return value from "recv"
function is determined by such kind of different signal.
"""
pass
if hasattr(self, "send_action"):
warnings.warn(
"send_action will soon be deprecated. "
"Please use send and recv for your own EnvWorker."
)
if action is None:
self.is_reset = True
self.result = self.reset()
else:
self.is_reset = False
self.send_action(action) # type: ignore

def recv(
self
Expand All @@ -41,6 +53,13 @@ def recv(
single observation; otherwise it returns a tuple of (obs, rew, done,
info).
"""
if hasattr(self, "get_result"):
warnings.warn(
"get_result will soon be deprecated. "
"Please use send and recv for your own EnvWorker."
)
if not self.is_reset:
self.result = self.get_result() # type: ignore
return self.result

def reset(self) -> np.ndarray:
Expand Down
6 changes: 5 additions & 1 deletion tianshou/policy/multiagent/mapolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
import numpy as np

from tianshou.data import Batch, ReplayBuffer
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import BasePolicy

try:
from tianshou.env.pettingzoo_env import PettingZooEnv
except ImportError:
PettingZooEnv = None # type: ignore


class MultiAgentPolicyManager(BasePolicy):
"""Multi-agent policy manager for MARL.
Expand Down

0 comments on commit c248b4f

Please sign in to comment.