Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
friederike-schneider-mint authored Sep 8, 2021
1 parent 2a4b287 commit 186f6bf
Show file tree
Hide file tree
Showing 5 changed files with 311 additions and 0 deletions.
62 changes: 62 additions & 0 deletions Scripts/DataSet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import SimpleITK as sitk
import glob
import numpy as np
from torch.utils.data import Dataset, DataLoader
import math
import torch


class DataLoaderPNG8(Dataset):
def __init__(self, ):
imageDict = getImagesOfFolder("../Data/PNG8")
patientData = []
segmentation = []
n_samples = len(imageDict["flair"])
for i in range(n_samples):
patientData.append(
np.array([imageDict["flair"][i], imageDict["t1"][i], imageDict["t1ce"][i], imageDict["t2"][i]]))
segmentation.append(
torch.from_numpy(np.array(imageDict["seg"][i])[:, :, np.newaxis]).permute(2, 0, 1).numpy())
self.patientData = (torch.from_numpy(np.array(patientData)) - 0) / (255.0-0)
self.segmentation = torch.from_numpy(np.array(segmentation))
self.n_samples = n_samples

def __getitem__(self, index):
print(index)
return self.patientData[index], self.segmentation[index]

def __len__(self):
return self.n_samples


class DataLoaderPNG16(Dataset):
def __init__(self, ):
imageDict = getImagesOfFolder("../Data/PNG8")
patientData = []
segmentation = []
n_samples = len(imageDict["flair"])
for i in range(n_samples):
patientData.append(
np.array([imageDict["flair"][i], imageDict["t1"][i], imageDict["t1ce"][i], imageDict["t2"][i]]))
segmentation.append(
torch.from_numpy(np.array(imageDict["seg"][i])[:, :, np.newaxis]).permute(2, 0, 1).numpy())
self.patientData = torch.from_numpy(np.array(patientData))
self.segmentation = torch.from_numpy(np.array(segmentation))
self.n_samples = n_samples

def __getitem__(self, index):
return self.patientData[index], self.segmentation[index]

def __len__(self):
return self.n_samples


def getImagesOfFolder(dir_path):
dict = {"flair": [], "t1": [], "t1ce": [], "t2": [], "seg": []}
for element in glob.glob(dir_path + "/*.png"):
imageNames = element.split("/")[-1]
modality = imageNames.split("_")[-1].split(".")[0]
data = sitk.ReadImage(element)
img = sitk.GetArrayFromImage(data)
dict[modality].append(img)
return dict
108 changes: 108 additions & 0 deletions Scripts/UNet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""

def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)

def forward(self, x):
return self.double_conv(x)


class Down(nn.Module):
"""Downscaling with maxpool then double conv"""

def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)

def forward(self, x):
return self.maxpool_conv(x)


class Up(nn.Module):
"""Upscaling then double conv"""

def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()

# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)

def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]

x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
return self.conv(x)


class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

def forward(self, x):
return self.conv(x)


class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=True):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear

