Skip to content

Commit

Permalink
CIFAR10 dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
rsk97 committed Mar 8, 2023
1 parent 23b4cf5 commit 265c082
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion dataset/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import numpy as np
import torch
import random
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable

use_cuda = True # maybe change this to infer the device
Expand Down Expand Up @@ -33,4 +35,25 @@ def get_8gaussians(batch_size):
out = Variable(torch.Tensor(dataset))
if use_cuda:
out = out.cuda()
yield out
yield out

def CIFAR(batch_size = 64):
"""The function returns the train and test loaders as well as the class labels in human readable form"""
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
shuffle=True, num_workers=2)

test_set = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size,
shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

return train_loader, test_loader, classes

0 comments on commit 265c082

Please sign in to comment.