Skip to content

Commit

Permalink
changed all instances of self.artifact_manager.get_logger().{} to log…
Browse files Browse the repository at this point in the history
…ging.{}
  • Loading branch information
wangpatrick57 committed Sep 6, 2024
1 parent 0d6b37c commit e2ca4e7
Show file tree
Hide file tree
Showing 15 changed files with 143 additions and 150 deletions.
13 changes: 7 additions & 6 deletions tune/protox/agent/hpo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import logging
import os
import random
import shutil
Expand Down Expand Up @@ -577,8 +578,8 @@ def setup(self, hpo_params: dict[str, Any]) -> None:
hpo_params=hpo_params,
ray_trial_id=self.ray_trial_id,
)
self.artifact_manager.get_logger().info("%s", hpo_params)
self.artifact_manager.get_logger().info(f"Seed: {seed}")
logging.info("%s", hpo_params)
logging.info(f"Seed: {seed}")

# Attach the timeout checker and loggers.
self.agent.set_timeout_checker(self.timeout_checker)
Expand All @@ -595,7 +596,7 @@ def step(self) -> dict[Any, Any]:

episode = self.agent._episode_num
it = self.agent.num_timesteps
self.artifact_manager.get_logger().info(
logging.info(
f"Starting episode: {episode+1}, iteration: {it+1}"
)

Expand All @@ -605,9 +606,9 @@ def step(self) -> dict[Any, Any]:
infos["baseline_reward"],
infos["baseline_metric"],
)
self.artifact_manager.get_logger().info(
f"Baseline Metric: {baseline_metric}. Baseline Reward: {baseline_reward}"
)
metric_reward_message = f"Baseline Metric: {baseline_metric}. Baseline Reward: {baseline_reward}"
logging.info(metric_reward_message)
self.artifact_manager.log_to_replay_info(metric_reward_message)
self.env_init = True

