diff --git a/hive/__init__.py b/hive/__init__.py index d37d7c60..4835a666 100644 --- a/hive/__init__.py +++ b/hive/__init__.py @@ -1,5 +1,11 @@ +import logging import os +logging.basicConfig( + format="[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s", + level=logging.INFO, +) + from hive import agents, envs, replays, runners, utils from hive.utils.registry import Registrable, registry diff --git a/hive/agents/__init__.py b/hive/agents/__init__.py index bffdf9bc..894ee12a 100644 --- a/hive/agents/__init__.py +++ b/hive/agents/__init__.py @@ -1,3 +1,5 @@ +import logging + from hive.agents import qnets from hive.agents.agent import Agent from hive.agents.dqn import DQNAgent @@ -16,4 +18,6 @@ }, ) +logging.info("Registered agents.") + get_agent = getattr(registry, f"get_{Agent.type_name()}") diff --git a/hive/agents/qnets/__init__.py b/hive/agents/qnets/__init__.py index 232c0d50..401112f5 100644 --- a/hive/agents/qnets/__init__.py +++ b/hive/agents/qnets/__init__.py @@ -1,3 +1,5 @@ +import logging + from hive.utils.registry import registry from hive.agents.qnets.atari import NatureAtariDQNModel from hive.agents.qnets.base import FunctionApproximator @@ -13,4 +15,6 @@ }, ) +logging.info("Registered function approximators.") + get_qnet = getattr(registry, f"get_{FunctionApproximator.type_name()}") diff --git a/hive/agents/qnets/utils.py b/hive/agents/qnets/utils.py index 9a42ab6e..ded75e39 100644 --- a/hive/agents/qnets/utils.py +++ b/hive/agents/qnets/utils.py @@ -1,3 +1,4 @@ +import logging import math import torch @@ -138,4 +139,6 @@ def type_name(cls): }, ) +logging.info("Registered PyTorch initialization functions.") + get_optimizer_fn = getattr(registry, f"get_{InitializationFn.type_name()}") diff --git a/hive/configs/gym/dqn.yml b/hive/configs/gym/dqn.yml index acfc1863..5a66e0ca 100644 --- a/hive/configs/gym/dqn.yml +++ b/hive/configs/gym/dqn.yml @@ -1,6 +1,6 @@ run_name: &run_name 'gym-dqn' train_steps: 50000 -test_frequency: 200 +test_frequency: 5000 test_episodes: 10 max_steps_per_episode: 1000 stack_size: &stack_size 1 diff --git a/hive/envs/__init__.py b/hive/envs/__init__.py index 0273f180..aa062a15 100644 --- a/hive/envs/__init__.py +++ b/hive/envs/__init__.py @@ -1,3 +1,5 @@ +import logging + from hive.envs.base import BaseEnv, ParallelEnv from hive.envs.env_spec import EnvSpec from hive.envs.gym_env import GymEnv @@ -41,4 +43,6 @@ }, ) +logging.info('Registered environments.') + get_env = getattr(registry, f"get_{BaseEnv.type_name()}") diff --git a/hive/replays/__init__.py b/hive/replays/__init__.py index 4302d9c9..579e5bb3 100644 --- a/hive/replays/__init__.py +++ b/hive/replays/__init__.py @@ -1,3 +1,4 @@ +import logging from hive.replays.circular_replay import CircularReplayBuffer, SimpleReplayBuffer from hive.replays.legal_moves_replay import LegalMovesBuffer from hive.replays.prioritized_replay import PrioritizedReplayBuffer @@ -14,4 +15,6 @@ }, ) +logging.info("Registered replays.") + get_replay = getattr(registry, f"get_{BaseReplayBuffer.type_name()}") diff --git a/hive/runners/base.py b/hive/runners/base.py index 43664cc1..977ad84d 100644 --- a/hive/runners/base.py +++ b/hive/runners/base.py @@ -1,8 +1,11 @@ +import logging from abc import ABC +from asyncio.log import logger +import pprint from typing import List + from hive.agents.agent import Agent from hive.envs.base import BaseEnv - from hive.runners.utils import Metrics from hive.utils import schedule from hive.utils.experiment import Experiment @@ -124,6 +127,8 @@ def run_episode(self): def run_training(self): """Run the training loop.""" self.train_mode(True) + logging.info("Starting train loop") + while self._train_schedule.get_value(): # Run training episode if not self._training: @@ -135,13 +140,20 @@ def run_training(self): # Run test episodes if self._run_testing: + logging.info( + f"{self._train_schedule._steps}/" + f"{self._train_schedule._flip_step} training steps completed." + ) + logger.info("Running testing.") test_metrics = self.run_testing() self._logger.update_step("test") self._logger.log_metrics(test_metrics, "test") self._run_testing = False + logging.info(f"Testing results: {pprint.pformat(test_metrics)}") # Save experiment state if self._save_experiment: + logger.info("Saving run.") self._experiment_manager.save() self._save_experiment = False diff --git a/hive/runners/multi_agent_loop.py b/hive/runners/multi_agent_loop.py index 36ff6359..c835fe1b 100644 --- a/hive/runners/multi_agent_loop.py +++ b/hive/runners/multi_agent_loop.py @@ -1,5 +1,7 @@ import argparse import copy +import logging +import pprint from hive import agents as agent_lib from hive import envs @@ -274,6 +276,11 @@ def main(): args.logger_config, ) runner = set_up_experiment(config) + logging.info( + f"Using config: \n{pprint.pformat(runner._experiment_manager._config, compact=True)}" + ) + logging.info("Created runner. Starting run!") + runner.run_training() diff --git a/hive/runners/single_agent_loop.py b/hive/runners/single_agent_loop.py index c8636885..6ae62a89 100644 --- a/hive/runners/single_agent_loop.py +++ b/hive/runners/single_agent_loop.py @@ -1,5 +1,7 @@ import argparse import copy +import logging +import pprint from hive import agents as agent_lib from hive import envs @@ -209,6 +211,11 @@ def main(): args.logger_config, ) runner = set_up_experiment(config) + logging.info( + f"Using config: \n{pprint.pformat(runner._experiment_manager._config, compact=False)}" + ) + logging.info("Created runner. Starting run!") + runner.run_training() diff --git a/hive/utils/loggers.py b/hive/utils/loggers.py index 7e32a1a5..649506f7 100644 --- a/hive/utils/loggers.py +++ b/hive/utils/loggers.py @@ -1,5 +1,6 @@ import abc import copy +import logging import os from typing import List @@ -450,4 +451,6 @@ def load(self, dir_name): }, ) +logging.info("Registered loggers.") + get_logger = getattr(registry, f"get_{Logger.type_name()}") diff --git a/hive/utils/schedule.py b/hive/utils/schedule.py index 9e9765aa..52de168b 100644 --- a/hive/utils/schedule.py +++ b/hive/utils/schedule.py @@ -1,4 +1,5 @@ import abc +import logging from hive.utils.registry import Registrable, registry @@ -200,4 +201,6 @@ def __repr__(self): }, ) +logging.info("Registered schedules.") + get_schedule = getattr(registry, f"get_{Schedule.type_name()}") diff --git a/hive/utils/torch_utils.py b/hive/utils/torch_utils.py index b322a917..7ccd8411 100644 --- a/hive/utils/torch_utils.py +++ b/hive/utils/torch_utils.py @@ -1,3 +1,4 @@ +import logging import numpy as np import torch from torch import optim @@ -193,6 +194,8 @@ def step(self, closure=None): }, ) +logging.info("Registered PyTorch optimizers.") + registry.register_all( LossFn, { @@ -218,5 +221,7 @@ def step(self, closure=None): }, ) +logging.info("Registered PyTorch losses.") + get_optimizer_fn = getattr(registry, f"get_{OptimizerFn.type_name()}") get_loss_fn = getattr(registry, f"get_{LossFn.type_name()}")