Skip to content

Commit

Permalink
trainer and loader works :D
Browse files Browse the repository at this point in the history
  • Loading branch information
madhavkhoslaa committed Jun 28, 2019
1 parent 5b869e4 commit f94b5be
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
4 changes: 2 additions & 2 deletions dataloader/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import glob
import skimage
import torch
import numpy as np


class ImageLoader():
def __init__(self, Images, Annotations,
Expand Down Expand Up @@ -36,8 +38,6 @@ def __getitem__(self, index):
else:
image = skimage.io.imread(self.images[index])
label = skimage.io.imread(self.target_images[index])
image= torch.from_numpy(image)
label= torch.from_numpy(label)
if self.transform:
image = self.transform(image)
if torch.cuda.is_available():
Expand Down
9 changes: 6 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from dataloader.dataloader import ImageLoader, TrainSet, TestSet
from collections import defaultdict
import torch
params = hyperparameters(train_percentage=0.6, batch_size=10, epoch=4)
from torchvision import transforms

transforms_compose= transforms.Compose([transforms.ToTensor()])
params = hyperparameters(train_percentage=0.6, batch_size=1, epoch=4)
if torch.cuda.is_available():
net= UNeT(n_class=1).cuda()
else:
Expand All @@ -20,7 +23,7 @@
Annotations=ANNOTATIONS_DIR,
train_percentage=0.7)
loss_val = Loss()
Train = TrainSet(Images.train_set, extension="tif", transform=None)
Train = TrainSet(Images.train_set, extension="tif", transform=transforms_compose)
Test = TestSet(Images.test_set, extension="tif", transform=None)
TrainLoder = DataLoader(
Train,
Expand All @@ -37,7 +40,7 @@
running_loss = 0.0
for i, data in enumerate(TrainLoder, 0):
inputs, labels= data["Image"], data["Label"]
inputs, labels= inputs.transpose(0, 3).transpose(1, 2), labels.transpose(0, 3).transpose(1, 2)
#inputs, labels= inputs.permute(0, 3, 1, 2), labels.permute(0, 3, 1, 2)
optimizer.zero_grad()
outputs = net(inputs)
loss = loss_val.calc_loss(outputs, labels, metrics, bce_weight=0.5)
Expand Down

0 comments on commit f94b5be

Please sign in to comment.