assert (
Expand Down
15 changes: 7 additions & 8 deletions tune/protox/agent/wolp/policies.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import time
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast

Expand Down Expand Up @@ -71,9 +72,8 @@ def __init__(
self.gamma = gamma

# Log all the networks.
if self.artifact_manager:
self.artifact_manager.get_logger(__name__).info("Actor: %s", self.actor)
self.artifact_manager.get_logger(__name__).info("Critic: %s", self.critic)
logging.info("Actor: %s", self.actor)
logging.info("Critic: %s", self.critic)

def forward(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
raise NotImplementedError()
Expand Down Expand Up @@ -171,8 +171,8 @@ def wolp_act(
# Insert a dimension.
noise = noise.view(-1, *noise.shape)

if noise is not None and self.artifact_manager is not None:
self.artifact_manager.get_logger(__name__).debug(
if noise is not None:
logging.debug(
f"Perturbing with noise class {action_noise}"
)

Expand All @@ -186,9 +186,8 @@ def wolp_act(
raw_action, neighbor_parameters
)

if self.artifact_manager is not None:
# Log the neighborhood we are observing.
self.artifact_manager.get_logger(__name__).debug(f"Neighborhood Sizes {actions_dim}")
# Log the neighborhood we are observing.
logging.debug(f"Neighborhood Sizes {actions_dim}")

if random_act:
# If we want a random action, don't use Q-value estimate.
Expand Down
8 changes: 4 additions & 4 deletions tune/protox/agent/wolp/wolp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from copy import deepcopy
import logging
from typing import Any, Dict, Optional, Tuple

import numpy as np
Expand Down Expand Up @@ -173,10 +174,9 @@ def train(self, env: AgentEnv, gradient_steps: int, batch_size: int) -> None:

actor_losses, critic_losses = [], []
for gs in range(gradient_steps):
if self.artifact_manager:
self.artifact_manager.get_logger(__name__).debug(
f"Training agent gradient step {gs}"
)
logging.debug(
f"Training agent gradient step {gs}"
)
self._n_updates += 1
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size)
Expand Down
25 changes: 18 additions & 7 deletions tune/protox/env/artifact_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ class ArtifactManager(object):
Importantly, this class should *not* be used for general-purpose logging. You should directly
use the logging library to do that.
"""
# The output log is the file that the root logger writes to
OUTPUT_LOG_FNAME = "output.log"
REPLAY_INFO_LOG_FNAME = "replay_info.log"
REPLAY_LOGGER_NAME = "replay_logger"

def __init__(
self,
dbgym_cfg: DBGymConfig,
Expand All @@ -72,15 +77,18 @@ def __init__(
self.tuning_steps_dpath = self.log_dpath / "tuning_steps"
self.tuning_steps_dpath.mkdir(parents=True, exist_ok=True)

level = logging.DEBUG
# Setup the root and replay loggers
formatter = "%(levelname)s:%(asctime)s [%(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(format=formatter, level=level, force=True)
logging.basicConfig(format=formatter, level=logging.DEBUG, force=True)
output_log_handler = logging.FileHandler(self.log_dpath / ArtifactManager.OUTPUT_LOG_FNAME)
output_log_handler.setFormatter(logging.Formatter(formatter))
output_log_handler.setLevel(logging.DEBUG)
logging.getLogger().addHandler(output_log_handler)

# Setup the file artifact_manager.
file_handler = logging.FileHandler(self.tuning_steps_dpath / "output.log")
file_handler.setFormatter(logging.Formatter(formatter))
file_handler.setLevel(level)
logging.getLogger().addHandler(file_handler)
replay_info_log_handler = logging.FileHandler(self.tuning_steps_dpath / ArtifactManager.REPLAY_INFO_LOG_FNAME)
replay_info_log_handler.setFormatter(logging.Formatter(formatter))
replay_info_log_handler.setLevel(logging.INFO)
logging.getLogger(ArtifactManager.REPLAY_LOGGER_NAME)

# Setup the writer.
self.writer: Union[SummaryWriter, None] = None
Expand All @@ -93,6 +101,9 @@ def __init__(

def get_logger(self, name: Optional[str]=None) -> logging.Logger:
return logging.getLogger(name)

def log_to_replay_info(self, message: str) -> None:
logging.getLogger(ArtifactManager.REPLAY_LOGGER_NAME).info(message)

def stash_results(
self,
Expand Down
19 changes: 9 additions & 10 deletions tune/protox/env/lsc/lsc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Any, Optional, TypeVar, cast

import numpy as np
Expand Down Expand Up @@ -52,12 +53,11 @@ def __init__(
self.shift_after = lsc_parameters["shift_after"]
self.artifact_manager = artifact_manager

if self.artifact_manager:
self.artifact_manager.get_logger(__name__).info("LSC Shift: %s", self.lsc_shift)
self.artifact_manager.get_logger(__name__).info(
"LSC Shift Increment: %s", self.increment
)
self.artifact_manager.get_logger(__name__).info("LSC Shift Max: %s", self.max)
logging.info("LSC Shift: %s", self.lsc_shift)
logging.info(
"LSC Shift Increment: %s", self.increment
)
logging.info("LSC Shift Max: %s", self.max)

def apply_bias(self, action: ProtoAction) -> ProtoAction:
if not self.enabled:
Expand Down Expand Up @@ -130,7 +130,6 @@ def reset(self) -> None:
# Increment the current bias with the increment.
self.lsc_shift[:bound] += self.increment[:bound]
self.lsc_shift = self.lsc_shift % self.max
if self.artifact_manager:
self.artifact_manager.get_logger(__name__).info(
"LSC Bias Update: %s", self.lsc_shift
)
logging.info(
"LSC Bias Update: %s", self.lsc_shift
)
15 changes: 7 additions & 8 deletions tune/protox/env/lsc/lsc_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Any, Optional, Tuple

import gymnasium as gym
Expand All @@ -19,9 +20,8 @@ def reset(self, *args: Any, **kwargs: Any) -> tuple[Any, dict[str, Any]]:
self.lsc.reset()

state["lsc"] = self.lsc.current_scale()
if self.artifact_manager:
lsc = state["lsc"]
self.artifact_manager.get_logger(__name__).debug(f"Attaching LSC: {lsc}")
lsc = state["lsc"]
logging.debug(f"Attaching LSC: {lsc}")

return state, info

Expand All @@ -40,10 +40,9 @@ def step(
state["lsc"] = self.lsc.current_scale()
new_bias = self.lsc.current_bias()

if self.artifact_manager:
lsc = state["lsc"]
self.artifact_manager.get_logger(__name__).debug(
f"Shifting LSC: {old_lsc} ({old_bias}) -> {lsc} ({new_bias})"
)
lsc = state["lsc"]
logging.debug(
f"Shifting LSC: {old_lsc} ({old_bias}) -> {lsc} ({new_bias})"
)

return state, float(reward), term, trunc, info
30 changes: 13 additions & 17 deletions tune/protox/env/mqo/mqo_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import json
import logging
from typing import Any, Optional, Tuple, Union

import gymnasium as gym
Expand Down Expand Up @@ -109,10 +110,9 @@ def _regress_query_knobs(
value = 1.0 if "Index" in ams[qid_prefix][alias] else 0.0
else:
# Log out the missing alias for debugging reference.
if artifact_manager:
artifact_manager.get_logger(__name__).debug(
f"Found missing {alias} in the parsed {ams}."
)
logging.debug(
f"Found missing {alias} in the parsed {ams}."
)
value = 0.0
global_qknobs[knob] = value
elif knob.knob_type == SettingType.BOOLEAN:
Expand Down Expand Up @@ -175,11 +175,10 @@ def _update_best_observed(
None,
None,
)
if self.artifact_manager:
assert best_run.runtime is not None
self.artifact_manager.get_logger(__name__).debug(
f"[best_observe] {qid}: {best_run.runtime/1e6} (force: {force_overwrite})"
)
assert best_run.runtime is not None
logging.debug(
f"[best_observe] {qid}: {best_run.runtime/1e6} (force: {force_overwrite})"
)
elif not best_run.timed_out:
qobs = self.best_observed[qid]
assert qobs.runtime and best_run.runtime
Expand All @@ -191,10 +190,9 @@ def _update_best_observed(
None,
None,
)
if self.artifact_manager:
self.artifact_manager.get_logger(__name__).debug(
f"[best_observe] {qid}: {best_run.runtime/1e6}"
)
logging.debug(
f"[best_observe] {qid}: {best_run.runtime/1e6}"
)

def step( # type: ignore
self,
Expand Down Expand Up @@ -308,8 +306,7 @@ def transmute(
)

# Execute.
assert self.artifact_manager is not None
self.artifact_manager.get_logger(__name__).info("MQOWrapper called step_execute()")
logging.info("MQOWrapper called step_execute()")
success, info = self.unwrapped.step_execute(success, runs, info)
if info["query_metric_data"]:
self._update_best_observed(info["query_metric_data"])
Expand Down Expand Up @@ -424,7 +421,6 @@ def reset(self, *args: Any, **kwargs: Any) -> tuple[Any, EnvInfoDict]: # type:
prev_result=metric,
)

if self.artifact_manager:
self.artifact_manager.get_logger(__name__).debug("Maximized on reset.")
logging.debug("Maximized on reset.")

return state, info
Loading

0 comments on commit e2ca4e7

Please sign in to comment.