Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Feb 1, 2025
1 parent 4277cef commit feeeefc
Showing 1 changed file with 37 additions and 10 deletions.
47 changes: 37 additions & 10 deletions benchmarl/experiment/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import json
import os
import warnings
from collections.abc import MutableMapping, Sequence
from pathlib import Path

from typing import Dict, List, Optional
Expand All @@ -17,6 +18,8 @@

from tensordict import TensorDictBase
from torch import Tensor

from torchrl.record import TensorboardLogger
from torchrl.record.loggers import get_logger
from torchrl.record.loggers.wandb import WandbLogger

Expand Down Expand Up @@ -73,17 +76,41 @@ def __init__(
)

def log_hparams(self, **kwargs):
kwargs.update(
{
"algorithm_name": self.algorithm_name,
"model_name": self.model_name,
"task_name": self.task_name,
"environment_name": self.environment_name,
"seed": self.seed,
}
)
for logger in self.loggers:
kwargs.update(
{
"algorithm_name": self.algorithm_name,
"model_name": self.model_name,
"task_name": self.task_name,
"environment_name": self.environment_name,
"seed": self.seed,
}
)
logger.log_hparams(kwargs)
if isinstance(logger, TensorboardLogger):
# Tensorboard does not like nested dictionaries -> flatten them
def flatten(dictionary, parent_key="", separator="_"):
items = []
for key, value in dictionary.items():
new_key = parent_key + separator + key if parent_key else key
if isinstance(value, MutableMapping):
items.extend(
flatten(value, new_key, separator=separator).items()
)
elif isinstance(value, Sequence):
for i, v in enumerate(value):
items.append((new_key + separator + str(i), v))
else:
items.append((new_key, value))
return dict(items)

# Convert any non-supported values
for key, value in kwargs.items():
if not isinstance(value, (int, float, str, Tensor)):
kwargs[key] = str(value)

logger.log_hparams(flatten(kwargs))
else:
logger.log_hparams(kwargs)

def log_collection(
self,
Expand Down

0 comments on commit feeeefc

Please sign in to comment.