Skip to content

Commit

Permalink
Add multi-process logger utility for status monitoring (#254)
Browse files Browse the repository at this point in the history
* Add logger utility for detailed status monitoring

* Fix `flake8` errors

* Clean up info messages

* Leave PPO rollout progress bar

* Update logging doc and add tip

* Clarify `README.md` logging docs

* Adopt Hugging Face logging API

* Run pre-commit

* Remove redundant verbosity setters

* Toggle rollout bar positioning for suppressed verbosity levels
  • Loading branch information
jon-tow authored Feb 3, 2023
1 parent b70bc92 commit 6427429
Show file tree
Hide file tree
Showing 6 changed files with 466 additions and 28 deletions.
35 changes: 35 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,41 @@ For more usage see the [NeMo README](./trlx/trainer/nemo)
python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml examples/ppo_sentiments.py
```

## Logging

trlX uses the standard Python `logging` library to log training information to the console. The default logger is set to the `INFO` level, which means that `INFO`, `WARNING`, `ERROR`, and `CRITICAL` level messages will be printed to standard output.

To change the log level directly, you can use the verbosity setter. For example, to set the log level to `WARNING` use:

```python
import trlx

trlx.logging.set_verbosity(trlx.logging.WARNING)
```

This will suppress `INFO` level messages, but still print `WARNING`, `ERROR`, and `CRITICAL` level messages.

You can also control logging verbosity by setting the `TRLX_VERBOSITY` environment variable to one of the standard logging [level names](https://docs.python.org/3/library/logging.html#logging-levels):

* `CRITICAL` (`trlx.logging.CRITICAL`)
* `ERROR` (`trlx.logging.ERROR`)
* `WARNING` (`trlx.logging.WARNING`)
* `INFO` (`trlx.logging.INFO`)
* `DEBUG` (`trlx.logging.DEBUG`)

```sh
export TRLX_VERBOSITY=WARNING
```

By default, [`tqdm`](https://tqdm.github.io/docs/tqdm/) progress bars are used to display training progress. You can disable them by calling `trlx.logging.disable_progress_bar()`, otherwise `trlx.logging.enable_progress_bar()` to enable.

Messages can be formatted with greater detail by setting `trlx.logging.enable_explicit_format()`. This will inject call-site information into each log which may be helpful for debugging.

```sh
[2023-01-01 05:00:00,000] [INFO] [ppo_orchestrator.py:63:make_experience] [RANK 0] Message...
```

> 💡 Tip: To reduce the amount of logging output, you might find it helpful to change log levels of third-party libraries used by trlX. For example, try adding `transformers.logging.set_verbosity_error()` to the top of your trlX scripts to silence verbose messages from the `transformers` library (see their [logging docs](https://huggingface.co/docs/transformers/main_classes/logging#logging) for more details).
## Contributing

Expand Down
1 change: 1 addition & 0 deletions trlx/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .trlx import train
from .utils import logging
36 changes: 23 additions & 13 deletions trlx/orchestrator/offline_orchestrator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import os
from typing import List, Union

import numpy as np
import torch
from rich.console import Console
from rich.table import Table

import trlx.utils.logging as logging
from trlx.orchestrator import Orchestrator, register_orchestrator
from trlx.pipeline.offline_pipeline import ILQLRolloutStorage
from trlx.utils import print_rank_0

logger = logging.get_logger(__name__)


def tokenize_dialogue( # noqa: C901
Expand Down Expand Up @@ -60,6 +65,8 @@ def make_experience(self, samples, rewards, max_length=2048):
"""
Tokenizes samples and shapes rewards into proper tensors and then inserts the resulting dataset into the trainer
"""
logger.info("Collecting rollouts")

if self.trainer.tokenizer:
samples = [tokenize_dialogue(s, self.trainer.tokenizer, max_length) for s in samples]

Expand All @@ -84,26 +91,29 @@ def make_experience(self, samples, rewards, max_length=2048):
all_actions_ixs.append(torch.hstack(actions_ixs))
all_states_ixs.append(states_ixs)

if self.trainer.tokenizer:
if self.trainer.tokenizer and os.environ.get("RANK", "0") == "0":
logger.info("Logging sample example")
prompt = self.trainer.tokenizer.decode(all_input_ids[0][: all_states_ixs[0][1]])
response = self.trainer.tokenizer.decode(all_input_ids[0][all_states_ixs[0][1] :])
print_rank_0("[Sample example]")
print_rank_0("Prompt: ", prompt)
print_rank_0("Response: ", response)
print_rank_0("Reward: ", rewards[0])
columns = ["Prompt", "Response", "Reward"]
table = Table(*columns, title="Sample Example", show_lines=True)
table.add_row(prompt, response, str(rewards[0]))
Console().print(table)

sample_lengths = np.array(list(map(len, all_input_ids)))
output_lengths = np.array(list(map(len, all_actions_ixs)))
prompt_lengths = sample_lengths - output_lengths
returns = torch.tensor(rewards, dtype=float)

def string_stats(name: str, xs: np.array):
return f"[Mean {name}] {xs.mean():.2f} ∈ [{min(xs)}, {max(xs)}]"

print_rank_0(string_stats("prompt length", prompt_lengths))
print_rank_0(string_stats("output length", output_lengths))
print_rank_0(string_stats("sample length", sample_lengths))
print_rank_0(string_stats("return", returns))
if os.environ.get("RANK", "0") == "0":
logger.info("Logging experience string statistics")
columns = ["Prompt Length", "Output Length", "Sample Length"]
table = Table(*columns, title="Experience String Stats (mean ∈ \[min, max])", show_lines=True)
row = []
for lengths in [prompt_lengths, output_lengths, sample_lengths]:
row.append(f"{lengths.mean():.2f} ∈ [{min(lengths)}, {max(lengths)}]")
table.add_row(*row)
Console().print(table)

returns = (returns - returns.mean()) / (returns.std() + 1e-30)
rewards = [torch.zeros(len(x)) for x in all_actions_ixs]
Expand Down
23 changes: 22 additions & 1 deletion trlx/orchestrator/ppo_orchestrator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
from time import time

import ray
import torch
import torch.nn.functional as F

import trlx.utils.logging as logging
from trlx.data.accelerate_base_datatypes import PromptBatch
from trlx.data.ppo_types import PPORLElement
from trlx.orchestrator import Orchestrator, register_orchestrator
Expand All @@ -12,6 +14,8 @@
from trlx.utils import Clock
from trlx.utils.modeling import RunningMoments, logprobs_from_logits

logger = logging.get_logger(__name__)


@register_orchestrator
class PPOOrchestrator(Orchestrator):
Expand Down Expand Up @@ -55,9 +59,22 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
Takes `num_rollouts` prompts from `pipeline`, samples model and computes the
KL againts a reference model. It then appends PPOElements to trainer's `store`
"""
logger.info("Collecting rollouts")
tbar = logging.tqdm(
total=num_rollouts,
disable=os.environ.get("RANK", 0) != "0",
desc=f"[rollout 0 / {num_rollouts}]",
# Lower progress bar by 1 if we're in WARNING mode or above to avoid hiding high priority progress
# bars (e.g. loss progress in trainers)
position=logging.get_verbosity() >= logging.WARNING,
# Leave progress bar if we're in INFO mode or lower to avoid spamming in suppressed verbosity levels
leave=logging.get_verbosity() < logging.WARNING,
)

ppo_rl_elements = []
stats = {}
clock = Clock()

while len(ppo_rl_elements) < num_rollouts:
# Get next batch in prompt dataset and refresh if exhausted
try:
Expand Down Expand Up @@ -198,6 +215,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
rewards = -self.trainer.kl_ctl.value * (logprobs - ref_logprobs)
rewards = [rs[start : ends[ix]] for ix, rs in enumerate(rewards)]

rollout_count = 0
for ix in range(n):
if len(rewards[ix]) == 0 or len(all_logprobs[ix]) == 0:
continue
Expand All @@ -213,8 +231,11 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
rewards=rewards[ix],
)
)

rollout_count += 1
exp_time = clock.tick()
tbar.set_description(f"[rollout {len(ppo_rl_elements)} / {num_rollouts}]")
tbar.update(min(rollout_count, num_rollouts))
tbar.close()

stats["kl_ctl_value"] = self.trainer.kl_ctl.value
stats["time/exp"] = exp_time
Expand Down
59 changes: 45 additions & 14 deletions trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from ray.air.checkpoint import Checkpoint
from rich.console import Console
from rich.table import Table
from tqdm import tqdm
from transformers import AutoTokenizer

import trlx.utils.logging as logging
from trlx.data.configs import TRLConfig
from trlx.trainer import BaseRLTrainer, register_trainer
from trlx.utils import (
Expand All @@ -24,7 +24,6 @@
get_git_tag,
get_optimizer_class,
get_scheduler_class,
print_rank_0,
significant,
)
from trlx.utils.modeling import (
Expand All @@ -35,6 +34,8 @@
parse_delta_kwargs,
)

logger = logging.get_logger(__name__)


@register_trainer
class AccelerateRLTrainer(BaseRLTrainer):
Expand Down Expand Up @@ -116,6 +117,8 @@ def setup_model(self):
"""
Returns a model derived from an instance's TRLConfig
"""
logger.info(f"Initializing model: {self.config.model.model_path}")

# Retrieves model equipped for ppo, ilql, etc
model = self.get_arch(self.config)
if self.config.model.model_arch_type == "seq2seq":
Expand Down Expand Up @@ -279,16 +282,30 @@ def add_eval_pipeline(self, eval_pipeline):

def evaluate(self): # noqa: C901
"""Samples model on `eval_prompts`, logs stats with `reward_fn` or `metric_fn` if provided"""
stats = {}
table = []
logger.info("Evaluating model")

# Do multiple evaluations over a single list in `gen_kwargs` if present
if self.generate_sweep_kwarg is not None:
gen_sweep_arg, gen_sweep_values = self.generate_sweep_kwarg
else:
gen_sweep_values = [None]

for gen_sweep_value in gen_sweep_values:
desc = [
f"generation sweep 0/{len(gen_sweep_values)}",
f"eval batch 0/{len(self.eval_dataloader)}",
]
tbar = logging.tqdm(
total=len(self.eval_dataloader) * len(gen_sweep_values),
desc=f"[{' | '.join(desc)}]",
disable=not self.accelerator.is_main_process,
position=0,
leave=True,
)

stats = {}
table = []

for i_sweep, gen_sweep_value in enumerate(gen_sweep_values):
# A dedicated suffix for wandb logging
if gen_sweep_value is not None:
sweep_suffix = f"@{gen_sweep_arg}={gen_sweep_value}"
Expand All @@ -299,7 +316,7 @@ def evaluate(self): # noqa: C901
all_prompts = []
prompt_sizes = []
generate_time = time()
for prompts in self.eval_dataloader:
for i_prompt, prompts in enumerate(self.eval_dataloader):
if self.generate_sweep_kwarg:
samples = self.generate_eval(**prompts, **{gen_sweep_arg: gen_sweep_value})
else:
Expand All @@ -326,6 +343,14 @@ def evaluate(self): # noqa: C901
torch.tensor(prompts.input_ids.shape[1], device=samples.device).repeat(len(prompts.input_ids))
)

desc = [
f"generation sweep {i_sweep + 1}/{len(gen_sweep_values)}",
f"eval batch {i_prompt + 1}/{len(self.eval_dataloader)}",
]
tbar.set_description(f"[{' | '.join(desc)}]")
tbar.update()
tbar.close()

stats["time/generate"] = time() - generate_time

samples = self.accelerator.gather(torch.vstack(all_samples))
Expand All @@ -340,6 +365,7 @@ def evaluate(self): # noqa: C901

# in online setting, compute the reward for validation
if self.reward_fn:
logger.info("Computing rewards")
rewards = torch.tensor(
self.reward_fn(
samples=str_samples,
Expand All @@ -357,6 +383,7 @@ def evaluate(self): # noqa: C901

# additionally log any other metrics
if self.metric_fn:
logger.info("Computing metrics")
metric_time = time()
metrics = self.metric_fn(
samples=str_samples,
Expand Down Expand Up @@ -385,6 +412,7 @@ def evaluate(self): # noqa: C901
table.append(list(zip(*columns_data)))

# Log and display evaluation metrics
logger.info("Summarizing evaluation")
if self.accelerator.is_main_process:
rows = sum(list(map(list, zip(*table))), [])

Expand All @@ -395,30 +423,30 @@ def evaluate(self): # noqa: C901
table_title += f" {k}: {significant(x)}"

rich_table = Table(*columns, title=table_title, show_lines=True)

for ix in range(max(min(3, len(rows)), len(gen_sweep_values))):
rich_table.add_row(*[str(significant(x)) for x in rows[ix]])
Console().print(rich_table)

if not ray.is_initialized():
if self.config.train.tracker == "wandb":
import wandb

stats["samples"] = wandb.Table(columns, rows)

Console().print(rich_table)

self.nth_evaluation += 1
return stats

def learn(self): # noqa: C901
"""
Samples batches from `self.store`, updates model and periodically evaluates it on `self.eval_dataloader`
"""
logger.info("Starting training")

self.generate_sweep_kwarg = None
for k, v in self.config.method.gen_kwargs.items():
if isinstance(v, list):
if self.generate_sweep_kwarg is not None:
print_rank_0("Only a single sweep is allowed, {k} is going to be set to {v[0]}")
logger.info("Only a single sweep is allowed, {k} is going to be set to {v[0]}")
self.generate_kwargs[k] = v[0]
else:
self.generate_sweep_kwarg = (k, v)
Expand All @@ -440,10 +468,12 @@ def learn(self): # noqa: C901
results = self.evaluate()
self.accelerator.log(results, step=self.iter_count)

tbar = tqdm(
tbar = logging.tqdm(
initial=self.iter_count,
total=self.total_steps,
disable=not self.accelerator.is_local_main_process,
position=0,
leave=True,
)

best_reward = -float("inf")
Expand Down Expand Up @@ -491,7 +521,7 @@ def learn(self): # noqa: C901
torch.distributed.all_reduce(do_save, torch.distributed.ReduceOp.MAX)
if do_save:
best_path = f"{self.config.train.checkpoint_dir}/best_checkpoint"
print_rank_0(f"saving the best state so far into {best_path}")
logger.info(f"Saving the best state so far into {best_path}")
self.save(best_path)

# Report the metrics to Ray Tune.
Expand All @@ -505,8 +535,8 @@ def learn(self): # noqa: C901
if not ray.is_initialized():
self.accelerator.log(stats, step=self.iter_count)

desc = ", ".join(f"{k}: {v:.2f}" for k, v in stats.items() if k.startswith("loss"))
tbar.set_description(desc)
desc = " | ".join(f"{k}: {v:.2f}" for k, v in stats.items() if k.startswith("loss"))
tbar.set_description(f"[{desc}]")
tbar.update()

if self.iter_count >= self.total_steps:
Expand All @@ -516,6 +546,7 @@ def learn(self): # noqa: C901
self.post_backward_callback()

self.post_epoch_callback()
tbar.close()

@abstractmethod
def get_arch(self, config: TRLConfig):
Expand Down
Loading

0 comments on commit 6427429

Please sign in to comment.