Skip to content

Commit

Permalink
Merge pull request #18 from msamogh/torch-1.2
Browse files Browse the repository at this point in the history
Support Torch 1.2
  • Loading branch information
msamogh authored Aug 28, 2019
2 parents 32955e7 + a9aff15 commit d3ad194
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 12 deletions.
46 changes: 42 additions & 4 deletions nonechucks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,43 @@
from .dataset import SafeDataset
from .sampler import SafeSampler
from .dataloader import SafeDataLoader
import logging

__all__ = ["SafeDataset", "SafeSampler", "SafeDataLoader"]
import torch
import torch.utils.data


logger = logging.getLogger(__name__)


def _get_pytorch_version():
version = torch.__version__
major, minor, patch = [int(x) for x in version.split(".")]
if major != 1:
raise RuntimeError(
"nonechucks only supports PyTorch major version 1 at the moment."
)
if minor > 2:
logger.warn(
"nonechucks may not work properly with this version of PyTorch ({}). "
"It has only been tested on PyTorch versions 1.0, 1.1, and 1.2".format(
version
)
)
return major, minor


MAJOR, MINOR = _get_pytorch_version()

if MINOR > 1:
SingleProcessDataLoaderIter = (
torch.utils.data.dataloader._SingleProcessDataLoaderIter
)
MultiProcessingDataLoaderIter = (
torch.utils.data.dataloader._MultiProcessingDataLoaderIter
)
else:
SingleProcessDataLoaderIter = torch.utils.data.dataloader._DataLoaderIter
MultiProcessingDataLoaderIter = torch.utils.data.dataloader._DataLoaderIter


from nonechucks.dataset import SafeDataset
from nonechucks.sampler import SafeSampler
from nonechucks.dataloader import SafeDataLoader
11 changes: 6 additions & 5 deletions nonechucks/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
except ImportError:
from torch.utils.data._utils.collate import default_collate

from .dataset import SafeDataset
from .sampler import SafeSampler
from .utils import batch_len, collate_batches, slice_batch
from nonechucks import SingleProcessDataLoaderIter, MultiProcessingDataLoaderIter
from nonechucks.dataset import SafeDataset
from nonechucks.sampler import SafeSampler
from nonechucks.utils import batch_len, collate_batches, slice_batch


class _SafeDataLoaderCaller(type):
Expand Down Expand Up @@ -44,7 +45,7 @@ def _restore_default_samplers(cls):
data.dataloader.RandomSampler = cls.random


class _SafeDataLoaderIter(data.dataloader._DataLoaderIter):
class _SafeDataLoaderIter(MultiProcessingDataLoaderIter):
def __init__(self, loader):
super().__init__(loader)
self.batch_size = loader.batch_size
Expand Down Expand Up @@ -145,4 +146,4 @@ def __init__(self, dataset, **kwargs):
def __iter__(self):
if self.num_workers > 0:
return _SafeDataLoaderIter(self)
return data.dataloader._DataLoaderIter(self)
return SingleProcessDataLoaderIter(self)
2 changes: 1 addition & 1 deletion nonechucks/dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.utils.data

from .utils import memoize
from nonechucks.utils import memoize


class SafeDataset(torch.utils.data.Dataset):
Expand Down
2 changes: 1 addition & 1 deletion nonechucks/sampler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.utils.data

from .dataset import SafeDataset
from nonechucks.dataset import SafeDataset


class SafeSampler(torch.utils.data.sampler.Sampler):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

setup(
name="nonechucks",
version="0.3.1",
version="0.4.0",
url="https://github.com/msamogh/nonechucks",
license="MIT",
author="Amogh Mannekote",
Expand Down

0 comments on commit d3ad194

Please sign in to comment.