Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logging #260

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions hive/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import logging
import os

logging.basicConfig(
format="[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s",
level=logging.INFO,
)

from hive import agents, envs, replays, runners, utils
from hive.utils.registry import Registrable, registry

Expand Down
4 changes: 4 additions & 0 deletions hive/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

from hive.agents import qnets
from hive.agents.agent import Agent
from hive.agents.dqn import DQNAgent
Expand All @@ -16,4 +18,6 @@
},
)

logging.info("Registered agents.")

get_agent = getattr(registry, f"get_{Agent.type_name()}")
4 changes: 4 additions & 0 deletions hive/agents/qnets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

from hive.utils.registry import registry
from hive.agents.qnets.atari import NatureAtariDQNModel
from hive.agents.qnets.base import FunctionApproximator
Expand All @@ -13,4 +15,6 @@
},
)

logging.info("Registered function approximators.")

get_qnet = getattr(registry, f"get_{FunctionApproximator.type_name()}")
3 changes: 3 additions & 0 deletions hive/agents/qnets/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import math

import torch
Expand Down Expand Up @@ -138,4 +139,6 @@ def type_name(cls):
},
)

logging.info("Registered PyTorch initialization functions.")

get_optimizer_fn = getattr(registry, f"get_{InitializationFn.type_name()}")
2 changes: 1 addition & 1 deletion hive/configs/gym/dqn.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
run_name: &run_name 'gym-dqn'
train_steps: 50000
test_frequency: 200
test_frequency: 5000
test_episodes: 10
max_steps_per_episode: 1000
stack_size: &stack_size 1
Expand Down
4 changes: 4 additions & 0 deletions hive/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

from hive.envs.base import BaseEnv, ParallelEnv
from hive.envs.env_spec import EnvSpec
from hive.envs.gym_env import GymEnv
Expand Down Expand Up @@ -41,4 +43,6 @@
},
)

logging.info('Registered environments.')

get_env = getattr(registry, f"get_{BaseEnv.type_name()}")
3 changes: 3 additions & 0 deletions hive/replays/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from hive.replays.circular_replay import CircularReplayBuffer, SimpleReplayBuffer
from hive.replays.legal_moves_replay import LegalMovesBuffer
from hive.replays.prioritized_replay import PrioritizedReplayBuffer
Expand All @@ -14,4 +15,6 @@
},
)

logging.info("Registered replays.")

get_replay = getattr(registry, f"get_{BaseReplayBuffer.type_name()}")
14 changes: 13 additions & 1 deletion hive/runners/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import logging
from abc import ABC
from asyncio.log import logger
import pprint
from typing import List

from hive.agents.agent import Agent
from hive.envs.base import BaseEnv

from hive.runners.utils import Metrics
from hive.utils import schedule
from hive.utils.experiment import Experiment
Expand Down Expand Up @@ -124,6 +127,8 @@ def run_episode(self):
def run_training(self):
"""Run the training loop."""
self.train_mode(True)
logging.info("Starting train loop")

while self._train_schedule.get_value():
# Run training episode
if not self._training:
Expand All @@ -135,13 +140,20 @@ def run_training(self):

# Run test episodes
if self._run_testing:
logging.info(
f"{self._train_schedule._steps}/"
f"{self._train_schedule._flip_step} training steps completed."
)
logger.info("Running testing.")
test_metrics = self.run_testing()
self._logger.update_step("test")
self._logger.log_metrics(test_metrics, "test")
self._run_testing = False
logging.info(f"Testing results: {pprint.pformat(test_metrics)}")

# Save experiment state
if self._save_experiment:
logger.info("Saving run.")
self._experiment_manager.save()
self._save_experiment = False

Expand Down
7 changes: 7 additions & 0 deletions hive/runners/multi_agent_loop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import argparse
import copy
import logging
import pprint

from hive import agents as agent_lib
from hive import envs
Expand Down Expand Up @@ -274,6 +276,11 @@ def main():
args.logger_config,
)
runner = set_up_experiment(config)
logging.info(
f"Using config: \n{pprint.pformat(runner._experiment_manager._config, compact=True)}"
)
logging.info("Created runner. Starting run!")

runner.run_training()


Expand Down
7 changes: 7 additions & 0 deletions hive/runners/single_agent_loop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import argparse
import copy
import logging
import pprint

from hive import agents as agent_lib
from hive import envs
Expand Down Expand Up @@ -209,6 +211,11 @@ def main():
args.logger_config,
)
runner = set_up_experiment(config)
logging.info(
f"Using config: \n{pprint.pformat(runner._experiment_manager._config, compact=False)}"
)
logging.info("Created runner. Starting run!")

runner.run_training()


Expand Down
3 changes: 3 additions & 0 deletions hive/utils/loggers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
import copy
import logging
import os
from typing import List

Expand Down Expand Up @@ -450,4 +451,6 @@ def load(self, dir_name):
},
)

logging.info("Registered loggers.")

get_logger = getattr(registry, f"get_{Logger.type_name()}")
3 changes: 3 additions & 0 deletions hive/utils/schedule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import logging

from hive.utils.registry import Registrable, registry

Expand Down Expand Up @@ -200,4 +201,6 @@ def __repr__(self):
},
)

logging.info("Registered schedules.")

get_schedule = getattr(registry, f"get_{Schedule.type_name()}")
5 changes: 5 additions & 0 deletions hive/utils/torch_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import numpy as np
import torch
from torch import optim
Expand Down Expand Up @@ -193,6 +194,8 @@ def step(self, closure=None):
},
)

logging.info("Registered PyTorch optimizers.")

registry.register_all(
LossFn,
{
Expand All @@ -218,5 +221,7 @@ def step(self, closure=None):
},
)

logging.info("Registered PyTorch losses.")

get_optimizer_fn = getattr(registry, f"get_{OptimizerFn.type_name()}")
get_loss_fn = getattr(registry, f"get_{LossFn.type_name()}")