Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Model: PialNN/pialnn/1.0.0 #4

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions PialNN/pialnn/1.0.0/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import argparse

def load_config():

# args
parser = argparse.ArgumentParser(description="PialNN")

# data
parser.add_argument('--data_path', default="./data/train/", type=str, help="path of the dataset")
parser.add_argument('--hemisphere', default="lh", type=str, help="left or right hemisphere (lh or rh)")
# model file
parser.add_argument('--model', help="path to best model")
#model
parser.add_argument('--nc', default=128, type=int, help="num of channels")
parser.add_argument('--K', default=5, type=int, help="kernal size")
parser.add_argument('--n_scale', default=3, type=int, help="num of scales for image pyramid")
parser.add_argument('--n_smooth', default=1, type=int, help="num of Laplacian smoothing layers")
parser.add_argument('--lambd', default=1.0, type=float, help="Laplacian smoothing weights")
# training
parser.add_argument('--train_data_ratio', default=0.8, type=float, help="percentage of training data")
parser.add_argument('--lr', default=1e-4, type=float, help="learning rate")
parser.add_argument('--n_epoch', default=200, type=int, help="total training epochs")
parser.add_argument('--ckpts_interval', default=10, type=int, help="save checkpoints after each n epoch")
parser.add_argument('--report_training_loss', default=True, type=bool, help="if report training loss")
parser.add_argument('--save_model', default=True, type=bool, help="if save training models")
parser.add_argument('--save_mesh_train', default=False, type=bool, help="if save mesh during training")
# evaluation
parser.add_argument('--save_mesh_eval', default=False, type=bool, help="if save mesh during evaluation")
parser.add_argument('--n_test_pts', default=150000, type=int, help="num of points sampled for evaluation")

config = parser.parse_args()

return config
105 changes: 105 additions & 0 deletions PialNN/pialnn/1.0.0/data/dataload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import os
import numpy as np
import torch
from tqdm import tqdm
import nibabel as nib
from torch.utils.data import Dataset


"""
volume: brain MRI volume
v_in: vertices of input white matter surface
f_in: faces of ground truth pial surface
v_gt: vertices of input white matter surface
f_gt: faces of ground truth pial surface
"""

class BrainData():
def __init__(self, volume, v_in, v_gt, f_in, f_gt):
self.v_in = torch.Tensor(v_in)
self.v_gt = torch.Tensor(v_gt)
self.f_in = torch.LongTensor(f_in)
self.f_gt = torch.LongTensor(f_gt)
self.volume = torch.Tensor(volume).unsqueeze(0)


class BrainDataset(Dataset):
def __init__(self, data):
self.data = data

def __len__(self):
return len(self.data)

def __getitem__(self, i):
brain = self.data[i]
return brain.volume, brain.v_gt, \
brain.f_gt, brain.v_in, brain.f_in


def load_mri(path):

brain = nib.load(path)
brain_arr = brain.get_fdata()
brain_arr = brain_arr / 255.

# ====== change to your own transformation ======
# transpose and clip the data to [192,224,192]
brain_arr = brain_arr.transpose(1,2,0)
brain_arr = brain_arr[::-1,:,:]
brain_arr = brain_arr[:,:,::-1]
brain_arr = brain_arr[32:-32, 16:-16, 32:-32]
#================================================

return brain_arr.copy()


def load_surf(path):
v, f = nib.freesurfer.io.read_geometry(path)

# ====== change to your own transformation ======
# transpose and clip the data to [192,224,192]
v = v[:,[0,2,1]]
v[:,0] = v[:,0] - 32
v[:,1] = - v[:,1] - 15
v[:,2] = v[:,2] - 32

# normalize to [-1, 1]
v = v + 128
v = (v - [96, 112, 96]) / 112
f = f.astype(np.int32)
#================================================

return v, f


def load_data(data_path, hemisphere):
"""
data path: path of dataset
"""

subject_lists = sorted(os.listdir(data_path))

dataset = []

for i in tqdm(range(len(subject_lists))):

subid = subject_lists[i]

# load brain MRI
volume = load_mri(data_path + subid + '/mri/orig.mgz')

# load ground truth pial surface
v_gt, f_gt = load_surf(data_path + subid + '/surf/' + hemisphere + '.pial')
# v_gt, f_gt = load_surf(data_path + subid + '/surf/' + hemisphere + '.pial.deformed')

