From 6e124e7207f6459cb43f540cfb5a1c6cc9b00f7a Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 6 Sep 2021 14:49:09 +0200 Subject: [PATCH] CI: precommit - docformatter (#8584) * CI: precommit - docformatter * fix deprecated Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 6 + _notebooks | 2 +- benchmarks/test_basic_parity.py | 8 +- benchmarks/test_sharded_parity.py | 13 +- docs/source/api_references.rst | 2 + docs/source/conf.py | 2 +- docs/source/extensions/plugins.rst | 2 + docs/source/links.rst | 2 + pl_examples/basic_examples/autoencoder.py | 6 +- .../backbone_image_classifier.py | 6 +- .../basic_examples/dali_image_classifier.py | 14 +- .../basic_examples/mnist_datamodule.py | 11 +- .../basic_examples/profiler_example.py | 6 +- .../basic_examples/simple_image_classifier.py | 6 +- .../computer_vision_fine_tuning.py | 19 +- .../generative_adversarial_net.py | 4 +- pl_examples/domain_templates/imagenet.py | 6 +- .../domain_templates/reinforce_learn_Qnet.py | 54 ++--- .../domain_templates/reinforce_learn_ppo.py | 60 +++--- .../domain_templates/semantic_segmentation.py | 18 +- pl_examples/domain_templates/unet.py | 17 +- pytorch_lightning/accelerators/accelerator.py | 101 +++++----- pytorch_lightning/callbacks/base.py | 30 ++- pytorch_lightning/callbacks/early_stopping.py | 5 +- pytorch_lightning/callbacks/finetuning.py | 26 +-- .../callbacks/gpu_stats_monitor.py | 8 +- pytorch_lightning/callbacks/lr_monitor.py | 10 +- .../callbacks/model_checkpoint.py | 30 +-- .../callbacks/prediction_writer.py | 3 +- .../callbacks/progress/__init__.py | 2 +- pytorch_lightning/callbacks/progress/base.py | 60 +++--- .../callbacks/progress/rich_progress.py | 3 +- .../{progress.py => tqdm_progress.py} | 92 ++++----- pytorch_lightning/callbacks/pruning.py | 25 +-- pytorch_lightning/callbacks/quantization.py | 28 ++- .../callbacks/stochastic_weight_avg.py | 16 +- pytorch_lightning/callbacks/timer.py | 5 +- .../callbacks/xla_stats_monitor.py | 6 +- pytorch_lightning/core/datamodule.py | 37 ++-- pytorch_lightning/core/decorators.py | 5 +- pytorch_lightning/core/hooks.py | 188 ++++++------------ pytorch_lightning/core/lightning.py | 155 ++++++--------- .../core/mixins/device_dtype_mixin.py | 7 +- .../core/mixins/hparams_mixin.py | 6 +- pytorch_lightning/core/optimizer.py | 16 +- pytorch_lightning/core/saving.py | 10 +- pytorch_lightning/loggers/base.py | 74 +++---- pytorch_lightning/loggers/comet.py | 9 +- pytorch_lightning/loggers/csv_logs.py | 31 ++- pytorch_lightning/loggers/mlflow.py | 18 +- pytorch_lightning/loggers/neptune.py | 28 +-- pytorch_lightning/loggers/tensorboard.py | 35 ++-- pytorch_lightning/loggers/test_tube.py | 9 +- pytorch_lightning/loggers/wandb.py | 9 +- pytorch_lightning/loops/base.py | 42 ++-- .../loops/batch/training_batch_loop.py | 25 +-- pytorch_lightning/loops/closure.py | 13 +- .../loops/dataloader/dataloader_loop.py | 14 +- .../loops/dataloader/evaluation_loop.py | 31 +-- .../loops/dataloader/prediction_loop.py | 26 +-- .../loops/epoch/evaluation_epoch_loop.py | 17 +- .../loops/epoch/prediction_epoch_loop.py | 21 +- .../loops/epoch/training_epoch_loop.py | 24 +-- pytorch_lightning/loops/fit_loop.py | 34 ++-- .../loops/optimizer/optimizer_loop.py | 23 ++- pytorch_lightning/loops/utilities.py | 11 +- pytorch_lightning/overrides/base.py | 5 +- pytorch_lightning/overrides/data_parallel.py | 21 +- pytorch_lightning/overrides/distributed.py | 25 +-- .../overrides/torch_distributed.py | 4 +- .../environments/kubeflow_environment.py | 8 +- .../environments/lightning_environment.py | 15 +- .../plugins/environments/lsf_environment.py | 13 +- .../plugins/io/checkpoint_plugin.py | 7 +- pytorch_lightning/plugins/io/torch_plugin.py | 10 +- pytorch_lightning/plugins/plugins_registry.py | 15 +- .../plugins/precision/apex_amp.py | 4 +- .../plugins/precision/deepspeed_precision.py | 4 +- pytorch_lightning/plugins/precision/double.py | 24 +-- .../precision/fully_sharded_native_amp.py | 2 +- .../plugins/precision/ipu_precision.py | 2 +- pytorch_lightning/plugins/precision/mixed.py | 2 +- .../plugins/precision/native_amp.py | 11 +- .../plugins/precision/precision_plugin.py | 35 ++-- .../plugins/precision/sharded_native_amp.py | 2 +- .../plugins/precision/tpu_bfloat.py | 2 +- .../plugins/training_type/ddp.py | 17 +- .../plugins/training_type/ddp2.py | 5 +- .../plugins/training_type/ddp_spawn.py | 11 +- .../plugins/training_type/deepspeed.py | 24 +-- pytorch_lightning/plugins/training_type/dp.py | 9 +- .../plugins/training_type/fully_sharded.py | 3 +- .../plugins/training_type/horovod.py | 3 +- .../plugins/training_type/ipu.py | 21 +- .../plugins/training_type/parallel.py | 13 +- .../plugins/training_type/single_device.py | 7 +- .../training_type/training_type_plugin.py | 83 ++++---- pytorch_lightning/profiler/__init__.py | 5 +- pytorch_lightning/profiler/advanced.py | 8 +- pytorch_lightning/profiler/base.py | 14 +- pytorch_lightning/profiler/pytorch.py | 14 +- pytorch_lightning/profiler/simple.py | 6 +- pytorch_lightning/profiler/xla.py | 12 +- pytorch_lightning/setup_tools.py | 4 +- pytorch_lightning/trainer/__init__.py | 4 +- pytorch_lightning/trainer/callback_hook.py | 12 +- .../trainer/configuration_validator.py | 2 +- .../connectors/accelerator_connector.py | 6 +- .../trainer/connectors/callback_connector.py | 7 +- .../connectors/checkpoint_connector.py | 36 ++-- .../trainer/connectors/debugging_connector.py | 2 +- .../trainer/connectors/env_vars_connector.py | 6 +- .../logger_connector/fx_validator.py | 2 +- .../logger_connector/logger_connector.py | 5 +- .../connectors/logger_connector/result.py | 17 +- pytorch_lightning/trainer/data_loading.py | 12 +- pytorch_lightning/trainer/optimizers.py | 7 +- pytorch_lightning/trainer/progress.py | 39 ++-- pytorch_lightning/trainer/properties.py | 63 +++--- pytorch_lightning/trainer/states.py | 8 +- pytorch_lightning/trainer/supporters.py | 63 ++---- pytorch_lightning/tuner/batch_size_scaling.py | 8 +- pytorch_lightning/tuner/lr_finder.py | 27 +-- pytorch_lightning/tuner/tuning.py | 14 +- pytorch_lightning/utilities/__init__.py | 2 +- pytorch_lightning/utilities/apply_func.py | 11 +- pytorch_lightning/utilities/argparse.py | 4 +- pytorch_lightning/utilities/auto_restart.py | 77 ++++--- pytorch_lightning/utilities/cli.py | 35 ++-- pytorch_lightning/utilities/data.py | 13 +- pytorch_lightning/utilities/deepspeed.py | 3 +- pytorch_lightning/utilities/device_parser.py | 13 +- pytorch_lightning/utilities/distributed.py | 19 +- pytorch_lightning/utilities/enums.py | 6 +- pytorch_lightning/utilities/exceptions.py | 8 +- pytorch_lightning/utilities/fetching.py | 35 ++-- pytorch_lightning/utilities/finite_checks.py | 5 +- pytorch_lightning/utilities/grads.py | 4 +- pytorch_lightning/utilities/imports.py | 8 +- pytorch_lightning/utilities/memory.py | 6 +- pytorch_lightning/utilities/metrics.py | 5 +- pytorch_lightning/utilities/model_summary.py | 40 ++-- pytorch_lightning/utilities/parsing.py | 47 ++--- pytorch_lightning/utilities/seed.py | 18 +- pytorch_lightning/utilities/warnings.py | 2 +- pytorch_lightning/utilities/xla_device.py | 11 +- requirements/collect_env_details.py | 3 +- tests/accelerators/ddp_model.py | 4 +- .../test_accelerator_connector.py | 2 +- tests/accelerators/test_common.py | 7 +- tests/accelerators/test_cpu.py | 16 +- tests/accelerators/test_ddp.py | 12 +- tests/accelerators/test_dp.py | 11 +- tests/accelerators/test_ipu.py | 26 +-- tests/accelerators/test_multi_nodes_gpu.py | 8 +- tests/accelerators/test_tpu_backend.py | 11 +- tests/base/model_optimizers.py | 8 +- tests/base/model_template.py | 3 +- tests/base/model_test_dataloaders.py | 2 +- tests/base/model_test_epoch_ends.py | 8 +- tests/base/model_test_steps.py | 12 +- tests/base/model_train_dataloaders.py | 4 +- tests/base/model_train_steps.py | 8 +- tests/base/model_valid_dataloaders.py | 2 +- tests/base/model_valid_epoch_ends.py | 10 +- tests/base/model_valid_steps.py | 12 +- tests/callbacks/test_callback_hook_outputs.py | 4 +- tests/callbacks/test_early_stopping.py | 10 +- tests/callbacks/test_finetuning_callback.py | 35 ++-- tests/callbacks/test_gpu_stats_monitor.py | 20 +- tests/callbacks/test_lr_monitor.py | 4 +- tests/callbacks/test_progress_bar.py | 29 ++- tests/callbacks/test_pruning.py | 7 +- tests/callbacks/test_quantization.py | 8 +- tests/callbacks/test_stochastic_weight_avg.py | 4 +- tests/callbacks/test_timer.py | 2 +- tests/checkpointing/test_model_checkpoint.py | 56 ++---- .../checkpointing/test_trainer_checkpoint.py | 8 +- tests/conftest.py | 7 +- tests/core/test_datamodules.py | 6 +- tests/core/test_lightning_optimizer.py | 36 ++-- tests/core/test_metric_result_integration.py | 14 +- tests/core/test_results.py | 2 +- tests/deprecated_api/__init__.py | 2 +- tests/deprecated_api/test_remove_1-5.py | 2 +- tests/deprecated_api/test_remove_1-6.py | 2 +- tests/deprecated_api/test_remove_1-7.py | 2 +- tests/deprecated_api/test_remove_2-0.py | 2 +- tests/helpers/boring_model.py | 4 +- tests/helpers/dataloaders.py | 4 +- tests/helpers/datasets.py | 7 +- tests/helpers/runif.py | 11 +- tests/helpers/test_models.py | 2 +- tests/loggers/test_all.py | 13 +- tests/loggers/test_base.py | 6 +- tests/loggers/test_csv.py | 8 +- tests/loggers/test_mlflow.py | 16 +- tests/loggers/test_neptune.py | 2 +- tests/loggers/test_tensorboard.py | 24 +-- tests/loggers/test_wandb.py | 16 +- .../data/horovod/train_default_model.py | 3 +- tests/models/test_cpu.py | 11 +- tests/models/test_gpu.py | 6 +- tests/models/test_hooks.py | 8 +- tests/models/test_hparams.py | 34 +--- tests/models/test_onnx.py | 16 +- tests/models/test_restore.py | 5 +- tests/models/test_torchscript.py | 8 +- tests/models/test_tpu.py | 26 +-- tests/plugins/test_amp_plugins.py | 12 +- tests/plugins/test_checkpoint_io_plugin.py | 4 +- ..._ddp_fully_sharded_with_full_state_dict.py | 20 +- tests/plugins/test_deepspeed_plugin.py | 107 ++++------ tests/plugins/test_sharded_plugin.py | 44 +--- tests/profiler/test_profiler.py | 41 ++-- .../connectors/test_callback_connector.py | 6 +- .../test_multiple_eval_dataloaders.py | 4 +- tests/trainer/flags/test_env_vars.py | 8 +- tests/trainer/flags/test_fast_dev_run.py | 6 +- tests/trainer/flags/test_min_max_epochs.py | 4 +- tests/trainer/flags/test_overfit_batches.py | 8 +- .../logging_/test_distributed_logging.py | 22 +- .../logging_/test_eval_loop_logging.py | 28 +-- .../trainer/logging_/test_logger_connector.py | 9 +- .../logging_/test_train_loop_logging.py | 36 +--- tests/trainer/loops/test_evaluation_loop.py | 10 +- .../loops/test_evaluation_loop_flow.py | 20 +- tests/trainer/loops/test_flow_warnings.py | 4 +- tests/trainer/loops/test_training_loop.py | 8 +- .../loops/test_training_loop_flow_dict.py | 20 +- .../loops/test_training_loop_flow_scalar.py | 3 +- .../optimization/test_manual_optimization.py | 60 ++---- .../optimization/test_multiple_optimizers.py | 18 +- tests/trainer/optimization/test_optimizers.py | 58 ++---- tests/trainer/properties/test_get_model.py | 12 +- tests/trainer/properties/test_log_dir.py | 26 +-- tests/trainer/test_config_validator.py | 16 +- tests/trainer/test_data_loading.py | 4 +- tests/trainer/test_dataloaders.py | 48 +++-- tests/trainer/test_progress.py | 6 +- tests/trainer/test_states.py | 4 +- tests/trainer/test_supporters.py | 26 ++- tests/trainer/test_trainer.py | 72 +++---- tests/trainer/test_trainer_cli.py | 6 +- tests/trainer/test_trainer_tricks.py | 4 +- tests/tuner/test_lr_finder.py | 21 +- tests/tuner/test_scale_batch_size.py | 6 +- tests/utilities/test_all_gather_grad.py | 2 +- tests/utilities/test_argparse.py | 15 +- tests/utilities/test_auto_restart.py | 35 ++-- tests/utilities/test_cli.py | 9 +- .../test_deepspeed_collate_checkpoint.py | 4 +- tests/utilities/test_dtype_device_mixin.py | 6 +- tests/utilities/test_fetching.py | 38 ++-- tests/utilities/test_imports.py | 2 +- tests/utilities/test_model_summary.py | 8 +- tests/utilities/test_parsing.py | 6 +- tests/utilities/test_seed.py | 17 +- tests/utilities/test_warnings.py | 4 +- tests/utilities/test_xla_device_utils.py | 6 +- 260 files changed, 1746 insertions(+), 2692 deletions(-) create mode 100644 docs/source/links.rst rename pytorch_lightning/callbacks/progress/{progress.py => tqdm_progress.py} (83%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e91c0c4033c34..50a3e771816f4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -52,6 +52,12 @@ repos: args: [--py36-plus] name: Upgrade code + - repo: https://github.com/myint/docformatter + rev: v1.4 + hooks: + - id: docformatter + args: [--in-place, --wrap-summaries=115, --wrap-descriptions=120] + - repo: https://github.com/asottile/yesqa rev: v1.2.3 hooks: diff --git a/_notebooks b/_notebooks index 6100885854c80..4fe3370eac9c4 160000 --- a/_notebooks +++ b/_notebooks @@ -1 +1 @@ -Subproject commit 6100885854c803458886c731cddd6bd67498c0a1 +Subproject commit 4fe3370eac9c448eceb36b835ff49ca30de7d404 diff --git a/benchmarks/test_basic_parity.py b/benchmarks/test_basic_parity.py index ab3b6ebfb84f4..9c2c3fb72e80e 100644 --- a/benchmarks/test_basic_parity.py +++ b/benchmarks/test_basic_parity.py @@ -59,9 +59,7 @@ def assert_parity_absolute(pl_values, pt_values, norm_by: float = 1, max_diff: f def test_pytorch_parity( tmpdir, cls_model: LightningModule, max_diff_speed: float, max_diff_memory: float, num_epochs: int, num_runs: int ): - """ - Verify that the same pytorch and lightning models achieve the same results - """ + """Verify that the same pytorch and lightning models achieve the same results.""" lightning = measure_loops(cls_model, kind="PT Lightning", num_epochs=num_epochs, num_runs=num_runs) vanilla = measure_loops(cls_model, kind="Vanilla PT", num_epochs=num_epochs, num_runs=num_runs) @@ -88,9 +86,7 @@ def _hook_memory(): def measure_loops(cls_model, kind, num_runs=10, num_epochs=10): - """ - Returns an array with the last loss from each epoch for each run - """ + """Returns an array with the last loss from each epoch for each run.""" hist_losses = [] hist_durations = [] hist_memory = [] diff --git a/benchmarks/test_sharded_parity.py b/benchmarks/test_sharded_parity.py index b86c2d4800d5e..b6bcb658dcde9 100644 --- a/benchmarks/test_sharded_parity.py +++ b/benchmarks/test_sharded_parity.py @@ -25,9 +25,7 @@ class SeedTrainLoaderModel(BoringModel): - """ - Overrides training loader to ensure we enforce the same seed for all DDP processes. - """ + """Overrides training loader to ensure we enforce the same seed for all DDP processes.""" def train_dataloader(self): seed_everything(42) @@ -87,8 +85,7 @@ def configure_optimizers(self): def record_ddp_fit_model_stats(trainer, model, use_cuda): - """ - Helper to calculate wall clock time for fit + max allocated memory. + """Helper to calculate wall clock time for fit + max allocated memory. Args: trainer: The trainer object. @@ -123,9 +120,8 @@ def plugin_parity_test( precision: int = 32, max_percent_speed_diff: float = 0.1, ): - """ - Ensures that the trained model is identical to the standard DDP implementation. - Also checks for speed/memory regressions, we should expect always less memory but performance to fluctuate. + """Ensures that the trained model is identical to the standard DDP implementation. Also checks for speed/memory + regressions, we should expect always less memory but performance to fluctuate. Args: model_cls: Model class to use for test. @@ -134,7 +130,6 @@ def plugin_parity_test( precision: Whether to use AMP or normal FP32 training. max_percent_speed_diff: The maximum speed difference compared to normal DDP training. This is more a safety net for variability in CI which can vary in speed, not for benchmarking. - """ # Train normal DDP diff --git a/docs/source/api_references.rst b/docs/source/api_references.rst index 49b5556d7a922..df70b2b0a3944 100644 --- a/docs/source/api_references.rst +++ b/docs/source/api_references.rst @@ -1,6 +1,8 @@ API References ============== +.. include:: links.rst + Accelerator API --------------- diff --git a/docs/source/conf.py b/docs/source/conf.py index 88c22059b3fe1..f68338803ce53 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -293,7 +293,7 @@ def setup(app): # Ignoring Third-party packages # https://stackoverflow.com/questions/15889621/sphinx-how-to-exclude-imports-in-automodule def package_list_from_file(file): - """List up package name (not containing version and extras) from a package list file""" + """List up package name (not containing version and extras) from a package list file.""" mocked_packages = [] with open(file) as fp: for ln in fp.readlines(): diff --git a/docs/source/extensions/plugins.rst b/docs/source/extensions/plugins.rst index f2ad1ac8e8b92..a7d88505202e2 100644 --- a/docs/source/extensions/plugins.rst +++ b/docs/source/extensions/plugins.rst @@ -4,6 +4,8 @@ Plugins ####### +.. include:: ../links.rst + Plugins allow custom integrations to the internals of the Trainer such as a custom precision or distributed implementation. diff --git a/docs/source/links.rst b/docs/source/links.rst new file mode 100644 index 0000000000000..64ec918bf8e25 --- /dev/null +++ b/docs/source/links.rst @@ -0,0 +1,2 @@ +.. _PyTorchJob: https://www.kubeflow.org/docs/components/training/pytorch/ +.. _Kubeflow: https://www.kubeflow.org diff --git a/pl_examples/basic_examples/autoencoder.py b/pl_examples/basic_examples/autoencoder.py index 179eafdc139a1..a78e728af1bf1 100644 --- a/pl_examples/basic_examples/autoencoder.py +++ b/pl_examples/basic_examples/autoencoder.py @@ -11,11 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -MNIST autoencoder example. +"""MNIST autoencoder example. -To run: -python autoencoder.py --trainer.max_epochs=50 +To run: python autoencoder.py --trainer.max_epochs=50 """ import torch diff --git a/pl_examples/basic_examples/backbone_image_classifier.py b/pl_examples/basic_examples/backbone_image_classifier.py index 26d83522ce6f3..f3e0297d0ed15 100644 --- a/pl_examples/basic_examples/backbone_image_classifier.py +++ b/pl_examples/basic_examples/backbone_image_classifier.py @@ -11,11 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -MNIST backbone image classifier example. +"""MNIST backbone image classifier example. -To run: -python backbone_image_classifier.py --trainer.max_epochs=50 +To run: python backbone_image_classifier.py --trainer.max_epochs=50 """ from typing import Optional diff --git a/pl_examples/basic_examples/dali_image_classifier.py b/pl_examples/basic_examples/dali_image_classifier.py index 8eda150cbb620..2cbc35f6b4805 100644 --- a/pl_examples/basic_examples/dali_image_classifier.py +++ b/pl_examples/basic_examples/dali_image_classifier.py @@ -45,9 +45,7 @@ class ExternalMNISTInputIterator: - """ - This iterator class wraps torchvision's MNIST dataset and returns the images and labels in batches - """ + """This iterator class wraps torchvision's MNIST dataset and returns the images and labels in batches.""" def __init__(self, mnist_ds, batch_size): self.batch_size = batch_size @@ -73,9 +71,7 @@ def __next__(self): class ExternalSourcePipeline(Pipeline): - """ - This DALI pipeline class just contains the MNIST iterator - """ + """This DALI pipeline class just contains the MNIST iterator.""" def __init__(self, batch_size, eii, num_threads, device_id): super().__init__(batch_size, num_threads, device_id, seed=12) @@ -88,10 +84,8 @@ def define_graph(self): class DALIClassificationLoader(DALIClassificationIterator): - """ - This class extends DALI's original `DALIClassificationIterator` with the `__len__()` function - so that we can call `len()` on it - """ + """This class extends DALI's original `DALIClassificationIterator` with the `__len__()` function so that we can + call `len()` on it.""" def __init__( self, diff --git a/pl_examples/basic_examples/mnist_datamodule.py b/pl_examples/basic_examples/mnist_datamodule.py index 693514e0a3620..68823eeac7bba 100644 --- a/pl_examples/basic_examples/mnist_datamodule.py +++ b/pl_examples/basic_examples/mnist_datamodule.py @@ -41,8 +41,7 @@ class MNISTDataModule(LightningDataModule): - """ - Standard MNIST, train, val, test splits and transforms + """Standard MNIST, train, val, test splits and transforms. >>> MNISTDataModule() # doctest: +ELLIPSIS <...mnist_datamodule.MNISTDataModule object at ...> @@ -100,14 +99,14 @@ def prepare_data(self): MNIST(self.data_dir, train=False, download=True) def setup(self, stage: Optional[str] = None): - """Split the train and valid dataset""" + """Split the train and valid dataset.""" extra = dict(transform=self.default_transforms) if self.default_transforms else {} dataset = MNIST(self.data_dir, train=True, download=False, **extra) train_length = len(dataset) self.dataset_train, self.dataset_val = random_split(dataset, [train_length - self.val_split, self.val_split]) def train_dataloader(self): - """MNIST train set removes a subset to use for validation""" + """MNIST train set removes a subset to use for validation.""" loader = DataLoader( self.dataset_train, batch_size=self.batch_size, @@ -119,7 +118,7 @@ def train_dataloader(self): return loader def val_dataloader(self): - """MNIST val set uses a subset of the training set for validation""" + """MNIST val set uses a subset of the training set for validation.""" loader = DataLoader( self.dataset_val, batch_size=self.batch_size, @@ -131,7 +130,7 @@ def val_dataloader(self): return loader def test_dataloader(self): - """MNIST test set uses the test split""" + """MNIST test set uses the test split.""" extra = dict(transform=self.test_transforms) if self.test_transforms else {} dataset = MNIST(self.data_dir, train=False, download=False, **extra) loader = DataLoader( diff --git a/pl_examples/basic_examples/profiler_example.py b/pl_examples/basic_examples/profiler_example.py index ab12fb623ecc3..598287ea640b0 100644 --- a/pl_examples/basic_examples/profiler_example.py +++ b/pl_examples/basic_examples/profiler_example.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -This script will generate 2 traces: one for `training_step` and one for `validation_step`. -The traces can be visualized in 2 ways: +"""This script will generate 2 traces: one for `training_step` and one for `validation_step`. The traces can be +visualized in 2 ways: + * With Chrome: 1. Open Chrome and copy/paste this url: `chrome://tracing/`. 2. Once tracing opens, click on `Load` at the top-right and load one of the generated traces. diff --git a/pl_examples/basic_examples/simple_image_classifier.py b/pl_examples/basic_examples/simple_image_classifier.py index 5fdcb8d8c3bb2..8e2850e17cd8a 100644 --- a/pl_examples/basic_examples/simple_image_classifier.py +++ b/pl_examples/basic_examples/simple_image_classifier.py @@ -11,11 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -MNIST simple image classifier example. +"""MNIST simple image classifier example. -To run: -python simple_image_classifier.py --trainer.max_epochs=50 +To run: python simple_image_classifier.py --trainer.max_epochs=50 """ import torch diff --git a/pl_examples/domain_templates/computer_vision_fine_tuning.py b/pl_examples/domain_templates/computer_vision_fine_tuning.py index 6a73dc5ee3b91..2fba1d8ad1759 100644 --- a/pl_examples/domain_templates/computer_vision_fine_tuning.py +++ b/pl_examples/domain_templates/computer_vision_fine_tuning.py @@ -11,12 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Computer vision example on Transfer Learning. -This computer vision example illustrates how one could fine-tune a pre-trained -network (by default, a ResNet50 is used) using pytorch-lightning. For the sake -of this example, the 'cats and dogs dataset' (~60MB, see `DATA_URL` below) and -the proposed network (denoted by `TransferLearningModel`, see below) is -trained for 15 epochs. +"""Computer vision example on Transfer Learning. This computer vision example illustrates how one could fine-tune a +pre-trained network (by default, a ResNet50 is used) using pytorch-lightning. For the sake of this example, the +'cats and dogs dataset' (~60MB, see `DATA_URL` below) and the proposed network (denoted by `TransferLearningModel`, +see below) is trained for 15 epochs. The training consists of three stages. @@ -91,7 +89,7 @@ def finetune_function(self, pl_module: pl.LightningModule, epoch: int, optimizer class CatDogImageDataModule(LightningDataModule): def __init__(self, dl_path: Union[str, Path] = "data", num_workers: int = 0, batch_size: int = 8): - """CatDogImageDataModule + """CatDogImageDataModule. Args: dl_path: root directory where to download the data @@ -166,7 +164,7 @@ def __init__( num_workers: int = 6, **kwargs, ) -> None: - """TransferLearningModel + """TransferLearningModel. Args: backbone: Name (as in ``torchvision.models``) of the feature extractor @@ -208,7 +206,10 @@ def __build_model(self): self.loss_func = F.binary_cross_entropy_with_logits def forward(self, x): - """Forward pass. Returns logits.""" + """Forward pass. + + Returns logits. + """ # 1. Feature extraction: x = self.feature_extractor(x) diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index 2ac51befec29a..48492c8ce7f04 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -11,9 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -To run this template just do: -python generative_adversarial_net.py +"""To run this template just do: python generative_adversarial_net.py. After a few epochs, launch TensorBoard to see the images being generated at every batch: diff --git a/pl_examples/domain_templates/imagenet.py b/pl_examples/domain_templates/imagenet.py index eb3dbe13132f6..baefc7c9440f9 100644 --- a/pl_examples/domain_templates/imagenet.py +++ b/pl_examples/domain_templates/imagenet.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -This example is largely adapted from https://github.com/pytorch/examples/blob/master/imagenet/main.py +"""This example is largely adapted from https://github.com/pytorch/examples/blob/master/imagenet/main.py. Before you can run this example, you will need to download the ImageNet dataset manually from the `official website `_ and place it into a folder `path/to/imagenet`. @@ -28,7 +27,6 @@ .. code-block: bash python imagenet.py --help - """ import os from argparse import ArgumentParser, Namespace @@ -112,7 +110,7 @@ def validation_step(self, batch, batch_idx): @staticmethod def __accuracy(output, target, topk=(1,)): - """Computes the accuracy over the k top predictions for the specified values of k""" + """Computes the accuracy over the k top predictions for the specified values of k.""" with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) diff --git a/pl_examples/domain_templates/reinforce_learn_Qnet.py b/pl_examples/domain_templates/reinforce_learn_Qnet.py index 06ddb646e1db4..123f89cc641d7 100644 --- a/pl_examples/domain_templates/reinforce_learn_Qnet.py +++ b/pl_examples/domain_templates/reinforce_learn_Qnet.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Deep Reinforcement Learning: Deep Q-network (DQN) +"""Deep Reinforcement Learning: Deep Q-network (DQN) The template illustrates using Lightning for Reinforcement Learning. The example builds a basic DQN using the classic CartPole environment. @@ -50,8 +49,7 @@ class DQN(nn.Module): - """ - Simple MLP network + """Simple MLP network. >>> DQN(10, 5) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE DQN( @@ -78,8 +76,7 @@ def forward(self, x): class ReplayBuffer: - """ - Replay Buffer for storing past experiences allowing the agent to learn from them + """Replay Buffer for storing past experiences allowing the agent to learn from them. >>> ReplayBuffer(5) # doctest: +ELLIPSIS <...reinforce_learn_Qnet.ReplayBuffer object at ...> @@ -96,8 +93,7 @@ def __len__(self) -> int: return len(self.buffer) def append(self, experience: Experience) -> None: - """ - Add experience to the buffer + """Add experience to the buffer. Args: experience: tuple (state, action, reward, done, new_state) @@ -118,9 +114,7 @@ def sample(self, batch_size: int) -> Tuple: class RLDataset(IterableDataset): - """ - Iterable Dataset containing the ExperienceBuffer - which will be updated with new experiences during training + """Iterable Dataset containing the ExperienceBuffer which will be updated with new experiences during training. >>> RLDataset(ReplayBuffer(5)) # doctest: +ELLIPSIS <...reinforce_learn_Qnet.RLDataset object at ...> @@ -142,8 +136,7 @@ def __iter__(self) -> Iterator: class Agent: - """ - Base Agent class handling the interaction with the environment + """Base Agent class handling the interaction with the environment. >>> env = gym.make("CartPole-v1") >>> buffer = ReplayBuffer(10) @@ -163,13 +156,11 @@ def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None: self.state = self.env.reset() def reset(self) -> None: - """Resets the environment and updates the state""" + """Resets the environment and updates the state.""" self.state = self.env.reset() def get_action(self, net: nn.Module, epsilon: float, device: str) -> int: - """ - Using the given network, decide what action to carry out - using an epsilon-greedy policy + """Using the given network, decide what action to carry out using an epsilon-greedy policy. Args: net: DQN network @@ -195,8 +186,7 @@ def get_action(self, net: nn.Module, epsilon: float, device: str) -> int: @torch.no_grad() def play_step(self, net: nn.Module, epsilon: float = 0.0, device: str = "cpu") -> Tuple[float, bool]: - """ - Carries out a single interaction step between the agent and the environment + """Carries out a single interaction step between the agent and the environment. Args: net: DQN network @@ -223,7 +213,7 @@ def play_step(self, net: nn.Module, epsilon: float = 0.0, device: str = "cpu") - class DQNLightning(pl.LightningModule): - """Basic DQN Model + """Basic DQN Model. >>> DQNLightning(env="CartPole-v1") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE DQNLightning( @@ -277,9 +267,8 @@ def __init__( self.populate(self.warm_start_steps) def populate(self, steps: int = 1000) -> None: - """ - Carries out several random steps through the environment to initially fill - up the replay buffer with experiences + """Carries out several random steps through the environment to initially fill up the replay buffer with + experiences. Args: steps: number of random steps to populate the buffer with @@ -288,8 +277,7 @@ def populate(self, steps: int = 1000) -> None: self.agent.play_step(self.net, epsilon=1.0) def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Passes in a state `x` through the network and gets the `q_values` of each action as an output + """Passes in a state `x` through the network and gets the `q_values` of each action as an output. Args: x: environment state @@ -301,8 +289,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return output def dqn_mse_loss(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: - """ - Calculates the mse loss using a mini batch from the replay buffer + """Calculates the mse loss using a mini batch from the replay buffer. Args: batch: current mini batch of replay data @@ -324,9 +311,8 @@ def dqn_mse_loss(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor return nn.MSELoss()(state_action_values, expected_state_action_values) def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> OrderedDict: - """ - Carries out a single step through the environment to update the replay buffer. - Then calculates loss based on the minibatch received + """Carries out a single step through the environment to update the replay buffer. Then calculates loss + based on the minibatch received. Args: batch: current mini batch of replay data @@ -362,22 +348,22 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> O return OrderedDict({"loss": loss, "log": log, "progress_bar": log}) def configure_optimizers(self) -> List[Optimizer]: - """Initialize Adam optimizer""" + """Initialize Adam optimizer.""" optimizer = optim.Adam(self.net.parameters(), lr=self.lr) return [optimizer] def __dataloader(self) -> DataLoader: - """Initialize the Replay Buffer dataset used for retrieving experiences""" + """Initialize the Replay Buffer dataset used for retrieving experiences.""" dataset = RLDataset(self.buffer, self.episode_length) dataloader = DataLoader(dataset=dataset, batch_size=self.batch_size, sampler=None) return dataloader def train_dataloader(self) -> DataLoader: - """Get train loader""" + """Get train loader.""" return self.__dataloader() def get_device(self, batch) -> str: - """Retrieve device currently being used by minibatch""" + """Retrieve device currently being used by minibatch.""" return batch[0].device.index if self.on_gpu else "cpu" @staticmethod diff --git a/pl_examples/domain_templates/reinforce_learn_ppo.py b/pl_examples/domain_templates/reinforce_learn_ppo.py index 2eefcb5abd6d6..e9d5378d0ad8b 100644 --- a/pl_examples/domain_templates/reinforce_learn_ppo.py +++ b/pl_examples/domain_templates/reinforce_learn_ppo.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -PyTorch Lightning implementation of Proximal Policy Optimization (PPO) +"""PyTorch Lightning implementation of Proximal Policy Optimization (PPO) + Paper authors: John Schulman, Filip Wolski, Prafulla Dhariwal, Alec Radford, Oleg Klimov @@ -42,9 +42,7 @@ def create_mlp(input_shape: Tuple[int], n_actions: int, hidden_size: int = 128): - """ - Simple Multi-Layer Perceptron network - """ + """Simple Multi-Layer Perceptron network.""" network = nn.Sequential( nn.Linear(input_shape[0], hidden_size), nn.ReLU(), @@ -57,10 +55,8 @@ def create_mlp(input_shape: Tuple[int], n_actions: int, hidden_size: int = 128): class ActorCategorical(nn.Module): - """ - Policy network, for discrete action spaces, which returns a distribution - and an action given an observation - """ + """Policy network, for discrete action spaces, which returns a distribution and an action given an + observation.""" def __init__(self, actor_net): """ @@ -80,8 +76,7 @@ def forward(self, states): return pi, actions def get_log_prob(self, pi: Categorical, actions: torch.Tensor): - """ - Takes in a distribution and actions and returns log prob of actions under the distribution + """Takes in a distribution and actions and returns log prob of actions under the distribution. Args: pi: torch distribution @@ -94,10 +89,8 @@ def get_log_prob(self, pi: Categorical, actions: torch.Tensor): class ActorContinous(nn.Module): - """ - Policy network, for continous action spaces, which returns a distribution - and an action given an observation - """ + """Policy network, for continous action spaces, which returns a distribution and an action given an + observation.""" def __init__(self, actor_net, act_dim): """ @@ -119,8 +112,7 @@ def forward(self, states): return pi, actions def get_log_prob(self, pi: Normal, actions: torch.Tensor): - """ - Takes in a distribution and actions and returns log prob of actions under the distribution + """Takes in a distribution and actions and returns log prob of actions under the distribution. Args: pi: torch distribution @@ -133,12 +125,11 @@ def get_log_prob(self, pi: Normal, actions: torch.Tensor): class ExperienceSourceDataset(IterableDataset): - """ - Implementation from PyTorch Lightning Bolts: - https://github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/datamodules/experience_source.py + """Implementation from PyTorch Lightning Bolts: https://github.com/PyTorchLightning/lightning- + bolts/blob/master/pl_bolts/datamodules/experience_source.py. - Basic experience source dataset. Takes a generate_batch function that returns an iterator. - The logic for the experience source and how the batch is generated is defined the Lightning model itself + Basic experience source dataset. Takes a generate_batch function that returns an iterator. The logic for the + experience source and how the batch is generated is defined the Lightning model itself """ def __init__(self, generate_batch: Callable): @@ -150,8 +141,7 @@ def __iter__(self) -> Iterator: class PPOLightning(pl.LightningModule): - """ - PyTorch Lightning implementation of PPO. + """PyTorch Lightning implementation of PPO. Example: model = PPOLightning("CartPole-v0") @@ -236,8 +226,7 @@ def __init__( self.state = torch.FloatTensor(self.env.reset()) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Passes in a state x through the network and returns the policy and a sampled action + """Passes in a state x through the network and returns the policy and a sampled action. Args: x: environment state @@ -251,7 +240,7 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Te return pi, action, value def discount_rewards(self, rewards: List[float], discount: float) -> List[float]: - """Calculate the discounted rewards of all rewards in list + """Calculate the discounted rewards of all rewards in list. Args: rewards: list of rewards/advantages @@ -271,7 +260,7 @@ def discount_rewards(self, rewards: List[float], discount: float) -> List[float] return list(reversed(cumul_reward)) def calc_advantage(self, rewards: List[float], values: List[float], last_value: float) -> List[float]: - """Calculate the advantage given rewards, state values, and the last value of episode + """Calculate the advantage given rewards, state values, and the last value of episode. Args: rewards: list of episode rewards @@ -387,8 +376,7 @@ def critic_loss(self, state, action, logp_old, qval, adv) -> torch.Tensor: return loss_critic def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx, optimizer_idx): - """ - Carries out a single update to actor and critic network from a batch of replay buffer. + """Carries out a single update to actor and critic network from a batch of replay buffer. Args: batch: batch of replay buffer/trajectory data @@ -420,28 +408,26 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx, opt return loss_critic def configure_optimizers(self) -> List[Optimizer]: - """Initialize Adam optimizer""" + """Initialize Adam optimizer.""" optimizer_actor = torch.optim.Adam(self.actor.parameters(), lr=self.lr_actor) optimizer_critic = torch.optim.Adam(self.critic.parameters(), lr=self.lr_critic) return optimizer_actor, optimizer_critic def optimizer_step(self, *args, **kwargs): - """ - Run 'nb_optim_iters' number of iterations of gradient descent on actor and critic - for each data sample. - """ + """Run 'nb_optim_iters' number of iterations of gradient descent on actor and critic for each data + sample.""" for _ in range(self.nb_optim_iters): super().optimizer_step(*args, **kwargs) def _dataloader(self) -> DataLoader: - """Initialize the Replay Buffer dataset used for retrieving experiences""" + """Initialize the Replay Buffer dataset used for retrieving experiences.""" dataset = ExperienceSourceDataset(self.generate_trajectory_samples) dataloader = DataLoader(dataset=dataset, batch_size=self.batch_size) return dataloader def train_dataloader(self) -> DataLoader: - """Get train loader""" + """Get train loader.""" return self._dataloader() @staticmethod diff --git a/pl_examples/domain_templates/semantic_segmentation.py b/pl_examples/domain_templates/semantic_segmentation.py index d0ffb5a757155..d5a10c4faa6c5 100644 --- a/pl_examples/domain_templates/semantic_segmentation.py +++ b/pl_examples/domain_templates/semantic_segmentation.py @@ -33,7 +33,8 @@ def _create_synth_kitti_dataset(path_dir: str, image_dims: tuple = (1024, 512)): - """Create synthetic dataset with random images, just to simulate that the dataset have been already downloaded.""" + """Create synthetic dataset with random images, just to simulate that the dataset have been already + downloaded.""" path_dir_images = os.path.join(path_dir, KITTI.IMAGE_PATH) path_dir_masks = os.path.join(path_dir, KITTI.MASK_PATH) for p_dir in (path_dir_images, path_dir_masks): @@ -46,8 +47,8 @@ def _create_synth_kitti_dataset(path_dir: str, image_dims: tuple = (1024, 512)): class KITTI(Dataset): - """ - Class for KITTI Semantic Segmentation Benchmark dataset + """Class for KITTI Semantic Segmentation Benchmark dataset. + Dataset link - http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015 There are 34 classes in the given labels. However, not all of them are useful for training @@ -128,9 +129,7 @@ def __getitem__(self, idx): return img, mask def encode_segmap(self, mask): - """ - Sets void classes to zero so they won't be considered for training - """ + """Sets void classes to zero so they won't be considered for training.""" for voidc in self.void_labels: mask[mask == voidc] = self.ignore_index for validc in self.valid_labels: @@ -140,9 +139,7 @@ def encode_segmap(self, mask): return mask def get_filenames(self, path): - """ - Returns a list of absolute paths to images inside given `path` - """ + """Returns a list of absolute paths to images inside given `path`""" files_list = [] for filename in os.listdir(path): files_list.append(os.path.join(path, filename)) @@ -150,8 +147,7 @@ def get_filenames(self, path): class SegModel(pl.LightningModule): - """ - Semantic Segmentation Module + """Semantic Segmentation Module. This is a basic semantic segmentation module implemented with Lightning. It uses CrossEntropyLoss as the default loss function. May be replaced with diff --git a/pl_examples/domain_templates/unet.py b/pl_examples/domain_templates/unet.py index eec9589065f4c..6714699883ab4 100644 --- a/pl_examples/domain_templates/unet.py +++ b/pl_examples/domain_templates/unet.py @@ -18,8 +18,8 @@ class UNet(nn.Module): - """ - Architecture based on U-Net: Convolutional Networks for Biomedical Image Segmentation + """Architecture based on U-Net: Convolutional Networks for Biomedical Image Segmentation. + Link - https://arxiv.org/abs/1505.04597 >>> UNet(num_classes=2, num_layers=3) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE @@ -73,9 +73,7 @@ def forward(self, x): class DoubleConv(nn.Module): - """ - Double Convolution and BN and ReLU - (3x3 conv -> BN -> ReLU) ** 2 + """Double Convolution and BN and ReLU (3x3 conv -> BN -> ReLU) ** 2. >>> DoubleConv(4, 4) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE DoubleConv( @@ -99,8 +97,7 @@ def forward(self, x): class Down(nn.Module): - """ - Combination of MaxPool2d and DoubleConv in series + """Combination of MaxPool2d and DoubleConv in series. >>> Down(4, 8) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Down( @@ -122,10 +119,8 @@ def forward(self, x): class Up(nn.Module): - """ - Upsampling (by either bilinear interpolation or transpose convolutions) - followed by concatenation of feature map from contracting path, - followed by double 3x3 convolution. + """Upsampling (by either bilinear interpolation or transpose convolutions) followed by concatenation of feature + map from contracting path, followed by double 3x3 convolution. >>> Up(8, 4) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Up( diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 6038a8abc8f5c..f40dc9e1576cf 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -35,9 +35,7 @@ class Accelerator: - """ - The Accelerator Base Class. - An Accelerator is meant to deal with one type of Hardware. + """The Accelerator Base Class. An Accelerator is meant to deal with one type of Hardware. Currently there are accelerators for: @@ -47,7 +45,6 @@ class Accelerator: Each Accelerator gets two plugins upon initialization: One to handle differences from the training routine and one to handle different precisions. - """ def __init__(self, precision_plugin: PrecisionPlugin, training_type_plugin: TrainingTypePlugin) -> None: @@ -64,20 +61,19 @@ def __init__(self, precision_plugin: PrecisionPlugin, training_type_plugin: Trai self.optimizer_frequencies: List = [] def connect(self, model: "pl.LightningModule") -> None: - """Transfers ownership of the model to this plugin""" + """Transfers ownership of the model to this plugin.""" self.training_type_plugin.connect(model) def setup_environment(self) -> None: - """ - Setup any processes or distributed connections. - This is called before the LightningModule/DataModule setup hook - which allows the user to access the accelerator environment before setup is complete. + """Setup any processes or distributed connections. + + This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator + environment before setup is complete. """ self.training_type_plugin.setup_environment() def setup(self, trainer: "pl.Trainer") -> None: - """ - Setup plugins for the trainer fit and creates optimizers. + """Setup plugins for the trainer fit and creates optimizers. Args: trainer: the trainer instance @@ -125,9 +121,10 @@ def post_dispatch(self, trainer: "pl.Trainer") -> None: @property def model(self) -> Module: - """ - Returns the model. This can also be a wrapped LightningModule. - For retrieving the pure LightningModule use :attr:`Accelerator.lightning_module` + """Returns the model. + + This can also be a wrapped LightningModule. For retrieving the pure LightningModule use + :attr:`Accelerator.lightning_module` """ return self.training_type_plugin.model @@ -137,27 +134,27 @@ def model(self, new_model: Module) -> None: @property def lightning_module(self) -> "pl.LightningModule": - """ - Returns the pure LightningModule. + """Returns the pure LightningModule. + To get the potentially wrapped model use :attr:`Accelerator.model` """ return self.training_type_plugin.lightning_module @property def root_device(self) -> torch.device: - """Returns the root device""" + """Returns the root device.""" return self.training_type_plugin.root_device def teardown(self) -> None: - """ - This method is called to teardown the training process. + """This method is called to teardown the training process. + It is the right place to release memory and free other resources. """ self.training_type_plugin.teardown() def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any: - """Moves the batch to the correct device. - The returned batch is of the same type as the input batch, just having all tensors on the correct device. + """Moves the batch to the correct device. The returned batch is of the same type as the input batch, just + having all tensors on the correct device. Args: batch: The batch of samples to move to the correct device @@ -238,7 +235,7 @@ def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: return self.training_type_plugin.predict_step(*step_kwargs.values()) def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: - """A hook to do something at the end of the training step + """A hook to do something at the end of the training step. Args: output: the output of the training step @@ -246,7 +243,7 @@ def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: return self.training_type_plugin.training_step_end(output) def test_step_end(self, output: Optional[STEP_OUTPUT]) -> Optional[STEP_OUTPUT]: - """A hook to do something at the end of the test step + """A hook to do something at the end of the test step. Args: output: the output of the test step @@ -254,7 +251,7 @@ def test_step_end(self, output: Optional[STEP_OUTPUT]) -> Optional[STEP_OUTPUT]: return self.training_type_plugin.test_step_end(output) def validation_step_end(self, output: Optional[STEP_OUTPUT]) -> Optional[STEP_OUTPUT]: - """A hook to do something at the end of the validation step + """A hook to do something at the end of the validation step. Args: output: the output of the validation step @@ -284,7 +281,6 @@ def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Cal optimizer: the optimizer performing the step opt_idx: index of the current optimizer lambda_closure: closure calculating the loss value - """ make_optimizer_step = self.precision_plugin.pre_optimizer_step( self.lightning_module, optimizer, opt_idx, lambda_closure, **kwargs @@ -300,7 +296,7 @@ def run_optimizer_step( self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs) def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None: - """Zeros all model parameter's gradients""" + """Zeros all model parameter's gradients.""" model_ref = self.lightning_module model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx) @@ -310,14 +306,13 @@ def clip_gradients( clip_val: Union[int, float], gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, ) -> None: - """clips all the optimizer parameters to the given value""" + """clips all the optimizer parameters to the given value.""" self.precision_plugin.clip_gradients( optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm, model=self.model ) def setup_optimizers(self, trainer: "pl.Trainer") -> None: - """ - Creates optimizers and schedulers + """Creates optimizers and schedulers. Args: trainer: the Trainer, these optimizers should be connected to @@ -336,7 +331,7 @@ def setup_training_type_plugin(self) -> None: self.training_type_plugin.setup() def setup_precision_plugin(self) -> None: - """Attaches the precision plugin to the accelerator""" + """Attaches the precision plugin to the accelerator.""" model, optimizers, schedulers = self.precision_plugin.connect(self.model, self.optimizers, self.lr_schedulers) self.model = model self.optimizers = optimizers @@ -359,15 +354,16 @@ def scaler(self) -> Optional["GradScaler"]: return getattr(self.precision_plugin, "scaler", None) def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: - """ - Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom - plugins. + """Returns state of an optimizer. + + Allows for syncing/collating optimizer state from processes in custom plugins. """ return getattr(self.training_type_plugin, "optimizer_state", lambda x: x.state_dict())(optimizer) def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: - """ - Returns state of model. Allows for syncing/collating model state from processes in custom plugins. + """Returns state of model. + + Allows for syncing/collating model state from processes in custom plugins. """ return self.training_type_plugin.lightning_module_state_dict() @@ -375,7 +371,8 @@ def barrier(self, name: Optional[str] = None) -> None: self.training_type_plugin.barrier(name=name) def broadcast(self, obj: object, src: int = 0) -> object: - """Broadcasts an object to all processes, such that the src object is broadcast to all other ranks if needed. + """Broadcasts an object to all processes, such that the src object is broadcast to all other ranks if + needed. Args: obj: Object to broadcast to all process, usually a tensor or collection of tensors. @@ -384,8 +381,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: return self.training_type_plugin.broadcast(obj, src) def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: - """ - Function to gather a tensor from several distributed processes. + """Function to gather a tensor from several distributed processes. Args: tensor: tensor of shape (batch, ...) @@ -398,7 +394,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo return self.training_type_plugin.all_gather(tensor, group=group, sync_grads=sync_grads) def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: - """Wraps the dataloader if necessary + """Wraps the dataloader if necessary. Args: dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader` @@ -407,16 +403,16 @@ def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[I @property def results(self) -> Any: - """ - The results of the last run will be cached within the training type plugin. + """The results of the last run will be cached within the training type plugin. + In distributed training, we make sure to transfer the results to the appropriate master process. """ return self.training_type_plugin.results @contextlib.contextmanager def model_sharded_context(self) -> Generator[None, None, None]: - """ - Provide hook to create modules in a distributed aware context. This is useful for when we'd like to + """Provide hook to create modules in a distributed aware context. This is useful for when we'd like to. + shard the model instantly - useful for extremely large models. Can save memory and initialization time. @@ -437,8 +433,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: @property def call_configure_sharded_model_hook(self) -> bool: - """ - Allow model parallel hook to be called in suitable environments determined by the training type plugin. + """Allow model parallel hook to be called in suitable environments determined by the training type plugin. This is useful for when we want to shard the model once within fit. Returns: @@ -452,10 +447,9 @@ def call_configure_sharded_model_hook(self, mode: bool) -> None: @property def setup_optimizers_in_pre_dispatch(self) -> bool: - """ - Override to delay setting optimizers and schedulers till after dispatch. - This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model. - However this may break certain precision plugins such as APEX which require optimizers to be set. + """Override to delay setting optimizers and schedulers till after dispatch. This is useful when the + `TrainingTypePlugin` requires operating on the wrapped accelerator model. However this may break certain + precision plugins such as APEX which require optimizers to be set. Returns: If True, delay setup optimizers until `pre_dispatch`, else call within `setup`. @@ -464,9 +458,8 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: @property def restore_checkpoint_after_pre_dispatch(self) -> bool: - """ - Override to delay restoring from checkpoint till after pre-dispatch. - This is useful when the plugin requires all the setup hooks to run before loading checkpoint. + """Override to delay restoring from checkpoint till after pre-dispatch. This is useful when the plugin + requires all the setup hooks to run before loading checkpoint. Returns: If true, restore checkpoint after pre_dispatch. @@ -509,7 +502,5 @@ def on_train_end(self) -> None: return self.training_type_plugin.on_train_end() def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: - """ - Called in the training loop before anything happens for that batch. - """ + """Called in the training loop before anything happens for that batch.""" return self.training_type_plugin.on_train_batch_start(batch, batch_idx, dataloader_idx) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index b67b304424d8f..97cf4a5ddb849 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -35,11 +35,11 @@ class Callback(abc.ABC): @property def state_key(self) -> str: - """ - Identifier for the state of the callback. Used to store and retrieve a callback's state from the - checkpoint dictionary by ``checkpoint["callbacks"][state_key]``. Implementations of a callback need to - provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of - multiple instances of that callback. + """Identifier for the state of the callback. + + Used to store and retrieve a callback's state from the checkpoint dictionary by + ``checkpoint["callbacks"][state_key]``. Implementations of a callback need to provide a unique state key if 1) + the callback has state and 2) it is desired to maintain the state of multiple instances of that callback. """ return self.__class__.__qualname__ @@ -49,9 +49,8 @@ def _legacy_state_key(self) -> Type["Callback"]: return type(self) def _generate_state_key(self, **kwargs: Any) -> str: - """ - Formats a set of key-value pairs into a state key string with the callback class name prefixed. - Useful for defining a :attr:`state_key`. + """Formats a set of key-value pairs into a state key string with the callback class name prefixed. Useful + for defining a :attr:`state_key`. Args: **kwargs: A set of key-value pairs. Must be serializable to :class:`str`. @@ -59,18 +58,18 @@ def _generate_state_key(self, **kwargs: Any) -> str: return f"{self.__class__.__qualname__}{repr(kwargs)}" def on_configure_sharded_model(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - """Called before configure sharded model""" + """Called before configure sharded model.""" def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - """Called before accelerator is being setup""" + """Called before accelerator is being setup.""" pass def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: - """Called when fit, validate, test, predict, or tune begins""" + """Called when fit, validate, test, predict, or tune begins.""" pass def teardown(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: - """Called when fit, validate, test, predict, or tune ends""" + """Called when fit, validate, test, predict, or tune ends.""" pass def on_init_start(self, trainer: "pl.Trainer") -> None: @@ -82,11 +81,11 @@ def on_init_end(self, trainer: "pl.Trainer") -> None: pass def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - """Called when fit begins""" + """Called when fit begins.""" pass def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - """Called when fit ends""" + """Called when fit ends.""" pass def on_sanity_check_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: @@ -279,8 +278,7 @@ def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", e def on_save_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] ) -> dict: - """ - Called when saving a model checkpoint, use to persist state. + """Called when saving a model checkpoint, use to persist state. Args: trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance. diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index ecb46eab446d4..04c24059da7e2 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -194,10 +194,7 @@ def on_validation_end(self, trainer, pl_module) -> None: self._run_early_stopping_check(trainer) def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None: - """ - Checks whether the early stopping condition is met - and if so tells the trainer to stop the training. - """ + """Checks whether the early stopping condition is met and if so tells the trainer to stop the training.""" logs = trainer.callback_metrics if trainer.fast_dev_run or not self._validate_condition_metric( # disable early_stopping with fast_dev_run diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index 2636896d2fd47..fb412c2e71435 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -109,9 +109,8 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - @staticmethod def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> List[Module]: - """ - This function is used to flatten a module or an iterable of modules into a list of its leaf modules (modules - with no children) and parent modules that have parameters directly themselves. + """This function is used to flatten a module or an iterable of modules into a list of its leaf modules + (modules with no children) and parent modules that have parameters directly themselves. Args: modules: A given module or an iterable of modules @@ -157,8 +156,7 @@ def filter_params( @staticmethod def make_trainable(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> None: - """ - Unfreezes the parameters of the provided modules + """Unfreezes the parameters of the provided modules. Args: modules: A given module or an iterable of modules @@ -171,8 +169,7 @@ def make_trainable(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> @staticmethod def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: bool = True) -> None: - """ - Freezes the parameters of the provided modules + """Freezes the parameters of the provided modules. Args: modules: A given module or an iterable of modules @@ -192,9 +189,7 @@ def freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], train_bn: @staticmethod def filter_on_optimizer(optimizer: Optimizer, params: Iterable) -> List: - """ - This function is used to exclude any parameter which already exists in - this optimizer + """This function is used to exclude any parameter which already exists in this optimizer. Args: optimizer: Optimizer used for parameter exclusion @@ -229,8 +224,7 @@ def unfreeze_and_add_param_group( initial_denom_lr: float = 10.0, train_bn: bool = True, ) -> None: - """ - Unfreezes a module and adds its parameters to an optimizer. + """Unfreezes a module and adds its parameters to an optimizer. Args: @@ -298,15 +292,11 @@ def on_train_epoch_start(self, trainer, pl_module): self._store(pl_module, opt_idx, num_param_groups, current_param_groups) def finetune_function(self, pl_module: "pl.LightningModule", epoch: int, optimizer: Optimizer, opt_idx: int): - """ - Override to add your unfreeze logic - """ + """Override to add your unfreeze logic.""" raise NotImplementedError def freeze_before_training(self, pl_module: "pl.LightningModule"): - """ - Override to add your freeze logic - """ + """Override to add your freeze logic.""" raise NotImplementedError diff --git a/pytorch_lightning/callbacks/gpu_stats_monitor.py b/pytorch_lightning/callbacks/gpu_stats_monitor.py index 3a8d110d59376..e09af1ea57f67 100644 --- a/pytorch_lightning/callbacks/gpu_stats_monitor.py +++ b/pytorch_lightning/callbacks/gpu_stats_monitor.py @@ -180,7 +180,7 @@ def on_train_batch_end( @staticmethod def _get_gpu_ids(device_ids: List[int]) -> List[str]: - """Get the unmasked real GPU IDs""" + """Get the unmasked real GPU IDs.""" # All devices if `CUDA_VISIBLE_DEVICES` unset default = ",".join(str(i) for i in range(torch.cuda.device_count())) cuda_visible_devices: List[str] = os.getenv("CUDA_VISIBLE_DEVICES", default=default).split(",") @@ -216,7 +216,7 @@ def _to_float(x: str) -> float: def _parse_gpu_stats( device_ids: List[int], stats: List[List[float]], keys: List[Tuple[str, str]] ) -> Dict[str, float]: - """Parse the gpu stats into a loggable dict""" + """Parse the gpu stats into a loggable dict.""" logs = {} for i, device_id in enumerate(device_ids): for j, (x, unit) in enumerate(keys): @@ -224,7 +224,7 @@ def _parse_gpu_stats( return logs def _get_gpu_stat_keys(self) -> List[Tuple[str, str]]: - """Get the GPU stats keys""" + """Get the GPU stats keys.""" stat_keys = [] if self._log_stats.gpu_utilization: @@ -236,7 +236,7 @@ def _get_gpu_stat_keys(self) -> List[Tuple[str, str]]: return stat_keys def _get_gpu_device_stat_keys(self) -> List[Tuple[str, str]]: - """Get the device stats keys""" + """Get the device stats keys.""" stat_keys = [] if self._log_stats.fan_speed: diff --git a/pytorch_lightning/callbacks/lr_monitor.py b/pytorch_lightning/callbacks/lr_monitor.py index d7f350fecfeae..01fded3e8dc9c 100644 --- a/pytorch_lightning/callbacks/lr_monitor.py +++ b/pytorch_lightning/callbacks/lr_monitor.py @@ -94,10 +94,8 @@ def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = self.lr_sch_names = [] def on_train_start(self, trainer, *args, **kwargs): - """ - Called before training, determines unique names for all lr - schedulers in the case of multiple of the same type or in - the case of multiple parameter groups + """Called before training, determines unique names for all lr schedulers in the case of multiple of the + same type or in the case of multiple parameter groups. Raises: MisconfigurationException: @@ -183,9 +181,7 @@ def _extract_lr(self, param_group: Dict[str, Any], name: str) -> Dict[str, Any]: return {name: lr} def _remap_keys(self, names: List[str], token: str = "/pg1") -> None: - """ - This function is used the remap the keys if param groups for a given optimizer increased. - """ + """This function is used the remap the keys if param groups for a given optimizer increased.""" for new_name in names: old_name = new_name.replace(token, "") if token in new_name and old_name in self.lrs: diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 51b27d1cc898c..62569843d2e34 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -330,8 +330,7 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul self.save_checkpoint(trainer) def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - """ - Save a checkpoint when training stops. + """Save a checkpoint when training stops. This will only save a checkpoint if `save_last` is also enabled as the monitor metrics logged during training/validation steps or end of epochs are not guaranteed to be available at this stage. @@ -364,10 +363,10 @@ def on_load_checkpoint( self.best_model_path = callback_state["best_model_path"] def save_checkpoint(self, trainer: "pl.Trainer") -> None: - """ - Performs the main logic around saving a checkpoint. This method runs on all ranks. - It is the responsibility of `trainer.save_checkpoint` to correctly handle the behaviour in distributed training, - i.e., saving only on rank 0 for data parallel use cases. + """Performs the main logic around saving a checkpoint. + + This method runs on all ranks. It is the responsibility of `trainer.save_checkpoint` to correctly handle the + behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases. """ epoch = trainer.current_epoch global_step = trainer.global_step @@ -586,7 +585,6 @@ def format_checkpoint_name(self, metrics: Dict[str, _METRIC], ver: Optional[int] >>> ckpt = ModelCheckpoint(filename='{step}') >>> os.path.basename(ckpt.format_checkpoint_name(dict(step=0))) 'step=0.ckpt' - """ filename = self._format_checkpoint_name( self.filename, metrics, auto_insert_metric_name=self.auto_insert_metric_name @@ -599,10 +597,8 @@ def format_checkpoint_name(self, metrics: Dict[str, _METRIC], ver: Optional[int] return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None: - """ - Determines model checkpoint save directory at runtime. References attributes from the - trainer's logger to determine where to save checkpoints. - The base path for saving weights is set in this priority: + """Determines model checkpoint save directory at runtime. References attributes from the trainer's logger + to determine where to save checkpoints. The base path for saving weights is set in this priority: 1. Checkpoint callback's path (if passed in) 2. The default_root_dir from trainer if trainer has no logger @@ -761,10 +757,8 @@ def _update_best_and_save( self._del_model(trainer, del_filepath) def to_yaml(self, filepath: Optional[Union[str, Path]] = None) -> None: - """ - Saves the `best_k_models` dict containing the checkpoint - paths with the corresponding scores to a YAML file. - """ + """Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML + file.""" best_k = {k: v.item() for k, v in self.best_k_models.items()} if filepath is None: filepath = os.path.join(self.dirpath, "best_k_models.yaml") @@ -772,9 +766,7 @@ def to_yaml(self, filepath: Optional[Union[str, Path]] = None) -> None: yaml.dump(best_k, fp) def file_exists(self, filepath: Union[str, Path], trainer: "pl.Trainer") -> bool: - """ - Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing - the internal state to diverge between ranks. - """ + """Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal + state to diverge between ranks.""" exists = self._fs.exists(filepath) return trainer.training_type_plugin.broadcast(exists) diff --git a/pytorch_lightning/callbacks/prediction_writer.py b/pytorch_lightning/callbacks/prediction_writer.py index 841722546c005..195a7f8be1cdc 100644 --- a/pytorch_lightning/callbacks/prediction_writer.py +++ b/pytorch_lightning/callbacks/prediction_writer.py @@ -40,8 +40,7 @@ def on_epoch(self) -> bool: class BasePredictionWriter(Callback): - """ - Base class to implement how the predictions should be stored. + """Base class to implement how the predictions should be stored. Args: write_interval: When to write. diff --git a/pytorch_lightning/callbacks/progress/__init__.py b/pytorch_lightning/callbacks/progress/__init__.py index 441c79a5ab1c6..3fa7b1afe6b44 100644 --- a/pytorch_lightning/callbacks/progress/__init__.py +++ b/pytorch_lightning/callbacks/progress/__init__.py @@ -19,5 +19,5 @@ """ from pytorch_lightning.callbacks.progress.base import ProgressBarBase # noqa: F401 -from pytorch_lightning.callbacks.progress.progress import ProgressBar, tqdm # noqa: F401 from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBar # noqa: F401 +from pytorch_lightning.callbacks.progress.tqdm_progress import ProgressBar # noqa: F401 diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index c1963345fd2e4..fe4c89a66d665 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -56,51 +56,51 @@ def trainer(self): @property def train_batch_idx(self) -> int: - """ - The current batch index being processed during training. + """The current batch index being processed during training. + Use this to update your progress bar. """ return self._train_batch_idx @property def val_batch_idx(self) -> int: - """ - The current batch index being processed during validation. + """The current batch index being processed during validation. + Use this to update your progress bar. """ return self._val_batch_idx @property def test_batch_idx(self) -> int: - """ - The current batch index being processed during testing. + """The current batch index being processed during testing. + Use this to update your progress bar. """ return self._test_batch_idx @property def predict_batch_idx(self) -> int: - """ - The current batch index being processed during predicting. + """The current batch index being processed during predicting. + Use this to update your progress bar. """ return self._predict_batch_idx @property def total_train_batches(self) -> int: - """ - The total number of training batches during training, which may change from epoch to epoch. - Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the - training dataloader is of infinite size. + """The total number of training batches during training, which may change from epoch to epoch. + + Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the training + dataloader is of infinite size. """ return self.trainer.num_training_batches @property def total_val_batches(self) -> int: - """ - The total number of validation batches during validation, which may change from epoch to epoch. - Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the - validation dataloader is of infinite size. + """The total number of validation batches during validation, which may change from epoch to epoch. + + Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the validation + dataloader is of infinite size. """ total_val_batches = 0 if self.trainer.enable_validation: @@ -111,33 +111,33 @@ def total_val_batches(self) -> int: @property def total_test_batches(self) -> int: - """ - The total number of testing batches during testing, which may change from epoch to epoch. - Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the - test dataloader is of infinite size. + """The total number of testing batches during testing, which may change from epoch to epoch. + + Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the test dataloader is + of infinite size. """ return sum(self.trainer.num_test_batches) @property def total_predict_batches(self) -> int: - """ - The total number of predicting batches during testing, which may change from epoch to epoch. - Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the - predict dataloader is of infinite size. + """The total number of predicting batches during testing, which may change from epoch to epoch. + + Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the predict dataloader + is of infinite size. """ return sum(self.trainer.num_predict_batches) def disable(self): - """ - You should provide a way to disable the progress bar. + """You should provide a way to disable the progress bar. + The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this to disable the output on processes that have a rank different from 0, e.g., in multi-node training. """ raise NotImplementedError def enable(self): - """ - You should provide a way to enable the progress bar. + """You should provide a way to enable the progress bar. + The :class:`~pytorch_lightning.trainer.trainer.Trainer` will call this in e.g. pre-training routines like the :ref:`learning rate finder ` to temporarily enable and disable the main progress bar. @@ -145,9 +145,7 @@ def enable(self): raise NotImplementedError def print(self, *args, **kwargs): - """ - You should provide a way to print without breaking the progress bar. - """ + """You should provide a way to print without breaking the progress bar.""" print(*args, **kwargs) def on_init_end(self, trainer): diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 8182933a32079..cfd0b3a36f2ce 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -80,8 +80,7 @@ def render(self, task) -> Text: class RichProgressBar(ProgressBarBase): - """ - Create a progress bar with `rich text formatting `_. + """Create a progress bar with `rich text formatting `_. Install it with pip: diff --git a/pytorch_lightning/callbacks/progress/progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py similarity index 83% rename from pytorch_lightning/callbacks/progress/progress.py rename to pytorch_lightning/callbacks/progress/tqdm_progress.py index aaea8676807f3..51bee9f624b87 100644 --- a/pytorch_lightning/callbacks/progress/progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -30,14 +30,16 @@ _PAD_SIZE = 5 -class tqdm(_tqdm): - """ - Custom tqdm progressbar where we append 0 to floating points/strings to prevent the progress bar from flickering - """ +class Tqdm(_tqdm): + def __init__(self, *args, **kwargs): + """Custom tqdm progressbar where we append 0 to floating points/strings to prevent the progress bar from + flickering.""" + # this just to make the make docs happy, otherwise it pulls docs which has some issues... + super().__init__(*args, **kwargs) @staticmethod def format_num(n) -> str: - """Add additional padding to the formatted numbers""" + """Add additional padding to the formatted numbers.""" should_be_padded = isinstance(n, (float, str)) if not isinstance(n, str): n = _tqdm.format_num(n) @@ -54,45 +56,42 @@ def format_num(n) -> str: class ProgressBar(ProgressBarBase): r""" - This is the default progress bar used by Lightning. It prints to `stdout` using the + This is the default progress bar used by Lightning. It prints to ``stdout`` using the :mod:`tqdm` package and shows up to four different bars: - - **sanity check progress:** the progress during the sanity check run - - **main progress:** shows training + validation progress combined. It also accounts for - multiple validation runs during training when - :paramref:`~pytorch_lightning.trainer.trainer.Trainer.val_check_interval` is used. - - **validation progress:** only visible during validation; - shows total progress over all validation datasets. - - **test progress:** only active when testing; shows total progress over all test datasets. + - **sanity check progress:** the progress during the sanity check run + - **main progress:** shows training + validation progress combined. It also accounts for + multiple validation runs during training when + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.val_check_interval` is used. + - **validation progress:** only visible during validation; + shows total progress over all validation datasets. + - **test progress:** only active when testing; shows total progress over all test datasets. For infinite datasets, the progress bar never ends. If you want to customize the default ``tqdm`` progress bars used by Lightning, you can override specific methods of the callback class and pass your custom implementation to the - :class:`~pytorch_lightning.trainer.trainer.Trainer`: - - Example:: + :class:`~pytorch_lightning.trainer.trainer.Trainer`. - class LitProgressBar(ProgressBar): + Example: - def init_validation_tqdm(self): - bar = super().init_validation_tqdm() - bar.set_description('running validation ...') - return bar - - bar = LitProgressBar() - trainer = Trainer(callbacks=[bar]) + >>> class LitProgressBar(ProgressBar): + ... def init_validation_tqdm(self): + ... bar = super().init_validation_tqdm() + ... bar.set_description('running validation ...') + ... return bar + ... + >>> bar = LitProgressBar() + >>> from pytorch_lightning import Trainer + >>> trainer = Trainer(callbacks=[bar]) Args: - refresh_rate: - Determines at which rate (in number of batches) the progress bars get updated. - Set it to ``0`` to disable the display. By default, the - :class:`~pytorch_lightning.trainer.trainer.Trainer` uses this implementation of the progress - bar and sets the refresh rate to the value provided to the + refresh_rate: Determines at which rate (in number of batches) the progress bars get updated. + Set it to ``0`` to disable the display. By default, the :class:`~pytorch_lightning.trainer.trainer.Trainer` + uses this implementation of the progress bar and sets the refresh rate to the value provided to the :paramref:`~pytorch_lightning.trainer.trainer.Trainer.progress_bar_refresh_rate` argument in the :class:`~pytorch_lightning.trainer.trainer.Trainer`. - process_position: - Set this to a value greater than ``0`` to offset the progress bars by this many lines. + process_position: Set this to a value greater than ``0`` to offset the progress bars by this many lines. This is useful when you have progress bars defined elsewhere and want to show all of them together. This corresponds to :paramref:`~pytorch_lightning.trainer.trainer.Trainer.process_position` in the @@ -140,9 +139,9 @@ def disable(self) -> None: def enable(self) -> None: self._enabled = True - def init_sanity_tqdm(self) -> tqdm: + def init_sanity_tqdm(self) -> Tqdm: """Override this to customize the tqdm bar for the validation sanity run.""" - bar = tqdm( + bar = Tqdm( desc="Validation sanity check", position=(2 * self.process_position), disable=self.is_disabled, @@ -152,9 +151,9 @@ def init_sanity_tqdm(self) -> tqdm: ) return bar - def init_train_tqdm(self) -> tqdm: + def init_train_tqdm(self) -> Tqdm: """Override this to customize the tqdm bar for training.""" - bar = tqdm( + bar = Tqdm( desc="Training", initial=self.train_batch_idx, position=(2 * self.process_position), @@ -166,9 +165,9 @@ def init_train_tqdm(self) -> tqdm: ) return bar - def init_predict_tqdm(self) -> tqdm: + def init_predict_tqdm(self) -> Tqdm: """Override this to customize the tqdm bar for predicting.""" - bar = tqdm( + bar = Tqdm( desc="Predicting", initial=self.train_batch_idx, position=(2 * self.process_position), @@ -180,11 +179,11 @@ def init_predict_tqdm(self) -> tqdm: ) return bar - def init_validation_tqdm(self) -> tqdm: + def init_validation_tqdm(self) -> Tqdm: """Override this to customize the tqdm bar for validation.""" # The main progress bar doesn't exist in `trainer.validate()` has_main_bar = self.main_progress_bar is not None - bar = tqdm( + bar = Tqdm( desc="Validating", position=(2 * self.process_position + has_main_bar), disable=self.is_disabled, @@ -194,9 +193,9 @@ def init_validation_tqdm(self) -> tqdm: ) return bar - def init_test_tqdm(self) -> tqdm: + def init_test_tqdm(self) -> Tqdm: """Override this to customize the tqdm bar for testing.""" - bar = tqdm( + bar = Tqdm( desc="Testing", position=(2 * self.process_position), disable=self.is_disabled, @@ -209,7 +208,7 @@ def init_test_tqdm(self) -> tqdm: def on_sanity_check_start(self, trainer, pl_module): super().on_sanity_check_start(trainer, pl_module) self.val_progress_bar = self.init_sanity_tqdm() - self.main_progress_bar = tqdm(disable=True) # dummy progress bar + self.main_progress_bar = Tqdm(disable=True) # dummy progress bar def on_sanity_check_end(self, trainer, pl_module): super().on_sanity_check_end(trainer, pl_module) @@ -313,7 +312,7 @@ def print( def _should_update(self, current, total) -> bool: return self.is_enabled and (current % self.refresh_rate == 0 or current == total) - def _update_bar(self, bar: Optional[tqdm]) -> None: + def _update_bar(self, bar: Optional[Tqdm]) -> None: """Updates the bar by the refresh rate without overshooting.""" if bar is None: return @@ -327,13 +326,16 @@ def _update_bar(self, bar: Optional[tqdm]) -> None: def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]: - """The tqdm doesn't support inf/nan values. We have to convert it to None.""" + """The tqdm doesn't support inf/nan values. + + We have to convert it to None. + """ if x is None or math.isinf(x) or math.isnan(x): return None return x -def reset(bar: tqdm, total: Optional[int] = None, current: int = 0) -> None: +def reset(bar: Tqdm, total: Optional[int] = None, current: int = 0) -> None: """Resets the tqdm bar to the desired position and sets a new total, unless it is disabled.""" if not bar.disable: bar.reset(total=convert_inf(total)) diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index bfcd97252f122..1edc4f153565c 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -78,9 +78,8 @@ def __init__( verbose: int = 0, prune_on_train_epoch_end: bool = True, ) -> None: - """ - Model pruning Callback, using PyTorch's prune utilities. - This callback is responsible of pruning networks parameters during training. + """Model pruning Callback, using PyTorch's prune utilities. This callback is responsible of pruning + networks parameters during training. To learn more about pruning with PyTorch, please take a look at `this tutorial `_. @@ -232,18 +231,14 @@ def __init__( self._verbose = verbose def filter_parameters_to_prune(self, parameters_to_prune: _PARAM_LIST = ()) -> _PARAM_LIST: - """ - This function can be overridden to control which module to prune. - """ + """This function can be overridden to control which module to prune.""" return parameters_to_prune def _create_pruning_fn(self, pruning_fn: str, **kwargs: Any) -> Union[Callable, pytorch_prune.BasePruningMethod]: - """ - This function takes `pruning_fn`, a function name. - - IF use_global_unstructured, pruning_fn will be resolved into its associated ``PyTorch BasePruningMethod`` - ELSE, pruning_fn will be resolved into its function counterpart from `torch.nn.utils.prune`. + """This function takes `pruning_fn`, a function name. + IF use_global_unstructured, pruning_fn will be resolved into its associated ``PyTorch BasePruningMethod`` ELSE, + pruning_fn will be resolved into its function counterpart from `torch.nn.utils.prune`. """ pruning_fn = ( _PYTORCH_PRUNING_METHOD[pruning_fn] @@ -265,8 +260,7 @@ def _wrap_pruning_fn(pruning_fn: Callable, **kwargs: Any) -> Callable: return partial(pruning_fn, **kwargs) def make_pruning_permanent(self, module: nn.Module) -> None: - """ - Removes pruning buffers from any pruned modules + """Removes pruning buffers from any pruned modules. Adapted from https://github.com/pytorch/pytorch/blob/1.7.1/torch/nn/utils/prune.py#L1176-L1180 """ @@ -443,9 +437,8 @@ def on_save_checkpoint( def sanitize_parameters_to_prune( pl_module: LightningModule, parameters_to_prune: _PARAM_LIST = (), parameter_names: Sequence[str] = () ) -> _PARAM_LIST: - """ - This function is responsible of sanitizing ``parameters_to_prune`` and ``parameter_names``. - If ``parameters_to_prune is None``, it will be generated with all parameters of the model. + """This function is responsible of sanitizing ``parameters_to_prune`` and ``parameter_names``. If + ``parameters_to_prune is None``, it will be generated with all parameters of the model. Raises: MisconfigurationException: diff --git a/pytorch_lightning/callbacks/quantization.py b/pytorch_lightning/callbacks/quantization.py index e1cc0d44c0ce4..140c774dde716 100644 --- a/pytorch_lightning/callbacks/quantization.py +++ b/pytorch_lightning/callbacks/quantization.py @@ -30,10 +30,9 @@ def wrap_qat_forward_context( quant_cb, model: "pl.LightningModule", func: Callable, trigger_condition: Optional[Union[Callable, int]] = None ) -> Callable: - """ - Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out compatibility - Moreover this version has the (de)quantization conditional as it may not be needed for the training all the time - """ + """Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out + compatibility Moreover this version has the (de)quantization conditional as it may not be needed for the + training all the time.""" # todo: consider using registering hook before/after forward @functools.wraps(func) def wrapper(data) -> Any: @@ -54,9 +53,8 @@ def wrapper(data) -> Any: def wrap_quantize_forward_context(model: "pl.LightningModule", func: Callable) -> Callable: - """ - Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out compatibility - """ + """Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out + compatibility.""" # todo: consider using registering hook before/after forward @functools.wraps(func) def wrapper(data) -> Any: @@ -69,7 +67,7 @@ def wrapper(data) -> Any: def _recursive_hasattr(obj: Any, attribs: str, state: bool = True) -> bool: - """recursive check if model has some layers denoted with '.'""" + """recursive check if model has some layers denoted with '.'.""" if "." in attribs: attrib, attribs = attribs.split(".", 1) if hasattr(obj, attrib): @@ -79,12 +77,9 @@ def _recursive_hasattr(obj: Any, attribs: str, state: bool = True) -> bool: class QuantizationAwareTraining(Callback): - """ - Quantization allows speeding up inference and decreasing memory requirements - by performing computations and storing tensors at lower bitwidths - (such as INT8 or FLOAT16) than floating point precision. - We use native PyTorch API so for more information - see `Quantization `_. + """Quantization allows speeding up inference and decreasing memory requirements by performing computations and + storing tensors at lower bitwidths (such as INT8 or FLOAT16) than floating point precision. We use native + PyTorch API so for more information see `PyTorch Quantization`_. .. warning:: ``QuantizationAwareTraining`` is in beta and subject to change. @@ -95,8 +90,7 @@ class QuantizationAwareTraining(Callback): - 'fbgemm' for server inference. - 'qnnpack' for mobile inference. - - a custom `torch.quantization.QConfig - `_. + - a custom `torch.quantization.QConfig`_. observer_type: allows switching between ``MovingAverageMinMaxObserver`` as "average" (default) and ``HistogramObserver`` as "histogram" which is more computationally expensive. @@ -127,6 +121,8 @@ def custom_trigger_last(trainer): quantize_on_fit_end: perform the quantization in `on_fit_end`. Note that once converted, the model cannot be put in training mode again. + .. _PyTorch Quantization: https://pytorch.org/docs/stable/quantization.html#quantization-aware-training + .. _torch.quantization.QConfig: https://pytorch.org/docs/stable/torch.quantization.html#torch.quantization.QConfig """ OBSERVER_TYPES = ("histogram", "average") diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index 12a9ac8275adb..bde9c1b5c2407 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -241,9 +241,7 @@ def transfer_weights(src_pl_module: "pl.LightningModule", dst_pl_module: "pl.Lig dst_param.detach().copy_(src_param.to(dst_param.device)) def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule"): - """ - Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154 - """ + """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154.""" self.momenta = {} for module in pl_module.modules(): if not isinstance(module, nn.modules.batchnorm._BatchNorm): @@ -259,9 +257,7 @@ def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule"): module.num_batches_tracked *= 0 def reset_momenta(self): - """ - Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165 - """ + """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165.""" for bn_module in self.momenta: bn_module.momentum = self.momenta[bn_module] @@ -269,9 +265,7 @@ def reset_momenta(self): def update_parameters( average_model: "pl.LightningModule", model: "pl.LightningModule", n_averaged: torch.LongTensor, avg_fn: _AVG_FN ): - """ - Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112 - """ + """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112.""" for p_swa, p_model in zip(average_model.parameters(), model.parameters()): device = p_swa.device p_swa_ = p_swa.detach() @@ -284,7 +278,5 @@ def update_parameters( def avg_fn( averaged_model_parameter: torch.Tensor, model_parameter: torch.Tensor, num_averaged: torch.LongTensor ) -> torch.FloatTensor: - """ - Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97 - """ + """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97.""" return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1) diff --git a/pytorch_lightning/callbacks/timer.py b/pytorch_lightning/callbacks/timer.py index 23894a1179c1f..ef7e586654225 100644 --- a/pytorch_lightning/callbacks/timer.py +++ b/pytorch_lightning/callbacks/timer.py @@ -36,9 +36,8 @@ class Interval(LightningEnum): class Timer(Callback): - """ - The Timer callback tracks the time spent in the training, validation, and test loops and interrupts the Trainer - if the given time limit for the training loop is reached. + """The Timer callback tracks the time spent in the training, validation, and test loops and interrupts the + Trainer if the given time limit for the training loop is reached. Args: duration: A string in the format DD:HH:MM:SS (days, hours, minutes seconds), or a :class:`datetime.timedelta`, diff --git a/pytorch_lightning/callbacks/xla_stats_monitor.py b/pytorch_lightning/callbacks/xla_stats_monitor.py index c74be10065f86..07e3008aa6cd1 100644 --- a/pytorch_lightning/callbacks/xla_stats_monitor.py +++ b/pytorch_lightning/callbacks/xla_stats_monitor.py @@ -29,9 +29,8 @@ class XLAStatsMonitor(Callback): - """ - Automatically monitors and logs XLA stats during training stage. ``XLAStatsMonitor`` - is a callback and in order to use it you need to assign a logger in the ``Trainer``. + """Automatically monitors and logs XLA stats during training stage. ``XLAStatsMonitor`` is a callback and in + order to use it you need to assign a logger in the ``Trainer``. Args: verbose: Set to ``True`` to print average peak and free memory, and epoch time @@ -47,7 +46,6 @@ class XLAStatsMonitor(Callback): >>> from pytorch_lightning.callbacks import XLAStatsMonitor >>> xla_stats = XLAStatsMonitor() # doctest: +SKIP >>> trainer = Trainer(callbacks=[xla_stats]) # doctest: +SKIP - """ def __init__(self, verbose: bool = True) -> None: diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 268b7eeb18f22..b16c9b0c51a8b 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -26,9 +26,8 @@ class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin): - """ - A DataModule standardizes the training, val, test splits, data preparation and transforms. - The main advantage is consistent data splits, data preparation and transforms across models. + """A DataModule standardizes the training, val, test splits, data preparation and transforms. The main + advantage is consistent data splits, data preparation and transforms across models. Example:: @@ -107,11 +106,9 @@ def __init__(self, train_transforms=None, val_transforms=None, test_transforms=N @property def train_transforms(self): - """ - Optional transforms (or collection of transforms) you can apply to train dataset + """Optional transforms (or collection of transforms) you can apply to train dataset. - .. deprecated:: v1.5 - Will be removed in v1.7.0. + .. deprecated:: v1.5 Will be removed in v1.7.0. """ rank_zero_deprecation( @@ -128,11 +125,9 @@ def train_transforms(self, t): @property def val_transforms(self): - """ - Optional transforms (or collection of transforms) you can apply to validation dataset + """Optional transforms (or collection of transforms) you can apply to validation dataset. - .. deprecated:: v1.5 - Will be removed in v1.7.0. + .. deprecated:: v1.5 Will be removed in v1.7.0. """ rank_zero_deprecation( @@ -149,11 +144,9 @@ def val_transforms(self, t): @property def test_transforms(self): - """ - Optional transforms (or collection of transforms) you can apply to test dataset + """Optional transforms (or collection of transforms) you can apply to test dataset. - .. deprecated:: v1.5 - Will be removed in v1.7.0. + .. deprecated:: v1.5 Will be removed in v1.7.0. """ rank_zero_deprecation( @@ -170,11 +163,9 @@ def test_transforms(self, t): @property def dims(self): - """ - A tuple describing the shape of your data. Extra functionality exposed in ``size``. + """A tuple describing the shape of your data. Extra functionality exposed in ``size``. - .. deprecated:: v1.5 - Will be removed in v1.7.0. + .. deprecated:: v1.5 Will be removed in v1.7.0. """ rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.") return self._dims @@ -185,12 +176,10 @@ def dims(self, d): self._dims = d def size(self, dim=None) -> Union[Tuple, List[Tuple]]: - """ - Return the dimension of each input either as a tuple or list of tuples. You can index this - just as you would with a torch tensor. + """Return the dimension of each input either as a tuple or list of tuples. You can index this just as you + would with a torch tensor. - .. deprecated:: v1.5 - Will be removed in v1.7.0. + .. deprecated:: v1.5 Will be removed in v1.7.0. """ rank_zero_deprecation("DataModule property `size` was deprecated in v1.5 and will be removed in v1.7.") diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index 88b66635d4eff..3936f2c6e9134 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -20,9 +20,8 @@ def parameter_validation(fn: Callable) -> Callable: - """ - Validates that the module parameter lengths match after moving to the device. It is useful - when tying weights on TPU's. + """Validates that the module parameter lengths match after moving to the device. It is useful when tying + weights on TPU's. Args: fn: ``model_to_device`` method diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 2d569b1f71b5f..f49b2f0fc0396 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -26,82 +26,61 @@ class ModelHooks: """Hooks to be used in LightningModule.""" def on_fit_start(self) -> None: - """ - Called at the very beginning of fit. + """Called at the very beginning of fit. + If on DDP it is called on every process """ def on_fit_end(self) -> None: - """ - Called at the very end of fit. + """Called at the very end of fit. + If on DDP it is called on every process """ def on_train_start(self) -> None: - """ - Called at the beginning of training after sanity check. - """ + """Called at the beginning of training after sanity check.""" def on_train_end(self) -> None: - """ - Called at the end of training before logger experiment is closed. - """ + """Called at the end of training before logger experiment is closed.""" def on_validation_start(self) -> None: - """ - Called at the beginning of validation. - """ + """Called at the beginning of validation.""" def on_validation_end(self) -> None: - """ - Called at the end of validation. - """ + """Called at the end of validation.""" def on_test_start(self) -> None: - """ - Called at the beginning of testing. - """ + """Called at the beginning of testing.""" def on_test_end(self) -> None: - """ - Called at the end of testing. - """ + """Called at the end of testing.""" def on_predict_start(self) -> None: - """ - Called at the beginning of predicting. - """ + """Called at the beginning of predicting.""" def on_predict_end(self) -> None: - """ - Called at the end of predicting. - """ + """Called at the end of predicting.""" def on_pretrain_routine_start(self) -> None: - """ - Called at the beginning of the pretrain routine (between fit and train start). + """Called at the beginning of the pretrain routine (between fit and train start). - fit - pretrain_routine start - pretrain_routine end - training_start - """ def on_pretrain_routine_end(self) -> None: - """ - Called at the end of the pretrain routine (between fit and train start). + """Called at the end of the pretrain routine (between fit and train start). - fit - pretrain_routine start - pretrain_routine end - training_start - """ def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: - """ - Called in the training loop before anything happens for that batch. + """Called in the training loop before anything happens for that batch. If you return -1 here, you will skip training for the rest of the current epoch. @@ -112,8 +91,7 @@ def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) """ def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int) -> None: - """ - Called in the training loop after the batch. + """Called in the training loop after the batch. Args: outputs: The outputs of training_step_end(training_step(x)) @@ -123,8 +101,7 @@ def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, d """ def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: - """ - Called in the validation loop before anything happens for that batch. + """Called in the validation loop before anything happens for that batch. Args: batch: The batched data as it is returned by the validation DataLoader. @@ -135,8 +112,7 @@ def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: def on_validation_batch_end( self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int ) -> None: - """ - Called in the validation loop after the batch. + """Called in the validation loop after the batch. Args: outputs: The outputs of validation_step_end(validation_step(x)) @@ -146,8 +122,7 @@ def on_validation_batch_end( """ def on_test_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: - """ - Called in the test loop before anything happens for that batch. + """Called in the test loop before anything happens for that batch. Args: batch: The batched data as it is returned by the test DataLoader. @@ -158,8 +133,7 @@ def on_test_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) - def on_test_batch_end( self, outputs: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int ) -> None: - """ - Called in the test loop after the batch. + """Called in the test loop after the batch. Args: outputs: The outputs of test_step_end(test_step(x)) @@ -169,8 +143,7 @@ def on_test_batch_end( """ def on_predict_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: - """ - Called in the predict loop before anything happens for that batch. + """Called in the predict loop before anything happens for that batch. Args: batch: The batched data as it is returned by the test DataLoader. @@ -179,8 +152,7 @@ def on_predict_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int """ def on_predict_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: int, dataloader_idx: int) -> None: - """ - Called in the predict loop after the batch. + """Called in the predict loop after the batch. Args: outputs: The outputs of predict_step_end(test_step(x)) @@ -190,53 +162,36 @@ def on_predict_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: in """ def on_validation_model_eval(self) -> None: - """ - Sets the model to eval during the val loop - """ + """Sets the model to eval during the val loop.""" self.trainer.model.eval() def on_validation_model_train(self) -> None: - """ - Sets the model to train during the val loop - """ + """Sets the model to train during the val loop.""" self.trainer.model.train() def on_test_model_train(self) -> None: - """ - Sets the model to train during the test loop - """ + """Sets the model to train during the test loop.""" self.trainer.model.train() def on_test_model_eval(self) -> None: - """ - Sets the model to eval during the test loop - """ + """Sets the model to eval during the test loop.""" self.trainer.model.eval() def on_predict_model_eval(self) -> None: - """ - Sets the model to eval during the predict loop - """ + """Sets the model to eval during the predict loop.""" self.trainer.model.eval() def on_epoch_start(self) -> None: - """ - Called when either of train/val/test epoch begins. - """ + """Called when either of train/val/test epoch begins.""" def on_epoch_end(self) -> None: - """ - Called when either of train/val/test epoch ends. - """ + """Called when either of train/val/test epoch ends.""" def on_train_epoch_start(self) -> None: - """ - Called in the training loop at the very beginning of the epoch. - """ + """Called in the training loop at the very beginning of the epoch.""" def on_train_epoch_end(self) -> None: - """ - Called in the training loop at the very end of the epoch. + """Called in the training loop at the very end of the epoch. To access all batch outputs at the end of the epoch, either: @@ -245,38 +200,25 @@ def on_train_epoch_end(self) -> None: """ def on_validation_epoch_start(self) -> None: - """ - Called in the validation loop at the very beginning of the epoch. - """ + """Called in the validation loop at the very beginning of the epoch.""" def on_validation_epoch_end(self) -> None: - """ - Called in the validation loop at the very end of the epoch. - """ + """Called in the validation loop at the very end of the epoch.""" def on_test_epoch_start(self) -> None: - """ - Called in the test loop at the very beginning of the epoch. - """ + """Called in the test loop at the very beginning of the epoch.""" def on_test_epoch_end(self) -> None: - """ - Called in the test loop at the very end of the epoch. - """ + """Called in the test loop at the very end of the epoch.""" def on_predict_epoch_start(self) -> None: - """ - Called at the beginning of predicting. - """ + """Called at the beginning of predicting.""" def on_predict_epoch_end(self, results: List[Any]) -> None: - """ - Called at the end of predicting. - """ + """Called at the end of predicting.""" def on_before_zero_grad(self, optimizer: Optimizer) -> None: - """ - Called after ``training_step()`` and before ``optimizer.zero_grad()``. + """Called after ``training_step()`` and before ``optimizer.zero_grad()``. Called in the training loop after taking an optimizer step and before zeroing grads. Good place to inspect weight information with weights updated. @@ -296,8 +238,7 @@ def on_before_zero_grad(self, optimizer: Optimizer) -> None: """ def on_before_backward(self, loss: torch.Tensor) -> None: - """ - Called before ``loss.backward()``. + """Called before ``loss.backward()``. Args: loss: Loss divided by number of batches for gradient accumulation and scaled if using native AMP. @@ -305,8 +246,7 @@ def on_before_backward(self, loss: torch.Tensor) -> None: pass def on_after_backward(self) -> None: - """ - Called after ``loss.backward()`` and before optimizers are stepped. + """Called after ``loss.backward()`` and before optimizers are stepped. Note: If using native AMP, the gradients will not be unscaled at this point. @@ -314,8 +254,7 @@ def on_after_backward(self) -> None: """ def on_before_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None: - """ - Called before ``optimizer.step()``. + """Called before ``optimizer.step()``. The hook is only called if gradients do not need to be accumulated. See: :paramref:`~pytorch_lightning.trainer.Trainer.accumulate_grad_batches`. @@ -339,10 +278,10 @@ def on_before_optimizer_step(self, optimizer, optimizer_idx): """ def on_post_move_to_device(self) -> None: - """ - Called in the ``parameter_validation`` decorator after :meth:`~pytorch_lightning.core.LightningModule.to` - is called. This is a good place to tie weights between modules after moving them to a device. Can be - used when training models with weight sharing properties on TPU. + """Called in the ``parameter_validation`` decorator after + :meth:`~pytorch_lightning.core.LightningModule.to` is called. This is a good place to tie weights between + modules after moving them to a device. Can be used when training models with weight sharing properties on + TPU. Addresses the handling of shared weights on TPU: https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks @@ -351,14 +290,12 @@ def on_post_move_to_device(self) -> None: def on_post_move_to_device(self): self.decoder.weight = self.encoder.weight - """ def configure_sharded_model(self) -> None: - """ - Hook to create modules in a distributed aware context. This is useful for when using sharded plugins, - where we'd like to shard the model instantly, which is useful for extremely large models - which can save memory and initialization time. + """Hook to create modules in a distributed aware context. This is useful for when using sharded plugins, + where we'd like to shard the model instantly, which is useful for extremely large models which can save + memory and initialization time. The accelerator manages whether to call this hook at every given stage. For sharded plugins where model parallelism is required, the hook is usually on called once @@ -383,8 +320,7 @@ def __init__(self) -> None: self.prepare_data_per_node: bool = True def prepare_data(self) -> None: - """ - Use this to download and prepare data. + """Use this to download and prepare data. .. warning:: DO NOT set state to the model (use `setup` instead) since this is NOT called on every GPU in DDP/TPU @@ -432,10 +368,9 @@ def prepare_data(self): """ def setup(self, stage: Optional[str] = None) -> None: - """ - Called at the beginning of fit (train + validate), validate, test, and predict. - This is a good hook when you need to build models dynamically or adjust something about them. - This hook is called on every process when using DDP. + """Called at the beginning of fit (train + validate), validate, test, and predict. This is a good hook when + you need to build models dynamically or adjust something about them. This hook is called on every process + when using DDP. Args: stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` @@ -456,20 +391,17 @@ def prepare_data(self): def setup(stage): data = Load_data(...) self.l1 = nn.Linear(28, data.num_classes) - """ def teardown(self, stage: Optional[str] = None) -> None: - """ - Called at the end of fit (train + validate), validate, test, predict, or tune. + """Called at the end of fit (train + validate), validate, test, predict, or tune. Args: stage: either ``'fit'``, ``'validate'``, ``'test'``, or ``'predict'`` """ def train_dataloader(self) -> TRAIN_DATALOADERS: - """ - Implement one or more PyTorch DataLoaders for training. + """Implement one or more PyTorch DataLoaders for training. Return: A collection of :class:`torch.utils.data.DataLoader` specifying training samples. @@ -538,7 +470,6 @@ def train_dataloader(self): ) # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar} return {'mnist': mnist_loader, 'cifar': cifar_loader} - """ raise NotImplementedError("`train_dataloader` must be implemented to be used with the Lightning Trainer") @@ -716,9 +647,8 @@ def on_predict_dataloader(self) -> None: """ def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any: - """ - Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors - wrapped in a custom data structure. + """Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors wrapped in a custom + data structure. The data types listed below (and any arbitrary nesting of them) are supported out of the box: @@ -774,8 +704,7 @@ def transfer_batch_to_device(self, batch, device, dataloader_idx): return move_data_to_device(batch, device) def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: - """ - Override to alter or apply batch augmentations to your batch before it is transferred to the device. + """Override to alter or apply batch augmentations to your batch before it is transferred to the device. Note: To check the current state of execution of this hook you can use @@ -810,8 +739,7 @@ def on_before_batch_transfer(self, batch, dataloader_idx): return batch def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: - """ - Override to alter or apply batch augmentations to your batch after it is transferred to the device. + """Override to alter or apply batch augmentations to your batch after it is transferred to the device. Note: To check the current state of execution of this hook you can use diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 2a77abc980e21..e1b4d1f3f7e58 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -119,8 +119,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: def optimizers( self, use_pl_optimizer: bool = True ) -> Union[Optimizer, LightningOptimizer, List[Optimizer], List[LightningOptimizer]]: - """ - Returns the optimizer(s) that are being used during training. Useful for manual optimization. + """Returns the optimizer(s) that are being used during training. Useful for manual optimization. Args: use_pl_optimizer: If ``True``, will wrap the optimizer(s) in a @@ -142,8 +141,8 @@ def optimizers( return opts def lr_schedulers(self) -> Optional[Union[Any, List[Any]]]: - """ - Returns the learning rate scheduler(s) that are being used during training. Useful for manual optimization. + """Returns the learning rate scheduler(s) that are being used during training. Useful for manual + optimization. Returns: A single scheduler, or a list of schedulers in case multiple ones are present, or ``None`` if no @@ -164,8 +163,7 @@ def lr_schedulers(self) -> Optional[Union[Any, List[Any]]]: @property def example_input_array(self) -> Any: - """ - The example input array is a specification of what the module can consume in the :meth:`forward` method. + """The example input array is a specification of what the module can consume in the :meth:`forward` method. The return type is interpreted as follows: - Single tensor: It is assumed the model takes a single argument, i.e., @@ -183,12 +181,18 @@ def example_input_array(self, example: Any) -> None: @property def current_epoch(self) -> int: - """The current epoch in the Trainer. If no Trainer is attached, this propery is 0.""" + """The current epoch in the Trainer. + + If no Trainer is attached, this propery is 0. + """ return self.trainer.current_epoch if self.trainer else 0 @property def global_step(self) -> int: - """Total training batches seen across all epochs. If no Trainer is attached, this propery is 0.""" + """Total training batches seen across all epochs. + + If no Trainer is attached, this propery is 0. + """ return self.trainer.global_step if self.trainer else 0 @property @@ -221,17 +225,15 @@ def loaded_optimizer_states_dict(self, val: dict) -> None: @property def on_gpu(self): - """ - Returns ``True`` if this model is currently located on a GPU. + """Returns ``True`` if this model is currently located on a GPU. + Useful to set flags around the LightningModule for different CPU vs GPU behavior. """ return self.device.type == "cuda" @property def automatic_optimization(self) -> bool: - """ - If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``. - """ + """If set to ``False`` you are responsible for calling ``.backward()``, ``.step()``, ``.zero_grad()``.""" return self._automatic_optimization @automatic_optimization.setter @@ -240,8 +242,9 @@ def automatic_optimization(self, automatic_optimization: bool) -> None: @property def truncated_bptt_steps(self) -> int: - """ - Enables `Truncated Backpropagation Through Time` in the Trainer when set to a positive integer. It represents + """Enables `Truncated Backpropagation Through Time` in the Trainer when set to a positive integer. + + It represents the number of times :meth:`training_step` gets called before backpropagation. If this is > 0, the :meth:`training_step` receives an additional argument ``hiddens`` and is expected to return a hidden state. """ @@ -316,8 +319,7 @@ def log( metric_attribute: Optional[str] = None, rank_zero_only: Optional[bool] = None, ) -> None: - """ - Log a key, value pair. + """Log a key, value pair. Example:: @@ -489,8 +491,7 @@ def log_dict( batch_size: Optional[int] = None, rank_zero_only: Optional[bool] = None, ) -> None: - """ - Log a dictionary of values at once. + """Log a dictionary of values at once. Example:: @@ -680,10 +681,8 @@ def training_step(self, batch, batch_idx, hiddens): rank_zero_warn("`training_step` must be implemented to be used with the Lightning Trainer") def training_step_end(self, *args, **kwargs) -> STEP_OUTPUT: - """ - Use this when training with dp or ddp2 because :meth:`training_step` - will operate on only part of the batch. However, this is still optional - and only needed for things like softmax or NCE loss. + """Use this when training with dp or ddp2 because :meth:`training_step` will operate on only part of the + batch. However, this is still optional and only needed for things like softmax or NCE loss. Note: If you later switch to ddp or some other mode, this will still be called @@ -743,9 +742,8 @@ def training_step_end(self, training_step_outputs): """ def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: - """ - Called at the end of the training epoch with the outputs of all training steps. - Use this in case you need to do something with all the outputs returned by :meth:`training_step`. + """Called at the end of the training epoch with the outputs of all training steps. Use this in case you + need to do something with all the outputs returned by :meth:`training_step`. .. code-block:: python @@ -873,10 +871,8 @@ def validation_step(self, batch, batch_idx, dataloader_idx): """ def validation_step_end(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: - """ - Use this when validating with dp or ddp2 because :meth:`validation_step` - will operate on only part of the batch. However, this is still optional - and only needed for things like softmax or NCE loss. + """Use this when validating with dp or ddp2 because :meth:`validation_step` will operate on only part of + the batch. However, this is still optional and only needed for things like softmax or NCE loss. Note: If you later switch to ddp or some other mode, this will still be called @@ -929,8 +925,7 @@ def validation_step_end(self, val_step_outputs): """ def validation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: - """ - Called at the end of the validation epoch with the outputs of all validation steps. + """Called at the end of the validation epoch with the outputs of all validation steps. .. code-block:: python @@ -1054,10 +1049,8 @@ def test_step(self, batch, batch_idx, dataloader_idx): """ def test_step_end(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: - """ - Use this when testing with dp or ddp2 because :meth:`test_step` will operate - on only part of the batch. However, this is still optional - and only needed for things like softmax or NCE loss. + """Use this when testing with dp or ddp2 because :meth:`test_step` will operate on only part of the batch. + However, this is still optional and only needed for things like softmax or NCE loss. Note: If you later switch to ddp or some other mode, this will still be called @@ -1110,8 +1103,7 @@ def test_step_end(self, output_results): """ def test_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: - """ - Called at the end of a test epoch with the output of all test steps. + """Called at the end of a test epoch with the output of all test steps. .. code-block:: python @@ -1161,10 +1153,9 @@ def test_epoch_end(self, outputs): """ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any: - """ - Step function called during :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. - By default, it calls :meth:`~pytorch_lightning.core.lightning.LightningModule.forward`. - Override to add any processing logic. + """Step function called during :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. By default, it + calls :meth:`~pytorch_lightning.core.lightning.LightningModule.forward`. Override to add any processing + logic. The :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` is used to scale inference on multi-devices. @@ -1200,14 +1191,11 @@ def predicts_step(self, batch, batch_idx, dataloader_idx): return self(batch) def configure_callbacks(self): - """ - Configure model-specific callbacks. - When the model gets attached, e.g., when ``.fit()`` or ``.test()`` gets called, - the list returned here will be merged with the list of callbacks passed to the Trainer's ``callbacks`` argument. - If a callback returned here has the same type as one or several callbacks already present in - the Trainer's callbacks list, it will take priority and replace them. - In addition, Lightning will make sure :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` - callbacks run last. + """Configure model-specific callbacks. When the model gets attached, e.g., when ``.fit()`` or ``.test()`` + gets called, the list returned here will be merged with the list of callbacks passed to the Trainer's + ``callbacks`` argument. If a callback returned here has the same type as one or several callbacks already + present in the Trainer's callbacks list, it will take priority and replace them. In addition, Lightning + will make sure :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks run last. Return: A list of callbacks which will extend the list of callbacks in the Trainer. @@ -1394,9 +1382,8 @@ def configure_optimizers(self): rank_zero_warn("`configure_optimizers` must be implemented to be used with the Lightning Trainer") def manual_backward(self, loss: Tensor, *args, **kwargs) -> None: - """ - Call this directly from your :meth:`training_step` when doing optimizations manually. - By using this, Lightning can ensure that all the proper scaling gets applied when using mixed precision. + """Call this directly from your :meth:`training_step` when doing optimizations manually. By using this, + Lightning can ensure that all the proper scaling gets applied when using mixed precision. See :ref:`manual optimization` for more examples. @@ -1424,9 +1411,8 @@ def training_step(...): def backward( self, loss: Tensor, optimizer: Optional[Optimizer], optimizer_idx: Optional[int], *args, **kwargs ) -> None: - """ - Called to perform backward on the loss returned in :meth:`training_step`. - Override this hook with your own implementation if you need to. + """Called to perform backward on the loss returned in :meth:`training_step`. Override this hook with your + own implementation if you need to. Args: loss: The loss tensor returned by :meth:`training_step`. If gradient accumulation is used, the loss here @@ -1442,11 +1428,9 @@ def backward(self, loss, optimizer, optimizer_idx): loss.backward(*args, **kwargs) def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int): - """ - Makes sure only the gradients of the current optimizer's parameters are calculated - in the training step to prevent dangling gradients in multiple-optimizer setup. - It works with :meth:`untoggle_optimizer` to make sure ``param_requires_grad_state`` is properly reset. - Override for your own behavior. + """Makes sure only the gradients of the current optimizer's parameters are calculated in the training step + to prevent dangling gradients in multiple-optimizer setup. It works with :meth:`untoggle_optimizer` to make + sure ``param_requires_grad_state`` is properly reset. Override for your own behavior. Args: optimizer: Current optimizer used in the training loop @@ -1475,9 +1459,8 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int): self._param_requires_grad_state = param_requires_grad_state def untoggle_optimizer(self, optimizer_idx: int): - """ - Resets the state of required gradients that were toggled with :meth:`toggle_optimizer`. - Override for your own behavior. + """Resets the state of required gradients that were toggled with :meth:`toggle_optimizer`. Override for + your own behavior. Args: optimizer_idx: Current optimizer idx in the training loop @@ -1665,8 +1648,7 @@ def tbptt_split_batch(self, batch, split_size): return splits def summarize(self, mode: Optional[str] = "top", max_depth: Optional[int] = None) -> Optional[ModelSummary]: - """ - Summarize this LightningModule. + """Summarize this LightningModule. .. deprecated:: v1.5 This method was deprecated in v1.5 in favor of `pytorch_lightning.utilities.model_summary.summarize` @@ -1708,14 +1690,12 @@ def freeze(self) -> None: self.eval() def unfreeze(self) -> None: - """ - Unfreeze all parameters for training. + """Unfreeze all parameters for training. .. code-block:: python model = MyLightningModule(...) model.unfreeze() - """ for param in self.parameters(): param.requires_grad = True @@ -1777,10 +1757,9 @@ def _verify_is_manual_optimization(self, fn_name): @classmethod def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]: - """ - Collect all module arguments in the current constructor and all child constructors. - The child constructors are all the ``__init__`` methods that reach the current class through - (chained) ``super().__init__()`` calls. + """Collect all module arguments in the current constructor and all child constructors. The child + constructors are all the ``__init__`` methods that reach the current class through (chained) + ``super().__init__()`` calls. Args: frame: instance frame @@ -1806,8 +1785,7 @@ def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]: @torch.no_grad() def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = None, **kwargs): - """ - Saves the model in ONNX format. + """Saves the model in ONNX format. Args: file_path: The path of the file the onnx model should be saved to. @@ -1860,12 +1838,11 @@ def to_torchscript( example_inputs: Optional[Any] = None, **kwargs, ) -> Union[ScriptModule, Dict[str, ScriptModule]]: - """ - By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. - If you want to use tracing, please provided the argument ``method='trace'`` and make sure that either the - `example_inputs` argument is provided, or the model has :attr:`example_input_array` set. - If you would like to customize the modules that are scripted you should override this method. - In case you want to return multiple modules, we recommend using a dictionary. + """By default compiles the whole model to a :class:`~torch.jit.ScriptModule`. If you want to use tracing, + please provided the argument ``method='trace'`` and make sure that either the `example_inputs` argument is + provided, or the model has :attr:`example_input_array` set. If you would like to customize the modules that + are scripted you should override this method. In case you want to return multiple modules, we recommend + using a dictionary. Args: file_path: Path where to save the torchscript. Default: None (no file saved). @@ -1943,9 +1920,8 @@ def model_size(self) -> float: return get_model_size_mb(self) def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: - """ - Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. - To avoid issues with memory sharing, we cast the data to numpy. + """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory + sharing, we cast the data to numpy. Args: queue: the instance of the queue to append the data. @@ -1956,9 +1932,8 @@ def add_to_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: queue.put(callback_metrics) def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None: - """ - Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. - To preserve consistency, we cast back the data to ``torch.Tensor``. + """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, + we cast back the data to ``torch.Tensor``. Args: queue: the instance of the queue from where to get the data. @@ -1986,9 +1961,9 @@ def __getstate__(self) -> Dict[str, Any]: return state def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None: - """ - Adds ShardedTensor state dict hooks if ShardedTensors are supported. These hooks ensure that - ShardedTensors are included when saving, and are loaded the LightningModule correctly. + """Adds ShardedTensor state dict hooks if ShardedTensors are supported. + + These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly. """ if not _TORCH_SHARDED_TENSOR_AVAILABLE: return diff --git a/pytorch_lightning/core/mixins/device_dtype_mixin.py b/pytorch_lightning/core/mixins/device_dtype_mixin.py index d74a391609e30..e02790edddd1e 100644 --- a/pytorch_lightning/core/mixins/device_dtype_mixin.py +++ b/pytorch_lightning/core/mixins/device_dtype_mixin.py @@ -109,10 +109,9 @@ def to(self, *args: Any, **kwargs: Any) -> "DeviceDtypeModuleMixin": return super().to(*args, **kwargs) def cuda(self, device: Optional[Union[torch.device, int]] = None) -> "DeviceDtypeModuleMixin": - """Moves all model parameters and buffers to the GPU. - This also makes associated parameters and buffers different objects. So - it should be called before constructing optimizer if the module will - live on GPU while being optimized. + """Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers + different objects. So it should be called before constructing optimizer if the module will live on GPU + while being optimized. Arguments: device: if specified, all parameters will be diff --git a/pytorch_lightning/core/mixins/hparams_mixin.py b/pytorch_lightning/core/mixins/hparams_mixin.py index c864de23d3cf2..0e722f2bdb683 100644 --- a/pytorch_lightning/core/mixins/hparams_mixin.py +++ b/pytorch_lightning/core/mixins/hparams_mixin.py @@ -126,8 +126,7 @@ def _to_hparams_dict(hp: Union[MutableMapping, Namespace, str]): @property def hparams(self) -> Union[AttributeDict, dict, Namespace]: - """ - The collection of hyperparameters saved with :meth:`save_hyperparameters`. It is mutable by the user. + """The collection of hyperparameters saved with :meth:`save_hyperparameters`. It is mutable by the user. For the frozen set of initial hyperparameters, use :attr:`hparams_initial`. Returns: @@ -139,8 +138,7 @@ def hparams(self) -> Union[AttributeDict, dict, Namespace]: @property def hparams_initial(self) -> AttributeDict: - """ - The collection of hyperparameters saved with :meth:`save_hyperparameters`. These contents are read-only. + """The collection of hyperparameters saved with :meth:`save_hyperparameters`. These contents are read-only. Manual updates to the saved hyperparameters can instead be performed through :attr:`hparams`. Returns: diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 07f98faf825d3..ba81644b9bd9a 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -26,10 +26,8 @@ def do_nothing_closure(): class LightningOptimizer: - """ - This class is used to wrap the user optimizers and handle properly - the backward and optimizer_step logic across accelerators, AMP, accumulate_grad_batches - """ + """This class is used to wrap the user optimizers and handle properly the backward and optimizer_step logic + across accelerators, AMP, accumulate_grad_batches.""" def __init__(self, optimizer: Optimizer): @@ -105,8 +103,7 @@ def _untoggle_model(self): @contextmanager def toggle_model(self, sync_grad: bool = True): - """ - This function is just a helper for advanced users. + """This function is just a helper for advanced users. Considering the current optimizer as A and all other optimizers as B. Toggling means all parameters from B exclusive to A will have ``requires_grad`` set to False. @@ -131,10 +128,8 @@ def __optimizer_step(self, closure: Callable, profiler_name: str = None, **kwarg trainer.accelerator.optimizer_step(self._optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs) def step(self, closure: Optional[Callable] = None, **kwargs): - """ - Call this directly from your training_step when doing optimizations manually. - By using this we can ensure that all the proper scaling when using 16-bit, accelerator etc - is been done properly for you. + """Call this directly from your training_step when doing optimizations manually. By using this we can + ensure that all the proper scaling when using 16-bit, accelerator etc is been done properly for you. .. note:: In Manual Optimization, the user is expected to know when to call zero_grad, perform accumulated_grad_batches, etc ... Lightning will only take care of precision and accelerators @@ -196,7 +191,6 @@ def closure_dis(): with opt_dis.toggle_model(sync_grad=accumulated_grad_batches): opt_dis.step(closure=closure_dis) - """ if closure is None: profiler_name = f"closure_{self._optimizer_idx}" diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 79608bfc1c5c1..942f0e32cd8fe 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -216,8 +216,7 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], strict: bool = True, **cl # OPTIONAL HOOKS # ------------------------- def on_hpc_save(self, checkpoint: Dict[str, Any]) -> None: - """ - Hook to do whatever you need right before Slurm manager saves the model. + """Hook to do whatever you need right before Slurm manager saves the model. Args: checkpoint: A dictionary in which you can save variables to save in a checkpoint. @@ -225,8 +224,7 @@ def on_hpc_save(self, checkpoint: Dict[str, Any]) -> None: """ def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None: - """ - Hook to do whatever you need right before Slurm manager loads the model. + """Hook to do whatever you need right before Slurm manager loads the model. Args: checkpoint: A dictionary with variables from the checkpoint. @@ -246,8 +244,7 @@ def _convert_loaded_hparams(model_args: dict, hparams_type: Optional[Union[Calla def update_hparams(hparams: dict, updates: dict) -> None: - """ - Overrides hparams with new values + """Overrides hparams with new values. >>> hparams = {'c': 4} >>> update_hparams(hparams, {'a': {'b': 2}, 'c': 1}) @@ -260,7 +257,6 @@ def update_hparams(hparams: dict, updates: dict) -> None: Args: hparams: the original params and also target object updates: new params to be used as update - """ for k, v in updates.items(): # if missing, add the key diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 76331801f9cd1..5e9ae2a94e335 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -45,8 +45,7 @@ def get_experiment(): class LightningLoggerBase(ABC): - """ - Base class for experiment loggers. + """Base class for experiment loggers. Args: agg_key_funcs: @@ -73,8 +72,7 @@ def __init__( self._agg_default_func = agg_default_func def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None: - """ - Called after model checkpoint callback saves a new checkpoint + """Called after model checkpoint callback saves a new checkpoint. Args: checkpoint_callback: the model checkpoint callback instance @@ -86,8 +84,7 @@ def update_agg_funcs( agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, agg_default_func: Callable[[Sequence[float]], float] = np.mean, ): - """ - Update aggregation methods. + """Update aggregation methods. Args: agg_key_funcs: @@ -111,8 +108,7 @@ def experiment(self) -> Any: def _aggregate_metrics( self, metrics: Dict[str, float], step: Optional[int] = None ) -> Tuple[int, Optional[Dict[str, float]]]: - """ - Aggregates metrics. + """Aggregates metrics. Args: metrics: Dictionary with metric names as keys and measured quantities as values @@ -155,9 +151,7 @@ def _finalize_agg_metrics(self): self.log_metrics(metrics=metrics_to_log, step=agg_step) def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): - """ - Aggregates and records metrics. - This method doesn't log the passed metrics instantaneously, but instead + """Aggregates and records metrics. This method doesn't log the passed metrics instantaneously, but instead it aggregates them and logs only if metrics are ready to be logged. Args: @@ -196,8 +190,7 @@ def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]: @staticmethod def _sanitize_callable_params(params: Dict[str, Any]) -> Dict[str, Any]: - """ - Sanitize callable params dict, e.g. ``{'a': } -> {'a': 'function_****'}``. + """Sanitize callable params dict, e.g. ``{'a': } -> {'a': 'function_****'}``. Args: params: Dictionary containing the hyperparameters @@ -223,8 +216,7 @@ def _sanitize_callable(val): @staticmethod def _flatten_dict(params: Dict[Any, Any], delimiter: str = "/") -> Dict[str, Any]: - """ - Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``. + """Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``. Args: params: Dictionary containing the hyperparameters @@ -259,8 +251,7 @@ def _dict_generator(input_dict, prefixes=None): @staticmethod def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]: - """ - Returns params with non-primitvies converted to strings for logging. + """Returns params with non-primitvies converted to strings for logging. >>> params = {"float": 0.3, ... "int": 1, @@ -289,8 +280,7 @@ def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]: @abstractmethod def log_hyperparams(self, params: argparse.Namespace, *args, **kwargs): - """ - Record hyperparameters. + """Record hyperparameters. Args: params: :class:`~argparse.Namespace` containing the hyperparameters @@ -299,8 +289,7 @@ def log_hyperparams(self, params: argparse.Namespace, *args, **kwargs): """ def log_graph(self, model: "pl.LightningModule", input_array=None) -> None: - """ - Record model graph + """Record model graph. Args: model: lightning model @@ -313,8 +302,7 @@ def save(self) -> None: self._finalize_agg_metrics() def finalize(self, status: str) -> None: - """ - Do any processing that is necessary to finalize an experiment. + """Do any processing that is necessary to finalize an experiment. Args: status: Status that the experiment finished with (e.g. success, failed, aborted) @@ -327,10 +315,8 @@ def close(self) -> None: @property def save_dir(self) -> Optional[str]: - """ - Return the root directory where experiment logs get saved, or `None` if the logger does not - save data locally. - """ + """Return the root directory where experiment logs get saved, or `None` if the logger does not save data + locally.""" return None @property @@ -351,9 +337,7 @@ def _add_prefix(self, metrics: Dict[str, float]): class LoggerCollection(LightningLoggerBase): - """ - The :class:`LoggerCollection` class is used to iterate all logging actions over - the given `logger_iterable`. + """The :class:`LoggerCollection` class is used to iterate all logging actions over the given `logger_iterable`. Args: logger_iterable: An iterable collection of loggers @@ -380,9 +364,7 @@ def update_agg_funcs( @property def experiment(self) -> List[Any]: - """ - Returns a list of experiment objects for all the loggers in the logger collection. - """ + """Returns a list of experiment objects for all the loggers in the logger collection.""" return [logger.experiment for logger in self._logger_iterable] def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): @@ -415,29 +397,24 @@ def close(self) -> None: @property def save_dir(self) -> Optional[str]: - """ - Returns ``None`` as checkpoints should be saved to default / chosen location when using multiple loggers. - """ + """Returns ``None`` as checkpoints should be saved to default / chosen location when using multiple + loggers.""" # Checkpoints should be saved to default / chosen location when using multiple loggers return None @property def name(self) -> str: - """ - Returns the experiment names for all the loggers in the logger collection joined by an underscore. - """ + """Returns the experiment names for all the loggers in the logger collection joined by an underscore.""" return "_".join(str(logger.name) for logger in self._logger_iterable) @property def version(self) -> str: - """ - Returns the experiment versions for all the loggers in the logger collection joined by an underscore. - """ + """Returns the experiment versions for all the loggers in the logger collection joined by an underscore.""" return "_".join(str(logger.version) for logger in self._logger_iterable) class DummyExperiment: - """Dummy experiment""" + """Dummy experiment.""" def nop(self, *args, **kw): pass @@ -451,9 +428,9 @@ def __getitem__(self, idx) -> "DummyExperiment": class DummyLogger(LightningLoggerBase): - """ - Dummy logger for internal use. It is useful if we want to disable user's - logger for a feature, but still ensure that user code can run + """Dummy logger for internal use. + + It is useful if we want to disable user's logger for a feature, but still ensure that user code can run """ def __init__(self): @@ -491,9 +468,8 @@ def merge_dicts( agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, default_func: Callable[[Sequence[float]], float] = np.mean, ) -> Dict: - """ - Merge a sequence with dictionaries into one dictionary by aggregating the - same keys with some given function. + """Merge a sequence with dictionaries into one dictionary by aggregating the same keys with some given + function. Args: dicts: diff --git a/pytorch_lightning/loggers/comet.py b/pytorch_lightning/loggers/comet.py index 9bdafbd265352..43ecd46e853ae 100644 --- a/pytorch_lightning/loggers/comet.py +++ b/pytorch_lightning/loggers/comet.py @@ -268,8 +268,7 @@ def finalize(self, status: str) -> None: @property def save_dir(self) -> Optional[str]: - """ - Gets the save directory. + """Gets the save directory. Returns: The path to the save directory. @@ -278,8 +277,7 @@ def save_dir(self) -> Optional[str]: @property def name(self) -> str: - """ - Gets the project name. + """Gets the project name. Returns: The project name if it is specified, else "comet-default". @@ -295,8 +293,7 @@ def name(self) -> str: @property def version(self) -> str: - """ - Gets the version. + """Gets the version. Returns: The first one of the following that is set in the following order diff --git a/pytorch_lightning/loggers/csv_logs.py b/pytorch_lightning/loggers/csv_logs.py index 9decf9a7e33c6..2d0a2a3edb8ca 100644 --- a/pytorch_lightning/loggers/csv_logs.py +++ b/pytorch_lightning/loggers/csv_logs.py @@ -63,11 +63,11 @@ def __init__(self, log_dir: str) -> None: self.metrics_file_path = os.path.join(self.log_dir, self.NAME_METRICS_FILE) def log_hparams(self, params: Dict[str, Any]) -> None: - """Record hparams""" + """Record hparams.""" self.hparams.update(params) def log_metrics(self, metrics_dict: Dict[str, float], step: Optional[int] = None) -> None: - """Record metrics""" + """Record metrics.""" def _handle_value(value): if isinstance(value, torch.Tensor): @@ -82,7 +82,7 @@ def _handle_value(value): self.metrics.append(metrics) def save(self) -> None: - """Save recorded hparams and metrics into files""" + """Save recorded hparams and metrics into files.""" hparams_file = os.path.join(self.log_dir, self.NAME_HPARAMS_FILE) save_hparams_to_yaml(hparams_file, self.hparams) @@ -138,10 +138,10 @@ def __init__( @property def root_dir(self) -> str: - """ - Parent directory for all checkpoint subdirectories. - If the experiment name parameter is ``None`` or the empty string, no experiment subdirectory is used - and the checkpoint will be saved in "save_dir/version_dir" + """Parent directory for all checkpoint subdirectories. + + If the experiment name parameter is ``None`` or the empty string, no experiment subdirectory is used and the + checkpoint will be saved in "save_dir/version_dir" """ if not self.name: return self.save_dir @@ -149,10 +149,10 @@ def root_dir(self) -> str: @property def log_dir(self) -> str: - """ - The log directory for this run. By default, it is named - ``'version_${self.version}'`` but it can be overridden by passing a string value - for the constructor's version parameter instead of ``None`` or an int. + """The log directory for this run. + + By default, it is named ``'version_${self.version}'`` but it can be overridden by passing a string value for the + constructor's version parameter instead of ``None`` or an int. """ # create a pseudo standard path ala test-tube version = self.version if isinstance(self.version, str) else f"version_{self.version}" @@ -161,8 +161,7 @@ def log_dir(self) -> str: @property def save_dir(self) -> Optional[str]: - """ - The current directory where logs are saved. + """The current directory where logs are saved. Returns: The path to current directory where logs are saved. @@ -210,8 +209,7 @@ def finalize(self, status: str) -> None: @property def name(self) -> str: - """ - Gets the name of the experiment. + """Gets the name of the experiment. Returns: The name of the experiment. @@ -220,8 +218,7 @@ def name(self) -> str: @property def version(self) -> int: - """ - Gets the version of the experiment. + """Gets the version of the experiment. Returns: The version of the experiment if it is specified, else the next version. diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index 908c2b93a4286..ef225d06cd81a 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -53,8 +53,7 @@ def resolve_tags(tags=None): class MLFlowLogger(LightningLoggerBase): - """ - Log using `MLflow `_. + """Log using `MLflow `_. Install it with pip: @@ -172,8 +171,7 @@ def experiment(self) -> MlflowClient: @property def run_id(self) -> str: - """ - Create the experiment if it does not exist to get the run id. + """Create the experiment if it does not exist to get the run id. Returns: The run id. @@ -183,8 +181,7 @@ def run_id(self) -> str: @property def experiment_id(self) -> str: - """ - Create the experiment if it does not exist to get the experiment id. + """Create the experiment if it does not exist to get the experiment id. Returns: The experiment id. @@ -237,8 +234,7 @@ def finalize(self, status: str = "FINISHED") -> None: @property def save_dir(self) -> Optional[str]: - """ - The root file directory in which MLflow experiments are saved. + """The root file directory in which MLflow experiments are saved. Return: Local path to the root experiment directory if the tracking uri is local. @@ -249,8 +245,7 @@ def save_dir(self) -> Optional[str]: @property def name(self) -> str: - """ - Get the experiment id. + """Get the experiment id. Returns: The experiment id. @@ -259,8 +254,7 @@ def name(self) -> str: @property def version(self) -> str: - """ - Get the run id. + """Get the run id. Returns: The run id. diff --git a/pytorch_lightning/loggers/neptune.py b/pytorch_lightning/loggers/neptune.py index b810aebce47ba..b5a1de1a943d3 100644 --- a/pytorch_lightning/loggers/neptune.py +++ b/pytorch_lightning/loggers/neptune.py @@ -244,8 +244,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: @rank_zero_only def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Optional[int] = None) -> None: - """ - Log metrics (numeric values) in Neptune experiments. + """Log metrics (numeric values) in Neptune experiments. Args: metrics: Dictionary with metric names as keys and measured quantities as values @@ -267,8 +266,8 @@ def finalize(self, status: str) -> None: @property def save_dir(self) -> Optional[str]: - """ - Gets the save directory of the experiment which in this case is ``None`` because Neptune does not save locally. + """Gets the save directory of the experiment which in this case is ``None`` because Neptune does not save + locally. Returns: None @@ -278,8 +277,7 @@ def save_dir(self) -> Optional[str]: @property def name(self) -> str: - """ - Gets the name of the experiment. + """Gets the name of the experiment. Returns: The name of the experiment if not in offline mode else "offline-name". @@ -290,8 +288,7 @@ def name(self) -> str: @property def version(self) -> str: - """ - Gets the id of the experiment. + """Gets the id of the experiment. Returns: The id of the experiment if not in offline mode else "offline-id-1234". @@ -304,8 +301,7 @@ def version(self) -> str: def log_metric( self, metric_name: str, metric_value: Union[torch.Tensor, float, str], step: Optional[int] = None ) -> None: - """ - Log metrics (numeric values) in Neptune experiments. + """Log metrics (numeric values) in Neptune experiments. Args: metric_name: The name of log, i.e. mse, loss, accuracy. @@ -322,8 +318,7 @@ def log_metric( @rank_zero_only def log_text(self, log_name: str, text: str, step: Optional[int] = None) -> None: - """ - Log text data in Neptune experiments. + """Log text data in Neptune experiments. Args: log_name: The name of log, i.e. mse, my_text_data, timing_info. @@ -337,8 +332,7 @@ def log_text(self, log_name: str, text: str, step: Optional[int] = None) -> None @rank_zero_only def log_image(self, log_name: str, image: Union[str, Any], step: Optional[int] = None) -> None: - """ - Log image data in Neptune experiment + """Log image data in Neptune experiment. Args: log_name: The name of log, i.e. bboxes, visualisations, sample_images. @@ -365,8 +359,7 @@ def log_artifact(self, artifact: str, destination: Optional[str] = None) -> None @rank_zero_only def set_property(self, key: str, value: Any) -> None: - """ - Set key-value pair as Neptune experiment property. + """Set key-value pair as Neptune experiment property. Args: key: Property key. @@ -376,8 +369,7 @@ def set_property(self, key: str, value: Any) -> None: @rank_zero_only def append_tags(self, tags: Union[str, Iterable[str]]) -> None: - """ - Appends tags to the neptune experiment. + """Appends tags to the neptune experiment. Args: tags: Tags to add to the current experiment. If str is passed, a single tag is added. diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index a9e1118b9d647..6abc809bc65d1 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -107,10 +107,10 @@ def __init__( @property def root_dir(self) -> str: - """ - Parent directory for all tensorboard checkpoint subdirectories. - If the experiment name parameter is ``None`` or the empty string, no experiment subdirectory is used - and the checkpoint will be saved in "save_dir/version_dir" + """Parent directory for all tensorboard checkpoint subdirectories. + + If the experiment name parameter is ``None`` or the empty string, no experiment subdirectory is used and the + checkpoint will be saved in "save_dir/version_dir" """ if self.name is None or len(self.name) == 0: return self.save_dir @@ -118,10 +118,10 @@ def root_dir(self) -> str: @property def log_dir(self) -> str: - """ - The directory for this run's tensorboard checkpoint. By default, it is named - ``'version_${self.version}'`` but it can be overridden by passing a string value - for the constructor's version parameter instead of ``None`` or an int. + """The directory for this run's tensorboard checkpoint. + + By default, it is named ``'version_${self.version}'`` but it can be overridden by passing a string value for the + constructor's version parameter instead of ``None`` or an int. """ # create a pseudo standard path ala test-tube version = self.version if isinstance(self.version, str) else f"version_{self.version}" @@ -134,8 +134,7 @@ def log_dir(self) -> str: @property def save_dir(self) -> Optional[str]: - """ - Gets the save directory where the TensorBoard experiments are saved. + """Gets the save directory where the TensorBoard experiments are saved. Returns: The local path to the save directory where the TensorBoard experiments are saved. @@ -144,8 +143,7 @@ def save_dir(self) -> Optional[str]: @property def sub_dir(self) -> Optional[str]: - """ - Gets the sub directory where the TensorBoard experiments are saved. + """Gets the sub directory where the TensorBoard experiments are saved. Returns: The local path to the sub directory where the TensorBoard experiments are saved. @@ -177,10 +175,9 @@ def experiment(self) -> SummaryWriter: def log_hyperparams( self, params: Union[Dict[str, Any], Namespace], metrics: Optional[Dict[str, Any]] = None ) -> None: - """ - Record hyperparameters. TensorBoard logs with and without saved hyperparameters - are incompatible, the hyperparameters are then not displayed in the TensorBoard. - Please delete or move the previously saved logs to display the new ones with hyperparameters. + """Record hyperparameters. TensorBoard logs with and without saved hyperparameters are incompatible, the + hyperparameters are then not displayed in the TensorBoard. Please delete or move the previously saved logs + to display the new ones with hyperparameters. Args: params: a dictionary-like container with the hyperparameters @@ -270,8 +267,7 @@ def finalize(self, status: str) -> None: @property def name(self) -> str: - """ - Get the name of the experiment. + """Get the name of the experiment. Returns: The name of the experiment. @@ -280,8 +276,7 @@ def name(self) -> str: @property def version(self) -> int: - """ - Get the experiment version. + """Get the experiment version. Returns: The experiment version if specified else the next version. diff --git a/pytorch_lightning/loggers/test_tube.py b/pytorch_lightning/loggers/test_tube.py index 9a3b6ccee64df..eeeedfab584f7 100644 --- a/pytorch_lightning/loggers/test_tube.py +++ b/pytorch_lightning/loggers/test_tube.py @@ -205,8 +205,7 @@ def close(self) -> None: @property def save_dir(self) -> Optional[str]: - """ - Gets the save directory. + """Gets the save directory. Returns: The path to the save directory. @@ -215,8 +214,7 @@ def save_dir(self) -> Optional[str]: @property def name(self) -> str: - """ - Gets the experiment name. + """Gets the experiment name. Returns: The experiment name if the experiment exists, else the name specified in the constructor. @@ -228,8 +226,7 @@ def name(self) -> str: @property def version(self) -> int: - """ - Gets the experiment version. + """Gets the experiment version. Returns: The experiment version if the experiment exists, else the next version. diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 7081e95d352a3..c93b8d02bca16 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -219,8 +219,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> @property def save_dir(self) -> Optional[str]: - """ - Gets the save directory. + """Gets the save directory. Returns: The path to the save directory. @@ -229,8 +228,7 @@ def save_dir(self) -> Optional[str]: @property def name(self) -> Optional[str]: - """ - Gets the name of the experiment. + """Gets the name of the experiment. Returns: The name of the experiment if the experiment exists else the name given to the constructor. @@ -240,8 +238,7 @@ def name(self) -> Optional[str]: @property def version(self) -> Optional[str]: - """ - Gets the id of the experiment. + """Gets the id of the experiment. Returns: The id of the experiment if the experiment exists else the id given to the constructor. diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index d5b528f53148e..503c417a2870f 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -26,8 +26,7 @@ class Loop(ABC): - """ - Basic Loops interface. All classes derived from this must implement the following properties and methods: + """Basic Loops interface. All classes derived from this must implement the following properties and methods: * :attr:`done` (property): Condition to break the loop * :attr:`reset` (method): Resets the internal state between multiple calls of :attr:`run` @@ -59,7 +58,7 @@ def trainer(self) -> "pl.Trainer": @trainer.setter def trainer(self, trainer: "pl.Trainer"): - """Connects this loop's trainer and its children""" + """Connects this loop's trainer and its children.""" if not isinstance(trainer, pl.Trainer): raise MisconfigurationException( f"Loop {self.__class__.__name__} should be connected to a `Trainer`, found: {trainer}." @@ -72,7 +71,7 @@ def trainer(self, trainer: "pl.Trainer"): @property @abstractmethod def done(self) -> bool: - """Property indicating when loop is finished""" + """Property indicating when loop is finished.""" @property def skip(self) -> bool: @@ -80,19 +79,20 @@ def skip(self) -> bool: return False def connect(self, **kwargs: "Loop") -> None: - """Optionally connect one or multiple loops to this one. Linked loops should form a tree.""" + """Optionally connect one or multiple loops to this one. - def on_skip(self) -> Optional[Any]: + Linked loops should form a tree. """ - The function to run when :meth:`run` should be skipped, determined by the condition in :attr:`skip`. + + def on_skip(self) -> Optional[Any]: + """The function to run when :meth:`run` should be skipped, determined by the condition in :attr:`skip`. Returns: the default output value of :meth:`on_run_end` """ def run(self, *args: Any, **kwargs: Any) -> Optional[Any]: - """ - The main entry point to the loop. + """The main entry point to the loop. Will frequently check the :attr:`done` condition and calls :attr:`advance` until :attr:`done` evaluates to ``True``. @@ -124,35 +124,40 @@ def reset(self) -> None: """Resets the internal state of the loop at the beginning of each call to :attr:`run`.""" def on_run_start(self, *args: Any, **kwargs: Any) -> None: - """ - Hook to be called as the first thing after entering :attr:`run` (except the state reset). + """Hook to be called as the first thing after entering :attr:`run` (except the state reset). Accepts all arguments passed to :attr:`run`. """ void(*args, **kwargs) def on_advance_start(self, *args: Any, **kwargs: Any) -> None: - """ - Hook to be called each time before :attr:`advance` is called. Accepts all arguments passed to :attr`run`. + """Hook to be called each time before :attr:`advance` is called. + + Accepts all arguments passed to :attr`run`. """ void(*args, **kwargs) @abstractmethod def advance(self, *args: Any, **kwargs: Any) -> None: - """Performs a single step. Accepts all arguments passed to :attr:`run`.""" + """Performs a single step. + + Accepts all arguments passed to :attr:`run`. + """ def on_advance_end(self) -> None: """Hook to be called each time after :attr:`advance` is called.""" def on_run_end(self) -> Any: - """Hook to be called at the end of the run. Its return argument is returned from :attr:`run`.""" + """Hook to be called at the end of the run. + + Its return argument is returned from :attr:`run`. + """ def teardown(self) -> None: """Use to release memory etc.""" def on_save_checkpoint(self) -> Dict: - """ - Called when saving a model checkpoint, use to persist loop state. + """Called when saving a model checkpoint, use to persist loop state. Returns: The current loop state. @@ -163,8 +168,7 @@ def on_load_checkpoint(self, state_dict: Dict) -> None: """Called when loading a model checkpoint, use to reload loop state.""" def state_dict(self, destination: Optional[Dict] = None, prefix: Optional[str] = "") -> Dict: - """ - The state dict is determined by the state and progress of this loop and all its children. + """The state dict is determined by the state and progress of this loop and all its children. Args: destination: An existing dictionary to update with this loop's state. By default a new dictionary diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 0a86486dc322a..c2757e4035b48 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -53,12 +53,12 @@ def __init__(self) -> None: @property def done(self) -> bool: - """Returns if all batch splits have been processed already""" + """Returns if all batch splits have been processed already.""" return len(self._remaining_splits) == 0 @property def optimizer_freq_cumsum(self) -> int: - """Returns the cumulated sum of optimizer frequencies""" + """Returns the cumulated sum of optimizer frequencies.""" if self._optimizer_freq_cumsum is None: self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies) return self._optimizer_freq_cumsum @@ -67,7 +67,7 @@ def connect(self, optimizer_loop: "Loop") -> None: self.optimizer_loop = optimizer_loop def run(self, batch: Any, batch_idx: int) -> AttributeDict: - """Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks + """Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks. Args: batch: the current batch to run the train step on @@ -96,12 +96,12 @@ def run(self, batch: Any, batch_idx: int) -> AttributeDict: return output def reset(self) -> None: - """Resets the loop state""" + """Resets the loop state.""" self._hiddens = None self.batch_outputs = [[] for _ in range(len(self.trainer.optimizers))] def on_run_start(self, batch: Any, batch_idx: int): - """Splits the data into tbptt splits + """Splits the data into tbptt splits. Args: batch: the current batch to run the trainstep on @@ -111,7 +111,7 @@ def on_run_start(self, batch: Any, batch_idx: int): self._remaining_splits = list(enumerate(self._tbptt_split_batch(batch))) def advance(self, batch, batch_idx): - """Runs the train step together with optimization (if necessary) on the current batch split + """Runs the train step together with optimization (if necessary) on the current batch split. Args: batch: the current batch to run the training on (this is not the split!) @@ -142,7 +142,7 @@ def teardown(self) -> None: self._remaining_splits = None def num_active_optimizers(self, batch_idx: Optional[int] = None) -> int: - """Gets the number of active optimizers based on their frequency""" + """Gets the number of active optimizers based on their frequency.""" return len(self.get_active_optimizers(batch_idx)) def _run_optimization( @@ -174,10 +174,8 @@ def _make_closure( batch_idx: int, hiddens: Any, ) -> Closure: - """ - Build a closure object that captures the given arguments and runs the `training_step` function and optionally - other functions such as `backward` and `zero_grad`. - """ + """Build a closure object that captures the given arguments and runs the `training_step` function and + optionally other functions such as `backward` and `zero_grad`.""" step_fn = self._make_step_fn(split_batch, batch_idx, hiddens) backward_fn = None zero_grad_fn = None @@ -246,7 +244,7 @@ def _tbptt_split_batch(self, batch: Any) -> List[Any]: return splits def _update_running_loss(self, current_loss: Tensor) -> None: - """Updates the running loss value with the current value""" + """Updates the running loss value with the current value.""" if self.trainer.lightning_module.automatic_optimization: # track total loss for logging (avoid mem leaks) self.accumulated_loss.append(current_loss) @@ -261,8 +259,7 @@ def _update_running_loss(self, current_loss: Tensor) -> None: self.accumulated_loss.reset() def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[int, Optimizer]]: - """ - Returns the currently active optimizers. When multiple optimizers are used with different frequencies, + """Returns the currently active optimizers. When multiple optimizers are used with different frequencies, only one of the optimizers is active at a time. Returns: diff --git a/pytorch_lightning/loops/closure.py b/pytorch_lightning/loops/closure.py index ba6c3a43c2a52..8097d6e15a5d7 100644 --- a/pytorch_lightning/loops/closure.py +++ b/pytorch_lightning/loops/closure.py @@ -39,8 +39,7 @@ class ClosureResult: class AbstractClosure(ABC): - """ - Abstract base class for optimizer closures in Lightning. + """Abstract base class for optimizer closures in Lightning. Formally, a closure is binding variables from an external scope to a function that does a computation on these variables without taking them explicitly as input. This has the benefit that a closure can be passed to an @@ -55,8 +54,11 @@ def __init__(self) -> None: self._result: Optional[ClosureResult] = None def get_result(self) -> Optional[ClosureResult]: - """The cached result from the last time the closure was called. Once accessed, the internal reference - gets reset and the consumer will have to hold on to the reference as long as necessary.""" + """The cached result from the last time the closure was called. + + Once accessed, the internal reference gets reset and the consumer will have to hold on to the reference as long + as necessary. + """ result = self._result self._result = None # free memory return result @@ -73,8 +75,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]: class Closure(AbstractClosure): - """ - An implementation of a :class:`AbstractClosure` for optimization in Lightning that combines three elementary + """An implementation of a :class:`AbstractClosure` for optimization in Lightning that combines three elementary closures into one: ``training_step``, ``backward`` and ``zero_grad``. The Closure gets created by the training loop(s) and is then passed to the diff --git a/pytorch_lightning/loops/dataloader/dataloader_loop.py b/pytorch_lightning/loops/dataloader/dataloader_loop.py index 65521aea547d8..6b5fecd07e807 100644 --- a/pytorch_lightning/loops/dataloader/dataloader_loop.py +++ b/pytorch_lightning/loops/dataloader/dataloader_loop.py @@ -22,7 +22,7 @@ class DataLoaderLoop(Loop): - """Base class to loop over all dataloaders""" + """Base class to loop over all dataloaders.""" def __init__(self): super().__init__() @@ -31,30 +31,30 @@ def __init__(self): @property @abstractmethod def dataloaders(self) -> Sequence[DataLoader]: - """Returns the dataloaders to loop over""" + """Returns the dataloaders to loop over.""" @property def current_dataloader_idx(self) -> int: - """Returns the index of the current dataloader""" + """Returns the index of the current dataloader.""" return self.dataloader_progress.current.ready - 1 @property def current_dataloader(self) -> DataLoader: - """Returns the current dataloader""" + """Returns the current dataloader.""" return self.dataloaders[self.current_dataloader_idx] @property def num_dataloaders(self) -> int: - """Returns the number of dataloaders present""" + """Returns the number of dataloaders present.""" return len(self.dataloaders) if self.dataloaders is not None else 0 @property def done(self) -> bool: - """Returns whether all dataloaders have been processed""" + """Returns whether all dataloaders have been processed.""" return self.dataloader_progress.current.completed >= self.num_dataloaders def reset(self) -> None: - """Resets the internal state""" + """Resets the internal state.""" if not self.restarting: self.dataloader_progress.current.reset() diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 65babb7bfd448..094f3dc3c7776 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -37,7 +37,7 @@ def __init__(self): @property def num_dataloaders(self) -> int: - """Returns the total number of dataloaders""" + """Returns the total number of dataloaders.""" # case where user does: # return dl1, dl2 dataloaders = self.dataloaders @@ -50,7 +50,7 @@ def num_dataloaders(self) -> int: @property def dataloaders(self) -> Sequence[DataLoader]: - """Returns the validation or test dataloaders""" + """Returns the validation or test dataloaders.""" if self.trainer.testing: return self.trainer.test_dataloaders return self.trainer.val_dataloaders @@ -61,7 +61,7 @@ def connect(self, epoch_loop: EvaluationEpochLoop): @property def done(self) -> bool: - """Returns whether all dataloaders are processed or evaluation should be skipped altogether""" + """Returns whether all dataloaders are processed or evaluation should be skipped altogether.""" return super().done or self.skip @property @@ -71,7 +71,7 @@ def skip(self) -> bool: return sum(max_batches) == 0 def reset(self) -> None: - """Resets the internal state of the loop""" + """Resets the internal state of the loop.""" self._max_batches = self.get_max_batches() # bookkeeping self.outputs = [] @@ -85,7 +85,8 @@ def on_skip(self) -> List: return [] def on_run_start(self, *args: Any, **kwargs: Any) -> None: - """Runs the ``on_evaluation_model_eval``, ``on_evaluation_start`` and ``on_evaluation_epoch_start`` hooks""" + """Runs the ``on_evaluation_model_eval``, ``on_evaluation_start`` and ``on_evaluation_epoch_start`` + hooks.""" void(*args, **kwargs) # hook self.on_evaluation_model_eval() @@ -94,7 +95,7 @@ def on_run_start(self, *args: Any, **kwargs: Any) -> None: self.on_evaluation_epoch_start() def advance(self, *args: Any, **kwargs: Any) -> None: - """Performs evaluation on one single dataloader""" + """Performs evaluation on one single dataloader.""" void(*args, **kwargs) dataloader_idx: int = self.current_dataloader_idx @@ -113,7 +114,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None: self._has_run = True def on_run_end(self) -> Any: - """Runs the ``on_evaluation_epoch_end`` hook""" + """Runs the ``on_evaluation_epoch_end`` hook.""" outputs = self.outputs # free memory @@ -141,7 +142,7 @@ def on_run_end(self) -> Any: return eval_loop_results def get_max_batches(self) -> List[Union[int, float]]: - """Returns the max number of batches for each dataloader""" + """Returns the max number of batches for each dataloader.""" if self.trainer.testing: max_batches = self.trainer.num_test_batches else: @@ -155,14 +156,14 @@ def get_max_batches(self) -> List[Union[int, float]]: return max_batches def reload_evaluation_dataloaders(self) -> None: - """Reloads dataloaders if necessary""" + """Reloads dataloaders if necessary.""" if self.trainer.testing: self.trainer.reset_test_dataloader() elif self.trainer.val_dataloaders is None or self.trainer._should_reload_dl_epoch: self.trainer.reset_val_dataloader() def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: - """Runs ``on_{validation/test}_start`` hooks""" + """Runs ``on_{validation/test}_start`` hooks.""" assert self._results is not None self._results.to(device=self.trainer.lightning_module.device) @@ -172,14 +173,14 @@ def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: self.trainer.call_hook("on_validation_start", *args, **kwargs) def on_evaluation_model_eval(self) -> None: - """Sets model to eval mode""" + """Sets model to eval mode.""" if self.trainer.testing: self.trainer.call_hook("on_test_model_eval") else: self.trainer.call_hook("on_validation_model_eval") def on_evaluation_model_train(self) -> None: - """Sets model to train mode""" + """Sets model to train mode.""" model_ref = self.trainer.lightning_module if self.trainer.testing: model_ref.on_test_model_train() @@ -187,7 +188,7 @@ def on_evaluation_model_train(self) -> None: model_ref.on_validation_model_train() def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None: - """Runs ``on_{validation/test}_end`` hook""" + """Runs ``on_{validation/test}_end`` hook.""" if self.trainer.testing: self.trainer.call_hook("on_test_end", *args, **kwargs) else: @@ -197,7 +198,7 @@ def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None: self.trainer.logger_connector.reset(metrics=True) def on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None: - """Runs ``on_epoch_start`` and ``on_{validation/test}_epoch_start`` hooks""" + """Runs ``on_epoch_start`` and ``on_{validation/test}_epoch_start`` hooks.""" self.trainer.logger_connector.on_epoch_start() self.trainer.call_hook("on_epoch_start", *args, **kwargs) @@ -228,7 +229,7 @@ def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: model.validation_epoch_end(outputs) def on_evaluation_epoch_end(self) -> None: - """Runs ``on_{validation/test}_epoch_end`` hook""" + """Runs ``on_{validation/test}_epoch_end`` hook.""" hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end" self.trainer.call_hook(hook_name) self.trainer.call_hook("on_epoch_end") diff --git a/pytorch_lightning/loops/dataloader/prediction_loop.py b/pytorch_lightning/loops/dataloader/prediction_loop.py index 8010948b7e819..9bacf1df2f06f 100644 --- a/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -11,7 +11,7 @@ class PredictionLoop(DataLoaderLoop): - """Loop to run over dataloaders for prediction""" + """Loop to run over dataloaders for prediction.""" def __init__(self): super().__init__() @@ -24,7 +24,7 @@ def __init__(self): @property def return_predictions(self) -> bool: - """Whether to return the predictions or not""" + """Whether to return the predictions or not.""" return self._return_predictions @return_predictions.setter @@ -41,7 +41,7 @@ def return_predictions(self, return_predictions: Optional[bool] = None) -> None: @property def num_dataloaders(self) -> int: - """Returns the number of prediction dataloaders""" + """Returns the number of prediction dataloaders.""" # case where user does: # return dl1, dl2 dataloaders = self.dataloaders @@ -60,7 +60,7 @@ def max_batches(self) -> List[int]: @property def dataloaders(self) -> Sequence[DataLoader]: - """Returns all prediction dataloaders""" + """Returns all prediction dataloaders.""" return self.trainer.predict_dataloaders @property @@ -72,17 +72,17 @@ def connect(self, epoch_loop: PredictionEpochLoop): self.epoch_loop = epoch_loop def reset(self) -> None: - """Resets the internal state of the loop for a new run""" + """Resets the internal state of the loop for a new run.""" super().reset() self.predictions = [] self.epoch_batch_indices = [] def on_run_start(self) -> None: - """Calls ``on_predict_start`` hook""" + """Calls ``on_predict_start`` hook.""" self.on_predict_start() def advance(self, *args: Any, **kwargs: Any) -> None: - """Predicts one entire dataloader""" + """Predicts one entire dataloader.""" void(*args, **kwargs) dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) dataloader_iter = enumerate(dataloader) @@ -95,15 +95,15 @@ def advance(self, *args: Any, **kwargs: Any) -> None: self.epoch_batch_indices.append(dl_batch_indices) def on_run_end(self) -> Union[List[Any], List[List[Any]]]: - """Calls ``on_predict_epoch_end`` and ``on_predict_end`` hooks and returns results from all dataloaders""" + """Calls ``on_predict_epoch_end`` and ``on_predict_end`` hooks and returns results from all dataloaders.""" results = self.on_predict_epoch_end() self.on_predict_end() return results def on_predict_start(self) -> None: - """ - Sets model to eval mode and disables gradients. Also calls ``on_predict_start`` and - ``on_predict_epoch_start`` hooks. + """Sets model to eval mode and disables gradients. + + Also calls ``on_predict_start`` and ``on_predict_epoch_start`` hooks. """ # enable eval mode + no grads self.on_predict_model_eval() @@ -127,7 +127,7 @@ def on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]: return results[0] if self.num_dataloaders == 1 else results def on_predict_end(self) -> None: - """Resets previous gradient status and calls ``on_predict_end`` hook""" + """Resets previous gradient status and calls ``on_predict_end`` hook.""" # clear memory. the predictions are extracted in `on_predict_epoch_end`. self.predictions = [] self.epoch_batch_indices = [] @@ -136,6 +136,6 @@ def on_predict_end(self) -> None: self.trainer.call_hook("on_predict_end") def on_predict_model_eval(self): - """Calls ``on_predict_model_eval`` hook""" + """Calls ``on_predict_model_eval`` hook.""" model_ref = self.trainer.lightning_module model_ref.on_predict_model_eval() diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 9f8f77e806b87..4d6c654c93816 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -28,9 +28,10 @@ class EvaluationEpochLoop(Loop): - """ - This is the loop performing the evaluation. It mainly loops over the given dataloader and runs the validation - or test step (depending on the trainer's current state). + """This is the loop performing the evaluation. + + It mainly loops over the given dataloader and runs the validation or test step (depending on the trainer's current + state). """ def __init__(self) -> None: @@ -62,7 +63,7 @@ def reset(self) -> None: def on_run_start( self, data_fetcher: AbstractDataFetcher, dataloader_idx: int, dl_max_batches: int, num_dataloaders: int ) -> None: - """Adds the passed arguments to the loop's state if necessary + """Adds the passed arguments to the loop's state if necessary. Args: data_fetcher: the current data_fetcher wrapping the dataloader @@ -130,7 +131,7 @@ def advance( self.outputs.append(output) def on_run_end(self) -> EPOCH_OUTPUT: - """Returns the outputs of the whole run""" + """Returns the outputs of the whole run.""" outputs = self.outputs # free memory self.outputs = [] @@ -162,7 +163,7 @@ def evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Op return output def evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: - """Calls the `{validation/test}_step_end` hook""" + """Calls the `{validation/test}_step_end` hook.""" hook_name = "test_step_end" if self.trainer.testing else "validation_step_end" output = self.trainer.call_hook(hook_name, *args, **kwargs) return output @@ -205,7 +206,7 @@ def on_evaluation_batch_end( self.trainer.logger_connector.on_batch_end() def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict[str, Union[Any, int]]: - """Helper function to build the arguments for the current step + """Helper function to build the arguments for the current step. Args: batch: The current batch to run through the step @@ -228,7 +229,7 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict @lru_cache(1) def _should_track_batch_outputs_for_epoch_end(self) -> bool: - """Whether the batch outputs should be stored for later usage""" + """Whether the batch outputs should be stored for later usage.""" model = self.trainer.lightning_module if self.trainer.testing: return is_overridden("test_epoch_end", model) diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index 73557f71ade73..8b89782353d84 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -28,7 +28,7 @@ def __init__(self) -> None: @property def done(self) -> bool: - """Ends prediction when the iteration count exceeds the total number of available batches""" + """Ends prediction when the iteration count exceeds the total number of available batches.""" return self.batch_progress.current.completed >= self._dl_max_batches @property @@ -41,7 +41,7 @@ def connect(self, **kwargs: "Loop") -> None: raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.") def reset(self) -> None: - """Resets the loops internal state""" + """Resets the loops internal state.""" self._all_batch_indices: List[int] = [] self.predictions: List[Any] = [] self.batch_progress.current.reset() @@ -54,8 +54,7 @@ def on_run_start( num_dataloaders: int, return_predictions: bool = False, ) -> None: - """ - Prepares the loops internal state + """Prepares the loops internal state. Args: dataloader_iter: the iterator over the current dataloader @@ -77,8 +76,7 @@ def advance( num_dataloaders: int, return_predictions: bool = False, ) -> None: - """ - Runs one prediction step. + """Runs one prediction step. Args: dataloader_iter: the iterator over the current dataloader @@ -100,7 +98,7 @@ def advance( self._predict_step(batch, batch_idx, dataloader_idx) def on_run_end(self) -> Tuple[Any, Any]: - """Returns the predictions and the corresponding batch indices""" + """Returns the predictions and the corresponding batch indices.""" predictions = self.predictions all_batch_indices = self._all_batch_indices # free memory @@ -109,8 +107,8 @@ def on_run_end(self) -> Tuple[Any, Any]: return predictions, all_batch_indices def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: - """Runs the actual predict step together with all the - necessary bookkeeping and the hooks tied to the predict step. + """Runs the actual predict step together with all the necessary bookkeeping and the hooks tied to the + predict step. Args: batch: the current batch to run the prediction on @@ -145,8 +143,7 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None self.predictions.append(move_data_to_device(predictions, torch.device("cpu"))) def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict[str, Any]: - """ - Assembles the keyword arguments for the ``predict_step`` + """Assembles the keyword arguments for the ``predict_step`` Args: batch: the current batch to run the prediction on @@ -162,7 +159,7 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict return step_kwargs def _store_batch_indices(self, dataloader_idx: int) -> None: - """Stores the batch indices if the predictions should be stored""" + """Stores the batch indices if the predictions should be stored.""" batch_sampler = self.trainer.predict_dataloaders[dataloader_idx].batch_sampler if isinstance(batch_sampler, IndexBatchSamplerWrapper): self.current_batch_indices = batch_sampler.batch_indices diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 1db4f74008ce8..ae63e564511c9 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -26,8 +26,7 @@ class TrainingEpochLoop(loops.Loop): - """ - Runs over all batches in a dataloader (one epoch). + """Runs over all batches in a dataloader (one epoch). Args: min_steps: The minimum number of steps (batches) to process @@ -71,8 +70,9 @@ def batch_idx(self) -> int: @property def done(self) -> bool: """Returns whether the training should be stopped. - The criteria are that the number of steps reached the max steps, - the last batch is reached or the trainer signals to stop (e.g. by early stopping). + + The criteria are that the number of steps reached the max steps, the last batch is reached or the trainer + signals to stop (e.g. by early stopping). """ max_steps_reached = self.max_steps is not None and self.global_step >= self.max_steps return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(self.is_last_batch) @@ -89,7 +89,7 @@ def connect( self.val_loop = val_loop def reset(self) -> None: - """Resets the internal state of the loop for a new run""" + """Resets the internal state of the loop for a new run.""" assert self.batch_loop is not None assert self.batch_loop.optimizer_loop is not None self.is_last_batch = False @@ -257,7 +257,8 @@ def _num_training_batches_reached(self, is_last_batch: bool = False) -> bool: return self.batch_progress.current.ready == self.trainer.num_training_batches or is_last_batch def _should_accumulate(self) -> bool: - """Checks if the optimizer step should be performed or gradients should be accumulated for the current step.""" + """Checks if the optimizer step should be performed or gradients should be accumulated for the current + step.""" accumulation_done = self._accumulated_batches_reached() is_final_batch = self._num_training_batches_reached() return not (accumulation_done or is_final_batch) @@ -265,7 +266,7 @@ def _should_accumulate(self) -> bool: def _track_epoch_end_reduce_metrics( self, epoch_output: List[List[STEP_OUTPUT]], batch_end_outputs: STEP_OUTPUT ) -> None: - """Adds the batch outputs to the epoch outputs and prepares reduction""" + """Adds the batch outputs to the epoch outputs and prepares reduction.""" hook_overridden = is_overridden("training_epoch_end", self.trainer.lightning_module) if not hook_overridden: return @@ -282,8 +283,7 @@ def _track_epoch_end_reduce_metrics( def _prepare_outputs( outputs: List[List[List["ResultCollection"]]], batch_mode: bool ) -> Union[List[List[List[Dict]]], List[List[Dict]], List[Dict], Dict]: - """ - Extract required information from batch or epoch end results. + """Extract required information from batch or epoch end results. Args: outputs: A 3-dimensional list of ``ResultCollection`` objects with dimensions: @@ -335,7 +335,7 @@ def _prepare_outputs( return processed_outputs def update_lr_schedulers(self, interval: str, update_plateau_schedulers: bool) -> None: - """updates the lr schedulers based on the given interval""" + """updates the lr schedulers based on the given interval.""" if interval == "step" and self._should_accumulate(): return self.trainer.optimizer_connector.update_learning_rates( @@ -345,7 +345,7 @@ def update_lr_schedulers(self, interval: str, update_plateau_schedulers: bool) - ) def _increment_accumulated_grad_global_step(self) -> None: - """Increments global step according to grads progress""" + """Increments global step according to grads progress.""" if not self._should_accumulate(): self.global_step = self.trainer.accelerator.update_global_step( self.batch_progress.current.ready, self.trainer.global_step @@ -377,7 +377,7 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool: return is_val_check_batch def _save_loggers_on_train_batch_end(self) -> None: - """Flushes loggers to disk""" + """Flushes loggers to disk.""" # when loggers should save to disk should_flush_logs = self.trainer.logger_connector.should_flush_logs if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index bfcd3f15242da..98404f65110d8 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -26,8 +26,7 @@ class FitLoop(Loop): - """ - This Loop iterates over the epochs to run the training. + """This Loop iterates over the epochs to run the training. Args: min_epochs: The minimum number of epochs @@ -51,17 +50,17 @@ def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = @property def current_epoch(self) -> int: - """Return the current epoch""" + """Return the current epoch.""" return self.epoch_progress.current.completed @current_epoch.setter def current_epoch(self, value: int) -> None: - """Setter for the current epoch""" + """Setter for the current epoch.""" self.epoch_progress.current.completed = value @property def global_step(self) -> int: - """Returns the global step""" + """Returns the global step.""" return self.epoch_loop.global_step @global_step.setter @@ -81,13 +80,13 @@ def batch_idx(self) -> int: @property def split_idx(self) -> int: - """Returns the index of the current batch split (within the current batch) for bptt""" + """Returns the index of the current batch split (within the current batch) for bptt.""" return self.epoch_loop.batch_loop.split_idx @property def min_steps(self) -> int: # TODO(@justusschock): Why aren't we using the attribute in this class? - """Returns the minimum numnber of steps to run""" + """Returns the minimum numnber of steps to run.""" return self.epoch_loop.min_steps @min_steps.setter @@ -98,7 +97,7 @@ def min_steps(self, value: int) -> None: @property def max_steps(self) -> int: - """Returns the maximum number of steps to run""" + """Returns the maximum number of steps to run.""" return self.epoch_loop.max_steps @max_steps.setter @@ -111,7 +110,7 @@ def max_steps(self, value: int) -> None: @property def running_loss(self) -> TensorRunningAccum: - """Returns the running loss""" + """Returns the running loss.""" return self.epoch_loop.batch_loop.running_loss @property @@ -138,8 +137,8 @@ def _results(self) -> ResultCollection: @staticmethod def _is_max_limit_enabled(max_value: Optional[int]) -> bool: - """Checks whether the max_value is enabled. This can - be used for checking whether max_epochs or max_steps is enabled. + """Checks whether the max_value is enabled. This can be used for checking whether max_epochs or max_steps + is enabled. Args: max_value: the value to check @@ -153,8 +152,8 @@ def _is_max_limit_enabled(max_value: Optional[int]) -> bool: def done(self) -> bool: """Evaluates when to leave the loop. - Returns True if trainer.should_stop was set (e.g. by early stopping) - or if the maximum number of steps or epochs is reached. + Returns True if trainer.should_stop was set (e.g. by early stopping) or if the maximum number of steps or epochs + is reached. """ # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop stop_steps = FitLoop._is_max_limit_enabled(self.max_steps) and self.global_step >= self.max_steps @@ -187,7 +186,7 @@ def connect(self, epoch_loop: TrainingEpochLoop): self.epoch_loop = epoch_loop def reset(self) -> None: - """Resets the internal state of this loop""" + """Resets the internal state of this loop.""" def on_run_start(self) -> None: """Calls the ``on_train_start`` hook.""" @@ -195,7 +194,8 @@ def on_run_start(self) -> None: self.trainer.call_hook("on_train_start") def on_advance_start(self) -> None: - """Prepares the dataloader for training and calls the hooks ``on_epoch_start`` and ``on_train_epoch_start``""" + """Prepares the dataloader for training and calls the hooks ``on_epoch_start`` and + ``on_train_epoch_start``""" model = self.trainer.lightning_module # reset train dataloader @@ -241,7 +241,7 @@ def on_advance_end(self) -> None: self.epoch_progress.increment_completed() def on_run_end(self) -> None: - """Calls the ``on_train_end`` hook""" + """Calls the ``on_train_end`` hook.""" # NOTE: the current_epoch is already incremented # Lightning today does not increment the current epoch at the last epoch run in Trainer.fit # To simulate that current behavior, we decrement here. @@ -255,7 +255,7 @@ def on_run_end(self) -> None: self.trainer.accelerator.on_train_end() def should_accumulate(self) -> bool: - """Whether the gradients should be accumulated""" + """Whether the gradients should be accumulated.""" return self.epoch_loop._should_accumulate() def teardown(self) -> None: diff --git a/pytorch_lightning/loops/optimizer/optimizer_loop.py b/pytorch_lightning/loops/optimizer/optimizer_loop.py index b817207a5b639..5ec476787aed8 100644 --- a/pytorch_lightning/loops/optimizer/optimizer_loop.py +++ b/pytorch_lightning/loops/optimizer/optimizer_loop.py @@ -40,7 +40,10 @@ class OptimizerLoop(Loop): - """Runs over a sequence of optimizers. This loop implements what is known in Lightning as Automatic Optimization.""" + """Runs over a sequence of optimizers. + + This loop implements what is known in Lightning as Automatic Optimization. + """ def __init__(self): super().__init__() @@ -179,10 +182,8 @@ def _make_closure( optimizer: Optimizer, hiddens: Any, ) -> Closure: - """ - Build a closure object that captures the given arguments and runs the `training_step` function and optionally - other functions such as `backward` and `zero_grad`. - """ + """Build a closure object that captures the given arguments and runs the `training_step` function and + optionally other functions such as `backward` and `zero_grad`.""" step_fn = self._make_step_fn(split_batch, batch_idx, opt_idx, hiddens) backward_fn = self._make_backward_fn(optimizer, opt_idx) zero_grad_fn = self._make_zero_grad_fn(batch_idx, opt_idx, optimizer) @@ -201,8 +202,8 @@ def _make_step_fn( return partial(self._training_step, split_batch, batch_idx, opt_idx, hiddens) def _make_zero_grad_fn(self, batch_idx: int, opt_idx: int, optimizer: Optimizer) -> Optional[Callable[[], None]]: - """ - Build a `zero_grad` function that zeroes the gradients before back-propagation. + """Build a `zero_grad` function that zeroes the gradients before back-propagation. + Returns ``None`` in the case backward needs to be skipped. """ @@ -224,9 +225,10 @@ def _make_backward_fn( optimizer: Optimizer, opt_idx: int, ) -> Optional[Callable[[Tensor], Tensor]]: - """ - Build a `backward` function that handles back-propagation through the output produced by the `training_step` - function. Returns ``None`` in the case backward needs to be skipped. + """Build a `backward` function that handles back-propagation through the output produced by the + `training_step` function. + + Returns ``None`` in the case backward needs to be skipped. """ if self._skip_backward: return None @@ -248,7 +250,6 @@ def _run_optimization_start(self, opt_idx: int, optimizer: torch.optim.Optimizer Args: opt_idx: the index of the optimizer to use optimizer: the optimizer to use - """ # make sure only the gradients of the current optimizer's parameters are calculated # in the training step to prevent dangling gradients in multiple-optimizer setup. diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index f90bc392e72d8..154680535ef73 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -68,7 +68,7 @@ def _check_training_step_output(model: "pl.LightningModule", training_step_outpu def _process_training_step_output( trainer: "pl.Trainer", training_step_output: STEP_OUTPUT ) -> Tuple[Optional[ResultCollection], Optional[Any]]: - """Adds the :param:`training_step_output` to the trainer's results + """Adds the :param:`training_step_output` to the trainer's results. Args: trainer: a reference to the trainer @@ -120,7 +120,7 @@ def _build_training_step_kwargs( opt_idx: Optional[int], hiddens: Optional[Tensor], ) -> Dict[str, Any]: - """Builds the keyword arguments for training_step + """Builds the keyword arguments for training_step. Args: lightning_module: the LightningModule with a `training_step` hook implementation @@ -165,7 +165,7 @@ def _build_training_step_kwargs( def _prepare_dataloader_iter(data_fetcher: AbstractDataFetcher, batch_idx: int) -> Iterator: - """Attach the dataloader""" + """Attach the dataloader.""" if not isinstance(data_fetcher, DataLoaderIterDataFetcher): # restore iteration dataloader_iter = enumerate(data_fetcher, batch_idx) @@ -176,9 +176,8 @@ def _prepare_dataloader_iter(data_fetcher: AbstractDataFetcher, batch_idx: int) @contextmanager def _block_parallel_sync_behavior(trainer: "pl.Trainer", block: bool = True) -> Generator[None, None, None]: - """ - Blocks synchronization in :class:`~pytorch_lightning.plugins.training_type.parallel.ParallelPlugin`. - This is useful for example when when accumulating gradients to reduce communication when it is not needed. + """Blocks synchronization in :class:`~pytorch_lightning.plugins.training_type.parallel.ParallelPlugin`. This is + useful for example when when accumulating gradients to reduce communication when it is not needed. Args: trainer: the trainer instance with a reference to a training type plugin diff --git a/pytorch_lightning/overrides/base.py b/pytorch_lightning/overrides/base.py index c4a5a5c1d80d1..fc22902495820 100644 --- a/pytorch_lightning/overrides/base.py +++ b/pytorch_lightning/overrides/base.py @@ -23,9 +23,8 @@ class _LightningPrecisionModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module): def __init__(self, pl_module: "pl.LightningModule") -> None: - """ - Wraps the user's LightningModule. Requires overriding all ``*_step`` methods and ``forward`` so that it can - safely be wrapped by a ``_LightningModuleWrapperBase`` and a ``*DataParallel``. + """Wraps the user's LightningModule. Requires overriding all ``*_step`` methods and ``forward`` so that it + can safely be wrapped by a ``_LightningModuleWrapperBase`` and a ``*DataParallel``. Args: pl_module: the model to wrap diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index ff9dc1077f2de..ad8cae469e351 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -35,12 +35,11 @@ def _ignore_scalar_return_in_dp(): class LightningParallelModule(_LightningModuleWrapperBase): - """ - Wraps the user's LightningModule and redirects the forward call to the appropriate - method, either ``training_step``, ``validation_step``, ``test_step`` or ``predict``. - This class is used in combination with :class:`~torch.nn.parallel.DataParallel` as - shown in the example. It also takes care of converting Python scalars to Tensors and - un-squeezes 0-dimensional Tensors as it is required by :class:`~torch.nn.parallel.DataParallel`. + """Wraps the user's LightningModule and redirects the forward call to the appropriate method, either + ``training_step``, ``validation_step``, ``test_step`` or ``predict``. This class is used in combination with + :class:`~torch.nn.parallel.DataParallel` as shown in the example. It also takes care of converting Python + scalars to Tensors and un-squeezes 0-dimensional Tensors as it is required by + :class:`~torch.nn.parallel.DataParallel`. Example: @@ -52,7 +51,6 @@ class LightningParallelModule(_LightningModuleWrapperBase): Args: pl_module: the model to wrap - """ def __init__(self, pl_module: "pl.LightningModule") -> None: @@ -73,11 +71,10 @@ def output_transform(data: Any): return output def update_replica_device_attributes(self, inputs: Any) -> None: - """ - Updates the device information of LightningModule by reading the device from the inputs. - In :class:`~torch.nn.data_parallel.DataParallel` changes to the state during the `forward` pass - are lost when the replicas get discarded. The only way to know the current device is from the - inputs passed into the model. + """Updates the device information of LightningModule by reading the device from the inputs. In + :class:`~torch.nn.data_parallel.DataParallel` changes to the state during the `forward` pass are lost when + the replicas get discarded. The only way to know the current device is from the inputs passed into the + model. Args: inputs: A collection of inputs (typically a tuple). If the inputs don't contain tensors, diff --git a/pytorch_lightning/overrides/distributed.py b/pytorch_lightning/overrides/distributed.py index 644361d06e486..0cf392dd44775 100644 --- a/pytorch_lightning/overrides/distributed.py +++ b/pytorch_lightning/overrides/distributed.py @@ -24,11 +24,9 @@ class LightningDistributedModule(_LightningModuleWrapperBase): def __init__(self, pl_module: "pl.LightningModule") -> None: - """ - Wraps the user's LightningModule and redirects the forward call to the appropriate - method, either ``training_step``, ``validation_step``, ``test_step`` or ``predict``. - This class is used in combination with :class:`~torch.nn.parallel.DistributedDataParallel` as - shown in the example. + """Wraps the user's LightningModule and redirects the forward call to the appropriate method, either + ``training_step``, ``validation_step``, ``test_step`` or ``predict``. This class is used in combination + with :class:`~torch.nn.parallel.DistributedDataParallel` as shown in the example. Example: @@ -40,7 +38,6 @@ def __init__(self, pl_module: "pl.LightningModule") -> None: Args: pl_module: the model to wrap - """ super().__init__(pl_module) @@ -81,16 +78,12 @@ def prepare_for_backward(model: DistributedDataParallel, output: Any): class UnrepeatedDistributedSampler(DistributedSampler): - """ - A fork of the pytorch DistributedSampler that doesn't repeat data, instead - allowing the number of batches per process to be off-by-one from each other. - This makes this sampler usable for predictions (it's deterministic and - doesn't require shuffling). It is potentially unsafe to use this sampler for - training, because during training the DistributedDataParallel syncs buffers - on each forward pass, so it could freeze if one of the processes runs one - fewer batch. During prediction, buffers are only synced on the first batch, - so this is safe to use as long as each process runs at least one batch. We - verify this in an assert. + """A fork of the pytorch DistributedSampler that doesn't repeat data, instead allowing the number of batches + per process to be off-by-one from each other. This makes this sampler usable for predictions (it's + deterministic and doesn't require shuffling). It is potentially unsafe to use this sampler for training, + because during training the DistributedDataParallel syncs buffers on each forward pass, so it could freeze if + one of the processes runs one fewer batch. During prediction, buffers are only synced on the first batch, so + this is safe to use as long as each process runs at least one batch. We verify this in an assert. Taken from https://github.com/jpuigcerver/PyLaia/blob/v1.0.0/laia/data/unpadded_distributed_sampler.py and https://github.com/pytorch/pytorch/issues/25162#issuecomment-634146002 diff --git a/pytorch_lightning/overrides/torch_distributed.py b/pytorch_lightning/overrides/torch_distributed.py index 3a1ceca7d951c..3cbbe5ea760ff 100644 --- a/pytorch_lightning/overrides/torch_distributed.py +++ b/pytorch_lightning/overrides/torch_distributed.py @@ -16,9 +16,7 @@ # https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L160 def _rank_not_in_group(group): - """ - Helper that checks if the current process's rank is not in a given group. - """ + """Helper that checks if the current process's rank is not in a given group.""" if group is None: return False return group == GroupMember.NON_GROUP_MEMBER diff --git a/pytorch_lightning/plugins/environments/kubeflow_environment.py b/pytorch_lightning/plugins/environments/kubeflow_environment.py index cd7da406a61b3..10a020d35a529 100644 --- a/pytorch_lightning/plugins/environments/kubeflow_environment.py +++ b/pytorch_lightning/plugins/environments/kubeflow_environment.py @@ -21,10 +21,10 @@ class KubeflowEnvironment(ClusterEnvironment): - """ - Environment for distributed training using the - `PyTorchJob `_ - operator from `Kubeflow `_ + """Environment for distributed training using the `PyTorchJob`_ operator from `Kubeflow`_ + + .. _PyTorchJob: https://www.kubeflow.org/docs/components/training/pytorch/ + .. _Kubeflow: https://www.kubeflow.org """ @staticmethod diff --git a/pytorch_lightning/plugins/environments/lightning_environment.py b/pytorch_lightning/plugins/environments/lightning_environment.py index 077ebf995eebf..b3558c23d6b94 100644 --- a/pytorch_lightning/plugins/environments/lightning_environment.py +++ b/pytorch_lightning/plugins/environments/lightning_environment.py @@ -20,8 +20,7 @@ class LightningEnvironment(ClusterEnvironment): - """ - The default environment used by Lightning for a single node or free cluster (not managed). + """The default environment used by Lightning for a single node or free cluster (not managed). There are two modes the Lightning environment can operate with: @@ -42,8 +41,8 @@ def __init__(self): self._world_size: int = 1 def creates_children(self) -> bool: - """ - Returns whether the cluster creates the processes or not. + """Returns whether the cluster creates the processes or not. + If at least :code:`LOCAL_RANK` is available as environment variable, Lightning assumes the user acts as the process launcher/job scheduler and Lightning will not launch new processes. """ @@ -83,10 +82,10 @@ def teardown(self) -> None: def find_free_network_port() -> int: - """ - Finds a free port on localhost. - It is useful in single-node training when we don't want to connect to a real master node but - have to set the `MASTER_PORT` environment variable. + """Finds a free port on localhost. + + It is useful in single-node training when we don't want to connect to a real master node but have to set the + `MASTER_PORT` environment variable. """ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(("", 0)) diff --git a/pytorch_lightning/plugins/environments/lsf_environment.py b/pytorch_lightning/plugins/environments/lsf_environment.py index 249cf900ab0d9..af6bfbb8163c9 100644 --- a/pytorch_lightning/plugins/environments/lsf_environment.py +++ b/pytorch_lightning/plugins/environments/lsf_environment.py @@ -20,8 +20,7 @@ class LSFEnvironment(ClusterEnvironment): - """ - An environment for running on clusters managed by the LSF resource manager. + """An environment for running on clusters managed by the LSF resource manager. It is expected that any execution using this ClusterEnvironment was executed using the Job Step Manager i.e. ``jsrun``. @@ -104,10 +103,8 @@ def local_rank(self): return int(local_rank) def node_rank(self): - """ - The node rank is determined by the position of the current hostname in the list of hosts stored in - the environment variable `LSB_HOSTS`. - """ + """The node rank is determined by the position of the current hostname in the list of hosts stored in the + environment variable `LSB_HOSTS`.""" hosts = self._read_hosts() count = {} for host in hosts: @@ -135,8 +132,8 @@ def _get_master_address(self): @staticmethod def _get_master_port(): - """ - A helper function for accessing the master port. + """A helper function for accessing the master port. + Uses the LSF job ID so all ranks can compute the master port. """ # check for user-specified master port diff --git a/pytorch_lightning/plugins/io/checkpoint_plugin.py b/pytorch_lightning/plugins/io/checkpoint_plugin.py index 575399af48df3..506936bc347e6 100644 --- a/pytorch_lightning/plugins/io/checkpoint_plugin.py +++ b/pytorch_lightning/plugins/io/checkpoint_plugin.py @@ -18,8 +18,7 @@ class CheckpointIO(ABC): - """ - Interface to save/load checkpoints as they are saved through the ``TrainingTypePlugin``. + """Interface to save/load checkpoints as they are saved through the ``TrainingTypePlugin``. Typically most plugins either use the Torch based IO Plugin; ``TorchCheckpointIO`` but may require particular handling depending on the plugin. @@ -31,7 +30,6 @@ class CheckpointIO(ABC): For some plugins, it is not possible to use a custom checkpoint plugin as checkpointing logic is not modifiable. - """ @abstractmethod @@ -46,8 +44,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio @abstractmethod def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> Dict[str, Any]: - """ - Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages. + """Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages. Args: path: Path to checkpoint diff --git a/pytorch_lightning/plugins/io/torch_plugin.py b/pytorch_lightning/plugins/io/torch_plugin.py index 2aa66e65cc30b..be377cb39c3da 100644 --- a/pytorch_lightning/plugins/io/torch_plugin.py +++ b/pytorch_lightning/plugins/io/torch_plugin.py @@ -22,10 +22,8 @@ class TorchCheckpointIO(CheckpointIO): - """ - CheckpointIO that utilizes :func:`torch.save` and :func:`torch.load` - to save and load checkpoints respectively, common for most use cases. - """ + """CheckpointIO that utilizes :func:`torch.save` and :func:`torch.load` to save and load checkpoints + respectively, common for most use cases.""" def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: try: @@ -42,8 +40,8 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio def load_checkpoint( self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage ) -> Dict[str, Any]: - """ - Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of files. + """Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of + files. Args: path: Path to checkpoint diff --git a/pytorch_lightning/plugins/plugins_registry.py b/pytorch_lightning/plugins/plugins_registry.py index 81307157f527e..baa09af8b6dce 100644 --- a/pytorch_lightning/plugins/plugins_registry.py +++ b/pytorch_lightning/plugins/plugins_registry.py @@ -23,8 +23,7 @@ class _TrainingTypePluginsRegistry(UserDict): - """ - This class is a Registry that stores information about the Training Type Plugins. + """This class is a Registry that stores information about the Training Type Plugins. The Plugins are mapped to strings. These strings are names that idenitify a plugin, e.g., "deepspeed". It also returns Optional description and @@ -45,7 +44,6 @@ def __init__(self, a, b): or TrainingTypePluginsRegistry.register("lightning", LightningPlugin, description="Super fast", a=1, b=True) - """ def register( @@ -56,8 +54,7 @@ def register( override: bool = False, **init_params: Any, ) -> Callable: - """ - Registers a plugin mapped to a name and with required metadata. + """Registers a plugin mapped to a name and with required metadata. Args: name : the name that identifies a plugin, e.g. "deepspeed_stage_3" @@ -89,9 +86,7 @@ def do_register(plugin: Callable) -> Callable: return do_register def get(self, name: str, default: Optional[Any] = None) -> Any: - """ - Calls the registered plugin with the required parameters - and returns the plugin object + """Calls the registered plugin with the required parameters and returns the plugin object. Args: name (str): the name that identifies a plugin, e.g. "deepspeed_stage_3" @@ -108,11 +103,11 @@ def get(self, name: str, default: Optional[Any] = None) -> Any: raise KeyError(err_msg.format(name, available_names)) def remove(self, name: str) -> None: - """Removes the registered plugin by name""" + """Removes the registered plugin by name.""" self.pop(name) def available_plugins(self) -> List: - """Returns a list of registered plugins""" + """Returns a list of registered plugins.""" return list(self.keys()) def __str__(self) -> str: diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index 842a2821a1e02..297571d1174c8 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -55,7 +55,7 @@ def backward( *args: Any, **kwargs: Any, ) -> None: - """Run before precision plugin executes backward + """Run before precision plugin executes backward. Args: model: the model to be optimized @@ -68,7 +68,7 @@ def backward( @staticmethod def reinit_scheduler_properties(optimizers: Sequence[Optimizer], schedulers: Sequence[Any]) -> None: - """Reinitializes schedulers with correct properties""" + """Reinitializes schedulers with correct properties.""" # Reinitialize optimizer.step properties added by schedulers for scheduler in schedulers: scheduler = scheduler["scheduler"] diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index c0a25ea894019..c127a2076455f 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -66,7 +66,5 @@ def clip_gradients( gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, model: Optional[Module] = None, ) -> None: - """ - DeepSpeed handles clipping gradients internally via the training type plugin. - """ + """DeepSpeed handles clipping gradients internally via the training type plugin.""" pass diff --git a/pytorch_lightning/plugins/precision/double.py b/pytorch_lightning/plugins/precision/double.py index 8261fe0fbbb65..179daf9e91db8 100644 --- a/pytorch_lightning/plugins/precision/double.py +++ b/pytorch_lightning/plugins/precision/double.py @@ -25,8 +25,7 @@ class LightningDoublePrecisionModule(_LightningPrecisionModuleWrapperBase): - """ - LightningModule wrapper which converts incoming floating point data in ``*_step`` and ``forward`` to double + """LightningModule wrapper which converts incoming floating point data in ``*_step`` and ``forward`` to double (``torch.float64``) precision. Args: @@ -83,8 +82,9 @@ def connect( self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any] ) -> Tuple[nn.Module, List["Optimizer"], List[Any]]: """Converts the model to double precision and wraps it in a ``LightningDoublePrecisionModule`` to convert - incoming floating point data to double (``torch.float64``) precision. Does not alter `optimizers` or - `lr_schedulers`. + incoming floating point data to double (``torch.float64``) precision. + + Does not alter `optimizers` or `lr_schedulers`. """ model = cast(pl.LightningModule, model.double()) model = LightningDoublePrecisionModule(model) @@ -93,8 +93,8 @@ def connect( @contextmanager def train_step_context(self) -> Generator[None, None, None]: - """ - A context manager to change the default tensor type. + """A context manager to change the default tensor type. + See: :meth:`torch.set_default_tensor_type` """ torch.set_default_tensor_type(torch.DoubleTensor) @@ -103,8 +103,8 @@ def train_step_context(self) -> Generator[None, None, None]: @contextmanager def val_step_context(self) -> Generator[None, None, None]: - """ - A context manager to change the default tensor type. + """A context manager to change the default tensor type. + See: :meth:`torch.set_default_tensor_type` """ torch.set_default_tensor_type(torch.DoubleTensor) @@ -113,8 +113,8 @@ def val_step_context(self) -> Generator[None, None, None]: @contextmanager def test_step_context(self) -> Generator[None, None, None]: - """ - A context manager to change the default tensor type. + """A context manager to change the default tensor type. + See: :meth:`torch.set_default_tensor_type` """ torch.set_default_tensor_type(torch.DoubleTensor) @@ -123,8 +123,8 @@ def test_step_context(self) -> Generator[None, None, None]: @contextmanager def predict_step_context(self) -> Generator[None, None, None]: - """ - A context manager to change the default tensor type. + """A context manager to change the default tensor type. + See: :meth:`torch.set_default_tensor_type` """ torch.set_default_tensor_type(torch.DoubleTensor) diff --git a/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py b/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py index dedf274237f09..0872dea2a079c 100644 --- a/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py @@ -21,7 +21,7 @@ class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): - """Mixed Precision for Full Sharded Training""" + """Mixed Precision for Full Sharded Training.""" precision = "mixed" diff --git a/pytorch_lightning/plugins/precision/ipu_precision.py b/pytorch_lightning/plugins/precision/ipu_precision.py index 5c488ba3ea1f3..d950dfe7cc553 100644 --- a/pytorch_lightning/plugins/precision/ipu_precision.py +++ b/pytorch_lightning/plugins/precision/ipu_precision.py @@ -45,7 +45,7 @@ def clip_gradients( gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, model: Optional[Module] = None, ) -> None: - """Clips the gradients""" + """Clips the gradients.""" if clip_val is None: return diff --git a/pytorch_lightning/plugins/precision/mixed.py b/pytorch_lightning/plugins/precision/mixed.py index 7a01ad25643c3..52c8b96d42882 100644 --- a/pytorch_lightning/plugins/precision/mixed.py +++ b/pytorch_lightning/plugins/precision/mixed.py @@ -20,7 +20,7 @@ class MixedPrecisionPlugin(PrecisionPlugin): - """Base Class for mixed precision""" + """Base Class for mixed precision.""" backend: "AMPType" precision: Union[str, int] = "mixed" diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index f18b353ffe959..9373625f66d02 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -29,8 +29,7 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin): - """ - Plugin for native mixed precision training with :mod:`torch.cuda.amp`. + """Plugin for native mixed precision training with :mod:`torch.cuda.amp`. Args: precision: Whether to use torch.float16 (16) or torch.bfloat16 (bf16). @@ -116,25 +115,25 @@ def autocast_context_manager(self) -> torch.cuda.amp.autocast: @contextmanager def train_step_context(self) -> Generator[None, None, None]: - """Enable autocast context""" + """Enable autocast context.""" with self.autocast_context_manager(): yield @contextmanager def val_step_context(self) -> Generator[None, None, None]: - """Enable autocast context""" + """Enable autocast context.""" with self.autocast_context_manager(): yield @contextmanager def test_step_context(self) -> Generator[None, None, None]: - """Enable autocast context""" + """Enable autocast context.""" with self.autocast_context_manager(): yield @contextmanager def predict_step_context(self) -> Generator[None, None, None]: - """Enable autocast context""" + """Enable autocast context.""" with self.autocast_context_manager(): yield diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 86486bfc37cd9..ffeb2ccf90f4b 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -26,18 +26,17 @@ class PrecisionPlugin(CheckpointHooks): - """ - Base class for all plugins handling the precision-specific parts of the training. - The class attribute precision must be overwritten in child classes. - The default value reflects fp32 training. + """Base class for all plugins handling the precision-specific parts of the training. + + The class attribute precision must be overwritten in child classes. The default value reflects fp32 training. """ precision: Union[str, int] = 32 def master_params(self, optimizer: Optimizer) -> _PARAMETERS: - """ - The master params of the model. Returns the plain model params here. - Maybe different in other precision plugins. + """The master params of the model. + + Returns the plain model params here. Maybe different in other precision plugins. """ for group in optimizer.param_groups: yield from group["params"] @@ -45,11 +44,11 @@ def master_params(self, optimizer: Optimizer) -> _PARAMETERS: def connect( self, model: Module, optimizers: List[Optimizer], lr_schedulers: List[Any] ) -> Tuple[Module, List[Optimizer], List[Any]]: - """Connects this plugin to the accelerator and the training process""" + """Connects this plugin to the accelerator and the training process.""" return model, optimizers, lr_schedulers def pre_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Tensor: - """Run before precision plugin executes backward + """Run before precision plugin executes backward. Args: model: the model to be optimized @@ -66,7 +65,7 @@ def backward( *args: Any, **kwargs: Any, ) -> None: - """Performs the actual backpropagation + """Performs the actual backpropagation. Args: model: the model to be optimized @@ -80,7 +79,7 @@ def backward( closure_loss.backward(*args, **kwargs) def post_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Tensor: - """Run after precision plugin executes backward + """Run after precision plugin executes backward. Args: model: the model to be optimized @@ -113,7 +112,7 @@ def clip_gradients( gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, model: Optional[Module] = None, ) -> None: - """Clips the gradients""" + """Clips the gradients.""" if clip_val is None: return @@ -128,12 +127,12 @@ def clip_gradients( self.clip_grad_by_norm(optimizer, clip_val) def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None: - """Clip gradients by value""" + """Clip gradients by value.""" parameters = self.master_params(optimizer) torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val) def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None: - """Clip gradients by norm""" + """Clip gradients by norm.""" parameters = self.master_params(optimizer) torch.nn.utils.clip_grad_norm_(parameters, clip_val) @@ -148,20 +147,20 @@ def post_dispatch(self) -> None: @contextlib.contextmanager def train_step_context(self) -> Generator: - """A contextmanager for the training step""" + """A contextmanager for the training step.""" yield @contextlib.contextmanager def val_step_context(self) -> Generator: - """A contextmanager for the validation step""" + """A contextmanager for the validation step.""" yield @contextlib.contextmanager def test_step_context(self) -> Generator: - """A contextmanager for the test step""" + """A contextmanager for the test step.""" yield @contextlib.contextmanager def predict_step_context(self) -> Generator: - """A contextmanager for the predict step""" + """A contextmanager for the predict step.""" yield diff --git a/pytorch_lightning/plugins/precision/sharded_native_amp.py b/pytorch_lightning/plugins/precision/sharded_native_amp.py index a1eb23e478132..3904c36911003 100644 --- a/pytorch_lightning/plugins/precision/sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/sharded_native_amp.py @@ -22,7 +22,7 @@ class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin): - """Mixed Precision for Sharded Training""" + """Mixed Precision for Sharded Training.""" def __init__(self, precision: Union[int, str] = 16, use_cpu: bool = False) -> None: super().__init__(precision, use_cpu=use_cpu) diff --git a/pytorch_lightning/plugins/precision/tpu_bfloat.py b/pytorch_lightning/plugins/precision/tpu_bfloat.py index b86a949a889f0..4e1db6210e697 100644 --- a/pytorch_lightning/plugins/precision/tpu_bfloat.py +++ b/pytorch_lightning/plugins/precision/tpu_bfloat.py @@ -21,7 +21,7 @@ class TPUHalfPrecisionPlugin(PrecisionPlugin): - """Plugin that enables bfloats on TPUs""" + """Plugin that enables bfloats on TPUs.""" precision: int = 16 diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 143518262bb7a..8ffdbcf3e11b4 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -76,12 +76,10 @@ class DDPPlugin(ParallelPlugin): - """ - Plugin for multi-process single-device training on one or multiple nodes. + """Plugin for multi-process single-device training on one or multiple nodes. - The master process in each node spawns N-1 child processes via :func:`subprocess.Popen`, - where N is the number of devices (e.g. GPU) per node. - It is very similar to how :mod:`torch.distributed.launch` launches processes. + The master process in each node spawns N-1 child processes via :func:`subprocess.Popen`, where N is the number of + devices (e.g. GPU) per node. It is very similar to how :mod:`torch.distributed.launch` launches processes. """ distributed_backend = "ddp" @@ -404,7 +402,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: return self.dist.broadcast(obj) def pre_backward(self, closure_loss: torch.Tensor) -> None: - """Run before precision plugin executes backward""" + """Run before precision plugin executes backward.""" if not self.lightning_module.automatic_optimization: prepare_for_backward(self.model, closure_loss) @@ -412,8 +410,7 @@ def model_to_device(self): self.model.to(self.root_device) def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean") -> torch.Tensor: - """ - Reduces a tensor from several distributed processes to one aggregated tensor. + """Reduces a tensor from several distributed processes to one aggregated tensor. Args: tensor: the tensor to sync and reduce @@ -475,9 +472,7 @@ def _share_information_to_prevent_deadlock(self): self._sync_dir = sync_dirs[self.node_rank] def _share_pids(self): - """ - Make all DDP processes aware of all processes pids. - """ + """Make all DDP processes aware of all processes pids.""" self.barrier() pids = self.all_gather(torch.tensor(os.getpid(), device=self.root_device)) pids = pids.cpu().numpy().tolist() diff --git a/pytorch_lightning/plugins/training_type/ddp2.py b/pytorch_lightning/plugins/training_type/ddp2.py index 9981d2a1fc260..ae3954093880c 100644 --- a/pytorch_lightning/plugins/training_type/ddp2.py +++ b/pytorch_lightning/plugins/training_type/ddp2.py @@ -35,9 +35,8 @@ def setup(self) -> None: # the difference to DDP is that we don't call children processes here def reduce(self, collection: _METRIC_COLLECTION, *args, **kwargs) -> _METRIC_COLLECTION: - """ - Reduces a collection of tensors from all processes. It can be applied to just a single tensor. - In DDP2, the reduction here is only across local devices within the node. + """Reduces a collection of tensors from all processes. It can be applied to just a single tensor. In DDP2, + the reduction here is only across local devices within the node. Args: collection: The collection of tensors to sync and reduce. diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index a45e70adffdde..f4ae9709823f8 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -55,10 +55,8 @@ class DDPSpawnPlugin(ParallelPlugin): - """ - Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes after - training finishes. - """ + """Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes after training + finishes.""" distributed_backend = "ddp_spawn" @@ -323,13 +321,12 @@ def model_to_device(self): self.model.to(self.root_device) def pre_backward(self, closure_loss: torch.Tensor) -> None: - """Run before precision plugin executes backward""" + """Run before precision plugin executes backward.""" if not self.lightning_module.automatic_optimization: prepare_for_backward(self.model, closure_loss) def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean") -> torch.Tensor: - """ - Reduces a tensor from several distributed processes to one aggregated tensor. + """Reduces a tensor from several distributed processes to one aggregated tensor. Args: tensor: the tensor to sync and reduce diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 5fa8739de7b2d..303cc985aad3a 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -125,10 +125,9 @@ def __init__( synchronize_checkpoint_boundary: bool = False, load_full_weights: bool = False, ) -> None: - """ - Provides capabilities to run training using the DeepSpeed library, - with training optimizations for large billion parameter models. - `For more information: https://pytorch-lightning.readthedocs.io/en/latest/advanced/multi_gpu.html#deepspeed`. + """Provides capabilities to run training using the DeepSpeed library, with training optimizations for large + billion parameter models. `For more information: https://pytorch- + lightning.readthedocs.io/en/latest/advanced/multi_gpu.html#deepspeed`. .. warning:: ``DeepSpeedPlugin`` is in beta and subject to change. @@ -519,11 +518,10 @@ def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Calla self.model.step(**kwargs) def _handle_gradient_accumulation_steps(self): - """ - This functions overrides the trainer.accumulation_scheduler to generate - ``accumulate_grad_batches=1``. - Therefore, ``optimizer_step`` will be called on every batches seen - so DeepSpeed Engine handles the gradient accumulation logic internally. + """This functions overrides the trainer.accumulation_scheduler to generate ``accumulate_grad_batches=1``. + + Therefore, ``optimizer_step`` will be called on every batches seen so DeepSpeed Engine handles the gradient + accumulation logic internally. """ if self.config.get("gradient_accumulation_steps") > 1: self._original_accumulate_grad_batches = self.lightning_module.trainer.accumulate_grad_batches @@ -722,11 +720,9 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: self._restore_zero_state(checkpoint) def _restore_zero_state(self, ckpt: Mapping[str, Any]) -> None: - """ - Overrides the normal load_state_dict behaviour in PyTorch to ensure - we gather parameters that may be sharded across processes before loading - the state dictionary when using ZeRO stage 3. - This is then automatically synced across processes. + """Overrides the normal load_state_dict behaviour in PyTorch to ensure we gather parameters that may be + sharded across processes before loading the state dictionary when using ZeRO stage 3. This is then + automatically synced across processes. Args: ckpt: The ckpt file. diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 5b0887c848322..fe970bb5a3bbc 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -25,10 +25,8 @@ class DataParallelPlugin(ParallelPlugin): - """ - Implements data-parallel training in a single process, i.e., the model gets replicated to each - device and each gets a split of the data. - """ + """Implements data-parallel training in a single process, i.e., the model gets replicated to each device and + each gets a split of the data.""" def __init__( self, @@ -59,8 +57,7 @@ def setup(self) -> None: self._model = DataParallel(LightningParallelModule(self._model), self.parallel_devices) def reduce(self, collection: _METRIC_COLLECTION, *args, **kwargs) -> _METRIC_COLLECTION: - """ - Reduces a collection of tensors from all processes. It can be applied to just a single tensor. + """Reduces a collection of tensors from all processes. It can be applied to just a single tensor. Args: collection: The collection of tensors to sync and reduce. diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index 29c74439dd5ee..72338e2923c07 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -43,8 +43,7 @@ def __init__( cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, ): - """ - Plugin for Fully Sharded Data Parallel provided by FairScale. + """Plugin for Fully Sharded Data Parallel provided by FairScale. Full Sharded Training shards the entire model across all available GPUs, allowing you to scale model size, whilst using efficient communication to reduce overhead. In practice, this means we can remain diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 19694e1bcda11..c7a6fefea6542 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -146,8 +146,7 @@ def join(self): hvd.join() def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"): - """ - Reduces a tensor from several distributed processes to one aggregated tensor. + """Reduces a tensor from several distributed processes to one aggregated tensor. Args: tensor: the tensor to sync and reduce diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index 008710af9fc0e..16bc5e4e9be4b 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -55,9 +55,7 @@ def _move_float_tensors_to_half(self, batch: Any) -> Any: class IPUPlugin(ParallelPlugin): - """ - Plugin for training on IPU devices. - """ + """Plugin for training on IPU devices.""" def __init__( self, @@ -187,8 +185,8 @@ def _convert_to_poptorch_loader( @property def accumulate_grad_batches(self) -> int: - """ - Tracks lazily the set accumulate_grad_batches in the trainer. + """Tracks lazily the set accumulate_grad_batches in the trainer. + The IPUPlugin replaces the original accumulate_grad_batches. """ if self._original_accumulate_grad_batches is None: @@ -201,9 +199,8 @@ def accumulate_grad_batches(self) -> int: return self._original_accumulate_grad_batches def _handle_gradient_accumulation_steps(self): - """ - This functions overrides the trainer.accumulation_scheduler to generate - ``accumulate_grad_batches=1``. + """This functions overrides the trainer.accumulation_scheduler to generate ``accumulate_grad_batches=1``. + Therefore, ``optimizer_step`` will be called on every batch, and the IPU will handle grad accumulation. """ if self.accumulate_grad_batches > 1: @@ -262,16 +259,14 @@ def _compiled(self, model: Any): return model._executable is not None def _detach_models(self): - """ - Detaches all stage specific models from IPU devices. - """ + """Detaches all stage specific models from IPU devices.""" for k, model in self.poptorch_models.items(): if self._compiled(model) and model.isAttachedToDevice(): model.detachFromDevice() def _load_model(self, stage: str): - """ - Loads the stage specific accelerator model onto device if compiled and not attached to IPU devices. + """Loads the stage specific accelerator model onto device if compiled and not attached to IPU devices. + Args: stage: The stage to load """ diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index 31d2deb5f65e6..e6406ea444947 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -84,12 +84,10 @@ def distributed_sampler_kwargs(self): return distributed_sampler_kwargs def reconciliate_processes(self, trace: str): - """ - Function to re-conciliate processes on failure - """ + """Function to re-conciliate processes on failure.""" def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: - """Perform a all_gather on all processes""" + """Perform a all_gather on all processes.""" return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads) def reduce_boolean_decision(self, decision: bool) -> bool: @@ -107,8 +105,7 @@ def torch_distributed_backend(self): @staticmethod def configure_sync_batchnorm(model: "pl.LightningModule") -> "pl.LightningModule": - """ - Add global batchnorm for a model spread across multiple GPUs and nodes. + """Add global batchnorm for a model spread across multiple GPUs and nodes. Override to synchronize batchnorm between specific process groups instead of the whole world or use a different sync_bn like `apex`'s version. @@ -123,8 +120,8 @@ def configure_sync_batchnorm(model: "pl.LightningModule") -> "pl.LightningModule @contextmanager def block_backward_sync(self): - """ - Blocks ddp sync gradients behaviour on backwards pass. + """Blocks ddp sync gradients behaviour on backwards pass. + This is useful for skipping sync when accumulating gradients, reducing communication overhead Returns: context manager with sync behaviour off """ diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index c92fead861c19..1737bf3b41ca8 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -43,9 +43,8 @@ def on_gpu(self) -> bool: return self.root_device.type == "cuda" and torch.cuda.is_available() def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> Union[Any, torch.Tensor]: - """ - Reduces a tensor from several distributed processes to one aggregated tensor. - As this plugin only operates with a single device, the reduction is simply the identity. + """Reduces a tensor from several distributed processes to one aggregated tensor. As this plugin only + operates with a single device, the reduction is simply the identity. Args: tensor: the tensor to sync and reduce @@ -58,7 +57,7 @@ def reduce(self, tensor: Union[Any, torch.Tensor], *args: Any, **kwargs: Any) -> return tensor def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: - """Perform a all_gather on all processes""" + """Perform a all_gather on all processes.""" return tensor @property diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 6ee1ce77c8c24..8d35b130eac4d 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -32,9 +32,8 @@ class TrainingTypePlugin(ABC): - """ - Base class for all training type plugins that change the behaviour of the training, validation and test-loop. - """ + """Base class for all training type plugins that change the behaviour of the training, validation and test- + loop.""" def __init__(self, checkpoint_io: Optional[CheckpointIO] = None) -> None: self._model: Optional[Module] = None @@ -52,14 +51,14 @@ def checkpoint_io(self, plugin: CheckpointIO) -> None: self._checkpoint_io = plugin def connect(self, model: Module) -> None: - """Called by the accelerator to connect the accelerator and the model with this plugin""" + """Called by the accelerator to connect the accelerator and the model with this plugin.""" self.model = model def setup_environment(self) -> None: - """ - Setup any processes or distributed connections. - This is called before the LightningModule/DataModule setup hook - which allows the user to access the accelerator environment before setup is complete. + """Setup any processes or distributed connections. + + This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator + environment before setup is complete. """ def setup(self) -> None: @@ -68,24 +67,24 @@ def setup(self) -> None: @property @abstractmethod def on_gpu(self) -> bool: - """Returns whether the current process is done on GPU""" + """Returns whether the current process is done on GPU.""" raise NotImplementedError @property @abstractmethod def on_tpu(self) -> bool: - """Returns whether the current process is done on TPU""" + """Returns whether the current process is done on TPU.""" raise NotImplementedError @property @abstractmethod def root_device(self) -> torch.device: - """Returns the root device""" + """Returns the root device.""" raise NotImplementedError @abstractmethod def model_to_device(self) -> None: - """Moves the model to the correct device""" + """Moves the model to the correct device.""" @property @abstractmethod @@ -94,8 +93,7 @@ def is_global_zero(self) -> bool: @abstractmethod def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]: - """ - Reduces the given tensor (e.g. across GPUs/processes). + """Reduces the given tensor (e.g. across GPUs/processes). Args: tensor: the tensor to sync and reduce @@ -105,32 +103,32 @@ def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> @abstractmethod def barrier(self, name: Optional[str] = None) -> None: - """Forces all possibly joined processes to wait for each other""" + """Forces all possibly joined processes to wait for each other.""" @abstractmethod def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: - """Broadcasts an object to all processes""" + """Broadcasts an object to all processes.""" @abstractmethod def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: - """Perform a all_gather on all processes""" + """Perform a all_gather on all processes.""" def reduce_boolean_decision(self, decision: bool) -> bool: - """Reduce the early stopping decision across all processes""" + """Reduce the early stopping decision across all processes.""" return decision def pre_backward(self, closure_loss: torch.Tensor) -> None: - """Run before precision plugin executes backward""" + """Run before precision plugin executes backward.""" def post_backward(self, closure_loss: torch.Tensor) -> None: - """Run after precision plugin executes backward""" + """Run after precision plugin executes backward.""" def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, **kwargs) -> None: """Hook to do something after each optimizer step.""" @property def model(self) -> Optional[Module]: - """Returns the potentially wrapped LightningModule""" + """Returns the potentially wrapped LightningModule.""" return self._model @model.setter @@ -139,13 +137,14 @@ def model(self, new_model: Optional[Module]) -> None: @property def lightning_module(self) -> "pl.LightningModule": - """Returns the pure LightningModule without potential wrappers""" + """Returns the pure LightningModule without potential wrappers.""" return unwrap_lightning_module(self._model) @property def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: - """ - Enables plugin-agnostic access to the result returned by the training/evaluation/prediction run. The result is + """Enables plugin-agnostic access to the result returned by the training/evaluation/prediction run. + + The result is cached instead of returned directly, because some plugins require transmitting the results from one multiprocessing context to another in a separate step. For example, the plugins that use the "spawn" start-method send the result to the master process through a @@ -202,7 +201,7 @@ def test_step_end(self, output): return output def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: - """Wraps the dataloader if necessary + """Wraps the dataloader if necessary. Args: dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader` @@ -217,10 +216,9 @@ def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Calla @property def setup_optimizers_in_pre_dispatch(self) -> bool: - """ - Override to delay setting optimizers and schedulers till after dispatch. - This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model. - However this may break certain precision plugins such as APEX which require optimizers to be set. + """Override to delay setting optimizers and schedulers till after dispatch. This is useful when the + `TrainingTypePlugin` requires operating on the wrapped accelerator model. However this may break certain + precision plugins such as APEX which require optimizers to be set. Returns: If True, delay setup optimizers till pre_dispatch, else call within setup. @@ -229,9 +227,8 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: @property def restore_checkpoint_after_pre_dispatch(self) -> bool: - """ - Override to delay restoring from checkpoint till after pre-dispatch. - This is useful when the plugin requires all the setup hooks to run before loading checkpoint. + """Override to delay restoring from checkpoint till after pre-dispatch. This is useful when the plugin + requires all the setup hooks to run before loading checkpoint. Returns: If true, restore checkpoint after pre_dispatch. @@ -240,15 +237,14 @@ def restore_checkpoint_after_pre_dispatch(self) -> bool: @property def lightning_restore_optimizer_and_schedulers(self) -> bool: - """ - Override to disable Lightning restoring optimizers/schedulers. + """Override to disable Lightning restoring optimizers/schedulers. + This is useful for plugins which manage restoring optimizers/schedulers. """ return True def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int: - """ - Provide a hook to count optimizer step calls. + """Provide a hook to count optimizer step calls. Args: total_batch_idx: Total number of batches seen for training @@ -275,8 +271,7 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: @contextlib.contextmanager def model_sharded_context(self) -> Generator: - """ - Provide hook to create modules in a distributed aware context. This is useful for when we'd like to + """Provide hook to create modules in a distributed aware context. This is useful for when we'd like to shard the model instantly, which is useful for extremely large models which can save memory and initialization time. @@ -286,8 +281,8 @@ def model_sharded_context(self) -> Generator: @property def call_configure_sharded_model_hook(self) -> bool: - """ - Allow model parallel hook to be called in suitable environments determined by the training type plugin. + """Allow model parallel hook to be called in suitable environments determined by the training type plugin. + This is useful for when we want to shard the model once within fit. Returns: True if we want to call the model parallel setup hook. """ @@ -299,8 +294,8 @@ def call_configure_sharded_model_hook(self, mode: bool) -> None: @abstractmethod def teardown(self) -> None: - """ - This method is called to teardown the training process. + """This method is called to teardown the training process. + It is the right place to release memory and free other resources. """ raise NotImplementedError @@ -347,9 +342,7 @@ def on_predict_end(self): pass def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: - """ - Called in the training loop before anything happens for that batch. - """ + """Called in the training loop before anything happens for that batch.""" pass def pre_dispatch(self) -> None: diff --git a/pytorch_lightning/profiler/__init__.py b/pytorch_lightning/profiler/__init__.py index e2021d9f6734b..45205cf36a899 100644 --- a/pytorch_lightning/profiler/__init__.py +++ b/pytorch_lightning/profiler/__init__.py @@ -1,6 +1,4 @@ -""" -Profiling your training run can help you understand if there are any bottlenecks in your code. - +"""Profiling your training run can help you understand if there are any bottlenecks in your code. Built-in checks --------------- @@ -194,7 +192,6 @@ def custom_processing_step(self, data): Or:: python -c 'import torch; print(torch.autograd.profiler.load_nvprof("trace_name.prof"))' - """ from pytorch_lightning.profiler.advanced import AdvancedProfiler from pytorch_lightning.profiler.base import AbstractProfiler, BaseProfiler, PassThroughProfiler diff --git a/pytorch_lightning/profiler/advanced.py b/pytorch_lightning/profiler/advanced.py index 01531e15141e6..b9c5ec5dce7ff 100644 --- a/pytorch_lightning/profiler/advanced.py +++ b/pytorch_lightning/profiler/advanced.py @@ -25,10 +25,10 @@ class AdvancedProfiler(BaseProfiler): - """ - This profiler uses Python's cProfiler to record more detailed information about - time spent in each function call recorded during a given action. The output is quite - verbose and you should only use this if you want very detailed reports. + """This profiler uses Python's cProfiler to record more detailed information about time spent in each function + call recorded during a given action. + + The output is quite verbose and you should only use this if you want very detailed reports. """ def __init__( diff --git a/pytorch_lightning/profiler/base.py b/pytorch_lightning/profiler/base.py index 7c06cb8fe4cfa..f14082bfd8295 100644 --- a/pytorch_lightning/profiler/base.py +++ b/pytorch_lightning/profiler/base.py @@ -49,9 +49,7 @@ def teardown(self, **kwargs: Any) -> None: class BaseProfiler(AbstractProfiler): - """ - If you wish to write a custom profiler, you should inherit from this class. - """ + """If you wish to write a custom profiler, you should inherit from this class.""" def __init__( self, @@ -69,8 +67,7 @@ def __init__( @contextmanager def profile(self, action_name: str) -> Generator: - """ - Yields a context manager to encapsulate the scope of a profiled action. + """Yields a context manager to encapsulate the scope of a profiled action. Example:: @@ -163,8 +160,7 @@ def setup( self.dirpath = self.dirpath or log_dir def teardown(self, stage: Optional[str] = None) -> None: - """ - Execute arbitrary post-profiling tear-down steps. + """Execute arbitrary post-profiling tear-down steps. Closes the currently open file and stream. """ @@ -191,8 +187,8 @@ def local_rank(self) -> int: class PassThroughProfiler(BaseProfiler): - """ - This class should be used when you don't want the (small) overhead of profiling. + """This class should be used when you don't want the (small) overhead of profiling. + The Trainer uses this class by default. """ diff --git a/pytorch_lightning/profiler/pytorch.py b/pytorch_lightning/profiler/pytorch.py index 0fd942c71e601..8bdbadffec15b 100644 --- a/pytorch_lightning/profiler/pytorch.py +++ b/pytorch_lightning/profiler/pytorch.py @@ -45,8 +45,8 @@ class RegisterRecordFunction: - """ - While profiling autograd operations, this class will add labels for module names around the forward function. + """While profiling autograd operations, this class will add labels for module names around the forward + function. The Lightning PyTorch Profiler will activate this feature automatically. It can be deactivated as follows: @@ -100,10 +100,8 @@ def __exit__(self, type: Any, value: Any, traceback: Any) -> None: class ScheduleWrapper: - """ - This class is used to override the schedule logic from the profiler and perform - recording for both `training_step`, `validation_step`. - """ + """This class is used to override the schedule logic from the profiler and perform recording for both + `training_step`, `validation_step`.""" def __init__(self, schedule: Callable) -> None: if not _KINETO_AVAILABLE: @@ -233,8 +231,8 @@ def __init__( record_module_names: bool = True, **profiler_kwargs: Any, ) -> None: - """ - This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of + """This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of. + different operators inside your model - both on the CPU and GPU Args: diff --git a/pytorch_lightning/profiler/simple.py b/pytorch_lightning/profiler/simple.py index 37570fdb44d01..96854f887e875 100644 --- a/pytorch_lightning/profiler/simple.py +++ b/pytorch_lightning/profiler/simple.py @@ -27,10 +27,8 @@ class SimpleProfiler(BaseProfiler): - """ - This profiler simply records the duration of actions (in seconds) and reports - the mean duration of each action and the total time spent over the entire training run. - """ + """This profiler simply records the duration of actions (in seconds) and reports the mean duration of each + action and the total time spent over the entire training run.""" def __init__( self, diff --git a/pytorch_lightning/profiler/xla.py b/pytorch_lightning/profiler/xla.py index 402dfbd6ca0b8..e30f06f84e952 100644 --- a/pytorch_lightning/profiler/xla.py +++ b/pytorch_lightning/profiler/xla.py @@ -11,9 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -XLA Profiler will help you debug and optimize training workload performance -for your models using Cloud TPU performance tools. +"""XLA Profiler will help you debug and optimize training workload performance for your models using Cloud TPU +performance tools. Manual capture via TensorBoard @@ -38,7 +37,6 @@ 4. Once the capture is finished, the page will refresh and you could browse through the insights using the ``Tools`` dropdown at the top left - """ import logging from typing import Dict @@ -65,10 +63,8 @@ class XLAProfiler(BaseProfiler): } def __init__(self, port: int = 9012) -> None: - """ - This Profiler will help you debug and optimize training workload performance - for your models using Cloud TPU performance tools. - """ + """This Profiler will help you debug and optimize training workload performance for your models using Cloud + TPU performance tools.""" super().__init__(dirpath=None, filename=None) self.port = port self._recording_map: Dict = {} diff --git a/pytorch_lightning/setup_tools.py b/pytorch_lightning/setup_tools.py index 6b1a9d3bca335..da16d9475a41b 100644 --- a/pytorch_lightning/setup_tools.py +++ b/pytorch_lightning/setup_tools.py @@ -20,7 +20,7 @@ def _load_requirements(path_dir: str, file_name: str = "requirements.txt", comment_char: str = "#") -> List[str]: - """Load requirements from a file + """Load requirements from a file. >>> _load_requirements(_PROJECT_ROOT) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE ['numpy...', 'torch...', ...] @@ -41,7 +41,7 @@ def _load_requirements(path_dir: str, file_name: str = "requirements.txt", comme def _load_readme_description(path_dir: str, homepage: str, version: str) -> str: - """Load readme as decribtion + """Load readme as decribtion. >>> _load_readme_description(_PROJECT_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE '
...' diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index 98abd994b531d..6226a75de42fb 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -11,9 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" - -""" +"""""" from pytorch_lightning.trainer.trainer import Trainer from pytorch_lightning.utilities.seed import seed_everything diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 757489d5f372a..b8931c415553b 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -288,22 +288,16 @@ def on_before_backward(self, loss: torch.Tensor) -> None: callback.on_before_backward(self, self.lightning_module, loss) def on_after_backward(self): - """ - Called after loss.backward() and before optimizers do anything. - """ + """Called after loss.backward() and before optimizers do anything.""" for callback in self.callbacks: callback.on_after_backward(self, self.lightning_module) def on_before_optimizer_step(self, optimizer, optimizer_idx): - """ - Called after on_after_backward() once the gradient is accumulated and before optimizer.step(). - """ + """Called after on_after_backward() once the gradient is accumulated and before optimizer.step().""" for callback in self.callbacks: callback.on_before_optimizer_step(self, self.lightning_module, optimizer, optimizer_idx) def on_before_zero_grad(self, optimizer): - """ - Called after optimizer.step() and before optimizer.zero_grad(). - """ + """Called after optimizer.step() and before optimizer.zero_grad().""" for callback in self.callbacks: callback.on_before_zero_grad(self, self.lightning_module, optimizer) diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index bd0457404ed51..1c30768285fc1 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -158,7 +158,7 @@ def __verify_predict_loop_configuration(self, model: "pl.LightningModule") -> No ) def __verify_dp_batch_transfer_support(self, model: "pl.LightningModule") -> None: - """Raise Misconfiguration exception since these hooks are not supported in DP mode""" + """Raise Misconfiguration exception since these hooks are not supported in DP mode.""" # TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode. batch_transfer_hooks = ("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer") for hook in batch_transfer_hooks: diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 85a2c7ee87d13..8576d6cd140cf 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -843,10 +843,8 @@ def _set_horovod_backend(self): self.num_processes = hvd.local_size() def check_interactive_compatibility(self): - """ - Raises a `MisconfigurationException` if the accelerator and/or plugin - is not compatible with an interactive environment - """ + """Raises a `MisconfigurationException` if the accelerator and/or plugin is not compatible with an + interactive environment.""" from pytorch_lightning.utilities import _IS_INTERACTIVE if _IS_INTERACTIVE and self._distrib_type is not None and not self._distrib_type.is_interactive_compatible(): diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index e2153a0cd9341..57e95b4446369 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -139,8 +139,8 @@ def attach_model_logging_functions(self, model): callback.log_dict = model.log_dict def _attach_model_callbacks(self) -> None: - """ - Attaches the callbacks defined in the model. + """Attaches the callbacks defined in the model. + If a callback returned by the model's configure_callback method has the same type as one or several callbacks already present in the trainer callbacks list, it will replace them. In addition, all :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks @@ -167,8 +167,7 @@ def _attach_model_callbacks(self) -> None: @staticmethod def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: - """ - Moves all ModelCheckpoint callbacks to the end of the list. The sequential order within the group of + """Moves all ModelCheckpoint callbacks to the end of the list. The sequential order within the group of checkpoint callbacks is preserved, as well as the order of all other callbacks. Args: diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index ce119d80c24eb..2503fc61f4f7f 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -46,8 +46,7 @@ def hpc_resume_path(self) -> Optional[str]: return os.path.join(dir_path_hpc, f"hpc_ckpt_{max_version}.ckpt") def resume_start(self) -> None: - """ - Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority: + """Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority: 1. from HPC weights if found 2. from `resume_from_checkpoint` file if provided @@ -65,7 +64,8 @@ def resume_start(self) -> None: self._loaded_checkpoint = self.trainer.training_type_plugin.load_checkpoint(checkpoint_path) def resume_end(self) -> None: - """Signal the connector that all states have resumed and memory for the checkpoint object can be released.""" + """Signal the connector that all states have resumed and memory for the checkpoint object can be + released.""" if self.resume_checkpoint_path: rank_zero_info(f"Restored all states from the checkpoint file at {self.resume_checkpoint_path}") self.resume_checkpoint_path = None @@ -78,9 +78,8 @@ def resume_end(self) -> None: self.trainer.training_type_plugin.barrier("CheckpointConnector.resume_end") def restore(self, checkpoint_path: Optional[Union[Path, str]] = None) -> None: - """ - Attempt to restore everything at once from a 'PyTorch-Lightning checkpoint' file - through file-read and state-restore, in this priority: + """Attempt to restore everything at once from a 'PyTorch-Lightning checkpoint' file through file-read and + state-restore, in this priority: 1. from HPC weights if found 2. from `resume_from_checkpoint` file if provided @@ -115,10 +114,10 @@ def restore_datamodule(self) -> None: datamodule.on_load_checkpoint(self._loaded_checkpoint) def restore_model(self) -> None: - """ - Restores a model's weights from a PyTorch Lightning checkpoint. Hooks are called first go give - the LightningModule a chance to modify the contents, then finally the model gets updated with - the loaded weights. + """Restores a model's weights from a PyTorch Lightning checkpoint. + + Hooks are called first go give the LightningModule a chance to modify the contents, then finally the model gets + updated with the loaded weights. """ if not self._loaded_checkpoint: return @@ -151,9 +150,9 @@ def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) -> self.trainer.training_type_plugin.load_model_state_dict(checkpoint) def restore_training_state(self) -> None: - """ - Restore the trainer state from the pre-loaded checkpoint. This includes the precision settings, loop progress, - optimizer states and learning rate scheduler states. + """Restore the trainer state from the pre-loaded checkpoint. + + This includes the precision settings, loop progress, optimizer states and learning rate scheduler states. """ if not self._loaded_checkpoint: return @@ -180,8 +179,8 @@ def restore_callbacks(self) -> None: self.trainer.on_load_checkpoint(self._loaded_checkpoint) def restore_loops(self) -> None: - """ - Restores the loop progress from the pre-loaded checkpoint. + """Restores the loop progress from the pre-loaded checkpoint. + Calls hooks on the loops to give it a chance to restore its state from the checkpoint. """ if not self._loaded_checkpoint: @@ -384,11 +383,9 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: return checkpoint def hpc_load(self, checkpoint_path: str) -> None: - """ - Attempts to restore the full training and model state from a HPC checkpoint file. + """Attempts to restore the full training and model state from a HPC checkpoint file. - .. deprecated::v1.4 - Will be removed in v1.6. Use :meth:`restore` instead. + .. deprecated:: v1.4 Will be removed in v1.6. Use :meth:`restore` instead. """ rank_zero_deprecation( "`CheckpointConnector.hpc_load()` was deprecated in v1.4 and will be removed in v1.6." @@ -398,6 +395,7 @@ def hpc_load(self, checkpoint_path: str) -> None: def max_ckpt_version_in_folder(self, dir_path: Union[str, Path], name_key: str = "ckpt_") -> Optional[int]: """List up files in `dir_path` with `name_key`, then yield maximum suffix number. + Args: dir_path: path of directory which may contain files whose name include `name_key` name_key: file name prefix diff --git a/pytorch_lightning/trainer/connectors/debugging_connector.py b/pytorch_lightning/trainer/connectors/debugging_connector.py index 2f71e4627a968..8bd1a3f3207f9 100644 --- a/pytorch_lightning/trainer/connectors/debugging_connector.py +++ b/pytorch_lightning/trainer/connectors/debugging_connector.py @@ -76,7 +76,7 @@ def on_init_start( self.determine_data_use_amount(self.trainer.overfit_batches) def determine_data_use_amount(self, overfit_batches: float) -> None: - """Use less data for debugging purposes""" + """Use less data for debugging purposes.""" if overfit_batches > 0: self.trainer.limit_train_batches = overfit_batches self.trainer.limit_val_batches = overfit_batches diff --git a/pytorch_lightning/trainer/connectors/env_vars_connector.py b/pytorch_lightning/trainer/connectors/env_vars_connector.py index d3084e3e4ece5..4d130ca8e720b 100644 --- a/pytorch_lightning/trainer/connectors/env_vars_connector.py +++ b/pytorch_lightning/trainer/connectors/env_vars_connector.py @@ -19,10 +19,8 @@ def _defaults_from_env_vars(fn: Callable) -> Callable: - """ - Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods for which - input arguments should be moved automatically to the correct device. - """ + """Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods for which input arguments should + be moved automatically to the correct device.""" @wraps(fn) def insert_env_defaults(self, *args, **kwargs): diff --git a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py index 8c8546f0b61e3..a928122a2053a 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py @@ -98,7 +98,7 @@ class _LogOptions(TypedDict): @classmethod def check_logging(cls, fx_name: str, on_step: bool, on_epoch: bool) -> None: - """Check if the given function name is allowed to log""" + """Check if the given function name is allowed to log.""" if fx_name not in cls.functions: raise RuntimeError( f"Logging inside `{fx_name}` is not implemented." diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index b9371a83a71c6..c9daae80e5c45 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -79,9 +79,8 @@ def configure_logger(self, logger: Union[bool, LightningLoggerBase, Iterable[Lig self.trainer.logger = logger def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None: - """Logs the metric dict passed in. - If `step` parameter is None and `step` key is presented is metrics, - uses metrics["step"] as a step + """Logs the metric dict passed in. If `step` parameter is None and `step` key is presented is metrics, uses + metrics["step"] as a step. Args: metrics: Metric values diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index ba8278592df0e..71dd88aee5c7b 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -294,8 +294,7 @@ def to(self, *args: Any, **kwargs: Any) -> "ResultMetric": class ResultMetricCollection(dict): - """ - Dict wrapper for easy access to metadata. + """Dict wrapper for easy access to metadata. All of the leaf items should be instances of :class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetric` @@ -378,10 +377,8 @@ def batch_size(self, value: int) -> None: @property def minimize(self) -> Optional[torch.Tensor]: - """ - The :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` loss - will be saved as the ``minimize`` attribute. - """ + """The :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` loss will be saved as the + ``minimize`` attribute.""" return self._minimize @minimize.setter @@ -480,7 +477,10 @@ def log( self.update_metrics(key, value) def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None: - """Create one ResultMetric object per value. Value can be provided as a nested collection""" + """Create one ResultMetric object per value. + + Value can be provided as a nested collection + """ def fn(v: _IN_METRIC) -> ResultMetric: metric = ResultMetric(meta, isinstance(v, torch.Tensor)) @@ -573,8 +573,7 @@ def any_tensor(_: Any) -> None: return metrics def reset(self, metrics: Optional[bool] = None, fx: Optional[str] = None) -> None: - """ - Reset the result collection + """Reset the result collection. Args: metrics: If True, only ``torchmetrics.Metric`` results are reset, diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 26c295b89d03a..1dd5c1c6fec36 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -282,8 +282,8 @@ def _get_distributed_sampler( return sampler def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: - """Resets the train dataloader and initialises required variables - (number of batches, when to validate, etc.). + """Resets the train dataloader and initialises required variables (number of batches, when to validate, + etc.). Args: model: The `LightningModule` if calling this outside of the trainer scope. @@ -501,8 +501,7 @@ def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None) ) def reset_train_val_dataloaders(self, model: Optional["pl.LightningModule"] = None) -> None: - """ - Resets train and val dataloaders if none are attached to the trainer. + """Resets train and val dataloaders if none are attached to the trainer. The val dataloader must be initialized before training loop starts, as the training loop inspects the val dataloader to determine whether to run the evaluation loop. @@ -533,9 +532,8 @@ def request_dataloader( @staticmethod def _add_sampler_metadata_collate(dataloader: DataLoader) -> None: - """ - Wrap default collate function to retrive ``FastForwardSampler`` state dict when fault tolerant is enabled. - """ + """Wrap default collate function to retrive ``FastForwardSampler`` state dict when fault tolerant is + enabled.""" dataloader.collate_fn = partial( _capture_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn ) diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 0701eab390333..544b18b8975d8 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -116,7 +116,7 @@ def _convert_to_lightning_optimizer(trainer, optimizer): def configure_schedulers( self, schedulers: list, monitor: Optional[str], is_manual_optimization: bool ) -> List[Dict[str, Any]]: - """Convert each scheduler into dict structure with relevant information""" + """Convert each scheduler into dict structure with relevant information.""" lr_schedulers = [] default_config = _get_default_scheduler_config() for scheduler in schedulers: @@ -179,9 +179,8 @@ def configure_schedulers( class _MockOptimizer(Optimizer): - """The `_MockOptimizer` will be used inplace of an optimizer in the event that `None` - is returned from `configure_optimizers`. - """ + """The `_MockOptimizer` will be used inplace of an optimizer in the event that `None` is returned from + `configure_optimizers`.""" def __init__(self): super().__init__([torch.zeros(1)], {}) diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index fd7034f63e1da..4f53968125d4b 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -17,9 +17,7 @@ @dataclass class BaseProgress: - """ - Mixin that implements state-loading utilities for dataclasses. - """ + """Mixin that implements state-loading utilities for dataclasses.""" def state_dict(self) -> dict: return asdict(self) @@ -36,8 +34,7 @@ def from_state_dict(cls, state_dict: dict) -> "BaseProgress": @dataclass class ReadyCompletedTracker(BaseProgress): - """ - Track an event's progress. + """Track an event's progress. Args: ready: Intended to track the number of events ready to start. @@ -55,19 +52,17 @@ def reset(self) -> None: self.completed = 0 def reset_on_restart(self) -> None: - """ - Reset the progress on restart. + """Reset the progress on restart. - If there is a failure before all attributes are increased, - restore the attributes to the last fully completed value. + If there is a failure before all attributes are increased, restore the attributes to the last fully completed + value. """ self.ready = self.completed @dataclass class StartedTracker(ReadyCompletedTracker): - """ - Track an event's progress. + """Track an event's progress. Args: ready: Intended to track the number of events ready to start. @@ -90,8 +85,7 @@ def reset_on_restart(self) -> None: @dataclass class ProcessedTracker(StartedTracker): - """ - Track an event's progress. + """Track an event's progress. Args: ready: Intended to track the number of events ready to start. @@ -116,8 +110,7 @@ def reset_on_restart(self) -> None: @dataclass class Progress(BaseProgress): - """ - Track aggregated and current progress. + """Track aggregated and current progress. Args: total: Intended to track the total progress of an event. @@ -163,9 +156,8 @@ def load_state_dict(self, state_dict: dict) -> None: @dataclass class DataLoaderProgress(Progress): - """ - Tracks the dataloader progress - These counters are local to a trainer rank. By default, they are not globally synced across all ranks. + """Tracks the dataloader progress These counters are local to a trainer rank. By default, they are not globally + synced across all ranks. Args: total: Tracks the total dataloader progress. @@ -178,9 +170,8 @@ class DataLoaderProgress(Progress): @dataclass class SchedulerProgress(Progress): - """ - Tracks the scheduler progress. - These counters are local to a trainer rank. By default, they are not globally synced across all ranks. + """Tracks the scheduler progress. These counters are local to a trainer rank. By default, they are not globally + synced across all ranks. Args: total: Tracks the total scheduler progress. @@ -193,8 +184,7 @@ class SchedulerProgress(Progress): @dataclass class OptimizerProgress(BaseProgress): - """ - Track optimizer progress. + """Track optimizer progress. Args: step: Tracks ``optimizer.step`` calls. @@ -215,8 +205,7 @@ def load_state_dict(self, state_dict: dict) -> None: @dataclass class OptimizationProgress(BaseProgress): - """ - Track optimization progress. + """Track optimization progress. Args: optimizer: Tracks optimizer progress. diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index a693653aa9ffa..22d8ee01d3871 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -220,8 +220,8 @@ def gpus(self) -> Optional[Union[List[int], str, int]]: @property def model(self) -> torch.nn.Module: - """ - The LightningModule, but possibly wrapped into DataParallel or DistributedDataParallel. + """The LightningModule, but possibly wrapped into DataParallel or DistributedDataParallel. + To access the pure LightningModule, use :meth:`~pytorch_lightning.trainer.trainer.Trainer.lightning_module` instead. """ @@ -229,9 +229,8 @@ def model(self) -> torch.nn.Module: @model.setter def model(self, model: torch.nn.Module) -> None: - """ - Setter for the model, pass-through to accelerator and plugin where the model reference is stored. - Used by the Tuner to reset the state of Trainer and Accelerator. + """Setter for the model, pass-through to accelerator and plugin where the model reference is stored. Used + by the Tuner to reset the state of Trainer and Accelerator. Args: model: The LightningModule, possibly wrapped into DataParallel or DistributedDataParallel, depending @@ -347,8 +346,8 @@ def enable_validation(self) -> bool: @property def default_root_dir(self) -> str: - """ - The default location to save artifacts of loggers, checkpoints etc. + """The default location to save artifacts of loggers, checkpoints etc. + It is used as a fallback if logger or checkpoint callback do not define specific save paths. """ if get_filesystem(self._default_root_dir).protocol == "file": @@ -367,44 +366,34 @@ def weights_save_path(self) -> str: @property def early_stopping_callback(self) -> Optional[EarlyStopping]: - """ - The first :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` - callback in the Trainer.callbacks list, or ``None`` if it doesn't exist. - """ + """The first :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` callback in the + Trainer.callbacks list, or ``None`` if it doesn't exist.""" callbacks = self.early_stopping_callbacks return callbacks[0] if len(callbacks) > 0 else None @property def early_stopping_callbacks(self) -> List[EarlyStopping]: - """ - A list of all instances of :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` - found in the Trainer.callbacks list. - """ + """A list of all instances of :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` found in + the Trainer.callbacks list.""" return [c for c in self.callbacks if isinstance(c, EarlyStopping)] @property def prediction_writer_callbacks(self) -> List[BasePredictionWriter]: - """ - A list of all instances of :class:`~pytorch_lightning.callbacks.prediction_writer.BasePredictionWriter` - found in the Trainer.callbacks list. - """ + """A list of all instances of :class:`~pytorch_lightning.callbacks.prediction_writer.BasePredictionWriter` + found in the Trainer.callbacks list.""" return [cb for cb in self.callbacks if isinstance(cb, BasePredictionWriter)] @property def checkpoint_callback(self) -> Optional[ModelCheckpoint]: - """ - The first :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` - callback in the Trainer.callbacks list, or ``None`` if it doesn't exist. - """ + """The first :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callback in the + Trainer.callbacks list, or ``None`` if it doesn't exist.""" callbacks = self.checkpoint_callbacks return callbacks[0] if len(callbacks) > 0 else None @property def checkpoint_callbacks(self) -> List[ModelCheckpoint]: - """ - A list of all instances of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` - found in the Trainer.callbacks list. - """ + """A list of all instances of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` found + in the Trainer.callbacks list.""" return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)] @property @@ -564,8 +553,9 @@ def fit_loop(self) -> FitLoop: @fit_loop.setter def fit_loop(self, loop: FitLoop): - """ - Attach a custom fit loop to this Trainer. It will run with + """Attach a custom fit loop to this Trainer. + + It will run with :meth:`~pytorch_lighting.trainer.trainer.Trainer.fit`. """ loop.trainer = self @@ -577,8 +567,9 @@ def validate_loop(self) -> EvaluationLoop: @validate_loop.setter def validate_loop(self, loop: EvaluationLoop): - """ - Attach a custom validation loop to this Trainer. It will run with + """Attach a custom validation loop to this Trainer. + + It will run with :meth:`~pytorch_lighting.trainer.trainer.Trainer.validate`. Note that this loop is different from the one running during training inside the :meth:`pytorch_lightning.trainer.trainer.Trainer.fit` call. """ @@ -591,8 +582,9 @@ def test_loop(self) -> EvaluationLoop: @test_loop.setter def test_loop(self, loop: EvaluationLoop): - """ - Attach a custom test loop to this Trainer. It will run with + """Attach a custom test loop to this Trainer. + + It will run with :meth:`~pytorch_lightning.trainer.trainer.Trainer.test`. """ loop.trainer = self @@ -604,8 +596,9 @@ def predict_loop(self) -> PredictionLoop: @predict_loop.setter def predict_loop(self, loop: PredictionLoop): - """ - Attach a custom prediction loop to this Trainer. It will run with + """Attach a custom prediction loop to this Trainer. + + It will run with :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. """ loop.trainer = self diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index b6d52d62dc0f6..7f83dd76156ab 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -45,8 +45,7 @@ class TrainerFn(LightningEnum): @property def _setup_fn(self) -> "TrainerFn": - """ - ``FITTING`` is used instead of ``TUNING`` as there are no "tune" dataloaders. + """``FITTING`` is used instead of ``TUNING`` as there are no "tune" dataloaders. This is used for the ``setup()`` and ``teardown()`` hooks """ @@ -54,8 +53,7 @@ def _setup_fn(self) -> "TrainerFn": class RunningStage(LightningEnum): - """ - Enum for the current running stage. + """Enum for the current running stage. This stage complements :class:`TrainerFn` by specifying the current running stage for each function. More than one running stage value can be set while a :class:`TrainerFn` is running: @@ -89,7 +87,7 @@ def dataloader_prefix(self) -> Optional[str]: @dataclass class TrainerState: - """Dataclass to encapsulate the current :class:`~pytorch_lightning.trainer.trainer.Trainer` state""" + """Dataclass to encapsulate the current :class:`~pytorch_lightning.trainer.trainer.Trainer` state.""" status: TrainerStatus = TrainerStatus.INITIALIZING fn: Optional[TrainerFn] = None diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 6909214fc95bc..0fa9e8c219df0 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -37,8 +37,7 @@ class TensorRunningAccum: - """Tracks a running accumulation values (min, max, mean) without graph - references. + """Tracks a running accumulation values (min, max, mean) without graph references. Examples: >>> accum = TensorRunningAccum(5) @@ -117,9 +116,9 @@ def _agg_memory(self, how: str): class SharedCycleIteratorState: """A state shared between all CylceIterators in a CombinedLoader. - With a shared state, the iterators can decide to terminate based on the state of all others. - If the mode is *max_size_cycle*, all iterators need to have finished before the combined loading is considered - finished, and otherwise any iterator finishing early will lead to all iterators ending early. + With a shared state, the iterators can decide to terminate based on the state of all others. If the mode is + *max_size_cycle*, all iterators need to have finished before the combined loading is considered finished, and + otherwise any iterator finishing early will lead to all iterators ending early. """ mode: str = "max_size_cycle" @@ -143,9 +142,7 @@ def done(self) -> bool: class CycleIterator: - """ - Iterator for restarting a dataloader if it runs out of samples - """ + """Iterator for restarting a dataloader if it runs out of samples.""" def __init__(self, loader: Any, length: Optional[int] = None, state: SharedCycleIteratorState = None): """ @@ -173,8 +170,7 @@ def __init__(self, loader: Any, length: Optional[int] = None, state: SharedCycle self.state = state def __iter__(self) -> Any: - """ - Creates the internal iterator and returns self + """Creates the internal iterator and returns self. Returns: CycleIterator: self @@ -226,9 +222,7 @@ def __len__(self) -> Union[int, float]: class CombinedDataset: - """ - Combine multiple datasets and compute their statistics - """ + """Combine multiple datasets and compute their statistics.""" COMPUTE_FUNCS = {"min_size": min, "max_size_cycle": max} @@ -257,8 +251,7 @@ def min_len(self) -> Union[int, float]: return self._calc_num_data(self.datasets, "min_size") def _calc_num_data(self, datasets: Union[Sequence, Mapping], mode: str) -> Union[int, float]: - """ - Compute the length of `CombinedDataset` according to the `mode`. + """Compute the length of `CombinedDataset` according to the `mode`. Args: datasets: a sequence/mapping datasets. Can be a collections of torch.utils.data.Dataset, @@ -314,12 +307,10 @@ def __len__(self) -> int: class CombinedLoader: - """ - Combines different dataloaders and allows sampling in parallel. - Supported modes are 'min_size', which raises StopIteration after the shortest loader - (the one with the lowest number of batches) is done, and 'max_size_cycle` which raises - StopIteration after the longest loader (the one with most batches) is done, while cycling - through the shorter loaders. + """Combines different dataloaders and allows sampling in parallel. Supported modes are 'min_size', which raises + StopIteration after the shortest loader (the one with the lowest number of batches) is done, and + 'max_size_cycle` which raises StopIteration after the longest loader (the one with most batches) is done, while + cycling through the shorter loaders. Examples: >>> loaders = {'a': torch.utils.data.DataLoader(range(6), batch_size=4), @@ -375,8 +366,7 @@ def _state_dict_fn(dataloader: DataLoader, iterator: Optional[Iterator], has_com return {} def state_dict(self, has_completed: bool = False) -> Dict: - """ - The state dict includes all states from wrapped dataloaders and their samplers through the + """The state dict includes all states from wrapped dataloaders and their samplers through the ``CaptureIterableDataset`` and fast-forward samplers. Args: @@ -476,8 +466,7 @@ def sampler(self) -> Union[Iterable, Sequence, Mapping]: return apply_to_collection(self.loaders, (DataLoader, IterableDataset), getattr, "sampler", None) def _wrap_loaders_max_size_cycle(self) -> Any: - """ - Wraps all loaders to make sure they are cycled until the longest loader is exhausted + """Wraps all loaders to make sure they are cycled until the longest loader is exhausted. Returns: the wrapped loaders @@ -496,9 +485,7 @@ def _wrap_loaders_max_size_cycle(self) -> Any: state.reset() def __iter__(self) -> Any: - """ - Create and return an iterator, `CombinedLoaderIterator`, for the combined loader. - """ + """Create and return an iterator, `CombinedLoaderIterator`, for the combined loader.""" # prevent `NotImplementedError` from PyTorch: # https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/dataloader.py#L541 @@ -514,8 +501,7 @@ def __getstate__patch__(*_): @staticmethod def _calc_num_batches(loaders: Any) -> Union[int, float]: - """ - Compute the length (aka the number of batches) of `CombinedLoader`. + """Compute the length (aka the number of batches) of `CombinedLoader`. Args: loaders: a collections of loaders. @@ -534,9 +520,7 @@ def __len__(self) -> int: class CombinedLoaderIterator: - """ - Custom Iterator returning data from multple loaders, and allows sampling in parallel - """ + """Custom Iterator returning data from multple loaders, and allows sampling in parallel.""" def __init__(self, loaders: Any): """ @@ -548,9 +532,7 @@ def __init__(self, loaders: Any): @property def loader_iters(self) -> Any: - """ - Get the `_loader_iters` and create one if it is None. - """ + """Get the `_loader_iters` and create one if it is None.""" if self._loader_iters is None: self._loader_iters = self.create_loader_iters(self.loaders) @@ -560,8 +542,7 @@ def __iter__(self) -> Any: return self def __next__(self) -> Any: - """ - Fetches the next batch from multiple data loaders + """Fetches the next batch from multiple data loaders. Returns: a collections of batch data @@ -570,8 +551,7 @@ def __next__(self) -> Any: @staticmethod def request_next_batch(loader_iters: Union[Iterator, Sequence, Mapping]) -> Any: - """ - Return the batch of data from multiple iterators. + """Return the batch of data from multiple iterators. Args: loader_iters: a collections of iterators @@ -585,8 +565,7 @@ def request_next_batch(loader_iters: Union[Iterator, Sequence, Mapping]) -> Any: def create_loader_iters( loaders: Union[Any, Iterator, Sequence, Mapping] ) -> Union[Any, Iterator, Sequence, Mapping]: - """ - Create and return a collection of iterators from loaders. + """Create and return a collection of iterators from loaders. Args: loaders: a collections of loaders diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index c048ce0a42dd9..7e1181ebdb58c 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -171,9 +171,11 @@ def _run_power_scaling( def _run_binsearch_scaling( trainer: "pl.Trainer", model: "pl.LightningModule", new_size: int, batch_arg_name: str, max_trials: int ) -> int: - """Batch scaling mode where the size is initially is doubled at each iteration - until an OOM error is encountered. Hereafter, the batch size is further - refined using a binary search""" + """Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is + encountered. + + Hereafter, the batch size is further refined using a binary search + """ low = 1 high = None count = 0 diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 46b44a7a26846..e44f3168633db 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -98,10 +98,8 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int): self._total_batch_idx = 0 # for debug purpose def _exchange_scheduler(self, configure_optimizers: Callable): - """Decorate configure_optimizers methods such that it returns the users - originally specified optimizer together with a new scheduler that - that takes care of the learning rate search. - """ + """Decorate configure_optimizers methods such that it returns the users originally specified optimizer + together with a new scheduler that that takes care of the learning rate search.""" @wraps(configure_optimizers) def func(): @@ -171,14 +169,13 @@ def plot(self, suggest: bool = False, show: bool = False): return fig def suggestion(self, skip_begin: int = 10, skip_end: int = 1): - """This will propose a suggestion for choice of initial learning rate - as the point with the steepest negative gradient. + """This will propose a suggestion for choice of initial learning rate as the point with the steepest + negative gradient. Returns: lr: suggested initial learning rate to use skip_begin: how many samples to skip in the beginning. Prevent too naive estimates skip_end: how many samples to skip in the end. Prevent too optimistic estimates - """ try: loss = np.array(self.results["loss"][skip_begin:-skip_end]) @@ -302,9 +299,8 @@ def __lr_finder_restore_params(trainer, model): class _LRCallback(Callback): - """Special callback used by the learning rate finder. This callbacks log - the learning rate before each batch and log the corresponding loss after - each batch. + """Special callback used by the learning rate finder. This callbacks log the learning rate before each batch + and log the corresponding loss after each batch. Args: num_training: number of iterations done by the learning rate finder @@ -316,7 +312,6 @@ class _LRCallback(Callback): beta: smoothing value, the loss being logged is a running average of loss values logged until now. ``beta`` controls the forget rate i.e. if ``beta=0`` all past information is ignored. - """ def __init__( @@ -337,7 +332,7 @@ def __init__( self.progress_bar = None def on_batch_start(self, trainer, pl_module): - """Called before each training batch, logs the lr that will be used""" + """Called before each training batch, logs the lr that will be used.""" if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0: return @@ -347,7 +342,7 @@ def on_batch_start(self, trainer, pl_module): self.lrs.append(trainer.lr_schedulers[0]["scheduler"].lr[0]) def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - """Called when the training batch ends, logs the calculated loss""" + """Called when the training batch ends, logs the calculated loss.""" if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0: return @@ -376,8 +371,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data class _LinearLR(_LRScheduler): - """ - Linearly increases the learning rate between two boundaries over a number of iterations. + """Linearly increases the learning rate between two boundaries over a number of iterations. Args: @@ -415,8 +409,7 @@ def lr(self): class _ExponentialLR(_LRScheduler): - """Exponentially increases the learning rate between two boundaries - over a number of iterations. + """Exponentially increases the learning rate between two boundaries over a number of iterations. Arguments: diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index 0a16eccf0c846..88c53f9c28103 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -21,7 +21,7 @@ class Tuner: - """Tuner class to tune your model""" + """Tuner class to tune your model.""" def __init__(self, trainer: "pl.Trainer") -> None: self.trainer = trainer @@ -57,7 +57,7 @@ def _tune( return result def _run(self, *args: Any, **kwargs: Any) -> None: - """`_run` wrapper to set the proper state during tuning, as this can be called multiple times""" + """`_run` wrapper to set the proper state during tuning, as this can be called multiple times.""" self.trainer.state.status = TrainerStatus.RUNNING # last `_run` call might have set it to `FINISHED` self.trainer.training = True self.trainer._run(*args, **kwargs) @@ -76,9 +76,8 @@ def scale_batch_size( batch_arg_name: str = "batch_size", train_dataloader=None, # TODO: remove with 1.6 ) -> Optional[int]: - """ - Iteratively try to find the largest batch size for a given model - that does not give an out of memory (OOM) error. + """Iteratively try to find the largest batch size for a given model that does not give an out of memory + (OOM) error. Args: model: Model to tune. @@ -146,9 +145,8 @@ def lr_find( update_attr: bool = False, train_dataloader=None, # TODO: remove with 1.6 ) -> Optional[_LRFinder]: - """ - Enables the user to do a range test of good initial learning rates, - to reduce the amount of guesswork in picking a good starting learning rate. + """Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in + picking a good starting learning rate. Args: model: Model to tune. diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index ed9bb930001d8..66c7721599a7b 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""General utilities""" +"""General utilities.""" import numpy diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index c31e7dc13caba..2758262653ba7 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -72,8 +72,7 @@ def apply_to_collection( include_none: bool = True, **kwargs: Any, ) -> Any: - """ - Recursively applies a function to all elements of a certain dtype. + """Recursively applies a function to all elements of a certain dtype. Args: data: the collection to apply the function to @@ -146,8 +145,7 @@ def apply_to_collections( wrong_dtype: Optional[Union[type, Tuple[type]]] = None, **kwargs: Any, ) -> Any: - """ - Zips two collections and applies a function to their items of a certain dtype. + """Zips two collections and applies a function to their items of a certain dtype. Args: data1: The first collection @@ -227,9 +225,8 @@ def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any: - """ - Transfers a collection of data to the given device. Any object that defines a method - ``to(device)`` will be moved and all other objects in the collection will be left untouched. + """Transfers a collection of data to the given device. Any object that defines a method ``to(device)`` will be + moved and all other objects in the collection will be left untouched. Args: batch: A tensor or collection of tensors or anything that has a method `.to(...)`. diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index b9757715a3267..d457bbd587209 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -34,8 +34,8 @@ def parse_argparser(cls, args: "ArgumentParser") -> Any: def from_argparse_args( cls: Type[ParseArgparserDataType], args: Union[Namespace, ArgumentParser], **kwargs: Any ) -> ParseArgparserDataType: - """Create an instance from CLI arguments. - Eventually use varibles from OS environement which are defined as "PL__" + """Create an instance from CLI arguments. Eventually use varibles from OS environement which are defined as + "PL__". Args: cls: Lightning class diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index eaf70f3f9b31f..d67c7d74231e3 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -30,9 +30,10 @@ class FastForwardSampler(Sampler): - """ - This FastForwardSampler wraps a :class:`torch.utils.data.Sampler` and records the number of iterations - performed during an epoch. It maintains a state, saved with :meth:`state_dict`, that can be reloaded with + """This FastForwardSampler wraps a :class:`torch.utils.data.Sampler` and records the number of iterations + performed during an epoch. + + It maintains a state, saved with :meth:`state_dict`, that can be reloaded with :meth:`load_state_dict`. If the sampler is used in a multiprocessing context, the ``FastForwardSampler`` will record the state of the current worker. When reloading, the ``FastForwardSampler`` will "fast-forward" the wrapped sampler by iterating through all the @@ -54,7 +55,9 @@ def __getattr__(self, key: str) -> Any: return getattr(self._sampler, key, None) def setup(self, dataloader_batch_size: Optional[int] = None) -> None: - """Setup the ``FastForwardSampler``. This is required only when the provided dataset subclassed + """Setup the ``FastForwardSampler``. + + This is required only when the provided dataset subclassed :class:`torch.utils.data.Dataset`. """ self._dataloader_batch_size = dataloader_batch_size @@ -96,14 +99,17 @@ def __len__(self) -> int: return len(self._sampler) def state_dict(self, num_batches_processed: Optional[int] = None) -> Dict[int, Dict[str, int]]: - """Returns the state of the sampler in the current worker. The worker id indexes the state dict.""" + """Returns the state of the sampler in the current worker. + + The worker id indexes the state dict. + """ return {self.worker_id: {"current_iteration": self._compute_current_iteration(num_batches_processed)}} def load_state_dict(self, state_dict: Dict[int, Any]) -> None: - """ - Loads the saved state for the wrapped sampler. - If the ``state_dict`` contains multiple states, it means there were multiple workers. - The state will be cached and fully reloaded (fast-forward) the first time :meth:`__iter__` is called. + """Loads the saved state for the wrapped sampler. + + If the ``state_dict`` contains multiple states, it means there were multiple workers. The state will be cached + and fully reloaded (fast-forward) the first time :meth:`__iter__` is called. """ # as workers aren't available, the ``state_dict``` is cached until workers are made available. state_dict = deepcopy(state_dict) @@ -111,10 +117,10 @@ def load_state_dict(self, state_dict: Dict[int, Any]) -> None: self.restarting = True def _compute_current_iteration(self, num_batches_processed: Optional[int] = None) -> int: - """ - This function is used to compute the effective iteration. - As DataLoader can perform ``prefecthing`` or training can fail while processing a batch, - the current iteration needs to be computed using the ``num_batches_processed`` processed information. + """This function is used to compute the effective iteration. + + As DataLoader can perform ``prefecthing`` or training can fail while processing a batch, the current iteration + needs to be computed using the ``num_batches_processed`` processed information. """ if num_batches_processed is not None: current_iteration = num_batches_processed @@ -148,9 +154,11 @@ def from_state_dict(cls, state_dict) -> "IteratorState": @dataclass class MergedIteratorState: - """This class is used to hold the current iterator state and lives on the iterator. It holds the current merged - states from all worker processes. Once an iterator advances, it can store updates of the worker states in this - merged iterator state.""" + """This class is used to hold the current iterator state and lives on the iterator. + + It holds the current merged states from all worker processes. Once an iterator advances, it can store updates of the + worker states in this merged iterator state. + """ state: Union[Dict[Union[int, str], Union[Dict[str, IteratorState], IteratorState]]] = field(default_factory=dict) latest_worker_id: int = 0 @@ -259,12 +267,11 @@ def set_rng_states(rng_state_dict: Dict[str, Any]) -> None: class CaptureIterableDataset(IterableDataset): - """ - The ``CaptureIterableDataset`` is used to wrap an :class:`torch.utils.data.IterableDataset`. - On ``__iter__`` function call, the ``CaptureIterableDataset`` will wrap the wrapped dataset - generators into ``FastForwardSampler`` to keep track of progress. - On ``__next__`` function call, the ``CaptureIterableDataset`` will return a dictionary containing - user data and metadata containing the ``FastForwardSampler`` samplers state_dict. + """The ``CaptureIterableDataset`` is used to wrap an :class:`torch.utils.data.IterableDataset`. + + On ``__iter__`` function call, the ``CaptureIterableDataset`` will wrap the wrapped dataset generators into + ``FastForwardSampler`` to keep track of progress. On ``__next__`` function call, the ``CaptureIterableDataset`` will + return a dictionary containing user data and metadata containing the ``FastForwardSampler`` samplers state_dict. """ def __init__(self, dataset: IterableDataset) -> None: @@ -354,9 +361,7 @@ def __next__(self) -> Dict[str, Any]: def _find_fast_forward_samplers(dataloader: DataLoader) -> Optional[FastForwardSampler]: - """ - If the ``DataLoader`` is wrapping a mapping based Dataset, return the ``FastForwardSampler``. - """ + """If the ``DataLoader`` is wrapping a mapping based Dataset, return the ``FastForwardSampler``.""" if isinstance(dataloader.sampler, FastForwardSampler): return dataloader.sampler @@ -365,9 +370,8 @@ def _find_fast_forward_samplers(dataloader: DataLoader) -> Optional[FastForwardS def _cycle_to_next_worker_and_reset(dataloader: DataLoader, state_dict: Dict[str, Any]) -> Iterator: - """ - This function is used to cycle back the DataLoader ``_MultiProcessingDataLoaderIter`` - workers and call the reset function. + """This function is used to cycle back the DataLoader ``_MultiProcessingDataLoaderIter`` workers and call the + reset function. Returns: iterator: Return the iterator generated from the provided ``DataLoader``. @@ -404,9 +408,7 @@ def _cycle_to_next_worker_and_reset(dataloader: DataLoader, state_dict: Dict[str def _dataloader_to_state_dict( dataloader: DataLoader, iterator: Iterator, num_batches_processed: int = None ) -> List[Dict[str, Any]]: - """ - Convert a dataloader to its associated state dict - """ + """Convert a dataloader to its associated state dict.""" out = {} if iterator is not None: out.update(_find_current_worker(iterator)) @@ -419,9 +421,7 @@ def _dataloader_to_state_dict( def _dataloader_load_state_dict(dataloader: DataLoader, state_dict: List[Dict[str, Any]]) -> DataLoader: - """ - Reload ``DataLoader`` fast-forward sampler state dict. - """ + """Reload ``DataLoader`` fast-forward sampler state dict.""" fast_forward_sampler = _find_fast_forward_samplers(dataloader) if isinstance(fast_forward_sampler, Sampler): @@ -452,9 +452,9 @@ def _find_current_worker(iterator: Iterator) -> Dict[str, Optional[int]]: def _capture_metadata_collate(samples: List, dataset: Dataset, default_collate: Callable) -> Dict: - """A collate function that adds the state dict of a :class:`CaptureIterableDataset` or :class:`CaptureMapDataset` - used in the worker processes. This function gets executed within the worker processes. - The structure will be: + """A collate function that adds the state dict of a :class:`CaptureIterableDataset` or + :class:`CaptureMapDataset` used in the worker processes. This function gets executed within the worker + processes. The structure will be: .. code-block:: python @@ -483,8 +483,7 @@ def patch_dataloader_iterator( data_fetcher: "pl.utilities.fetching.DataFetcher", num_batches_fetched: int = 0, ) -> None: - """ - Patches the iterator of a PyTorch dataloader by injecting logic for fault-tolerant training when it is + """Patches the iterator of a PyTorch dataloader by injecting logic for fault-tolerant training when it is necessary to remove the sampler state dict from provided data batch. The custom data has this format: diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 2f9b23cbfe05a..7a9546ddf74bd 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -36,10 +36,10 @@ class LightningArgumentParser(ArgumentParser): - """Extension of jsonargparse's ArgumentParser for pytorch-lightning""" + """Extension of jsonargparse's ArgumentParser for pytorch-lightning.""" def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None: - """Initialize argument parser that supports configuration file input + """Initialize argument parser that supports configuration file input. For full details of accepted arguments see `ArgumentParser.__init__ `_. @@ -68,8 +68,7 @@ def add_lightning_class_args( nested_key: str, subclass_mode: bool = False, ) -> List[str]: - """ - Adds arguments from a lightning class to a nested key of the parser + """Adds arguments from a lightning class to a nested key of the parser. Args: lightning_class: A callable or any subclass of {Trainer, LightningModule, LightningDataModule, Callback}. @@ -103,8 +102,7 @@ def add_optimizer_args( nested_key: str = "optimizer", link_to: str = "AUTOMATIC", ) -> None: - """ - Adds arguments from an optimizer class to a nested key of the parser + """Adds arguments from an optimizer class to a nested key of the parser. Args: optimizer_class: Any subclass of torch.optim.Optimizer. @@ -128,8 +126,7 @@ def add_lr_scheduler_args( nested_key: str = "lr_scheduler", link_to: str = "AUTOMATIC", ) -> None: - """ - Adds arguments from a learning rate scheduler class to a nested key of the parser + """Adds arguments from a learning rate scheduler class to a nested key of the parser. Args: lr_scheduler_class: Any subclass of ``torch.optim.lr_scheduler.{_LRScheduler, ReduceLROnPlateau}``. @@ -149,7 +146,7 @@ def add_lr_scheduler_args( class SaveConfigCallback(Callback): - """Saves a LightningCLI config to the log_dir when training starts + """Saves a LightningCLI config to the log_dir when training starts. Raises: RuntimeError: If the config file already exists in the directory to avoid overwriting a previous run @@ -193,7 +190,7 @@ def __reduce__(self) -> Tuple[Type["SaveConfigCallback"], Tuple, Dict]: class LightningCLI: - """Implementation of a configurable command line tool for pytorch-lightning""" + """Implementation of a configurable command line tool for pytorch-lightning.""" def __init__( self, @@ -213,9 +210,8 @@ def __init__( subclass_mode_data: bool = False, run: bool = True, ) -> None: - """ - Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which are - called / instantiated using a parsed configuration file and / or command line args. + """Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which + are called / instantiated using a parsed configuration file and / or command line args. Parsing of configuration from environment variables can be enabled by setting ``env_parse=True``. A full configuration yaml would be parsed from ``PL_CONFIG`` if set. @@ -333,8 +329,7 @@ def _add_arguments(self, parser: LightningArgumentParser) -> None: self.link_optimizers_and_lr_schedulers(parser) def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: - """ - Implement to add extra arguments to the parser or link arguments. + """Implement to add extra arguments to the parser or link arguments. Args: parser: The parser object to which arguments can be added @@ -404,8 +399,7 @@ def instantiate_classes(self) -> None: self.trainer = self.instantiate_trainer() def instantiate_trainer(self, **kwargs: Any) -> Trainer: - """ - Instantiates the trainer. + """Instantiates the trainer. Args: kwargs: Any custom trainer arguments. @@ -438,11 +432,10 @@ def _parser(self, subcommand: Optional[str]) -> ArgumentParser: return action_subcommand._name_parser_map[subcommand] def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None: - """ - Adds to the model an automatically generated ``configure_optimizers`` method. + """Adds to the model an automatically generated ``configure_optimizers`` method. - If a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC', - then a `configure_optimizers` method is automatically implemented in the model class. + If a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC', then a + `configure_optimizers` method is automatically implemented in the model class. """ parser = self._parser(subcommand) optimizers_and_lr_schedulers = parser.optimizers_and_lr_schedulers diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index 86dcdc3b68a21..4669cc2020b16 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -23,8 +23,7 @@ def extract_batch_size(batch: BType) -> int: - """ - Recursively unpack a batch to find a torch.Tensor. + """Recursively unpack a batch to find a torch.Tensor. Returns: ``len(tensor)`` when found, or ``1`` when it hits an empty or non iterable. @@ -48,9 +47,8 @@ def has_iterable_dataset(dataloader: DataLoader) -> bool: def has_len(dataloader: DataLoader) -> bool: - """ - Checks if a given Dataloader has ``__len__`` method implemented i.e. if - it is a finite dataloader or infinite dataloader. + """Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or + infinite dataloader. Raises: ValueError: @@ -78,7 +76,10 @@ def has_len(dataloader: DataLoader) -> bool: def get_len(dataloader: DataLoader) -> Union[int, float]: - """Return the length of the given DataLoader. If ``__len__`` method is not implemented, return float('inf').""" + """Return the length of the given DataLoader. + + If ``__len__`` method is not implemented, return float('inf'). + """ if has_len(dataloader): return len(dataloader) diff --git a/pytorch_lightning/utilities/deepspeed.py b/pytorch_lightning/utilities/deepspeed.py index bad90fd14fb33..58002c161f571 100644 --- a/pytorch_lightning/utilities/deepspeed.py +++ b/pytorch_lightning/utilities/deepspeed.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Modified script from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/utils/zero_to_fp32.py +"""Modified script from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/utils/zero_to_fp32.py. This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets copied into the top level checkpoint dir, so the user can easily do the conversion at any point in diff --git a/pytorch_lightning/utilities/device_parser.py b/pytorch_lightning/utilities/device_parser.py index 11d977416bf55..6656b9765ba00 100644 --- a/pytorch_lightning/utilities/device_parser.py +++ b/pytorch_lightning/utilities/device_parser.py @@ -132,9 +132,8 @@ def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[in def _sanitize_gpu_ids(gpus: List[int]) -> List[int]: - """ - Checks that each of the GPUs in the list is actually available. - Raises a MisconfigurationException if any of the GPUs is not available. + """Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of + the GPUs is not available. Args: gpus: list of ints corresponding to GPU indices @@ -178,8 +177,7 @@ def _get_all_available_gpus() -> List[int]: def _check_unique(device_ids: List[int]) -> None: - """ - Checks that the device_ids are unique. + """Checks that the device_ids are unique. Args: device_ids: list of ints corresponding to gpus indices @@ -193,9 +191,8 @@ def _check_unique(device_ids: List[int]) -> None: def _check_data_type(device_ids: Any) -> None: - """ - Checks that the device_ids argument is one of: None, Int, String or List. - Raises a MisconfigurationException otherwise. + """Checks that the device_ids argument is one of: None, Int, String or List. Raises a MisconfigurationException + otherwise. Args: device_ids: gpus/tpu_cores parameter as passed to the Trainer diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 73de501fc9911..3db371b252490 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -109,9 +109,7 @@ def rank_zero_info(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None: def gather_all_tensors(result: torch.Tensor, group: Optional[Any] = None) -> List[torch.Tensor]: - """ - Function to gather all tensors from several ddp processes onto a list that - is broadcasted to all processes + """Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes. Args: result: the value to sync @@ -164,8 +162,7 @@ def sync_ddp_if_available( def sync_ddp( result: torch.Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None ) -> torch.Tensor: - """ - Function to reduce the tensors from several ddp processes to one master process + """Function to reduce the tensors from several ddp processes to one master process. Args: result: the value to sync and reduce (typically tensor or number) @@ -228,8 +225,7 @@ def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]: def all_gather_ddp_if_available( tensor: torch.Tensor, group: Optional["torch.distributed.ProcessGroup"] = None, sync_grads: bool = False ) -> torch.Tensor: - """ - Function to gather a tensor from several distributed processes + """Function to gather a tensor from several distributed processes. Args: tensor: tensor of shape (batch, ...) @@ -254,9 +250,7 @@ def register_ddp_comm_hook( ddp_comm_hook: Optional[Callable] = None, ddp_comm_wrapper: Optional[Callable] = None, ) -> None: - """ - Function to register communication hook for DDP model - https://pytorch.org/docs/master/ddp_comm_hooks.html + """Function to register communication hook for DDP model https://pytorch.org/docs/master/ddp_comm_hooks.html. Args: model: @@ -373,9 +367,8 @@ def init_ddp_connection( world_size: Optional[int] = None, **kwargs: Any, ) -> None: - """ - Utility function to initialize DDP connection by setting env variables - and initiliazing the distributed process group. + """Utility function to initialize DDP connection by setting env variables and initiliazing the distributed + process group. Args: cluster_environment: ``ClusterEnvironment`` instance diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 73fafabe8f5d9..50c52fd57a7eb 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Enumerated utilities""" +"""Enumerated utilities.""" from enum import Enum from typing import List, Optional, Union @@ -85,7 +85,7 @@ class DistributedType(LightningEnum): @staticmethod def interactive_compatible_types() -> List["DistributedType"]: - """Returns a list containing interactive compatible DistributeTypes""" + """Returns a list containing interactive compatible DistributeTypes.""" return [ DistributedType.DP, DistributedType.DDP_SPAWN, @@ -94,7 +94,7 @@ def interactive_compatible_types() -> List["DistributedType"]: ] def is_interactive_compatible(self) -> bool: - """Returns whether self is interactive compatible""" + """Returns whether self is interactive compatible.""" return self in DistributedType.interactive_compatible_types() DP = "dp" diff --git a/pytorch_lightning/utilities/exceptions.py b/pytorch_lightning/utilities/exceptions.py index bf5258f4f5f36..164b4c7c8e6e1 100644 --- a/pytorch_lightning/utilities/exceptions.py +++ b/pytorch_lightning/utilities/exceptions.py @@ -14,12 +14,8 @@ class MisconfigurationException(Exception): - """ - Exception used to inform users of mis-use with PyTorch Lightning - """ + """Exception used to inform users of mis-use with PyTorch Lightning.""" class DeadlockDetectedException(Exception): - """ - Exception used when a deadlock has been detected and processes are being killed - """ + """Exception used when a deadlock has been detected and processes are being killed.""" diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index d37cd3a9c1e6f..f0f09401ab47e 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -37,9 +37,8 @@ class AbstractDataFetcher(ABC): - """ - This base class should be used to implement a fault tolerant ``DataFetcher``. - It is required to override the ``fetching_function`` with fetching logic. + """This base class should be used to implement a fault tolerant ``DataFetcher``. It is required to override the + ``fetching_function`` with fetching logic. Example:: @@ -215,9 +214,8 @@ def teardown(self) -> None: class DataFetcher(AbstractDataFetcher): - """ - This class is used to control batch fetching flow. - By default, the ``fetching_function`` will pre-fetch a batch in advance to detect the end of the iteration. + """This class is used to control batch fetching flow. By default, the ``fetching_function`` will pre-fetch a + batch in advance to detect the end of the iteration. Args: prefetch_batches: Number of batches to be pre-fetched. Lightning will pre-fetch @@ -235,14 +233,14 @@ def __init__( @contextmanager def fetching_context(self): - """Hook to override to add context logic around batch fetching""" + """Hook to override to add context logic around batch fetching.""" yield def on_fetch_start(self) -> None: - """Hook to override to handle the logic before fetching a batch""" + """Hook to override to handle the logic before fetching a batch.""" def on_fetch_end(self, batch, on_fetch_start_output: Optional[Any] = None) -> None: - """Hook to extend which handles the logic after fetching a batch""" + """Hook to extend which handles the logic after fetching a batch.""" if self.store_on_device: batch = self.move_data_to_device(batch) self.append_batch(batch) @@ -321,9 +319,8 @@ def move_data_to_device(self, batch: Any) -> Any: class InterBatchParallelDataFetcher(DataFetcher): - """ - This class implements inter-batch parallelism, which aims at hiding the latency of host-to-device copy - of input batches behind computationally intensive operations. + """This class implements inter-batch parallelism, which aims at hiding the latency of host-to-device copy of + input batches behind computationally intensive operations. code-block:: @@ -351,7 +348,7 @@ def __init__( @contextmanager def fetching_context(self): - """Wrap the batch fetching logic under a cuda stream""" + """Wrap the batch fetching logic under a cuda stream.""" with torch.cuda.stream(self.cuda_stream): yield @@ -375,10 +372,8 @@ def wait(self) -> None: class StepFuncDataLoaderIter: - """ - This class is a wrapper to keep track of dataloader iterator fetching event - while left entirely to user control. - """ + """This class is a wrapper to keep track of dataloader iterator fetching event while left entirely to user + control.""" def __init__(self, iterator: Iterator, data_fetcher: "AbstractDataFetcher"): self.iterator = iterator @@ -399,10 +394,8 @@ def __next__(self) -> Any: class DataLoaderIterDataFetcher(AbstractDataFetcher): - """ - This class is used to return directly the `dataloader_iter` to the ``LightningModule`` training_step - for users to implement their own pre-fetching logic. - This feature can be activated as follows: + """This class is used to return directly the `dataloader_iter` to the ``LightningModule`` training_step for + users to implement their own pre-fetching logic. This feature can be activated as follows: Example:: diff --git a/pytorch_lightning/utilities/finite_checks.py b/pytorch_lightning/utilities/finite_checks.py index dc5a1cb0a84cc..4dfc5843de8c2 100644 --- a/pytorch_lightning/utilities/finite_checks.py +++ b/pytorch_lightning/utilities/finite_checks.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Helper functions to detect NaN/Inf values. """ +"""Helper functions to detect NaN/Inf values.""" import logging @@ -29,8 +29,7 @@ def print_nan_gradients(model: nn.Module) -> None: def detect_nan_parameters(model: nn.Module) -> None: - """ - Iterates over model parameters and prints gradients if any parameter is not finite. + """Iterates over model parameters and prints gradients if any parameter is not finite. Raises: ValueError: diff --git a/pytorch_lightning/utilities/grads.py b/pytorch_lightning/utilities/grads.py index 1a92bd4f6cb10..c1dfb2277b53a 100644 --- a/pytorch_lightning/utilities/grads.py +++ b/pytorch_lightning/utilities/grads.py @@ -11,9 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Utilities to describe gradients -""" +"""Utilities to describe gradients.""" from typing import Dict, Union import torch diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index eb64e03a559a9..b78bf1a512985 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""General utilities""" +"""General utilities.""" import importlib import operator import os @@ -25,8 +25,7 @@ def _module_available(module_path: str) -> bool: - """ - Check if a path is available in your environment + """Check if a path is available in your environment. >>> _module_available('os') True @@ -44,8 +43,7 @@ def _module_available(module_path: str) -> bool: def _compare_version(package: str, op, version) -> bool: - """ - Compare package version with some requirements + """Compare package version with some requirements. >>> _compare_version("torch", operator.ge, "0.1") True diff --git a/pytorch_lightning/utilities/memory.py b/pytorch_lightning/utilities/memory.py index 02636ba54e37d..4dd213552f6f3 100644 --- a/pytorch_lightning/utilities/memory.py +++ b/pytorch_lightning/utilities/memory.py @@ -124,8 +124,7 @@ def get_memory_profile(mode: str) -> Dict[str, float]: def get_gpu_memory_map() -> Dict[str, float]: - """ - Get the current gpu usage. + """Get the current gpu usage. Return: A dictionary in which the keys are device ids as integers and @@ -154,8 +153,7 @@ def get_gpu_memory_map() -> Dict[str, float]: def get_model_size_mb(model: Module) -> float: - """ - Calculates the size of a Module in megabytes by saving the model to a temporary file and reading its size. + """Calculates the size of a Module in megabytes by saving the model to a temporary file and reading its size. The computation includes everything in the :meth:`~torch.nn.Module.state_dict`, i.e., by default the parameteters and buffers. diff --git a/pytorch_lightning/utilities/metrics.py b/pytorch_lightning/utilities/metrics.py index d10f9a8045fec..c1519c59292ee 100644 --- a/pytorch_lightning/utilities/metrics.py +++ b/pytorch_lightning/utilities/metrics.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Helper functions to operate on metric values. """ +"""Helper functions to operate on metric values.""" import numbers from typing import Any @@ -22,8 +22,7 @@ def metrics_to_scalars(metrics: Any) -> Any: - """ - Recursively walk through a collection and convert single-item tensors to scalar values + """Recursively walk through a collection and convert single-item tensors to scalar values. Raises: MisconfigurationException: diff --git a/pytorch_lightning/utilities/model_summary.py b/pytorch_lightning/utilities/model_summary.py index d664d4774e870..727779162f629 100644 --- a/pytorch_lightning/utilities/model_summary.py +++ b/pytorch_lightning/utilities/model_summary.py @@ -36,9 +36,8 @@ class LayerSummary: - """ - Summary class for a single layer in a :class:`~pytorch_lightning.core.lightning.LightningModule`. - It collects the following information: + """Summary class for a single layer in a :class:`~pytorch_lightning.core.lightning.LightningModule`. It + collects the following information: - Type of the layer (e.g. Linear, BatchNorm1d, ...) - Input shape @@ -64,7 +63,6 @@ class LayerSummary: Args: module: A module to summarize - """ def __init__(self, module: nn.Module): @@ -78,11 +76,10 @@ def __del__(self): self.detach_hook() def _register_hook(self) -> Optional[RemovableHandle]: - """ - Registers a hook on the module that computes the input- and output size(s) on the first forward pass. - If the hook is called, it will remove itself from the from the module, meaning that - recursive models will only record their input- and output shapes once. - Registering hooks on :class:`~torch.jit.ScriptModule` is not supported. + """Registers a hook on the module that computes the input- and output size(s) on the first forward pass. If + the hook is called, it will remove itself from the from the module, meaning that recursive models will only + record their input- and output shapes once. Registering hooks on :class:`~torch.jit.ScriptModule` is not + supported. Return: A handle for the installed hook, or ``None`` if registering the hook is not possible. @@ -101,8 +98,8 @@ def hook(module, inp, out): return handle def detach_hook(self): - """ - Removes the forward hook if it was not already removed in the forward pass. + """Removes the forward hook if it was not already removed in the forward pass. + Will be called after the summary is created. """ if self._hook_handle is not None: @@ -128,8 +125,7 @@ def num_parameters(self) -> int: class ModelSummary: - """ - Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`. + """Generates a summary of all layers in a :class:`~pytorch_lightning.core.lightning.LightningModule`. Args: model: The model to summarize (also referred to as the root module). @@ -300,8 +296,7 @@ def _forward_example_input(self) -> None: model.train(mode) # restore mode of module def __str__(self): - """ - Makes a summary listing with: + """Makes a summary listing with: Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size """ @@ -337,11 +332,8 @@ def parse_batch_shape(batch: Any) -> Union[str, List]: def _format_summary_table(total_parameters: int, trainable_parameters: int, model_size: float, *cols) -> str: - """ - Takes in a number of arrays, each specifying a column in - the summary table, and combines them all into one big - string defining the summary table that are nicely formatted. - """ + """Takes in a number of arrays, each specifying a column in the summary table, and combines them all into one + big string defining the summary table that are nicely formatted.""" n_rows = len(cols[0][1]) n_cols = 1 + len(cols) @@ -383,9 +375,7 @@ def get_formatted_model_size(total_model_size: float) -> float: def get_human_readable_count(number: int) -> str: - """ - Abbreviates an integer number with K, M, B, T for thousands, millions, - billions and trillions, respectively. + """Abbreviates an integer number with K, M, B, T for thousands, millions, billions and trillions, respectively. Examples: >>> get_human_readable_count(123) @@ -406,7 +396,6 @@ def get_human_readable_count(number: int) -> str: Return: A string formatted according to the pattern described above. - """ assert number >= 0 labels = PARAMETER_NUM_UNITS @@ -438,8 +427,7 @@ def _is_lazy_weight_tensor(p: Tensor) -> bool: def summarize( lightning_module: "pl.LightningModule", mode: Optional[str] = "top", max_depth: Optional[int] = None ) -> Optional[ModelSummary]: - """ - Summarize the LightningModule specified by `lightning_module`. + """Summarize the LightningModule specified by `lightning_module`. Args: lightning_module: `LightningModule` to summarize. diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 8bb055da87781..0cec4aa8db5e2 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -30,12 +30,10 @@ def str_to_bool_or_str(val: str) -> Union[str, bool]: - """Possibly convert a string representation of truth to bool. - Returns the input otherwise. - Based on the python implementation distutils.utils.strtobool + """Possibly convert a string representation of truth to bool. Returns the input otherwise. Based on the python + implementation distutils.utils.strtobool. - True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values - are 'n', 'no', 'f', 'false', 'off', and '0'. + True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'. """ lower = val.lower() if lower in ("y", "yes", "t", "true", "on", "1"): @@ -88,7 +86,7 @@ def str_to_bool_or_int(val: str) -> Union[bool, int, str]: def is_picklable(obj: object) -> bool: - """Tests if an object can be pickled""" + """Tests if an object can be pickled.""" try: pickle.dumps(obj) @@ -98,7 +96,7 @@ def is_picklable(obj: object) -> bool: def clean_namespace(hparams: Union[Dict[str, Any], Namespace]) -> None: - """Removes all unpicklable entries from hparams""" + """Removes all unpicklable entries from hparams.""" hparams_dict = hparams if isinstance(hparams, Namespace): @@ -112,7 +110,7 @@ def clean_namespace(hparams: Union[Dict[str, Any], Namespace]) -> None: def parse_class_init_keys(cls: Type["pl.LightningModule"]) -> Tuple[str, Optional[str], Optional[str]]: - """Parse key words for standard self, *args and **kwargs + """Parse key words for standard self, *args and **kwargs. >>> class Model(): ... def __init__(self, hparams, *my_args, anykw=42, **my_kwargs): @@ -163,8 +161,7 @@ def get_init_args(frame: types.FrameType) -> Dict[str, Any]: def collect_init_args( frame: types.FrameType, path_args: List[Dict[str, Any]], inside: bool = False ) -> List[Dict[str, Any]]: - """ - Recursively collects the arguments passed to the child constructors in the inheritance tree. + """Recursively collects the arguments passed to the child constructors in the inheritance tree. Args: frame: the current stack frame @@ -303,9 +300,10 @@ def __repr__(self) -> str: def _lightning_get_all_attr_holders(model: "pl.LightningModule", attribute: str) -> List[Any]: - """ - Special attribute finding for Lightning. Gets all of the objects or dicts that holds attribute. - Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. + """Special attribute finding for Lightning. + + Gets all of the objects or dicts that holds attribute. Checks for attribute in model namespace, the old hparams + namespace/dict, and the datamodule. """ trainer = getattr(model, "trainer", None) @@ -328,10 +326,10 @@ def _lightning_get_all_attr_holders(model: "pl.LightningModule", attribute: str) def _lightning_get_first_attr_holder(model: "pl.LightningModule", attribute: str) -> Optional[Any]: - """ - Special attribute finding for Lightning. Gets the object or dict that holds attribute, or None. - Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule, - returns the last one that has it. + """Special attribute finding for Lightning. + + Gets the object or dict that holds attribute, or None. Checks for attribute in model namespace, the old hparams + namespace/dict, and the datamodule, returns the last one that has it. """ holders = _lightning_get_all_attr_holders(model, attribute) if len(holders) == 0: @@ -341,17 +339,16 @@ def _lightning_get_first_attr_holder(model: "pl.LightningModule", attribute: str def lightning_hasattr(model: "pl.LightningModule", attribute: str) -> bool: - """ - Special hasattr for Lightning. Checks for attribute in model namespace, - the old hparams namespace/dict, and the datamodule. + """Special hasattr for Lightning. + + Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """ return _lightning_get_first_attr_holder(model, attribute) is not None def lightning_getattr(model: "pl.LightningModule", attribute: str) -> Optional[Any]: - """ - Special getattr for Lightning. Checks for attribute in model namespace, - the old hparams namespace/dict, and the datamodule. + """Special getattr for Lightning. Checks for attribute in model namespace, the old hparams namespace/dict, and + the datamodule. Raises: AttributeError: @@ -371,9 +368,7 @@ def lightning_getattr(model: "pl.LightningModule", attribute: str) -> Optional[A def lightning_setattr(model: "pl.LightningModule", attribute: str, value: Any) -> None: - """ - Special setattr for Lightning. Checks for attribute in model namespace - and the old hparams namespace/dict. + """Special setattr for Lightning. Checks for attribute in model namespace and the old hparams namespace/dict. Will also set the attribute on datamodule, if it exists. Raises: diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index 732f2d8136b9e..dc64ffb78bade 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Helper functions to help with reproducibility of models. """ +"""Helper functions to help with reproducibility of models.""" import logging import os @@ -28,10 +28,8 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: - """ - Function that sets seed for pseudo-random number generators in: - pytorch, numpy, python.random - In addition, sets the following environment variables: + """Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition, + sets the following environment variables: - `PL_GLOBAL_SEED`: will be passed to spawned subprocesses (e.g. ddp_spawn backend). - `PL_SEED_WORKERS`: (optional) is set to 1 if ``workers=True``. @@ -79,8 +77,8 @@ def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> def reset_seed() -> None: - """ - Reset the seed to the value that :func:`pytorch_lightning.utilities.seed.seed_everything` previously set. + """Reset the seed to the value that :func:`pytorch_lightning.utilities.seed.seed_everything` previously set. + If :func:`pytorch_lightning.utilities.seed.seed_everything` is unused, this function will do nothing. """ seed = os.environ.get("PL_GLOBAL_SEED", None) @@ -90,9 +88,9 @@ def reset_seed() -> None: def pl_worker_init_function(worker_id: int, rank: Optional = None) -> None: # pragma: no cover - """ - The worker_init_fn that Lightning automatically adds to your dataloader if you previously set - set the seed with ``seed_everything(seed, workers=True)``. + """The worker_init_fn that Lightning automatically adds to your dataloader if you previously set set the seed + with ``seed_everything(seed, workers=True)``. + See also the PyTorch documentation on `randomness in DataLoaders `_. """ diff --git a/pytorch_lightning/utilities/warnings.py b/pytorch_lightning/utilities/warnings.py index 1949f7ec3e378..5a01e2a1e941d 100644 --- a/pytorch_lightning/utilities/warnings.py +++ b/pytorch_lightning/utilities/warnings.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Warning-related utilities""" +"""Warning-related utilities.""" import warnings from functools import partial diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py index 3a9073cfa122a..b922a749e7742 100644 --- a/pytorch_lightning/utilities/xla_device.py +++ b/pytorch_lightning/utilities/xla_device.py @@ -53,15 +53,14 @@ def wrapper(*args: Any, **kwargs: Any) -> Union[bool, Any]: class XLADeviceUtils: - """Used to detect the type of XLA device""" + """Used to detect the type of XLA device.""" _TPU_AVAILABLE = False @staticmethod @pl_multi_process def _is_device_tpu() -> bool: - """ - Check if TPU devices are available + """Check if TPU devices are available. Return: A boolean value indicating if TPU devices are available @@ -77,8 +76,7 @@ def _is_device_tpu() -> bool: @staticmethod def xla_available() -> bool: - """ - Check if XLA library is installed + """Check if XLA library is installed. Return: A boolean value indicating if a XLA is installed @@ -87,8 +85,7 @@ def xla_available() -> bool: @staticmethod def tpu_device_exists() -> bool: - """ - Runs XLA device check within a separate process + """Runs XLA device check within a separate process. Return: A boolean value indicating if a TPU device exists on the system diff --git a/requirements/collect_env_details.py b/requirements/collect_env_details.py index c643e4d73fd92..0e0d73b9e2aa6 100644 --- a/requirements/collect_env_details.py +++ b/requirements/collect_env_details.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Diagnose your system and show basic information +"""Diagnose your system and show basic information. This server mainly to get detail info for better bug reporting. - """ import os diff --git a/tests/accelerators/ddp_model.py b/tests/accelerators/ddp_model.py index 9b3246a0a40de..a102f2690797f 100644 --- a/tests/accelerators/ddp_model.py +++ b/tests/accelerators/ddp_model.py @@ -11,9 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Runs either `.fit()` or `.test()` on a single node across multiple gpus. -""" +"""Runs either `.fit()` or `.test()` on a single node across multiple gpus.""" import os from argparse import ArgumentParser diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index f4793a61b87c2..650b7949ac1ba 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -360,7 +360,7 @@ def _test_accelerator_choice_ddp_cpu_and_plugin(tmpdir, ddp_plugin_class): ) @mock.patch("torch.cuda.device_count", return_value=0) def test_accelerator_choice_ddp_cpu_custom_cluster(_, tmpdir): - """Test that we choose the custom cluster even when SLURM or TE flags are around""" + """Test that we choose the custom cluster even when SLURM or TE flags are around.""" class CustomCluster(LightningEnvironment): def master_address(self): diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index e67fb166f815b..61f0a1e247215 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -89,7 +89,7 @@ def configure_sharded_model(self): def test_configure_sharded_model_false(tmpdir): - """Ensure ``configure_sharded_model`` is not called, when turned off""" + """Ensure ``configure_sharded_model`` is not called, when turned off.""" class CustomPlugin(SingleDevicePlugin): @property @@ -110,7 +110,8 @@ def call_configure_sharded_model_hook(self) -> bool: def test_accelerator_configure_sharded_model_called_once(tmpdir): - """Ensure that the configure sharded model hook is called, and set to False after to ensure not called again.""" + """Ensure that the configure sharded model hook is called, and set to False after to ensure not called + again.""" model = DummyModel() trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=1) @@ -120,7 +121,7 @@ def test_accelerator_configure_sharded_model_called_once(tmpdir): def test_configure_sharded_model_called_once(tmpdir): - """Ensure ``configure_sharded_model`` is only called once""" + """Ensure ``configure_sharded_model`` is only called once.""" model = DummyModel() trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=1) diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index d695e4c63f43e..f95d182f9e5e1 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -15,10 +15,8 @@ @pytest.mark.parametrize("delay_dispatch", [True, False]) def test_plugin_setup_optimizers_in_pre_dispatch(tmpdir, delay_dispatch): - """ - Test when using a custom training type plugin that delays setup optimizers, - we do not call setup optimizers till ``pre_dispatch``. - """ + """Test when using a custom training type plugin that delays setup optimizers, we do not call setup optimizers + till ``pre_dispatch``.""" class TestModel(BoringModel): def on_fit_start(self): @@ -42,9 +40,7 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: def test_restore_checkpoint_after_pre_dispatch_default(): - """ - Assert default for restore_checkpoint_after_pre_dispatch is False. - """ + """Assert default for restore_checkpoint_after_pre_dispatch is False.""" plugin = SingleDevicePlugin(torch.device("cpu")) accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin()) assert not accelerator.restore_checkpoint_after_pre_dispatch @@ -53,10 +49,8 @@ def test_restore_checkpoint_after_pre_dispatch_default(): @pytest.mark.parametrize("restore_after_pre_dispatch", [True, False]) def test_restore_checkpoint_after_pre_dispatch(tmpdir, restore_after_pre_dispatch): - """ - Test to ensure that if restore_checkpoint_after_pre_dispatch is True, then we only load the state after - pre-dispatch is called. - """ + """Test to ensure that if restore_checkpoint_after_pre_dispatch is True, then we only load the state after pre- + dispatch is called.""" class TestPlugin(SingleDevicePlugin): predispatched_called = False diff --git a/tests/accelerators/test_ddp.py b/tests/accelerators/test_ddp.py index dc83e4ad4f02e..e03bc467a453a 100644 --- a/tests/accelerators/test_ddp.py +++ b/tests/accelerators/test_ddp.py @@ -80,9 +80,7 @@ def test_multi_gpu_model_ddp_fit_test(tmpdir, as_module): @RunIf(skip_windows=True) @pytest.mark.skipif(torch.cuda.is_available(), reason="test doesn't requires GPU machine") def test_torch_distributed_backend_env_variables(tmpdir): - """ - This test set `undefined` as torch backend and should raise an `Backend.UNDEFINED` ValueError. - """ + """This test set `undefined` as torch backend and should raise an `Backend.UNDEFINED` ValueError.""" _environ = {"PL_TORCH_DISTRIBUTED_BACKEND": "undefined", "CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2"} with patch.dict(os.environ, _environ), patch("torch.cuda.device_count", return_value=2): with pytest.raises(ValueError, match="Invalid backend: 'undefined'"): @@ -97,9 +95,7 @@ def test_torch_distributed_backend_env_variables(tmpdir): @mock.patch("torch.cuda.set_device") @mock.patch.dict(os.environ, {"PL_TORCH_DISTRIBUTED_BACKEND": "gloo"}, clear=True) def test_ddp_torch_dist_is_available_in_setup(mock_set_device, mock_is_available, mock_device_count, tmpdir): - """ - Test to ensure torch distributed is available within the setup hook using ddp - """ + """Test to ensure torch distributed is available within the setup hook using ddp.""" class TestModel(BoringModel): def setup(self, stage: Optional[str] = None) -> None: @@ -123,9 +119,7 @@ def test_ddp_wrapper_32(tmpdir): def _test_ddp_wrapper(tmpdir, precision): - """ - Test parameters to ignore are carried over for DDP. - """ + """Test parameters to ignore are carried over for DDP.""" class WeirdModule(torch.nn.Module): def _save_to_state_dict(self, destination, prefix, keep_vars): diff --git a/tests/accelerators/test_dp.py b/tests/accelerators/test_dp.py index efaf761cb7116..8e09460551dec 100644 --- a/tests/accelerators/test_dp.py +++ b/tests/accelerators/test_dp.py @@ -55,7 +55,10 @@ def test_step_end(self, outputs): @RunIf(min_gpus=2) def test_multi_gpu_early_stop_dp(tmpdir): - """Make sure DDP works. with early stopping""" + """Make sure DDP works. + + with early stopping + """ tutils.set_random_master_port() dm = ClassifDataModule() @@ -136,9 +139,7 @@ def training_epoch_end(self, outputs): def test_dp_raise_exception_with_batch_transfer_hooks(tmpdir, monkeypatch): - """ - Test that an exception is raised when overriding batch_transfer_hooks in DP model. - """ + """Test that an exception is raised when overriding batch_transfer_hooks in DP model.""" monkeypatch.setattr("torch.cuda.device_count", lambda: 2) class CustomModel(BoringModel): @@ -179,7 +180,7 @@ def on_after_batch_transfer(self, batch, dataloader_idx): @RunIf(min_gpus=2) def test_dp_training_step_dict(tmpdir): - """This test verifies that dp properly reduces dictionaries""" + """This test verifies that dp properly reduces dictionaries.""" model = ReductionTestModel() model.training_step_end = None model.validation_step_end = None diff --git a/tests/accelerators/test_ipu.py b/tests/accelerators/test_ipu.py index ba2de43da110e..d76cf68d32801 100644 --- a/tests/accelerators/test_ipu.py +++ b/tests/accelerators/test_ipu.py @@ -265,7 +265,7 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: @RunIf(ipu=True) def test_stages_correct(tmpdir): - """Ensure all stages correctly are traced correctly by asserting the output for each stage""" + """Ensure all stages correctly are traced correctly by asserting the output for each stage.""" class StageModel(IPUModel): def training_step(self, batch, batch_idx): @@ -363,10 +363,8 @@ def test_manual_poptorch_opts(tmpdir): @RunIf(ipu=True) def test_manual_poptorch_opts_custom(tmpdir): - """ - Ensure if the user passes manual poptorch Options with custom parameters set, - we respect them in our poptorch options and the dataloaders. - """ + """Ensure if the user passes manual poptorch Options with custom parameters set, we respect them in our + poptorch options and the dataloaders.""" model = IPUModel() training_opts = poptorch.Options() @@ -418,10 +416,8 @@ def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: @RunIf(ipu=True) def test_replication_factor(tmpdir): - """ - Ensure if the user passes manual poptorch Options with custom parameters set, - we set them correctly in the dataloaders. - """ + """Ensure if the user passes manual poptorch Options with custom parameters set, we set them correctly in the + dataloaders.""" plugin = IPUPlugin() trainer = Trainer(ipus=2, default_root_dir=tmpdir, fast_dev_run=True, plugins=plugin) @@ -430,9 +426,7 @@ def test_replication_factor(tmpdir): @RunIf(ipu=True) def test_default_opts(tmpdir): - """ - Ensure default opts are set correctly in the IPUPlugin. - """ + """Ensure default opts are set correctly in the IPUPlugin.""" model = IPUModel() @@ -450,9 +444,7 @@ def test_default_opts(tmpdir): @RunIf(ipu=True) def test_multi_optimizers_fails(tmpdir): - """ - Ensure if there are multiple optimizers, we throw an exception - """ + """Ensure if there are multiple optimizers, we throw an exception.""" class TestModel(IPUModel): def configure_optimizers(self): @@ -467,9 +459,7 @@ def configure_optimizers(self): @RunIf(ipu=True) def test_precision_plugin(tmpdir): - """ - Ensure precision plugin value is set correctly. - """ + """Ensure precision plugin value is set correctly.""" plugin = IPUPrecisionPlugin(precision=16) assert plugin.precision == 16 diff --git a/tests/accelerators/test_multi_nodes_gpu.py b/tests/accelerators/test_multi_nodes_gpu.py index ea591e47041f8..a1abed776c12d 100644 --- a/tests/accelerators/test_multi_nodes_gpu.py +++ b/tests/accelerators/test_multi_nodes_gpu.py @@ -33,9 +33,7 @@ @pytest.mark.skip("Multi-node testing is currently disabled") @RunIf(special=True) def test_logging_sync_dist_true_ddp(tmpdir): - """ - Tests to ensure that the sync_dist flag works with CPU (should just return the original value) - """ + """Tests to ensure that the sync_dist flag works with CPU (should just return the original value)""" fake_result = 1 class TestModel(BoringModel): @@ -72,9 +70,7 @@ def validation_step(self, batch, batch_idx): @pytest.mark.skip("Multi-node testing is currently disabled") @RunIf(special=True) def test__validation_step__log(tmpdir): - """ - Tests that validation_step can log - """ + """Tests that validation_step can log.""" class TestModel(BoringModel): def training_step(self, batch, batch_idx): diff --git a/tests/accelerators/test_tpu_backend.py b/tests/accelerators/test_tpu_backend.py index 99ac579eb99b0..7f7bad327f515 100644 --- a/tests/accelerators/test_tpu_backend.py +++ b/tests/accelerators/test_tpu_backend.py @@ -47,7 +47,7 @@ def forward(self, x): @RunIf(tpu=True) @pl_multi_process_test def test_resume_training_on_cpu(tmpdir): - """Checks if training can be resumed from a saved checkpoint on CPU""" + """Checks if training can be resumed from a saved checkpoint on CPU.""" # Train a model on TPU model = BoringModel() trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=8) @@ -83,10 +83,7 @@ def test_if_test_works_after_train(tmpdir): @RunIf(tpu=True) @pl_multi_process_test def test_weight_tying_warning(tmpdir, capsys=None): - """ - Ensure a warning is thrown if model parameter lengths do not match - post moving to device. - """ + """Ensure a warning is thrown if model parameter lengths do not match post moving to device.""" model = WeightSharingModule() trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1) @@ -98,8 +95,8 @@ def test_weight_tying_warning(tmpdir, capsys=None): @RunIf(tpu=True) @pl_multi_process_test def test_if_weights_tied(tmpdir, capsys=None): - """ - Test if weights are properly tied on `on_post_move_to_device`. + """Test if weights are properly tied on `on_post_move_to_device`. + Ensure no warning for parameter mismatch is thrown. """ diff --git a/tests/base/model_optimizers.py b/tests/base/model_optimizers.py index 206f471c443a9..26a88618f0f56 100644 --- a/tests/base/model_optimizers.py +++ b/tests/base/model_optimizers.py @@ -18,16 +18,16 @@ class ConfigureOptimizersPool(ABC): def configure_optimizers(self): - """ - return whatever optimizers we want here. + """return whatever optimizers we want here. + :return: list of optimizers """ optimizer = optim.Adam(self.parameters(), lr=self.learning_rate) return optimizer def configure_optimizers__lbfgs(self): - """ - return whatever optimizers we want here. + """return whatever optimizers we want here. + :return: list of optimizers """ optimizer = optim.LBFGS(self.parameters(), lr=self.learning_rate) diff --git a/tests/base/model_template.py b/tests/base/model_template.py index 5165df37fce9c..f557f2bfdf797 100644 --- a/tests/base/model_template.py +++ b/tests/base/model_template.py @@ -46,8 +46,7 @@ class EvalModelTemplate( ConfigureOptimizersPool, LightningModule, ): - """ - This template houses all combinations of model configurations we want to test + """This template houses all combinations of model configurations we want to test. >>> model = EvalModelTemplate() """ diff --git a/tests/base/model_test_dataloaders.py b/tests/base/model_test_dataloaders.py index 5008dd2f2352c..2082ccad48cad 100644 --- a/tests/base/model_test_dataloaders.py +++ b/tests/base/model_test_dataloaders.py @@ -19,7 +19,7 @@ class TestDataloaderVariations(ABC): @abstractmethod def dataloader(self, *args, **kwargs): - """placeholder""" + """placeholder.""" def test_dataloader(self): return self.dataloader(train=False) diff --git a/tests/base/model_test_epoch_ends.py b/tests/base/model_test_epoch_ends.py index f94feb036f6ed..746ceb94a5de0 100644 --- a/tests/base/model_test_epoch_ends.py +++ b/tests/base/model_test_epoch_ends.py @@ -20,8 +20,8 @@ class TestEpochEndVariations(ABC): def test_epoch_end(self, outputs): - """ - Called at the end of test epoch to aggregate outputs + """Called at the end of test epoch to aggregate outputs. + :param outputs: list of individual outputs of each validation step :return: """ @@ -53,8 +53,8 @@ def test_epoch_end(self, outputs): return result def test_epoch_end__multiple_dataloaders(self, outputs): - """ - Called at the end of test epoch to aggregate outputs + """Called at the end of test epoch to aggregate outputs. + :param outputs: list of individual outputs of each validation step :return: """ diff --git a/tests/base/model_test_steps.py b/tests/base/model_test_steps.py index 99353f0702d22..5bb0c641e4c13 100644 --- a/tests/base/model_test_steps.py +++ b/tests/base/model_test_steps.py @@ -18,13 +18,11 @@ class TestStepVariations(ABC): - """ - Houses all variations of test steps - """ + """Houses all variations of test steps.""" def test_step(self, batch, batch_idx, *args, **kwargs): - """ - Default, baseline test_step + """Default, baseline test_step. + :param batch: :return: """ @@ -57,8 +55,8 @@ def test_step(self, batch, batch_idx, *args, **kwargs): return output def test_step__multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kwargs): - """ - Default, baseline test_step + """Default, baseline test_step. + :param batch: :return: """ diff --git a/tests/base/model_train_dataloaders.py b/tests/base/model_train_dataloaders.py index 70ff21d2d34fa..49646fd3cf025 100644 --- a/tests/base/model_train_dataloaders.py +++ b/tests/base/model_train_dataloaders.py @@ -19,7 +19,7 @@ class TrainDataloaderVariations(ABC): @abstractmethod def dataloader(self, train: bool, *args, **kwargs): - """placeholder""" + """placeholder.""" def train_dataloader(self): return self.dataloader(train=True) @@ -37,7 +37,7 @@ def train_dataloader__zero_length(self): return dataloader def train_dataloader__multiple_mapping(self): - """Return a mapping loaders with different lengths""" + """Return a mapping loaders with different lengths.""" # List[DataLoader] loaders_a_b = [self.dataloader(num_samples=100, train=True), self.dataloader(num_samples=50, train=True)] diff --git a/tests/base/model_train_steps.py b/tests/base/model_train_steps.py index 70a7691d69386..eb1dca615d011 100644 --- a/tests/base/model_train_steps.py +++ b/tests/base/model_train_steps.py @@ -15,12 +15,10 @@ class TrainingStepVariations(ABC): - """ - Houses all variations of training steps - """ + """Houses all variations of training steps.""" def training_step(self, batch, batch_idx, optimizer_idx=None): - """Lightning calls this inside the training loop""" + """Lightning calls this inside the training loop.""" self.training_step_called = True # forward pass @@ -33,7 +31,7 @@ def training_step(self, batch, batch_idx, optimizer_idx=None): return {"loss": loss_train} def training_step__multiple_dataloaders(self, batch, batch_idx, optimizer_idx=None): - """Training step for multiple train loaders""" + """Training step for multiple train loaders.""" assert isinstance(batch, dict) assert len(batch) == 2 diff --git a/tests/base/model_valid_dataloaders.py b/tests/base/model_valid_dataloaders.py index 1df885d8b6f4d..3fc19002b24a0 100644 --- a/tests/base/model_valid_dataloaders.py +++ b/tests/base/model_valid_dataloaders.py @@ -19,7 +19,7 @@ class ValDataloaderVariations(ABC): @abstractmethod def dataloader(self, *args, **kwargs): - """placeholder""" + """placeholder.""" def val_dataloader(self): return self.dataloader(train=False) diff --git a/tests/base/model_valid_epoch_ends.py b/tests/base/model_valid_epoch_ends.py index bbe06d3c8d203..718d2b699da23 100644 --- a/tests/base/model_valid_epoch_ends.py +++ b/tests/base/model_valid_epoch_ends.py @@ -17,13 +17,10 @@ class ValidationEpochEndVariations(ABC): - """ - Houses all variations of validation_epoch_end steps - """ + """Houses all variations of validation_epoch_end steps.""" def validation_epoch_end(self, outputs): - """ - Called at the end of validation to aggregate outputs + """Called at the end of validation to aggregate outputs. Args: outputs: list of individual outputs of each validation step @@ -47,8 +44,7 @@ def _mean(res, key): self.log("val_acc", val_acc_mean, prog_bar=True) def validation_epoch_end__multiple_dataloaders(self, outputs): - """ - Called at the end of validation to aggregate outputs + """Called at the end of validation to aggregate outputs. Args: outputs: list of individual outputs of each validation step diff --git a/tests/base/model_valid_steps.py b/tests/base/model_valid_steps.py index 6f131436680d2..96d2f1bc9e129 100644 --- a/tests/base/model_valid_steps.py +++ b/tests/base/model_valid_steps.py @@ -18,13 +18,11 @@ class ValidationStepVariations(ABC): - """ - Houses all variations of validation steps - """ + """Houses all variations of validation steps.""" def validation_step(self, batch, batch_idx, *args, **kwargs): - """ - Lightning calls this inside the validation loop + """Lightning calls this inside the validation loop. + :param batch: :return: """ @@ -62,8 +60,8 @@ def validation_step__dp(self, batch, batch_idx, *args, **kwargs): return loss_val def validation_step__multiple_dataloaders(self, batch, batch_idx, dataloader_idx, **kwargs): - """ - Lightning calls this inside the validation loop + """Lightning calls this inside the validation loop. + :param batch: :return: """ diff --git a/tests/callbacks/test_callback_hook_outputs.py b/tests/callbacks/test_callback_hook_outputs.py index 1a2ecae7b94c0..45a0c364e1936 100644 --- a/tests/callbacks/test_callback_hook_outputs.py +++ b/tests/callbacks/test_callback_hook_outputs.py @@ -19,9 +19,7 @@ @pytest.mark.parametrize("single_cb", [False, True]) def test_train_step_no_return(tmpdir, single_cb: bool): - """ - Tests that only training_step can be used - """ + """Tests that only training_step can be used.""" class CB(Callback): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 049e409842564..7cab0d8776056 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -56,8 +56,8 @@ def on_train_epoch_end(self, trainer, pl_module): def test_resume_early_stopping_from_checkpoint(tmpdir): - """ - Prevent regressions to bugs: + """Prevent regressions to bugs: + https://github.com/PyTorchLightning/pytorch-lightning/issues/1464 https://github.com/PyTorchLightning/pytorch-lightning/issues/1463 """ @@ -255,9 +255,9 @@ def validation_epoch_end(self, outputs): @pytest.mark.parametrize("step_freeze, min_steps, min_epochs", [(5, 1, 1), (5, 1, 3), (3, 15, 1)]) def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze: int, min_steps: int, min_epochs: int): - """Excepted Behaviour: - IF `min_steps` was set to a higher value than the `trainer.global_step` when `early_stopping` is being triggered, - THEN the trainer should continue until reaching `trainer.global_step` == `min_steps`, and stop. + """Excepted Behaviour: IF `min_steps` was set to a higher value than the `trainer.global_step` when + `early_stopping` is being triggered, THEN the trainer should continue until reaching `trainer.global_step` == + `min_steps`, and stop. IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step` when `early_stopping` is being triggered, diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py index 9a28c7a8fc478..c014c8e736874 100644 --- a/tests/callbacks/test_finetuning_callback.py +++ b/tests/callbacks/test_finetuning_callback.py @@ -40,7 +40,7 @@ def on_train_epoch_start(self, trainer, pl_module): def test_finetuning_callback(tmpdir): - """Test finetuning callbacks works as expected""" + """Test finetuning callbacks works as expected.""" seed_everything(42) @@ -89,7 +89,7 @@ def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int): def test_finetuning_callback_warning(tmpdir): - """Test finetuning callbacks works as expected""" + """Test finetuning callbacks works as expected.""" seed_everything(42) @@ -133,7 +133,7 @@ def configure_optimizers(self): def test_freeze_unfreeze_function(tmpdir): - """Test freeze properly sets requires_grad on the modules""" + """Test freeze properly sets requires_grad on the modules.""" seed_everything(42) @@ -167,7 +167,7 @@ def __init__(self): def test_unfreeze_and_add_param_group_function(tmpdir): - """Test unfreeze_and_add_param_group properly unfreeze parameters and add to the correct param_group""" + """Test unfreeze_and_add_param_group properly unfreeze parameters and add to the correct param_group.""" seed_everything(42) @@ -220,7 +220,8 @@ def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: O def test_base_finetuning_internal_optimizer_metadata(tmpdir): - """Test the param_groups updates are properly saved within the internal state of the BaseFinetuning Callbacks""" + """Test the param_groups updates are properly saved within the internal state of the BaseFinetuning + Callbacks.""" seed_everything(42) @@ -263,10 +264,8 @@ def configure_optimizers(self): def test_on_before_accelerator_backend_setup(tmpdir): - """ - `on_before_accelerator_backend_setup` hook is used by finetuning callbacks to freeze the model before - before configure_optimizers function call. - """ + """`on_before_accelerator_backend_setup` hook is used by finetuning callbacks to freeze the model before before + configure_optimizers function call.""" class TestCallback(Callback): def on_before_accelerator_backend_setup(self, trainer, pl_module): @@ -289,10 +288,8 @@ def configure_optimizers(self): def test_complex_nested_model(): - """ - Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters - directly themselves rather than exclusively their submodules containing parameters. - """ + """Test flattening, freezing, and thawing of models which contain parent (non-leaf) modules with parameters + directly themselves rather than exclusively their submodules containing parameters.""" class ConvBlock(nn.Module): def __init__(self, in_channels, out_channels): @@ -362,10 +359,8 @@ def configure_optimizers(self): def test_callbacks_restore(tmpdir): - """ - Test callbacks restore is called after optimizers have been re-created - but before optimizer states reload - """ + """Test callbacks restore is called after optimizers have been re-created but before optimizer states + reload.""" chk = ModelCheckpoint(dirpath=tmpdir, save_last=True) model = FinetuningBoringModel() @@ -412,10 +407,8 @@ def test_callbacks_restore(tmpdir): def test_callbacks_restore_backbone(tmpdir): - """ - Test callbacks restore is called after optimizers have been re-created - but before optimizer states reload - """ + """Test callbacks restore is called after optimizers have been re-created but before optimizer states + reload.""" class BackboneBoringModel(BoringModel): def __init__(self): diff --git a/tests/callbacks/test_gpu_stats_monitor.py b/tests/callbacks/test_gpu_stats_monitor.py index eaba4d30684f3..5ed3f533b5588 100644 --- a/tests/callbacks/test_gpu_stats_monitor.py +++ b/tests/callbacks/test_gpu_stats_monitor.py @@ -29,9 +29,7 @@ @RunIf(min_gpus=1) def test_gpu_stats_monitor(tmpdir): - """ - Test GPU stats are logged using a logger. - """ + """Test GPU stats are logged using a logger.""" model = BoringModel() gpu_stats = GPUStatsMonitor(intra_step_time=True) logger = CSVLogger(tmpdir) @@ -65,9 +63,7 @@ def test_gpu_stats_monitor(tmpdir): @RunIf(min_gpus=1) def test_gpu_stats_monitor_no_queries(tmpdir): - """ - Test GPU logger doesn't fail if no "nvidia-smi" queries are to be performed. - """ + """Test GPU logger doesn't fail if no "nvidia-smi" queries are to be performed.""" model = BoringModel() gpu_stats = GPUStatsMonitor( memory_utilization=False, @@ -96,18 +92,14 @@ def test_gpu_stats_monitor_no_queries(tmpdir): @pytest.mark.skipif(torch.cuda.is_available(), reason="test requires CPU machine") def test_gpu_stats_monitor_cpu_machine(tmpdir): - """ - Test GPUStatsMonitor on CPU machine. - """ + """Test GPUStatsMonitor on CPU machine.""" with pytest.raises(MisconfigurationException, match="NVIDIA driver is not installed"): GPUStatsMonitor() @RunIf(min_gpus=1) def test_gpu_stats_monitor_no_logger(tmpdir): - """ - Test GPUStatsMonitor with no logger in Trainer. - """ + """Test GPUStatsMonitor with no logger in Trainer.""" model = BoringModel() gpu_stats = GPUStatsMonitor() @@ -119,9 +111,7 @@ def test_gpu_stats_monitor_no_logger(tmpdir): @RunIf(min_gpus=1) def test_gpu_stats_monitor_no_gpu_warning(tmpdir): - """ - Test GPUStatsMonitor raises a warning when not training on GPU device. - """ + """Test GPUStatsMonitor raises a warning when not training on GPU device.""" model = BoringModel() gpu_stats = GPUStatsMonitor() diff --git a/tests/callbacks/test_lr_monitor.py b/tests/callbacks/test_lr_monitor.py index d742781599d77..32b2551245885 100644 --- a/tests/callbacks/test_lr_monitor.py +++ b/tests/callbacks/test_lr_monitor.py @@ -91,9 +91,7 @@ def configure_optimizers(self): def test_log_momentum_no_momentum_optimizer(tmpdir): - """ - Test that if optimizer doesn't have momentum then a warning is raised with log_momentum=True. - """ + """Test that if optimizer doesn't have momentum then a warning is raised with log_momentum=True.""" class LogMomentumModel(BoringModel): def configure_optimizers(self): diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 7634e18bede43..8356994b7b018 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -24,7 +24,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase -from pytorch_lightning.callbacks.progress import tqdm +from pytorch_lightning.callbacks.progress.tqdm_progress import Tqdm from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -228,9 +228,7 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal @pytest.mark.parametrize("limit_val_batches", (0, 5)) def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches: int): - """ - Test val_progress_bar total with 'num_sanity_val_steps' Trainer argument. - """ + """Test val_progress_bar total with 'num_sanity_val_steps' Trainer argument.""" class CurrentProgressBar(ProgressBar): val_pbar_total = 0 @@ -321,9 +319,10 @@ def init_test_tqdm(self): def test_main_progress_bar_update_amount( tmpdir, train_batches: int, val_batches: int, refresh_rate: int, train_deltas: list, val_deltas: list ): - """ - Test that the main progress updates with the correct amount together with the val progress. At the end of - the epoch, the progress must not overshoot if the number of steps is not divisible by the refresh rate. + """Test that the main progress updates with the correct amount together with the val progress. + + At the end of the epoch, the progress must not overshoot if the number of steps is not divisible by the refresh + rate. """ model = BoringModel() progress_bar = MockedUpdateProgressBars(refresh_rate=refresh_rate) @@ -345,9 +344,7 @@ def test_main_progress_bar_update_amount( @pytest.mark.parametrize("test_batches,refresh_rate,test_deltas", [[1, 3, [1]], [3, 1, [1, 1, 1]], [5, 3, [3, 2]]]) def test_test_progress_bar_update_amount(tmpdir, test_batches: int, refresh_rate: int, test_deltas: list): - """ - Test that test progress updates with the correct amount. - """ + """Test that test progress updates with the correct amount.""" model = BoringModel() progress_bar = MockedUpdateProgressBars(refresh_rate=refresh_rate) trainer = Trainer( @@ -363,7 +360,7 @@ def test_test_progress_bar_update_amount(tmpdir, test_batches: int, refresh_rate def test_tensor_to_float_conversion(tmpdir): - """Check tensor gets converted to float""" + """Check tensor gets converted to float.""" class TestModel(BoringModel): def training_step(self, batch, batch_idx): @@ -399,8 +396,8 @@ def training_step(self, batch, batch_idx): ], ) def test_tqdm_format_num(input_num: Union[str, int, float], expected: str): - """Check that the specialized tqdm.format_num appends 0 to floats and strings""" - assert tqdm.format_num(input_num) == expected + """Check that the specialized tqdm.format_num appends 0 to floats and strings.""" + assert Tqdm.format_num(input_num) == expected class PrintModel(BoringModel): @@ -421,7 +418,7 @@ def predict_step(self, *args, **kwargs): return super().predict_step(*args, **kwargs) -@mock.patch("pytorch_lightning.callbacks.progress.tqdm.write") +@mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm.write") def test_progress_bar_print(tqdm_write, tmpdir): """Test that printing in the LightningModule redirects arguments to the progress bar.""" model = PrintModel() @@ -448,7 +445,7 @@ def test_progress_bar_print(tqdm_write, tmpdir): ] -@mock.patch("pytorch_lightning.callbacks.progress.tqdm.write") +@mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm.write") def test_progress_bar_print_no_train(tqdm_write, tmpdir): """Test that printing in the LightningModule redirects arguments to the progress bar without training.""" model = PrintModel() @@ -475,7 +472,7 @@ def test_progress_bar_print_no_train(tqdm_write, tmpdir): @mock.patch("builtins.print") -@mock.patch("pytorch_lightning.callbacks.progress.tqdm.write") +@mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm.write") def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir): """Test that printing in LightningModule goes through built-in print function when progress bar is disabled.""" model = PrintModel() diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py index fe6e14d1084d9..cf02b0a4bce75 100644 --- a/tests/callbacks/test_pruning.py +++ b/tests/callbacks/test_pruning.py @@ -294,11 +294,8 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent: bool def test_permanent_when_model_is_saved_multiple_times( tmpdir, caplog, prune_on_train_epoch_end, save_on_train_epoch_end ): - """ - When a model is saved multiple times and make_permanent=True, we need to - make sure a copy is pruned and not the trained model if we want to continue - with the same pruning buffers. - """ + """When a model is saved multiple times and make_permanent=True, we need to make sure a copy is pruned and not + the trained model if we want to continue with the same pruning buffers.""" if prune_on_train_epoch_end and save_on_train_epoch_end: pytest.xfail( "Pruning sets the `grad_fn` of the parameters so we can't save" diff --git a/tests/callbacks/test_quantization.py b/tests/callbacks/test_quantization.py index 23d0cbc9d5581..73d1a458b52ed 100644 --- a/tests/callbacks/test_quantization.py +++ b/tests/callbacks/test_quantization.py @@ -32,7 +32,7 @@ @pytest.mark.parametrize("convert", [True, False]) @RunIf(quantization=True) def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool): - """Parity test for quant model""" + """Parity test for quant model.""" seed_everything(42) dm = RegressDataModule() trainer_args = dict(default_root_dir=tmpdir, max_epochs=7, gpus=int(torch.cuda.is_available())) @@ -76,7 +76,7 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool): @RunIf(quantization=True) def test_quantize_torchscript(tmpdir): - """Test converting to torchscipt""" + """Test converting to torchscipt.""" dm = RegressDataModule() qmodel = RegressionModel() qcb = QuantizationAwareTraining(input_compatible=False) @@ -92,7 +92,7 @@ def test_quantize_torchscript(tmpdir): @RunIf(quantization=True) def test_quantization_exceptions(tmpdir): - """Test wrong fuse layers""" + """Test wrong fuse layers.""" with pytest.raises(MisconfigurationException, match="Unsupported qconfig"): QuantizationAwareTraining(qconfig=["abc"]) @@ -130,7 +130,7 @@ def custom_trigger_last(trainer): ) @RunIf(quantization=True) def test_quantization_triggers(tmpdir, trigger_fn: Union[None, int, Callable], expected_count: int): - """Test how many times the quant is called""" + """Test how many times the quant is called.""" dm = RegressDataModule() qmodel = RegressionModel() qcb = QuantizationAwareTraining(collect_quantization=trigger_fn) diff --git a/tests/callbacks/test_stochastic_weight_avg.py b/tests/callbacks/test_stochastic_weight_avg.py index 0bfaa359bb1a8..6800694eb3fcf 100644 --- a/tests/callbacks/test_stochastic_weight_avg.py +++ b/tests/callbacks/test_stochastic_weight_avg.py @@ -192,7 +192,7 @@ def test_swa_raises(): @pytest.mark.parametrize("stochastic_weight_avg", [False, True]) @pytest.mark.parametrize("use_callbacks", [False, True]) def test_trainer_and_stochastic_weight_avg(tmpdir, use_callbacks: bool, stochastic_weight_avg: bool): - """Test to ensure SWA Callback is injected when `stochastic_weight_avg` is provided to the Trainer""" + """Test to ensure SWA Callback is injected when `stochastic_weight_avg` is provided to the Trainer.""" class TestModel(BoringModel): def configure_optimizers(self): @@ -217,7 +217,7 @@ def configure_optimizers(self): def test_swa_deepcopy(tmpdir): - """Test to ensure SWA Callback doesn't deepcopy dataloaders and datamodule potentially leading to OOM""" + """Test to ensure SWA Callback doesn't deepcopy dataloaders and datamodule potentially leading to OOM.""" class TestSWA(StochasticWeightAveraging): def __init__(self, *args, **kwargs): diff --git a/tests/callbacks/test_timer.py b/tests/callbacks/test_timer.py index 92643ba51b82c..94ee3e87bc3e3 100644 --- a/tests/callbacks/test_timer.py +++ b/tests/callbacks/test_timer.py @@ -104,7 +104,7 @@ def test_timer_time_remaining(time_mock): def test_timer_stops_training(tmpdir, caplog): - """Test that the timer stops training before reaching max_epochs""" + """Test that the timer stops training before reaching max_epochs.""" model = BoringModel() duration = timedelta(milliseconds=100) timer = Timer(duration=duration) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index f5d17ac5333ed..c7f4bd0e802a9 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -72,10 +72,8 @@ def validation_epoch_end(self, outputs): def test_model_checkpoint_score_and_ckpt( tmpdir, validation_step_none: bool, val_dataloaders_none: bool, monitor: str, reduce_lr_on_plateau: bool ): - """ - Test that when a model checkpoint is saved, it saves with - the correct score appended to ckpt_path and checkpoint data - """ + """Test that when a model checkpoint is saved, it saves with the correct score appended to ckpt_path and + checkpoint data.""" max_epochs = 3 limit_train_batches = 5 limit_val_batches = 7 @@ -184,10 +182,8 @@ def on_validation_epoch_end(self): def test_model_checkpoint_score_and_ckpt_val_check_interval( tmpdir, val_check_interval, reduce_lr_on_plateau, epoch_aligned ): - """ - Test that when a model checkpoint is saved, it saves with the correct - score appended to ckpt_path and checkpoint data with val_check_interval - """ + """Test that when a model checkpoint is saved, it saves with the correct score appended to ckpt_path and + checkpoint data with val_check_interval.""" max_epochs = 3 limit_train_batches = 12 limit_val_batches = 7 @@ -297,7 +293,7 @@ def _make_assertions(epoch, ix, version=""): @pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2]) def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k: int): - """Test that dirpath=None in checkpoint callback is valid and that ckpt_path is set correctly""" + """Test that dirpath=None in checkpoint callback is valid and that ckpt_path is set correctly.""" tutils.reset_seed() model = LogInTwoMethods() @@ -316,7 +312,7 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k: int): @pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2]) def test_model_checkpoint_to_yaml(tmpdir, save_top_k: int): - """Test that None in checkpoint callback is valid and that chkp_path is set correctly""" + """Test that None in checkpoint callback is valid and that chkp_path is set correctly.""" tutils.reset_seed() model = LogInTwoMethods() @@ -334,7 +330,7 @@ def test_model_checkpoint_to_yaml(tmpdir, save_top_k: int): @pytest.mark.parametrize("logger_version,expected", [(None, "version_0"), (1, "version_1"), ("awesome", "awesome")]) def test_model_checkpoint_path(tmpdir, logger_version: Union[None, int, str], expected: str): - """Test that "version_" prefix is only added when logger's version is an integer""" + """Test that "version_" prefix is only added when logger's version is an integer.""" tutils.reset_seed() model = LogInTwoMethods() logger = TensorBoardLogger(str(tmpdir), version=logger_version) @@ -458,9 +454,7 @@ class ModelCheckpointExtensionTest(ModelCheckpoint): def test_model_checkpoint_file_extension(tmpdir): - """ - Test ModelCheckpoint with different file extension. - """ + """Test ModelCheckpoint with different file extension.""" model = LogInTwoMethods() model_checkpoint = ModelCheckpointExtensionTest( @@ -549,10 +543,8 @@ def test_invalid_every_n_train_steps(tmpdir): def test_invalid_trigger_combination(tmpdir): - """ - Test that a MisconfigurationException is raised if more than one of - every_n_epochs, every_n_train_steps, and train_time_interval are enabled together. - """ + """Test that a MisconfigurationException is raised if more than one of every_n_epochs, every_n_train_steps, and + train_time_interval are enabled together.""" with pytest.raises(MisconfigurationException, match=r".*Combination of parameters every_n_train_steps"): ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1, every_n_epochs=2) with pytest.raises(MisconfigurationException, match=r".*Combination of parameters every_n_train_steps"): @@ -788,7 +780,7 @@ def test_default_checkpoint_behavior(tmpdir): def test_model_checkpoint_save_last_warning( tmpdir, caplog, max_epochs: int, should_validate: bool, save_last: bool, verbose: bool ): - """Tests 'Saving latest checkpoint...' log""" + """Tests 'Saving latest checkpoint...' log.""" model = LogInTwoMethods() if not should_validate: model.validation_step = None @@ -868,9 +860,7 @@ def validation_epoch_end(self, outputs): def test_checkpoint_repeated_strategy(tmpdir): - """ - This test validates that the checkpoint can be called when provided to callbacks list - """ + """This test validates that the checkpoint can be called when provided to callbacks list.""" checkpoint_callback = ModelCheckpoint(monitor="val_loss", dirpath=tmpdir, filename="{epoch:02d}") class ExtendedBoringModel(BoringModel): @@ -913,10 +903,8 @@ def validation_step(self, batch, batch_idx): def test_checkpoint_repeated_strategy_extended(tmpdir): - """ - This test validates checkpoint can be called several times without - increasing internally its global step if nothing run. - """ + """This test validates checkpoint can be called several times without increasing internally its global step if + nothing run.""" class ExtendedBoringModel(BoringModel): def validation_step(self, batch, batch_idx): @@ -1050,7 +1038,7 @@ def test_val_check_interval_checkpoint_files(tmpdir): def test_current_score(tmpdir): - """Check that the current_score value is correct and was saved""" + """Check that the current_score value is correct and was saved.""" class TestModel(BoringModel): def training_step(self, *args): @@ -1083,7 +1071,7 @@ def training_step(self, *args): @pytest.mark.parametrize("mode", ["min", "max"]) def test_current_score_when_nan(tmpdir, mode: str): - """Check that ModelCheckpoint handles NaN values correctly""" + """Check that ModelCheckpoint handles NaN values correctly.""" class TestModel(BoringModel): def training_step(self, *args): @@ -1138,10 +1126,8 @@ def __init__(self, hparams): def test_ckpt_version_after_rerun_new_trainer(tmpdir): - """ - Check that previous checkpoints are renamed to have the correct - version suffix when new trainer instances are used - """ + """Check that previous checkpoints are renamed to have the correct version suffix when new trainer instances + are used.""" epochs = 2 for i in range(epochs): mc = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, monitor="epoch", filename="{epoch}") @@ -1167,10 +1153,8 @@ def test_ckpt_version_after_rerun_new_trainer(tmpdir): def test_ckpt_version_after_rerun_same_trainer(tmpdir): - """ - Check that previous checkpoints are renamed to have the correct - version suffix when the same trainer instance is used - """ + """Check that previous checkpoints are renamed to have the correct version suffix when the same trainer + instance is used.""" mc = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, monitor="epoch", filename="test") mc.STARTING_VERSION = 9 trainer = Trainer( diff --git a/tests/checkpointing/test_trainer_checkpoint.py b/tests/checkpointing/test_trainer_checkpoint.py index 6a8192ef0149e..739dc98a22834 100644 --- a/tests/checkpointing/test_trainer_checkpoint.py +++ b/tests/checkpointing/test_trainer_checkpoint.py @@ -23,9 +23,7 @@ def test_finetuning_with_resume_from_checkpoint(tmpdir): - """ - This test validates that generated ModelCheckpoint is pointing to the right best_model_path during test - """ + """This test validates that generated ModelCheckpoint is pointing to the right best_model_path during test.""" checkpoint_callback = ModelCheckpoint(monitor="val_loss", dirpath=tmpdir, filename="{epoch:02d}", save_top_k=-1) @@ -81,9 +79,7 @@ def validation_step(self, batch, batch_idx): def test_accumulated_gradient_batches_with_resume_from_checkpoint(tmpdir): - """ - This test validates that accumulated gradient is properly recomputed and reset on the trainer. - """ + """This test validates that accumulated gradient is properly recomputed and reset on the trainer.""" ckpt = ModelCheckpoint(dirpath=tmpdir, save_last=True) model = BoringModel() diff --git a/tests/conftest.py b/tests/conftest.py index a9fd1af75ce6e..9efe445ce812f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -111,10 +111,9 @@ class ThreadingHTTPServer(ThreadingMixIn, HTTPServer): @pytest.fixture def single_process_pg(): - """ - Initialize the default process group with only the current process for - testing purposes. The process group is destroyed when the with block is - exited. + """Initialize the default process group with only the current process for testing purposes. + + The process group is destroyed when the with block is exited. """ if torch.distributed.is_initialized(): raise RuntimeError("Can't use `single_process_pg` when the default process group is already initialized.") diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index c3bd24546f4f3..2f84032593472 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -431,10 +431,8 @@ def transfer_batch_to_device(self, batch, device, dataloader_idx): def test_dm_reload_dataloaders_every_n_epochs(tmpdir): - """ - Test datamodule, where trainer argument - reload_dataloaders_every_n_epochs is set to a non negative integer - """ + """Test datamodule, where trainer argument reload_dataloaders_every_n_epochs is set to a non negative + integer.""" class CustomBoringDataModule(BoringDataModule): def __init__(self): diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py index 4068bc5504b5b..2348089ed37e5 100644 --- a/tests/core/test_lightning_optimizer.py +++ b/tests/core/test_lightning_optimizer.py @@ -25,9 +25,7 @@ def test_lightning_optimizer(tmpdir): - """ - Test that optimizer are correctly wrapped by our LightningOptimizer - """ + """Test that optimizer are correctly wrapped by our LightningOptimizer.""" class TestModel(BoringModel): def configure_optimizers(self): @@ -47,8 +45,9 @@ def configure_optimizers(self): def test_lightning_optimizer_from_user(tmpdir): - """ - Test that the user can use our LightningOptimizer. Not recommended. + """Test that the user can use our LightningOptimizer. + + Not recommended. """ class TestModel(BoringModel): @@ -70,8 +69,9 @@ def configure_optimizers(self): def test_lightning_optimizer_manual_optimization_and_accumulated_gradients(tmpdir): - """ - Test that the user can use our LightningOptimizer. Not recommended. + """Test that the user can use our LightningOptimizer. + + Not recommended. """ class TestModel(BoringModel): @@ -174,9 +174,7 @@ def test_state(tmpdir): def test_lightning_optimizer_automatic_optimization_optimizer_zero_grad(tmpdir): - """ - Test overriding zero_grad works in automatic_optimization - """ + """Test overriding zero_grad works in automatic_optimization.""" class TestModel(BoringModel): def training_step(self, batch, batch_idx, optimizer_idx=None): @@ -210,9 +208,7 @@ def configure_optimizers(self): def test_lightning_optimizer_automatic_optimization_optimizer_step(tmpdir): - """ - Test overriding step works in automatic_optimization - """ + """Test overriding step works in automatic_optimization.""" class TestModel(BoringModel): def training_step(self, batch, batch_idx, optimizer_idx=None): @@ -257,10 +253,8 @@ def configure_optimizers(self): def test_lightning_optimizer_automatic_optimization_lbfgs_zero_grad(tmpdir): - """ - Test zero_grad is called the same number of times as LBFGS requires - for reevaluation of the loss in automatic_optimization. - """ + """Test zero_grad is called the same number of times as LBFGS requires for reevaluation of the loss in + automatic_optimization.""" class TestModel(BoringModel): def configure_optimizers(self): @@ -306,15 +300,13 @@ def __init__(self, model): super().__init__(self.params, {"lr": 0.01}) def _save_input(self, mod, i): - """Saves input of layer""" + """Saves input of layer.""" if mod.training: self.state[mod]["x"] = i[0] def _save_grad_output(self, mod, _, grad_output): - """ - Saves grad on output of layer to - grad is scaled with batch_size since gradient is spread over samples in mini batch - """ + """Saves grad on output of layer to grad is scaled with batch_size since gradient is spread over samples in + mini batch.""" batch_size = grad_output[0].shape[0] if mod.training: self.state[mod]["grad"] = grad_output[0] * batch_size diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 98e884eedbf4b..1f4b72744c03b 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -97,7 +97,7 @@ def _ddp_test_fn(rank, worldsize): @RunIf(skip_windows=True, min_gpus=2) def test_result_reduce_ddp(): - """Make sure result logging works with DDP""" + """Make sure result logging works with DDP.""" tutils.set_random_master_port() worldsize = 2 @@ -208,9 +208,7 @@ def my_sync_dist(x, *_, **__): def test_result_collection_restoration(tmpdir): - """ - This test make sure metrics are properly reloaded on failure. - """ + """This test make sure metrics are properly reloaded on failure.""" result = ResultCollection(True, torch.device("cpu")) metric_a = DummyMetric() @@ -379,10 +377,8 @@ def __repr__(self) -> str: def result_collection_reload(**kwargs): - """ - This test is going to validate ResultCollection is properly being reload - and final accumulation with Fault Tolerant Training is correct. - """ + """This test is going to validate ResultCollection is properly being reload and final accumulation with Fault + Tolerant Training is correct.""" if not _fault_tolerant_training(): pytest.skip("Fault tolerant not available") @@ -499,7 +495,7 @@ def test_result_collection_reload_2_gpus(tmpdir): def test_metric_collections(tmpdir): - """This test ensures the metric attribute is properly found even with complex nested metric structure""" + """This test ensures the metric attribute is properly found even with complex nested metric structure.""" class TestModel(BoringModel): def __init__(self): diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 9d164b989f434..1033699ef398c 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -40,7 +40,7 @@ def _ddp_test_fn(rank, worldsize): @RunIf(skip_windows=True) def test_result_reduce_ddp(): - """Make sure result logging works with DDP""" + """Make sure result logging works with DDP.""" tutils.set_random_master_port() worldsize = 2 mp.spawn(_ddp_test_fn, args=(worldsize,), nprocs=worldsize) diff --git a/tests/deprecated_api/__init__.py b/tests/deprecated_api/__init__.py index ccfae3ec8dcf2..1026981f75307 100644 --- a/tests/deprecated_api/__init__.py +++ b/tests/deprecated_api/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Test deprecated functionality which will be removed in vX.Y.Z""" +"""Test deprecated functionality which will be removed in vX.Y.Z.""" import sys from contextlib import contextmanager from typing import Optional diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 695b9c8260402..a9d17601153ae 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Test deprecated functionality which will be removed in v1.5.0""" +"""Test deprecated functionality which will be removed in v1.5.0.""" import pytest from pytorch_lightning import Trainer diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index 98c5c4e0320ea..fec29ed6b47f8 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Test deprecated functionality which will be removed in v1.6.0 """ +"""Test deprecated functionality which will be removed in v1.6.0.""" import pytest from pytorch_lightning import Trainer diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 1b4b6ef0701d1..488e14a498f3d 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Test deprecated functionality which will be removed in v1.7.0 """ +"""Test deprecated functionality which will be removed in v1.7.0.""" from unittest import mock import pytest diff --git a/tests/deprecated_api/test_remove_2-0.py b/tests/deprecated_api/test_remove_2-0.py index 9c372c8f1a9c6..a4a3eb2b3726d 100644 --- a/tests/deprecated_api/test_remove_2-0.py +++ b/tests/deprecated_api/test_remove_2-0.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Test deprecated functionality which will be removed in v1.4.0""" +"""Test deprecated functionality which will be removed in v1.4.0.""" import pytest diff --git a/tests/helpers/boring_model.py b/tests/helpers/boring_model.py index 5835b1aa2c7ca..4036d34663a9f 100644 --- a/tests/helpers/boring_model.py +++ b/tests/helpers/boring_model.py @@ -70,8 +70,7 @@ def __len__(self): class BoringModel(LightningModule): def __init__(self): - """ - Testing PL Module + """Testing PL Module. Use as follows: - subclass @@ -85,7 +84,6 @@ def training_step(...): model = BaseTestModel() model.training_epoch_end = None - """ super().__init__() self.layer = torch.nn.Linear(32, 2) diff --git a/tests/helpers/dataloaders.py b/tests/helpers/dataloaders.py index 81417bc12d819..14dde1c8424b2 100644 --- a/tests/helpers/dataloaders.py +++ b/tests/helpers/dataloaders.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Custom dataloaders for testing""" +"""Custom dataloaders for testing.""" class CustomInfDataloader: @@ -38,7 +38,7 @@ def __next__(self): class CustomNotImplementedErrorDataloader(CustomInfDataloader): def __len__(self): - """raise NotImplementedError""" + """raise NotImplementedError.""" raise NotImplementedError def __next__(self): diff --git a/tests/helpers/datasets.py b/tests/helpers/datasets.py index 10a4c6b6e7ca7..561642ae8cfbe 100644 --- a/tests/helpers/datasets.py +++ b/tests/helpers/datasets.py @@ -24,9 +24,8 @@ class MNIST(Dataset): - """ - Customized `MNIST `_ dataset for testing Pytorch Lightning - without the torchvision dependency. + """Customized `MNIST `_ dataset for testing Pytorch Lightning without the + torchvision dependency. Part of the code was copied from https://github.com/pytorch/vision/blob/build/v0.5.0/torchvision/datasets/mnist.py @@ -134,7 +133,7 @@ def normalize_tensor(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> Ten class TrialMNIST(MNIST): - """Constrained MNIST dataset + """Constrained MNIST dataset. Args: num_samples: number of examples per selected class/digit diff --git a/tests/helpers/runif.py b/tests/helpers/runif.py index 6abc2e377e6dc..cc6463eb61f6b 100644 --- a/tests/helpers/runif.py +++ b/tests/helpers/runif.py @@ -44,13 +44,12 @@ class RunIf: - """ - RunIf wrapper for simple marking specific cases, fully compatible with pytest.mark:: + """RunIf wrapper for simple marking specific cases, fully compatible with pytest.mark:: - @RunIf(min_torch="0.0") - @pytest.mark.parametrize("arg1", [1, 2.0]) - def test_wrapper(arg1): - assert arg1 > 0.0 + @RunIf(min_torch="0.0") + @pytest.mark.parametrize("arg1", [1, 2.0]) + def test_wrapper(arg1): + assert arg1 > 0.0 """ def __new__( diff --git a/tests/helpers/test_models.py b/tests/helpers/test_models.py index b6d853f2ac594..8e5f85632bbc5 100644 --- a/tests/helpers/test_models.py +++ b/tests/helpers/test_models.py @@ -34,7 +34,7 @@ ], ) def test_models(tmpdir, data_class, model_class): - """Test simple models""" + """Test simple models.""" dm = data_class() if data_class else data_class model = model_class() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index eb1d53ccc62e3..e13232eafca09 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -240,7 +240,10 @@ def name(self): ], ) def test_loggers_pickle_all(tmpdir, monkeypatch, logger_class): - """Test that the logger objects can be pickled. This test only makes sense if the packages are installed.""" + """Test that the logger objects can be pickled. + + This test only makes sense if the packages are installed. + """ _patch_comet_atexit(monkeypatch) try: _test_loggers_pickle(tmpdir, monkeypatch, logger_class) @@ -281,7 +284,7 @@ def _test_loggers_pickle(tmpdir, monkeypatch, logger_class): ], ) def test_logger_reset_correctly(tmpdir, extra_params): - """Test that the tuners do not alter the logger reference""" + """Test that the tuners do not alter the logger reference.""" class CustomModel(BoringModel): def __init__(self, lr=0.1, batch_size=1): @@ -318,7 +321,7 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_ ) @RunIf(skip_windows=True) def test_logger_created_on_rank_zero_only(tmpdir, monkeypatch, logger_class): - """Test that loggers get replaced by dummy loggers on global rank > 0""" + """Test that loggers get replaced by dummy loggers on global rank > 0.""" _patch_comet_atexit(monkeypatch) try: _test_logger_created_on_rank_zero_only(tmpdir, logger_class) @@ -344,9 +347,7 @@ def _test_logger_created_on_rank_zero_only(tmpdir, logger_class): def test_logger_with_prefix_all(tmpdir, monkeypatch): - """ - Test that prefix is added at the beginning of the metric keys. - """ + """Test that prefix is added at the beginning of the metric keys.""" prefix = "tmp" # Comet diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py index 7c8673c34956f..7a1bf9a56f70a 100644 --- a/tests/loggers/test_base.py +++ b/tests/loggers/test_base.py @@ -79,10 +79,8 @@ def finalize(self, status): @property def save_dir(self) -> Optional[str]: - """ - Return the root directory where experiment logs get saved, or `None` if the logger does not - save data locally. - """ + """Return the root directory where experiment logs get saved, or `None` if the logger does not save data + locally.""" return None @property diff --git a/tests/loggers/test_csv.py b/tests/loggers/test_csv.py index f28377471dc3e..2640ede1bf39f 100644 --- a/tests/loggers/test_csv.py +++ b/tests/loggers/test_csv.py @@ -23,7 +23,7 @@ def test_file_logger_automatic_versioning(tmpdir): - """Verify that automatic versioning works""" + """Verify that automatic versioning works.""" root_dir = tmpdir.mkdir("exp") root_dir.mkdir("version_0") @@ -35,7 +35,7 @@ def test_file_logger_automatic_versioning(tmpdir): def test_file_logger_manual_versioning(tmpdir): - """Verify that manual versioning works""" + """Verify that manual versioning works.""" root_dir = tmpdir.mkdir("exp") root_dir.mkdir("version_0") @@ -48,7 +48,7 @@ def test_file_logger_manual_versioning(tmpdir): def test_file_logger_named_version(tmpdir): - """Verify that manual versioning works for string versions, e.g. '2020-02-05-162402'""" + """Verify that manual versioning works for string versions, e.g. '2020-02-05-162402'.""" exp_name = "exp" tmpdir.mkdir(exp_name) @@ -64,7 +64,7 @@ def test_file_logger_named_version(tmpdir): @pytest.mark.parametrize("name", ["", None]) def test_file_logger_no_name(tmpdir, name): - """Verify that None or empty name works""" + """Verify that None or empty name works.""" logger = CSVLogger(save_dir=tmpdir, name=name) logger.save() assert logger.root_dir == tmpdir diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index 8523edf69a980..7bf0ae67adb23 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -184,9 +184,7 @@ def training_epoch_end(self, *args, **kwargs): @mock.patch("pytorch_lightning.loggers.mlflow.mlflow") @mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") def test_mlflow_experiment_id_retrieved_once(client, mlflow, tmpdir): - """ - Test that the logger experiment_id retrieved only once. - """ + """Test that the logger experiment_id retrieved only once.""" logger = MLFlowLogger("test", save_dir=tmpdir) _ = logger.experiment _ = logger.experiment @@ -197,9 +195,7 @@ def test_mlflow_experiment_id_retrieved_once(client, mlflow, tmpdir): @mock.patch("pytorch_lightning.loggers.mlflow.mlflow") @mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") def test_mlflow_logger_with_unexpected_characters(client, mlflow, tmpdir): - """ - Test that the logger raises warning with special characters not accepted by MLFlow. - """ + """Test that the logger raises warning with special characters not accepted by MLFlow.""" logger = MLFlowLogger("test", save_dir=tmpdir) metrics = {"[some_metric]": 10} @@ -210,9 +206,7 @@ def test_mlflow_logger_with_unexpected_characters(client, mlflow, tmpdir): @mock.patch("pytorch_lightning.loggers.mlflow.mlflow") @mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") def test_mlflow_logger_with_long_param_value(client, mlflow, tmpdir): - """ - Test that the logger raises warning with special characters not accepted by MLFlow. - """ + """Test that the logger raises warning with special characters not accepted by MLFlow.""" logger = MLFlowLogger("test", save_dir=tmpdir) value = "test" * 100 key = "test_param" @@ -226,9 +220,7 @@ def test_mlflow_logger_with_long_param_value(client, mlflow, tmpdir): @mock.patch("pytorch_lightning.loggers.mlflow.mlflow") @mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient") def test_mlflow_logger_experiment_calls(client, mlflow, time, tmpdir): - """ - Test that the logger calls methods on the mlflow experiment correctly. - """ + """Test that the logger calls methods on the mlflow experiment correctly.""" time.return_value = 1 logger = MLFlowLogger("test", save_dir=tmpdir, artifact_location="my_artifact_location") diff --git a/tests/loggers/test_neptune.py b/tests/loggers/test_neptune.py index c58bb4ef59fbe..84cece3ae2101 100644 --- a/tests/loggers/test_neptune.py +++ b/tests/loggers/test_neptune.py @@ -103,7 +103,7 @@ def test_neptune_additional_methods(neptune): @patch("pytorch_lightning.loggers.neptune.neptune") def test_neptune_leave_open_experiment_after_fit(neptune, tmpdir): - """Verify that neptune experiment was closed after training""" + """Verify that neptune experiment was closed after training.""" model = BoringModel() def _run_training(logger): diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index a1c66c0559d75..027a29d94fc80 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -71,7 +71,7 @@ def __init__(self, b1=0.5, b2=0.999): def test_tensorboard_automatic_versioning(tmpdir): - """Verify that automatic versioning works""" + """Verify that automatic versioning works.""" root_dir = tmpdir / "tb_versioning" root_dir.mkdir() @@ -83,7 +83,7 @@ def test_tensorboard_automatic_versioning(tmpdir): def test_tensorboard_manual_versioning(tmpdir): - """Verify that manual versioning works""" + """Verify that manual versioning works.""" root_dir = tmpdir / "tb_versioning" root_dir.mkdir() @@ -97,7 +97,7 @@ def test_tensorboard_manual_versioning(tmpdir): def test_tensorboard_named_version(tmpdir): - """Verify that manual versioning works for string versions, e.g. '2020-02-05-162402'""" + """Verify that manual versioning works for string versions, e.g. '2020-02-05-162402'.""" name = "tb_versioning" (tmpdir / name).mkdir() @@ -113,7 +113,7 @@ def test_tensorboard_named_version(tmpdir): @pytest.mark.parametrize("name", ["", None]) def test_tensorboard_no_name(tmpdir, name): - """Verify that None or empty name works""" + """Verify that None or empty name works.""" logger = TensorBoardLogger(save_dir=tmpdir, name=name) logger.log_hyperparams({"a": 1, "b": 2, 123: 3, 3.5: 4, 5j: 5}) # Force data to be written assert logger.root_dir == tmpdir @@ -223,9 +223,7 @@ def test_tensorboard_log_omegaconf_hparams_and_metrics(tmpdir): @pytest.mark.parametrize("example_input_array", [None, torch.rand(2, 32)]) def test_tensorboard_log_graph(tmpdir, example_input_array): - """test that log graph works with both model.example_input_array and - if array is passed externaly - """ + """test that log graph works with both model.example_input_array and if array is passed externaly.""" model = BoringModel() if example_input_array is not None: model.example_input_array = None @@ -235,7 +233,7 @@ def test_tensorboard_log_graph(tmpdir, example_input_array): def test_tensorboard_log_graph_warning_no_example_input_array(tmpdir): - """test that log graph throws warning if model.example_input_array is None""" + """test that log graph throws warning if model.example_input_array is None.""" model = BoringModel() model.example_input_array = None logger = TensorBoardLogger(tmpdir, log_graph=True) @@ -249,7 +247,7 @@ def test_tensorboard_log_graph_warning_no_example_input_array(tmpdir): @mock.patch("pytorch_lightning.loggers.TensorBoardLogger.log_metrics") def test_tensorboard_with_accummulated_gradients(mock_log_metrics, tmpdir): - """Tests to ensure that tensorboard log properly when accumulated_gradients > 1""" + """Tests to ensure that tensorboard log properly when accumulated_gradients > 1.""" class TestModel(BoringModel): def __init__(self): @@ -308,10 +306,8 @@ def test_tensorboard_save_hparams_to_yaml_once(tmpdir): @mock.patch("pytorch_lightning.loggers.tensorboard.log") def test_tensorboard_with_symlink(log, tmpdir): - """ - Tests a specific failure case when tensorboard logger is used with empty name, symbolic link ``save_dir``, and - relative paths. - """ + """Tests a specific failure case when tensorboard logger is used with empty name, symbolic link ``save_dir``, + and relative paths.""" os.chdir(tmpdir) # need to use relative paths source = os.path.join(".", "lightning_logs") dest = os.path.join(".", "sym_lightning_logs") @@ -326,7 +322,7 @@ def test_tensorboard_with_symlink(log, tmpdir): def test_tensorboard_missing_folder_warning(tmpdir, caplog): - """Verify that the logger throws a warning for invalid directory""" + """Verify that the logger throws a warning for invalid directory.""" name = "fake_dir" logger = TensorBoardLogger(save_dir=tmpdir, name=name) diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 0684727e84ac8..8388d7877ab7e 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -27,7 +27,9 @@ @mock.patch("pytorch_lightning.loggers.wandb.wandb") def test_wandb_logger_init(wandb): """Verify that basic functionality of wandb logger works. - Wandb doesn't work well with pytest so we have to mock it out here.""" + + Wandb doesn't work well with pytest so we have to mock it out here. + """ # test wandb.init called when there is no W&B run wandb.run = None @@ -83,8 +85,8 @@ def test_wandb_logger_init(wandb): @mock.patch("pytorch_lightning.loggers.wandb.wandb") def test_wandb_pickle(wandb, tmpdir): - """ - Verify that pickling trainer with wandb logger works. + """Verify that pickling trainer with wandb logger works. + Wandb doesn't work well with pytest so we have to mock it out here. """ @@ -215,9 +217,9 @@ def test_wandb_log_model(wandb, tmpdir): def test_wandb_sanitize_callable_params(tmpdir): - """ - Callback function are not serializiable. Therefore, we get them a chance to return - something and if the returned type is not accepted, return None. + """Callback function are not serializiable. + + Therefore, we get them a chance to return something and if the returned type is not accepted, return None. """ opt = "--max_epochs 1".split(" ") parser = ArgumentParser() @@ -246,6 +248,6 @@ def wrapper_something(): @mock.patch("pytorch_lightning.loggers.wandb.wandb") def test_wandb_logger_offline_log_model(wandb, tmpdir): - """Test that log_model=True raises an error in offline mode""" + """Test that log_model=True raises an error in offline mode.""" with pytest.raises(MisconfigurationException, match="checkpoints cannot be uploaded in offline mode"): _ = WandbLogger(save_dir=str(tmpdir), offline=True, log_model=True) diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index f851a7de0837e..71acb9a168081 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -1,5 +1,4 @@ -""" -This script is meant to be executed from `../../test_horovod.py`. +"""This script is meant to be executed from `../../test_horovod.py`. Because Horovod uses a parallel programming model similar to MPI, unit tests for collective ops like allreduce need to be run in parallel. The most common approach for running parallel diff --git a/tests/models/test_cpu.py b/tests/models/test_cpu.py index 015c79458e1aa..c9d05ae5f3f42 100644 --- a/tests/models/test_cpu.py +++ b/tests/models/test_cpu.py @@ -144,7 +144,10 @@ def test_multi_cpu_model_ddp(tmpdir): def test_lbfgs_cpu_model(tmpdir): - """Test each of the trainer options. Testing LBFGS optimizer""" + """Test each of the trainer options. + + Testing LBFGS optimizer + """ class ModelSpecifiedOptimizer(BoringModel): def __init__(self, optimizer_name, learning_rate): @@ -230,8 +233,10 @@ def test_step(self, *args, **kwargs): def test_running_test_no_val(tmpdir): - """Verify `test()` works on a model with no `val_dataloader`. It performs - train and test only""" + """Verify `test()` works on a model with no `val_dataloader`. + + It performs train and test only + """ class ModelTrainTest(BoringModel): def val_dataloader(self): diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 1d23ed2f76907..1faf4b820eae5 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -259,10 +259,8 @@ def test_parse_gpu_returns_none_when_no_devices_are_available(mocked_device_coun @mock.patch("torch.cuda.device_count", return_value=1) @pytest.mark.parametrize("gpus", [[0, 1, 2], 2, "0"]) def test_torchelastic_gpu_parsing(mocked_device_count, gpus): - """ - Ensure when using torchelastic and nproc_per_node is set to the default of 1 per GPU device - That we omit sanitizing the gpus as only one of the GPUs is visible. - """ + """Ensure when using torchelastic and nproc_per_node is set to the default of 1 per GPU device That we omit + sanitizing the gpus as only one of the GPUs is visible.""" trainer = Trainer(gpus=gpus) assert isinstance(trainer.accelerator_connector.cluster_environment, TorchElasticEnvironment) assert trainer.accelerator_connector.parallel_device_ids == device_parser.parse_gpu_ids(gpus) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 990178b09d07f..4a787a833dd1c 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -169,9 +169,7 @@ def transfer_batch_to_device(self, batch, device, dataloader_idx): @RunIf(min_gpus=2, special=True) def test_transfer_batch_hook_ddp(tmpdir): - """ - Test custom data are properly moved to the right device using ddp - """ + """Test custom data are properly moved to the right device using ddp.""" class CustomBatch: def __init__(self, data): @@ -767,9 +765,7 @@ def test_trainer_model_hook_system_predict(tmpdir): def test_hooks_with_different_argument_names(tmpdir): - """ - Test that argument names can be anything in the hooks - """ + """Test that argument names can be anything in the hooks.""" class CustomBoringModel(BoringModel): def assert_args(self, x, batch_nb): diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 7e37fab22cd3f..3172f8504ade0 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -39,7 +39,7 @@ class SaveHparamsModel(BoringModel): - """Tests that a model can take an object""" + """Tests that a model can take an object.""" def __init__(self, hparams): super().__init__() @@ -55,7 +55,7 @@ def wrapper(*args, **kwargs): class SaveHparamsDecoratedModel(BoringModel): - """Tests that a model can take an object""" + """Tests that a model can take an object.""" @decorate @decorate @@ -68,9 +68,7 @@ def __init__(self, hparams, *my_args, **my_kwargs): # STANDARD TESTS # ------------------------- def _run_standard_hparams_test(tmpdir, model, cls, try_overwrite=False): - """ - Tests for the existence of an arg 'test_arg=14' - """ + """Tests for the existence of an arg 'test_arg=14'.""" hparam_type = type(model.hparams) # test proper property assignments assert model.hparams.test_arg == 14 @@ -135,9 +133,7 @@ def test_omega_conf_hparams(tmpdir, cls): def test_explicit_args_hparams(tmpdir): - """ - Tests that a model can take implicit args and assign - """ + """Tests that a model can take implicit args and assign.""" # define model class LocalModel(BoringModel): @@ -156,9 +152,7 @@ def __init__(self, test_arg, test_arg2): def test_implicit_args_hparams(tmpdir): - """ - Tests that a model can take regular args and assign - """ + """Tests that a model can take regular args and assign.""" # define model class LocalModel(BoringModel): @@ -177,9 +171,7 @@ def __init__(self, test_arg, test_arg2): def test_explicit_missing_args_hparams(tmpdir): - """ - Tests that a model can take regular args and assign - """ + """Tests that a model can take regular args and assign.""" # define model class LocalModel(BoringModel): @@ -292,7 +284,7 @@ def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param="something")) ], ) def test_collect_init_arguments(tmpdir, cls): - """Test that the model automatically saves the arguments passed into the constructor""" + """Test that the model automatically saves the arguments passed into the constructor.""" extra_args = {} if cls is AggSubClassBoringModel: extra_args.update(my_loss=torch.nn.CosineEmbeddingLoss()) @@ -542,7 +534,7 @@ def __init__(self, hparams): class SubClassVarArgs(SuperClassPositionalArgs): - """Loading this model should accept hparams and init in the super class""" + """Loading this model should accept hparams and init in the super class.""" def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -568,7 +560,7 @@ def __init__(self, **kwargs): @pytest.mark.parametrize("cls", [RuntimeParamChangeModelSaving]) def test_init_arg_with_runtime_change(tmpdir, cls): - """Test that we save/export only the initial hparams, no other runtime change allowed""" + """Test that we save/export only the initial hparams, no other runtime change allowed.""" model = cls(running_arg=123) assert model.hparams.running_arg == 123 model.hparams.running_arg = -1 @@ -601,9 +593,7 @@ def test_model_with_fsspec_as_parameter(tmpdir): @pytest.mark.skipif(not _HYDRA_EXPERIMENTAL_AVAILABLE, reason="Hydra experimental is not available") def test_model_save_hyper_parameters_interpolation_with_hydra(tmpdir): - """ - This test relies on configuration saved under tests/models/conf/config.yaml - """ + """This test relies on configuration saved under tests/models/conf/config.yaml.""" class TestHydraModel(BoringModel): def __init__(self, args_0, args_1, args_2, kwarg_1=None): @@ -636,9 +626,7 @@ def __init__(self, args_0, args_1, args_2, kwarg_1=None): @pytest.mark.parametrize("ignore", ("arg2", ("arg2", "arg3"))) def test_ignore_args_list_hparams(tmpdir, ignore): - """ - Tests that args can be ignored in save_hyperparameters - """ + """Tests that args can be ignored in save_hyperparameters.""" class LocalModel(BoringModel): def __init__(self, arg1, arg2, arg3): diff --git a/tests/models/test_onnx.py b/tests/models/test_onnx.py index 7cd1d2776f43c..59af0ffa831d7 100644 --- a/tests/models/test_onnx.py +++ b/tests/models/test_onnx.py @@ -27,7 +27,7 @@ def test_model_saves_with_input_sample(tmpdir): - """Test that ONNX model saves with input sample and size is greater than 3 MB""" + """Test that ONNX model saves with input sample and size is greater than 3 MB.""" model = BoringModel() trainer = Trainer(fast_dev_run=True) trainer.fit(model) @@ -41,7 +41,7 @@ def test_model_saves_with_input_sample(tmpdir): @RunIf(min_gpus=1) def test_model_saves_on_gpu(tmpdir): - """Test that model saves on gpu""" + """Test that model saves on gpu.""" model = BoringModel() trainer = Trainer(gpus=1, fast_dev_run=True) trainer.fit(model) @@ -54,7 +54,7 @@ def test_model_saves_on_gpu(tmpdir): def test_model_saves_with_example_output(tmpdir): - """Test that ONNX model saves when provided with example output""" + """Test that ONNX model saves when provided with example output.""" model = BoringModel() trainer = Trainer(fast_dev_run=True) trainer.fit(model) @@ -75,7 +75,7 @@ def test_model_saves_with_example_output(tmpdir): ], ) def test_model_saves_with_example_input_array(tmpdir, modelclass, input_sample): - """Test that ONNX model saves with example_input_array and size is greater than 3 MB""" + """Test that ONNX model saves with example_input_array and size is greater than 3 MB.""" model = modelclass() model.example_input_array = input_sample @@ -87,7 +87,7 @@ def test_model_saves_with_example_input_array(tmpdir, modelclass, input_sample): @RunIf(min_gpus=2) def test_model_saves_on_multi_gpu(tmpdir): - """Test that ONNX model saves on a distributed backend""" + """Test that ONNX model saves on a distributed backend.""" tutils.set_random_master_port() trainer_options = dict( @@ -111,7 +111,7 @@ def test_model_saves_on_multi_gpu(tmpdir): def test_verbose_param(tmpdir, capsys): - """Test that output is present when verbose parameter is set""" + """Test that output is present when verbose parameter is set.""" model = BoringModel() model.example_input_array = torch.randn(5, 32) @@ -122,7 +122,7 @@ def test_verbose_param(tmpdir, capsys): def test_error_if_no_input(tmpdir): - """Test that an error is thrown when there is no input tensor""" + """Test that an error is thrown when there is no input tensor.""" model = BoringModel() model.example_input_array = None file_path = os.path.join(tmpdir, "model.onnx") @@ -135,7 +135,7 @@ def test_error_if_no_input(tmpdir): def test_if_inference_output_is_valid(tmpdir): - """Test that the output inferred from ONNX model is same as from PyTorch""" + """Test that the output inferred from ONNX model is same as from PyTorch.""" model = BoringModel() model.example_input_array = torch.randn(5, 32) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index d1d870fa30116..c9a784ed0a0f5 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -105,10 +105,7 @@ def validation_step_end(self, outputs): def test_model_properties_resume_from_checkpoint(tmpdir): - """ - Test that properties like `current_epoch` and `global_step` - in model and trainer are always the same. - """ + """Test that properties like `current_epoch` and `global_step` in model and trainer are always the same.""" model = BoringModel() checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) trainer_args = dict( diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index db7e3874f6ad3..c74fc12f7eb73 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -46,7 +46,7 @@ def test_torchscript_input_output(modelclass): @pytest.mark.parametrize("modelclass", [BoringModel, ParityModuleRNN, BasicGAN]) def test_torchscript_example_input_output_trace(modelclass): - """Test that traced LightningModule forward works with example_input_array""" + """Test that traced LightningModule forward works with example_input_array.""" model = modelclass() if isinstance(model, BoringModel): @@ -64,7 +64,7 @@ def test_torchscript_example_input_output_trace(modelclass): def test_torchscript_input_output_trace(): - """Test that traced LightningModule forward works with example_inputs""" + """Test that traced LightningModule forward works with example_inputs.""" model = BoringModel() example_inputs = torch.randn(1, 32) script = model.to_torchscript(example_inputs=example_inputs, method="trace") @@ -148,7 +148,7 @@ class DummyFileSystem(LocalFileSystem): def test_torchcript_invalid_method(tmpdir): - """Test that an error is thrown with invalid torchscript method""" + """Test that an error is thrown with invalid torchscript method.""" model = BoringModel() model.train(True) @@ -157,7 +157,7 @@ def test_torchcript_invalid_method(tmpdir): def test_torchscript_with_no_input(tmpdir): - """Test that an error is thrown when there is no input tensor""" + """Test that an error is thrown when there is no input tensor.""" model = BoringModel() model.example_input_array = None diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 5aa605cdf38bb..950c3577b89b9 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -177,7 +177,7 @@ def test_model_16bit_tpu_cores_8(tmpdir): @RunIf(tpu=True) @pl_multi_process_test def test_model_tpu_early_stop(tmpdir): - """Test if single TPU core training works""" + """Test if single TPU core training works.""" class CustomBoringModel(BoringModel): def validation_step(self, *args, **kwargs): @@ -222,7 +222,7 @@ def test_tpu_grad_norm(tmpdir): @RunIf(tpu=True) @pl_multi_process_test def test_tpu_clip_grad_by_value(tmpdir): - """Test if clip_gradients by value works on TPU""" + """Test if clip_gradients by value works on TPU.""" tutils.reset_seed() trainer_options = dict( default_root_dir=tmpdir, @@ -242,7 +242,7 @@ def test_tpu_clip_grad_by_value(tmpdir): @RunIf(tpu=True) @pl_multi_process_test def test_dataloaders_passed_to_fit(tmpdir): - """Test if dataloaders passed to trainer works on TPU""" + """Test if dataloaders passed to trainer works on TPU.""" tutils.reset_seed() model = BoringModel() @@ -257,19 +257,19 @@ def test_dataloaders_passed_to_fit(tmpdir): ) @RunIf(tpu=True) def test_tpu_id_to_be_as_expected(tpu_cores, expected_tpu_id): - """Test if trainer.tpu_id is set as expected""" + """Test if trainer.tpu_id is set as expected.""" assert Trainer(tpu_cores=tpu_cores).accelerator_connector.tpu_id == expected_tpu_id def test_tpu_misconfiguration(): - """Test if trainer.tpu_id is set as expected""" + """Test if trainer.tpu_id is set as expected.""" with pytest.raises(MisconfigurationException, match="`tpu_cores` can only be"): Trainer(tpu_cores=[1, 8]) @pytest.mark.skipif(_TPU_AVAILABLE, reason="test requires missing TPU") def test_exception_when_no_tpu_found(tmpdir): - """Test if exception is thrown when xla devices are not available""" + """Test if exception is thrown when xla devices are not available.""" with pytest.raises(MisconfigurationException, match="No TPU devices were found."): Trainer(tpu_cores=8) @@ -278,14 +278,14 @@ def test_exception_when_no_tpu_found(tmpdir): @pytest.mark.parametrize("tpu_cores", [1, 8, [1]]) @RunIf(tpu=True) def test_distributed_backend_set_when_using_tpu(tmpdir, tpu_cores): - """Test if distributed_backend is set to `tpu` when tpu_cores is not None""" + """Test if distributed_backend is set to `tpu` when tpu_cores is not None.""" assert Trainer(tpu_cores=tpu_cores).distributed_backend == "tpu" @RunIf(tpu=True) @pl_multi_process_test def test_broadcast_on_tpu(): - """Checks if an object from the master process is broadcasted to other processes correctly""" + """Checks if an object from the master process is broadcasted to other processes correctly.""" def test_broadcast(rank): trainer = Trainer(tpu_cores=8) @@ -332,7 +332,7 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected): @RunIf(tpu=True) @pl_multi_process_test def test_tpu_cores_with_argparse(cli_args, expected): - """Test passing tpu_cores in command line""" + """Test passing tpu_cores in command line.""" cli_args = cli_args.split(" ") if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): parser = ArgumentParser(add_help=False) @@ -347,7 +347,7 @@ def test_tpu_cores_with_argparse(cli_args, expected): @RunIf(tpu=True) @pl_multi_process_test def test_tpu_reduce(): - """Test tpu spawn reduce operation""" + """Test tpu spawn reduce operation.""" def test_reduce(rank): trainer = Trainer(tpu_cores=8) @@ -372,8 +372,8 @@ def test_reduce(rank): @pytest.mark.parametrize("clip_val", [10]) @mock.patch("torch.nn.utils.clip_grad_norm_") def test_tpu_precision_16_clip_gradients(mock_clip_grad_norm, clip_val, tmpdir): - """ - Ensure that clip gradients is only called if the value is greater than 0. + """Ensure that clip gradients is only called if the value is greater than 0. + TODO: Fix (test fails with parametrize) """ tutils.reset_seed() @@ -411,7 +411,7 @@ def test_if_test_works_with_checkpoint_false(tmpdir): @RunIf(tpu=True) @pl_multi_process_test def test_tpu_sync_dist(): - """Test tpu spawn sync dist operation""" + """Test tpu spawn sync dist operation.""" def test_sync_dist(_): sync = _Sync(TPUSpawnPlugin().reduce, should=True, op=torch.distributed.ReduceOp.SUM) diff --git a/tests/plugins/test_amp_plugins.py b/tests/plugins/test_amp_plugins.py index 47852bba47c97..f567893c1576f 100644 --- a/tests/plugins/test_amp_plugins.py +++ b/tests/plugins/test_amp_plugins.py @@ -106,9 +106,7 @@ def test_amp_gradient_unscale(tmpdir, accum: int): @RunIf(min_gpus=1, amp_native=True) def test_amp_skip_optimizer(tmpdir): - """ - Test that optimizers can be skipped when using amp - """ + """Test that optimizers can be skipped when using amp.""" class CustomBoringModel(BoringModel): def __init__(self): @@ -208,9 +206,7 @@ def test_cpu_amp_precision_throws_error(tmpdir): amp_native=True, ) def test_cpu_amp_precision_context_manager(tmpdir): - """ - Test to ensure that the context manager correctly is set to CPU + bfloat16, and a scaler isn't set. - """ + """Test to ensure that the context manager correctly is set to CPU + bfloat16, and a scaler isn't set.""" plugin = NativeMixedPrecisionPlugin(precision="bf16", use_cpu=True) assert plugin.use_cpu @@ -226,9 +222,7 @@ def test_cpu_amp_precision_context_manager(tmpdir): amp_native=True, ) def test_cpu_amp_precision_16_throws_error(tmpdir): - """ - Throw error when using 16 as Native CPU AMP only supports bfloat16. - """ + """Throw error when using 16 as Native CPU AMP only supports bfloat16.""" with pytest.raises( MisconfigurationException, diff --git a/tests/plugins/test_checkpoint_io_plugin.py b/tests/plugins/test_checkpoint_io_plugin.py index 810127a03f361..2b0195d584de7 100644 --- a/tests/plugins/test_checkpoint_io_plugin.py +++ b/tests/plugins/test_checkpoint_io_plugin.py @@ -35,9 +35,7 @@ def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> def test_checkpoint_plugin_called(tmpdir): - """ - Ensure that the custom checkpoint IO plugin and torch checkpoint IO plugin is called when saving/loading. - """ + """Ensure that the custom checkpoint IO plugin and torch checkpoint IO plugin is called when saving/loading.""" checkpoint_plugin = CustomCheckpointIO() checkpoint_plugin = MagicMock(wraps=checkpoint_plugin, spec=CustomCheckpointIO) diff --git a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py index 3dd4897864947..c9c29d31c42ae 100644 --- a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py @@ -18,9 +18,7 @@ def test_invalid_on_cpu(tmpdir): - """ - Test to ensure that to raise Misconfiguration for FSDP on CPU. - """ + """Test to ensure that to raise Misconfiguration for FSDP on CPU.""" with pytest.raises( MisconfigurationException, match="You selected accelerator to be `ddp_fully_sharded`, but GPU is not available." ): @@ -34,9 +32,7 @@ def test_invalid_on_cpu(tmpdir): @mock.patch("torch.cuda.is_available", return_value=True) @RunIf(amp_apex=True, fairscale_fully_sharded=True) def test_invalid_apex_sharded(device_count_mock, mock_cuda_available, tmpdir): - """ - Test to ensure that we raise an error when we try to use apex and fully sharded - """ + """Test to ensure that we raise an error when we try to use apex and fully sharded.""" with pytest.raises(MisconfigurationException, match="Sharded Plugin is not supported with Apex AMP"): Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins="fsdp", gpus=1, precision=16, amp_backend="apex") @@ -46,9 +42,7 @@ def test_invalid_apex_sharded(device_count_mock, mock_cuda_available, tmpdir): @mock.patch("torch.cuda.is_available", return_value=True) @RunIf(amp_native=True, fairscale_fully_sharded=True) def test_fsdp_with_sharded_amp(device_count_mock, mock_cuda_available, tmpdir): - """ - Test to ensure that plugin native amp plugin is correctly chosen when using sharded - """ + """Test to ensure that plugin native amp plugin is correctly chosen when using sharded.""" trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins="fsdp", gpus=1, precision=16) assert isinstance(trainer.accelerator.training_type_plugin, DDPFullyShardedPlugin) assert isinstance(trainer.accelerator.precision_plugin, FullyShardedNativeMixedPrecisionPlugin) @@ -105,9 +99,7 @@ def _assert_layer_fsdp_instance(self) -> None: @RunIf(min_gpus=1, skip_windows=True, fairscale_fully_sharded=True, amp_native=True, special=True) def test_fully_sharded_plugin_checkpoint(tmpdir): - """ - Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run. - """ + """Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run.""" model = TestFSDPModel() trainer = Trainer(default_root_dir=tmpdir, gpus=1, plugins="fsdp", precision=16, max_epochs=1) @@ -116,9 +108,7 @@ def test_fully_sharded_plugin_checkpoint(tmpdir): @RunIf(min_gpus=2, skip_windows=True, fairscale_fully_sharded=True, amp_native=True, special=True) def test_fully_sharded_plugin_checkpoint_multi_gpus(tmpdir): - """ - Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run. - """ + """Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run.""" model = TestFSDPModel() ck = ModelCheckpoint(save_last=True) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 7400569bd3f99..9b4a1f8a4ba99 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -67,9 +67,7 @@ def automatic_optimization(self) -> bool: def test_deepspeed_lightning_module(tmpdir): - """ - Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves types and device correctly. - """ + """Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves types and device correctly.""" model = BoringModel() module = LightningDeepSpeedModule(model, precision=16) @@ -85,9 +83,8 @@ def test_deepspeed_lightning_module(tmpdir): @RunIf(min_gpus=1) def test_deepspeed_lightning_module_precision(tmpdir): - """ - Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves tensors to half when precision 16. - """ + """Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves tensors to half when precision + 16.""" model = BoringModel() module = LightningDeepSpeedModule(model, precision=16) @@ -125,9 +122,8 @@ def deepspeed_zero_config(deepspeed_config): @RunIf(deepspeed=True) @pytest.mark.parametrize("input", ("deepspeed", DeepSpeedPlugin)) def test_deepspeed_plugin_string(tmpdir, input): - """ - Test to ensure that the plugin can be passed via string or instance, and parallel devices is correctly set. - """ + """Test to ensure that the plugin can be passed via string or instance, and parallel devices is correctly + set.""" trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir, plugins=input if isinstance(input, str) else input()) @@ -137,9 +133,7 @@ def test_deepspeed_plugin_string(tmpdir, input): @RunIf(deepspeed=True) def test_deepspeed_plugin_env(tmpdir, monkeypatch, deepspeed_config): - """ - Test to ensure that the plugin can be passed via a string with an environment variable. - """ + """Test to ensure that the plugin can be passed via a string with an environment variable.""" config_path = os.path.join(tmpdir, "temp.json") with open(config_path, "w") as f: f.write(json.dumps(deepspeed_config)) @@ -160,8 +154,8 @@ def test_deepspeed_plugin_env(tmpdir, monkeypatch, deepspeed_config): [pytest.param("native", marks=RunIf(amp_native=True)), pytest.param("apex", marks=RunIf(amp_apex=True))], ) def test_deepspeed_precision_choice(amp_backend, precision, tmpdir): - """ - Test to ensure precision plugin is also correctly chosen. + """Test to ensure precision plugin is also correctly chosen. + DeepSpeed handles precision via Custom DeepSpeedPrecisionPlugin """ @@ -176,9 +170,7 @@ def test_deepspeed_precision_choice(amp_backend, precision, tmpdir): @RunIf(deepspeed=True) def test_deepspeed_with_invalid_config_path(tmpdir): - """ - Test to ensure if we pass an invalid config path we throw an exception. - """ + """Test to ensure if we pass an invalid config path we throw an exception.""" with pytest.raises( MisconfigurationException, match="You passed in a path to a DeepSpeed config but the path does not exist" @@ -188,9 +180,7 @@ def test_deepspeed_with_invalid_config_path(tmpdir): @RunIf(deepspeed=True) def test_deepspeed_with_env_path(tmpdir, monkeypatch, deepspeed_config): - """ - Test to ensure if we pass an env variable, we load the config from the path. - """ + """Test to ensure if we pass an env variable, we load the config from the path.""" config_path = os.path.join(tmpdir, "temp.json") with open(config_path, "w") as f: f.write(json.dumps(deepspeed_config)) @@ -201,9 +191,7 @@ def test_deepspeed_with_env_path(tmpdir, monkeypatch, deepspeed_config): @RunIf(deepspeed=True) def test_deepspeed_defaults(tmpdir): - """ - Ensure that defaults are correctly set as a config for DeepSpeed if no arguments are passed. - """ + """Ensure that defaults are correctly set as a config for DeepSpeed if no arguments are passed.""" plugin = DeepSpeedPlugin() assert plugin.config is not None assert isinstance(plugin.config["zero_optimization"], dict) @@ -263,10 +251,8 @@ def on_train_start(self, trainer, pl_module) -> None: @RunIf(min_gpus=1, deepspeed=True, special=True) def test_deepspeed_run_configure_optimizers(tmpdir): - """ - Test end to end that deepspeed works with defaults (without ZeRO as that requires compilation), - whilst using configure_optimizers for optimizers and schedulers. - """ + """Test end to end that deepspeed works with defaults (without ZeRO as that requires compilation), whilst using + configure_optimizers for optimizers and schedulers.""" class TestCB(Callback): def on_train_start(self, trainer, pl_module) -> None: @@ -304,10 +290,8 @@ def configure_optimizers(self): @RunIf(min_gpus=1, deepspeed=True, special=True) def test_deepspeed_config(tmpdir, deepspeed_zero_config): - """ - Test to ensure deepspeed works correctly when passed a DeepSpeed config object including optimizers/schedulers - and saves the model weights to load correctly. - """ + """Test to ensure deepspeed works correctly when passed a DeepSpeed config object including + optimizers/schedulers and saves the model weights to load correctly.""" class TestCB(Callback): def on_train_start(self, trainer, pl_module) -> None: @@ -336,7 +320,8 @@ def on_train_start(self, trainer, pl_module) -> None: @RunIf(min_gpus=1, deepspeed=True, special=True) def test_deepspeed_custom_precision_params(tmpdir): - """Ensure if we modify the FP16 parameters via the DeepSpeedPlugin, the deepspeed config contains these changes.""" + """Ensure if we modify the FP16 parameters via the DeepSpeedPlugin, the deepspeed config contains these + changes.""" class TestCB(Callback): def on_train_start(self, trainer, pl_module) -> None: @@ -397,9 +382,7 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module) -> None: @RunIf(min_gpus=2, deepspeed=True, special=True) def test_deepspeed_multigpu(tmpdir): - """ - Test to ensure that DeepSpeed with multiple GPUs works. - """ + """Test to ensure that DeepSpeed with multiple GPUs works.""" model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=2, fast_dev_run=True, precision=16 @@ -419,9 +402,7 @@ def test_deepspeed_fp32_works(tmpdir): @RunIf(min_gpus=2, deepspeed=True, special=True) def test_deepspeed_stage_3_save_warning(tmpdir): - """ - Test to ensure that DeepSpeed Stage 3 gives a warning when saving. - """ + """Test to ensure that DeepSpeed Stage 3 gives a warning when saving.""" model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=2, fast_dev_run=True, precision=16 @@ -434,9 +415,7 @@ def test_deepspeed_stage_3_save_warning(tmpdir): @RunIf(min_gpus=1, deepspeed=True, special=True) def test_deepspeed_multigpu_single_file(tmpdir): - """ - Test to ensure that DeepSpeed loads from a single file checkpoint. - """ + """Test to ensure that DeepSpeed loads from a single file checkpoint.""" model = BoringModel() checkpoint_path = os.path.join(tmpdir, "model.pt") trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) @@ -545,9 +524,7 @@ def training_step(self, batch, batch_idx): @RunIf(min_gpus=2, deepspeed=True, special=True) def test_deepspeed_multigpu_stage_3(tmpdir, deepspeed_config): - """ - Test to ensure ZeRO Stage 3 works with a parallel model. - """ + """Test to ensure ZeRO Stage 3 works with a parallel model.""" model = ModelParallelBoringModel() trainer = Trainer( default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=2, fast_dev_run=True, precision=16 @@ -560,9 +537,7 @@ def test_deepspeed_multigpu_stage_3(tmpdir, deepspeed_config): @RunIf(min_gpus=2, deepspeed=True, special=True) def test_deepspeed_multigpu_stage_3_manual_optimization(tmpdir, deepspeed_config): - """ - Test to ensure ZeRO Stage 3 works with a parallel model. - """ + """Test to ensure ZeRO Stage 3 works with a parallel model.""" model = ModelParallelBoringModelManualOptim() model.training_epoch_end = None trainer = Trainer( @@ -611,19 +586,15 @@ def run_checkpoint_test(tmpdir: str, automatic_optimization: bool = True, accumu @RunIf(min_gpus=2, deepspeed=True, special=True) def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir): - """ - Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint, - and see convergence. - """ + """Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint, and + see convergence.""" run_checkpoint_test(tmpdir) @RunIf(min_gpus=1, deepspeed=True, special=False) def test_deepspeed_multigpu_stage_3_warns_resume_training(tmpdir): - """ - Test to ensure with Stage 3 and multiple GPUs that we can resume from training, throwing a warning - that the optimizer state and scheduler states cannot be restored. - """ + """Test to ensure with Stage 3 and multiple GPUs that we can resume from training, throwing a warning that the + optimizer state and scheduler states cannot be restored.""" dm = ClassifDataModule() model = BoringModel() checkpoint_path = os.path.join(tmpdir, "model.pt") @@ -650,9 +621,7 @@ def test_deepspeed_multigpu_stage_3_warns_resume_training(tmpdir): @RunIf(min_gpus=1, deepspeed=True, special=True) def test_deepspeed_multigpu_stage_3_resume_training(tmpdir): - """ - Test to ensure with Stage 3 and multiple GPUs that we can resume training. - """ + """Test to ensure with Stage 3 and multiple GPUs that we can resume training.""" initial_model = ModelParallelClassificationModel() dm = ClassifDataModule() @@ -707,10 +676,8 @@ def on_train_batch_start( @RunIf(min_gpus=2, deepspeed=True, special=True) def test_deepspeed_multigpu_stage_3_checkpointing_full_weights_manual(tmpdir): - """ - Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint, - where we save the full weights to one file. - """ + """Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint, + where we save the full weights to one file.""" run_checkpoint_test(tmpdir, automatic_optimization=False, accumulate_grad_batches=1) @@ -725,9 +692,7 @@ def test_deepspeed_multigpu_stage_2_accumulated_grad_batches_offload_optimizer(t def _deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir, offload_optimizer): - """ - Test to ensure with Stage 2 and multiple GPUs, accumulated grad batches works. - """ + """Test to ensure with Stage 2 and multiple GPUs, accumulated grad batches works.""" seed_everything(42) class VerificationCallback(Callback): @@ -761,9 +726,7 @@ def on_train_batch_start( @RunIf(min_gpus=2, deepspeed=True, special=True) def test_deepspeed_multigpu_test(tmpdir, deepspeed_config): - """ - Test to ensure we can use DeepSpeed with just test using ZeRO Stage 3. - """ + """Test to ensure we can use DeepSpeed with just test using ZeRO Stage 3.""" model = ModelParallelBoringModel() trainer = Trainer( default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=2, fast_dev_run=True, precision=16 @@ -775,8 +738,8 @@ def test_deepspeed_multigpu_test(tmpdir, deepspeed_config): @mock.patch("deepspeed.init_distributed", autospec=True) @pytest.mark.parametrize("platform", ["Linux", "Windows"]) def test_deepspeed_plugin_env_variables(mock_deepspeed_distributed, tmpdir, platform): - """ - Test to ensure that we setup distributed communication using correctly. + """Test to ensure that we setup distributed communication using correctly. + When using windows, ranks environment variables should not be set, and deepspeed should handle this. """ trainer = Trainer(default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)]) @@ -820,9 +783,7 @@ def _assert_save_model_is_equal(model, tmpdir, trainer): @RunIf(min_gpus=2, deepspeed=True, special=True) def test_deepspeed_multigpu_no_schedulers(tmpdir): - """ - Test to ensure ZeRO Stage 3 works with a parallel model and no schedulers. - """ + """Test to ensure ZeRO Stage 3 works with a parallel model and no schedulers.""" model = ModelParallelBoringModelNoSchedulers() trainer = Trainer( default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=2, fast_dev_run=True, precision=16 diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 3bcdc357fc351..ac2fccbb449c1 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -21,9 +21,7 @@ @RunIf(min_gpus=1, skip_windows=True, amp_native=True, fairscale=True) @mock.patch("fairscale.optim.oss.OSS.clip_grad_norm") def test_ddp_sharded_precision_16_clip_gradients(mock_oss_clip_grad_norm, clip_val, tmpdir): - """ - Ensure that clip gradients is only called if the value is greater than 0. - """ + """Ensure that clip gradients is only called if the value is greater than 0.""" model = BoringModel() trainer = Trainer(accelerator="ddp_sharded", gpus=1, precision=16, fast_dev_run=True, gradient_clip_val=clip_val) trainer.fit(model) @@ -36,9 +34,7 @@ def test_ddp_sharded_precision_16_clip_gradients(mock_oss_clip_grad_norm, clip_v @RunIf(fairscale=True) @pytest.mark.parametrize(["accelerator"], [("ddp_sharded",), ("ddp_sharded_spawn",)]) def test_sharded_ddp_choice(tmpdir, accelerator): - """ - Test to ensure that plugin is correctly chosen - """ + """Test to ensure that plugin is correctly chosen.""" class CB(Callback): def on_fit_start(self, trainer, pl_module): @@ -57,9 +53,7 @@ def on_fit_start(self, trainer, pl_module): @RunIf(amp_apex=True, fairscale=True) def test_invalid_apex_sharded(tmpdir): - """ - Test to ensure that we raise an error when we try to use apex and sharded - """ + """Test to ensure that we raise an error when we try to use apex and sharded.""" model = BoringModel() with pytest.raises(MisconfigurationException, match="Sharded Plugin is not supported with Apex AMP"): @@ -71,9 +65,7 @@ def test_invalid_apex_sharded(tmpdir): @RunIf(min_gpus=2, amp_native=True, fairscale=True) @pytest.mark.parametrize(["accelerator"], [("ddp_sharded",), ("ddp_sharded_spawn",)]) def test_ddp_choice_sharded_amp(tmpdir, accelerator): - """ - Test to ensure that plugin native amp plugin is correctly chosen when using sharded - """ + """Test to ensure that plugin native amp plugin is correctly chosen when using sharded.""" class CB(Callback): def on_fit_start(self, trainer, pl_module): @@ -92,9 +84,7 @@ def on_fit_start(self, trainer, pl_module): @RunIf(skip_windows=True, fairscale=True) def test_ddp_sharded_plugin_checkpoint_cpu(tmpdir): - """ - Test to ensure that checkpoint is saved correctly - """ + """Test to ensure that checkpoint is saved correctly.""" model = BoringModel() trainer = Trainer(accelerator="ddp_sharded_spawn", num_processes=2, fast_dev_run=True) @@ -111,9 +101,7 @@ def test_ddp_sharded_plugin_checkpoint_cpu(tmpdir): @RunIf(min_gpus=2, skip_windows=True, fairscale=True) def test_ddp_sharded_plugin_checkpoint_multi_gpu(tmpdir): - """ - Test to ensure that checkpoint is saved correctly when using multiple GPUs - """ + """Test to ensure that checkpoint is saved correctly when using multiple GPUs.""" model = BoringModel() trainer = Trainer(gpus=2, accelerator="ddp_sharded_spawn", fast_dev_run=True) @@ -130,9 +118,7 @@ def test_ddp_sharded_plugin_checkpoint_multi_gpu(tmpdir): @RunIf(min_gpus=2, skip_windows=True, fairscale=True) def test_ddp_sharded_plugin_finetune(tmpdir): - """ - Test to ensure that we can save and restart training (simulate fine-tuning) - """ + """Test to ensure that we can save and restart training (simulate fine-tuning)""" model = BoringModel() trainer = Trainer(gpus=2, accelerator="ddp_sharded_spawn", fast_dev_run=True) trainer.fit(model) @@ -147,9 +133,7 @@ def test_ddp_sharded_plugin_finetune(tmpdir): @RunIf(skip_windows=True, fairscale=True) def test_ddp_sharded_plugin_resume_from_checkpoint(tmpdir): - """ - Test to ensure that resuming from checkpoint works - """ + """Test to ensure that resuming from checkpoint works.""" model = BoringModel() trainer = Trainer(accelerator="ddp_sharded_spawn", num_processes=2, fast_dev_run=True) @@ -171,9 +155,7 @@ def test_ddp_sharded_plugin_resume_from_checkpoint(tmpdir): @pytest.mark.skip(reason="Currently unsupported restarting training on different number of devices.") @RunIf(min_gpus=2, skip_windows=True, fairscale=True) def test_ddp_sharded_plugin_resume_from_checkpoint_downsize_gpus(tmpdir): - """ - Test to ensure that resuming from checkpoint works when downsizing number of GPUS - """ + """Test to ensure that resuming from checkpoint works when downsizing number of GPUS.""" model = BoringModel() trainer = Trainer(accelerator="ddp_sharded_spawn", fast_dev_run=True, gpus=2) @@ -193,9 +175,7 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_downsize_gpus(tmpdir): @RunIf(min_gpus=1, skip_windows=True, fairscale=True) def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): - """ - Test to ensure that resuming from checkpoint works when going from GPUs- > CPU - """ + """Test to ensure that resuming from checkpoint works when going from GPUs- > CPU.""" model = BoringModel() trainer = Trainer(accelerator="ddp_sharded_spawn", gpus=1, fast_dev_run=True) @@ -216,9 +196,7 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): @RunIf(skip_windows=True, special=True, fairscale=True) @pytest.mark.parametrize("trainer_kwargs", (dict(num_processes=2), pytest.param(dict(gpus=2), marks=RunIf(min_gpus=2)))) def test_ddp_sharded_plugin_test_multigpu(tmpdir, trainer_kwargs): - """ - Test to ensure we can use validate and test without fit - """ + """Test to ensure we can use validate and test without fit.""" model = BoringModel() trainer = Trainer(accelerator="ddp_sharded_spawn", fast_dev_run=True, **trainer_kwargs) diff --git a/tests/profiler/test_profiler.py b/tests/profiler/test_profiler.py index 08511d2e544c5..8210804a46ddc 100644 --- a/tests/profiler/test_profiler.py +++ b/tests/profiler/test_profiler.py @@ -40,10 +40,8 @@ def _get_python_cprofile_total_duration(profile): def _sleep_generator(durations): - """ - the profile_iterable method needs an iterable in which we can ensure that we're - properly timing how long it takes to call __next__ - """ + """the profile_iterable method needs an iterable in which we can ensure that we're properly timing how long it + takes to call __next__""" for duration in durations: time.sleep(duration) yield duration @@ -115,7 +113,7 @@ def test_simple_profiler_deepcopy(tmpdir): def test_simple_profiler_log_dir(tmpdir): - """Ensure the profiler dirpath defaults to `trainer.log_dir` when not present""" + """Ensure the profiler dirpath defaults to `trainer.log_dir` when not present.""" profiler = SimpleProfiler(filename="profiler") assert profiler._log_dir is None @@ -131,7 +129,7 @@ def test_simple_profiler_log_dir(tmpdir): @RunIf(skip_windows=True) def test_simple_profiler_distributed_files(tmpdir): - """Ensure the proper files are saved in distributed""" + """Ensure the proper files are saved in distributed.""" profiler = SimpleProfiler(dirpath=tmpdir, filename="profiler") model = BoringModel() trainer = Trainer( @@ -150,7 +148,7 @@ def test_simple_profiler_distributed_files(tmpdir): def test_simple_profiler_logs(tmpdir, caplog, simple_profiler): - """Ensure that the number of printed logs is correct""" + """Ensure that the number of printed logs is correct.""" model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2, profiler=simple_profiler, logger=False) with caplog.at_level(logging.INFO, logger="pytorch_lightning.profiler.profilers"): @@ -197,9 +195,7 @@ def test_advanced_profiler_iterable_durations(advanced_profiler, action: str, ex def test_advanced_profiler_overhead(advanced_profiler, n_iter=5): - """ - ensure that the profiler doesn't introduce too much overhead during training - """ + """ensure that the profiler doesn't introduce too much overhead during training.""" for _ in range(n_iter): with advanced_profiler.profile("no-op"): pass @@ -211,9 +207,7 @@ def test_advanced_profiler_overhead(advanced_profiler, n_iter=5): def test_advanced_profiler_describe(tmpdir, advanced_profiler): - """ - ensure the profiler won't fail when reporting the summary - """ + """ensure the profiler won't fail when reporting the summary.""" # record at least one event with advanced_profiler.profile("test"): pass @@ -259,7 +253,7 @@ def test_pytorch_profiler_describe(pytorch_profiler): def test_advanced_profiler_cprofile_deepcopy(tmpdir): - """Checks for pickle issue reported in #6522""" + """Checks for pickle issue reported in #6522.""" model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, profiler="advanced", stochastic_weight_avg=True) trainer.fit(model) @@ -347,7 +341,7 @@ def test_pytorch_profiler_trainer(fn, step_name, boring_model_cls, tmpdir): def test_pytorch_profiler_nested(tmpdir): - """Ensure that the profiler handles nested context""" + """Ensure that the profiler handles nested context.""" pytorch_profiler = PyTorchProfiler( record_functions={"a", "b", "c"}, use_cuda=False, dirpath=tmpdir, filename="profiler", schedule=None @@ -374,13 +368,14 @@ def test_pytorch_profiler_nested(tmpdir): def test_pytorch_profiler_logger_collection(tmpdir): - """ - Tests whether the PyTorch profiler is able to write its trace locally when - the Trainer's logger is an instance of LoggerCollection. See issue #8157. + """Tests whether the PyTorch profiler is able to write its trace locally when the Trainer's logger is an + instance of LoggerCollection. + + See issue #8157. """ def look_for_trace(trace_dir): - """Determines if a directory contains a PyTorch trace""" + """Determines if a directory contains a PyTorch trace.""" return any("trace.json" in filename for filename in os.listdir(trace_dir)) # Sanity check @@ -399,9 +394,7 @@ def look_for_trace(trace_dir): @RunIf(min_gpus=1, special=True) def test_pytorch_profiler_nested_emit_nvtx(tmpdir): - """ - This test check emit_nvtx is correctly supported - """ + """This test check emit_nvtx is correctly supported.""" profiler = PyTorchProfiler(use_cuda=True, emit_nvtx=True) model = BoringModel() @@ -448,9 +441,7 @@ def __init__(self): @pytest.mark.parametrize("cls", (SimpleProfiler, AdvancedProfiler, PyTorchProfiler)) def test_profiler_teardown(tmpdir, cls): - """ - This test checks if profiler teardown method is called when trainer is exiting. - """ + """This test checks if profiler teardown method is called when trainer is exiting.""" class TestCallback(Callback): def on_fit_end(self, trainer, *args, **kwargs) -> None: diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 455e08dc10ad5..1973b2f2c7369 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -77,10 +77,8 @@ def on_save_checkpoint(self, *args): def test_all_callback_states_saved_before_checkpoint_callback(tmpdir): - """ - Test that all callback states get saved even if the ModelCheckpoint is not given as last - and when there are multiple callbacks of the same type. - """ + """Test that all callback states get saved even if the ModelCheckpoint is not given as last and when there are + multiple callbacks of the same type.""" callback0 = StatefulCallback0() callback1 = StatefulCallback1(unique="one") diff --git a/tests/trainer/dynamic_args/test_multiple_eval_dataloaders.py b/tests/trainer/dynamic_args/test_multiple_eval_dataloaders.py index 15e817da975be..7c49b4b12534b 100644 --- a/tests/trainer/dynamic_args/test_multiple_eval_dataloaders.py +++ b/tests/trainer/dynamic_args/test_multiple_eval_dataloaders.py @@ -107,9 +107,7 @@ def val_dataloader(self): def test_multiple_optimizers_multiple_dataloaders(tmpdir): - """ - Tests that only training_step can be used - """ + """Tests that only training_step can be used.""" class TestModel(BoringModel): def on_train_epoch_start(self) -> None: diff --git a/tests/trainer/flags/test_env_vars.py b/tests/trainer/flags/test_env_vars.py index c0f7983a134f4..836c80a49821e 100644 --- a/tests/trainer/flags/test_env_vars.py +++ b/tests/trainer/flags/test_env_vars.py @@ -18,7 +18,7 @@ def test_passing_no_env_variables(): - """Testing overwriting trainer arguments""" + """Testing overwriting trainer arguments.""" trainer = Trainer() assert trainer.logger is not None assert trainer.max_steps is None @@ -29,7 +29,7 @@ def test_passing_no_env_variables(): @mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "False", "PL_TRAINER_MAX_STEPS": "7"}) def test_passing_env_variables_only(): - """Testing overwriting trainer arguments""" + """Testing overwriting trainer arguments.""" trainer = Trainer() assert trainer.logger is None assert trainer.max_steps == 7 @@ -37,7 +37,7 @@ def test_passing_env_variables_only(): @mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "True", "PL_TRAINER_MAX_STEPS": "7"}) def test_passing_env_variables_defaults(): - """Testing overwriting trainer arguments""" + """Testing overwriting trainer arguments.""" trainer = Trainer(False, max_steps=42) assert trainer.logger is None assert trainer.max_steps == 42 @@ -47,7 +47,7 @@ def test_passing_env_variables_defaults(): @mock.patch("torch.cuda.device_count", return_value=2) @mock.patch("torch.cuda.is_available", return_value=True) def test_passing_env_variables_gpus(cuda_available_mock, device_count_mock): - """Testing overwriting trainer arguments""" + """Testing overwriting trainer arguments.""" trainer = Trainer() assert trainer.gpus == 2 trainer = Trainer(gpus=1) diff --git a/tests/trainer/flags/test_fast_dev_run.py b/tests/trainer/flags/test_fast_dev_run.py index a6e54d2fd1738..f6c9ea01987b6 100644 --- a/tests/trainer/flags/test_fast_dev_run.py +++ b/tests/trainer/flags/test_fast_dev_run.py @@ -12,7 +12,7 @@ @pytest.mark.parametrize("tuner_alg", ["batch size scaler", "learning rate finder"]) def test_skip_on_fast_dev_run_tuner(tmpdir, tuner_alg): - """Test that tuner algorithms are skipped if fast dev run is enabled""" + """Test that tuner algorithms are skipped if fast dev run is enabled.""" model = BoringModel() model.lr = 0.1 # avoid no-lr-found exception @@ -30,9 +30,7 @@ def test_skip_on_fast_dev_run_tuner(tmpdir, tuner_alg): @pytest.mark.parametrize("fast_dev_run", [1, 4]) def test_callbacks_and_logger_not_called_with_fastdevrun(tmpdir, fast_dev_run): - """ - Test that ModelCheckpoint, EarlyStopping and Logger are turned off with fast_dev_run - """ + """Test that ModelCheckpoint, EarlyStopping and Logger are turned off with fast_dev_run.""" class FastDevRunModel(BoringModel): def __init__(self): diff --git a/tests/trainer/flags/test_min_max_epochs.py b/tests/trainer/flags/test_min_max_epochs.py index 059a447e10edb..ecfdb8bedf6e6 100644 --- a/tests/trainer/flags/test_min_max_epochs.py +++ b/tests/trainer/flags/test_min_max_epochs.py @@ -17,9 +17,7 @@ ], ) def test_min_max_steps_epochs(tmpdir, min_epochs, max_epochs, min_steps, max_steps): - """ - Tests that max_steps can be used without max_epochs - """ + """Tests that max_steps can be used without max_epochs.""" model = BoringModel() trainer = Trainer( diff --git a/tests/trainer/flags/test_overfit_batches.py b/tests/trainer/flags/test_overfit_batches.py index 798b9988469df..f9411c1eeea00 100644 --- a/tests/trainer/flags/test_overfit_batches.py +++ b/tests/trainer/flags/test_overfit_batches.py @@ -19,9 +19,7 @@ def test_overfit_multiple_val_loaders(tmpdir): - """ - Tests that only training_step can be used - """ + """Tests that only training_step can be used.""" class TestModel(BoringModel): def validation_step(self, batch, batch_idx, dataloader_idx): @@ -48,9 +46,7 @@ def val_dataloader(self): @pytest.mark.parametrize("overfit", [1, 2, 0.1, 0.25, 1.0]) def test_overfit_basic(tmpdir, overfit): - """ - Tests that only training_step can be used - """ + """Tests that only training_step can be used.""" model = BoringModel() diff --git a/tests/trainer/logging_/test_distributed_logging.py b/tests/trainer/logging_/test_distributed_logging.py index 727e95b894060..ac71de39757b9 100644 --- a/tests/trainer/logging_/test_distributed_logging.py +++ b/tests/trainer/logging_/test_distributed_logging.py @@ -23,8 +23,8 @@ class AllRankLogger(LightningLoggerBase): - """ - Logger to test all-rank logging (i.e. not just rank 0). + """Logger to test all-rank logging (i.e. not just rank 0). + Logs are saved to local variable `logs`. """ @@ -61,9 +61,7 @@ def on_train_end(self): @RunIf(skip_windows=True) def test_all_rank_logging_ddp_cpu(tmpdir): - """ - Check that all ranks can be logged from - """ + """Check that all ranks can be logged from.""" model = TestModel() all_rank_logger = AllRankLogger() trainer = Trainer( @@ -82,9 +80,7 @@ def test_all_rank_logging_ddp_cpu(tmpdir): @RunIf(min_gpus=2) def test_all_rank_logging_ddp_spawn(tmpdir): - """ - Check that all ranks can be logged from - """ + """Check that all ranks can be logged from.""" model = TestModel() all_rank_logger = AllRankLogger() model.training_epoch_end = None @@ -102,9 +98,9 @@ def test_all_rank_logging_ddp_spawn(tmpdir): def test_first_logger_call_in_subprocess(tmpdir): - """ - Test that the Trainer does not call the logger too early. Only when the worker processes are initialized - do we have access to the rank and know which one is the main process. + """Test that the Trainer does not call the logger too early. + + Only when the worker processes are initialized do we have access to the rank and know which one is the main process. """ class LoggerCallsObserver(Callback): @@ -137,9 +133,7 @@ def on_train_start(self, trainer, pl_module): def test_logger_after_fit_predict_test_calls(tmpdir): - """ - Make sure logger outputs are finalized after fit, prediction, and test calls. - """ + """Make sure logger outputs are finalized after fit, prediction, and test calls.""" class BufferLogger(LightningLoggerBase): def __init__(self): diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 8579bc044734a..e8b398bee8872 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -11,9 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Test logging in the evaluation loop -""" +"""Test logging in the evaluation loop.""" import collections import itertools from unittest import mock @@ -29,9 +27,7 @@ def test__validation_step__log(tmpdir): - """ - Tests that validation_step can log - """ + """Tests that validation_step can log.""" class TestModel(BoringModel): def training_step(self, batch, batch_idx): @@ -67,9 +63,7 @@ def validation_step(self, batch, batch_idx): def test__validation_step__epoch_end__log(tmpdir): - """ - Tests that validation_epoch_end can log - """ + """Tests that validation_epoch_end can log.""" class TestModel(BoringModel): def training_step(self, batch, batch_idx): @@ -203,9 +197,7 @@ def validation_epoch_end(self, outputs) -> None: @pytest.mark.parametrize(["batches", "log_interval", "max_epochs"], [(1, 1, 1), (64, 32, 2)]) def test_eval_epoch_only_logging(tmpdir, batches, log_interval, max_epochs): - """ - Tests that test_epoch_end can be used to log, and we return them in the results. - """ + """Tests that test_epoch_end can be used to log, and we return them in the results.""" class TestModel(BoringModel): def test_epoch_end(self, outputs): @@ -264,9 +256,7 @@ def test_dataloader(self): def test_log_works_in_val_callback(tmpdir): - """ - Tests that log can be called within callback - """ + """Tests that log can be called within callback.""" class TestCallback(callbacks.Callback): @@ -369,9 +359,7 @@ def get_expected(on_epoch, values): def test_log_works_in_test_callback(tmpdir): - """ - Tests that log can be called within callback - """ + """Tests that log can be called within callback.""" class TestCallback(callbacks.Callback): @@ -500,9 +488,7 @@ def get_expected(on_epoch, values): @mock.patch("pytorch_lightning.loggers.TensorBoardLogger.log_metrics") def test_validation_step_log_with_tensorboard(mock_log_metrics, tmpdir): - """ - This tests make sure we properly log_metrics to loggers - """ + """This tests make sure we properly log_metrics to loggers.""" class ExtendedModel(BoringModel): diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 59bbd38d6bd21..d26471a715c2b 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -209,7 +209,7 @@ def call(hook, fn, *args, **kwargs): def test_fx_validator_integration(tmpdir): - """Tries to log inside all `LightningModule` and `Callback` hooks to check any expected errors""" + """Tries to log inside all `LightningModule` and `Callback` hooks to check any expected errors.""" not_supported = { None: "`self.trainer` reference is not registered", "on_before_accelerator_backend_setup": "You can't", @@ -337,7 +337,10 @@ def test_dataloader(self): def test_can_return_tensor_with_more_than_one_element(tmpdir): - """Ensure {validation,test}_step return values are not included as callback metrics. #6623""" + """Ensure {validation,test}_step return values are not included as callback metrics. + + #6623 + """ class TestModel(BoringModel): def validation_step(self, batch, *args, **kwargs): @@ -381,7 +384,7 @@ def training_step(self, *args, **kwargs): @pytest.mark.parametrize("add_dataloader_idx", [False, True]) def test_auto_add_dataloader_idx(tmpdir, add_dataloader_idx): - """test that auto_add_dataloader_idx argument works""" + """test that auto_add_dataloader_idx argument works.""" class TestModel(BoringModel): def val_dataloader(self): diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 67df980aa05d6..7bc10d564fd07 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -11,9 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Test logging in the training loop -""" +"""Test logging in the training loop.""" import collections import itertools @@ -32,9 +30,7 @@ def test__training_step__log(tmpdir): - """ - Tests that only training_step can be used - """ + """Tests that only training_step can be used.""" class TestModel(BoringModel): def training_step(self, batch, batch_idx): @@ -96,9 +92,7 @@ def training_step(self, batch, batch_idx): def test__training_step__epoch_end__log(tmpdir): - """ - Tests that training_epoch_end can log - """ + """Tests that training_epoch_end can log.""" class TestModel(BoringModel): def training_step(self, batch, batch_idx): @@ -136,9 +130,7 @@ def training_epoch_end(self, outputs): @pytest.mark.parametrize(["batches", "log_interval", "max_epochs"], [(1, 1, 1), (64, 32, 2)]) def test__training_step__step_end__epoch_end__log(tmpdir, batches, log_interval, max_epochs): - """ - Tests that training_step_end and training_epoch_end can log - """ + """Tests that training_step_end and training_epoch_end can log.""" class TestModel(BoringModel): def training_step(self, batch, batch_idx): @@ -181,9 +173,7 @@ def training_epoch_end(self, outputs): ["batches", "fx", "result"], [(3, min, 0), (3, torch.max, 2), (11, max, 10), (5, "avg", 2), (5, "SUM", 10)] ) def test__training_step__log_max_reduce_fx(tmpdir, batches, fx, result): - """ - Tests that log works correctly with different tensor types - """ + """Tests that log works correctly with different tensor types.""" class TestModel(BoringModel): def training_step(self, batch, batch_idx): @@ -250,9 +240,7 @@ def val_dataloader(self): def test_log_works_in_train_callback(tmpdir): - """ - Tests that log can be called within callback - """ + """Tests that log can be called within callback.""" class TestCallback(callbacks.Callback): @@ -394,9 +382,7 @@ def validation_step(self, batch, batch_idx): "gpus", [None, pytest.param(1, marks=RunIf(min_gpus=1)), pytest.param(2, marks=RunIf(min_gpus=2))] ) def test_logging_sync_dist_true(tmpdir, gpus): - """ - Tests to ensure that the sync_dist flag works (should just return the original value) - """ + """Tests to ensure that the sync_dist flag works (should just return the original value)""" fake_result = 1 model = LoggingSyncDistModel(fake_result) trainer = Trainer( @@ -431,9 +417,7 @@ def test_logging_sync_dist_true(tmpdir, gpus): @RunIf(min_gpus=2, special=True) def test_logging_sync_dist_true_ddp(tmpdir): - """ - Tests to ensure that the sync_dist flag works with ddp - """ + """Tests to ensure that the sync_dist flag works with ddp.""" class TestLoggingSyncDistModel(BoringModel): def training_step(self, batch, batch_idx): @@ -499,9 +483,7 @@ def on_epoch_end(self): def test_logging_in_callbacks_with_log_function(tmpdir): - """ - Tests ensure self.log can be used directly in callbacks. - """ + """Tests ensure self.log can be used directly in callbacks.""" class LoggingCallback(callbacks.Callback): def on_train_start(self, trainer, pl_module): diff --git a/tests/trainer/loops/test_evaluation_loop.py b/tests/trainer/loops/test_evaluation_loop.py index 9d4af2f393aee..218af13ea73db 100644 --- a/tests/trainer/loops/test_evaluation_loop.py +++ b/tests/trainer/loops/test_evaluation_loop.py @@ -24,10 +24,8 @@ @mock.patch("pytorch_lightning.loops.dataloader.evaluation_loop.EvaluationLoop.on_evaluation_epoch_end") def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir): - """ - Tests that `on_evaluation_epoch_end` is called - for `on_validation_epoch_end` and `on_test_epoch_end` hooks - """ + """Tests that `on_evaluation_epoch_end` is called for `on_validation_epoch_end` and `on_test_epoch_end` + hooks.""" model = BoringModel() trainer = Trainer( @@ -47,7 +45,7 @@ def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir): "pytorch_lightning.trainer.connectors.logger_connector.logger_connector.LoggerConnector.update_eval_epoch_metrics" ) def test_log_epoch_metrics_before_on_evaluation_end(update_eval_epoch_metrics_mock, tmpdir): - """Test that the epoch metrics are logged before the `on_evalutaion_end` hook is fired""" + """Test that the epoch metrics are logged before the `on_evalutaion_end` hook is fired.""" order = [] update_eval_epoch_metrics_mock.side_effect = lambda: order.append("log_epoch_metrics") @@ -64,7 +62,7 @@ def on_validation_end(self): @RunIf(min_gpus=1) def test_memory_consumption_validation(tmpdir): - """Test that the training batch is no longer in GPU memory when running validation""" + """Test that the training batch is no longer in GPU memory when running validation.""" initial_memory = torch.cuda.memory_allocated(0) diff --git a/tests/trainer/loops/test_evaluation_loop_flow.py b/tests/trainer/loops/test_evaluation_loop_flow.py index 6af5df081c102..916da177c75d6 100644 --- a/tests/trainer/loops/test_evaluation_loop_flow.py +++ b/tests/trainer/loops/test_evaluation_loop_flow.py @@ -11,9 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Tests the evaluation loop -""" +"""Tests the evaluation loop.""" import torch @@ -24,9 +22,7 @@ def test__eval_step__flow(tmpdir): - """ - Tests that only training_step can be used - """ + """Tests that only training_step can be used.""" class TestModel(DeterministicModel): def training_step(self, batch, batch_idx): @@ -86,9 +82,7 @@ def backward(self, loss, optimizer, optimizer_idx): def test__eval_step__eval_step_end__flow(tmpdir): - """ - Tests that only training_step can be used - """ + """Tests that only training_step can be used.""" class TestModel(DeterministicModel): def training_step(self, batch, batch_idx): @@ -153,9 +147,7 @@ def backward(self, loss, optimizer, optimizer_idx): def test__eval_step__epoch_end__flow(tmpdir): - """ - Tests that only training_step can be used - """ + """Tests that only training_step can be used.""" class TestModel(DeterministicModel): def training_step(self, batch, batch_idx): @@ -208,9 +200,7 @@ def backward(self, loss, optimizer, optimizer_idx): def test__validation_step__step_end__epoch_end__flow(tmpdir): - """ - Tests that only training_step can be used - """ + """Tests that only training_step can be used.""" class TestModel(DeterministicModel): def training_step(self, batch, batch_idx): diff --git a/tests/trainer/loops/test_flow_warnings.py b/tests/trainer/loops/test_flow_warnings.py index e14bd8825510a..d7860c807ac5b 100644 --- a/tests/trainer/loops/test_flow_warnings.py +++ b/tests/trainer/loops/test_flow_warnings.py @@ -24,9 +24,7 @@ def training_step(self, batch, batch_idx): def test_no_depre_without_epoch_end(tmpdir): - """ - Tests that only training_step can be used - """ + """Tests that only training_step can be used.""" model = TestModel() model.validation_epoch_end = None diff --git a/tests/trainer/loops/test_training_loop.py b/tests/trainer/loops/test_training_loop.py index 22258b8e52eea..d4652fe0d2c7f 100644 --- a/tests/trainer/loops/test_training_loop.py +++ b/tests/trainer/loops/test_training_loop.py @@ -104,7 +104,7 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx): def test_should_stop_mid_epoch(tmpdir): - """Test that training correctly stops mid epoch and that validation is still called at the right time""" + """Test that training correctly stops mid epoch and that validation is still called at the right time.""" class TestModel(BoringModel): def __init__(self): @@ -166,10 +166,8 @@ def training_step_end(self, outputs): def test_prepare_outputs(tmpdir): - """ - Test that the `extra` field of the saved `ResultCollection` objects for - `training_epoch_end` doesn't get accidentally modified by reference. - """ + """Test that the `extra` field of the saved `ResultCollection` objects for `training_epoch_end` doesn't get + accidentally modified by reference.""" class TestModel(BoringModel): on_train_batch_end_called = 0 diff --git a/tests/trainer/loops/test_training_loop_flow_dict.py b/tests/trainer/loops/test_training_loop_flow_dict.py index ab4d7979bbf39..b8061610a49bb 100644 --- a/tests/trainer/loops/test_training_loop_flow_dict.py +++ b/tests/trainer/loops/test_training_loop_flow_dict.py @@ -11,9 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Tests to ensure that the training loop works with a dict (1.0) -""" +"""Tests to ensure that the training loop works with a dict (1.0)""" import torch @@ -23,9 +21,7 @@ def test__training_step__flow_dict(tmpdir): - """ - Tests that only training_step can be used - """ + """Tests that only training_step can be used.""" class TestModel(DeterministicModel): def training_step(self, batch, batch_idx): @@ -57,9 +53,7 @@ def backward(self, loss, optimizer, optimizer_idx): def test__training_step__tr_step_end__flow_dict(tmpdir): - """ - Tests that only training_step can be used - """ + """Tests that only training_step can be used.""" class TestModel(DeterministicModel): def training_step(self, batch, batch_idx): @@ -98,9 +92,7 @@ def backward(self, loss, optimizer, optimizer_idx): def test__training_step__epoch_end__flow_dict(tmpdir): - """ - Tests that only training_step can be used - """ + """Tests that only training_step can be used.""" class TestModel(DeterministicModel): def training_step(self, batch, batch_idx): @@ -147,9 +139,7 @@ def backward(self, loss, optimizer, optimizer_idx): def test__training_step__step_end__epoch_end__flow_dict(tmpdir): - """ - Tests that only training_step can be used - """ + """Tests that only training_step can be used.""" class TestModel(DeterministicModel): def training_step(self, batch, batch_idx): diff --git a/tests/trainer/loops/test_training_loop_flow_scalar.py b/tests/trainer/loops/test_training_loop_flow_scalar.py index 43a8e561f8373..56674b0ff8e95 100644 --- a/tests/trainer/loops/test_training_loop_flow_scalar.py +++ b/tests/trainer/loops/test_training_loop_flow_scalar.py @@ -239,7 +239,8 @@ def backward(self, loss, optimizer, optimizer_idx): def test_train_step_no_return(tmpdir): - """Tests that only training_step raises a warning when nothing is returned in case of automatic_optimization.""" + """Tests that only training_step raises a warning when nothing is returned in case of + automatic_optimization.""" class TestModel(BoringModel): def training_step(self, batch, batch_idx): diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 80896f6fa450c..72ebd62ae499e 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -268,10 +268,8 @@ def on_train_end(self): @RunIf(min_gpus=2) def test_manual_optimization_and_return_tensor(tmpdir): - """ - This test verify that in `manual_optimization` - we don't add gradient when the user return loss in `training_step` - """ + """This test verify that in `manual_optimization` we don't add gradient when the user return loss in + `training_step`""" model = ManualOptimizationExtendedModel() model.training_step_end = None @@ -293,11 +291,8 @@ def test_manual_optimization_and_return_tensor(tmpdir): @RunIf(min_gpus=2) def test_manual_optimization_and_return_detached_tensor(tmpdir): - """ - This test verify that in `manual_optimization` - we don't add gradient when the user return loss in `training_step` - When the tensor is detached, return MisConfiguration Error. - """ + """This test verify that in `manual_optimization` we don't add gradient when the user return loss in + `training_step` When the tensor is detached, return MisConfiguration Error.""" model = ManualOptimizationExtendedModel() model.detach = True @@ -322,10 +317,8 @@ def test_manual_optimization_and_return_detached_tensor(tmpdir): @RunIf(min_gpus=1) def test_manual_optimization_and_accumulated_gradient(tmpdir): - """ - This test verify that in `automatic_optimization=False`, - step is being called only when we shouldn't accumulate. - """ + """This test verify that in `automatic_optimization=False`, step is being called only when we shouldn't + accumulate.""" seed_everything(234) class ExtendedModel(BoringModel): @@ -411,9 +404,7 @@ def on_train_epoch_end(self, *_, **__): @RunIf(min_gpus=1) def test_multiple_optimizers_step(tmpdir): - """ - Tests that `step` works with several optimizers - """ + """Tests that `step` works with several optimizers.""" class TestModel(ManualOptModel): @@ -482,9 +473,7 @@ def training_epoch_end(self, outputs) -> None: def test_step_with_optimizer_closure(tmpdir): - """ - Tests that `step` works with optimizer_closure - """ + """Tests that `step` works with optimizer_closure.""" class TestModel(BoringModel): @@ -556,9 +545,7 @@ def configure_optimizers(self): def test_step_with_optimizer_closure_and_accumulated_grad(tmpdir): - """ - Tests that `step` works with optimizer_closure and accumulated_grad - """ + """Tests that `step` works with optimizer_closure and accumulated_grad.""" class TestModel(BoringModel): def __init__(self): @@ -613,9 +600,7 @@ def configure_optimizers(self): @patch("torch.optim.SGD.step") def test_step_with_optimizer_closure_and_extra_arguments(step_mock, tmpdir): - """ - Tests that `step` works with optimizer_closure and extra arguments - """ + """Tests that `step` works with optimizer_closure and extra arguments.""" class TestModel(BoringModel): def __init__(self): @@ -664,9 +649,7 @@ def configure_optimizers(self): @patch("torch.optim.Adam.step") @patch("torch.optim.SGD.step") def test_step_with_optimizer_closure_with_different_frequencies(mock_sgd_step, mock_adam_step, tmpdir): - """ - Tests that `step` works with optimizer_closure and different accumulated_gradient frequency - """ + """Tests that `step` works with optimizer_closure and different accumulated_gradient frequency.""" class TestModel(BoringModel): def __init__(self): @@ -847,18 +830,14 @@ def train_manual_optimization(tmpdir, accelerator, model_cls=TesManualOptimizati @RunIf(min_gpus=2, special=True) def test_step_with_optimizer_closure_with_different_frequencies_ddp(tmpdir): - """ - Tests that `step` works with optimizer_closure and different accumulated_gradient frequency - """ + """Tests that `step` works with optimizer_closure and different accumulated_gradient frequency.""" train_manual_optimization(tmpdir, "ddp") @RunIf(min_gpus=2) def test_step_with_optimizer_closure_with_different_frequencies_ddp_spawn(tmpdir): - """ - Tests that `step` works with optimizer_closure and different accumulated_gradient frequency - """ + """Tests that `step` works with optimizer_closure and different accumulated_gradient frequency.""" train_manual_optimization(tmpdir, "ddp_spawn") @@ -925,10 +904,7 @@ def test_step_with_optimizer_closure_with_different_frequencies_ddp_with_toggle_ def test_lr_schedulers(tmpdir): - """ - Test `lr_schedulers()` returns the same objects - in the same order as `configure_optimizers()` returns. - """ + """Test `lr_schedulers()` returns the same objects in the same order as `configure_optimizers()` returns.""" class TestModel(BoringModel): def __init__(self): @@ -1001,9 +977,7 @@ def configure_optimizers(self): def test_lr_scheduler_step_not_called(tmpdir): - """ - Test `lr_scheduler.step()` is not called in manual optimization. - """ + """Test `lr_scheduler.step()` is not called in manual optimization.""" class TestModel(BoringModel): def __init__(self): @@ -1038,9 +1012,7 @@ def training_step(self, batch, batch_idx): @RunIf(min_gpus=1) @pytest.mark.parametrize("precision", [16, 32]) def test_multiple_optimizers_logging(precision, tmpdir): - """ - Tests that metrics are properly being logged. - """ + """Tests that metrics are properly being logged.""" class TestModel(BoringModel): def __init__(self): diff --git a/tests/trainer/optimization/test_multiple_optimizers.py b/tests/trainer/optimization/test_multiple_optimizers.py index fccb4e60657d9..603adb36d6981 100644 --- a/tests/trainer/optimization/test_multiple_optimizers.py +++ b/tests/trainer/optimization/test_multiple_optimizers.py @@ -11,9 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Tests to ensure that the behaviours related to multiple optimizers works -""" +"""Tests to ensure that the behaviours related to multiple optimizers works.""" import pytest import torch @@ -29,7 +27,7 @@ def configure_optimizers(self): def test_unbalanced_logging_with_multiple_optimizers(tmpdir): - """This tests ensures reduction works in unbalanced logging settings""" + """This tests ensures reduction works in unbalanced logging settings.""" class TestModel(MultiOptModel): @@ -127,10 +125,8 @@ def training_epoch_end(self, outputs) -> None: def test_multiple_optimizers_no_opt_idx_argument(tmpdir): - """ - Test that an error is raised if no optimizer_idx is present when - multiple optimizeres are passed in case of automatic_optimization - """ + """Test that an error is raised if no optimizer_idx is present when multiple optimizeres are passed in case of + automatic_optimization.""" class TestModel(MultiOptModel): def training_step(self, batch, batch_idx): @@ -143,10 +139,8 @@ def training_step(self, batch, batch_idx): def test_custom_optimizer_step_with_multiple_optimizers(tmpdir): - """ - This tests ensures custom optimizer_step works, - even when optimizer.step is not called for a particular optimizer - """ + """This tests ensures custom optimizer_step works, even when optimizer.step is not called for a particular + optimizer.""" class TestModel(BoringModel): training_step_called = [0, 0] diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 848a7ac56f6b2..deb7dc0412968 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -25,7 +25,7 @@ def test_optimizer_with_scheduling(tmpdir): - """Verify that learning rate scheduling is working""" + """Verify that learning rate scheduling is working.""" model = BoringModel() trainer = Trainer( @@ -43,7 +43,7 @@ def test_optimizer_with_scheduling(tmpdir): def test_multi_optimizer_with_scheduling(tmpdir): - """Verify that learning rate scheduling is working""" + """Verify that learning rate scheduling is working.""" class TestModel(BoringModel): init_lr = 5e-4 @@ -75,9 +75,7 @@ def configure_optimizers(self): def test_reducelronplateau_with_no_monitor_raises(tmpdir): - """ - Test exception when a ReduceLROnPlateau is used with no monitor - """ + """Test exception when a ReduceLROnPlateau is used with no monitor.""" model = BoringModel() optimizer = optim.Adam(model.parameters()) model.configure_optimizers = lambda: ([optimizer], [optim.lr_scheduler.ReduceLROnPlateau(optimizer)]) @@ -89,9 +87,7 @@ def test_reducelronplateau_with_no_monitor_raises(tmpdir): def test_reducelronplateau_with_no_monitor_in_lr_scheduler_dict_raises(tmpdir): - """ - Test exception when lr_scheduler dict has a ReduceLROnPlateau with no monitor - """ + """Test exception when lr_scheduler dict has a ReduceLROnPlateau with no monitor.""" model = BoringModel() optimizer = optim.Adam(model.parameters()) model.configure_optimizers = lambda: { @@ -268,10 +264,8 @@ def configure_optimizers(self): def test_step_scheduling_for_multiple_optimizers_with_frequency( tmpdir, schedulers, kwargs, intervals, frequencies, expected_steps, max_epochs ): - """ - Test that step LR schedulers for multiple optimizers follow - the optimizer frequencies when corresponding frequency is set. - """ + """Test that step LR schedulers for multiple optimizers follow the optimizer frequencies when corresponding + frequency is set.""" class DummyModel(BoringModel): def training_step(self, batch, batch_idx, optimizer_idx): @@ -307,9 +301,7 @@ def configure_optimizers(self): @pytest.mark.parametrize("fn", ("validate", "test")) def test_init_optimizers_during_evaluation(tmpdir, fn): - """ - Test that optimizers is an empty list during evaluation - """ + """Test that optimizers is an empty list during evaluation.""" class TestModel(BoringModel): def configure_optimizers(self): @@ -329,9 +321,7 @@ def configure_optimizers(self): def test_multiple_optimizers_callbacks(tmpdir): - """ - Tests that multiple optimizers can be used with callbacks - """ + """Tests that multiple optimizers can be used with callbacks.""" class CB(Callback): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): @@ -378,9 +368,7 @@ def configure_optimizers(self): @pytest.mark.parametrize("complete_epoch", [True, False]) @mock.patch("torch.optim.lr_scheduler.ReduceLROnPlateau.step") def test_lr_scheduler_strict(step_mock, tmpdir, complete_epoch): - """ - Test "strict" support in lr_scheduler dict - """ + """Test "strict" support in lr_scheduler dict.""" model = BoringModel() optimizer = optim.Adam(model.parameters()) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer) @@ -419,9 +407,7 @@ def test_lr_scheduler_strict(step_mock, tmpdir, complete_epoch): def test_unknown_configure_optimizers_raises(tmpdir): - """ - Test exception with an unsupported configure_optimizers return - """ + """Test exception with an unsupported configure_optimizers return.""" model = BoringModel() model.configure_optimizers = lambda: 1 trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) @@ -430,9 +416,7 @@ def test_unknown_configure_optimizers_raises(tmpdir): def test_lr_scheduler_with_unknown_interval_raises(tmpdir): - """ - Test exception when lr_scheduler dict has unknown interval param value - """ + """Test exception when lr_scheduler dict has unknown interval param value.""" model = BoringModel() optimizer = optim.Adam(model.parameters()) model.configure_optimizers = lambda: { @@ -445,9 +429,7 @@ def test_lr_scheduler_with_unknown_interval_raises(tmpdir): def test_lr_scheduler_with_extra_keys_warns(tmpdir): - """ - Test warning when lr_scheduler dict has extra keys - """ + """Test warning when lr_scheduler dict has extra keys.""" model = BoringModel() optimizer = optim.Adam(model.parameters()) model.configure_optimizers = lambda: { @@ -460,9 +442,7 @@ def test_lr_scheduler_with_extra_keys_warns(tmpdir): def test_lr_scheduler_with_no_actual_scheduler_raises(tmpdir): - """ - Test exception when lr_scheduler dict has no scheduler - """ + """Test exception when lr_scheduler dict has no scheduler.""" model = BoringModel() model.configure_optimizers = lambda: {"optimizer": optim.Adam(model.parameters()), "lr_scheduler": {}} trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) @@ -471,9 +451,7 @@ def test_lr_scheduler_with_no_actual_scheduler_raises(tmpdir): def test_invalid_optimizer_in_scheduler(tmpdir): - """ - Test exception when optimizer attatched to lr_schedulers wasn't returned - """ + """Test exception when optimizer attatched to lr_schedulers wasn't returned.""" class InvalidOptimizerModel(BoringModel): def configure_optimizers(self): @@ -489,9 +467,7 @@ def configure_optimizers(self): def test_invalid_optimizer_dict_raises(tmpdir): - """ - Test exception when lr_scheduler dict has no scheduler - """ + """Test exception when lr_scheduler dict has no scheduler.""" class DummyModel(BoringModel): def configure_optimizers(self): @@ -504,9 +480,7 @@ def configure_optimizers(self): def test_warn_invalid_scheduler_key_in_manual_optimization(tmpdir): - """ - Test warning when invalid scheduler keys are provided in manual optimization. - """ + """Test warning when invalid scheduler keys are provided in manual optimization.""" class TestModel(BoringModel): def __init__(self): diff --git a/tests/trainer/properties/test_get_model.py b/tests/trainer/properties/test_get_model.py index 9a0527d46330c..13a8c617975ea 100644 --- a/tests/trainer/properties/test_get_model.py +++ b/tests/trainer/properties/test_get_model.py @@ -26,9 +26,7 @@ def on_fit_end(self): def test_get_model(tmpdir): - """ - Tests that `trainer.lightning_module` extracts the model correctly - """ + """Tests that `trainer.lightning_module` extracts the model correctly.""" model = TrainerGetModel() @@ -41,9 +39,7 @@ def test_get_model(tmpdir): @RunIf(skip_windows=True) def test_get_model_ddp_cpu(tmpdir): - """ - Tests that `trainer.lightning_module` extracts the model correctly when using ddp on cpu - """ + """Tests that `trainer.lightning_module` extracts the model correctly when using ddp on cpu.""" model = TrainerGetModel() @@ -61,9 +57,7 @@ def test_get_model_ddp_cpu(tmpdir): @RunIf(min_gpus=1) def test_get_model_gpu(tmpdir): - """ - Tests that `trainer.lightning_module` extracts the model correctly when using GPU - """ + """Tests that `trainer.lightning_module` extracts the model correctly when using GPU.""" model = TrainerGetModel() diff --git a/tests/trainer/properties/test_log_dir.py b/tests/trainer/properties/test_log_dir.py index d940dabd99c09..0623cf0097280 100644 --- a/tests/trainer/properties/test_log_dir.py +++ b/tests/trainer/properties/test_log_dir.py @@ -30,9 +30,7 @@ def training_step(self, *args, **kwargs): def test_logdir(tmpdir): - """ - Tests that the path is correct when checkpoint and loggers are used - """ + """Tests that the path is correct when checkpoint and loggers are used.""" expected = os.path.join(tmpdir, "lightning_logs", "version_0") model = TestModel(expected) @@ -45,9 +43,7 @@ def test_logdir(tmpdir): def test_logdir_no_checkpoint_cb(tmpdir): - """ - Tests that the path is correct with no checkpoint - """ + """Tests that the path is correct with no checkpoint.""" expected = os.path.join(tmpdir, "lightning_logs", "version_0") model = TestModel(expected) @@ -59,9 +55,7 @@ def test_logdir_no_checkpoint_cb(tmpdir): def test_logdir_no_logger(tmpdir): - """ - Tests that the path is correct even when there is no logger - """ + """Tests that the path is correct even when there is no logger.""" expected = os.path.join(tmpdir) model = TestModel(expected) @@ -73,9 +67,7 @@ def test_logdir_no_logger(tmpdir): def test_logdir_no_logger_no_checkpoint(tmpdir): - """ - Tests that the path is correct even when there is no logger - """ + """Tests that the path is correct even when there is no logger.""" expected = os.path.join(tmpdir) model = TestModel(expected) @@ -87,9 +79,7 @@ def test_logdir_no_logger_no_checkpoint(tmpdir): def test_logdir_custom_callback(tmpdir): - """ - Tests that the path is correct even when there is a custom callback - """ + """Tests that the path is correct even when there is a custom callback.""" expected = os.path.join(tmpdir, "lightning_logs", "version_0") model = TestModel(expected) @@ -103,9 +93,7 @@ def test_logdir_custom_callback(tmpdir): def test_logdir_custom_logger(tmpdir): - """ - Tests that the path is correct even when there is a custom logger - """ + """Tests that the path is correct even when there is a custom logger.""" expected = os.path.join(tmpdir, "custom_logs", "version_0") model = TestModel(expected) @@ -122,7 +110,7 @@ def test_logdir_custom_logger(tmpdir): def test_logdir_logger_collection(tmpdir): - """Tests that the logdir equals the default_root_dir when the logger is a LoggerCollection""" + """Tests that the logdir equals the default_root_dir when the logger is a LoggerCollection.""" default_root_dir = tmpdir / "default_root_dir" save_dir = tmpdir / "save_dir" model = TestModel(default_root_dir) diff --git a/tests/trainer/test_config_validator.py b/tests/trainer/test_config_validator.py index daa01b5abe7b5..9c452052d73e7 100644 --- a/tests/trainer/test_config_validator.py +++ b/tests/trainer/test_config_validator.py @@ -38,7 +38,7 @@ def test_wrong_train_setting(tmpdir): def test_wrong_configure_optimizers(tmpdir): - """Test that an error is thrown when no `configure_optimizers()` is defined""" + """Test that an error is thrown when no `configure_optimizers()` is defined.""" trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) with pytest.raises(MisconfigurationException, match=r"No `configure_optimizers\(\)` method defined."): @@ -48,9 +48,7 @@ def test_wrong_configure_optimizers(tmpdir): def test_fit_val_loop_config(tmpdir): - """ - When either val loop or val data are missing raise warning - """ + """When either val loop or val data are missing raise warning.""" trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) # no val data has val loop @@ -67,9 +65,7 @@ def test_fit_val_loop_config(tmpdir): def test_test_loop_config(tmpdir): - """ - When either test loop or test data are missing - """ + """When either test loop or test data are missing.""" trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) # has test loop but no test data @@ -86,9 +82,7 @@ def test_test_loop_config(tmpdir): def test_val_loop_config(tmpdir): - """ - When either validation loop or validation data are missing - """ + """When either validation loop or validation data are missing.""" trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) # has val loop but no val data @@ -143,7 +137,7 @@ def predict_dataloader(self): def test_trainer_manual_optimization_config(tmpdir): - """Test error message when requesting Trainer features unsupported with manual optimization""" + """Test error message when requesting Trainer features unsupported with manual optimization.""" model = BoringModel() model.automatic_optimization = False diff --git a/tests/trainer/test_data_loading.py b/tests/trainer/test_data_loading.py index e9d5d3cc047cb..437acd86a9024 100644 --- a/tests/trainer/test_data_loading.py +++ b/tests/trainer/test_data_loading.py @@ -185,7 +185,7 @@ def __init__(self, num_feat, dataset, **kwargs): def test_replace_sampler_with_multiprocessing_context(): - """This test verifies that replace_sampler conserves multiprocessing context""" + """This test verifies that replace_sampler conserves multiprocessing context.""" train = RandomDataset(32, 64) context = "spawn" train = DataLoader(train, batch_size=32, num_workers=2, multiprocessing_context=context, shuffle=True) @@ -257,7 +257,7 @@ class CustomSampler(Sampler): def test_loader_detaching(): - """Checks that the loader has been resetted after the entrypoint""" + """Checks that the loader has been resetted after the entrypoint.""" class LoaderTestModel(BoringModel): def training_step(self, batch, batch_idx): diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 3db27591fb24a..223a1a1be8db0 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -158,7 +158,7 @@ def validation_step(self, *args, **kwargs): def test_train_dataloader_passed_to_fit(tmpdir): - """Verify that train dataloader can be passed to fit""" + """Verify that train dataloader can be passed to fit.""" # only train passed to fit model = EvalModelTemplate() @@ -242,7 +242,7 @@ def on_test_epoch_start(self, trainer, pl_module): ["limit_train_batches", "limit_val_batches", "limit_test_batches"], [(0.0, 0.0, 0.0), (1.0, 1.0, 1.0)] ) def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): - """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit in percent""" + """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit in percent.""" ckpt_callback = ModelCheckpoint(monitor="val_log", save_top_k=1, mode="max", verbose=False) epoch_cb = Counter() @@ -293,7 +293,7 @@ def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, ], ) def test_dataloaders_with_limit_train_batches(tmpdir, dataset, limit_train_batches): - """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number""" + """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number.""" ckpt_callback = ModelCheckpoint(monitor="val_log", save_top_k=1, mode="max", verbose=False) epoch_cb = Counter() @@ -330,7 +330,7 @@ def test_dataloaders_with_limit_train_batches(tmpdir, dataset, limit_train_batch ], ) def test_dataloaders_with_limit_val_batches(tmpdir, dataset, limit_val_batches): - """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number""" + """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number.""" epoch_cb = Counter() callbacks = [epoch_cb] @@ -375,7 +375,7 @@ def test_dataloaders_with_limit_val_batches(tmpdir, dataset, limit_val_batches): def test_datasets_dataloaders_with_limit_num_batches( tmpdir, dataset, limit_train_batches, limit_val_batches, limit_test_batches ): - """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number""" + """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number.""" ckpt_callback = ModelCheckpoint(monitor="val_log", save_top_k=1, mode="max", verbose=False) epoch_cb = Counter() @@ -415,7 +415,7 @@ def test_datasets_dataloaders_with_limit_num_batches( [(0.0, 0.0, 0.0), (0, 0, 0.5), (1.0, 1.0, 1.0), (0.2, 0.4, 0.4)], ) def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): - """Verify num_batches for train, val & test dataloaders passed with batch limit in percent""" + """Verify num_batches for train, val & test dataloaders passed with batch limit in percent.""" model = EvalModelTemplate() model.val_dataloader = model.val_dataloader__multiple_mixed_length model.test_dataloader = model.test_dataloader__multiple_mixed_length @@ -447,7 +447,7 @@ def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, lim ["limit_train_batches", "limit_val_batches", "limit_test_batches"], [(0, 0, 0), (1, 2, 3), (1, 2, 1e50)] ) def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches): - """Verify num_batches for train, val & test dataloaders passed with batch limit as number""" + """Verify num_batches for train, val & test dataloaders passed with batch limit as number.""" model = EvalModelTemplate() model.val_dataloader = model.val_dataloader__multiple_mixed_length @@ -495,7 +495,7 @@ def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_v @pytest.mark.parametrize("fast_dev_run", [True, 1, 3, -1, "temp"]) def test_dataloaders_with_fast_dev_run(tmpdir, fast_dev_run): - """Verify num_batches for train, val & test dataloaders passed with fast_dev_run""" + """Verify num_batches for train, val & test dataloaders passed with fast_dev_run.""" model = EvalModelTemplate() model.val_dataloader = model.val_dataloader__multiple_mixed_length model.test_dataloader = model.test_dataloader__multiple_mixed_length @@ -541,7 +541,7 @@ def test_dataloaders_with_fast_dev_run(tmpdir, fast_dev_run): @pytest.mark.parametrize("ckpt_path", [None, "best", "specific"]) def test_mixing_of_dataloader_options(tmpdir, ckpt_path): - """Verify that dataloaders can be passed to fit""" + """Verify that dataloaders can be passed to fit.""" model = EvalModelTemplate() @@ -628,7 +628,7 @@ def test_inf_val_dataloader(tmpdir, check_interval): def test_error_on_zero_len_dataloader(tmpdir): - """Test that error is raised if a zero-length dataloader is defined""" + """Test that error is raised if a zero-length dataloader is defined.""" model = EvalModelTemplate() model.train_dataloader = model.train_dataloader__zero_length @@ -650,7 +650,7 @@ def test_error_on_zero_len_dataloader(tmpdir): @pytest.mark.parametrize("stage", ("train", "test", "val")) @patch("pytorch_lightning.trainer.data_loading.multiprocessing.cpu_count", return_value=4) def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage): - """Test that error is raised if dataloader with only a few workers is used""" + """Test that error is raised if dataloader with only a few workers is used.""" model = BoringModel() @@ -680,7 +680,7 @@ def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage): @pytest.mark.parametrize("stage", ("train", "test", "val")) @patch("pytorch_lightning.trainer.data_loading.multiprocessing.cpu_count", return_value=4) def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage): - """Test that error is raised if dataloader with only a few workers is used""" + """Test that error is raised if dataloader with only a few workers is used.""" model = EvalModelTemplate() model.training_step = model.training_step__multiple_dataloaders @@ -731,8 +731,8 @@ def _user_worker_init_fn(_): @RunIf(max_torch="1.8.9") def test_missing_worker_init_fn(): - """ - Test that naive worker seed initialization leads to undesired random state in subprocesses. + """Test that naive worker seed initialization leads to undesired random state in subprocesses. + PyTorch 1.9+ does not have this issue. """ dataset = NumpyRandomDataset() @@ -816,7 +816,8 @@ def test_auto_add_worker_init_fn_distributed(tmpdir, monkeypatch): def test_warning_with_small_dataloader_and_logging_interval(tmpdir): - """Test that a warning message is shown if the dataloader length is too short for the chosen logging interval.""" + """Test that a warning message is shown if the dataloader length is too short for the chosen logging + interval.""" model = BoringModel() dataloader = DataLoader(RandomDataset(32, length=10)) model.train_dataloader = lambda: dataloader @@ -922,7 +923,7 @@ def on_test_start(self, trainer, pl_module): @RunIf(min_gpus=2, skip_windows=True) def test_dataloader_distributed_sampler(tmpdir): - """Test DistributedSampler and it's arguments for DDP backend""" + """Test DistributedSampler and it's arguments for DDP backend.""" seed_everything(123) model = EvalModelTemplate() trainer = Trainer( @@ -948,7 +949,8 @@ def train_dataloader(self): @RunIf(min_gpus=2, skip_windows=True) def test_dataloader_distributed_sampler_already_attached(tmpdir): - """Test DistributedSampler and it's arguments for DDP backend when DistSampler already included on dataloader""" + """Test DistributedSampler and it's arguments for DDP backend when DistSampler already included on + dataloader.""" seed_everything(123) model = ModelWithDataLoaderDistributedSampler() trainer = Trainer( @@ -1012,7 +1014,7 @@ def train_dataloader(self): [pytest.param("min_size", 5), pytest.param("max_size_cycle", 10)], ) def test_fit_multiple_train_loaders(tmpdir, multiple_trainloader_mode, num_training_batches): - """Integration test for multple train loaders""" + """Integration test for multple train loaders.""" model = EvalModelTemplate() model.train_dataloader = model.train_dataloader__multiple_mapping @@ -1298,10 +1300,8 @@ def test_dataloaders_load_only_once_passed_loaders(tmpdir): def test_dataloaders_reset_and_attach(tmpdir): - """ - Test that repeated calls to Trainer.{fit,validate,test,predict} properly reset and dataloaders before - attaching the new one. - """ + """Test that repeated calls to Trainer.{fit,validate,test,predict} properly reset and dataloaders before + attaching the new one.""" # the assertions compare the datasets and not dataloaders since we patch and replace the samplers dataloader_0 = DataLoader(dataset=RandomDataset(32, 64)) dataloader_1 = DataLoader(dataset=RandomDataset(32, 64)) @@ -1343,9 +1343,7 @@ def test_dataloaders_reset_and_attach(tmpdir): @pytest.mark.parametrize("multiple_trainloader_mode", ["min_size", "max_size_cycle"]) def test_correct_dataloader_idx_in_hooks(tmpdir, multiple_trainloader_mode): - """ - Check the correct dataloader_idx inside hooks - """ + """Check the correct dataloader_idx inside hooks.""" class CustomBoringModel(BoringModel): def __init__(self): diff --git a/tests/trainer/test_progress.py b/tests/trainer/test_progress.py index ef2b8d2888573..9cf90640256ba 100644 --- a/tests/trainer/test_progress.py +++ b/tests/trainer/test_progress.py @@ -90,9 +90,9 @@ def test_progress_raises(): def test_optimizer_progress_default_factory(): - """ - Ensure that the defaults are created appropiately. If `default_factory` was not used, the default would - be shared between instances. + """Ensure that the defaults are created appropiately. + + If `default_factory` was not used, the default would be shared between instances. """ p1 = OptimizerProgress() p2 = OptimizerProgress() diff --git a/tests/trainer/test_states.py b/tests/trainer/test_states.py index 861885b4c052b..ab5e3c8da2680 100644 --- a/tests/trainer/test_states.py +++ b/tests/trainer/test_states.py @@ -19,7 +19,7 @@ def test_initialize_state(): - """Tests that state is INITIALIZING after Trainer creation""" + """Tests that state is INITIALIZING after Trainer creation.""" trainer = Trainer() assert trainer.state == TrainerState(status=TrainerStatus.INITIALIZING, fn=None, stage=None) @@ -80,7 +80,7 @@ def on_test_batch_start(self, *_): [pytest.param(dict(fast_dev_run=True), id="Fast-Run"), pytest.param(dict(max_steps=1), id="Single-Step")], ) def test_interrupt_state_on_keyboard_interrupt(tmpdir, extra_params): - """Tests that state is set to INTERRUPTED on KeyboardInterrupt""" + """Tests that state is set to INTERRUPTED on KeyboardInterrupt.""" model = BoringModel() class InterruptCallback(Callback): diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py index 4375bf7f2505e..f7b1b552ccda0 100644 --- a/tests/trainer/test_supporters.py +++ b/tests/trainer/test_supporters.py @@ -36,7 +36,7 @@ def test_tensor_running_accum_reset(): - """Test that reset would set all attributes to the initialization state""" + """Test that reset would set all attributes to the initialization state.""" window_length = 10 @@ -89,7 +89,7 @@ def test_none_length_cycle_iterator(): ], ) def test_combined_dataset(dataset_1, dataset_2): - """Verify the length of the CombinedDataset""" + """Verify the length of the CombinedDataset.""" datasets = [dataset_1, dataset_2] combined_dataset = CombinedDataset(datasets) @@ -104,7 +104,7 @@ def test_combined_dataset_length_mode_error(): def test_combined_loader_iterator_dict_min_size(): - """Test `CombinedLoaderIterator` given mapping loaders""" + """Test `CombinedLoaderIterator` given mapping loaders.""" loaders = { "a": torch.utils.data.DataLoader(range(10), batch_size=4), "b": torch.utils.data.DataLoader(range(20), batch_size=5), @@ -127,19 +127,19 @@ def test_combined_loader_init_mode_error(): def test_combined_loader_loader_type_error(): - """Test the ValueError when wrapping the loaders""" + """Test the ValueError when wrapping the loaders.""" with pytest.raises(TypeError, match="Expected data to be int, Sequence or Mapping, but got NoneType"): CombinedLoader(None, "max_size_cycle") def test_combined_loader_calc_length_mode_error(): - """Test the ValueError when calculating the number of batches""" + """Test the ValueError when calculating the number of batches.""" with pytest.raises(TypeError, match="Expected data to be int, Sequence or Mapping, but got NoneType"): CombinedLoader._calc_num_batches(None) def test_combined_loader_dict_min_size(): - """Test `CombinedLoader` of mode 'min_size' given mapping loaders""" + """Test `CombinedLoader` of mode 'min_size' given mapping loaders.""" loaders = { "a": torch.utils.data.DataLoader(range(10), batch_size=4), "b": torch.utils.data.DataLoader(range(20), batch_size=5), @@ -158,7 +158,7 @@ def test_combined_loader_dict_min_size(): def test_combined_loader_dict_max_size_cycle(): - """Test `CombinedLoader` of mode 'max_size_cycle' given mapping loaders""" + """Test `CombinedLoader` of mode 'max_size_cycle' given mapping loaders.""" loaders = { "a": torch.utils.data.DataLoader(range(10), batch_size=4), "b": torch.utils.data.DataLoader(range(20), batch_size=5), @@ -177,7 +177,7 @@ def test_combined_loader_dict_max_size_cycle(): def test_combined_loader_sequence_min_size(): - """Test `CombinedLoader` of mode 'min_size' given sequence loaders""" + """Test `CombinedLoader` of mode 'min_size' given sequence loaders.""" loaders = [ torch.utils.data.DataLoader(range(10), batch_size=4), torch.utils.data.DataLoader(range(20), batch_size=5), @@ -210,7 +210,7 @@ def __next__(self): @pytest.mark.parametrize("mode", ["min_size", "max_size_cycle"]) @pytest.mark.parametrize("use_multiple_dataloaders", [False, True]) def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloaders): - """Test `CombinedLoader` of mode 'min_size' given sequence loaders""" + """Test `CombinedLoader` of mode 'min_size' given sequence loaders.""" if use_multiple_dataloaders: loaders = [ torch.utils.data.DataLoader(TestIterableDataset(10), batch_size=2), @@ -272,7 +272,7 @@ def __len__(self): def test_combined_loader_sequence_max_size_cycle(): - """Test `CombinedLoader` of mode 'max_size_cycle' given sequence loaders""" + """Test `CombinedLoader` of mode 'max_size_cycle' given sequence loaders.""" loaders = [ torch.utils.data.DataLoader(range(10), batch_size=4), torch.utils.data.DataLoader(range(20), batch_size=5), @@ -312,10 +312,8 @@ def test_nested_calc_num_data(input_data, compute_func, expected_length): @mock.patch("torch.cuda.device_count", return_value=2) @mock.patch("torch.cuda.is_available", return_value=True) def test_combined_data_loader_validation_test(cuda_available_mock, device_count_mock, tmpdir): - """ - This test makes sure distributed sampler has been properly injected in dataloaders - when using CombinedLoader - """ + """This test makes sure distributed sampler has been properly injected in dataloaders when using + CombinedLoader.""" class CustomDataset(Dataset): def __init__(self, data): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index ded68e5e068c6..674ef851c8419 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -202,7 +202,7 @@ def test_trainer_accumulate_grad_batches_zero_grad(tmpdir, accumulate_grad_batch ], ) def test_gradient_accumulation_scheduling_last_batch(tmpdir, accumulate_grad_batches, limit_train_batches): - """Verify optimizer.step() applied to last batch while grad accumulation""" + """Verify optimizer.step() applied to last batch while grad accumulation.""" class TestModel(BoringModel): def state_dict(self, *args, **kwargs): @@ -264,7 +264,7 @@ def on_train_batch_end(self, outputs, batch, batch_idx, *_): def test_loading_meta_tags(tmpdir): - """test for backward compatibility to meta_tags.csv""" + """test for backward compatibility to meta_tags.csv.""" tutils.reset_seed() hparams = EvalModelTemplate.get_default_hparams() @@ -359,9 +359,8 @@ def mock_save_function(filepath, *args): def test_model_checkpoint_only_weights(tmpdir): - """Tests use case where ModelCheckpoint is configured to save only model weights, and - user tries to load checkpoint to resume training. - """ + """Tests use case where ModelCheckpoint is configured to save only model weights, and user tries to load + checkpoint to resume training.""" model = EvalModelTemplate() trainer = Trainer( @@ -405,7 +404,7 @@ def test_model_freeze_unfreeze(): @pytest.mark.parametrize("url_ckpt", [True, False]) def test_resume_from_checkpoint_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ckpt): - """Verify resuming from checkpoint runs the right number of epochs""" + """Verify resuming from checkpoint runs the right number of epochs.""" # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir monkeypatch.setenv("TORCH_HOME", tmpdir) @@ -462,7 +461,7 @@ def on_load_checkpoint(self, _): def test_trainer_max_steps_and_epochs(tmpdir): - """Verify model trains according to specified max steps""" + """Verify model trains according to specified max steps.""" model = BoringModel() num_train_samples = math.floor(len(model.train_dataloader()) * 0.5) @@ -510,7 +509,7 @@ def test_trainer_max_steps_and_epochs(tmpdir): ], ) def test_trainer_max_steps_and_epochs_validation(max_epochs, max_steps, incorrect_variable, incorrect_value): - """Don't allow max_epochs or max_steps to be less than -1 or a float""" + """Don't allow max_epochs or max_steps to be less than -1 or a float.""" with pytest.raises( MisconfigurationException, match=f"`{incorrect_variable}` must be a positive integer or -1. You passed in {incorrect_value}", @@ -545,7 +544,7 @@ def test_trainer_max_steps_and_epochs_fit_loop_done(max_epochs, max_steps, is_do def test_trainer_min_steps_and_epochs(tmpdir): - """Verify model trains according to specified min steps""" + """Verify model trains according to specified min steps.""" model = EvalModelTemplate() num_train_samples = math.floor(len(model.train_dataloader()) * 0.5) @@ -615,7 +614,7 @@ def training_step(self, batch, batch_idx): def test_trainer_max_steps_accumulate_batches(tmpdir): - """Verify model trains according to specified max steps with grad accumulated batches""" + """Verify model trains according to specified max steps with grad accumulated batches.""" model = BoringModel() num_train_samples = math.floor(len(model.train_dataloader()) * 0.5) @@ -947,7 +946,7 @@ def on_keyboard_interrupt(self, trainer, pl_module): [32, pytest.param(16, marks=RunIf(min_gpus=1, amp_native=True))], ) def test_gradient_clipping_by_norm(tmpdir, precision): - """Test gradient clipping by norm""" + """Test gradient clipping by norm.""" tutils.reset_seed() model = EvalModelTemplate() # TODO: when precision=16, BoringModel produces NaN, but EvalModelTemplate not @@ -980,7 +979,7 @@ def backward(*args, **kwargs): [32, pytest.param(16, marks=RunIf(min_gpus=1, amp_native=True))], ) def test_gradient_clipping_by_value(tmpdir, precision): - """Test gradient clipping by value""" + """Test gradient clipping by value.""" tutils.reset_seed() model = BoringModel() @@ -1028,9 +1027,7 @@ def test_gpu_choice(tmpdir): @pytest.mark.parametrize("limit_val_batches", [0.0, 1, 1.0, 0.5, 5]) def test_num_sanity_val_steps(tmpdir, limit_val_batches): - """ - Test that the number of sanity check batches is clipped to `limit_val_batches`. - """ + """Test that the number of sanity check batches is clipped to `limit_val_batches`.""" model = EvalModelTemplate() model.validation_step = model.validation_step__multiple_dataloaders model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders @@ -1059,10 +1056,8 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches): @pytest.mark.parametrize("limit_val_batches", [0.0, 1, 1.0, 0.3]) def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): - """ - Test that `num_sanity_val_steps=-1` runs through all validation data once, and as many batches as - limited by `limit_val_batches` Trainer argument. - """ + """Test that `num_sanity_val_steps=-1` runs through all validation data once, and as many batches as limited by + `limit_val_batches` Trainer argument.""" model = EvalModelTemplate() model.validation_step = model.validation_step__multiple_dataloaders model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders @@ -1221,7 +1216,7 @@ def test_trainer_pickle(tmpdir): @pytest.mark.parametrize("stage", ("fit", "validate", "test")) def test_trainer_setup_call(tmpdir, stage): - """Test setup call gets the correct stage""" + """Test setup call gets the correct stage.""" class CurrentModel(BoringModel): def setup(self, stage): @@ -1364,9 +1359,7 @@ def predict( def test_trainer_predict_no_return(tmpdir): - """ - Test trainer.predict warns when nothing is returned - """ + """Test trainer.predict warns when nothing is returned.""" class CustomBoringModel(BoringModel): def predict_step(self, batch, batch_idx, dataloader_idx=None): @@ -1426,9 +1419,7 @@ def test_trainer_predict_ddp_cpu(tmpdir): @patch("torch.cuda.device_count", return_value=2) @patch("torch.cuda.is_available", return_value=True) def test_spawn_predict_return_predictions(*_): - """ - Test that `return_predictions=True` raise a MisconfigurationException with spawn training type plugins. - """ + """Test that `return_predictions=True` raise a MisconfigurationException with spawn training type plugins.""" model = BoringModel() def run(expected_plugin, **trainer_kwargs): @@ -1444,9 +1435,7 @@ def run(expected_plugin, **trainer_kwargs): @pytest.mark.parametrize("return_predictions", [None, False, True]) @pytest.mark.parametrize("precision", [32, 64]) def test_predict_return_predictions_cpu(return_predictions, precision, tmpdir): - """ - Test that `return_predictions=True`. - """ + """Test that `return_predictions=True`.""" seed_everything(42) model = BoringModel() @@ -1465,8 +1454,8 @@ def test_predict_return_predictions_cpu(return_predictions, precision, tmpdir): def test_disabled_training_for_insufficient_limit_train_batches( tmpdir, limit_train_batches, global_step, num_training_batches, current_epoch, should_train ): - """ - Verify when `limit_train_batches` is float & between [0.0, 1.0] and + """Verify when `limit_train_batches` is float & between [0.0, 1.0] and. + `int(self.num_training_batches * self.limit_train_batches) == 0`, the training loop is disabled. """ @@ -1512,11 +1501,8 @@ def training_epoch_end(self, *args, **kwargs): @pytest.mark.parametrize(["max_steps", "max_epochs", "global_step"], [(10, 5, 10), (20, None, 20)]) def test_repeated_fit_calls_with_max_epochs_and_steps(tmpdir, max_steps, max_epochs, global_step): - """ - Ensure that the training loop is bound by `max_steps` and - `max_epochs` for repeated calls of `trainer.fit`, and - disabled if the limit is reached - """ + """Ensure that the training loop is bound by `max_steps` and `max_epochs` for repeated calls of `trainer.fit`, + and disabled if the limit is reached.""" dataset_len = 200 batch_size = 10 @@ -1533,9 +1519,7 @@ def test_repeated_fit_calls_with_max_epochs_and_steps(tmpdir, max_steps, max_epo def test_trainer_access_in_configure_optimizers(tmpdir): - """ - Verify that the configure optimizer function can reference the trainer. - """ + """Verify that the configure optimizer function can reference the trainer.""" class TestModel(BoringModel): def configure_optimizers(self): @@ -1550,9 +1534,7 @@ def configure_optimizers(self): @RunIf(min_gpus=1) def test_setup_hook_move_to_device_correctly(tmpdir): - """ - Verify that if a user defines a layer in the setup hook function, this is moved to the correct device. - """ + """Verify that if a user defines a layer in the setup hook function, this is moved to the correct device.""" class TestModel(BoringModel): def setup(self, stage: str) -> None: @@ -1704,7 +1686,7 @@ def test_dataloader(self): class TestCallback(Callback): def on_fit_start(self, trainer, pl_module: LightningModule) -> None: - """Called when fit begins""" + """Called when fit begins.""" assert isinstance(pl_module.data_pipeline, DataPipeline) model = BoringModel() @@ -1853,9 +1835,7 @@ def training_step(self, batch, batch_idx): @RunIf(min_gpus=1) def test_multiple_trainer_constant_memory_allocated(tmpdir): - """ - This tests ensures calling the trainer several times reset the memory back to 0. - """ + """This tests ensures calling the trainer several times reset the memory back to 0.""" class TestModel(BoringModel): def training_step(self, batch, batch_idx): diff --git a/tests/trainer/test_trainer_cli.py b/tests/trainer/test_trainer_cli.py index 08cdafefaf07a..b2bdd7f46d9a4 100644 --- a/tests/trainer/test_trainer_cli.py +++ b/tests/trainer/test_trainer_cli.py @@ -26,7 +26,7 @@ @mock.patch("argparse.ArgumentParser.parse_args") def test_default_args(mock_argparse, tmpdir): - """Tests default argument parser for Trainer""" + """Tests default argument parser for Trainer.""" mock_argparse.return_value = Namespace(**Trainer.default_attributes()) # logger file to get meta @@ -45,9 +45,7 @@ def test_default_args(mock_argparse, tmpdir): @pytest.mark.parametrize("cli_args", [["--accumulate_grad_batches=22"], ["--weights_save_path=./"], []]) def test_add_argparse_args_redefined(cli_args: list): - """Redefines some default Trainer arguments via the cli and - tests the Trainer initialization correctness. - """ + """Redefines some default Trainer arguments via the cli and tests the Trainer initialization correctness.""" parser = ArgumentParser(add_help=False) parser = Trainer.add_argparse_args(parent_parser=parser) diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index 59dad348cebeb..922dbdd13ab41 100644 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -20,9 +20,7 @@ def test_num_training_batches(tmpdir): - """ - Tests that the correct number of batches are allocated - """ + """Tests that the correct number of batches are allocated.""" # when we have fewer batches in the dataloader we should use those instead of the limit model = EvalModelTemplate() trainer = Trainer(limit_val_batches=100, limit_train_batches=100, max_epochs=1, default_root_dir=tmpdir) diff --git a/tests/tuner/test_lr_finder.py b/tests/tuner/test_lr_finder.py index de1873ee391d8..d764afba5d3c3 100644 --- a/tests/tuner/test_lr_finder.py +++ b/tests/tuner/test_lr_finder.py @@ -26,7 +26,7 @@ def test_error_on_more_than_1_optimizer(tmpdir): - """Check that error is thrown when more than 1 optimizer is passed""" + """Check that error is thrown when more than 1 optimizer is passed.""" model = EvalModelTemplate() model.configure_optimizers = model.configure_optimizers__multiple_schedulers @@ -87,7 +87,7 @@ def test_trainer_reset_correctly(tmpdir): @pytest.mark.parametrize("use_hparams", [False, True]) def test_trainer_arg_bool(tmpdir, use_hparams): - """Test that setting trainer arg to bool works""" + """Test that setting trainer arg to bool works.""" hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(**hparams) before_lr = hparams.get("learning_rate") @@ -109,7 +109,7 @@ def test_trainer_arg_bool(tmpdir, use_hparams): @pytest.mark.parametrize("use_hparams", [False, True]) def test_trainer_arg_str(tmpdir, use_hparams): - """Test that setting trainer arg to string works""" + """Test that setting trainer arg to string works.""" hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(**hparams) model.my_fancy_lr = 1.0 # update with non-standard field @@ -133,7 +133,7 @@ def test_trainer_arg_str(tmpdir, use_hparams): @pytest.mark.parametrize("optimizer", ["Adam", "Adagrad"]) def test_call_to_trainer_method(tmpdir, optimizer): - """Test that directly calling the trainer method works""" + """Test that directly calling the trainer method works.""" hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(**hparams) @@ -153,7 +153,7 @@ def test_call_to_trainer_method(tmpdir, optimizer): def test_datamodule_parameter(tmpdir): - """Test that the datamodule parameter works""" + """Test that the datamodule parameter works.""" seed_everything(1) dm = ClassifDataModule() @@ -171,7 +171,8 @@ def test_datamodule_parameter(tmpdir): def test_accumulation_and_early_stopping(tmpdir): - """Test that early stopping of learning rate finder works, and that accumulation also works for this feature""" + """Test that early stopping of learning rate finder works, and that accumulation also works for this + feature.""" class TestModel(BoringModel): def __init__(self): @@ -188,7 +189,7 @@ def __init__(self): def test_suggestion_parameters_work(tmpdir): - """Test that default skipping does not alter results in basic case""" + """Test that default skipping does not alter results in basic case.""" dm = ClassifDataModule() model = ClassificationModel() @@ -204,7 +205,7 @@ def test_suggestion_parameters_work(tmpdir): def test_suggestion_with_non_finite_values(tmpdir): - """Test that non-finite values does not alter results""" + """Test that non-finite values does not alter results.""" hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(**hparams) @@ -221,14 +222,14 @@ def test_suggestion_with_non_finite_values(tmpdir): def test_lr_finder_fails_fast_on_bad_config(tmpdir): - """Test that tune fails if the model does not have a lr BEFORE running lr find""" + """Test that tune fails if the model does not have a lr BEFORE running lr find.""" trainer = Trainer(default_root_dir=tmpdir, max_steps=2, auto_lr_find=True) with pytest.raises(MisconfigurationException, match="should have one of these fields"): trainer.tune(BoringModel()) def test_lr_find_with_bs_scale(tmpdir): - """Test that lr_find runs with batch_size_scaling""" + """Test that lr_find runs with batch_size_scaling.""" class BoringModelTune(BoringModel): def __init__(self, learning_rate=0.1, batch_size=2): diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py index e49fcf8686e4c..a78ef828ca3c1 100644 --- a/tests/tuner/test_scale_batch_size.py +++ b/tests/tuner/test_scale_batch_size.py @@ -202,8 +202,8 @@ def test_call_to_trainer_method(tmpdir, scale_method): def test_error_on_dataloader_passed_to_fit(tmpdir): - """Verify that when the auto scale batch size feature raises an error - if a train dataloader is passed to fit""" + """Verify that when the auto scale batch size feature raises an error if a train dataloader is passed to + fit.""" # only train passed to fit model = EvalModelTemplate() @@ -233,7 +233,7 @@ def test_auto_scale_batch_size_with_amp(tmpdir): def test_scale_batch_size_no_trials(tmpdir): - """Check the result is correct even when no trials are run""" + """Check the result is correct even when no trials are run.""" trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_val_batches=1, limit_train_batches=1, auto_scale_batch_size="power" ) diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index 62c543619ee4d..d34a0c64a63cc 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -11,7 +11,7 @@ def setup_ddp(rank, world_size): - """Setup ddp enviroment""" + """Setup ddp enviroment.""" os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "8088" diff --git a/tests/utilities/test_argparse.py b/tests/utilities/test_argparse.py index 8672795ea2787..4f450f1625981 100644 --- a/tests/utilities/test_argparse.py +++ b/tests/utilities/test_argparse.py @@ -153,10 +153,7 @@ def extract_help_text(parser): ], ) def test_add_argparse_args(cls, name): - """ - Tests that ``add_argparse_args`` handles argument groups correctly, and - can be parsed. - """ + """Tests that ``add_argparse_args`` handles argument groups correctly, and can be parsed.""" parser = ArgumentParser() parser_main = parser.add_argument_group("main") parser_main.add_argument("--main_arg", type=str, default="") @@ -186,10 +183,8 @@ def test_negative_add_argparse_args(): def test_add_argparse_args_no_argument_group(): - """ - Tests that ``add_argparse_args(..., use_argument_group=False)`` (old - workflow) handles argument groups correctly, and can be parsed. - """ + """Tests that ``add_argparse_args(..., use_argument_group=False)`` (old workflow) handles argument groups + correctly, and can be parsed.""" parser = ArgumentParser() parser.add_argument("--main_arg", type=str, default="") parser_old = parser # For testing. @@ -220,9 +215,7 @@ def test_int_or_float_type(): @pytest.mark.parametrize(["arg", "expected"], [["--precision=16", 16], ["--precision=bf16", "bf16"]]) def test_precision_parsed_correctly(arg, expected): - """ - Test to ensure that the precision flag is passed correctly when adding argparse args. - """ + """Test to ensure that the precision flag is passed correctly when adding argparse args.""" parser = ArgumentParser() parser = Trainer.add_argparse_args(parser) fake_argv = [arg] diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 15f0b84e860b4..9a06648e8f5fd 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -118,10 +118,8 @@ def test_fast_forward_getattr(): def test_fast_forward_on_batch_sampler(): - """ - This test ensures ``FastForwardSampler`` applied to ``BatchSampler`` correctly retrived - the right next batch on restart. - """ + """This test ensures ``FastForwardSampler`` applied to ``BatchSampler`` correctly retrived the right next batch + on restart.""" dataset = range(15) sampler = SequentialSampler(dataset) batch_sampler = BatchSampler(sampler, 3, False) @@ -144,10 +142,8 @@ def test_fast_forward_on_batch_sampler(): def test_fast_forward_on_sequential_sampler(): - """ - This test ensures ``FastForwardSampler`` applied to ``SequentialSampler`` correctly retrived - the right next batch on restart. - """ + """This test ensures ``FastForwardSampler`` applied to ``SequentialSampler`` correctly retrived the right next + batch on restart.""" dataset = range(15) sequential_sampler = SequentialSampler(dataset) sampler = FastForwardSampler(sequential_sampler) @@ -170,10 +166,8 @@ def test_fast_forward_on_sequential_sampler(): @pytest.mark.skipif(torch.cuda.is_available(), reason="todo (tchaton) Need more investigation") def test_fast_forward_on_random_sampler(): - """ - This test ensures ``FastForwardSampler`` applied to ``RandomSampler`` correctly retrived - the right next batch on restart. - """ + """This test ensures ``FastForwardSampler`` applied to ``RandomSampler`` correctly retrived the right next + batch on restart.""" seed = 42 seed_everything(42) @@ -250,10 +244,8 @@ def __next__(self): @pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 30 sec and should be skipped in Azure CI") @pytest.mark.parametrize("num_workers", [0, 1, 2]) def test_fast_forward_sampler_over_iterable_dataset(num_workers): - """ - This test ensures ``FastForwardSampler`` and ``CaptureIterableDataset`` are properly being - used to capture workers states. - """ + """This test ensures ``FastForwardSampler`` and ``CaptureIterableDataset`` are properly being used to capture + workers states.""" batch_size = 3 initial_seed = seed_everything(42) generator = torch.Generator() @@ -364,7 +356,7 @@ def _test_fast_forward_sampler_with_distributed_sampler(rank, worldsize): @pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 25 sec and should be skipped in Azure CI") @RunIf(skip_windows=True) def test_fast_forward_sampler_with_distributed_sampler(): - """Make sure result logging works with DDP""" + """Make sure result logging works with DDP.""" tutils.set_random_master_port() worldsize = 2 mp.spawn(_test_fast_forward_sampler_with_distributed_sampler, args=(worldsize,), nprocs=worldsize) @@ -638,7 +630,7 @@ def test_fast_forward_sampler_iterative_dataset(): @pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 55 sec and should be skipped in Azure CI") @RunIf(skip_windows=True) def test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset(): - """Make sure result logging works with DDP""" + """Make sure result logging works with DDP.""" tutils.set_random_master_port() worldsize = 2 mp.spawn( @@ -700,9 +692,7 @@ def create_dataloader(): @RunIf(min_torch="1.7.0") @pytest.mark.parametrize("use_fault_tolerant", ["0", "1"]) def test_data_loading_wraps_dataset_and_samplers(use_fault_tolerant, tmpdir): - """ - This test ensures the dataset and sampler are properly wrapped when fault tolerant is enabled. - """ + """This test ensures the dataset and sampler are properly wrapped when fault tolerant is enabled.""" class CustomBatchSampler(BatchSampler): pass @@ -807,8 +797,7 @@ def __len__(self): @pytest.mark.parametrize("batch_size", [1, 2, 3]) def test_dataset_rng_states_restart(dataset_class, num_workers, batch_size): """Test that the sequence of batches coming from a random number generator continues with the correct sequence - after reloading the state. - """ + after reloading the state.""" def create_dataset_sampler(): dset = CaptureMapDataset(dataset_class(16, 8)) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 3abd0cafcf9b8..76422c783b805 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -45,7 +45,7 @@ @mock.patch("argparse.ArgumentParser.parse_args") def test_default_args(mock_argparse, tmpdir): - """Tests default argument parser for Trainer""" + """Tests default argument parser for Trainer.""" mock_argparse.return_value = Namespace(**Trainer.default_attributes()) parser = LightningArgumentParser(add_help=False, parse_as_dict=False) @@ -60,9 +60,7 @@ def test_default_args(mock_argparse, tmpdir): @pytest.mark.parametrize("cli_args", [["--accumulate_grad_batches=22"], ["--weights_save_path=./"], []]) def test_add_argparse_args_redefined(cli_args): - """Redefines some default Trainer arguments via the cli and - tests the Trainer initialization correctness. - """ + """Redefines some default Trainer arguments via the cli and tests the Trainer initialization correctness.""" parser = LightningArgumentParser(add_help=False, parse_as_dict=False) parser.add_lightning_class_args(Trainer, None) @@ -788,8 +786,7 @@ def test_lightning_cli_subcommands(): def test_lightning_cli_custom_subcommand(): class TestTrainer(Trainer): def foo(self, model: LightningModule, x: int, y: float = 1.0): - """ - Sample extra function. + """Sample extra function. Args: model: A model diff --git a/tests/utilities/test_deepspeed_collate_checkpoint.py b/tests/utilities/test_deepspeed_collate_checkpoint.py index 45c8f1a9a1d4f..a04e56b7aabad 100644 --- a/tests/utilities/test_deepspeed_collate_checkpoint.py +++ b/tests/utilities/test_deepspeed_collate_checkpoint.py @@ -24,9 +24,7 @@ @RunIf(min_gpus=2, deepspeed=True, special=True) def test_deepspeed_collate_checkpoint(tmpdir): - """ - Test to ensure that with DeepSpeed Stage 3 we can collate the sharded checkpoints into a single file. - """ + """Test to ensure that with DeepSpeed Stage 3 we can collate the sharded checkpoints into a single file.""" model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=2, fast_dev_run=True, precision=16 diff --git a/tests/utilities/test_dtype_device_mixin.py b/tests/utilities/test_dtype_device_mixin.py index f209a310eea39..75884d8cfc505 100644 --- a/tests/utilities/test_dtype_device_mixin.py +++ b/tests/utilities/test_dtype_device_mixin.py @@ -52,10 +52,8 @@ def on_train_batch_start(self, trainer, model, batch, batch_idx, dataloader_idx) @pytest.mark.parametrize(["dst_device"], [pytest.param(torch.device("cpu")), pytest.param(torch.device("cuda", 0))]) @RunIf(min_gpus=1) def test_submodules_device_and_dtype(dst_device, dst_dtype): - """ - Test that the device and dtype property updates propagate through mixed nesting of regular - nn.Modules and the special modules of type DeviceDtypeModuleMixin (e.g. Metric or LightningModule). - """ + """Test that the device and dtype property updates propagate through mixed nesting of regular nn.Modules and + the special modules of type DeviceDtypeModuleMixin (e.g. Metric or LightningModule).""" model = TopModule() assert model.device == torch.device("cpu") diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index bf54bbae83568..86d04af9d2eb3 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -97,16 +97,16 @@ def test_misconfiguration_error(): def get_cycles_per_ms() -> float: - """ - Get 10 values and remove the 2 max and 2 min and return the avg. + """Get 10 values and remove the 2 max and 2 min and return the avg. + This is to avoid system disturbance that skew the results, e.g. the very first cuda call likely does a bunch of init, which takes much longer than subsequent calls. """ def measure() -> float: - """ - Measure and return approximate number of cycles per millisecond for `torch.cuda._sleep` - Copied from: https://github.com/pytorch/pytorch/blob/v1.9.0/test/test_cuda.py#L81 + """Measure and return approximate number of cycles per millisecond for `torch.cuda._sleep` Copied from: + + https://github.com/pytorch/pytorch/blob/v1.9.0/test/test_cuda.py#L81. """ start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) @@ -302,9 +302,7 @@ def train_dataloader(self): def test_training_step_with_dataloader_access(tmpdir) -> None: - """ - A baseline functional test for `training_step` with dataloader access. - """ + """A baseline functional test for `training_step` with dataloader access.""" trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) m = AsyncBoringModel() trainer.fit(m) @@ -313,10 +311,8 @@ def test_training_step_with_dataloader_access(tmpdir) -> None: @pytest.mark.parametrize("trigger_stop_iteration", [False, True]) def test_stop_iteration(trigger_stop_iteration, tmpdir): - """ - Verify that StopIteration properly terminates the training when this is trigged - from the current `dataloader_iter` - """ + """Verify that StopIteration properly terminates the training when this is trigged from the current + `dataloader_iter`""" EXPECT_NUM_BATCHES_PROCESSED = 2 class TestModel(AsyncBoringModel): @@ -345,10 +341,8 @@ def train_dataloader(self): def test_on_train_batch_start_overridden(tmpdir) -> None: - """ - Verify that a `MisconfigurationException` is raised when - `on_train_batch_start` is overridden on the `LightningModule`. - """ + """Verify that a `MisconfigurationException` is raised when `on_train_batch_start` is overridden on the + `LightningModule`.""" class InvalidModel(AsyncBoringModel): def on_train_batch_start(self, batch, batch_idx, dataloader_idx): @@ -361,10 +355,8 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx): def test_on_train_batch_end_overridden(tmpdir) -> None: - """ - Verify that a `MisconfigurationException` is raised when - `on_train_batch_end` is overridden on the `LightningModule`. - """ + """Verify that a `MisconfigurationException` is raised when `on_train_batch_end` is overridden on the + `LightningModule`.""" class InvalidModel(AsyncBoringModel): def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): @@ -377,10 +369,8 @@ def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): def test_tbptt_split_batch_overridden(tmpdir) -> None: - """ - Verify that a `MisconfigurationException` is raised when - `tbptt_split_batch` is overridden on the `LightningModule`. - """ + """Verify that a `MisconfigurationException` is raised when `tbptt_split_batch` is overridden on the + `LightningModule`.""" class InvalidModel(AsyncBoringModel): def __init__(self) -> None: diff --git a/tests/utilities/test_imports.py b/tests/utilities/test_imports.py index e1c494fe4754b..bf2c2c4f70a9f 100644 --- a/tests/utilities/test_imports.py +++ b/tests/utilities/test_imports.py @@ -16,7 +16,7 @@ def test_module_exists(): - """Test if the some 3rd party libs are available""" + """Test if the some 3rd party libs are available.""" assert _module_available("torch") assert _module_available("torch.nn.parallel") assert not _module_available("torch.nn.asdf") diff --git a/tests/utilities/test_model_summary.py b/tests/utilities/test_model_summary.py index 0d993bee18ff2..b919070268d0c 100644 --- a/tests/utilities/test_model_summary.py +++ b/tests/utilities/test_model_summary.py @@ -25,7 +25,7 @@ class EmptyModule(LightningModule): - """A module that has no layers""" + """A module that has no layers.""" def __init__(self): super().__init__() @@ -54,7 +54,7 @@ def forward(self, x): class UnorderedModel(LightningModule): - """A model in which the layers not defined in order of execution""" + """A model in which the layers not defined in order of execution.""" def __init__(self): super().__init__() @@ -326,7 +326,7 @@ def test_lazy_model_summary(): def test_max_depth_equals_mode_interface(): - """Test summarize(model, full/top) interface mapping matches max_depth""" + """Test summarize(model, full/top) interface mapping matches max_depth.""" model = DeepNestedModel() summary_top = summarize(model, mode="top") @@ -340,7 +340,7 @@ def test_max_depth_equals_mode_interface(): @pytest.mark.parametrize("max_depth", [-1, 0, 1, 3, 999]) def test_max_depth_param(max_depth): - """Test that only the modules up to the desired depth are shown""" + """Test that only the modules up to the desired depth are shown.""" model = DeepNestedModel() summary = ModelSummary(model, max_depth=max_depth) for lname in summary.layer_names: diff --git a/tests/utilities/test_parsing.py b/tests/utilities/test_parsing.py index b29dabe42cee9..4754c8a620383 100644 --- a/tests/utilities/test_parsing.py +++ b/tests/utilities/test_parsing.py @@ -94,7 +94,7 @@ class TestModel7: # test for datamodule w/ hparams w/ attribute (should use dat def test_lightning_hasattr(tmpdir, model_cases): - """Test that the lightning_hasattr works in all cases""" + """Test that the lightning_hasattr works in all cases.""" model1, model2, model3, model4, model5, model6, model7 = models = model_cases assert lightning_hasattr(model1, "learning_rate"), "lightning_hasattr failed to find namespace variable" assert lightning_hasattr(model2, "learning_rate"), "lightning_hasattr failed to find hparams namespace variable" @@ -113,7 +113,7 @@ def test_lightning_hasattr(tmpdir, model_cases): def test_lightning_getattr(tmpdir, model_cases): - """Test that the lightning_getattr works in all cases""" + """Test that the lightning_getattr works in all cases.""" models = model_cases for i, m in enumerate(models[:3]): value = lightning_getattr(m, "learning_rate") @@ -133,7 +133,7 @@ def test_lightning_getattr(tmpdir, model_cases): def test_lightning_setattr(tmpdir, model_cases): - """Test that the lightning_setattr works in all cases""" + """Test that the lightning_setattr works in all cases.""" models = model_cases for m in models[:3]: lightning_setattr(m, "learning_rate", 10) diff --git a/tests/utilities/test_seed.py b/tests/utilities/test_seed.py index 096c9780cccfa..ca103b0a2a318 100644 --- a/tests/utilities/test_seed.py +++ b/tests/utilities/test_seed.py @@ -9,10 +9,7 @@ @mock.patch.dict(os.environ, {}, clear=True) def test_seed_stays_same_with_multiple_seed_everything_calls(): - """ - Ensure that after the initial seed everything, - the seed stays the same for the same run. - """ + """Ensure that after the initial seed everything, the seed stays the same for the same run.""" with pytest.warns(UserWarning, match="No correct seed found"): seed_utils.seed_everything() initial_seed = os.environ.get("PL_GLOBAL_SEED") @@ -27,18 +24,14 @@ def test_seed_stays_same_with_multiple_seed_everything_calls(): @mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "2020"}, clear=True) def test_correct_seed_with_environment_variable(): - """ - Ensure that the PL_GLOBAL_SEED environment is read - """ + """Ensure that the PL_GLOBAL_SEED environment is read.""" assert seed_utils.seed_everything() == 2020 @mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "invalid"}, clear=True) @mock.patch.object(seed_utils, attribute="_select_seed_randomly", new=lambda *_: 123) def test_invalid_seed(): - """ - Ensure that we still fix the seed even if an invalid seed is given - """ + """Ensure that we still fix the seed even if an invalid seed is given.""" with pytest.warns(UserWarning, match="No correct seed found"): seed = seed_utils.seed_everything() assert seed == 123 @@ -48,9 +41,7 @@ def test_invalid_seed(): @mock.patch.object(seed_utils, attribute="_select_seed_randomly", new=lambda *_: 123) @pytest.mark.parametrize("seed", (10e9, -10e9)) def test_out_of_bounds_seed(seed): - """ - Ensure that we still fix the seed even if an out-of-bounds seed is given - """ + """Ensure that we still fix the seed even if an out-of-bounds seed is given.""" with pytest.warns(UserWarning, match="is not in bounds"): actual = seed_utils.seed_everything(seed) assert actual == 123 diff --git a/tests/utilities/test_warnings.py b/tests/utilities/test_warnings.py index 2e0c372e5c39f..d1222672b7595 100644 --- a/tests/utilities/test_warnings.py +++ b/tests/utilities/test_warnings.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Test that the warnings actually appear and they have the correct `stacklevel` +"""Test that the warnings actually appear and they have the correct `stacklevel` Needs to be run outside of `pytest` as it captures all the warnings. """ @@ -26,6 +25,7 @@ if running_special: stderr = StringIO() + # recording with redirect_stderr(stderr): _warn("test1") _warn("test2", DeprecationWarning) diff --git a/tests/utilities/test_xla_device_utils.py b/tests/utilities/test_xla_device_utils.py index b01bf81e61da0..8c1c92e7021c8 100644 --- a/tests/utilities/test_xla_device_utils.py +++ b/tests/utilities/test_xla_device_utils.py @@ -23,13 +23,13 @@ @pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent") def test_tpu_device_absence(): - """Check tpu_device_exists returns False when torch_xla is not available""" + """Check tpu_device_exists returns False when torch_xla is not available.""" assert not xla_utils.XLADeviceUtils.tpu_device_exists() @RunIf(tpu=True) def test_tpu_device_presence(): - """Check tpu_device_exists returns True when TPU is available""" + """Check tpu_device_exists returns True when TPU is available.""" assert xla_utils.XLADeviceUtils.tpu_device_exists() @@ -41,7 +41,7 @@ def sleep_fn(sleep_time: float) -> bool: @patch("pytorch_lightning.utilities.xla_device.TPU_CHECK_TIMEOUT", 3) @pytest.mark.skipif(not _XLA_AVAILABLE, reason="test requires torch_xla to be present") def test_result_returns_within_timeout_seconds(): - """Check that pl_multi_process returns within 3 seconds""" + """Check that pl_multi_process returns within 3 seconds.""" fn = xla_utils.pl_multi_process(sleep_fn) start = time.time()