From 4e60e033ca9aabe996b3755d8a0be2cec9e787c8 Mon Sep 17 00:00:00 2001 From: takuseno Date: Sun, 31 Jan 2021 13:34:29 +0900 Subject: [PATCH] Save statistics at Monitor wrapper --- d3rlpy/envs/wrappers.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/d3rlpy/envs/wrappers.py b/d3rlpy/envs/wrappers.py index bcb61546..983109d2 100644 --- a/d3rlpy/envs/wrappers.py +++ b/d3rlpy/envs/wrappers.py @@ -1,3 +1,4 @@ +import json import os from typing import Any, Callable, Dict, Optional, Tuple, Union @@ -280,6 +281,8 @@ class Monitor(gym.Wrapper): # type: ignore _video_callable: Callable[[int], bool] _framerate: float _episode: int + _episode_return: float + _episode_step: int _buffer: np.ndarray def __init__( @@ -305,6 +308,8 @@ def __init__( self._framerate = framerate self._episode = 0 + self._episode_return = 0.0 + self._episode_step = 0 self._buffer = [] def step( @@ -316,13 +321,18 @@ def step( # store rendering frame = cv2.cvtColor(super().render("rgb_array"), cv2.COLOR_BGR2RGB) self._buffer.append(frame) + self._episode_step += 1 + self._episode_return += reward if done: self._save_video() + self._save_stats() return obs, reward, done, info def reset(self, **kwargs: Any) -> np.ndarray: self._episode += 1 + self._episode_return = 0.0 + self._episode_step = 0 self._buffer = [] return super().reset(**kwargs) @@ -335,3 +345,13 @@ def _save_video(self) -> None: for frame in self._buffer: writer.write(frame) writer.release() + + def _save_stats(self) -> None: + path = os.path.join(self._directory, f"stats{self._episode}.json") + stats = { + "episode_step": self._episode_step, + "return": self._episode_return, + } + with open(path, "w") as f: + json_str = json.dumps(stats, indent=2) + f.write(json_str)