# load input white matter surface
v_in, f_in = load_surf(data_path + subid + '/surf/' + hemisphere + '.white')
# v_in, f_in = load_surf(data_path + subid + '/surf/' + hemisphere + '.white.deformed')

braindata = BrainData(volume=volume, v_gt=v_gt, f_gt=f_gt,
v_in=v_in, f_in=f_in)
dataset.append(braindata)

return dataset


Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
###
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
###
51 changes: 51 additions & 0 deletions PialNN/pialnn/1.0.0/docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
FROM nvidia/cuda:12.2.0-devel-ubuntu20.04
ENV LANG=C.UTF-8
ENV LC_ALL=C.UTF-8
ENV PATH=/opt/miniconda3/bin:$PATH
ENV PYTHONDONTWRITEBYTECODE=1
ENV PYTHONUNBUFFERED=1
ENV PYTHONIOENCODING=UTF-8
ENV PIPENV_VENV_IN_PROJECT=1
ENV JCC_JDK=/usr/lib/jvm/java-8-openjdk-amd64
RUN USE_CUDA=1
RUN CUDA_VERSION=12.2.0
RUN CUDNN_VERSION=8
RUN LINUX_DISTRO=ubuntu
RUN DISTRO_VERSION=20.04
RUN TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6"
RUN rm -f /etc/apt/apt.conf.d/docker-clean; \
echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' \
> /etc/apt/apt.conf.d/keep-cache
RUN apt-get update && DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata && apt-get install -y --no-install-recommends \
build-essential \
ca-certificates \
ccache \
curl \
git \
wget \
cmake \
gfortran \
libspatialindex-dev
RUN rm -rf /var/lib/apt/lists/*
ENV PYTHON_VERSION=3.7
ENV CONDA_URL=https://repo.anaconda.com/miniconda/Miniconda3-py37_4.10.3-Linux-x86_64.sh
RUN curl -fsSL -v -o ~/miniconda.sh -O ${CONDA_URL} && \
chmod +x ~/miniconda.sh && \
~/miniconda.sh -b -p /opt/miniconda3

WORKDIR /app
COPY pialnn.requirements.txt .
COPY environment.yml .
RUN conda env update -f environment.yml
SHELL ["conda", "run", "-n", "base", "/bin/bash", "-c"]
#RUN pip cache purge
RUN pip install torch-scatter==latest+cu102 -f https://pytorch-geometric.com/whl/torch-1.7.0.html
RUN pip install torch-sparse==latest+cu102 -f https://pytorch-geometric.com/whl/torch-1.7.0.html
RUN pip install torch-cluster==latest+cu102 -f https://pytorch-geometric.com/whl/torch-1.7.0.html
RUN pip install torch-spline-conv==latest+cu102 -f https://pytorch-geometric.com/whl/torch-1.7.0.html
RUN pip install torch-geometric
RUN pip install rtree
RUN pip install -r pialnn.requirements.txt
RUN conda install -c conda-forge libspatialindex=1.9.3
RUN conda clean -a
ENTRYPOINT ["/bin/bash", "-l", "-c"]
93 changes: 93 additions & 0 deletions PialNN/pialnn/1.0.0/docker/environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
name: base
channels:
- pytorch3d
- pytorch
- bottler
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- blas=1.0=mkl
- ca-certificates=2022.9.24=ha878542_0
- certifi=2022.9.24=pyhd8ed1ab_0
- colorama=0.4.6=pyhd8ed1ab_0
- cudatoolkit=10.2.89=hfd86e86_1
- dataclasses=0.8=pyhc8e2a94_3
- freetype=2.12.1=h4a9f257_0
- fvcore=0.1.5.post20220512=pyhd8ed1ab_0
- giflib=5.2.1=h7b6447c_0
- intel-openmp=2021.4.0=h06a4308_3561
- iopath=0.1.9=pyhd8ed1ab_0
- jpeg=9e=h7f8727e_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.38=h1181459_1
- lerc=3.0=h295c915_0
- libdeflate=1.8=h7f8727e_5
- libffi=3.3=he6710b0_2
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libpng=1.6.37=hbc83047_0
- libspatialindex=1.9.3=h9c3ff4c_4
- libstdcxx-ng=11.2.0=h1234567_1
- libtiff=4.4.0=hecacb30_0
- libuv=1.40.0=h7b6447c_0
- libwebp=1.2.4=h11a3e52_0
- libwebp-base=1.2.4=h5eee18b_0
- lz4-c=1.9.3=h295c915_1
- mkl=2021.4.0=h06a4308_640
- mkl-service=2.4.0=py37h7f8727e_0
- mkl_fft=1.3.1=py37hd3c417c_0
- mkl_random=1.2.2=py37h51133e4_0
- ncurses=6.3=h5eee18b_3
- ninja=1.10.2=h06a4308_5
- ninja-base=1.10.2=hd09550d_5
- numpy=1.21.5=py37h6c91a56_3
- numpy-base=1.21.5=py37ha15fc14_3
- nvidiacub=1.10.0=0
- openssl=1.1.1s=h7f8727e_0
- pillow=9.2.0=py37hace64e9_1
- portalocker=2.6.0=py37h89c1867_0
- python=3.7.13=haa1d7c7_1
- python_abi=3.7=2_cp37m
- pytorch=1.7.0=py3.7_cuda10.2.89_cudnn7.6.5_0
- pytorch3d=0.6.2=py37_cu102_pyt170
- pyyaml=6.0=py37h540881e_4
- readline=8.2=h5eee18b_0
- setuptools=65.4.0=py37h06a4308_0
- six=1.16.0=pyhd3eb1b0_1
- sqlite=3.39.3=h5082296_0
- tabulate=0.9.0=pyhd8ed1ab_1
- termcolor=2.0.1=pyhd8ed1ab_1
- tk=8.6.12=h1ccaba5_0
- torchvision=0.8.1=py37_cu102
- tqdm=4.64.1=pyhd8ed1ab_0
- typing_extensions=4.3.0=py37h06a4308_0
- wheel=0.37.1=pyhd3eb1b0_0
- xz=5.2.6=h5eee18b_0
- yacs=0.1.8=pyhd8ed1ab_0
- yaml=0.2.5=h7f98852_2
- zlib=1.2.13=h5eee18b_0
- zstd=1.5.2=ha4553b6_0
- pip:
- charset-normalizer==2.1.1
- idna==3.4
- jinja2==3.1.2
- joblib==1.2.0
- markupsafe==2.1.1
- nvidia-ml-py3==7.352.0
- packaging==21.3
- pip==18.0
- pyparsing==3.0.9
- requests==2.28.1
- rtree==1.0.1
- scikit-learn==1.0.2
- scipy==1.7.3
- threadpoolctl==3.1.0
- torch-cluster==1.5.8
- torch-geometric==2.1.0.post1
- torch-scatter==2.0.5
- torch-sparse==0.6.8
- torch-spline-conv==1.2.0
- trimesh==3.15.8
- urllib3==1.26.12
47 changes: 47 additions & 0 deletions PialNN/pialnn/1.0.0/docker/pialnn.requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
certifi==2022.9.24
charset-normalizer==2.1.1
colorama==0.4.6
dataclasses==0.8
freesurfer-surface==2.0.0
freetype-py==2.3.0
fvcore==0.1.5.post20220512
idna==3.3
iopath==0.1.9
Jinja2==3.1.2
joblib==1.2.0
MarkupSafe==2.1.1
mkl-fft==1.3.1
mkl-random==1.2.2
mkl-service==2.4.0
nibabel==3.2.1
nilearn==0.8.1
numpy==1.21.5
nvidia-ml-py3==7.352.0
packaging==21.3
Pillow==9.2.0
portalocker==2.6.0
PyOpenGL==3.1.0
pyparsing==3.0.9
pyrender==0.1.45
pytorch3d==0.6.2
PyYAML==6.0
requests==2.28.1
Rtree==1.0.1
scikit-learn==1.0.2
scipy==1.7.3
six==1.16.0
tabulate==0.9.0
termcolor==2.0.1
threadpoolctl==3.1.0
torch==1.7.0
torch-cluster==1.5.8
torch-geometric==2.1.0.post1
torch-scatter==2.0.5
torch-sparse==0.6.8
torch-spline-conv==1.2.0
torchvision==0.8.1
tqdm==4.64.1
trimesh==3.15.8
typing-extensions==4.3.0
urllib3==1.26.12
yacs==0.1.8
Loading