-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2a4b287
commit 186f6bf
Showing
5 changed files
with
311 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import SimpleITK as sitk | ||
import glob | ||
import numpy as np | ||
from torch.utils.data import Dataset, DataLoader | ||
import math | ||
import torch | ||
|
||
|
||
class DataLoaderPNG8(Dataset): | ||
def __init__(self, ): | ||
imageDict = getImagesOfFolder("../Data/PNG8") | ||
patientData = [] | ||
segmentation = [] | ||
n_samples = len(imageDict["flair"]) | ||
for i in range(n_samples): | ||
patientData.append( | ||
np.array([imageDict["flair"][i], imageDict["t1"][i], imageDict["t1ce"][i], imageDict["t2"][i]])) | ||
segmentation.append( | ||
torch.from_numpy(np.array(imageDict["seg"][i])[:, :, np.newaxis]).permute(2, 0, 1).numpy()) | ||
self.patientData = (torch.from_numpy(np.array(patientData)) - 0) / (255.0-0) | ||
self.segmentation = torch.from_numpy(np.array(segmentation)) | ||
self.n_samples = n_samples | ||
|
||
def __getitem__(self, index): | ||
print(index) | ||
return self.patientData[index], self.segmentation[index] | ||
|
||
def __len__(self): | ||
return self.n_samples | ||
|
||
|
||
class DataLoaderPNG16(Dataset): | ||
def __init__(self, ): | ||
imageDict = getImagesOfFolder("../Data/PNG8") | ||
patientData = [] | ||
segmentation = [] | ||
n_samples = len(imageDict["flair"]) | ||
for i in range(n_samples): | ||
patientData.append( | ||
np.array([imageDict["flair"][i], imageDict["t1"][i], imageDict["t1ce"][i], imageDict["t2"][i]])) | ||
segmentation.append( | ||
torch.from_numpy(np.array(imageDict["seg"][i])[:, :, np.newaxis]).permute(2, 0, 1).numpy()) | ||
self.patientData = torch.from_numpy(np.array(patientData)) | ||
self.segmentation = torch.from_numpy(np.array(segmentation)) | ||
self.n_samples = n_samples | ||
|
||
def __getitem__(self, index): | ||
return self.patientData[index], self.segmentation[index] | ||
|
||
def __len__(self): | ||
return self.n_samples | ||
|
||
|
||
def getImagesOfFolder(dir_path): | ||
dict = {"flair": [], "t1": [], "t1ce": [], "t2": [], "seg": []} | ||
for element in glob.glob(dir_path + "/*.png"): | ||
imageNames = element.split("/")[-1] | ||
modality = imageNames.split("_")[-1].split(".")[0] | ||
data = sitk.ReadImage(element) | ||
img = sitk.GetArrayFromImage(data) | ||
dict[modality].append(img) | ||
return dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class DoubleConv(nn.Module): | ||
"""(convolution => [BN] => ReLU) * 2""" | ||
|
||
def __init__(self, in_channels, out_channels, mid_channels=None): | ||
super().__init__() | ||
if not mid_channels: | ||
mid_channels = out_channels | ||
self.double_conv = nn.Sequential( | ||
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), | ||
nn.BatchNorm2d(mid_channels), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), | ||
nn.BatchNorm2d(out_channels), | ||
nn.ReLU(inplace=True) | ||
) | ||
|
||
def forward(self, x): | ||
return self.double_conv(x) | ||
|
||
|
||
class Down(nn.Module): | ||
"""Downscaling with maxpool then double conv""" | ||
|
||
def __init__(self, in_channels, out_channels): | ||
super().__init__() | ||
self.maxpool_conv = nn.Sequential( | ||
nn.MaxPool2d(2), | ||
DoubleConv(in_channels, out_channels) | ||
) | ||
|
||
def forward(self, x): | ||
return self.maxpool_conv(x) | ||
|
||
|
||
class Up(nn.Module): | ||
"""Upscaling then double conv""" | ||
|
||
def __init__(self, in_channels, out_channels, bilinear=True): | ||
super().__init__() | ||
|
||
# if bilinear, use the normal convolutions to reduce the number of channels | ||
if bilinear: | ||
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) | ||
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) | ||
else: | ||
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) | ||
self.conv = DoubleConv(in_channels, out_channels) | ||
|
||
def forward(self, x1, x2): | ||
x1 = self.up(x1) | ||
# input is CHW | ||
diffY = x2.size()[2] - x1.size()[2] | ||
diffX = x2.size()[3] - x1.size()[3] | ||
|
||
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, | ||
diffY // 2, diffY - diffY // 2]) | ||
# if you have padding issues, see | ||
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a | ||
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd | ||
x = torch.cat([x2, x1], dim=1) | ||
return self.conv(x) | ||
|
||
|
||
class OutConv(nn.Module): | ||
def __init__(self, in_channels, out_channels): | ||
super(OutConv, self).__init__() | ||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) | ||
|
||
def forward(self, x): | ||
return self.conv(x) | ||
|
||
|
||
class UNet(nn.Module): | ||
def __init__(self, n_channels, n_classes, bilinear=True): | ||
super(UNet, self).__init__() | ||
self.n_channels = n_channels | ||
self.n_classes = n_classes | ||
self.bilinear = bilinear | ||
|
||
self.inc = DoubleConv(n_channels, 64) | ||
self.down1 = Down(64, 128) | ||
self.down2 = Down(128, 256) | ||
self.down3 = Down(256, 512) | ||
factor = 2 if bilinear else 1 | ||
self.down4 = Down(512, 1024 // factor) | ||
self.up1 = Up(1024, 512 // factor, bilinear) | ||
self.up2 = Up(512, 256 // factor, bilinear) | ||
self.up3 = Up(256, 128 // factor, bilinear) | ||
self.up4 = Up(128, 64, bilinear) | ||
self.outc = OutConv(64, n_classes) | ||
|
||
def forward(self, x): | ||
x1 = self.inc(x) | ||
x2 = self.down1(x1) | ||
x3 = self.down2(x2) | ||
x4 = self.down3(x3) | ||
x5 = self.down4(x4) | ||
x = self.up1(x5, x4) | ||
x = self.up2(x, x3) | ||
x = self.up3(x, x2) | ||
x = self.up4(x, x1) | ||
logits = self.outc(x) | ||
return logits |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import torch.optim | ||
import math | ||
|
||
import UNet | ||
from DataSet import DataLoaderPNG8 | ||
from DataSet import DataLoaderPNG16 | ||
import torch.nn as nn | ||
from torch.utils.data.sampler import SubsetRandomSampler | ||
from torch.utils.data import Dataset, DataLoader | ||
from tqdm import tqdm | ||
|
||
|
||
PATH = "../tmp/state_dict_model.pt" | ||
################################################################################################################ | ||
#########################################Load Data############################################################## | ||
################################################################################################################ | ||
def loadData(data,batch_s): | ||
indices = list(range(len(data))) | ||
split_index1=int(len(data)*0.6) | ||
split_index2 = int(len(data) * 0.2)+split_index1 | ||
trainSampler = SubsetRandomSampler(indices[:split_index1]) | ||
testSampler = SubsetRandomSampler(indices[split_index1:split_index2]) | ||
valSampler = SubsetRandomSampler(indices[split_index2:]) | ||
|
||
trainDataLoader=DataLoader(dataset=data, | ||
batch_size=batch_s, | ||
sampler=trainSampler, | ||
drop_last=True) | ||
|
||
testDataLoader=DataLoader(dataset=data, | ||
batch_size=batch_s, | ||
sampler=testSampler, | ||
drop_last=True) | ||
|
||
valDataLoader=DataLoader(dataset=data, | ||
batch_size=batch_s, | ||
sampler=valSampler, | ||
drop_last=True) | ||
|
||
#trainIterator = iter(trainDataLoader) | ||
#data = trainIterator.next() | ||
#img,seg = data | ||
#print("Image: ",img,"\n Segmentation: ",seg) | ||
|
||
|
||
return trainDataLoader, testDataLoader, valDataLoader | ||
|
||
def visualize_images(axs, source, target, result, phase, epoch, every_epoche): | ||
offset=0 | ||
flair=source[0] | ||
t1=source[1] | ||
t1ce=source[2] | ||
t2=source[3] | ||
|
||
flair=flair[0, 0, ...].cpu() | ||
t1=t1[0,0,...].cpu() | ||
|
||
target=target[0,0,...].cpu() | ||
preds=(result.detach()[0,0,...].cpu()).float() | ||
|
||
axs[epoch//every_epoche][0+offset].imshow(flair) | ||
axs[epoch//every_epoche][0+offset].set_title(str(epoch)+': '+phase+' flair') | ||
axs[epoch // every_epoche][0 + offset].grid(False) | ||
axs[epoch//every_epoche][0+offset].imshow(t1) | ||
axs[epoch//every_epoche][0+offset].set_title(str(epoch)+': '+phase+' t1') | ||
axs[epoch // every_epoche][0 + offset].grid(False) | ||
axs[epoch//every_epoche][0+offset].imshow(t1ce) | ||
axs[epoch//every_epoche][0+offset].set_title(str(epoch)+': '+phase+' t1ce') | ||
axs[epoch // every_epoche][0 + offset].grid(False) | ||
axs[epoch//every_epoche][0+offset].imshow(t2) | ||
axs[epoch//every_epoche][0+offset].set_title(str(epoch)+': '+phase+' t2') | ||
axs[epoch // every_epoche][0 + offset].grid(False) | ||
axs[epoch//every_epoche][0+offset].imshow(target) | ||
axs[epoch//every_epoche][0+offset].set_title(str(epoch)+': '+phase+' prediction') | ||
axs[epoch // every_epoche][0 + offset].grid(False) | ||
axs[epoch//every_epoche][0+offset].imshow(preds) | ||
axs[epoch//every_epoche][0+offset].set_title(str(epoch)+': '+phase+' segmentation') | ||
axs[epoch // every_epoche][0 + offset].grid(False) | ||
return axs | ||
|
||
################################################################################################################ | ||
#############################################Training and Evaluation############################################ | ||
################################################################################################################ | ||
|
||
def train(trainLoader,valLoader,n_epochs,batch_s): | ||
model=UNet.UNet(4,3) | ||
model.train() | ||
l1_loss=nn.L1Loss() | ||
losses_training=[] | ||
optimizer=torch.optim.Adam(model.parameters(),lr=0.001) | ||
|
||
progress =tqdm(range(n_epochs),desc='progress') | ||
for epoch in progress: | ||
sum_loss = 0 | ||
for i,(patDat,seg) in enumerate(trainLoader): | ||
predSeg=model.forward(patDat) | ||
loss=l1_loss(predSeg,seg) | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
# Save | ||
torch.save(model.state_dict(), PATH) | ||
evaluate(valLoader,model,batch_s) | ||
#sum_loss+=loss.item() | ||
# losses_training.append(sum_loss/len(dataSet)) | ||
#if epoch%5==0: | ||
# axs = visualize_images(axs,patDat,seg,predSeg,epoch,every_epoche=5) | ||
|
||
def evaluate(valLoader,model): | ||
with torch.no_grad(): | ||
model.load_state_dict(torch.load(PATH)) | ||
model.eval() | ||
l1_loss=nn.L1Loss() | ||
val_loss = 0 | ||
for i,(patDat,seg) in enumerate(valLoader): | ||
predSeg=model.forward(patDat) | ||
loss = l1_loss(predSeg, seg) | ||
|
||
|
||
################################################################################################################ | ||
#############################################Test############################################################### | ||
################################################################################################################ | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
################################################################################################################ | ||
#############################################Run Code########################################################### | ||
################################################################################################################ | ||
|
||
if __name__ == '__main__': | ||
batch_s=3 | ||
epoche=5 | ||
dataSet = DataLoaderPNG8() | ||
trainDataLoader,testDataloader,valDataLoader=loadData(dataSet,batch_s) | ||
train(trainDataLoader,valDataLoader,epoche,batch_s) | ||
|
||
|