From 6427429f6953944a265e5d7ed2f731d0714b07fb Mon Sep 17 00:00:00 2001 From: Jonathan Tow <41410219+jon-tow@users.noreply.github.com> Date: Fri, 3 Feb 2023 17:59:34 -0500 Subject: [PATCH] Add multi-process logger utility for status monitoring (#254) * Add logger utility for detailed status monitoring * Fix `flake8` errors * Clean up info messages * Leave PPO rollout progress bar * Update logging doc and add tip * Clarify `README.md` logging docs * Adopt Hugging Face logging API * Run pre-commit * Remove redundant verbosity setters * Toggle rollout bar positioning for suppressed verbosity levels --- README.md | 35 +++ trlx/__init__.py | 1 + trlx/orchestrator/offline_orchestrator.py | 36 ++- trlx/orchestrator/ppo_orchestrator.py | 23 +- trlx/trainer/accelerate_base_trainer.py | 59 +++- trlx/utils/logging.py | 340 ++++++++++++++++++++++ 6 files changed, 466 insertions(+), 28 deletions(-) create mode 100644 trlx/utils/logging.py diff --git a/README.md b/README.md index 5325c83fe..27b8ffa7d 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,41 @@ For more usage see the [NeMo README](./trlx/trainer/nemo) python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml examples/ppo_sentiments.py ``` +## Logging + +trlX uses the standard Python `logging` library to log training information to the console. The default logger is set to the `INFO` level, which means that `INFO`, `WARNING`, `ERROR`, and `CRITICAL` level messages will be printed to standard output. + +To change the log level directly, you can use the verbosity setter. For example, to set the log level to `WARNING` use: + +```python +import trlx + +trlx.logging.set_verbosity(trlx.logging.WARNING) +``` + +This will suppress `INFO` level messages, but still print `WARNING`, `ERROR`, and `CRITICAL` level messages. + +You can also control logging verbosity by setting the `TRLX_VERBOSITY` environment variable to one of the standard logging [level names](https://docs.python.org/3/library/logging.html#logging-levels): + +* `CRITICAL` (`trlx.logging.CRITICAL`) +* `ERROR` (`trlx.logging.ERROR`) +* `WARNING` (`trlx.logging.WARNING`) +* `INFO` (`trlx.logging.INFO`) +* `DEBUG` (`trlx.logging.DEBUG`) + +```sh +export TRLX_VERBOSITY=WARNING +``` + +By default, [`tqdm`](https://tqdm.github.io/docs/tqdm/) progress bars are used to display training progress. You can disable them by calling `trlx.logging.disable_progress_bar()`, otherwise `trlx.logging.enable_progress_bar()` to enable. + +Messages can be formatted with greater detail by setting `trlx.logging.enable_explicit_format()`. This will inject call-site information into each log which may be helpful for debugging. + +```sh +[2023-01-01 05:00:00,000] [INFO] [ppo_orchestrator.py:63:make_experience] [RANK 0] Message... +``` + +> 💡 Tip: To reduce the amount of logging output, you might find it helpful to change log levels of third-party libraries used by trlX. For example, try adding `transformers.logging.set_verbosity_error()` to the top of your trlX scripts to silence verbose messages from the `transformers` library (see their [logging docs](https://huggingface.co/docs/transformers/main_classes/logging#logging) for more details). ## Contributing diff --git a/trlx/__init__.py b/trlx/__init__.py index e84114dc7..7b26a92a9 100644 --- a/trlx/__init__.py +++ b/trlx/__init__.py @@ -1 +1,2 @@ from .trlx import train +from .utils import logging diff --git a/trlx/orchestrator/offline_orchestrator.py b/trlx/orchestrator/offline_orchestrator.py index 90207b19a..d426e7aad 100644 --- a/trlx/orchestrator/offline_orchestrator.py +++ b/trlx/orchestrator/offline_orchestrator.py @@ -1,11 +1,16 @@ +import os from typing import List, Union import numpy as np import torch +from rich.console import Console +from rich.table import Table +import trlx.utils.logging as logging from trlx.orchestrator import Orchestrator, register_orchestrator from trlx.pipeline.offline_pipeline import ILQLRolloutStorage -from trlx.utils import print_rank_0 + +logger = logging.get_logger(__name__) def tokenize_dialogue( # noqa: C901 @@ -60,6 +65,8 @@ def make_experience(self, samples, rewards, max_length=2048): """ Tokenizes samples and shapes rewards into proper tensors and then inserts the resulting dataset into the trainer """ + logger.info("Collecting rollouts") + if self.trainer.tokenizer: samples = [tokenize_dialogue(s, self.trainer.tokenizer, max_length) for s in samples] @@ -84,26 +91,29 @@ def make_experience(self, samples, rewards, max_length=2048): all_actions_ixs.append(torch.hstack(actions_ixs)) all_states_ixs.append(states_ixs) - if self.trainer.tokenizer: + if self.trainer.tokenizer and os.environ.get("RANK", "0") == "0": + logger.info("Logging sample example") prompt = self.trainer.tokenizer.decode(all_input_ids[0][: all_states_ixs[0][1]]) response = self.trainer.tokenizer.decode(all_input_ids[0][all_states_ixs[0][1] :]) - print_rank_0("[Sample example]") - print_rank_0("Prompt: ", prompt) - print_rank_0("Response: ", response) - print_rank_0("Reward: ", rewards[0]) + columns = ["Prompt", "Response", "Reward"] + table = Table(*columns, title="Sample Example", show_lines=True) + table.add_row(prompt, response, str(rewards[0])) + Console().print(table) sample_lengths = np.array(list(map(len, all_input_ids))) output_lengths = np.array(list(map(len, all_actions_ixs))) prompt_lengths = sample_lengths - output_lengths returns = torch.tensor(rewards, dtype=float) - def string_stats(name: str, xs: np.array): - return f"[Mean {name}] {xs.mean():.2f} ∈ [{min(xs)}, {max(xs)}]" - - print_rank_0(string_stats("prompt length", prompt_lengths)) - print_rank_0(string_stats("output length", output_lengths)) - print_rank_0(string_stats("sample length", sample_lengths)) - print_rank_0(string_stats("return", returns)) + if os.environ.get("RANK", "0") == "0": + logger.info("Logging experience string statistics") + columns = ["Prompt Length", "Output Length", "Sample Length"] + table = Table(*columns, title="Experience String Stats (mean ∈ \[min, max])", show_lines=True) + row = [] + for lengths in [prompt_lengths, output_lengths, sample_lengths]: + row.append(f"{lengths.mean():.2f} ∈ [{min(lengths)}, {max(lengths)}]") + table.add_row(*row) + Console().print(table) returns = (returns - returns.mean()) / (returns.std() + 1e-30) rewards = [torch.zeros(len(x)) for x in all_actions_ixs] diff --git a/trlx/orchestrator/ppo_orchestrator.py b/trlx/orchestrator/ppo_orchestrator.py index f7fadea70..4f73f01c3 100644 --- a/trlx/orchestrator/ppo_orchestrator.py +++ b/trlx/orchestrator/ppo_orchestrator.py @@ -1,9 +1,11 @@ +import os from time import time import ray import torch import torch.nn.functional as F +import trlx.utils.logging as logging from trlx.data.accelerate_base_datatypes import PromptBatch from trlx.data.ppo_types import PPORLElement from trlx.orchestrator import Orchestrator, register_orchestrator @@ -12,6 +14,8 @@ from trlx.utils import Clock from trlx.utils.modeling import RunningMoments, logprobs_from_logits +logger = logging.get_logger(__name__) + @register_orchestrator class PPOOrchestrator(Orchestrator): @@ -55,9 +59,22 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq Takes `num_rollouts` prompts from `pipeline`, samples model and computes the KL againts a reference model. It then appends PPOElements to trainer's `store` """ + logger.info("Collecting rollouts") + tbar = logging.tqdm( + total=num_rollouts, + disable=os.environ.get("RANK", 0) != "0", + desc=f"[rollout 0 / {num_rollouts}]", + # Lower progress bar by 1 if we're in WARNING mode or above to avoid hiding high priority progress + # bars (e.g. loss progress in trainers) + position=logging.get_verbosity() >= logging.WARNING, + # Leave progress bar if we're in INFO mode or lower to avoid spamming in suppressed verbosity levels + leave=logging.get_verbosity() < logging.WARNING, + ) + ppo_rl_elements = [] stats = {} clock = Clock() + while len(ppo_rl_elements) < num_rollouts: # Get next batch in prompt dataset and refresh if exhausted try: @@ -198,6 +215,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rewards = -self.trainer.kl_ctl.value * (logprobs - ref_logprobs) rewards = [rs[start : ends[ix]] for ix, rs in enumerate(rewards)] + rollout_count = 0 for ix in range(n): if len(rewards[ix]) == 0 or len(all_logprobs[ix]) == 0: continue @@ -213,8 +231,11 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq rewards=rewards[ix], ) ) - + rollout_count += 1 exp_time = clock.tick() + tbar.set_description(f"[rollout {len(ppo_rl_elements)} / {num_rollouts}]") + tbar.update(min(rollout_count, num_rollouts)) + tbar.close() stats["kl_ctl_value"] = self.trainer.kl_ctl.value stats["time/exp"] = exp_time diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index a76ee5cbb..11c0ce68d 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -13,9 +13,9 @@ from ray.air.checkpoint import Checkpoint from rich.console import Console from rich.table import Table -from tqdm import tqdm from transformers import AutoTokenizer +import trlx.utils.logging as logging from trlx.data.configs import TRLConfig from trlx.trainer import BaseRLTrainer, register_trainer from trlx.utils import ( @@ -24,7 +24,6 @@ get_git_tag, get_optimizer_class, get_scheduler_class, - print_rank_0, significant, ) from trlx.utils.modeling import ( @@ -35,6 +34,8 @@ parse_delta_kwargs, ) +logger = logging.get_logger(__name__) + @register_trainer class AccelerateRLTrainer(BaseRLTrainer): @@ -116,6 +117,8 @@ def setup_model(self): """ Returns a model derived from an instance's TRLConfig """ + logger.info(f"Initializing model: {self.config.model.model_path}") + # Retrieves model equipped for ppo, ilql, etc model = self.get_arch(self.config) if self.config.model.model_arch_type == "seq2seq": @@ -279,8 +282,7 @@ def add_eval_pipeline(self, eval_pipeline): def evaluate(self): # noqa: C901 """Samples model on `eval_prompts`, logs stats with `reward_fn` or `metric_fn` if provided""" - stats = {} - table = [] + logger.info("Evaluating model") # Do multiple evaluations over a single list in `gen_kwargs` if present if self.generate_sweep_kwarg is not None: @@ -288,7 +290,22 @@ def evaluate(self): # noqa: C901 else: gen_sweep_values = [None] - for gen_sweep_value in gen_sweep_values: + desc = [ + f"generation sweep 0/{len(gen_sweep_values)}", + f"eval batch 0/{len(self.eval_dataloader)}", + ] + tbar = logging.tqdm( + total=len(self.eval_dataloader) * len(gen_sweep_values), + desc=f"[{' | '.join(desc)}]", + disable=not self.accelerator.is_main_process, + position=0, + leave=True, + ) + + stats = {} + table = [] + + for i_sweep, gen_sweep_value in enumerate(gen_sweep_values): # A dedicated suffix for wandb logging if gen_sweep_value is not None: sweep_suffix = f"@{gen_sweep_arg}={gen_sweep_value}" @@ -299,7 +316,7 @@ def evaluate(self): # noqa: C901 all_prompts = [] prompt_sizes = [] generate_time = time() - for prompts in self.eval_dataloader: + for i_prompt, prompts in enumerate(self.eval_dataloader): if self.generate_sweep_kwarg: samples = self.generate_eval(**prompts, **{gen_sweep_arg: gen_sweep_value}) else: @@ -326,6 +343,14 @@ def evaluate(self): # noqa: C901 torch.tensor(prompts.input_ids.shape[1], device=samples.device).repeat(len(prompts.input_ids)) ) + desc = [ + f"generation sweep {i_sweep + 1}/{len(gen_sweep_values)}", + f"eval batch {i_prompt + 1}/{len(self.eval_dataloader)}", + ] + tbar.set_description(f"[{' | '.join(desc)}]") + tbar.update() + tbar.close() + stats["time/generate"] = time() - generate_time samples = self.accelerator.gather(torch.vstack(all_samples)) @@ -340,6 +365,7 @@ def evaluate(self): # noqa: C901 # in online setting, compute the reward for validation if self.reward_fn: + logger.info("Computing rewards") rewards = torch.tensor( self.reward_fn( samples=str_samples, @@ -357,6 +383,7 @@ def evaluate(self): # noqa: C901 # additionally log any other metrics if self.metric_fn: + logger.info("Computing metrics") metric_time = time() metrics = self.metric_fn( samples=str_samples, @@ -385,6 +412,7 @@ def evaluate(self): # noqa: C901 table.append(list(zip(*columns_data))) # Log and display evaluation metrics + logger.info("Summarizing evaluation") if self.accelerator.is_main_process: rows = sum(list(map(list, zip(*table))), []) @@ -395,9 +423,9 @@ def evaluate(self): # noqa: C901 table_title += f" {k}: {significant(x)}" rich_table = Table(*columns, title=table_title, show_lines=True) - for ix in range(max(min(3, len(rows)), len(gen_sweep_values))): rich_table.add_row(*[str(significant(x)) for x in rows[ix]]) + Console().print(rich_table) if not ray.is_initialized(): if self.config.train.tracker == "wandb": @@ -405,8 +433,6 @@ def evaluate(self): # noqa: C901 stats["samples"] = wandb.Table(columns, rows) - Console().print(rich_table) - self.nth_evaluation += 1 return stats @@ -414,11 +440,13 @@ def learn(self): # noqa: C901 """ Samples batches from `self.store`, updates model and periodically evaluates it on `self.eval_dataloader` """ + logger.info("Starting training") + self.generate_sweep_kwarg = None for k, v in self.config.method.gen_kwargs.items(): if isinstance(v, list): if self.generate_sweep_kwarg is not None: - print_rank_0("Only a single sweep is allowed, {k} is going to be set to {v[0]}") + logger.info("Only a single sweep is allowed, {k} is going to be set to {v[0]}") self.generate_kwargs[k] = v[0] else: self.generate_sweep_kwarg = (k, v) @@ -440,10 +468,12 @@ def learn(self): # noqa: C901 results = self.evaluate() self.accelerator.log(results, step=self.iter_count) - tbar = tqdm( + tbar = logging.tqdm( initial=self.iter_count, total=self.total_steps, disable=not self.accelerator.is_local_main_process, + position=0, + leave=True, ) best_reward = -float("inf") @@ -491,7 +521,7 @@ def learn(self): # noqa: C901 torch.distributed.all_reduce(do_save, torch.distributed.ReduceOp.MAX) if do_save: best_path = f"{self.config.train.checkpoint_dir}/best_checkpoint" - print_rank_0(f"saving the best state so far into {best_path}") + logger.info(f"Saving the best state so far into {best_path}") self.save(best_path) # Report the metrics to Ray Tune. @@ -505,8 +535,8 @@ def learn(self): # noqa: C901 if not ray.is_initialized(): self.accelerator.log(stats, step=self.iter_count) - desc = ", ".join(f"{k}: {v:.2f}" for k, v in stats.items() if k.startswith("loss")) - tbar.set_description(desc) + desc = " | ".join(f"{k}: {v:.2f}" for k, v in stats.items() if k.startswith("loss")) + tbar.set_description(f"[{desc}]") tbar.update() if self.iter_count >= self.total_steps: @@ -516,6 +546,7 @@ def learn(self): # noqa: C901 self.post_backward_callback() self.post_epoch_callback() + tbar.close() @abstractmethod def get_arch(self, config: TRLConfig): diff --git a/trlx/utils/logging.py b/trlx/utils/logging.py new file mode 100644 index 000000000..79badb4a3 --- /dev/null +++ b/trlx/utils/logging.py @@ -0,0 +1,340 @@ +# Copyright 2023 Optuna, Hugging Face, CarperAI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Logging utilities.""" + +import logging +import os +import sys +import threading +from logging import CRITICAL # NOQA +from logging import DEBUG # NOQA +from logging import ERROR # NOQA +from logging import FATAL # NOQA +from logging import INFO # NOQA +from logging import NOTSET # NOQA +from logging import WARN # NOQA +from logging import WARNING # NOQA +from typing import Optional + +import torch +from tqdm import auto as tqdm_lib + +_lock = threading.Lock() +_default_handler: Optional[logging.Handler] = None + +log_levels = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + "critical": logging.CRITICAL, +} + +_default_log_level = logging.INFO + + +def _get_default_logging_level(): + """ + If `TRLX_VERBOSITY` env var is set to one of the valid choices, return that as the new default level. If it is + not - fall back to `_default_log_level` + """ + env_level_str = os.getenv("TRLX_VERBOSITY", None) + if env_level_str: + if env_level_str.lower() in log_levels: + return log_levels[env_level_str.lower()] + else: + logging.getLogger().warning( + f"Unknown option TRLX_VERBOSITY={env_level_str}, " f"has to be one of: { ', '.join(log_levels.keys()) }" + ) + return _default_log_level + + +def _get_library_name() -> str: + return __name__.split(".")[0] + + +def _get_library_root_logger() -> logging.Logger: + return logging.getLogger(_get_library_name()) + + +def _configure_library_root_logger() -> None: + global _default_handler + + with _lock: + if _default_handler: + # This library has already configured the library root logger. + return + _default_handler = logging.StreamHandler() # Set sys.stderr as stream. + _default_handler.flush = sys.stderr.flush + + # Apply our default configuration to the library root logger. + library_root_logger = _get_library_root_logger() + library_root_logger.addHandler(_default_handler) + library_root_logger.setLevel(_get_default_logging_level()) + library_root_logger.propagate = False + + +def _reset_library_root_logger() -> None: + global _default_handler + + with _lock: + if not _default_handler: + return + + library_root_logger = _get_library_root_logger() + library_root_logger.removeHandler(_default_handler) + library_root_logger.setLevel(logging.NOTSET) + _default_handler = None + + +def get_log_levels_dict(): + return log_levels + + +class MultiProcessAdapter(logging.LoggerAdapter): + """A logger adapter for handling multi-process logging""" + + def log(self, level, msg, *args, **kwargs): + """ + Consumes an additional kwarg called `ranks` to determine which processes should log. + NOTE: To specify all processes, pass in an empty list `ranks=[]` + + Default: ["0"], i.e. only the main process logs + """ + # By default, silence all non-main processes + ranks = kwargs.pop("ranks", ["0"]) + should_log = os.environ.get("RANK", "0") in ranks or len(ranks) == 0 + if self.isEnabledFor(level) and should_log: + msg, kwargs = self.process(msg, kwargs) + self.logger._log(level, msg, args, **kwargs) + + def process(self, msg, kwargs): + this_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + return f"[RANK {this_rank}] {msg}", kwargs + + +def get_logger(name: Optional[str] = None) -> MultiProcessAdapter: + """ + Returns a `logging.Logger` for `name` that can handle multiple processes + + Args: + name: Name of the logger + + Usage: + >> logger = get_logger(__name__) + >> logger.debug("Check the...", ranks=["0", "1"]) # Only main and rank 1 log + """ + if name is None: + name = _get_library_name() + _configure_library_root_logger() + logger = logging.getLogger(name) + return MultiProcessAdapter(logger, {}) + + +def get_verbosity() -> int: + """ + Return the current level for trlx's root logger as an int. + Returns: + `int`: The logging level. + + trlx has following logging levels: + - 50: `trlx.logging.CRITICAL` or `trlx.logging.FATAL` + - 40: `trlx.logging.ERROR` + - 30: `trlx.logging.WARNING` or `trlx.logging.WARN` + - 20: `trlx.logging.INFO` + - 10: `trlx.logging.DEBUG` + + """ + + _configure_library_root_logger() + return _get_library_root_logger().getEffectiveLevel() + + +def set_verbosity(verbosity: int) -> None: + """ + Set the verbosity level for trlX's root logger. + Args: + verbosity (`int`): + Logging level, e.g., one of: + - `trlx.logging.CRITICAL` or `trlx.logging.FATAL` + - `trlx.logging.ERROR` + - `trlx.logging.WARNING` or `trlx.logging.WARN` + - `trlx.logging.INFO` + - `trlx.logging.DEBUG` + """ + + _configure_library_root_logger() + _get_library_root_logger().setLevel(verbosity) + + +def disable_default_handler() -> None: + """Disable the default handler of trlx's root logger.""" + + _configure_library_root_logger() + + assert _default_handler is not None + _get_library_root_logger().removeHandler(_default_handler) + + +def enable_default_handler() -> None: + """Enable the default handler of trlx's root logger.""" + + _configure_library_root_logger() + + assert _default_handler is not None + _get_library_root_logger().addHandler(_default_handler) + + +def add_handler(handler: logging.Handler) -> None: + """Adds a handler to trlx's root logger.""" + + _configure_library_root_logger() + + assert handler is not None + _get_library_root_logger().addHandler(handler) + + +def remove_handler(handler: logging.Handler) -> None: + """Removes given handler from the trlx's root logger.""" + + _configure_library_root_logger() + + assert handler is not None and handler not in _get_library_root_logger().handlers + _get_library_root_logger().removeHandler(handler) + + +def disable_propagation() -> None: + """ + Disable propagation of the library log outputs. Note that log propagation is disabled by default. + """ + + _configure_library_root_logger() + _get_library_root_logger().propagate = False + + +def enable_propagation() -> None: + """ + Enable propagation of the library log outputs. Please disable the trlx's default handler to prevent + double logging if the root logger has been configured. + """ + + _configure_library_root_logger() + _get_library_root_logger().propagate = True + + +def enable_explicit_format() -> None: + """ + Enable explicit formatting for every trlx's logger. The explicit formatter is as follows: + ``` + [ASCTIME] [LEVELNAME] [FILENAME:LINE NUMBER:FUNCNAME] MESSAGE + ``` + All handlers currently bound to the root logger are affected by this method. + """ + handlers = _get_library_root_logger().handlers + + for handler in handlers: + formatter = logging.Formatter( + "[%(asctime)s] [%(levelname)s] [%(filename)s:%(lineno)d:%(funcName)s] %(message)s" + ) + handler.setFormatter(formatter) + + +def reset_format() -> None: + """ + Resets the formatting for trlx's loggers. + All handlers currently bound to the root logger are affected by this method. + """ + handlers = _get_library_root_logger().handlers + + for handler in handlers: + handler.setFormatter(None) + + +def warning_advice(self, *args, **kwargs): + """ + This method is identical to `logger.warning()`, but if env var TRLX_NO_ADVISORY_WARNINGS=1 is set, this + warning will not be printed + """ + no_advisory_warnings = os.getenv("TRLX_NO_ADVISORY_WARNINGS", False) + if no_advisory_warnings: + return + self.warning(*args, **kwargs) + + +logging.Logger.warning_advice = warning_advice + + +class EmptyTqdm: + """Dummy tqdm which doesn't do anything.""" + + def __init__(self, *args, **kwargs): # pylint: disable=unused-argument + self._iterator = args[0] if args else None + + def __iter__(self): + return iter(self._iterator) + + def __getattr__(self, _): + """Return empty function.""" + + def empty_fn(*args, **kwargs): # pylint: disable=unused-argument + return + + return empty_fn + + def __enter__(self): + return self + + def __exit__(self, type_, value, traceback): + return + + +_tqdm_active = True + + +class _tqdm_cls: + def __call__(self, *args, **kwargs): + if _tqdm_active: + return tqdm_lib.tqdm(*args, **kwargs) + else: + return EmptyTqdm(*args, **kwargs) + + def set_lock(self, *args, **kwargs): + self._lock = None + if _tqdm_active: + return tqdm_lib.tqdm.set_lock(*args, **kwargs) + + def get_lock(self): + if _tqdm_active: + return tqdm_lib.tqdm.get_lock() + + +tqdm = _tqdm_cls() + + +def is_progress_bar_enabled() -> bool: + """Return a boolean indicating whether tqdm progress bars are enabled.""" + global _tqdm_active + return bool(_tqdm_active) + + +def enable_progress_bar(): + """Enable tqdm progress bar.""" + global _tqdm_active + _tqdm_active = True + + +def disable_progress_bar(): + """Disable tqdm progress bar.""" + global _tqdm_active + _tqdm_active = False