Skip to content

Commit

Permalink
fix(nyz): fix enable_save_figure env interface
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Jul 18, 2023
1 parent 41ca06c commit 49f2e20
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 17 deletions.
21 changes: 9 additions & 12 deletions ding/envs/env_manager/base_env_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ class BaseEnvManager(object):
Overview:
Create a BaseEnvManager to manage multiple environments.
Interfaces:
reset, step, seed, close, enable_save_replay, launch, default_config, env_state_done
reset, step, seed, close, enable_save_replay, launch, default_config, env_state_done, reward_shaping, \
enable_save_figure
Properties:
env_num, ready_obs, done, method_name_list
observation_space, action_space, reward_space
env_num, ready_obs, done, method_name_list, observation_space, action_space, reward_space
"""

@classmethod
Expand Down Expand Up @@ -411,7 +411,7 @@ def seed(self, seed: Union[Dict[int, int], List[int], int], dynamic_seed: bool =
def enable_save_replay(self, replay_path: Union[List[str], str]) -> None:
"""
Overview:
Set each env's replay save path.
Enable all environments to save replay video after each episode terminates.
Arguments:
- replay_path (:obj:`Union[List[str], str]`): List of paths for each environment; \
Or one path for all environments.
Expand All @@ -420,18 +420,15 @@ def enable_save_replay(self, replay_path: Union[List[str], str]) -> None:
replay_path = [replay_path] * self.env_num
self._env_replay_path = replay_path

def enable_save_figure(self, env_id: int, figure_path: Union[List[str], str]) -> None:
def enable_save_figure(self, env_id: int, figure_path: str) -> None:
"""
Overview:
Set each env's replay save path.
Enable a specific env to save figure (e.g. environment statistics or episode return curve).
Arguments:
- replay_path (:obj:`Union[List[str], str]`): List of paths for each environment; \
Or one path for all environments.
- figure_path (:obj:`str`): The file directory path for all environments to save figures.
"""
if isinstance(figure_path, str):
self._env[env_id].enable_save_figure(figure_path)
else:
raise TypeError("invalid figure_path arguments type: {}".format(type(figure_path)))
assert figure_path is not None
self._env[env_id].enable_save_figure(figure_path)

def close(self) -> None:
"""
Expand Down
11 changes: 6 additions & 5 deletions ding/worker/collector/interaction_serial_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ class InteractionSerialEvaluator(ISerialEvaluator):
# Evaluate every "eval_freq" training iterations.
eval_freq=1000,
render=dict(
# tensorboard video render is disabled by default
# Tensorboard video render is disabled by default.
render_freq=-1,
mode='train_iter',
)
),
# File path for visualize environment information.
figure_path=None,
)

def __init__(
Expand Down Expand Up @@ -231,9 +233,8 @@ def eval(
continue
if t.done:
# Env reset is done by env_manager automatically.
if 'figure_path' in self._cfg:
if self._cfg.figure_path is not None:
self._env.enable_save_figure(env_id, self._cfg.figure_path)
if 'figure_path' in self._cfg and self._cfg.figure_path is not None:
self._env.enable_save_figure(env_id, self._cfg.figure_path)
self._policy.reset([env_id])
reward = t.info['eval_episode_return']
if 'episode_info' in t.info:
Expand Down

0 comments on commit 49f2e20

Please sign in to comment.