Skip to content

Commit

Permalink
fix(nyz): fix logger middleware problems (#715)
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Aug 31, 2023
1 parent b06ce44 commit 23fac67
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 49 deletions.
120 changes: 80 additions & 40 deletions ding/framework/middleware/functional/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import wandb
import pickle
import treetensor.numpy as tnp
from tensorboardX import SummaryWriter
from ding.framework import task
from ding.envs import BaseEnvManagerV2
from ding.utils import DistributedWriter
Expand All @@ -22,37 +21,27 @@
from ding.framework import OnlineRLContext, OfflineRLContext


def softmax(logit):
v = np.exp(logit)
return v / v.sum(axis=-1, keepdims=True)


def action_prob(num, action_prob, ln):
ax = plt.gca()
ax.set_ylim([0, 1])
for rect, x in zip(ln, action_prob[num]):
rect.set_height(x)
return ln


def return_prob(num, return_prob, ln):
return ln


def return_distribution(episode_return):
num = len(episode_return)
max_return = max(episode_return)
min_return = min(episode_return)
hist, bins = np.histogram(episode_return, bins=np.linspace(min_return - 50, max_return + 50, 6))
gap = (max_return - min_return + 100) / 5
x_dim = ['{:.1f}'.format(min_return - 50 + gap * x) for x in range(5)]
return hist / num, x_dim


def online_logger(record_train_iter: bool = False, train_show_freq: int = 100) -> Callable:
"""
Overview:
Create an online RL tensorboard logger for recording training and evaluation metrics.
Arguments:
- record_train_iter (:obj:`bool`): Whether to record training iteration. Default is False.
- train_show_freq (:obj:`int`): Frequency of showing training logs. Default is 100.
Returns:
- _logger (:obj:`Callable`): A logger function that takes an OnlineRLContext object as input.
Raises:
- RuntimeError: If writer is None.
- NotImplementedError: If the key of train_output is not supported, such as "scalars".
Examples:
>>> task.use(online_logger(record_train_iter=False, train_show_freq=1000))
"""
if task.router.is_active and not task.has_role(task.role.LEARNER):
return task.void()
writer = DistributedWriter.get_instance()
if writer is None:
raise RuntimeError("logger writer is None, you should call `ding_init(cfg)` at the beginning of training.")
last_train_show_iter = -1

def _logger(ctx: "OnlineRLContext"):
Expand All @@ -69,7 +58,7 @@ def _logger(ctx: "OnlineRLContext"):
if ctx.train_output is not None and ctx.train_iter - last_train_show_iter >= train_show_freq:
last_train_show_iter = ctx.train_iter
if isinstance(ctx.train_output, List):
output = ctx.train_output.pop() # only use latest output
output = ctx.train_output.pop() # only use latest output for some algorithms, like PPO
else:
output = ctx.train_output
for k, v in output.items():
Expand All @@ -93,17 +82,36 @@ def _logger(ctx: "OnlineRLContext"):
return _logger


def offline_logger(exp_name: str = None) -> Callable:
def offline_logger(train_show_freq: int = 100) -> Callable:
"""
Overview:
Create an offline RL tensorboard logger for recording training and evaluation metrics.
Arguments:
- train_show_freq (:obj:`int`): Frequency of showing training logs. Defaults to 100.
Returns:
- _logger (:obj:`Callable`): A logger function that takes an OfflineRLContext object as input.
Raises:
- RuntimeError: If writer is None.
- NotImplementedError: If the key of train_output is not supported, such as "scalars".
Examples:
>>> task.use(offline_logger(train_show_freq=1000))
"""
if task.router.is_active and not task.has_role(task.role.LEARNER):
return task.void()
writer = SummaryWriter(logdir=exp_name)
writer = DistributedWriter.get_instance()
if writer is None:
raise RuntimeError("logger writer is None, you should call `ding_init(cfg)` at the beginning of training.")
last_train_show_iter = -1

def _logger(ctx: "OfflineRLContext"):
nonlocal last_train_show_iter
if task.finish:
writer.close()
if not np.isinf(ctx.eval_value):
writer.add_scalar('basic/eval_episode_return_mean-train_iter', ctx.eval_value, ctx.train_iter)
if ctx.train_output is not None:
if ctx.train_output is not None and ctx.train_iter - last_train_show_iter >= train_show_freq:
last_train_show_iter = ctx.train_iter
output = ctx.train_output
for k, v in output.items():
if k in ['priority']:
Expand All @@ -120,6 +128,34 @@ def _logger(ctx: "OfflineRLContext"):
return _logger


# four utility functions for wandb logger
def softmax(logit: np.ndarray) -> np.ndarray:
v = np.exp(logit)
return v / v.sum(axis=-1, keepdims=True)


def action_prob(num, action_prob, ln):
ax = plt.gca()
ax.set_ylim([0, 1])
for rect, x in zip(ln, action_prob[num]):
rect.set_height(x)
return ln


def return_prob(num, return_prob, ln):
return ln


def return_distribution(episode_return):
num = len(episode_return)
max_return = max(episode_return)
min_return = min(episode_return)
hist, bins = np.histogram(episode_return, bins=np.linspace(min_return - 50, max_return + 50, 6))
gap = (max_return - min_return + 100) / 5
x_dim = ['{:.1f}'.format(min_return - 50 + gap * x) for x in range(5)]
return hist / num, x_dim


def wandb_online_logger(
record_path: str = None,
cfg: Union[dict, EasyDict] = None,
Expand All @@ -129,7 +165,7 @@ def wandb_online_logger(
anonymous: bool = False,
project_name: str = 'default-project',
) -> Callable:
'''
"""
Overview:
Wandb visualizer to track the experiment.
Arguments:
Expand All @@ -143,10 +179,12 @@ def wandb_online_logger(
- metric_list (:obj:`Optional[List[str]]`): Logged metric list, specialized by different policies.
- env (:obj:`BaseEnvManagerV2`): Evaluator environment.
- model (:obj:`nn.Module`): Policy neural network model.
- anonymous (:obj:`bool`): Open the anonymous mode of wandb or not.
The anonymous mode allows visualization of data without wandb count.
- anonymous (:obj:`bool`): Open the anonymous mode of wandb or not. The anonymous mode allows visualization \
of data without wandb count.
- project_name (:obj:`str`): The name of wandb project.
'''
Returns:
- _plot (:obj:`Callable`): A logger function that takes an OnlineRLContext object as input.
"""
if task.router.is_active and not task.has_role(task.role.LEARNER):
return task.void()
color_list = ["orange", "red", "blue", "purple", "green", "darkcyan"]
Expand Down Expand Up @@ -294,7 +332,7 @@ def wandb_offline_logger(
anonymous: bool = False,
project_name: str = 'default-project',
) -> Callable:
'''
"""
Overview:
Wandb visualizer to track the experiment.
Arguments:
Expand All @@ -309,10 +347,12 @@ def wandb_offline_logger(
- metric_list (:obj:`Optional[List[str]]`): Logged metric list, specialized by different policies.
- env (:obj:`BaseEnvManagerV2`): Evaluator environment.
- model (:obj:`nn.Module`): Policy neural network model.
- anonymous (:obj:`bool`): Open the anonymous mode of wandb or not.
The anonymous mode allows visualization of data without wandb count.
- anonymous (:obj:`bool`): Open the anonymous mode of wandb or not. The anonymous mode allows visualization \
of data without wandb count.
- project_name (:obj:`str`): The name of wandb project.
'''
Returns:
- _plot (:obj:`Callable`): A logger function that takes an OfflineRLContext object as input.
"""
if task.router.is_active and not task.has_role(task.role.LEARNER):
return task.void()
color_list = ["orange", "red", "blue", "purple", "green", "darkcyan"]
Expand Down
17 changes: 8 additions & 9 deletions ding/framework/middleware/tests/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def test_wandb_online_logger():
model = TheModelClass()
wandb.init(config=cfg, anonymous="must")

def mock_metric_logger(metric_dict, step):
def mock_metric_logger(data, step):
metric_list = [
"q_value",
"target q_value",
Expand All @@ -229,7 +229,7 @@ def mock_metric_logger(metric_dict, step):
"actions_of_trajectory_3",
"return distribution",
]
assert set(metric_dict.keys()) <= set(metric_list)
assert set(data.keys()) <= set(metric_list)

def mock_gradient_logger(input_model):
assert input_model == model
Expand All @@ -246,9 +246,7 @@ def test_wandb_online_logger_gradient():
test_wandb_online_logger_gradient()


# @pytest.mark.unittest
# TODO(nyz): fix CI bug when py=3.8.15
@pytest.mark.tmp
@pytest.mark.unittest
def test_wandb_offline_logger(mocker):
record_path = './video_pendulum_cql'
cfg = EasyDict(dict(gradient_logger=True, plot_logger=True, action_logger=True, vis_dataset=True))
Expand All @@ -258,12 +256,12 @@ def test_wandb_offline_logger(mocker):
model = TheModelClass()
wandb.init(config=cfg, anonymous="must")

def mock_metric_logger(metric_dict):
def mock_metric_logger(data, step=None):
metric_list = [
"q_value", "target q_value", "loss", "lr", "entropy", "reward", "q value", "video", "q value distribution",
"train iter", 'dataset'
]
assert set(metric_dict.keys()) < set(metric_list)
assert set(data.keys()) < set(metric_list)

def mock_gradient_logger(input_model):
assert input_model == model
Expand All @@ -273,8 +271,9 @@ def mock_image_logger(imagepath):

def test_wandb_offline_logger_gradient():
cfg.vis_dataset = False
print(cfg)
with patch.object(wandb, 'watch', new=mock_gradient_logger):
wandb_offline_logger(cfg, env, model, 'dataset.h5', anonymous=True)(ctx)
wandb_offline_logger('dataset.h5', record_path, cfg, env=env, model=model, anonymous=True)(ctx)

def test_wandb_offline_logger_dataset():
cfg.vis_dataset = True
Expand All @@ -283,7 +282,7 @@ def test_wandb_offline_logger_dataset():
with patch.object(wandb, 'log', new=mock_metric_logger):
with patch.object(wandb, 'Image', new=mock_image_logger):
mocker.patch('h5py.File', return_value=m)
wandb_offline_logger(cfg, env, model, 'dataset.h5', anonymous=True)(ctx)
wandb_offline_logger('dataset.h5', record_path, cfg, env=env, model=model, anonymous=True)(ctx)

test_wandb_offline_logger_gradient()
test_wandb_offline_logger_dataset()

0 comments on commit 23fac67

Please sign in to comment.