diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 6ddec837d4..da1fefe8fd 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -13,7 +13,7 @@ import warnings from collections.abc import Callable, Iterable, Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Sized import torch import torch.distributed as dist @@ -133,12 +133,9 @@ def __init__( def set_sampler_epoch(engine: Engine) -> None: sampler.set_epoch(engine.state.epoch) - # if the epoch_length isn't given, attempt to get it from the length of the data loader - if epoch_length is None: - try: - epoch_length = len(data_loader) - except TypeError: # raised when data_loader has an iterable dataset with no length, or is some other type - pass # deliberately leave epoch_length as None + # if the epoch_length isn't given, attempt to get it from the length of the data loader + if epoch_length is None and isinstance(data_loader.dataset, Sized): + epoch_length = len(data_loader.dataset) # set all sharable data for the workflow based on Ignite engine.state self.state: Any = State(