Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Features] Intialize dataset with ann_file #122

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion mmflow/datasets/base_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os.path as osp
import warnings
from abc import ABCMeta, abstractmethod
from typing import Optional, Sequence, Union

Expand All @@ -16,19 +17,26 @@ class BaseDataset(Dataset, metaclass=ABCMeta):
Args:
data_root (str): Directory for dataset.
pipeline (Sequence[dict]): Processing pipeline.
ann_file: Annotation file path. Defaults to None.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details.
Defaults to ``dict(backend='disk')``.
test_mode (bool): Whether the dataset works for model testing or
training.
"""

def __init__(self,
data_root: str,
pipeline: Sequence[dict],
ann_file: Optional[str] = None,
file_client_args: dict = dict(backend='disk'),
test_mode: bool = False) -> None:
super().__init__()
self.data_root = data_root
self.pipeline = Compose(pipeline)
self.test_mode = test_mode
self.dataset_name = self.__class__.__name__
self.file_client_args = file_client_args
"""
data_infos is the list of data_info containing img_info and ann_info
data_info
Expand All @@ -41,7 +49,38 @@ def __init__(self,
"""
self.data_infos = []

self.load_data_info()
if ann_file is None:
warnings.warn(message='ann_file is None, please use '
'tools/prepare_dataset to generate ann_file')
self.load_data_info()
else:
self.load_ann_file(ann_file)

def load_ann_file(self, ann_file: str) -> None:
"""Load annotation file.

Args:
ann_file (str): The json file contains the data sample
information.
"""
ann = mmcv.load(
ann_file,
file_format='json',
file_client_args=self.file_client_args)
self.data_infos = ann['data_list']
self.img1_dir = osp.join(self.data_root,
self.data_infos[0]['img1_dir'])
self.img2_dir = osp.join(self.data_root,
self.data_infos[0]['img2_dir'])
self.flow_dir = osp.join(self.data_root,
self.data_infos[0]['flow_dir'])
for data_info in self.data_infos:
data_info['img_info']['filename1'] = \
osp.join(self.img1_dir, data_info['img_info']['filename1'])
data_info['img_info']['filename2'] = \
osp.join(self.img2_dir, data_info['img_info']['filename2'])
data_info['ann_info']['filename_flow'] = osp.join(
self.flow_dir, data_info['ann_info']['filename_flow'])

@abstractmethod
def load_data_info(self):
Expand Down