Skip to content

Commit

Permalink
set default train_len for iterator dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
linshokaku committed Oct 30, 2023
1 parent c5b4d58 commit c28a9cf
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 2 deletions.
10 changes: 8 additions & 2 deletions pytorch_pfn_extras/training/_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,15 @@ def run(
- :meth:`pytorch_pfn_extras.training._evaluator.Evaluator`
"""
if train_len is None:
train_len = len(train_loader) # type: ignore[arg-type]
if hasattr(train_loader, "__len__"):
train_len = len(train_loader) # type: ignore[arg-type]
else:
train_len = 1
if eval_len is None and val_loader is not None:
eval_len = len(val_loader) # type: ignore[arg-type]
if hasattr(eval_len, "__len__"):
eval_len = len(val_loader) # type: ignore[arg-type]
else:
eval_len = 1

self._train_len = train_len
self._eval_len = eval_len
Expand Down
36 changes: 36 additions & 0 deletions tests/pytorch_pfn_extras_tests/training_tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,3 +948,39 @@ def test_create_distributed_evaluator():
with mock.patch.object(dist, "is_initialized", return_value=True):
evaluator = engine.create_evaluator(models=model, distributed=True)
assert isinstance(evaluator, DistributedEvaluator)


def test_trainer_run_with_iterator():
model = MyModel()
model = MyModelWithLossFn(model)
model = ppe.to(model, "cpu")
optimier = mock.MagicMock(spec=torch.optim.Optimizer)
evaluator = engine.create_evaluator(models=model)
trainer = engine.create_trainer(
models=model, optimizers=optimier, max_epochs=1, evaluator=evaluator
)
train_iterator = (
{
"x": torch.rand(
20,
),
"t": torch.rand(
10,
),
}
for i in range(10)
)
valid_iterator = (
{
"x": torch.rand(
20,
),
"t": torch.rand(
10,
),
}
for i in range(5)
)
trainer.run(train_iterator, valid_iterator)
assert trainer._train_len == 1
assert trainer._eval_len == 1

0 comments on commit c28a9cf

Please sign in to comment.