From 186f6bf3b1afdb1538ddd31b0bcc7ed858c71ce2 Mon Sep 17 00:00:00 2001 From: RikeEl <75080734+RikeEl@users.noreply.github.com> Date: Wed, 8 Sep 2021 13:19:55 +0200 Subject: [PATCH] Add files via upload --- Scripts/DataSet.py | 62 +++++++++ Scripts/UNet.py | 108 ++++++++++++++++ Scripts/__pycache__/DataSet.cpython-38.pyc | Bin 0 -> 2681 bytes Scripts/__pycache__/UNet.cpython-38.pyc | Bin 0 -> 3746 bytes Scripts/main.py | 141 +++++++++++++++++++++ 5 files changed, 311 insertions(+) create mode 100644 Scripts/DataSet.py create mode 100644 Scripts/UNet.py create mode 100644 Scripts/__pycache__/DataSet.cpython-38.pyc create mode 100644 Scripts/__pycache__/UNet.cpython-38.pyc create mode 100644 Scripts/main.py diff --git a/Scripts/DataSet.py b/Scripts/DataSet.py new file mode 100644 index 0000000..1257955 --- /dev/null +++ b/Scripts/DataSet.py @@ -0,0 +1,62 @@ +import SimpleITK as sitk +import glob +import numpy as np +from torch.utils.data import Dataset, DataLoader +import math +import torch + + +class DataLoaderPNG8(Dataset): + def __init__(self, ): + imageDict = getImagesOfFolder("../Data/PNG8") + patientData = [] + segmentation = [] + n_samples = len(imageDict["flair"]) + for i in range(n_samples): + patientData.append( + np.array([imageDict["flair"][i], imageDict["t1"][i], imageDict["t1ce"][i], imageDict["t2"][i]])) + segmentation.append( + torch.from_numpy(np.array(imageDict["seg"][i])[:, :, np.newaxis]).permute(2, 0, 1).numpy()) + self.patientData = (torch.from_numpy(np.array(patientData)) - 0) / (255.0-0) + self.segmentation = torch.from_numpy(np.array(segmentation)) + self.n_samples = n_samples + + def __getitem__(self, index): + print(index) + return self.patientData[index], self.segmentation[index] + + def __len__(self): + return self.n_samples + + +class DataLoaderPNG16(Dataset): + def __init__(self, ): + imageDict = getImagesOfFolder("../Data/PNG8") + patientData = [] + segmentation = [] + n_samples = len(imageDict["flair"]) + for i in range(n_samples): + patientData.append( + np.array([imageDict["flair"][i], imageDict["t1"][i], imageDict["t1ce"][i], imageDict["t2"][i]])) + segmentation.append( + torch.from_numpy(np.array(imageDict["seg"][i])[:, :, np.newaxis]).permute(2, 0, 1).numpy()) + self.patientData = torch.from_numpy(np.array(patientData)) + self.segmentation = torch.from_numpy(np.array(segmentation)) + self.n_samples = n_samples + + def __getitem__(self, index): + return self.patientData[index], self.segmentation[index] + + def __len__(self): + return self.n_samples + + +def getImagesOfFolder(dir_path): + dict = {"flair": [], "t1": [], "t1ce": [], "t2": [], "seg": []} + for element in glob.glob(dir_path + "/*.png"): + imageNames = element.split("/")[-1] + modality = imageNames.split("_")[-1].split(".")[0] + data = sitk.ReadImage(element) + img = sitk.GetArrayFromImage(data) + dict[modality].append(img) + return dict diff --git a/Scripts/UNet.py b/Scripts/UNet.py new file mode 100644 index 0000000..f70d1bd --- /dev/null +++ b/Scripts/UNet.py @@ -0,0 +1,108 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DoubleConv(nn.Module): + """(convolution => [BN] => ReLU) * 2""" + + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + if not mid_channels: + mid_channels = out_channels + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(mid_channels), + nn.ReLU(inplace=True), + nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + return self.double_conv(x) + + +class Down(nn.Module): + """Downscaling with maxpool then double conv""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool2d(2), + DoubleConv(in_channels, out_channels) + ) + + def forward(self, x): + return self.maxpool_conv(x) + + +class Up(nn.Module): + """Upscaling then double conv""" + + def __init__(self, in_channels, out_channels, bilinear=True): + super().__init__() + + # if bilinear, use the normal convolutions to reduce the number of channels + if bilinear: + self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) + else: + self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) + self.conv = DoubleConv(in_channels, out_channels) + + def forward(self, x1, x2): + x1 = self.up(x1) + # input is CHW + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2]) + # if you have padding issues, see + # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a + # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + + +class OutConv(nn.Module): + def __init__(self, in_channels, out_channels): + super(OutConv, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) + + def forward(self, x): + return self.conv(x) + + +class UNet(nn.Module): + def __init__(self, n_channels, n_classes, bilinear=True): + super(UNet, self).__init__() + self.n_channels = n_channels + self.n_classes = n_classes + self.bilinear = bilinear + + self.inc = DoubleConv(n_channels, 64) + self.down1 = Down(64, 128) + self.down2 = Down(128, 256) + self.down3 = Down(256, 512) + factor = 2 if bilinear else 1 + self.down4 = Down(512, 1024 // factor) + self.up1 = Up(1024, 512 // factor, bilinear) + self.up2 = Up(512, 256 // factor, bilinear) + self.up3 = Up(256, 128 // factor, bilinear) + self.up4 = Up(128, 64, bilinear) + self.outc = OutConv(64, n_classes) + + def forward(self, x): + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + x = self.up1(x5, x4) + x = self.up2(x, x3) + x = self.up3(x, x2) + x = self.up4(x, x1) + logits = self.outc(x) + return logits \ No newline at end of file diff --git a/Scripts/__pycache__/DataSet.cpython-38.pyc b/Scripts/__pycache__/DataSet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7effd0be8416cba9503f5f7af7d7f02ac280cca GIT binary patch literal 2681 zcmdT`OK%)S5bmD$K5VZIO6-Kk0uiu5TH^pwh!8T0No0{MQA{|D7LA6T-p$xMvzwl_ zNUYtH!;J$!fP;PX-{8WPBkB_pC;x?9!dJbH<8^rDNRPU@x~m`6)n8T3-I}oS17{VJ(~&3rShj zl8Tt$b#MmXV2$b_v|_D>rTw*7#f4D1Ow7$eoCw(=;mKm5vx-_=01@UQxCBou9Bxlk z#AS?Xp_}+PAkeeV_PgD>aiznt^l9%2YxXstJq?$d*?@98-} zm9rUivDZG|E0JN7(!RY{-bbIa`xlN^TeM>mEV870o{d;kIhNQY!AN!FfX*ZgcBr%- zSDVdcvc)o)LA%|4oJwt}744{%jsSN_+lA;Il944@>5xbdm9gIks|T}sz9-ar7WYK) zW%v4^j~CK*U*y`Aao!W!i-$vzC)&z~+Kr`*A8J<(WapkPcjX|9^6hN+Q2V*~E`E>} z+8+v;Z7ZSOiJ&r!RVs2tI_fGk$jAlIclD&lO|7sY|6okH}xVJRg=*3yoWOHfdr zBXS-@^Ynnp*TFsh`r6g@-9pHseJ9-#pSRaUv84vX_T4;f---(*WVC*3z5U^a$OQG; z*ZPC)sq4{f;M2c?-dTs!@?#2Py$u z@=dS@+!!AXsc`yiKZ<~pD5856=GJ`Y%yXxjDJ{G(?3>4#h$3B%qHK_C_ld7Y(f#eX zKat2c=$N;OG^P)vmzQ-SbU}nfR*WhSxxheHmE>Br%Dz5yB zo&Rrf*RPR)`QBMo< zVPoYo?mO`id2xF6GsHI`sM8k5h_@w;O@WhMi{!?QhhTcgU&4QK$6T>JSS_*6BIO|6 z-5gs;cSvtYH%M$_8%c~O?xY9K1LtEL33AUKAuWwvaLx!b#@>jJum*_|vyi@C8&TE| zS`GJak+b71Uu+KZ-ebOe_yk4sNb{zFQUk%i*B^YV-D22J6*57pw!+dK5ho@sh4VLr z`iv6Kb>tf}B0&UfLIe_RN!Tc!P?N$N;~p|nZnWVneWben)?-q@Ti-y^Y4)A27A zmb_=>6a>SeB)Yrg zX2^}z1qBq9fm|8rBS@Eei$2HP_9D*=x~*>|I^L#W#i^9TJZ!zGRX(5U-x_{`hwTGH57tzHnjT=XYVs9r1gYJD-n;} z6RYpRVoL`WXKL|bais@~H?=fi@nr*+#?;b;r72smw5FDp1|X_l5@s703P=ELNRagvp3G|(=s ze@p7ty{L>2d$}6o#-?TKmNWqhDKvqt?IIcM>sFeD@nMv~aiQD!xSZeh+rw1O@6Ney zgke00iXsfp**9-(9_GVjvxpC~B$bKU%pYevgM2JEA4El&sE<_sD2dC>hp|dWWwE)} zOUjMWzf=QjrB=KM!FZF~=C{Uu=th1sxxUEc#%!Unp!cU2=@hvgw1SXdu~Sy@GXU$H z_X4i2!FDdx8amC7)KAdYg1l;PKUa?;QOzBfWT z=p%UFl4h~GJ)Mn#<+){Yre2`eJ0wiFO!Q352s$sLnXK3F@wxx)WQ7*_qrswnilG3P zSIN^liAA1_J4z)zoS~#`+)VMIes=kbMP^^Y1NwBH*|w+17NcFdh1-6_@-bJ0Xjky^6OyKTm!Q@sQ3X&9T^hG4*a^6LwpxFb+i@DQxTr9VOTO9 zqWpe?JJ39BC;UBaVMoypq8}e>D~?LFK`*;CNs6OeS{!Zl9hvU$|Gn-$Q7_S$SCqtC zuoqOG2pa;nE?(p=YkucC9X@Fycyj{X{oxvU9bXfl&1+<}rO5VNjm#2IP$ACP&oCeM zSTLVu&IGJ)o3nsJA<^E4z$+elMX->{v^>g}`NL2lT1OeLYS*lCgBHar^9&`8yfMSP5gI^!!(ZI#D(BV=)eKqsCC+{tf}rp#~H=E z)1Z^$u~&HySzS&|br-(XpGn*!LGw}fNf7E4)vkJ<#9v7K6#}~x9WHBzBsx@YL-fVb zZM1E)-)eV|AEafmR8aFHMTB?f2jC_SzVw@guE{5D6S5^1eTO_;Rha~p9*ss}^!aK> zktJxWZFJf9Ziq2pC>$BU9{c#r4UnBfsXfAZMMIX5j%puooum)NNJ_agg@TGZ} h8PTe=|A+c$d