From aac82f567f3647f7b33636545fd760125b44e841 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 18 Dec 2024 14:20:18 +0000 Subject: [PATCH] Trying a better way of getting length Signed-off-by: Eric Kerfoot --- monai/engines/workflow.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 6ddec837d4..da1fefe8fd 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -13,7 +13,7 @@ import warnings from collections.abc import Callable, Iterable, Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Sized import torch import torch.distributed as dist @@ -133,12 +133,9 @@ def __init__( def set_sampler_epoch(engine: Engine) -> None: sampler.set_epoch(engine.state.epoch) - # if the epoch_length isn't given, attempt to get it from the length of the data loader - if epoch_length is None: - try: - epoch_length = len(data_loader) - except TypeError: # raised when data_loader has an iterable dataset with no length, or is some other type - pass # deliberately leave epoch_length as None + # if the epoch_length isn't given, attempt to get it from the length of the data loader + if epoch_length is None and isinstance(data_loader.dataset, Sized): + epoch_length = len(data_loader.dataset) # set all sharable data for the workflow based on Ignite engine.state self.state: Any = State(