self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
factor = 2 if bilinear else 1
self.down4 = Down(512, 1024 // factor)
self.up1 = Up(1024, 512 // factor, bilinear)
self.up2 = Up(512, 256 // factor, bilinear)
self.up3 = Up(256, 128 // factor, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes)

def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
Binary file added Scripts/__pycache__/DataSet.cpython-38.pyc
Binary file not shown.
Binary file added Scripts/__pycache__/UNet.cpython-38.pyc
Binary file not shown.
141 changes: 141 additions & 0 deletions Scripts/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import torch.optim
import math

import UNet
from DataSet import DataLoaderPNG8
from DataSet import DataLoaderPNG16
import torch.nn as nn
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm


PATH = "../tmp/state_dict_model.pt"
################################################################################################################
#########################################Load Data##############################################################
################################################################################################################
def loadData(data,batch_s):
indices = list(range(len(data)))
split_index1=int(len(data)*0.6)
split_index2 = int(len(data) * 0.2)+split_index1
trainSampler = SubsetRandomSampler(indices[:split_index1])
testSampler = SubsetRandomSampler(indices[split_index1:split_index2])
valSampler = SubsetRandomSampler(indices[split_index2:])

trainDataLoader=DataLoader(dataset=data,
batch_size=batch_s,
sampler=trainSampler,
drop_last=True)

testDataLoader=DataLoader(dataset=data,
batch_size=batch_s,
sampler=testSampler,
drop_last=True)

valDataLoader=DataLoader(dataset=data,
batch_size=batch_s,
sampler=valSampler,
drop_last=True)

#trainIterator = iter(trainDataLoader)
#data = trainIterator.next()
#img,seg = data
#print("Image: ",img,"\n Segmentation: ",seg)


return trainDataLoader, testDataLoader, valDataLoader

def visualize_images(axs, source, target, result, phase, epoch, every_epoche):
offset=0
flair=source[0]
t1=source[1]
t1ce=source[2]
t2=source[3]

flair=flair[0, 0, ...].cpu()
t1=t1[0,0,...].cpu()

target=target[0,0,...].cpu()
preds=(result.detach()[0,0,...].cpu()).float()

axs[epoch//every_epoche][0+offset].imshow(flair)
axs[epoch//every_epoche][0+offset].set_title(str(epoch)+': '+phase+' flair')
axs[epoch // every_epoche][0 + offset].grid(False)
axs[epoch//every_epoche][0+offset].imshow(t1)
axs[epoch//every_epoche][0+offset].set_title(str(epoch)+': '+phase+' t1')
axs[epoch // every_epoche][0 + offset].grid(False)
axs[epoch//every_epoche][0+offset].imshow(t1ce)
axs[epoch//every_epoche][0+offset].set_title(str(epoch)+': '+phase+' t1ce')
axs[epoch // every_epoche][0 + offset].grid(False)
axs[epoch//every_epoche][0+offset].imshow(t2)
axs[epoch//every_epoche][0+offset].set_title(str(epoch)+': '+phase+' t2')
axs[epoch // every_epoche][0 + offset].grid(False)
axs[epoch//every_epoche][0+offset].imshow(target)
axs[epoch//every_epoche][0+offset].set_title(str(epoch)+': '+phase+' prediction')
axs[epoch // every_epoche][0 + offset].grid(False)
axs[epoch//every_epoche][0+offset].imshow(preds)
axs[epoch//every_epoche][0+offset].set_title(str(epoch)+': '+phase+' segmentation')
axs[epoch // every_epoche][0 + offset].grid(False)
return axs

################################################################################################################
#############################################Training and Evaluation############################################
################################################################################################################

def train(trainLoader,valLoader,n_epochs,batch_s):
model=UNet.UNet(4,3)
model.train()
l1_loss=nn.L1Loss()
losses_training=[]
optimizer=torch.optim.Adam(model.parameters(),lr=0.001)

progress =tqdm(range(n_epochs),desc='progress')
for epoch in progress:
sum_loss = 0
for i,(patDat,seg) in enumerate(trainLoader):
predSeg=model.forward(patDat)
loss=l1_loss(predSeg,seg)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Save
torch.save(model.state_dict(), PATH)
evaluate(valLoader,model,batch_s)
#sum_loss+=loss.item()
# losses_training.append(sum_loss/len(dataSet))
#if epoch%5==0:
# axs = visualize_images(axs,patDat,seg,predSeg,epoch,every_epoche=5)

def evaluate(valLoader,model):
with torch.no_grad():
model.load_state_dict(torch.load(PATH))
model.eval()
l1_loss=nn.L1Loss()
val_loss = 0
for i,(patDat,seg) in enumerate(valLoader):
predSeg=model.forward(patDat)
loss = l1_loss(predSeg, seg)


################################################################################################################
#############################################Test###############################################################
################################################################################################################







################################################################################################################
#############################################Run Code###########################################################
################################################################################################################

if __name__ == '__main__':
batch_s=3
epoche=5
dataSet = DataLoaderPNG8()
trainDataLoader,testDataloader,valDataLoader=loadData(dataSet,batch_s)
train(trainDataLoader,valDataLoader,epoche,batch_s)


0 comments on commit 186f6bf

Please sign in to comment.