Skip to content

feat: Default to RichProgressBar and RichModelSummary if rich is avai… #20896

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- For cross-device local checkpoints, instruct users to install `fsspec>=2025.5.0` if unavailable ([#20780](https://github.com/Lightning-AI/pytorch-lightning/pull/20780))

- Default to RichProgressBar and RichModelSummary if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise. ([#9580](https://github.com/Lightning-AI/pytorch-lightning/pull/9580))


### Changed

Expand Down
48 changes: 29 additions & 19 deletions src/lightning/pytorch/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@
from datetime import timedelta
from typing import Any, Optional, Union, cast

from lightning_utilities.core.imports import RequirementCache
import torch
from typing_extensions import override

import lightning.pytorch as pl
from lightning.pytorch.callbacks.progress.progress_bar import ProgressBar
from lightning.pytorch.utilities.imports import _RICH_AVAILABLE
from lightning.pytorch.utilities.types import STEP_OUTPUT

_RICH_AVAILABLE = RequirementCache("rich>=10.2.2")

if _RICH_AVAILABLE:
from rich import get_console, reconfigure
from rich.console import Console, RenderableType
Expand Down Expand Up @@ -171,7 +170,7 @@ def render(self, task: "Task") -> Text:
return Text()
if self._trainer.training and task.id not in self._tasks:
self._tasks[task.id] = "None"
if self._renderable_cache:
if self._renderable_cache and self._current_task_id in self._renderable_cache:
self._current_task_id = cast(TaskID, self._current_task_id)
self._tasks[self._current_task_id] = self._renderable_cache[self._current_task_id][1]
self._current_task_id = task.id
Expand All @@ -185,7 +184,10 @@ def render(self, task: "Task") -> Text:
def _generate_metrics_texts(self) -> Generator[str, None, None]:
for name, value in self._metrics.items():
if not isinstance(value, str):
value = f"{value:{self._metrics_format}}"
try:
value = f"{value:{self._metrics_format}}"
except (TypeError, ValueError):
value = str(value)
yield f"{name}: {value}"


Expand Down Expand Up @@ -448,17 +450,12 @@ def _add_task(self, total_batches: Union[int, float], description: str, visible:
)

def _update(self, progress_bar_id: Optional["TaskID"], current: int, visible: bool = True) -> None:
if self.progress is not None and self.is_enabled:
assert progress_bar_id is not None
if self.progress is not None and self.is_enabled and progress_bar_id is not None:
total = self.progress.tasks[progress_bar_id].total
assert total is not None
if not self._should_update(current, total):
return

leftover = current % self.refresh_rate
advance = leftover if (current == total and leftover != 0) else self.refresh_rate
self.progress.update(progress_bar_id, advance=advance, visible=visible)
self.refresh()
self.progress.update(progress_bar_id, completed=current, visible=visible)

def _should_update(self, current: int, total: Union[int, float]) -> bool:
return current % self.refresh_rate == 0 or current == total
Expand Down Expand Up @@ -552,9 +549,13 @@ def on_validation_batch_end(
if self.is_disabled:
return
if trainer.sanity_checking:
self._update(self.val_sanity_progress_bar_id, batch_idx + 1)
elif self.val_progress_bar_id is not None:
self._update(self.val_progress_bar_id, batch_idx + 1)
if self.val_sanity_progress_bar_id is not None:
self._update(self.val_sanity_progress_bar_id, batch_idx + 1)
return

if self.val_progress_bar_id is None:
return
self._update(self.val_progress_bar_id, batch_idx + 1)
self.refresh()

@override
Expand All @@ -567,9 +568,8 @@ def on_test_batch_end(
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
if self.is_disabled:
if self.is_disabled or self.test_progress_bar_id is None:
return
assert self.test_progress_bar_id is not None
self._update(self.test_progress_bar_id, batch_idx + 1)
self.refresh()

Expand All @@ -583,9 +583,8 @@ def on_predict_batch_end(
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
if self.is_disabled:
if self.is_disabled or self.predict_progress_bar_id is None:
return
assert self.predict_progress_bar_id is not None
self._update(self.predict_progress_bar_id, batch_idx + 1)
self.refresh()

Expand All @@ -612,6 +611,17 @@ def _reset_progress_bar_ids(self) -> None:
self.test_progress_bar_id = None
self.predict_progress_bar_id = None

@override
def get_metrics(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
) -> dict[str, Union[int, str, float, dict[str, float]]]:
items = super().get_metrics(trainer, pl_module)
# convert all metrics to float before sending to rich
for k, v in items.items():
if isinstance(v, torch.Tensor):
items[k] = v.item()
return items

def _update_metrics(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
metrics = self.get_metrics(trainer, pl_module)
if self._metric_component:
Expand Down
8 changes: 4 additions & 4 deletions src/lightning/pytorch/callbacks/progress/tqdm_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def on_train_batch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
) -> None:
n = batch_idx + 1
if self._should_update(n, self.train_progress_bar.total):
if self.train_progress_bar is not None and self._should_update(n, self.train_progress_bar.total):
_update_n(self.train_progress_bar, n)
self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))

Expand Down Expand Up @@ -322,7 +322,7 @@ def on_validation_batch_end(
dataloader_idx: int = 0,
) -> None:
n = batch_idx + 1
if self._should_update(n, self.val_progress_bar.total):
if self.val_progress_bar is not None and self._should_update(n, self.val_progress_bar.total):
_update_n(self.val_progress_bar, n)

@override
Expand Down Expand Up @@ -363,7 +363,7 @@ def on_test_batch_end(
dataloader_idx: int = 0,
) -> None:
n = batch_idx + 1
if self._should_update(n, self.test_progress_bar.total):
if self.test_progress_bar is not None and self._should_update(n, self.test_progress_bar.total):
_update_n(self.test_progress_bar, n)

@override
Expand Down Expand Up @@ -402,7 +402,7 @@ def on_predict_batch_end(
dataloader_idx: int = 0,
) -> None:
n = batch_idx + 1
if self._should_update(n, self.predict_progress_bar.total):
if self.predict_progress_bar is not None and self._should_update(n, self.predict_progress_bar.total):
_update_n(self.predict_progress_bar, n)

@override
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/callbacks/rich_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing_extensions import override

from lightning.pytorch.callbacks import ModelSummary
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
from lightning.pytorch.utilities.imports import _RICH_AVAILABLE
from lightning.pytorch.utilities.model_summary import get_human_readable_count


Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/loops/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

import lightning.pytorch as pl
from lightning.fabric.utilities.data import _set_sampler_epoch
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher
from lightning.pytorch.loops.loop import _Loop
from lightning.pytorch.loops.progress import _BatchProgress
Expand All @@ -44,6 +43,7 @@
from lightning.pytorch.utilities.combined_loader import CombinedLoader
from lightning.pytorch.utilities.data import has_len_all_ranks
from lightning.pytorch.utilities.exceptions import SIGTERMException
from lightning.pytorch.utilities.imports import _RICH_AVAILABLE
from lightning.pytorch.utilities.model_helpers import _ModuleMode, is_overridden
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature

Expand Down
11 changes: 3 additions & 8 deletions src/lightning/pytorch/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from lightning.pytorch.callbacks.timer import Timer
from lightning.pytorch.trainer import call
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _RICH_AVAILABLE
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_info

Expand Down Expand Up @@ -125,14 +126,8 @@ def _configure_model_summary_callback(self, enable_model_summary: bool) -> None:
)
return

progress_bar_callback = self.trainer.progress_bar_callback
is_progress_bar_rich = isinstance(progress_bar_callback, RichProgressBar)

model_summary: ModelSummary
if progress_bar_callback is not None and is_progress_bar_rich:
model_summary = RichModelSummary()
else:
model_summary = ModelSummary()
model_summary = RichModelSummary() if _RICH_AVAILABLE else ModelSummary()
self.trainer.callbacks.append(model_summary)

def _configure_progress_bar(self, enable_progress_bar: bool = True) -> None:
Expand All @@ -157,7 +152,7 @@ def _configure_progress_bar(self, enable_progress_bar: bool = True) -> None:
)

if enable_progress_bar:
progress_bar_callback = TQDMProgressBar()
progress_bar_callback = RichProgressBar() if _RICH_AVAILABLE else TQDMProgressBar()
self.trainer.callbacks.append(progress_bar_callback)

def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, dict[str, int]]] = None) -> None:
Expand Down
1 change: 1 addition & 0 deletions src/lightning/pytorch/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

