From 32955e789ef9c1bf6ec29170ff5986c29ab55361 Mon Sep 17 00:00:00 2001 From: Amogh Mannekote Date: Tue, 27 Aug 2019 21:39:44 +0200 Subject: [PATCH] Black formatting --- nonechucks/__init__.py | 2 +- nonechucks/dataloader.py | 37 +++++++++++++++++-------------------- nonechucks/dataset.py | 10 ++++++---- nonechucks/sampler.py | 8 ++++---- nonechucks/utils.py | 4 ++-- setup.py | 32 +++++++++++++++++--------------- tests/test_dataset.py | 9 +++++---- tests/test_sampler.py | 27 ++++++++++++++------------- 8 files changed, 66 insertions(+), 63 deletions(-) diff --git a/nonechucks/__init__.py b/nonechucks/__init__.py index 575fae7..14b9ddb 100644 --- a/nonechucks/__init__.py +++ b/nonechucks/__init__.py @@ -2,4 +2,4 @@ from .sampler import SafeSampler from .dataloader import SafeDataLoader -__all__ = ['SafeDataset', 'SafeSampler', 'SafeDataLoader'] +__all__ = ["SafeDataset", "SafeSampler", "SafeDataLoader"] diff --git a/nonechucks/dataloader.py b/nonechucks/dataloader.py index 9ce38c1..d0f3664 100644 --- a/nonechucks/dataloader.py +++ b/nonechucks/dataloader.py @@ -2,6 +2,7 @@ from future.utils import with_metaclass import torch.utils.data as data + try: from torch.utils.data.dataloader import default_collate except ImportError: @@ -32,9 +33,11 @@ def safe_sampler_callable(sampler_cls, dataset): return SafeSampler(dataset, sampler_cls(dataset)) data.dataloader.SequentialSampler = partial( - safe_sampler_callable, data.SequentialSampler) + safe_sampler_callable, data.SequentialSampler + ) data.dataloader.RandomSampler = partial( - safe_sampler_callable, data.RandomSampler) + safe_sampler_callable, data.RandomSampler + ) def _restore_default_samplers(cls): data.dataloader.SequentialSampler = cls.sequential @@ -42,7 +45,6 @@ def _restore_default_samplers(cls): class _SafeDataLoaderIter(data.dataloader._DataLoaderIter): - def __init__(self, loader): super().__init__(loader) self.batch_size = loader.batch_size @@ -70,8 +72,7 @@ def _process_next_batch(self, curr_batch): while n_empty_slots > 0: # check if curr_batch is the final batch if self.batches_outstanding == 0 and not self.reorder_dict: - if (not self.drop_last) or \ - (batch_len(curr_batch) == self.batch_size): + if (not self.drop_last) or (batch_len(curr_batch) == self.batch_size): return curr_batch # raises StopIteration if no more elements left, which exits the @@ -85,21 +86,16 @@ def _process_next_batch(self, curr_batch): # The remaining elements of next_batch are added back into the # dict for future consumption. self.rcvd_idx -= 1 - curr_batch = collate_batches([ - curr_batch, - slice_batch(next_batch, end=n_empty_slots) - ]) + curr_batch = collate_batches( + [curr_batch, slice_batch(next_batch, end=n_empty_slots)] + ) self.reorder_dict[self.rcvd_idx] = slice_batch( - next_batch, - start=n_empty_slots + next_batch, start=n_empty_slots ) else: curr_batch = collate_batches([curr_batch, next_batch]) - n_empty_slots -= min( - n_empty_slots, - batch_len(next_batch) - ) + n_empty_slots -= min(n_empty_slots, batch_len(next_batch)) self.coalescing_in_progress = False return curr_batch @@ -130,13 +126,14 @@ def __init__(self, dataset, **kwargs): # drop_last is handled transparently by _SafeDataLoaderIter (bypassing # DataLoader). Since drop_last cannot be changed after initializing the # DataLoader instance, it needs to be intercepted here. - assert isinstance(dataset, SafeDataset), \ - "dataset must be an instance of SafeDataset." + assert isinstance( + dataset, SafeDataset + ), "dataset must be an instance of SafeDataset." self.drop_last_original = False - if 'drop_last' in kwargs: - self.drop_last_original = kwargs['drop_last'] - kwargs['drop_last'] = False + if "drop_last" in kwargs: + self.drop_last_original = kwargs["drop_last"] + kwargs["drop_last"] = False super(SafeDataLoader, self).__init__(dataset, **kwargs) self.safe_dataset = self.dataset diff --git a/nonechucks/dataset.py b/nonechucks/dataset.py index 825c937..c828d4c 100644 --- a/nonechucks/dataset.py +++ b/nonechucks/dataset.py @@ -64,8 +64,7 @@ def is_index_built(self): """Returns True if all indices of the original dataset have been classified into safe_samples_indices or _unsafe_samples_indices. """ - return len(self.dataset) == len(self._safe_indices) + \ - len(self._unsafe_indices) + return len(self.dataset) == len(self._safe_indices) + len(self._unsafe_indices) @property def num_samples_examined(self): @@ -78,8 +77,11 @@ def __len__(self): return len(self.dataset) def __iter__(self): - return (self._safe_get_item(i) for i in range(len(self)) - if self._safe_get_item(i) is not None) + return ( + self._safe_get_item(i) + for i in range(len(self)) + if self._safe_get_item(i) is not None + ) @memoize def __getitem__(self, idx): diff --git a/nonechucks/sampler.py b/nonechucks/sampler.py index 6346113..1140715 100644 --- a/nonechucks/sampler.py +++ b/nonechucks/sampler.py @@ -41,8 +41,9 @@ def __init__(self, dataset, sampler=None, step_to_index_fn=None): specified, the default function returns the `num_samples_examined` as the output. """ - assert isinstance(dataset, SafeDataset), \ - "dataset must be an instance of SafeDataset." + assert isinstance( + dataset, SafeDataset + ), "dataset must be an instance of SafeDataset." self.dataset = dataset if sampler is None: @@ -64,8 +65,7 @@ def __iter__(self): def _get_next_index(self): """Helper function that calls `step_to_index_fn` and decides whether to sample directly from `dataset` or through `sampler`.""" - index = self.step_to_index_fn( - self.num_valid_samples, self.num_samples_examined) + index = self.step_to_index_fn(self.num_valid_samples, self.num_samples_examined) if self.sampler is not None: index = self.sampler_indices[index] return index diff --git a/nonechucks/utils.py b/nonechucks/utils.py index d23edda..9ad08a6 100644 --- a/nonechucks/utils.py +++ b/nonechucks/utils.py @@ -4,6 +4,7 @@ from functools import partial import torch + try: from torch.utils.data.dataloader import default_collate except ImportError: @@ -59,8 +60,7 @@ def collate_batches(batches, collate_fn=default_collate): elif isinstance(batches[0], collections.Sequence): return list(chain(*batches)) elif isinstance(batches[0], collections.Mapping): - return {key: default_collate([d[key] for d in batches]) - for key in batches[0]} + return {key: default_collate([d[key] for d in batches]) for key in batches[0]} raise TypeError((error_msg.format(type(batches[0])))) diff --git a/setup.py b/setup.py index 417c001..6fa5c0b 100644 --- a/setup.py +++ b/setup.py @@ -1,18 +1,20 @@ from setuptools import setup, find_packages -setup(name='nonechucks', - version='0.3.1', - url='https://github.com/msamogh/nonechucks', - license='MIT', - author='Amogh Mannekote', - author_email='msamogh@gmail.com', - description="""nonechucks is a library that provides wrappers for """ + - """PyTorch's datasets, samplers, and transforms to """ + - """allow for dropping unwanted or invalid samples """ + - """dynamically.""", - install_requires=["future"], - packages=find_packages(), - long_description=open('README.md').read(), - long_description_content_type='text/markdown', - zip_safe=False) +setup( + name="nonechucks", + version="0.3.1", + url="https://github.com/msamogh/nonechucks", + license="MIT", + author="Amogh Mannekote", + author_email="msamogh@gmail.com", + description="""nonechucks is a library that provides wrappers for """ + + """PyTorch's datasets, samplers, and transforms to """ + + """allow for dropping unwanted or invalid samples """ + + """dynamically.""", + install_requires=["future"], + packages=find_packages(), + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + zip_safe=False, +) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 4b9a8a7..def39ff 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -16,8 +16,7 @@ class SafeDatasetTest(unittest.TestCase): """Unit tests for `SafeDataset`.""" - SafeDatasetPair = collections.namedtuple( - 'SafeDatasetPair', ['unsafe', 'safe']) + SafeDatasetPair = collections.namedtuple("SafeDatasetPair", ["unsafe", "safe"]) @classmethod def get_safe_dataset_pair(cls, dataset, **kwargs): @@ -25,7 +24,8 @@ def get_safe_dataset_pair(cls, dataset, **kwargs): both the unsafe and safe versions of the dataset. """ return SafeDatasetTest.SafeDatasetPair( - dataset, nonechucks.SafeDataset(dataset, **kwargs)) + dataset, nonechucks.SafeDataset(dataset, **kwargs) + ) def setUp(self): tensor_data = data.TensorDataset(torch.arange(0, 10)) @@ -75,5 +75,6 @@ def test_import(self): self.assertIsNotNone(SafeDataset) self.assertIsNotNone(SafeSampler) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 8632d44..a0b95bb 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -12,30 +12,30 @@ class SafeSamplerTest(unittest.TestCase): - def test_sequential_sampler(self): dataset = data.TensorDataset(torch.arange(0, 10)) dataset = nonechucks.SafeDataset(dataset) dataloader = data.DataLoader( - dataset, - sampler=nonechucks.SafeSequentialSampler(dataset)) + dataset, sampler=nonechucks.SafeSequentialSampler(dataset) + ) for i_batch, sample_batched in enumerate(dataloader): - print('Sample {}: {}'.format(i_batch, sample_batched)) + print("Sample {}: {}".format(i_batch, sample_batched)) def test_first_last_sampler(self): dataset = data.TensorDataset(torch.arange(0, 10)) dataset = nonechucks.SafeDataset(dataset) dataloader = data.DataLoader( - dataset, - sampler=nonechucks.SafeFirstAndLastSampler(dataset)) + dataset, sampler=nonechucks.SafeFirstAndLastSampler(dataset) + ) for i_batch, sample_batched in enumerate(dataloader): - print('Sample {}: {}'.format(i_batch, sample_batched)) + print("Sample {}: {}".format(i_batch, sample_batched)) - @mock.patch('torch.utils.data.TensorDataset.__getitem__') - @mock.patch('torch.utils.data.TensorDataset.__len__') + @mock.patch("torch.utils.data.TensorDataset.__getitem__") + @mock.patch("torch.utils.data.TensorDataset.__len__") def test_sampler_wrapper(self, mock_len, mock_get_item): def side_effect(idx): return [0, 1, None, 3, 4, 5][idx] + mock_get_item.side_effect = side_effect mock_len.return_value = 6 dataset = data.TensorDataset(torch.arange(0, 10)) @@ -43,10 +43,11 @@ def side_effect(idx): self.assertEqual(len(dataset), 6) sequential_sampler = data.SequentialSampler(dataset) dataloader = data.DataLoader( - dataset, - sampler=nonechucks.SafeSampler(dataset, sequential_sampler)) + dataset, sampler=nonechucks.SafeSampler(dataset, sequential_sampler) + ) for i_batch, sample_batched in enumerate(dataloader): - print('Sample {}: {}'.format(i_batch, sample_batched)) + print("Sample {}: {}".format(i_batch, sample_batched)) + -if __name__ == '__main__': +if __name__ == "__main__": unittest.main()