Skip to content

Commit

Permalink
Save statistics at Monitor wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Jan 31, 2021
1 parent 016f075 commit 4e60e03
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions d3rlpy/envs/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
from typing import Any, Callable, Dict, Optional, Tuple, Union

Expand Down Expand Up @@ -280,6 +281,8 @@ class Monitor(gym.Wrapper): # type: ignore
_video_callable: Callable[[int], bool]
_framerate: float
_episode: int
_episode_return: float
_episode_step: int
_buffer: np.ndarray

def __init__(
Expand All @@ -305,6 +308,8 @@ def __init__(
self._framerate = framerate

self._episode = 0
self._episode_return = 0.0
self._episode_step = 0
self._buffer = []

def step(
Expand All @@ -316,13 +321,18 @@ def step(
# store rendering
frame = cv2.cvtColor(super().render("rgb_array"), cv2.COLOR_BGR2RGB)
self._buffer.append(frame)
self._episode_step += 1
self._episode_return += reward
if done:
self._save_video()
self._save_stats()

return obs, reward, done, info

def reset(self, **kwargs: Any) -> np.ndarray:
self._episode += 1
self._episode_return = 0.0
self._episode_step = 0
self._buffer = []
return super().reset(**kwargs)

Expand All @@ -335,3 +345,13 @@ def _save_video(self) -> None:
for frame in self._buffer:
writer.write(frame)
writer.release()

def _save_stats(self) -> None:
path = os.path.join(self._directory, f"stats{self._episode}.json")
stats = {
"episode_step": self._episode_step,
"return": self._episode_return,
}
with open(path, "w") as f:
json_str = json.dumps(stats, indent=2)
f.write(json_str)

0 comments on commit 4e60e03

Please sign in to comment.