_OMEGACONF_AVAILABLE = package_available("omegaconf")
_TORCHVISION_AVAILABLE = RequirementCache("torchvision")
_RICH_AVAILABLE = RequirementCache("rich>=10.2.2")


@functools.lru_cache(maxsize=128)
Expand Down
3 changes: 1 addition & 2 deletions src/lightning/pytorch/utilities/testing/_runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@

from lightning.fabric.utilities.testing import _runif_reasons as fabric_run_if
from lightning.pytorch.accelerators.cpu import _PSUTIL_AVAILABLE
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
from lightning.pytorch.core.module import _ONNX_AVAILABLE
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE, _RICH_AVAILABLE

_SKLEARN_AVAILABLE = RequirementCache("scikit-learn")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from collections import defaultdict
from typing import Union
from unittest import mock
from unittest.mock import ANY, Mock, PropertyMock, call
from unittest.mock import ANY, Mock, PropertyMock, call, patch

import pytest
import torch
Expand Down Expand Up @@ -109,6 +109,7 @@ def test_tqdm_progress_bar_misconfiguration():
Trainer(callbacks=TQDMProgressBar(), enable_progress_bar=False)


@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
@pytest.mark.parametrize("num_dl", [1, 2])
def test_tqdm_progress_bar_totals(tmp_path, num_dl):
"""Test that the progress finishes with the correct total steps processed."""
Expand Down Expand Up @@ -203,6 +204,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
assert pbar.predict_progress_bar.leave


