From 8279b4605cb1a8f8ed51ffc6fc8d23c4ccc258db Mon Sep 17 00:00:00 2001 From: jingpengw Date: Mon, 29 Aug 2022 11:25:05 -0400 Subject: [PATCH] semantic training --- README.md | 1 + environment.yml | 286 ++++++++++++++++++ neutorch/dataset/base.py | 13 +- neutorch/dataset/semantic.py | 104 +++++++ neutorch/dataset/synapses.py | 20 +- neutorch/{cli => train}/__init__.py | 0 neutorch/train/base.py | 212 +++++++++++++ .../post_synapses.py} | 0 .../pre_synapses.py} | 1 + neutorch/train/semantic.py | 39 +++ neutorch/{cli => train}/train_denoise.py | 0 setup.py | 7 +- 12 files changed, 663 insertions(+), 20 deletions(-) create mode 100644 environment.yml create mode 100644 neutorch/dataset/semantic.py rename neutorch/{cli => train}/__init__.py (100%) create mode 100644 neutorch/train/base.py rename neutorch/{cli/train_post_synapses.py => train/post_synapses.py} (100%) rename neutorch/{cli/train_pre_synapses.py => train/pre_synapses.py} (98%) create mode 100644 neutorch/train/semantic.py rename neutorch/{cli => train}/train_denoise.py (100%) diff --git a/README.md b/README.md index 5fb495d..d5083f0 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ Neuron segmentation and synapse detection using PyTorch # Features - [x] Training using whole terabyte or even petabyte of image volume. - [x] Training using multiple version of image datasets as data augmentation. +- [x] Data augmentation without zero-filling. # Install python setup.py install diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..4ecbce4 --- /dev/null +++ b/environment.yml @@ -0,0 +1,286 @@ +name: pytorch +channels: + - pytorch + - nvidia + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=1_gnu + - backcall=0.2.0=pyhd3eb1b0_0 + - blas=1.0=mkl + - blas-devel=3.9.0=11_linux64_mkl + - bzip2=1.0.8=h7b6447c_0 + - c-ares=1.18.1=h7f8727e_0 + - ca-certificates=2021.10.26=h06a4308_2 + - certifi=2021.10.8=py38h06a4308_2 + - charset-normalizer=2.0.4=pyhd3eb1b0_0 + - conda=4.11.0=py38h578d9bd_0 + - conda-package-handling=1.7.3=py38h27cfd23_1 + - cudatoolkit=11.1.74=h6bb024c_0 + - ffmpeg=4.2.2=h20bf706_0 + - freetype=2.11.0=h70c0345_0 + - giflib=5.2.1=h7b6447c_0 + - gmp=6.2.1=h2531618_2 + - gnutls=3.6.15=he1e5248_0 + - icu=58.2=he6710b0_3 + - intel-openmp=2021.4.0=h06a4308_3561 + - ipython=7.26.0=py38hb070fc8_0 + - jedi=0.18.0=py38h06a4308_1 + - jpeg=9b=h024ee3a_2 + - krb5=1.19.2=hac12032_0 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.35.1=h7274673_9 + - libarchive=3.4.2=h62408e4_0 + - libblas=3.9.0=11_linux64_mkl + - libcblas=3.9.0=11_linux64_mkl + - libcurl=7.80.0=h0b77cf5_0 + - libedit=3.1.20210910=h7f8727e_0 + - libev=4.33=h7f8727e_1 + - libffi=3.3=he6710b0_2 + - libgcc-ng=11.2.0=h1d223b6_11 + - libgomp=11.2.0=h1d223b6_11 + - libidn2=2.3.2=h7f8727e_0 + - liblapack=3.9.0=11_linux64_mkl + - liblapacke=3.9.0=11_linux64_mkl + - libnghttp2=1.46.0=hce63b2e_0 + - libopus=1.3.1=h7b6447c_0 + - libpng=1.6.37=hbc83047_0 + - libsolv=0.7.16=h8b12597_0 + - libssh2=1.9.0=h1ba5d50_1 + - libstdcxx-ng=11.2.0=he4da1e4_11 + - libtasn1=4.16.0=h27cfd23_0 + - libtiff=4.2.0=h85742a9_0 + - libunistring=0.9.10=h27cfd23_0 + - libuv=1.40.0=h7b6447c_0 + - libvpx=1.7.0=h439df22_0 + - libwebp=1.2.0=h89dd481_0 + - libwebp-base=1.2.0=h27cfd23_0 + - libxml2=2.9.12=h03d6c58_0 + - lz4-c=1.9.3=h295c915_1 + - mamba=0.7.3=py38h9709c9f_0 + - matplotlib-inline=0.1.2=pyhd3eb1b0_2 + - mkl=2021.3.0=h06a4308_520 + - mkl-devel=2021.3.0=h66538d2_520 + - mkl-include=2021.3.0=h06a4308_520 + - mkl-service=2.4.0=py38h7f8727e_0 + - mkl_fft=1.3.1=py38hd3c417c_0 + - mkl_random=1.2.2=py38h51133e4_0 + - ncurses=6.3=h7f8727e_2 + - nettle=3.7.3=hbbd107a_1 + - ninja=1.10.2=py38hd09550d_3 + - olefile=0.46=pyhd3eb1b0_0 + - openh264=2.1.1=h4ff587b_0 + - openssl=1.1.1m=h7f8727e_0 + - parso=0.8.3=pyhd3eb1b0_0 + - pexpect=4.8.0=pyhd3eb1b0_3 + - pickleshare=0.7.5=pyhd3eb1b0_1003 + - pip=21.2.4=py38h06a4308_0 + - prompt-toolkit=3.0.20=pyhd3eb1b0_0 + - ptyprocess=0.7.0=pyhd3eb1b0_2 + - pybind11=2.8.1=py38hd09550d_0 + - pycosat=0.6.3=py38h7b6447c_1 + - pygments=2.10.0=pyhd3eb1b0_0 + - pysocks=1.7.1=py38h06a4308_0 + - python=3.8.12=h12debd9_0 + - python_abi=3.8=2_cp38 + - pytorch=1.9.1=py3.8_cuda11.1_cudnn8.0.5_0 + - readline=8.1.2=h7f8727e_1 + - reproc=14.2.1=h36c2ea0_0 + - reproc-cpp=14.2.1=h58526e2_0 + - ruamel_yaml=0.15.100=py38h27cfd23_0 + - setuptools=58.0.4=py38h06a4308_0 + - sqlite=3.37.0=hc218d9a_0 + - tk=8.6.11=h1ccaba5_0 + - torchaudio=0.9.1=py38 + - traitlets=5.1.1=pyhd3eb1b0_0 + - typing_extensions=3.10.0.2=pyh06a4308_0 + - wcwidth=0.2.5=pyhd3eb1b0_0 + - x264=1!157.20191217=h7b6447c_0 + - xtensor=0.24.0=h4bd325d_0 + - xtensor-blas=0.20.0=h4bd325d_0 + - xtensor-python=0.26.0=py38h2b96118_1 + - xtl=0.7.4=h4bd325d_0 + - xz=5.2.5=h7b6447c_0 + - yaml=0.2.5=h7b6447c_0 + - zlib=1.2.11=h7f8727e_4 + - zstd=1.4.9=haebb681_0 + - pip: + - absl-py==0.12.0 + - argon2-cffi==21.3.0 + - astunparse==1.6.3 + - attrs==21.1.0 + - awscli==1.19.72 + - bleach==4.1.0 + - boto3==1.17.68 + - botocore==1.20.112 + - brotli==1.0.9 + - brotlipy==0.7.0 + - cachetools==4.2.2 + - cffi==1.14.5 + - chardet==4.0.0 + - click==8.1.3 + - cloud-files==4.9.1 + - cloud-volume==8.8.1 + - colorama==0.4.3 + - compressed-segmentation==2.2.0 + - compresso==3.0.0 + - connected-components-3d==3.2.0 + - crc32c==2.2 + - cryptography==3.4.7 + - cycler==0.10.0 + - cython==0.29.21 + - decorator==4.4.2 + - deflate==0.3.0 + - defusedxml==0.7.1 + - deprecated==1.2.12 + - detect-secrets==1.1.0 + - dijkstra3d==1.9.1 + - dill==0.3.3 + - docutils==0.15.2 + - dracopy==1.1.1 + - edt==2.1.0 + - einops==0.4.0 + - entrypoints==0.3 + - fasteners==0.16 + - fastremap==1.13.2 + - fill-voids==2.0.1 + - flatbuffers==2.0 + - fpzip==1.1.4 + - fvcore==0.1.5.post20220512 + - gast==0.5.3 + - gevent==21.1.2 + - google-api-core==1.26.3 + - google-apitools==0.5.32 + - google-auth==1.30.0 + - google-auth-oauthlib==0.4.4 + - google-cloud-core==1.6.0 + - google-cloud-storage==1.38.0 + - google-crc32c==1.1.2 + - google-pasta==0.2.0 + - google-resumable-media==1.2.0 + - googleapis-common-protos==1.53.0 + - greenlet==1.1.0 + - grpcio==1.37.1 + - h5py==3.2.1 + - httplib2==0.19.1 + - humanize==3.5.0 + - idna==2.10 + - imageio==2.9.0 + - inflection==0.5.1 + - iniconfig==1.1.1 + - iopath==0.1.10 + - ipykernel==6.7.0 + - ipywidgets==8.0.0b1 + - jinja2==3.0.3 + - jmespath==0.10.0 + - joblib==1.0.1 + - json5==0.9.5 + - jsonschema==3.2.0 + - jupyter==1.0.0 + - jupyter-client==7.1.2 + - jupyter-console==6.4.0 + - jupyter-core==4.9.1 + - jupyterlab-pygments==0.1.2 + - jupyterlab-widgets==2.0.0b1 + - keras==2.8.0rc1 + - keras-preprocessing==1.1.2 + - kimimaro==2.1.1 + - kiwisolver==1.3.1 + - libclang==12.0.0 + - markdown==3.3.4 + - matplotlib==3.4.1 + - mistune==0.8.4 + - multiprocess==0.70.11.1 + - nbclient==0.5.10 + - nbconvert==6.4.0 + - nbformat==5.1.3 + - nest-asyncio==1.5.4 + - networkx==2.5.1 + - neuroglancer==2.29 + - nibabel==3.2.1 + - nose==1.3.7 + - notebook==6.4.7 + - numpy==1.20.2 + - oauth2client==4.1.3 + - oauthlib==3.1.0 + - opencv-python==4.5.1.48 + - opt-einsum==3.3.0 + - orjson==3.5.2 + - packaging==20.9 + - pandas==1.4.3 + - pandocfilters==1.5.0 + - pathos==0.2.7 + - pillow==8.2.0 + - pluggy==0.13.1 + - portalocker==2.5.1 + - posix-ipc==1.0.5 + - pox==0.2.9 + - ppft==1.6.6.3 + - prometheus-client==0.12.0 + - protobuf==3.15.8 + - psutil==5.4.3 + - py==1.10.0 + - pyasn1==0.4.8 + - pyasn1-modules==0.2.8 + - pycparser==2.20 + - pynrrd==0.4.2 + - pyopenssl==20.0.1 + - pyparsing==2.4.7 + - pyrsistent==0.17.3 + - pysimdjson==3.2.0 + - pyspng-seunglab==1.0.0 + - pytest==6.2.4 + - python-dateutil==2.8.1 + - python-jsonschema-objects==0.3.14 + - pytz==2021.1 + - pywavelets==1.1.1 + - pyyaml==5.4.1 + - pyzmq==22.3.0 + - qtconsole==5.2.2 + - qtpy==2.0.0 + - requests==2.25.1 + - requests-oauthlib==1.3.0 + - rsa==4.7.2 + - s3transfer==0.4.0 + - scikit-image==0.18.1 + - scikit-learn==0.24.2 + - scipy==1.6.3 + - send2trash==1.8.1b0 + - simpleitk==1.2.4 + - simplejpeg==1.6.0 + - six==1.15.0 + - sockjs-tornado==1.0.7 + - tabulate==0.8.10 + - tenacity==7.0.0 + - tensorboard==2.8.0 + - tensorboard-data-server==0.6.0 + - tensorboard-plugin-wit==1.8.0 + - tensorflow==2.8.0rc1 + - tensorflow-io-gcs-filesystem==0.23.1 + - terminado==0.12.1 + - testpath==0.5.0 + - tf-estimator-nightly==2.8.0.dev2021122109 + - threadpoolctl==2.1.0 + - tifffile==2021.4.8 + - tinybrain==1.2.1 + - toml==0.10.2 + - torch==1.8.1 + - torch-tb-profiler==0.2.1 + - torchio==0.18.36 + - torchvision==0.9.1 + - tornado==6.1 + - tqdm==4.64.0 + - typing-extensions==3.7.4.3 + - urllib3==1.26.4 + - werkzeug==1.0.1 + - wheel==0.37.1 + - widgetsnbextension==4.0.0b1 + - wrapt==1.12.1 + - yacs==0.1.8 + - zmesh==0.5.0 + - zope-event==4.5.0 + - zope-interface==5.4.0 + - zstandard==0.15.2 +prefix: /mnt/home/jwu/anaconda3/envs/pytorch diff --git a/neutorch/dataset/base.py b/neutorch/dataset/base.py index 78b2f61..a0e32bd 100644 --- a/neutorch/dataset/base.py +++ b/neutorch/dataset/base.py @@ -1,3 +1,4 @@ +from abc import abstractproperty from typing import Union from functools import cached_property import math @@ -24,6 +25,12 @@ def worker_init_fn(worker_id: int): dataset.start = overall_start + worker_id * per_worker dataset.end = min(dataset.start + per_worker, overall_end) +def path_to_dataset_name(path: str, dataset_names: list): + for dataset_name in dataset_names: + if dataset_name in path: + return dataset_name + + class DatasetBase(torch.utils.data.IterableDataset): def __init__(self, @@ -48,8 +55,10 @@ def __init__(self, self.transform.shrink_size[:3] + \ self.transform.shrink_size[-3:] - # inherite this class and build the samples - self.samples = None + @cached_property + @abstractproperty + def samples(self): + pass @cached_property def sample_num(self): diff --git a/neutorch/dataset/semantic.py b/neutorch/dataset/semantic.py new file mode 100644 index 0000000..8073e67 --- /dev/null +++ b/neutorch/dataset/semantic.py @@ -0,0 +1,104 @@ +import os +from functools import cached_property + +from tqdm import tqdm + +from chunkflow.chunk import Chunk +from chunkflow.lib.cartesian_coordinate import Cartesian +from chunkflow.volume import Volume + +from neutorch.dataset.base import DatasetBase, path_to_dataset_name +from neutorch.dataset.ground_truth_sample import GroundTruthSample +from neutorch.dataset.transform import * + + +class SemanticDataset(DatasetBase): + def __init__(self, path_list: list, + sample_name_to_image_versions: dict, + patch_size: Cartesian = Cartesian(128, 128, 128)): + super().__init__(patch_size=patch_size) + + self.path_list = path_list + self.sample_name_to_image_versions = sample_name_to_image_versions + + self.vols = {} + for dataset_name, dir_list in sample_name_to_image_versions.items(): + vol_list = [] + for dir_path in dir_list: + vol = Volume.from_cloudvolume_path( + 'file://' + dir_path, + bounded = True, + fill_missing = False, + parallel = True, + green_threads = False, + ) + vol_list.append(vol) + self.vols[dataset_name] = vol_list + + self.compute_sample_weights() + self.setup_iteration_range() + + @cached_property + def samples(self): + samples = [] + for sem_path in tqdm(self.path_list): + assert os.path.exists(sem_path) + sem = Chunk.from_h5(sem_path) + + images = [] + dataset_name = path_to_dataset_name( + sem_path, + self.sample_name_to_image_versions.keys() + ) + for vol in self.vols[dataset_name]: + image = vol.cutout(sem.bbox) + images.append(image) + + target = (sem.array>0) + target = target.astype(np.float32) + sample = GroundTruthSample( + images, + target=target, + patch_size=self.patch_size_before_transform + ) + samples.append(sample) + + return samples + + def _prepare_transform(self): + self.transform = Compose([ + NormalizeTo01(probability=1.), + AdjustBrightness(), + AdjustContrast(), + Gamma(), + OneOf([ + Noise(), + GaussianBlur2D(), + ]), + BlackBox(), + Perspective2D(), + # RotateScale(probability=1.), + #DropSection(), + Flip(), + Transpose(), + MissAlignment(), + ]) + + +if __name__ == '__main__': + + from yacs.config import CfgNode + + cfg_file = '/mnt/home/jwu/wasp/jwu/15_rna_granule_net/11/config.yaml' + with open(cfg_file) as file: + cfg = CfgNode.load_cfg(file) + cfg.freeze() + + sd = SemanticDataset( + path_list=['/mnt/ceph/users/neuro/wasp_em/jwu/40_gt/13_wasp_sample3/vol_01700/rna_v1.h5'], + sample_name_to_image_versions=cfg.dataset.sample_name_to_image_versions, + patch_size=Cartesian(128, 128, 128), + ) + + # print(sd.samples) + diff --git a/neutorch/dataset/synapses.py b/neutorch/dataset/synapses.py index 81e138d..047f86e 100644 --- a/neutorch/dataset/synapses.py +++ b/neutorch/dataset/synapses.py @@ -1,11 +1,8 @@ -import os from time import time, sleep -from collections import OrderedDict from functools import cached_property from typing import Union, List import numpy as np -from scipy.stats import describe from chunkflow.lib.cartesian_coordinate import Cartesian, BoundingBox from chunkflow.lib.synapses import Synapses @@ -13,16 +10,9 @@ import torch -from neutorch.dataset.ground_truth_sample import PostSynapseGroundTruth -from neutorch.dataset.transform import * -from .base import DatasetBase -from .ground_truth_sample import GroundTruthSampleWithPointAnnotation - - -def syns_path_to_dataset_name(syns_path: str, dataset_names: list): - for dataset_name in dataset_names: - if dataset_name in syns_path: - return dataset_name +from .transform import * +from .base import DatasetBase, path_to_dataset_name +from .ground_truth_sample import GroundTruthSampleWithPointAnnotation, PostSynapseGroundTruth class SynapsesDatasetBase(DatasetBase): @@ -40,7 +30,7 @@ def __init__(self, vol = Volume.from_cloudvolume_path( 'file://' + dir_path, bounded = True, - fill_missing = True, + fill_missing = False, parallel=True, ) vol_list.append(vol) @@ -50,7 +40,7 @@ def __init__(self, def syns_path_to_images(self, syns_path: str, bbox: BoundingBox): images = [] - dataset_name = syns_path_to_dataset_name( + dataset_name = path_to_dataset_name( syns_path, self.sample_name_to_image_versions.keys() ) diff --git a/neutorch/cli/__init__.py b/neutorch/train/__init__.py similarity index 100% rename from neutorch/cli/__init__.py rename to neutorch/train/__init__.py diff --git a/neutorch/train/base.py b/neutorch/train/base.py new file mode 100644 index 0000000..c629816 --- /dev/null +++ b/neutorch/train/base.py @@ -0,0 +1,212 @@ +from abc import ABC, abstractproperty +from functools import cached_property +from glob import glob + +import random +import os +from time import time + +from yacs.config import CfgNode +import numpy as np + +from chunkflow.lib.cartesian_coordinate import Cartesian + +import torch +from torch.utils.tensorboard import SummaryWriter +from torch.utils.data import DataLoader +from neutorch.dataset.patch import collate_batch + +from neutorch.model.IsoRSUNet import Model +from neutorch.model.io import save_chkpt, load_chkpt, log_tensor +from neutorch.loss import BinomialCrossEntropyWithLogits +from neutorch.dataset.base import worker_init_fn + + +class TrainerBase(ABC): + def __init__(self, cfg: CfgNode, + batch_size: int = 1) -> None: + if isinstance(cfg, str) and os.path.exists(cfg): + with open(cfg) as file: + cfg = CfgNode.load_cfg(file) + cfg.freeze() + + if cfg.system.seed is not None: + random.seed(cfg.system.seed) + + self.cfg = cfg + self.batch_size = batch_size + self.patch_size=Cartesian.from_collection(cfg.train.patch_size) + + self._split_path_list() + + @cached_property + def path_list(self): + glob_path = os.path.expanduser(self.cfg.dataset.glob_path) + path_list = glob(glob_path, recursive=True) + path_list = sorted(path_list) + print(f'path_list \n: {path_list}') + assert len(path_list) > 1 + assert len(path_list) % 2 == 0, \ + "the image and synapses should be paired." + return path_list + + def _split_path_list(self): + training_path_list = [] + validation_path_list = [] + for path in self.path_list: + assignment_flag = False + for validation_name in self.cfg.dataset.validation_names: + if validation_name in path: + validation_path_list.append(path) + assignment_flag = True + + for test_name in self.cfg.dataset.test_names: + if test_name in path: + assignment_flag = True + + if not assignment_flag: + training_path_list.append(path) + + print(f'split {len(self.path_list)} ground truth samples to {len(training_path_list)} training samples, {len(validation_path_list)} validation samples, and {len(self.path_list)-len(training_path_list)-len(validation_path_list)} test samples.') + self.training_path_list = training_path_list + self.validation_path_list = validation_path_list + + @cached_property + def model(self): + model = Model(self.cfg.model.in_channels, self.cfg.model.out_channels) + if torch.cuda.is_available(): + device = torch.device("cuda") + gpu_num = torch.cuda.device_count() + print("Let's use ", gpu_num, " GPUs!") + model = torch.nn.DataParallel( + model, + device_ids=list(range(gpu_num)), + dim=0, + ) + # we normally use one batch for each GPU + self.batch_size *= gpu_num + else: + device = torch.device("cpu") + + # note that we have to wrap the nn.DataParallel(model) before + # loading the model since the dictionary is changed after the wrapping + model = load_chkpt( + model, + self.cfg.train.output_dir, + self.cfg.train.iter_start) + print('send model to device: ', device) + model = model.to(device) + return model + + @cached_property + def optimizer(self): + return torch.optim.Adam( + self.model.parameters(), + lr=self.cfg.train.learning_rate + ) + + + @cached_property + def loss_module(self): + return BinomialCrossEntropyWithLogits() + + @cached_property + @abstractproperty + def training_dataset(self): + pass + + @cached_property + @abstractproperty + def validation_dataset(self): + pass + + @cached_property + def training_data_loader(self): + training_data_loader = DataLoader( + self.training_dataset, + #num_workers=self.cfg.system.cpus, + num_workers=1, + prefetch_factor=1, + drop_last=False, + multiprocessing_context='spawn', + collate_fn=collate_batch, + worker_init_fn=worker_init_fn, + batch_size=self.batch_size, + ) + return training_data_loader + + @cached_property + def validation_data_loader(self): + validation_data_loader = DataLoader( + self.validation_dataset, + num_workers=1, + prefetch_factor=2, + drop_last=False, + multiprocessing_context='spawn', + collate_fn=collate_batch, + batch_size=self.batch_size, + ) + return validation_data_loader + + @cached_property + def validation_data_iter(self): + validation_data_iter = iter(self.validation_data_loader) + return validation_data_iter + + @cached_property + def voxel_num(self): + return np.product(self.patch_size) * self.batch_size + + def __call__(self) -> None: + writer = SummaryWriter(log_dir=self.cfg.train.output_dir) + accumulated_loss = 0. + iter_idx = self.cfg.train.iter_start + for image, target in self.training_data_loader: + iter_idx += 1 + if iter_idx> self.cfg.train.iter_stop: + print('exceeds the maximum iteration: ', self.cfg.train.iter_stop) + return + + ping = time() + # print(f'preparing patch takes {round(time()-ping, 3)} seconds') + logits = self.model(image) + loss = self.loss_module(logits, target) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + accumulated_loss += loss.tolist() + print(f'iteration {iter_idx} takes {round(time()-ping, 3)} seconds.') + + if iter_idx % self.cfg.train.training_interval == 0 and iter_idx > 0: + per_voxel_loss = accumulated_loss / \ + self.cfg.train.training_interval / \ + self.voxel_num + + print(f'training loss {round(per_voxel_loss, 3)}') + accumulated_loss = 0. + predict = torch.sigmoid(logits) + writer.add_scalar('Loss/train', per_voxel_loss, iter_idx) + log_tensor(writer, 'train/image', image, iter_idx) + log_tensor(writer, 'train/prediction', predict, iter_idx) + log_tensor(writer, 'train/target', target, iter_idx) + + if iter_idx % self.cfg.train.validation_interval == 0 and iter_idx > 0: + fname = os.path.join(self.cfg.train.output_dir, f'model_{iter_idx}.chkpt') + print(f'save model to {fname}') + save_chkpt(self.model, self.cfg.train.output_dir, iter_idx, self.optimizer) + + print('evaluate prediction: ') + validation_image, validation_target = next(self.validation_data_iter) + + with torch.no_grad(): + validation_logits = self.model(validation_image) + validation_predict = torch.sigmoid(validation_logits) + validation_loss = self.loss_module(validation_logits, validation_target) + per_voxel_loss = validation_loss.tolist() / self.voxel_num + print(f'iter {iter_idx}: validation loss: {round(per_voxel_loss, 3)}') + writer.add_scalar('Loss/validation', per_voxel_loss, iter_idx) + log_tensor(writer, 'evaluate/image', validation_image, iter_idx) + log_tensor(writer, 'evaluate/prediction', validation_predict, iter_idx) + log_tensor(writer, 'evaluate/target', validation_target, iter_idx) + + writer.close() diff --git a/neutorch/cli/train_post_synapses.py b/neutorch/train/post_synapses.py similarity index 100% rename from neutorch/cli/train_post_synapses.py rename to neutorch/train/post_synapses.py diff --git a/neutorch/cli/train_pre_synapses.py b/neutorch/train/pre_synapses.py similarity index 98% rename from neutorch/cli/train_pre_synapses.py rename to neutorch/train/pre_synapses.py index b0e836e..a7a4195 100644 --- a/neutorch/cli/train_pre_synapses.py +++ b/neutorch/train/pre_synapses.py @@ -82,6 +82,7 @@ def main(config_file: str): else: device = torch.device("cpu") + # since we trained this model using DataParallel, we have to wrap it with DataParallel as well in the inference stage. # note that we have to wrap the nn.DataParallel(model) before # loading the model since the dictionary is changed after the wrapping model = load_chkpt(model, cfg.train.output_dir, cfg.train.iter_start) diff --git a/neutorch/train/semantic.py b/neutorch/train/semantic.py new file mode 100644 index 0000000..825f0cc --- /dev/null +++ b/neutorch/train/semantic.py @@ -0,0 +1,39 @@ +from functools import cached_property + +import click +from yacs.config import CfgNode + +from .base import TrainerBase +from neutorch.dataset.semantic import SemanticDataset + + +class SemanticTrainer(TrainerBase): + def __init__(self, cfg: CfgNode, batch_size: int = 1) -> None: + super().__init__(cfg, batch_size) + + @cached_property + def training_dataset(self): + return SemanticDataset( + self.training_path_list, + self.cfg.dataset.sample_name_to_image_versions, + patch_size=self.patch_size, + ) + + @cached_property + def validation_dataset(self): + return SemanticDataset( + self.validation_path_list, + self.cfg.dataset.sample_name_to_image_versions, + patch_size=self.patch_size, + ) + + +@click.command() +@click.option('--config-file', '-c', + type=click.Path(exists=True, dir_okay=False, file_okay=True, readable=True, resolve_path=True), + default='./config.yaml', + help = 'configuration file containing all the parameters.' +) +def main(config_file: str): + trainer = SemanticTrainer(config_file) + trainer() \ No newline at end of file diff --git a/neutorch/cli/train_denoise.py b/neutorch/train/train_denoise.py similarity index 100% rename from neutorch/cli/train_denoise.py rename to neutorch/train/train_denoise.py diff --git a/setup.py b/setup.py index d556900..e06bedc 100755 --- a/setup.py +++ b/setup.py @@ -12,9 +12,10 @@ packages=find_packages(exclude=['bin']), entry_points=''' [console_scripts] - neutrain-pre=neutorch.cli.train_pre_synapses:main - neutrain-denoise=neutorch.cli.train_denoise:main - neutrain-post=neutorch.cli.train_post_synapses:main + neutrain-sem=neutorch.train.semantic:main + neutrain-pre=neutorch.train.pre_synapses:main + neutrain-denoise=neutorch.train.denoise:main + neutrain-post=neutorch.train.post_synapses:main ''', classifiers=[ 'Development Status :: 4 - Beta',