From c28a9cf4fe91c46e0aa31f9667c7dcdd6c35a364 Mon Sep 17 00:00:00 2001 From: Linsho Kaku Date: Mon, 30 Oct 2023 19:00:07 +0900 Subject: [PATCH] set default train_len for iterator dataset --- pytorch_pfn_extras/training/_trainer.py | 10 ++++-- .../training_tests/test_trainer.py | 36 +++++++++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/pytorch_pfn_extras/training/_trainer.py b/pytorch_pfn_extras/training/_trainer.py index fe1c46d1c..bd8803ece 100644 --- a/pytorch_pfn_extras/training/_trainer.py +++ b/pytorch_pfn_extras/training/_trainer.py @@ -247,9 +247,15 @@ def run( - :meth:`pytorch_pfn_extras.training._evaluator.Evaluator` """ if train_len is None: - train_len = len(train_loader) # type: ignore[arg-type] + if hasattr(train_loader, "__len__"): + train_len = len(train_loader) # type: ignore[arg-type] + else: + train_len = 1 if eval_len is None and val_loader is not None: - eval_len = len(val_loader) # type: ignore[arg-type] + if hasattr(eval_len, "__len__"): + eval_len = len(val_loader) # type: ignore[arg-type] + else: + eval_len = 1 self._train_len = train_len self._eval_len = eval_len diff --git a/tests/pytorch_pfn_extras_tests/training_tests/test_trainer.py b/tests/pytorch_pfn_extras_tests/training_tests/test_trainer.py index b65faba55..60e35ae31 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/test_trainer.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/test_trainer.py @@ -948,3 +948,39 @@ def test_create_distributed_evaluator(): with mock.patch.object(dist, "is_initialized", return_value=True): evaluator = engine.create_evaluator(models=model, distributed=True) assert isinstance(evaluator, DistributedEvaluator) + + +def test_trainer_run_with_iterator(): + model = MyModel() + model = MyModelWithLossFn(model) + model = ppe.to(model, "cpu") + optimier = mock.MagicMock(spec=torch.optim.Optimizer) + evaluator = engine.create_evaluator(models=model) + trainer = engine.create_trainer( + models=model, optimizers=optimier, max_epochs=1, evaluator=evaluator + ) + train_iterator = ( + { + "x": torch.rand( + 20, + ), + "t": torch.rand( + 10, + ), + } + for i in range(10) + ) + valid_iterator = ( + { + "x": torch.rand( + 20, + ), + "t": torch.rand( + 10, + ), + } + for i in range(5) + ) + trainer.run(train_iterator, valid_iterator) + assert trainer._train_len == 1 + assert trainer._eval_len == 1