@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
def test_tqdm_progress_bar_fast_dev_run(tmp_path):
model = BoringModel()

Expand Down Expand Up @@ -323,6 +325,7 @@ def test_tqdm_progress_bar_default_value(tmp_path):


@mock.patch.dict(os.environ, {"COLAB_GPU": "1"})
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
def test_tqdm_progress_bar_value_on_colab(tmp_path):
"""Test that Trainer will override the default in Google COLAB."""
trainer = Trainer(default_root_dir=tmp_path)
Expand Down Expand Up @@ -411,6 +414,7 @@ def test_test_progress_bar_update_amount(tmp_path, test_batches: int, refresh_ra
assert progress_bar.test_progress_bar.n_values == updates


@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
def test_tensor_to_float_conversion(tmp_path):
"""Check tensor gets converted to float."""

Expand All @@ -424,7 +428,13 @@ def training_step(self, batch, batch_idx):
trainer = Trainer(
default_root_dir=tmp_path, max_epochs=1, limit_train_batches=2, logger=False, enable_checkpointing=False
)
trainer.fit(TestModel())

with mock.patch.object(sys.stdout, "write") as mock_write:
trainer.fit(TestModel())
bar_updates = "".join(call.args[0] for call in mock_write.call_args_list)
assert "a=0.123" in bar_updates
assert "b=1.000" in bar_updates
assert "c=2.000" in bar_updates

torch.testing.assert_close(trainer.progress_bar_metrics["a"], 0.123)
assert trainer.progress_bar_metrics["b"] == 1.0
Expand Down Expand Up @@ -616,6 +626,7 @@ def test_progress_bar_max_val_check_interval(
assert pbar_callback.is_enabled


@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
@RunIf(min_cuda_gpus=2, standalone=True)
@pytest.mark.parametrize("val_check_interval", [0.2, 0.5])
def test_progress_bar_max_val_check_interval_ddp(tmp_path, val_check_interval):
Expand Down Expand Up @@ -703,7 +714,7 @@ def get_metrics(self, trainer, pl_module):
del items["v_num"]
# this is equivalent to mocking `set_postfix` as this method gets called every time
self.calls[trainer.state.fn].append((
trainer.state.stage,
trainer.state.stage.value,
trainer.current_epoch,
trainer.global_step,
items,
Expand Down
3 changes: 2 additions & 1 deletion tests/tests_pytorch/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from pathlib import Path
from re import escape
from unittest.mock import Mock
from unittest.mock import Mock, patch

import pytest
from lightning_utilities.test.warning import no_warning_call
Expand Down Expand Up @@ -119,6 +119,7 @@ def load_state_dict(self, state_dict) -> None:
self.state = state_dict["state"]


@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
def test_resume_callback_state_saved_by_type_stateful(tmp_path):
"""Test that a legacy checkpoint that didn't use a state key before can still be loaded, using
state_dict/load_state_dict."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import contextlib
import logging
from unittest import mock
from unittest.mock import MagicMock, Mock
from unittest.mock import MagicMock, Mock, patch

import pytest
import torch
Expand All @@ -35,13 +35,14 @@
from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector


@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
def test_checkpoint_callbacks_are_last(tmp_path):
"""Test that checkpoint callbacks always get moved to the end of the list, with preserved order."""
checkpoint1 = ModelCheckpoint(tmp_path, monitor="foo")
checkpoint2 = ModelCheckpoint(tmp_path, monitor="bar")
model_summary = ModelSummary()
"""Test that checkpoint callbacks always come last."""
checkpoint1 = ModelCheckpoint(tmp_path / "path1", filename="ckpt1", monitor="val_loss_c1")
checkpoint2 = ModelCheckpoint(tmp_path / "path2", filename="ckpt2", monitor="val_loss_c2")
early_stopping = EarlyStopping(monitor="foo")
lr_monitor = LearningRateMonitor()
model_summary = ModelSummary()
progress_bar = TQDMProgressBar()

# no model reference
Expand Down Expand Up @@ -71,7 +72,7 @@ def test_checkpoint_callbacks_are_last(tmp_path):
# with model-specific callbacks that substitute ones in Trainer
model = LightningModule()
model.configure_callbacks = lambda: [checkpoint1, early_stopping, model_summary, checkpoint2]
trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmp_path)])
trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmp_path, filename="ckpt_trainer")])
trainer.strategy._lightning_module = model
cb_connector = _CallbackConnector(trainer)
cb_connector._attach_model_callbacks()
Expand Down Expand Up @@ -161,6 +162,7 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmp_path):
)


@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
def test_attach_model_callbacks():
"""Test that the callbacks defined in the model and through Trainer get merged correctly."""

Expand Down
Loading
Loading