diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 6e70119b65e99..2da11d8673ad8 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -17,6 +17,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - For cross-device local checkpoints, instruct users to install `fsspec>=2025.5.0` if unavailable ([#20780](https://github.com/Lightning-AI/pytorch-lightning/pull/20780)) +- Default to RichProgressBar and RichModelSummary if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise. ([#9580](https://github.com/Lightning-AI/pytorch-lightning/pull/9580)) + ### Changed diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 7bb98e8a9058c..dab436bf569ef 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -17,15 +17,14 @@ from datetime import timedelta from typing import Any, Optional, Union, cast -from lightning_utilities.core.imports import RequirementCache +import torch from typing_extensions import override import lightning.pytorch as pl from lightning.pytorch.callbacks.progress.progress_bar import ProgressBar +from lightning.pytorch.utilities.imports import _RICH_AVAILABLE from lightning.pytorch.utilities.types import STEP_OUTPUT -_RICH_AVAILABLE = RequirementCache("rich>=10.2.2") - if _RICH_AVAILABLE: from rich import get_console, reconfigure from rich.console import Console, RenderableType @@ -171,7 +170,7 @@ def render(self, task: "Task") -> Text: return Text() if self._trainer.training and task.id not in self._tasks: self._tasks[task.id] = "None" - if self._renderable_cache: + if self._renderable_cache and self._current_task_id in self._renderable_cache: self._current_task_id = cast(TaskID, self._current_task_id) self._tasks[self._current_task_id] = self._renderable_cache[self._current_task_id][1] self._current_task_id = task.id @@ -185,7 +184,10 @@ def render(self, task: "Task") -> Text: def _generate_metrics_texts(self) -> Generator[str, None, None]: for name, value in self._metrics.items(): if not isinstance(value, str): - value = f"{value:{self._metrics_format}}" + try: + value = f"{value:{self._metrics_format}}" + except (TypeError, ValueError): + value = str(value) yield f"{name}: {value}" @@ -448,17 +450,12 @@ def _add_task(self, total_batches: Union[int, float], description: str, visible: ) def _update(self, progress_bar_id: Optional["TaskID"], current: int, visible: bool = True) -> None: - if self.progress is not None and self.is_enabled: - assert progress_bar_id is not None + if self.progress is not None and self.is_enabled and progress_bar_id is not None: total = self.progress.tasks[progress_bar_id].total assert total is not None if not self._should_update(current, total): return - - leftover = current % self.refresh_rate - advance = leftover if (current == total and leftover != 0) else self.refresh_rate - self.progress.update(progress_bar_id, advance=advance, visible=visible) - self.refresh() + self.progress.update(progress_bar_id, completed=current, visible=visible) def _should_update(self, current: int, total: Union[int, float]) -> bool: return current % self.refresh_rate == 0 or current == total @@ -552,9 +549,13 @@ def on_validation_batch_end( if self.is_disabled: return if trainer.sanity_checking: - self._update(self.val_sanity_progress_bar_id, batch_idx + 1) - elif self.val_progress_bar_id is not None: - self._update(self.val_progress_bar_id, batch_idx + 1) + if self.val_sanity_progress_bar_id is not None: + self._update(self.val_sanity_progress_bar_id, batch_idx + 1) + return + + if self.val_progress_bar_id is None: + return + self._update(self.val_progress_bar_id, batch_idx + 1) self.refresh() @override @@ -567,9 +568,8 @@ def on_test_batch_end( batch_idx: int, dataloader_idx: int = 0, ) -> None: - if self.is_disabled: + if self.is_disabled or self.test_progress_bar_id is None: return - assert self.test_progress_bar_id is not None self._update(self.test_progress_bar_id, batch_idx + 1) self.refresh() @@ -583,9 +583,8 @@ def on_predict_batch_end( batch_idx: int, dataloader_idx: int = 0, ) -> None: - if self.is_disabled: + if self.is_disabled or self.predict_progress_bar_id is None: return - assert self.predict_progress_bar_id is not None self._update(self.predict_progress_bar_id, batch_idx + 1) self.refresh() @@ -612,6 +611,17 @@ def _reset_progress_bar_ids(self) -> None: self.test_progress_bar_id = None self.predict_progress_bar_id = None + @override + def get_metrics( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> dict[str, Union[int, str, float, dict[str, float]]]: + items = super().get_metrics(trainer, pl_module) + # convert all metrics to float before sending to rich + for k, v in items.items(): + if isinstance(v, torch.Tensor): + items[k] = v.item() + return items + def _update_metrics(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: metrics = self.get_metrics(trainer, pl_module) if self._metric_component: diff --git a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py index 4ef260f00006d..942ba3627efc0 100644 --- a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py +++ b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py @@ -274,7 +274,7 @@ def on_train_batch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int ) -> None: n = batch_idx + 1 - if self._should_update(n, self.train_progress_bar.total): + if self.train_progress_bar is not None and self._should_update(n, self.train_progress_bar.total): _update_n(self.train_progress_bar, n) self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) @@ -322,7 +322,7 @@ def on_validation_batch_end( dataloader_idx: int = 0, ) -> None: n = batch_idx + 1 - if self._should_update(n, self.val_progress_bar.total): + if self.val_progress_bar is not None and self._should_update(n, self.val_progress_bar.total): _update_n(self.val_progress_bar, n) @override @@ -363,7 +363,7 @@ def on_test_batch_end( dataloader_idx: int = 0, ) -> None: n = batch_idx + 1 - if self._should_update(n, self.test_progress_bar.total): + if self.test_progress_bar is not None and self._should_update(n, self.test_progress_bar.total): _update_n(self.test_progress_bar, n) @override @@ -402,7 +402,7 @@ def on_predict_batch_end( dataloader_idx: int = 0, ) -> None: n = batch_idx + 1 - if self._should_update(n, self.predict_progress_bar.total): + if self.predict_progress_bar is not None and self._should_update(n, self.predict_progress_bar.total): _update_n(self.predict_progress_bar, n) @override diff --git a/src/lightning/pytorch/callbacks/rich_model_summary.py b/src/lightning/pytorch/callbacks/rich_model_summary.py index 817aeeb655a7a..ce00ae06890f6 100644 --- a/src/lightning/pytorch/callbacks/rich_model_summary.py +++ b/src/lightning/pytorch/callbacks/rich_model_summary.py @@ -16,7 +16,7 @@ from typing_extensions import override from lightning.pytorch.callbacks import ModelSummary -from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE +from lightning.pytorch.utilities.imports import _RICH_AVAILABLE from lightning.pytorch.utilities.model_summary import get_human_readable_count diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index 7f033dbd8e2c2..b1e9edfaf7220 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -25,7 +25,6 @@ import lightning.pytorch as pl from lightning.fabric.utilities.data import _set_sampler_epoch -from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher from lightning.pytorch.loops.loop import _Loop from lightning.pytorch.loops.progress import _BatchProgress @@ -44,6 +43,7 @@ from lightning.pytorch.utilities.combined_loader import CombinedLoader from lightning.pytorch.utilities.data import has_len_all_ranks from lightning.pytorch.utilities.exceptions import SIGTERMException +from lightning.pytorch.utilities.imports import _RICH_AVAILABLE from lightning.pytorch.utilities.model_helpers import _ModuleMode, is_overridden from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature diff --git a/src/lightning/pytorch/trainer/connectors/callback_connector.py b/src/lightning/pytorch/trainer/connectors/callback_connector.py index 5c351aeebc564..e95f196d9ae43 100644 --- a/src/lightning/pytorch/trainer/connectors/callback_connector.py +++ b/src/lightning/pytorch/trainer/connectors/callback_connector.py @@ -37,6 +37,7 @@ from lightning.pytorch.callbacks.timer import Timer from lightning.pytorch.trainer import call from lightning.pytorch.utilities.exceptions import MisconfigurationException +from lightning.pytorch.utilities.imports import _RICH_AVAILABLE from lightning.pytorch.utilities.model_helpers import is_overridden from lightning.pytorch.utilities.rank_zero import rank_zero_info @@ -125,14 +126,8 @@ def _configure_model_summary_callback(self, enable_model_summary: bool) -> None: ) return - progress_bar_callback = self.trainer.progress_bar_callback - is_progress_bar_rich = isinstance(progress_bar_callback, RichProgressBar) - model_summary: ModelSummary - if progress_bar_callback is not None and is_progress_bar_rich: - model_summary = RichModelSummary() - else: - model_summary = ModelSummary() + model_summary = RichModelSummary() if _RICH_AVAILABLE else ModelSummary() self.trainer.callbacks.append(model_summary) def _configure_progress_bar(self, enable_progress_bar: bool = True) -> None: @@ -157,7 +152,7 @@ def _configure_progress_bar(self, enable_progress_bar: bool = True) -> None: ) if enable_progress_bar: - progress_bar_callback = TQDMProgressBar() + progress_bar_callback = RichProgressBar() if _RICH_AVAILABLE else TQDMProgressBar() self.trainer.callbacks.append(progress_bar_callback) def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, dict[str, int]]] = None) -> None: diff --git a/src/lightning/pytorch/utilities/imports.py b/src/lightning/pytorch/utilities/imports.py index 6c0815a6af9dc..cb3290b3d6275 100644 --- a/src/lightning/pytorch/utilities/imports.py +++ b/src/lightning/pytorch/utilities/imports.py @@ -28,6 +28,7 @@ _OMEGACONF_AVAILABLE = package_available("omegaconf") _TORCHVISION_AVAILABLE = RequirementCache("torchvision") +_RICH_AVAILABLE = RequirementCache("rich>=10.2.2") @functools.lru_cache(maxsize=128) diff --git a/src/lightning/pytorch/utilities/testing/_runif.py b/src/lightning/pytorch/utilities/testing/_runif.py index 9c46913681143..b84503365b68b 100644 --- a/src/lightning/pytorch/utilities/testing/_runif.py +++ b/src/lightning/pytorch/utilities/testing/_runif.py @@ -17,9 +17,8 @@ from lightning.fabric.utilities.testing import _runif_reasons as fabric_run_if from lightning.pytorch.accelerators.cpu import _PSUTIL_AVAILABLE -from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE from lightning.pytorch.core.module import _ONNX_AVAILABLE -from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE +from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE, _RICH_AVAILABLE _SKLEARN_AVAILABLE = RequirementCache("scikit-learn") diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py index d93bf1cf60e9c..538f1bce57ce0 100644 --- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py @@ -18,7 +18,7 @@ from collections import defaultdict from typing import Union from unittest import mock -from unittest.mock import ANY, Mock, PropertyMock, call +from unittest.mock import ANY, Mock, PropertyMock, call, patch import pytest import torch @@ -109,6 +109,7 @@ def test_tqdm_progress_bar_misconfiguration(): Trainer(callbacks=TQDMProgressBar(), enable_progress_bar=False) +@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False) @pytest.mark.parametrize("num_dl", [1, 2]) def test_tqdm_progress_bar_totals(tmp_path, num_dl): """Test that the progress finishes with the correct total steps processed.""" @@ -203,6 +204,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): assert pbar.predict_progress_bar.leave +@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False) def test_tqdm_progress_bar_fast_dev_run(tmp_path): model = BoringModel() @@ -323,6 +325,7 @@ def test_tqdm_progress_bar_default_value(tmp_path): @mock.patch.dict(os.environ, {"COLAB_GPU": "1"}) +@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False) def test_tqdm_progress_bar_value_on_colab(tmp_path): """Test that Trainer will override the default in Google COLAB.""" trainer = Trainer(default_root_dir=tmp_path) @@ -411,6 +414,7 @@ def test_test_progress_bar_update_amount(tmp_path, test_batches: int, refresh_ra assert progress_bar.test_progress_bar.n_values == updates +@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False) def test_tensor_to_float_conversion(tmp_path): """Check tensor gets converted to float.""" @@ -424,7 +428,13 @@ def training_step(self, batch, batch_idx): trainer = Trainer( default_root_dir=tmp_path, max_epochs=1, limit_train_batches=2, logger=False, enable_checkpointing=False ) - trainer.fit(TestModel()) + + with mock.patch.object(sys.stdout, "write") as mock_write: + trainer.fit(TestModel()) + bar_updates = "".join(call.args[0] for call in mock_write.call_args_list) + assert "a=0.123" in bar_updates + assert "b=1.000" in bar_updates + assert "c=2.000" in bar_updates torch.testing.assert_close(trainer.progress_bar_metrics["a"], 0.123) assert trainer.progress_bar_metrics["b"] == 1.0 @@ -616,6 +626,7 @@ def test_progress_bar_max_val_check_interval( assert pbar_callback.is_enabled +@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False) @RunIf(min_cuda_gpus=2, standalone=True) @pytest.mark.parametrize("val_check_interval", [0.2, 0.5]) def test_progress_bar_max_val_check_interval_ddp(tmp_path, val_check_interval): @@ -703,7 +714,7 @@ def get_metrics(self, trainer, pl_module): del items["v_num"] # this is equivalent to mocking `set_postfix` as this method gets called every time self.calls[trainer.state.fn].append(( - trainer.state.stage, + trainer.state.stage.value, trainer.current_epoch, trainer.global_step, items, diff --git a/tests/tests_pytorch/callbacks/test_callbacks.py b/tests/tests_pytorch/callbacks/test_callbacks.py index 53ea109b6ddf3..34749087bfb97 100644 --- a/tests/tests_pytorch/callbacks/test_callbacks.py +++ b/tests/tests_pytorch/callbacks/test_callbacks.py @@ -13,7 +13,7 @@ # limitations under the License. from pathlib import Path from re import escape -from unittest.mock import Mock +from unittest.mock import Mock, patch import pytest from lightning_utilities.test.warning import no_warning_call @@ -119,6 +119,7 @@ def load_state_dict(self, state_dict) -> None: self.state = state_dict["state"] +@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False) def test_resume_callback_state_saved_by_type_stateful(tmp_path): """Test that a legacy checkpoint that didn't use a state key before can still be loaded, using state_dict/load_state_dict.""" diff --git a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py index 94b5fcba652be..54fbd065fa919 100644 --- a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py @@ -14,7 +14,7 @@ import contextlib import logging from unittest import mock -from unittest.mock import MagicMock, Mock +from unittest.mock import MagicMock, Mock, patch import pytest import torch @@ -35,13 +35,14 @@ from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector +@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False) def test_checkpoint_callbacks_are_last(tmp_path): - """Test that checkpoint callbacks always get moved to the end of the list, with preserved order.""" - checkpoint1 = ModelCheckpoint(tmp_path, monitor="foo") - checkpoint2 = ModelCheckpoint(tmp_path, monitor="bar") - model_summary = ModelSummary() + """Test that checkpoint callbacks always come last.""" + checkpoint1 = ModelCheckpoint(tmp_path / "path1", filename="ckpt1", monitor="val_loss_c1") + checkpoint2 = ModelCheckpoint(tmp_path / "path2", filename="ckpt2", monitor="val_loss_c2") early_stopping = EarlyStopping(monitor="foo") lr_monitor = LearningRateMonitor() + model_summary = ModelSummary() progress_bar = TQDMProgressBar() # no model reference @@ -71,7 +72,7 @@ def test_checkpoint_callbacks_are_last(tmp_path): # with model-specific callbacks that substitute ones in Trainer model = LightningModule() model.configure_callbacks = lambda: [checkpoint1, early_stopping, model_summary, checkpoint2] - trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmp_path)]) + trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmp_path, filename="ckpt_trainer")]) trainer.strategy._lightning_module = model cb_connector = _CallbackConnector(trainer) cb_connector._attach_model_callbacks() @@ -161,6 +162,7 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmp_path): ) +@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False) def test_attach_model_callbacks(): """Test that the callbacks defined in the model and through Trainer get merged correctly.""" diff --git a/tests/tests_pytorch/trainer/connectors/test_rich_integration.py b/tests/tests_pytorch/trainer/connectors/test_rich_integration.py new file mode 100644 index 0000000000000..62926d26018ba --- /dev/null +++ b/tests/tests_pytorch/trainer/connectors/test_rich_integration.py @@ -0,0 +1,169 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +import pytest +import torch + +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import ModelSummary, ProgressBar, RichModelSummary, RichProgressBar, TQDMProgressBar +from lightning.pytorch.demos.boring_classes import BoringModel + + +class TestRichIntegration: + @patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False) + def test_no_rich_defaults_tqdm_and_model_summary(self, tmp_path): + trainer = Trainer(default_root_dir=tmp_path, logger=False, enable_checkpointing=False) + assert any(isinstance(cb, TQDMProgressBar) for cb in trainer.callbacks) + assert any(isinstance(cb, ModelSummary) for cb in trainer.callbacks) + assert not any(isinstance(cb, RichProgressBar) for cb in trainer.callbacks) + assert not any(isinstance(cb, RichModelSummary) for cb in trainer.callbacks) + + @patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False) + def test_no_rich_respects_user_provided_tqdm_progress_bar(self, tmp_path): + user_progress_bar = TQDMProgressBar() + trainer = Trainer( + default_root_dir=tmp_path, callbacks=[user_progress_bar], logger=False, enable_checkpointing=False + ) + assert user_progress_bar in trainer.callbacks + assert sum(isinstance(cb, ProgressBar) for cb in trainer.callbacks) == 1 + + @patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False) + def test_no_rich_respects_user_provided_rich_progress_bar(self, tmp_path): + # If user explicitly provides RichProgressBar, it should be used, + # even if _RICH_AVAILABLE is False (simulating our connector logic). + # RequirementCache would normally prevent RichProgressBar instantiation if rich is truly not installed. + user_progress_bar = RichProgressBar() + trainer = Trainer( + default_root_dir=tmp_path, callbacks=[user_progress_bar], logger=False, enable_checkpointing=False + ) + assert user_progress_bar in trainer.callbacks + assert sum(isinstance(cb, ProgressBar) for cb in trainer.callbacks) == 1 + assert isinstance(trainer.progress_bar_callback, RichProgressBar) + + @patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False) + def test_no_rich_respects_user_provided_model_summary(self, tmp_path): + user_model_summary = ModelSummary() + trainer = Trainer( + default_root_dir=tmp_path, callbacks=[user_model_summary], logger=False, enable_checkpointing=False + ) + assert user_model_summary in trainer.callbacks + assert sum(isinstance(cb, ModelSummary) for cb in trainer.callbacks) == 1 + # Check that the specific instance is the one from the trainer's list of ModelSummary callbacks + assert trainer.callbacks[trainer.callbacks.index(user_model_summary)] == user_model_summary + + @patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False) + def test_no_rich_respects_user_provided_rich_model_summary(self, tmp_path): + user_model_summary = RichModelSummary() + trainer = Trainer( + default_root_dir=tmp_path, callbacks=[user_model_summary], logger=False, enable_checkpointing=False + ) + assert user_model_summary in trainer.callbacks + assert sum(isinstance(cb, ModelSummary) for cb in trainer.callbacks) == 1 + # Check that the specific instance is the one from the trainer's list of ModelSummary callbacks + model_summary_callbacks = [cb for cb in trainer.callbacks if isinstance(cb, ModelSummary)] + assert user_model_summary in model_summary_callbacks + assert isinstance(model_summary_callbacks[0], RichModelSummary) + + @patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", True) + def test_rich_available_defaults_rich_progress_and_summary(self, tmp_path): + trainer = Trainer(default_root_dir=tmp_path, logger=False, enable_checkpointing=False) + assert any(isinstance(cb, RichProgressBar) for cb in trainer.callbacks) + assert any(isinstance(cb, RichModelSummary) for cb in trainer.callbacks) + assert not any(isinstance(cb, TQDMProgressBar) for cb in trainer.callbacks) + # Ensure the only ModelSummary is the RichModelSummary + model_summaries = [cb for cb in trainer.callbacks if isinstance(cb, ModelSummary)] + assert len(model_summaries) == 1 + assert isinstance(model_summaries[0], RichModelSummary) + + @patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", True) + def test_rich_available_respects_user_tqdm_progress_bar(self, tmp_path): + user_progress_bar = TQDMProgressBar() + trainer = Trainer( + default_root_dir=tmp_path, callbacks=[user_progress_bar], logger=False, enable_checkpointing=False + ) + assert user_progress_bar in trainer.callbacks + assert sum(isinstance(cb, ProgressBar) for cb in trainer.callbacks) == 1 + assert isinstance(trainer.progress_bar_callback, TQDMProgressBar) + + @patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", True) + def test_rich_available_respects_user_model_summary(self, tmp_path): + user_model_summary = ModelSummary() # Non-rich + trainer = Trainer( + default_root_dir=tmp_path, callbacks=[user_model_summary], logger=False, enable_checkpointing=False + ) + assert user_model_summary in trainer.callbacks + model_summaries = [cb for cb in trainer.callbacks if isinstance(cb, ModelSummary)] + assert len(model_summaries) == 1 + assert isinstance(model_summaries[0], ModelSummary) + assert not isinstance(model_summaries[0], RichModelSummary) + + @patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False) + def test_progress_bar_disabled_no_rich(self, tmp_path): + trainer = Trainer( + default_root_dir=tmp_path, enable_progress_bar=False, logger=False, enable_checkpointing=False + ) + assert not any(isinstance(cb, ProgressBar) for cb in trainer.callbacks) + + @patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", True) + def test_progress_bar_disabled_with_rich(self, tmp_path): + trainer = Trainer( + default_root_dir=tmp_path, enable_progress_bar=False, logger=False, enable_checkpointing=False + ) + assert not any(isinstance(cb, ProgressBar) for cb in trainer.callbacks) + + @patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False) + def test_model_summary_disabled_no_rich(self, tmp_path): + trainer = Trainer( + default_root_dir=tmp_path, enable_model_summary=False, logger=False, enable_checkpointing=False + ) + assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks) + + @patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", True) + def test_model_summary_disabled_with_rich(self, tmp_path): + trainer = Trainer( + default_root_dir=tmp_path, enable_model_summary=False, logger=False, enable_checkpointing=False + ) + assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks) + + @patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", True) + def test_rich_progress_bar_tensor_metric(self, tmp_path): + """Test that tensor metrics are converted to float for RichProgressBar.""" + + class MyModel(BoringModel): + def training_step(self, batch, batch_idx): + self.log("my_tensor_metric", torch.tensor(1.23), prog_bar=True) + return super().training_step(batch, batch_idx) + + model = MyModel() + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=1, + limit_val_batches=0, + max_epochs=1, + logger=False, + enable_checkpointing=False, + ) + + with patch("lightning.pytorch.callbacks.progress.rich_progress.MetricsTextColumn.update") as mock_update: + trainer.fit(model) + + assert mock_update.call_count > 0 + # The metrics are updated multiple times, check the last call + last_call_metrics = mock_update.call_args[0][0] + assert "my_tensor_metric" in last_call_metrics + metric_val = last_call_metrics["my_tensor_metric"] + assert isinstance(metric_val, float) + assert metric_val == pytest.approx(1.23) diff --git a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py index be6de37ddff3a..576497ae19f77 100644 --- a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py @@ -27,12 +27,12 @@ from torch import Tensor from lightning.pytorch import Trainer, callbacks -from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.loops import _EvaluationLoop from lightning.pytorch.trainer.states import RunningStage from lightning.pytorch.utilities.exceptions import MisconfigurationException +from lightning.pytorch.utilities.imports import _RICH_AVAILABLE from tests_pytorch.helpers.runif import RunIf if _RICH_AVAILABLE: @@ -534,7 +534,7 @@ def test_step(self, batch, batch_idx): max_epochs=2, ) - # Train the model ⚡ + # Train the model trainer.fit(model) assert set(trainer.callback_metrics) == {