From 0b007d1afe8c278a665419a92df585fc67125bf8 Mon Sep 17 00:00:00 2001 From: Jihun Lee <37467083+BlindedShooter@users.noreply.github.com> Date: Thu, 11 Apr 2024 10:27:30 +0900 Subject: [PATCH] Fixed detection api retrieval from dataset Added _FlatDataWithTransform for get_detection_api_from_dataset recursion condition. Previously it recursed when the dataset is a instance of (Subset, AvalancheDataset, ConcatDataset) but _FlatDataWithTransform can be inside of the recursion (examples/detection_lvis.py) --- avalanche/evaluation/metrics/detection.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/avalanche/evaluation/metrics/detection.py b/avalanche/evaluation/metrics/detection.py index 2e01df060..513dc53d6 100644 --- a/avalanche/evaluation/metrics/detection.py +++ b/avalanche/evaluation/metrics/detection.py @@ -19,6 +19,7 @@ ) from avalanche.benchmarks.utils.data import AvalancheDataset +from avalanche.benchmarks.utils.data import _FlatDataWithTransform try: from lvis import LVIS @@ -470,12 +471,12 @@ def get_detection_api_from_dataset( recursion_result = get_detection_api_from_dataset( dataset.dataset, supported_types, none_if_not_found=True ) - elif isinstance(dataset, AvalancheDataset) and len(dataset._datasets) == 1: + elif isinstance(dataset, (AvalancheDataset, _FlatDataWithTransform)) and len(dataset._datasets) == 1: recursion_result = get_detection_api_from_dataset( dataset._datasets[0], supported_types, none_if_not_found=True ) - elif isinstance(dataset, (AvalancheDataset, ConcatDataset)): - if isinstance(dataset, AvalancheDataset): + elif isinstance(dataset, (AvalancheDataset, ConcatDataset, _FlatDataWithTransform)): + if isinstance(dataset, (AvalancheDataset, _FlatDataWithTransform)): datasets_list = dataset._datasets else: datasets_list = dataset.datasets