diff --git a/VERSION b/VERSION index 8294c18..7693c96 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.1.2 \ No newline at end of file +0.1.3 \ No newline at end of file diff --git a/deepracer_env/__init__.py b/deepracer_env/__init__.py index 4fd11e0..0a61e7d 100644 --- a/deepracer_env/__init__.py +++ b/deepracer_env/__init__.py @@ -14,7 +14,7 @@ # limitations under the License. # ################################################################################# """DeepRacerEnv modules""" -from .deepracer_env import DeepRacerEnv +from .deepracer_env import DeepRacerEnv, DeepRacerEnvObserverInterface """DeepRacer Environment Config modules""" from deepracer_env_config import TrackDirection diff --git a/deepracer_env/deepracer_env.py b/deepracer_env/deepracer_env.py index 1ace985..13cc2fa 100644 --- a/deepracer_env/deepracer_env.py +++ b/deepracer_env/deepracer_env.py @@ -16,6 +16,7 @@ """A class for DeepRacerEnv environment.""" from typing import Dict, Optional, List, Tuple, Any, FrozenSet, Union import math +from threading import RLock from gym import Space @@ -32,6 +33,43 @@ ) +class DeepRacerEnvObserverInterface(object): + """ + DeepRacerEnv Observer Interface + """ + def on_step(self, env: 'DeepRacerEnv', step_result: UDEStepResult) -> None: + """ + On Step callback. + - Called after step completed. + + Args: + env (DeepRacerEnv): DeepRacer environment. + step_result (UDEStepResult): step result (obs, reward, done, last action, info) + """ + pass + + def on_reset(self, env: 'DeepRacerEnv', reset_result: UDEResetResult) -> None: + """ + On Reset callback. + - Called after reset completed. + + Args: + env (DeepRacerEnv): DeepRacer environment. + reset_result (UDEResetResult): reset result (obs, info) + """ + pass + + def on_close(self, env: 'DeepRacerEnv') -> None: + """ + On Close callback. + - Called after close completed. + + Args: + env (DeepRacerEnv): DeepRacer environment. + """ + pass + + class DeepRacerEnv(UDEEnvironmentInterface): """ DeepRacerEnv Class. @@ -76,6 +114,28 @@ def __init__(self, area_config = self._deepracer_config.get_area() self._track_names = area_config.track_names self._shell_names = area_config.shell_names + self._observer_lock = RLock() + self._observers = set() + + def register(self, observer: DeepRacerEnvObserverInterface) -> None: + """ + Register given observer. + + Args: + observer (DeepRacerEnvObserverInterface): observer + """ + with self._observer_lock: + self._observers.add(observer) + + def unregister(self, observer: DeepRacerEnvObserverInterface) -> None: + """ + Unregister given observer. + + Args: + observer (DeepRacerEnvObserverInterface): observer to discard + """ + with self._observer_lock: + self._observers.discard(observer) def step(self, action_dict: MultiAgentDict) -> UDEStepResult: """ @@ -101,7 +161,13 @@ def step(self, action_dict: MultiAgentDict) -> UDEStepResult: math.isnan(speed) or math.isinf(speed): raise ValueError("Agent's action value cannot contain nan or inf: {{}: {}}".format(agent_id, action)) - return self._env.step(action_dict=action_dict) + step_result = self._env.step(action_dict=action_dict) + + with self._observer_lock: + observers = self._observers.copy() + for observer in observers: + observer.on_step(env=self, step_result=step_result) + return step_result def reset(self) -> UDEResetResult: """ @@ -111,13 +177,23 @@ def reset(self) -> UDEResetResult: Returns: UDEResetResult: first observation and info in new episode. """ - return self._env.reset() + reset_result = self._env.reset() + + with self._observer_lock: + observers = self._observers.copy() + for observer in observers: + observer.on_reset(env=self, reset_result=reset_result) + return reset_result def close(self) -> None: """ Close the environment, and environment will be no longer available to be used. """ - return self._env.close() + self._env.close() + with self._observer_lock: + observers = self._observers.copy() + for observer in observers: + observer.on_close(env=self) @property def observation_space(self) -> Dict[AgentID, Space]: diff --git a/test/deepracer_env/test_deepracer_env.py b/test/deepracer_env/test_deepracer_env.py index e8eecfd..13aaa6c 100644 --- a/test/deepracer_env/test_deepracer_env.py +++ b/test/deepracer_env/test_deepracer_env.py @@ -20,13 +20,26 @@ import math -from deepracer_env import DeepRacerEnv -from ude import Compression - +from deepracer_env import DeepRacerEnv, DeepRacerEnvObserverInterface +from ude import Compression, UDEResetResult, UDEStepResult myself: Callable[[], Any] = lambda: inspect.stack()[1][3] +class DummyObserver(DeepRacerEnvObserverInterface): + def __init__(self): + self.mock = MagicMock() + + def on_step(self, env: 'DeepRacerEnv', step_result: UDEStepResult) -> None: + self.mock.on_step(env=env, step_result=step_result) + + def on_reset(self, env: 'DeepRacerEnv', reset_result: UDEResetResult) -> None: + self.mock.on_reset(env=env, reset_result=reset_result) + + def on_close(self, env: 'DeepRacerEnv') -> None: + self.mock.on_close(env=env) + + @patch("deepracer_env.deepracer_env.Client") @patch("deepracer_env.deepracer_env.UDEEnvironment") @patch("deepracer_env.deepracer_env.RemoteEnvironmentAdapter") @@ -94,6 +107,31 @@ def test_initialize_with_param(self, assert env.track_names == deepracer_config_mock.return_value.get_area.return_value.track_names assert env.shell_names == deepracer_config_mock.return_value.get_area.return_value.shell_names + def test_register(self, + remote_env_adapter_mock, + ude_env_mock, + deepracer_config_mock): + address = "test_ip" + env = DeepRacerEnv(address=address) + observer_mock = DummyObserver() + env.register(observer=observer_mock) + + assert observer_mock in env._observers + + def test_unregister(self, + remote_env_adapter_mock, + ude_env_mock, + deepracer_config_mock): + address = "test_ip" + env = DeepRacerEnv(address=address) + observer_mock = DummyObserver() + env.register(observer=observer_mock) + + assert observer_mock in env._observers + + env.unregister(observer=observer_mock) + assert observer_mock not in env._observers + def test_step(self, remote_env_adapter_mock, ude_env_mock, @@ -136,7 +174,6 @@ def test_step_ignore_more_than_two_value(self, ude_env_mock.return_value.step.assert_called_once_with(action_dict=expected_action_dict) assert step_result == ude_env_mock.return_value.step.return_value - def test_step_nan_or_inf(self, remote_env_adapter_mock, ude_env_mock, @@ -168,6 +205,23 @@ def test_step_nan_or_inf(self, with self.assertRaises(ValueError): _ = env.step(action_dict=action_dict) + def test_step_with_observer(self, + remote_env_adapter_mock, + ude_env_mock, + deepracer_config_mock): + address = "test_ip" + + env = DeepRacerEnv(address=address) + observer_mock = DummyObserver() + env.register(observer=observer_mock) + + action_dict = {"agent1": (1.0, 2.0)} + + step_result = env.step(action_dict=action_dict) + ude_env_mock.return_value.step.assert_called_once_with(action_dict=action_dict) + assert step_result == ude_env_mock.return_value.step.return_value + observer_mock.mock.on_step.assert_called_once_with(env=env, + step_result=step_result) def test_reset(self, remote_env_adapter_mock, @@ -179,6 +233,21 @@ def test_reset(self, ude_env_mock.return_value.reset.assert_called_once() assert reset_result == ude_env_mock.return_value.reset.return_value + def test_reset_with_observer(self, + remote_env_adapter_mock, + ude_env_mock, + deepracer_config_mock): + address = "test_ip" + env = DeepRacerEnv(address=address) + observer_mock = DummyObserver() + env.register(observer=observer_mock) + + reset_result = env.reset() + ude_env_mock.return_value.reset.assert_called_once() + assert reset_result == ude_env_mock.return_value.reset.return_value + observer_mock.mock.on_reset.assert_called_once_with(env=env, + reset_result=reset_result) + def test_close(self, remote_env_adapter_mock, ude_env_mock, @@ -188,6 +257,18 @@ def test_close(self, env.close() ude_env_mock.return_value.close.assert_called_once() + def test_close_with_observer(self, + remote_env_adapter_mock, + ude_env_mock, + deepracer_config_mock): + address = "test_ip" + env = DeepRacerEnv(address=address) + observer_mock = DummyObserver() + env.register(observer=observer_mock) + env.close() + ude_env_mock.return_value.close.assert_called_once() + observer_mock.mock.on_close.assert_called_once_with(env=env) + def test_observation_space(self, remote_env_adapter_mock, ude_env_mock,