diff --git a/.dockerignore b/.dockerignore
new file mode 100644
index 0000000..1068f81
--- /dev/null
+++ b/.dockerignore
@@ -0,0 +1,3 @@
+test/
+.git/
+*.tar.gz
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000..d1bf892
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,42 @@
+FROM pytorch/pytorch
+
+RUN groupadd -r user && useradd -m --no-log-init -r -g user user
+
+RUN mkdir -p /opt/app /input /output \
+ && chown user:user /opt/app /input /output
+
+USER user
+WORKDIR /opt/app
+
+ENV PATH="/home/user/.local/bin:${PATH}"
+
+RUN python -m pip install --user -U pip && python -m pip install --user pip-tools && python -m pip install --upgrade pip
+COPY --chown=user:user nnUNet/ /opt/app/nnUNet/
+RUN python -m pip install -e nnUNet
+#RUN python -m pip uninstall -y scipy
+#RUN python -m pip install --user --upgrade scipy
+
+COPY --chown=user:user requirements.txt /opt/app/
+RUN python -m pip install --user -r requirements.txt
+
+
+# This is the checkpoint file, uncomment the line below and modify /local/path/to/the/checkpoint to your needs
+COPY --chown=user:user nnUNetTrainer__nnUNetPlans__3d_fullres.zip /opt/algorithm/checkpoint/nnUNet/
+RUN python -c "import zipfile; import os; zipfile.ZipFile('/opt/algorithm/checkpoint/nnUNet/nnUNetTrainer__nnUNetPlans__3d_fullres.zip').extractall('/opt/algorithm/checkpoint/nnUNet/')"
+
+COPY --chown=user:user custom_algorithm.py /opt/app/
+COPY --chown=user:user process.py /opt/app/
+
+# COPY --chown=user:user weights /opt/algorithm/checkpoint
+ENV nnUNet_results="/opt/algorithm/checkpoint/"
+ENV nnUNet_raw="/opt/algorithm/nnUNet_raw_data_base"
+ENV nnUNet_preprocessed="/opt/algorithm/preproc"
+# ENV ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS=64(nope!)
+# ENV nnUNet_def_n_proc=1
+
+#ENTRYPOINT [ "python3", "-m", "process" ]
+
+ENV MKL_SERVICE_FORCE_INTEL=1
+
+# Launches the script
+ENTRYPOINT python -m process $0 $@
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..e68b7fb
--- /dev/null
+++ b/README.md
@@ -0,0 +1,13 @@
+# DMX Solution to HaNSeg Challenge
+
+The Head and Neck oragan-at-risk CT & MR segmentation challenge. Contribution to the Grand Challenge (MICCAI 2023)
+
+Challenge URL: **[HaN-Seg 2023 challenge](https://han-seg2023.grand-challenge.org/)**
+
+This solution is based on:
+
+ - [ANTsPY](https://antspy.readthedocs.io/en/latest/)
+ - [nnUNetv2](https://github.com/MIC-DKFZ/nnUNet/)
+ - [Zhack47](https://github.com/Zhack47/HaNSeg-QuantIF)
+
+
diff --git a/build.sh b/build.sh
new file mode 100755
index 0000000..b9c52ff
--- /dev/null
+++ b/build.sh
@@ -0,0 +1,4 @@
+#!/usr/bin/env bash
+SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )"
+# docker build --no-cache -t hanseg2023algorithm "$SCRIPTPATH"
+docker build -t hanseg2023algorithm_dmx "$SCRIPTPATH"
diff --git a/custom_algorithm.py b/custom_algorithm.py
new file mode 100644
index 0000000..fb688f4
--- /dev/null
+++ b/custom_algorithm.py
@@ -0,0 +1,195 @@
+import ants
+import SimpleITK as sitk
+
+from evalutils import SegmentationAlgorithm
+
+import logging
+from pathlib import Path
+from typing import (
+ Optional,
+ Pattern,
+ Tuple,
+)
+
+from pandas import DataFrame
+from evalutils.exceptions import FileLoaderError, ValidationError
+from evalutils.validators import DataFrameValidator
+from evalutils.io import (
+ ImageLoader,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class HanSegUniquePathIndicesValidator(DataFrameValidator):
+ """
+ Validates that the indicies from the filenames are unique
+ """
+
+ def validate(self, *, df: DataFrame):
+ try:
+ paths_ct = df["path_ct"].tolist()
+ paths_mrt1 = df["path_mrt1"].tolist()
+ paths = paths_ct + paths_mrt1
+ except KeyError:
+ raise ValidationError(
+ "Column `path_ct` or `path_mrt1` not found in DataFrame."
+ )
+
+ assert len(paths_ct) == len(
+ paths_mrt1
+ ), "The number of CT and MR images is not equal."
+
+
+class HanSegUniqueImagesValidator(DataFrameValidator):
+ """
+ Validates that each image in the set is unique
+ """
+
+ def validate(self, *, df: DataFrame):
+ try:
+ hashes_ct = df["hash_ct"].tolist()
+ hashes_mrt1 = df["hash_mrt1"].tolist()
+ hashes = hashes_ct + hashes_mrt1
+ except KeyError:
+ raise ValidationError(
+ "Column `hash_ct` or `hash_mrt1` not found in DataFrame."
+ )
+
+ if len(set(hashes)) != len(hashes):
+ raise ValidationError(
+ "The images are not unique, please submit a unique image for "
+ "each case."
+ )
+
+
+class Hanseg2023Algorithm(SegmentationAlgorithm):
+ def __init__(
+ self,
+ input_path=Path("/input/images/"),
+ output_path=Path("/output/images/head_neck_oar/"),
+ **kwargs,
+ ):
+ super().__init__(
+ validators=dict(
+ input_image=(
+ HanSegUniqueImagesValidator(),
+ HanSegUniquePathIndicesValidator(),
+ )
+ ),
+ input_path=input_path,
+ output_path=output_path,
+ **kwargs,
+ )
+
+ def _load_input_image(self, *, case) -> Tuple[sitk.Image, Path]:
+ input_image_file_path_ct = case["path_ct"]
+ input_image_file_path_mrt1 = case["path_mrt1"]
+
+ input_image_file_loader = self._file_loaders["input_image"]
+ if not isinstance(input_image_file_loader, ImageLoader):
+ raise RuntimeError("The used FileLoader was not of subclass ImageLoader")
+
+ # Load the image for this case
+
+ #input_image_ct = input_image_file_loader.load_image(input_image_file_path_ct)
+ #input_image_mrt1 = input_image_file_loader.load_image(
+ # input_image_file_path_mrt1
+ #)
+ # Ok so hear me out...
+ # Instead of loading the nrrd files as SimpleITK images, shifting to ants Image,
+ # doing the registration, then back to SITK Image, we load them directly with ants
+ # I did this back when time limit was 5 min, to win seconds
+ input_image_ct = ants.image_read(input_image_file_path_ct.__str__())
+ input_image_mrt1 = ants.image_read(input_image_file_path_mrt1.__str__())
+
+ # Check that it is the expected image
+ #if input_image_file_loader.hash_image(input_image_ct) != case["hash_ct"]:
+ # raise RuntimeError("CT image hashes do not match")
+ #if input_image_file_loader.hash_image(input_image_mrt1) != case["hash_mrt1"]:
+ # raise RuntimeError("MR image hashes do not match")
+
+ return (
+ input_image_ct,
+ input_image_file_path_ct,
+ input_image_mrt1,
+ input_image_file_path_mrt1,
+ )
+
+ def process_case(self, *, idx, case):
+ # Load and test the image for this case
+ (
+ input_image_ct,
+ input_image_file_path_ct,
+ input_image_mrt1,
+ input_image_file_path_mrt1,
+ ) = self._load_input_image(case=case)
+
+ # Segment nodule candidates
+ segmented_nodules = self.predict(
+ image_ct=input_image_ct, image_mrt1=input_image_mrt1
+ )
+
+ # Write resulting segmentation to output location
+ segmentation_path = self._output_path / input_image_file_path_ct.name.replace(
+ "_CT", "_seg"
+ )
+ self._output_path.mkdir(parents=True, exist_ok=True)
+ sitk.WriteImage(segmented_nodules, str(segmentation_path), True)
+
+ # Write segmentation file path to result.json for this case
+ return {
+ "outputs": [dict(type="metaio_image", filename=segmentation_path.name)],
+ "inputs": [
+ dict(type="metaio_ct_image", filename=input_image_file_path_ct.name),
+ dict(
+ type="metaio_mrt1_image", filename=input_image_file_path_mrt1.name
+ ),
+ ],
+ "error_messages": [],
+ }
+
+ def _load_cases(
+ self,
+ *,
+ folder: Path,
+ file_loader: ImageLoader,
+ file_filter: Optional[Pattern[str]] = None,
+ ) -> DataFrame:
+ cases = []
+
+ paths_ct = sorted(folder.glob("ct/*"), key=self._file_sorter_key)
+ paths_mrt1 = sorted(folder.glob("t1-mri/*"), key=self._file_sorter_key)
+
+ for pth_ct, pth_mr in zip(paths_ct, paths_mrt1):
+ if file_filter is None or (
+ file_filter.match(str(pth_ct)) and file_filter.match(str(pth_mr))
+ ):
+ try:
+ case_ct = file_loader.load(fname=pth_ct)[0]
+ case_mrt1 = file_loader.load(fname=pth_mr)[0]
+ new_cases = [
+ {
+ "hash_ct": case_ct["hash"],
+ "path_ct": case_ct["path"],
+ "hash_mrt1": case_mrt1["hash"],
+ "path_mrt1": case_mrt1["path"],
+ }
+ ]
+ except FileLoaderError:
+ logger.warning(
+ f"Could not load {pth_ct.name} or {pth_mr.name} using {file_loader}."
+ )
+ else:
+ cases += new_cases
+ else:
+ logger.info(
+ f"Skip loading {pth_ct.name} and {pth_mr.name} because it doesn't match {file_filter}."
+ )
+
+ if len(cases) == 0:
+ raise FileLoaderError(
+ f"Could not load any files in {folder} with " f"{file_loader}."
+ )
+
+ return DataFrame(cases)
diff --git a/export.sh b/export.sh
new file mode 100755
index 0000000..54abcec
--- /dev/null
+++ b/export.sh
@@ -0,0 +1,5 @@
+#!/usr/bin/env bash
+
+./build.sh
+
+docker save hanseg2023algorithm_dmx | gzip -c > HanSeg2023AlgorithmDMX.tar.gz
diff --git a/nnUNet/LICENSE b/nnUNet/LICENSE
new file mode 100644
index 0000000..8bbe09c
--- /dev/null
+++ b/nnUNet/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [2019] [Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
\ No newline at end of file
diff --git a/nnUNet/documentation/__init__.py b/nnUNet/documentation/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/documentation/benchmarking.md b/nnUNet/documentation/benchmarking.md
new file mode 100644
index 0000000..108a7f2
--- /dev/null
+++ b/nnUNet/documentation/benchmarking.md
@@ -0,0 +1,115 @@
+# nnU-Netv2 benchmarks
+
+Does your system run like it should? Is your epoch time longer than expected? What epoch times should you expect?
+
+Look no further for we have the solution here!
+
+## What does the nnU-netv2 benchmark do?
+
+nnU-Net's benchmark trains models for 5 epochs. At the end, the fastest epoch will
+be noted down, along with the GPU name, torch version and cudnn version. You can find the benchmark output in the
+corresponding nnUNet_results subfolder (see example below). Don't worry, we also provide scripts to collect your
+results. Or you just start a benchmark and look at the console output. Everything is possible. Nothing is forbidden.
+
+The benchmark implementation revolves around two trainers:
+- `nnUNetTrainerBenchmark_5epochs` runs a regular training for 5 epochs. When completed, writes a .json file with the fastest
+epoch time as well as the GPU used and the torch and cudnn versions. Useful for speed testing the entire pipeline
+(data loading, augmentation, GPU training)
+- `nnUNetTrainerBenchmark_5epochs_noDataLoading` is the same, but it doesn't do any data loading or augmentation. It
+just presents dummy arrays to the GPU. Useful for checking pure GPU speed.
+
+## How to run the nnU-Netv2 benchmark?
+It's quite simple, actually. It looks just like a regular nnU-Net training.
+
+We provide reference numbers for some of the Medical Segmentation Decathlon datasets because they are easily
+accessible: [download here](https://drive.google.com/drive/folders/1HqEgzS8BV2c7xYNrZdEAnrHk7osJJ--2). If it needs to be
+quick and dirty, focus on Tasks 2 and 4. Download and extract the data and convert them to the nnU-Net format with
+`nnUNetv2_convert_MSD_dataset`.
+Run `nnUNetv2_plan_and_preprocess` for them.
+
+Then, for each dataset, run the following commands (only one per GPU! Or one after the other):
+
+```bash
+nnUNetv2_train DATSET_ID 2d 0 -tr nnUNetTrainerBenchmark_5epochs
+nnUNetv2_train DATSET_ID 3d_fullres 0 -tr nnUNetTrainerBenchmark_5epochs
+nnUNetv2_train DATSET_ID 2d 0 -tr nnUNetTrainerBenchmark_5epochs_noDataLoading
+nnUNetv2_train DATSET_ID 3d_fullres 0 -tr nnUNetTrainerBenchmark_5epochs_noDataLoading
+```
+
+If you want to inspect the outcome manually, check (for example!) your
+`nnUNet_results/DATASET_NAME/nnUNetTrainerBenchmark_5epochs__nnUNetPlans__3d_fullres/fold_0/` folder for the `benchmark_result.json` file.
+
+Note that there can be multiple entries in this file if the benchmark was run on different GPU types, torch versions or cudnn versions!
+
+If you want to summarize your results like we did in our [results](#results), check the
+[summary script](../nnunetv2/batch_running/benchmarking/summarize_benchmark_results.py). Here you need to change the
+torch version, cudnn version and dataset you want to summarize, then execute the script. You can find the exact
+values you need to put there in one of your `benchmark_result.json` files.
+
+## Results
+We have tested a variety of GPUs and summarized the results in a
+[spreadsheet](https://docs.google.com/spreadsheets/d/12Cvt_gr8XU2qWaE0XJk5jJlxMEESPxyqW0CWbQhTNNY/edit?usp=sharing).
+Note that you can select the torch and cudnn versions at the bottom! There may be comments in this spreadsheet. Read them!
+
+## Result interpretation
+
+Results are shown as epoch time in seconds. Lower is better (duh). Epoch times can fluctuate between runs, so as
+long as you are within like 5-10% of the numbers we report, everything should be dandy.
+
+If not, here is how you can try to find the culprit!
+
+The first thing to do is to compare the performance between the `nnUNetTrainerBenchmark_5epochs_noDataLoading` and
+`nnUNetTrainerBenchmark_5epochs` trainers. If the difference is about the same as we report in our spreadsheet, but
+both your numbers are worse, the problem is with your GPU:
+
+- Are you certain you compare the correct GPU? (duh)
+- If yes, then you might want to install PyTorch in a different way. Never `pip install torch`! Go to the
+[PyTorch installation](https://pytorch.org/get-started/locally/) page, select the most recent cuda version your
+system supports and only then copy and execute the correct command! Either pip or conda should work
+- If the problem is still not fixed, we recommend you try
+[compiling pytorch from source](https://github.com/pytorch/pytorch#from-source). It's more difficult but that's
+how we roll here at the DKFZ (at least the cool kids here).
+- Another thing to consider is to try exactly the same torch + cudnn version as we did in our spreadsheet.
+Sometimes newer versions can actually degrade performance and there might be bugs from time to time. Older versions
+are also often a lot slower!
+- Finally, some very basic things that could impact your GPU performance:
+ - Is the GPU cooled adequately? Check the temperature with `nvidia-smi`. Hot GPUs throttle performance in order to not self-destruct
+ - Is your OS using the GPU for displaying your desktop at the same time? If so then you can expect a performance
+ penalty (I dunno like 10% !?). That's expected and OK.
+ - Are other users using the GPU as well?
+
+
+If you see a large performance difference between `nnUNetTrainerBenchmark_5epochs_noDataLoading` (fast) and
+`nnUNetTrainerBenchmark_5epochs` (slow) then the problem might be related to data loading and augmentation. As a
+reminder, nnU-net does not use pre-augmented images (offline augmentation) but instead generates augmented training
+samples on the fly during training (no, you cannot switch it to offline). This requires that your system can do partial
+reads of the image files fast enough (SSD storage required!) and that your CPU is powerful enough to run the augmentations.
+
+Check the following:
+
+- [CPU bottleneck] How many CPU threads are running during the training? nnU-Net uses 12 processes for data augmentation by default.
+If you see those 12 running constantly during training, consider increasing the number of processes used for data
+augmentation (provided there is headroom on your CPU!). Increase the number until you see less active workers than
+you configured (or just set the number to 32 and forget about it). You can do so by setting the `nnUNet_n_proc_DA`
+environment variable (Linux: `export nnUNet_n_proc_DA=24`). Read [here](set_environment_variables.md) on how to do this.
+If your CPU does not support more processes (setting more processes than your CPU has threads makes
+no sense!) you are out of luck and in desperate need of a system upgrade!
+- [I/O bottleneck] If you don't see 12 (or nnUNet_n_proc_DA if you set it) processes running but your training times
+are still slow then open up `top` (sorry, Windows users. I don't know how to do this on Windows) and look at the value
+left of 'wa' in the row that begins
+with '%Cpu (s)'. If this is >1.0 (arbitrarily set threshold here, essentially look for unusually high 'wa'. In a
+healthy training 'wa' will be almost 0) then your storage cannot keep up with data loading. Make sure to set
+nnUNet_preprocessed to a folder that is located on an SSD. nvme is preferred over SATA. PCIe3 is enough. 3000MB/s
+sequential read recommended.
+- [funky stuff] Sometimes there is funky stuff going on, especially when batch sizes are large, files are small and
+patch sizes are small as well. As part of the data loading process, nnU-Net needs to open and close a file for each
+training sample. Now imagine a dataset like Dataset004_Hippocampus where for the 2d config we have a batch size of
+366 and we run 250 iterations in <10s on an A100. That's a lotta files per second (366 * 250 / 10 = 9150 files per second).
+Oof. If the files are on some network drive (even if it's nvme) then (probably) good night. The good news: nnU-Net
+has got you covered: add `export nnUNet_keep_files_open=True` to your .bashrc and the problem goes away. The neat
+part: it causes new problems if you are not allowed to have enough open files. You may have to increase the number
+of allowed open files. `ulimit -n` gives your current limit (Linux only). It should not be something like 1024.
+Increasing that to 65535 works well for me. See here for how to change these limits:
+[Link](https://kupczynski.info/posts/ubuntu-18-10-ulimits/)
+(works for Ubuntu 18, google for your OS!).
+
diff --git a/nnUNet/documentation/changelog.md b/nnUNet/documentation/changelog.md
new file mode 100644
index 0000000..0b56e44
--- /dev/null
+++ b/nnUNet/documentation/changelog.md
@@ -0,0 +1,51 @@
+# What is different in v2?
+
+- We now support **hierarchical labels** (named regions in nnU-Net). For example, instead of training BraTS with the
+'edema', 'necrosis' and 'enhancing tumor' labels you can directly train it on the target areas 'whole tumor',
+'tumor core' and 'enhancing tumor'. See [here](region_based_training.md) for a detailed description + also have a look at the
+[BraTS 2021 conversion script](../nnunetv2/dataset_conversion/Dataset137_BraTS21.py).
+- Cross-platform support. Cuda, mps (Apple M1/M2) and of course CPU support! Simply select the device with
+`-device` in `nnUNetv2_train` and `nnUNetv2_predict`.
+- Unified trainer class: nnUNetTrainer. No messing around with cascaded trainer, DDP trainer, region-based trainer,
+ignore trainer etc. All default functionality is in there!
+- Supports more input/output data formats through ImageIO classes.
+- I/O formats can be extended by implementing new Adapters based on `BaseReaderWriter`.
+- The nnUNet_raw_cropped folder no longer exists -> saves disk space at no performance penalty. magic! (no jk the
+saving of cropped npz files was really slow, so it's actually faster to crop on the fly).
+- Preprocessed data and segmentation are stored in different files when unpacked. Seg is stored as int8 and thus
+takes 1/4 of the disk space per pixel (and I/O throughput) as in v1.
+- Native support for multi-GPU (DDP) TRAINING.
+Multi-GPU INFERENCE should still be run with `CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] -num_parts Y -part_id X`.
+There is no cross-GPU communication in inference, so it doesn't make sense to add additional complexity with DDP.
+- All nnU-Net functionality is now also accessible via API. Check the corresponding entry point in `setup.py` to see
+what functions you need to call.
+- Dataset fingerprint is now explicitly created and saved in a json file (see nnUNet_preprocessed).
+
+- Complete overhaul of plans files (read also [this](explanation_plans_files.md):
+ - Plans are now .json and can be opened and read more easily
+ - Configurations are explicitly named ("3d_fullres" , ...)
+ - Configurations can inherit from each other to make manual experimentation easier
+ - A ton of additional functionality is now included in and can be changed through the plans, for example normalization strategy, resampling etc.
+ - Stages of the cascade are now explicitly listed in the plans. 3d_lowres has 'next_stage' (which can also be a
+ list of configurations!). 3d_cascade_fullres has a 'previous_stage' entry. By manually editing plans files you can
+ now connect anything you want, for example 2d with 3d_fullres or whatever. Be wild! (But don't create cycles!)
+ - Multiple configurations can point to the same preprocessed data folder to save disk space. Careful! Only
+ configurations that use the same spacing, resampling, normalization etc. should share a data source! By default,
+ 3d_fullres and 3d_cascade_fullres share the same data
+ - Any number of configurations can be added to the plans (remember to give them a unique "data_identifier"!)
+
+Folder structures are different and more user-friendly:
+- nnUNet_preprocessed
+ - By default, preprocessed data is now saved as: `nnUNet_preprocessed/DATASET_NAME/PLANS_IDENTIFIER_CONFIGURATION` to clearly link them to their corresponding plans and configuration
+ - Name of the folder containing the preprocessed images can be adapted with the `data_identifier` key.
+- nnUNet_results
+ - Results are now sorted as follows: DATASET_NAME/TRAINERCLASS__PLANSIDENTIFIER__CONFIGURATION/FOLD
+
+## What other changes are planned and not yet implemented?
+- Integration into MONAI (together with our friends at Nvidia)
+- New pretrained weights for a large number of datasets (coming very soon))
+
+
+[//]: # (- nnU-Net now also natively supports an **ignore label**. Pixels with this label will not contribute to the loss. )
+
+[//]: # (Use this to learn from sparsely annotated data, or excluding irrelevant areas from training. Read more [here](ignore_label.md).)
\ No newline at end of file
diff --git a/nnUNet/documentation/convert_msd_dataset.md b/nnUNet/documentation/convert_msd_dataset.md
new file mode 100644
index 0000000..4c4ee48
--- /dev/null
+++ b/nnUNet/documentation/convert_msd_dataset.md
@@ -0,0 +1,3 @@
+Use `nnUNetv2_convert_MSD_dataset`.
+
+Read `nnUNetv2_convert_MSD_dataset -h` for usage instructions.
\ No newline at end of file
diff --git a/nnUNet/documentation/dataset_format.md b/nnUNet/documentation/dataset_format.md
new file mode 100644
index 0000000..e11d8b2
--- /dev/null
+++ b/nnUNet/documentation/dataset_format.md
@@ -0,0 +1,232 @@
+# nnU-Net dataset format
+The only way to bring your data into nnU-Net is by storing it in a specific format. Due to nnU-Net's roots in the
+[Medical Segmentation Decathlon](http://medicaldecathlon.com/) (MSD), its dataset is heavily inspired but has since
+diverged (see also [here](#how-to-use-decathlon-datasets)) from the format used in the MSD.
+
+Datasets consist of three components: raw images, corresponding segmentation maps and a dataset.json file specifying
+some metadata.
+
+If you are migrating from nnU-Net v1, read [this](#how-to-use-nnu-net-v1-tasks) to convert your existing Tasks.
+
+
+## What do training cases look like?
+Each training case is associated with an identifier = a unique name for that case. This identifier is used by nnU-Net to
+connect images with the correct segmentation.
+
+A training case consists of images and their corresponding segmentation.
+
+**Images** is plural because nnU-Net supports arbitrarily many input channels. In order to be as flexible as possible,
+nnU-net requires each input channel to be stored in a separate image (with the sole exception being RGB natural
+images). So these images could for example be a T1 and a T2 MRI (or whatever else you want). The different input
+channels MUST have the same geometry (same shape, spacing (if applicable) etc.) and
+must be co-registered (if applicable). Input channels are identified by nnU-Net by their FILE_ENDING: a four-digit integer at the end
+of the filename. Image files must therefore follow the following naming convention: {CASE_IDENTIFIER}_{XXXX}.{FILE_ENDING}.
+Hereby, XXXX is the 4-digit modality/channel identifier (should be unique for each modality/chanel, e.g., “0000” for T1, “0001” for
+T2 MRI, …) and FILE_ENDING is the file extension used by your image format (.png, .nii.gz, ...). See below for concrete examples.
+The dataset.json file connects channel names with the channel identifiers in the 'channel_names' key (see below for details).
+
+Side note: Typically, each channel/modality needs to be stored in a separate file and is accessed with the XXXX channel identifier.
+Exception are natural images (RGB; .png) where the three color channels can all be stored in one file (see the [road segmentation](../nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py) dataset as an example).
+
+**Segmentations** must share the same geometry with their corresponding images (same shape etc.). Segmentations are
+integer maps with each value representing a semantic class. The background must be 0. If there is no background, then
+do not use the label 0 for something else! Integer values of your semantic classes must be consecutive (0, 1, 2, 3,
+...). Of course, not all labels have to be present in each training case. Segmentations are saved as {CASE_IDENTIFER}.{FILE_ENDING} .
+
+Within a training case, all image geometries (input channels, corresponding segmentation) must match. Between training
+cases, they can of course differ. nnU-Net takes care of that.
+
+Important: The input channels must be consistent! Concretely, **all images need the same input channels in the same
+order and all input channels have to be present every time**. This is also true for inference!
+
+
+## Supported file formats
+nnU-Net expects the same file format for images and segmentations! These will also be used for inference. For now, it
+is thus not possible to train .png and then run inference on .jpg.
+
+One big change in nnU-Net V2 is the support of multiple input file types. Gone are the days of converting everything to .nii.gz!
+This is implemented by abstracting the input and output of images + segmentations through `BaseReaderWriter`. nnU-Net
+comes with a broad collection of Readers+Writers and you can even add your own to support your data format!
+See [here](../nnunetv2/imageio/readme.md).
+
+As a nice bonus, nnU-Net now also natively supports 2D input images and you no longer have to mess around with
+conversions to pseudo 3D niftis. Yuck. That was disgusting.
+
+Note that internally (for storing and accessing preprocessed images) nnU-Net will use its own file format, irrespective
+of what the raw data was provided in! This is for performance reasons.
+
+
+By default, the following file formats are supported:
+- NaturalImage2DIO: .png, .bmp, .tif
+- NibabelIO: .nii.gz, .nrrd, .mha
+- NibabelIOWithReorient: .nii.gz, .nrrd, .mha. This reader will reorient images to RAS!
+- SimpleITKIO: .nii.gz, .nrrd, .mha
+- Tiff3DIO: .tif, .tiff. 3D tif images! Since TIF does not have a standardized way of storing spacing information,
+nnU-Net expects each TIF file to be accompanied by an identically named .json file that contains three numbers
+(no units, no comma. Just separated by whitespace), one for each dimension.
+
+
+The file extension lists are not exhaustive and depend on what the backend supports. For example, nibabel and SimpleITK
+support more than the three given here. The file endings given here are just the ones we tested!
+
+IMPORTANT: nnU-Net can only be used with file formats that use lossless (or no) compression! Because the file
+format is defined for an entire dataset (and not separately for images and segmentations, this could be a todo for
+the future), we must ensure that there are no compression artifacts that destroy the segmentation maps. So no .jpg and
+the likes!
+
+## Dataset folder structure
+Datasets must be located in the `nnUNet_raw` folder (which you either define when installing nnU-Net or export/set every
+time you intend to run nnU-Net commands!).
+Each segmentation dataset is stored as a separate 'Dataset'. Datasets are associated with a dataset ID, a three digit
+integer, and a dataset name (which you can freely choose): For example, Dataset005_Prostate has 'Prostate' as dataset name and
+the dataset id is 5. Datasets are stored in the `nnUNet_raw` folder like this:
+
+ nnUNet_raw/
+ ├── Dataset001_BrainTumour
+ ├── Dataset002_Heart
+ ├── Dataset003_Liver
+ ├── Dataset004_Hippocampus
+ ├── Dataset005_Prostate
+ ├── ...
+
+Within each dataset folder, the following structure is expected:
+
+ Dataset001_BrainTumour/
+ ├── dataset.json
+ ├── imagesTr
+ ├── imagesTs # optional
+ └── labelsTr
+
+
+When adding your custom dataset, take a look at the [dataset_conversion](../nnunetv2/dataset_conversion) folder and
+pick an id that is not already taken. IDs 001-010 are for the Medical Segmentation Decathlon.
+
+- **imagesTr** contains the images belonging to the training cases. nnU-Net will perform pipeline configuration, training with
+cross-validation, as well as finding postprocessing and the best ensemble using this data.
+- **imagesTs** (optional) contains the images that belong to the test cases. nnU-Net does not use them! This could just
+be a convenient location for you to store these images. Remnant of the Medical Segmentation Decathlon folder structure.
+- **labelsTr** contains the images with the ground truth segmentation maps for the training cases.
+- **dataset.json** contains metadata of the dataset.
+
+The scheme introduced [above](#what-do-training-cases-look-like) results in the following folder structure. Given
+is an example for the first Dataset of the MSD: BrainTumour. This dataset hat four input channels: FLAIR (0000),
+T1w (0001), T1gd (0002) and T2w (0003). Note that the imagesTs folder is optional and does not have to be present.
+
+ nnUNet_raw/Dataset001_BrainTumour/
+ ├── dataset.json
+ ├── imagesTr
+ │ ├── BRATS_001_0000.nii.gz
+ │ ├── BRATS_001_0001.nii.gz
+ │ ├── BRATS_001_0002.nii.gz
+ │ ├── BRATS_001_0003.nii.gz
+ │ ├── BRATS_002_0000.nii.gz
+ │ ├── BRATS_002_0001.nii.gz
+ │ ├── BRATS_002_0002.nii.gz
+ │ ├── BRATS_002_0003.nii.gz
+ │ ├── ...
+ ├── imagesTs
+ │ ├── BRATS_485_0000.nii.gz
+ │ ├── BRATS_485_0001.nii.gz
+ │ ├── BRATS_485_0002.nii.gz
+ │ ├── BRATS_485_0003.nii.gz
+ │ ├── BRATS_486_0000.nii.gz
+ │ ├── BRATS_486_0001.nii.gz
+ │ ├── BRATS_486_0002.nii.gz
+ │ ├── BRATS_486_0003.nii.gz
+ │ ├── ...
+ └── labelsTr
+ ├── BRATS_001.nii.gz
+ ├── BRATS_002.nii.gz
+ ├── ...
+
+Here is another example of the second dataset of the MSD, which has only one input channel:
+
+ nnUNet_raw/Dataset002_Heart/
+ ├── dataset.json
+ ├── imagesTr
+ │ ├── la_003_0000.nii.gz
+ │ ├── la_004_0000.nii.gz
+ │ ├── ...
+ ├── imagesTs
+ │ ├── la_001_0000.nii.gz
+ │ ├── la_002_0000.nii.gz
+ │ ├── ...
+ └── labelsTr
+ ├── la_003.nii.gz
+ ├── la_004.nii.gz
+ ├── ...
+
+Remember: For each training case, all images must have the same geometry to ensure that their pixel arrays are aligned. Also
+make sure that all your data is co-registered!
+
+See also [dataset format inference](dataset_format_inference.md)!!
+
+## dataset.json
+The dataset.json contains metadata that nnU-Net needs for training. We have greatly reduced the number of required
+fields since version 1!
+
+Here is what the dataset.json should look like at the example of the Dataset005_Prostate from the MSD:
+
+ {
+ "channel_names": { # formerly modalities
+ "0": "T2",
+ "1": "ADC"
+ },
+ "labels": { # THIS IS DIFFERENT NOW!
+ "background": 0,
+ "PZ": 1,
+ "TZ": 2
+ },
+ "numTraining": 32,
+ "file_ending": ".nii.gz"
+ "overwrite_image_reader_writer": "SimpleITKIO" # optional! If not provided nnU-Net will automatically determine the ReaderWriter
+ }
+
+The channel_names determine the normalization used by nnU-Net. If a channel is marked as 'CT', then a global
+normalization based on the intensities in the foreground pixels will be used. If it is something else, per-channel
+z-scoring will be used. Refer to the methods section in [our paper](https://www.nature.com/articles/s41592-020-01008-z)
+for more details. nnU-Net v2 introduces a few more normalization schemes to
+choose from and allows you to define your own, see [here](explanation_normalization.md) for more information.
+
+Important changes relative to nnU-Net v1:
+- "modality" is now called "channel_names" to remove strong bias to medical images
+- labels are structured differently (name -> int instead of int -> name). This was needed to support [region-based training](region_based_training.md)
+- "file_ending" is added to support different input file types
+- "overwrite_image_reader_writer" optional! Can be used to specify a certain (custom) ReaderWriter class that should
+be used with this dataset. If not provided, nnU-Net will automatically determine the ReaderWriter
+- "regions_class_order" only used in [region-based training](region_based_training.md)
+
+There is a utility with which you can generate the dataset.json automatically. You can find it
+[here](../nnunetv2/dataset_conversion/generate_dataset_json.py).
+See our examples in [dataset_conversion](../nnunetv2/dataset_conversion) for how to use it. And read its documentation!
+
+## How to use nnU-Net v1 Tasks
+If you are migrating from the old nnU-Net, convert your existing datasets with `nnUNetv2_convert_old_nnUNet_dataset`!
+
+Example for migrating a nnU-Net v1 Task:
+```bash
+nnUNetv2_convert_old_nnUNet_dataset /media/isensee/raw_data/nnUNet_raw_data_base/nnUNet_raw_data/Task027_ACDC Dataset027_ACDC
+```
+Use `nnUNetv2_convert_old_nnUNet_dataset -h` for detailed usage instructions.
+
+
+## How to use decathlon datasets
+See [convert_msd_dataset.md](convert_msd_dataset.md)
+
+## How to use 2D data with nnU-Net
+2D is now natively supported (yay!). See [here](#supported-file-formats) as well as the example dataset in this
+[script](../nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py).
+
+
+## How to update an existing dataset
+When updating a dataset it is best practice to remove the preprocessed data in `nnUNet_preprocessed/DatasetXXX_NAME`
+to ensure a fresh start. Then replace the data in `nnUNet_raw` and rerun `nnUNetv2_plan_and_preprocess`. Optionally,
+also remove the results from old trainings.
+
+# Example dataset conversion scripts
+In the `dataset_conversion` folder (see [here](../nnunetv2/dataset_conversion)) are multiple example scripts for
+converting datasets into nnU-Net format. These scripts cannot be run as they are (you need to open them and change
+some paths) but they are excellent examples for you to learn how to convert your own datasets into nnU-Net format.
+Just pick the dataset that is closest to yours as a starting point.
+The list of dataset conversion scripts is continually updated. If you find that some publicly available dataset is
+missing, feel free to open a PR to add it!
diff --git a/nnUNet/documentation/dataset_format_inference.md b/nnUNet/documentation/dataset_format_inference.md
new file mode 100644
index 0000000..503d431
--- /dev/null
+++ b/nnUNet/documentation/dataset_format_inference.md
@@ -0,0 +1,39 @@
+# Data format for Inference
+Read the documentation on the overall [data format](dataset_format.md) first!
+
+The data format for inference must match the one used for the raw data (**specifically, the images must be in exactly
+the same format as in the imagesTr folder**). As before, the filenames must start with a
+unique identifier, followed by a 4-digit modality identifier. Here is an example for two different datasets:
+
+1) Task005_Prostate:
+
+ This task has 2 modalities, so the files in the input folder must look like this:
+
+ input_folder
+ ├── prostate_03_0000.nii.gz
+ ├── prostate_03_0001.nii.gz
+ ├── prostate_05_0000.nii.gz
+ ├── prostate_05_0001.nii.gz
+ ├── prostate_08_0000.nii.gz
+ ├── prostate_08_0001.nii.gz
+ ├── ...
+
+ _0000 has to be the T2 image and _0001 has to be the ADC image (as specified by 'channel_names' in the
+dataset.json), exactly the same as was used for training.
+
+2) Task002_Heart:
+
+ imagesTs
+ ├── la_001_0000.nii.gz
+ ├── la_002_0000.nii.gz
+ ├── la_006_0000.nii.gz
+ ├── ...
+
+ Task002 only has one modality, so each case only has one _0000.nii.gz file.
+
+
+The segmentations in the output folder will be named {CASE_IDENTIFIER}.nii.gz (omitting the modality identifier).
+
+Remember that the file format used for inference (.nii.gz in this example) must be the same as was used for training
+(and as was specified in 'file_ending' in the dataset.json)!
+
\ No newline at end of file
diff --git a/nnUNet/documentation/explanation_normalization.md b/nnUNet/documentation/explanation_normalization.md
new file mode 100644
index 0000000..ed2b897
--- /dev/null
+++ b/nnUNet/documentation/explanation_normalization.md
@@ -0,0 +1,45 @@
+# Intensity normalization in nnU-Net
+
+The type of intensity normalization applied in nnU-Net can be controlled via the `channel_names` (former `modalities`)
+entry in the dataset.json. Just like the old nnU-Net, per-channel z-scoring as well as dataset-wide z-scoring based on
+foreground intensities are supported. However, there have been a few additions as well.
+
+Reminder: The `channel_names` entry typically looks like this:
+
+ "channel_names": {
+ "0": "T2",
+ "1": "ADC"
+ },
+
+It has as many entries as there are input channels for the given dataset.
+
+To tell you a secret, nnU-Net does not really care what your channels are called. We just use this to determine what normalization
+scheme will be used for the given dataset. nnU-Net requires you to specify a normalization strategy for each of your input channels!
+If you enter a channel name that is not in the following list, the default (`zscore`) will be used.
+
+Here is a list of currently available normalization schemes:
+
+- `CT`: Perform CT normalization. Specifically, collect intensity values from the foreground classes (all but the
+background and ignore) from all training cases, compute the mean, standard deviation as well as the 0.5 and
+99.5 percentile of the values. Then clip to the percentiles, followed by subtraction of the mean and division with the
+standard deviation. The normalization that is applied is the same for each training case (for this input channel).
+The values used by nnU-Net for normalization are stored in the `foreground_intensity_properties_per_channel` entry in the
+corresponding plans file. This normalization is suitable for modalities presenting physical quantities such as CT
+images and ADC maps.
+- `noNorm` : do not perform any normalization at all
+- `rescale_to_0_1`: rescale the intensities to [0, 1]
+- `rgb_to_0_1`: assumes uint8 inputs. Divides by 255 to rescale uint8 to [0, 1]
+- `zscore`/anything else: perform z-scoring (subtract mean and standard deviation) separately for each train case
+
+**Important:** The nnU-Net default is to perform 'CT' normalization for CT images and 'zscore' for everything else! If
+you deviate from that path, make sure to benchmark whether that actually improves results!
+
+# How to implement custom normalization strategies?
+- Head over to nnunetv2/preprocessing/normalization
+- implement a new image normalization class by deriving from ImageNormalization
+- register it in nnunetv2/preprocessing/normalization/map_channel_name_to_normalization.py:channel_name_to_normalization_mapping.
+This is where you specify a channel name that should be associated with it
+- use it by specifying the correct channel_name
+
+Normalization can only be applied to one channel at a time. There is currently no way of implementing a normalization scheme
+that gets multiple channels as input to be used jointly!
\ No newline at end of file
diff --git a/nnUNet/documentation/explanation_plans_files.md b/nnUNet/documentation/explanation_plans_files.md
new file mode 100644
index 0000000..00f1216
--- /dev/null
+++ b/nnUNet/documentation/explanation_plans_files.md
@@ -0,0 +1,185 @@
+# Modifying the nnU-Net Configurations
+
+nnU-Net provides unprecedented out-of-the-box segmentation performance for essentially any dataset we have evaluated
+it on. That said, there is always room for improvements. A fool-proof strategy for squeezing out the last bit of
+performance is to start with the default nnU-Net, and then further tune it manually to a concrete dataset at hand.
+**This guide is about changes to the nnU-Net configuration you can make via the plans files. It does not cover code
+extensions of nnU-Net. For that, take a look [here](extending_nnunet.md)**
+
+In nnU-Net V2, plans files are SO MUCH MORE powerful than they were in v1. There are a lot more knobs that you can
+turn without resorting to hacky solutions or even having to touch the nnU-Net code at all! And as an added bonus:
+plans files are now also .json files and no longer require users to fiddle with pickle. Just open them in your text
+editor of choice!
+
+If overwhelmed, look at our [Examples](#examples)!
+
+# plans.json structure
+
+Plans have global and local settings. Global settings are applied to all configurations in that plans file while
+local settings are attached to a specific configuration.
+
+## Global settings
+
+- `foreground_intensity_properties_by_modality`: Intensity statistics of the foreground regions (all labels except
+background and ignore label), computed over all training cases. Used by [CT normalization scheme](explanation_normalization.md).
+- `image_reader_writer`: Name of the image reader/writer class that should be used with this dataset. You might want
+to change this if, for example, you would like to run inference with files that have a different file format. The
+class that is named here must be located in nnunetv2.imageio!
+- `label_manager`: The name of the class that does label handling. Take a look at
+nnunetv2.utilities.label_handling.LabelManager to see what it does. If you decide to change it, place your version
+in nnunetv2.utilities.label_handling!
+- `transpose_forward`: nnU-Net transposes the input data so that the axes with the highest resolution (lowest spacing)
+come last. This is because the 2D U-Net operates on the trailing dimensions (more efficient slicing due to internal
+memory layout of arrays). Future work might move this setting to affect only individual configurations.
+- transpose_backward is what numpy.transpose gets as new axis ordering.
+- `transpose_backward`: the axis ordering that inverts "transpose_forward"
+- \[`original_median_shape_after_transp`\]: just here for your information
+- \[`original_median_spacing_after_transp`\]: just here for your information
+- \[`plans_name`\]: do not change. Used internally
+- \[`experiment_planner_used`\]: just here as metadata so that we know what planner originally generated this file
+- \[`dataset_name`\]: do not change. This is the dataset these plans are intended for
+
+## Local settings
+Plans also have a `configurations` key in which the actual configurations are stored. `configurations` are again a
+dictionary, where the keys are the configuration names and the values are the local settings for each configuration.
+
+To better understand the components describing the network topology in our plans files, please read section 6.2
+in the [supplementary information](https://static-content.springer.com/esm/art%3A10.1038%2Fs41592-020-01008-z/MediaObjects/41592_2020_1008_MOESM1_ESM.pdf)
+(page 13) of our paper!
+
+Local settings:
+- `spacing`: the target spacing used in this configuration
+- `patch_size`: the patch size used for training this configuration
+- `data_identifier`: the preprocessed data for this configuration will be saved in
+ nnUNet_preprocessed/DATASET_NAME/_data_identifier_. If you add a new configuration, remember to set a unique
+ data_identifier in order to not create conflicts with other configurations (unless you plan to reuse the data from
+ another configuration, for example as is done in the cascade)
+- `batch_size`: batch size used for training
+- `batch_dice`: whether to use batch dice (pretend all samples in the batch are one image, compute dice loss over that)
+or not (each sample in the batch is a separate image, compute dice loss for each sample and average over samples)
+- `preprocessor_name`: Name of the preprocessor class used for running preprocessing. Class must be located in
+nnunetv2.preprocessing.preprocessors
+- `use_mask_for_norm`: whether to use the nonzero mask for normalization or not (relevant for BraTS and the like,
+probably False for all other datasets). Interacts with ImageNormalization class
+- `normalization_schemes`: mapping of channel identifier to ImageNormalization class name. ImageNormalization
+classes must be located in nnunetv2.preprocessing.normalization. Also see [here](explanation_normalization.md)
+- `resampling_fn_data`: name of resampling function to be used for resizing image data. resampling function must be
+callable(data, current_spacing, new_spacing, **kwargs). It must be located in nnunetv2.preprocessing.resampling
+- `resampling_fn_data_kwargs`: kwargs for resampling_fn_data
+- `resampling_fn_probabilities`: name of resampling function to be used for resizing predicted class probabilities/logits.
+resampling function must be `callable(data: Union[np.ndarray, torch.Tensor], current_spacing, new_spacing, **kwargs)`. It must be located in
+nnunetv2.preprocessing.resampling
+- `resampling_fn_probabilities_kwargs`: kwargs for resampling_fn_probabilities
+- `resampling_fn_seg`: name of resampling function to be used for resizing segmentation maps (integer: 0, 1, 2, 3, etc).
+resampling function must be callable(data, current_spacing, new_spacing, **kwargs). It must be located in
+nnunetv2.preprocessing.resampling
+- `resampling_fn_seg_kwargs`: kwargs for resampling_fn_seg
+- `UNet_class_name`: UNet class name, can be used to integrate custom dynamic architectures
+- `UNet_base_num_features`: The number of starting features for the UNet architecture. Default is 32. Default: Features
+are doubled with each downsampling
+- `unet_max_num_features`: Maximum number of features (default: capped at 320 for 3D and 512 for 2d). The purpose is to
+prevent parameters from exploding too much.
+- `conv_kernel_sizes`: the convolutional kernel sizes used by nnU-Net in each stage of the encoder. The decoder
+ mirrors the encoder and is therefore not explicitly listed here! The list is as long as `n_conv_per_stage_encoder` has
+ entries
+- `n_conv_per_stage_encoder`: number of convolutions used per stage (=at a feature map resolution in the encoder) in the encoder.
+ Default is 2. The list has as many entries as the encoder has stages
+- `n_conv_per_stage_decoder`: number of convolutions used per stage in the decoder. Also see `n_conv_per_stage_encoder`
+- `num_pool_per_axis`: number of times each of the spatial axes is pooled in the network. Needed to know how to pad
+ image sizes during inference (num_pool = 5 means input must be divisible by 2**5=32)
+- `pool_op_kernel_sizes`: the pooling kernel sizes (and at the same time strides) for each stage of the encoder
+- \[`median_image_size_in_voxels`\]: the median size of the images of the training set at the current target spacing.
+Do not modify this as this is not used. It is just here for your information.
+
+Special local settings:
+- `inherits_from`: configurations can inherit from each other. This makes it easy to add new configurations that only
+differ in a few local settings from another. If using this, remember to set a new `data_identifier` (if needed)!
+- `previous_stage`: if this configuration is part of a cascade, we need to know what the previous stage (for example
+the low resolution configuration) was. This needs to be specified here.
+- `next_stage`: if this configuration is part of a cascade, we need to know what possible subsequent stages are! This
+is because we need to export predictions in the correct spacing when running the validation. `next_stage` can either
+be a string or a list of strings
+
+# Examples
+
+## Increasing the batch size for large datasets
+If your dataset is large the training can benefit from larger batch_sizes. To do this, simply create a new
+configuration in the `configurations` dict
+
+ "configurations": {
+ "3d_fullres_bs40": {
+ "inherits_from": "3d_fullres",
+ "batch_size": 40
+ }
+ }
+
+No need to change the data_identifier. `3d_fullres_bs40` will just use the preprocessed data from `3d_fullres`.
+No need to rerun `nnUNetv2_preprocess` because we can use already existing data (if available) from `3d_fullres`.
+
+## Using custom preprocessors
+If you would like to use a different preprocessor class then this can be specified as follows:
+
+ "configurations": {
+ "3d_fullres_my_preprocesor": {
+ "inherits_from": "3d_fullres",
+ "preprocessor_name": MY_PREPROCESSOR,
+ "data_identifier": "3d_fullres_my_preprocesor"
+ }
+ }
+
+You need to run preprocessing for this new configuration:
+`nnUNetv2_preprocess -d DATASET_ID -c 3d_fullres_my_preprocesor` because it changes the preprocessing. Remember to
+set a unique `data_identifier` whenever you make modifications to the preprocessed data!
+
+## Change target spacing
+
+ "configurations": {
+ "3d_fullres_my_spacing": {
+ "inherits_from": "3d_fullres",
+ "spacing": [X, Y, Z],
+ "data_identifier": "3d_fullres_my_spacing"
+ }
+ }
+
+You need to run preprocessing for this new configuration:
+`nnUNetv2_preprocess -d DATASET_ID -c 3d_fullres_my_spacing` because it changes the preprocessing. Remember to
+set a unique `data_identifier` whenever you make modifications to the preprocessed data!
+
+## Adding a cascade to a dataset where it does not exist
+Hippocampus is small. It doesn't have a cascade. It also doesn't really make sense to add a cascade here but hey for
+the sake of demonstration we can do that.
+We change the following things here:
+
+- `spacing`: The lowres stage should operate at a lower resolution
+- we modify the `median_image_size_in_voxels` entry as a guide for what original image sizes we deal with
+- we set some patch size that is inspired by `median_image_size_in_voxels`
+- we need to remember that the patch size must be divisible by 2**num_pool in each axis!
+- network parameters such as kernel sizes, pooling operations are changed accordingly
+- we need to specify the name of the next stage
+- we need to add the highres stage
+
+This is how this would look like (comparisons with 3d_fullres given as reference):
+
+ "configurations": {
+ "3d_lowres": {
+ "inherits_from": "3d_fullres",
+ "data_identifier": "3d_lowres"
+ "spacing": [2.0, 2.0, 2.0], # from [1.0, 1.0, 1.0] in 3d_fullres
+ "median_image_size_in_voxels": [18, 25, 18], # from [36, 50, 35]
+ "patch_size": [20, 28, 20], # from [40, 56, 40]
+ "n_conv_per_stage_encoder": [2, 2, 2], # one less entry than 3d_fullres ([2, 2, 2, 2])
+ "n_conv_per_stage_decoder": [2, 2], # one less entry than 3d_fullres
+ "num_pool_per_axis": [2, 2, 2], # one less pooling than 3d_fullres in each dimension (3d_fullres: [3, 3, 3])
+ "pool_op_kernel_sizes": [[1, 1, 1], [2, 2, 2], [2, 2, 2]], # one less [2, 2, 2]
+ "conv_kernel_sizes": [[3, 3, 3], [3, 3, 3], [3, 3, 3]], # one less [3, 3, 3]
+ "next_stage": "3d_cascade_fullres" # name of the next stage in the cascade
+ },
+ "3d_cascade_fullres": { # does not need a data_identifier because we can use the data of 3d_fullres
+ "inherits_from": "3d_fullres",
+ "previous_stage": "3d_lowres" # name of the previous stage
+ }
+ }
+
+To better understand the components describing the network topology in our plans files, please read section 6.2
+in the [supplementary information](https://static-content.springer.com/esm/art%3A10.1038%2Fs41592-020-01008-z/MediaObjects/41592_2020_1008_MOESM1_ESM.pdf)
+(page 13) of our paper!
\ No newline at end of file
diff --git a/nnUNet/documentation/extending_nnunet.md b/nnUNet/documentation/extending_nnunet.md
new file mode 100644
index 0000000..2924f4f
--- /dev/null
+++ b/nnUNet/documentation/extending_nnunet.md
@@ -0,0 +1,37 @@
+# Extending nnU-Net
+We hope that the new structure of nnU-Net v2 makes it much more intuitive on how to modify it! We cannot give an
+extensive tutorial on how each and every bit of it can be modified. It is better for you to search for the position
+in the repository where the thing you intend to change is implemented and start working your way through the code from
+there. Setting breakpoints and debugging into nnU-Net really helps in understanding it and thus will help you make the
+necessary modifications!
+
+Here are some things you might want to read before you start:
+- Editing nnU-Net configurations through plans files is really powerful now and allows you to change a lot of things regarding
+preprocessing, resampling, network topology etc. Read [this](explanation_plans_files.md)!
+- [Image normalization](explanation_normalization.md) and [i/o formats](dataset_format.md#supported-file-formats) are easy to extend!
+- Manual data splits can be defined as described [here](manual_data_splits.md)
+- You can chain arbitrary configurations together into cascades, see [this again](explanation_plans_files.md)
+- Read about our support for [region-based training](region_based_training.md)
+- If you intend to modify the training procedure (loss, sampling, data augmentation, lr scheduler, etc) then you need
+to implement your own trainer class. Best practice is to create a class that inherits from nnUNetTrainer and
+implements the necessary changes. Head over to our [trainer classes folder](../nnunetv2/training/nnUNetTrainer) for
+inspiration! There will be similar trainers for what you intend to change and you can take them as a guide. nnUNetTrainer
+are structured similarly to PyTorch lightning trainers, this should also make things easier!
+- Integrating new network architectures can be done in two ways:
+ - Quick and dirty: implement a new nnUNetTrainer class and overwrite its `build_network_architecture` function.
+ Make sure your architecture is compatible with deep supervision (if not, use `nnUNetTrainerNoDeepSupervision`
+ as basis!) and that it can handle the patch sizes that are thrown at it! Your architecture should NOT apply any
+ nonlinearities at the end (softmax, sigmoid etc). nnU-Net does that!
+ - The 'proper' (but difficult) way: Build a dynamically configurable architecture such as the `PlainConvUNet` class
+ used by default. It needs to have some sort of GPU memory estimation method that can be used to evaluate whether
+ certain patch sizes and
+ topologies fit into a specified GPU memory target. Build a new `ExperimentPlanner` that can configure your new
+ class and communicate with its memory budget estimation. Run `nnUNetv2_plan_and_preprocess` while specifying your
+ custom `ExperimentPlanner` and a custom `plans_name`. Implement a nnUNetTrainer that can use the plans generated by
+ your `ExperimentPlanner` to instantiate the network architecture. Specify your plans and trainer when running `nnUNetv2_train`.
+ It always pays off to first read and understand the corresponding nnU-Net code and use it as a template for your implementation!
+- Remember that multi-GPU training, region-based training, ignore label and cascaded training are now simply integrated
+into one unified nnUNetTrainer class. No separate classes needed (remember that when implementing your own trainer
+classes and ensure support for all of these features! Or raise `NotImplementedError`)
+
+[//]: # (- Read about our support for [ignore label](ignore_label.md) and [region-based training](region_based_training.md))
diff --git a/nnUNet/documentation/how_to_use_nnunet.md b/nnUNet/documentation/how_to_use_nnunet.md
new file mode 100644
index 0000000..d962768
--- /dev/null
+++ b/nnUNet/documentation/how_to_use_nnunet.md
@@ -0,0 +1,297 @@
+## How to run nnU-Net on a new dataset
+Given some dataset, nnU-Net fully automatically configures an entire segmentation pipeline that matches its properties.
+nnU-Net covers the entire pipeline, from preprocessing to model configuration, model training, postprocessing
+all the way to ensembling. After running nnU-Net, the trained model(s) can be applied to the test cases for inference.
+
+### Dataset Format
+nnU-Net expects datasets in a structured format. This format is inspired by the data structure of
+the [Medical Segmentation Decthlon](http://medicaldecathlon.com/). Please read
+[this](dataset_format.md) for information on how to set up datasets to be compatible with nnU-Net.
+
+**Since version 2 we support multiple image file formats (.nii.gz, .png, .tif, ...)! Read the dataset_format
+documentation to learn more!**
+
+**Datasets from nnU-Net v1 can be converted to V2 by running `nnUNetv2_convert_old_nnUNet_dataset INPUT_FOLDER
+OUTPUT_DATASET_NAME`.** Remember that v2 calls datasets DatasetXXX_Name (not Task) where XXX is a 3-digit number.
+Please provide the **path** to the old task, not just the Task name. nnU-Net V2 doesn't know where v1 tasks were!
+
+### Experiment planning and preprocessing
+Given a new dataset, nnU-Net will extract a dataset fingerprint (a set of dataset-specific properties such as
+image sizes, voxel spacings, intensity information etc). This information is used to design three U-Net configurations.
+Each of these pipelines operates on its own preprocessed version of the dataset.
+
+The easiest way to run fingerprint extraction, experiment planning and preprocessing is to use:
+
+```bash
+nnUNetv2_plan_and_preprocess -d DATASET_ID --verify_dataset_integrity
+```
+
+Where `DATASET_ID` is the dataset id (duh). We recommend `--verify_dataset_integrity` whenever it's the first time
+you run this command. This will check for some of the most common error sources!
+
+You can also process several datasets at once by giving `-d 1 2 3 [...]`. If you already know what U-Net configuration
+you need you can also specify that with `-c 3d_fullres` (make sure to adapt -np in this case!). For more information
+about all the options available to you please run `nnUNetv2_plan_and_preprocess -h`.
+
+nnUNetv2_plan_and_preprocess will create a new subfolder in your nnUNet_preprocessed folder named after the dataset.
+Once the command is completed there will be a dataset_fingerprint.json file as well as a nnUNetPlans.json file for you to look at
+(in case you are interested!). There will also be subfolders containing the preprocessed data for your UNet configurations.
+
+[Optional]
+If you prefer to keep things separate, you can also use `nnUNetv2_extract_fingerprint`, `nnUNetv2_plan_experiment`
+and `nnUNetv2_preprocess` (in that order).
+
+### Model training
+#### Overview
+You pick which configurations (2d, 3d_fullres, 3d_lowres, 3d_cascade_fullres) should be trained! If you have no idea
+what performs best on your data, just run all of them and let nnU-Net identify the best one. It's up to you!
+
+nnU-Net trains all configurations in a 5-fold cross-validation over the training cases. This is 1) needed so that
+nnU-Net can estimate the performance of each configuration and tell you which one should be used for your
+segmentation problem and 2) a natural way of obtaining a good model ensemble (average the output of these 5 models
+for prediction) to boost performance.
+
+You can influence the splits nnU-Net uses for 5-fold cross-validation (see [here](manual_data_splits.md)). If you
+prefer to train a single model on all training cases, this is also possible (see below).
+
+**Note that not all U-Net configurations are created for all datasets. In datasets with small image sizes, the U-Net
+cascade (and with it the 3d_lowres configuration) is omitted because the patch size of the full resolution U-Net
+already covers a large part of the input images.**
+
+Training models is done with the `nnUNetv2_train` command. The general structure of the command is:
+```bash
+nnUNetv2_train DATASET_NAME_OR_ID UNET_CONFIGURATION FOLD [additional options, see -h]
+```
+
+UNET_CONFIGURATION is a string that identifies the requested U-Net configuration (defaults: 2d, 3d_fullres, 3d_lowres,
+3d_cascade_lowres). DATASET_NAME_OR_ID specifies what dataset should be trained on and FOLD specifies which fold of
+the 5-fold-cross-validation is trained.
+
+nnU-Net stores a checkpoint every 50 epochs. If you need to continue a previous training, just add a `--c` to the
+training command.
+
+IMPORTANT: If you plan to use `nnUNetv2_find_best_configuration` (see below) add the `--npz` flag. This makes
+nnU-Net save the softmax outputs during the final validation. They are needed for that. Exported softmax
+predictions are very large and therefore can take up a lot of disk space, which is why this is not enabled by default.
+If you ran initially without the `--npz` flag but now require the softmax predictions, simply rerun the validation with:
+```bash
+nnUNetv2_train DATASET_NAME_OR_ID UNET_CONFIGURATION FOLD --val --npz
+```
+
+You can specify the device nnU-net should use by using `-device DEVICE`. DEVICE can only be cpu, cuda or mps. If
+you have multiple GPUs, please select the gpu id using `CUDA_VISIBLE_DEVICES=X nnUNetv2_train [...]` (requires device to be cuda).
+
+See `nnUNetv2_train -h` for additional options.
+
+### 2D U-Net
+For FOLD in [0, 1, 2, 3, 4], run:
+```bash
+nnUNetv2_train DATASET_NAME_OR_ID 2d FOLD [--npz]
+```
+
+### 3D full resolution U-Net
+For FOLD in [0, 1, 2, 3, 4], run:
+```bash
+nnUNetv2_train DATASET_NAME_OR_ID 3d_fullres FOLD [--npz]
+```
+
+### 3D U-Net cascade
+#### 3D low resolution U-Net
+For FOLD in [0, 1, 2, 3, 4], run:
+```bash
+nnUNetv2_train DATASET_NAME_OR_ID 3d_lowres FOLD [--npz]
+```
+
+#### 3D full resolution U-Net
+For FOLD in [0, 1, 2, 3, 4], run:
+```bash
+nnUNetv2_train DATASET_NAME_OR_ID 3d_cascade_fullres FOLD [--npz]
+```
+**Note that the 3D full resolution U-Net of the cascade requires the five folds of the low resolution U-Net to be
+completed!**
+
+The trained models will be written to the nnUNet_results folder. Each training obtains an automatically generated
+output folder name:
+
+nnUNet_results/DatasetXXX_MYNAME/TRAINER_CLASS_NAME__PLANS_NAME__CONFIGURATION/FOLD
+
+For Dataset002_Heart (from the MSD), for example, this looks like this:
+
+ nnUNet_results/
+ ├── Dataset002_Heart
+ │── nnUNetTrainer__nnUNetPlans__2d
+ │ ├── fold_0
+ │ ├── fold_1
+ │ ├── fold_2
+ │ ├── fold_3
+ │ ├── fold_4
+ │ ├── dataset.json
+ │ ├── dataset_fingerprint.json
+ │ └── plans.json
+ └── nnUNetTrainer__nnUNetPlans__3d_fullres
+ ├── fold_0
+ ├── fold_1
+ ├── fold_2
+ ├── fold_3
+ ├── fold_4
+ ├── dataset.json
+ ├── dataset_fingerprint.json
+ └── plans.json
+
+Note that 3d_lowres and 3d_cascade_fullres do not exist here because this dataset did not trigger the cascade. In each
+model training output folder (each of the fold_x folder), the following files will be created:
+- debug.json: Contains a summary of blueprint and inferred parameters used for training this model as well as a
+bunch of additional stuff. Not easy to read, but very useful for debugging ;-)
+- checkpoint_best.pth: checkpoint files of the best model identified during training. Not used right now unless you
+explicitly tell nnU-Net to use it.
+- checkpoint_final.pth: checkpoint file of the final model (after training has ended). This is what is used for both
+validation and inference.
+- network_architecture.pdf (only if hiddenlayer is installed!): a pdf document with a figure of the network architecture in it.
+- progress.png: Shows losses, pseudo dice, learning rate and epoch times ofer the course of the training. At the top is
+a plot of the training (blue) and validation (red) loss during training. Also shows an approximation of
+ the dice (green) as well as a moving average of it (dotted green line). This approximation is the average Dice score
+ of the foreground classes. **It needs to be taken with a big (!)
+ grain of salt** because it is computed on randomly drawn patches from the validation
+ data at the end of each epoch, and the aggregation of TP, FP and FN for the Dice computation treats the patches as if
+ they all originate from the same volume ('global Dice'; we do not compute a Dice for each validation case and then
+ average over all cases but pretend that there is only one validation case from which we sample patches). The reason for
+ this is that the 'global Dice' is easy to compute during training and is still quite useful to evaluate whether a model
+ is training at all or not. A proper validation takes way too long to be done each epoch. It is run at the end of the training.
+- validation_raw: in this folder are the predicted validation cases after the training has finished. The summary.json file in here
+ contains the validation metrics (a mean over all cases is provided at the start of the file). If `--npz` was set then
+the compressed softmax outputs (saved as .npz files) are in here as well.
+
+During training it is often useful to watch the progress. We therefore recommend that you have a look at the generated
+progress.png when running the first training. It will be updated after each epoch.
+
+Training times largely depend on the GPU. The smallest GPU we recommend for training is the Nvidia RTX 2080ti. With
+that all network trainings take less than 2 days. Refer to our [benchmarks](benchmarking.md) to see if your system is
+performing as expected.
+
+### Using multiple GPUs for training
+
+If multiple GPUs are at your disposal, the best way of using them is to train multiple nnU-Net trainings at once, one
+on each GPU. This is because data parallelism never scales perfectly linearly, especially not with small networks such
+as the ones used by nnU-Net.
+
+Example:
+
+```bash
+CUDA_VISIBLE_DEVICES=0 nnUNetv2_train DATASET_NAME_OR_ID 2d 0 [--npz] & # train on GPU 0
+CUDA_VISIBLE_DEVICES=1 nnUNetv2_train DATASET_NAME_OR_ID 2d 1 [--npz] & # train on GPU 1
+CUDA_VISIBLE_DEVICES=2 nnUNetv2_train DATASET_NAME_OR_ID 2d 2 [--npz] & # train on GPU 2
+CUDA_VISIBLE_DEVICES=3 nnUNetv2_train DATASET_NAME_OR_ID 2d 3 [--npz] & # train on GPU 3
+CUDA_VISIBLE_DEVICES=4 nnUNetv2_train DATASET_NAME_OR_ID 2d 4 [--npz] & # train on GPU 4
+...
+wait
+```
+
+**Important: The first time a training is run nnU-Net will extract the preprocessed data into uncompressed numpy
+arrays for speed reasons! This operation must be completed before starting more than one training of the same
+configuration! Wait with starting subsequent folds until the first training is using the GPU! Depending on the
+dataset size and your System this should oly take a couple of minutes at most.**
+
+If you insist on running DDP multi-GPU training, we got you covered:
+
+`nnUNetv2_train DATASET_NAME_OR_ID 2d 0 [--npz] -num_gpus X`
+
+Again, note that this will be slower than running separate training on separate GPUs. DDP only makes sense if you have
+manually interfered with the nnU-Net configuration and are training larger models with larger patch and/or batch sizes!
+
+Important when using `-num_gpus`:
+1) If you train using, say, 2 GPUs but have more GPUs in the system you need to specify which GPUs should be used via
+CUDA_VISIBLE_DEVICES=0,1 (or whatever your ids are).
+2) You cannot specify more GPUs than you have samples in your minibatches. If the batch size is 2, 2 GPUs is the maximum!
+3) Make sure your batch size is divisible by the numbers of GPUs you use or you will not make good use of your hardware.
+
+In contrast to the old nnU-Net, DDP is now completely hassle free. Enjoy!
+
+### Automatically determine the best configuration
+Once the desired configurations were trained (full cross-validation) you can tell nnU-Net to automatically identify
+the best combination for you:
+
+```commandline
+nnUNetv2_find_best_configuration DATASET_NAME_OR_ID -c CONFIGURATIONS
+```
+
+`CONFIGURATIONS` hereby is the list of configurations you would like to explore. Per default, ensembling is enabled
+meaning that nnU-Net will generate all possible combinations of ensembles (2 configurations per ensemble). This requires
+the .npz files containing the predicted probabilities of the validation set to be present (use `nnUNetv2_train` with
+`--npz` flag, see above). You can disable ensembling by setting the `--disable_ensembling` flag.
+
+See `nnUNetv2_find_best_configuration -h` for more options.
+
+nnUNetv2_find_best_configuration will also automatically determine the postprocessing that should be used.
+Postprocessing in nnU-Net only considers the removal of all but the largest component in the prediction (once for
+foreground vs background and once for each label/region).
+
+Once completed, the command will print to your console exactly what commands you need to run to make predictions. It
+will also create two files in the `nnUNet_results/DATASET_NAME` folder for you to inspect:
+- `inference_instructions.txt` again contains the exact commands you need to use for predictions
+- `inference_information.json` can be inspected to see the performance of all configurations and ensembles, as well
+as the effect of the postprocessing plus some debug information.
+
+### Run inference
+Remember that the data located in the input folder must have the file endings as the dataset you trained the model on
+and must adhere to the nnU-Net naming scheme for image files (see [dataset format](dataset_format.md) and
+[inference data format](dataset_format_inference.md)!)
+
+`nnUNetv2_find_best_configuration` (see above) will print a string to the terminal with the inference commands you need to use.
+The easiest way to run inference is to simply use these commands.
+
+If you wish to manually specify the configuration(s) used for inference, use the following commands:
+
+#### Run prediction
+For each of the desired configurations, run:
+```
+nnUNetv2_predict -i INPUT_FOLDER -o OUTPUT_FOLDER -d DATASET_NAME_OR_ID -c CONFIGURATION --save_probabilities
+```
+
+Only specify `--save_probabilities` if you intend to use ensembling. `--save_probabilities` will make the command save the predicted
+probabilities alongside of the predicted segmentation masks requiring a lot of disk space.
+
+Please select a separate `OUTPUT_FOLDER` for each configuration!
+
+Note that per default, inference will be done with all 5 folds from the cross-validation as an ensemble. We very
+strongly recommend you use all 5 folds. Thus, all 5 folds must have been trained prior to running inference.
+
+If you wish to make predictions with a single model, train the `all` fold and specify it in `nnUNetv2_predict`
+with `-f all`
+
+#### Ensembling multiple configurations
+If you wish to ensemble multiple predictions (typically form different configurations), you can do so with the following command:
+```bash
+nnUNetv2_ensemble -i FOLDER1 FOLDER2 ... -o OUTPUT_FOLDER -np NUM_PROCESSES
+```
+
+You can specify an arbitrary number of folders, but remember that each folder needs to contain npz files that were
+generated by `nnUNetv2_predict`. Again, `nnUNetv2_ensemble -h` will tell you more about additional options.
+
+#### Apply postprocessing
+Finally, apply the previously determined postprocessing to the (ensembled) predictions:
+
+```commandline
+nnUNetv2_apply_postprocessing -i FOLDER_WITH_PREDICTIONS -o OUTPUT_FOLDER --pp_pkl_file POSTPROCESSING_FILE -plans_json PLANS_FILE -dataset_json DATASET_JSON_FILE
+```
+
+`nnUNetv2_find_best_configuration` (or its generated `inference_instructions.txt` file) will tell you where to find
+the postprocessing file. If not you can just look for it in your results folder (it's creatively named
+`postprocessing.pkl`). If your source folder is from an ensemble, you also need to specify a `-plans_json` file and
+a `-dataset_json` file that should be used (for single configuration predictions these are automatically copied
+from the respective training). You can pick these files from any of the ensemble members.
+
+
+## How to run inference with pretrained models
+See [here](run_inference_with_pretrained_models.md)
+
+[//]: # (## Examples)
+
+[//]: # ()
+[//]: # (To get you started we compiled two simple to follow examples:)
+
+[//]: # (- run a training with the 3d full resolution U-Net on the Hippocampus dataset. See [here](documentation/training_example_Hippocampus.md).)
+
+[//]: # (- run inference with nnU-Net's pretrained models on the Prostate dataset. See [here](documentation/inference_example_Prostate.md).)
+
+[//]: # ()
+[//]: # (Usability not good enough? Let us know!)
diff --git a/nnUNet/documentation/installation_instructions.md b/nnUNet/documentation/installation_instructions.md
new file mode 100644
index 0000000..409e5eb
--- /dev/null
+++ b/nnUNet/documentation/installation_instructions.md
@@ -0,0 +1,87 @@
+# System requirements
+
+## Operating System
+nnU-Net has been tested on Linux (Ubuntu 18.04, 20.04, 22.04; centOS, RHEL), Windows and MacOS! It should work out of the box!
+
+## Hardware requirements
+We support GPU (recommended), CPU and Apple M1/M2 as devices (currently Apple mps does not implement 3D
+convolutions, so you might have to use the CPU on those devices).
+
+### Hardware requirements for Training
+We recommend you use a GPU for training as this will take a really long time on CPU or MPS (Apple M1/M2).
+For training a GPU with at least 10 GB (popular non-datacenter options are the RTX 2080ti, RTX 3080/3090 or RTX 4080/4090) is
+required. We also recommend a strong CPU to go along with the GPU. 6 cores (12 threads)
+are the bare minimum! CPU requirements are mostly related to data augmentation and scale with the number of
+input channels and target structures. Plus, the faster the GPU, the better the CPU should be!
+
+### Hardware Requirements for inference
+Again we recommend a GPU to make predictions as this will be substantially faster than the other options. However,
+inference times are typically still manageable on CPU and MPS (Apple M1/M2). If using a GPU, it should have at least
+4 GB of available (unused) VRAM.
+
+### Example hardware configurations
+Example workstation configurations for training:
+- CPU: Ryzen 5800X - 5900X or 7900X would be even better! We have not yet tested Intel Alder/Raptor lake but they will likely work as well.
+- GPU: RTX 3090 or RTX 4090
+- RAM: 64GB
+- Storage: SSD (M.2 PCIe Gen 3 or better!)
+
+Example Server configuration for training:
+- CPU: 2x AMD EPYC7763 for a total of 128C/256T. 16C/GPU are highly recommended for fast GPUs such as the A100!
+- GPU: 8xA100 PCIe (price/performance superior to SXM variant + they use less power)
+- RAM: 1 TB
+- Storage: local SSD storage (PCIe Gen 3 or better) or ultra fast network storage
+
+(nnU-net by default uses one GPU per training. The server configuration can run up to 8 model trainings simultaneously)
+
+### Setting the correct number of Workers for data augmentation (training only)
+Note that you will need to manually set the number of processes nnU-Net uses for data augmentation according to your
+CPU/GPU ratio. For the server above (256 threads for 8 GPUs), a good value would be 24-30. You can do this by
+setting the `nnUNet_n_proc_DA` environment variable (`export nnUNet_n_proc_DA=XX`).
+Recommended values (assuming a recent CPU with good IPC) are 10-12 for RTX 2080 ti, 12 for a RTX 3090, 16-18 for
+RTX 4090, 28-32 for A100. Optimal values may vary depending on the number of input channels/modalities and number of classes.
+
+# Installation instructions
+We strongly recommend that you install nnU-Net in a virtual environment! Pip or anaconda are both fine. If you choose to
+compile PyTorch from source (see below), you will need to use conda instead of pip.
+
+Use a recent version of Python! 3.9 or newer is guaranteed to work!
+
+**nnU-Net v2 can coexist with nnU-Net v1! Both can be installed at the same time.**
+
+1) Install [PyTorch](https://pytorch.org/get-started/locally/) as described on their website (conda/pip). Please
+install the latest version with support for your hardware (cuda, mps, cpu).
+**DO NOT JUST `pip install nnunetv2` WITHOUT PROPERLY INSTALLING PYTORCH FIRST**. For maximum speed, consider
+[compiling pytorch yourself](https://github.com/pytorch/pytorch#from-source) (experienced users only!).
+2) Install nnU-Net depending on your use case:
+ 1) For use as **standardized baseline**, **out-of-the-box segmentation algorithm** or for running
+ **inference with pretrained models**:
+
+ ```pip install nnunetv2```
+
+ 2) For use as integrative **framework** (this will create a copy of the nnU-Net code on your computer so that you
+ can modify it as needed):
+ ```bash
+ git clone https://github.com/MIC-DKFZ/nnUNet.git
+ cd nnUNet
+ pip install -e .
+ ```
+3) nnU-Net needs to know where you intend to save raw data, preprocessed data and trained models. For this you need to
+ set a few environment variables. Please follow the instructions [here](setting_up_paths.md).
+4) (OPTIONAL) Install [hiddenlayer](https://github.com/waleedka/hiddenlayer). hiddenlayer enables nnU-net to generate
+ plots of the network topologies it generates (see [Model training](how_to_use_nnunet.md#model-training)).
+To install hiddenlayer,
+ run the following command:
+ ```bash
+ pip install --upgrade git+https://github.com/FabianIsensee/hiddenlayer.git
+ ```
+
+Installing nnU-Net will add several new commands to your terminal. These commands are used to run the entire nnU-Net
+pipeline. You can execute them from any location on your system. All nnU-Net commands have the prefix `nnUNetv2_` for
+easy identification.
+
+Note that these commands simply execute python scripts. If you installed nnU-Net in a virtual environment, this
+environment must be activated when executing the commands. You can see what scripts/functions are executed by
+checking the entry_points in the setup.py file.
+
+All nnU-Net commands have a `-h` option which gives information on how to use them.
diff --git a/nnUNet/documentation/manual_data_splits.md b/nnUNet/documentation/manual_data_splits.md
new file mode 100644
index 0000000..f299c02
--- /dev/null
+++ b/nnUNet/documentation/manual_data_splits.md
@@ -0,0 +1,46 @@
+# How to generate custom splits in nnU-Net
+
+Sometimes, the default 5-fold cross-validation split by nnU-Net does not fit a project. Maybe you want to run 3-fold
+cross-validation instead? Or maybe your training cases cannot be split randomly and require careful stratification.
+Fear not, for nnU-Net has got you covered (it really can do anything <3).
+
+The splits nnU-Net uses are generated in the `do_split` function of nnUNetTrainer. This function will first look for
+existing splits, stored as a file, and if no split exists it will create one. So if you wish to influence the split,
+manually creating a split file that will then be recognized and used is the way to go!
+
+The split file is located in the `nnUNet_preprocessed/DATASETXXX_NAME` folder. So it is best practice to first
+populate this folder by running `nnUNetv2_plan_and_preproccess`.
+
+Splits are stored as a .json file. They are a simple python list. The length of that list is the number of splits it
+contains (so it's 5 in the default nnU-Net). Each list entry is a dictionary with keys 'train' and 'val'. Values are
+again simply lists with the train identifiers in each set. To illustrate this, I am just messing with the Dataset002
+file as an example:
+
+```commandline
+In [1]: from batchgenerators.utilities.file_and_folder_operations import load_json
+
+In [2]: splits = load_json('splits_final.json')
+
+In [3]: len(splits)
+Out[3]: 5
+
+In [4]: splits[0].keys()
+Out[4]: dict_keys(['train', 'val'])
+
+In [5]: len(splits[0]['train'])
+Out[5]: 16
+
+In [6]: len(splits[0]['val'])
+Out[6]: 4
+
+In [7]: print(splits[0])
+{'train': ['la_003', 'la_004', 'la_005', 'la_009', 'la_010', 'la_011', 'la_014', 'la_017', 'la_018', 'la_019', 'la_020', 'la_022', 'la_023', 'la_026', 'la_029', 'la_030'],
+'val': ['la_007', 'la_016', 'la_021', 'la_024']}
+```
+
+If you are still not sure what splits are supposed to look like, simply download some reference dataset from the
+[Medical Decathlon](http://medicaldecathlon.com/), start some training (to generate the splits) and manually inspect
+the .json file with your text editor of choice!
+
+In order to generate your custom splits, all you need to do is reproduce the data structure explained above and save it as
+`splits_final.json` in the `nnUNet_preprocessed/DATASETXXX_NAME` folder. Then use `nnUNetv2_train` etc. as usual.
\ No newline at end of file
diff --git a/nnUNet/documentation/pretraining_and_finetuning.md b/nnUNet/documentation/pretraining_and_finetuning.md
new file mode 100644
index 0000000..44b46dc
--- /dev/null
+++ b/nnUNet/documentation/pretraining_and_finetuning.md
@@ -0,0 +1,82 @@
+# Pretraining with nnU-Net
+
+## Intro
+
+So far nnU-Net only supports supervised pre-training, meaning that you train a regular nnU-Net on some source dataset
+and then use the final network weights as initialization for your target dataset.
+
+As a reminder, many training hyperparameters such as patch size and network topology differ between datasets as a
+result of the automated dataset analysis and experiment planning nnU-Net is known for. So, out of the box, it is not
+possible to simply take the network weights from some dataset and then reuse them for another.
+
+Consequently, the plans need to be aligned between the two tasks. In this README we show how this can be achieved and
+how the resulting weights can then be used for initialization.
+
+### Terminology
+
+Throughout this README we use the following terminology:
+
+- `source dataset` is the dataset you intend to run the pretraining on
+- `target dataset` is the dataset you are interested in; the one you wish to fine tune on
+
+
+## Pretraining on the source dataset
+
+In order to obtain matching network topologies we need to transfer the plans from one dataset to another. Since we are
+only interested in the target dataset, we first need to run experiment planning (and preprocessing) for it:
+
+```bash
+nnUNetv2_plan_and_preprocess -d TARGET_DATASET
+```
+
+Then we need to extract the dataset fingerprint of the source dataset, if not yet available:
+
+```bash
+nnUNetv2_extract_fingerprint -d SOURCE_DATASET
+```
+
+Now we can take the plans from the target dataset and transfer it to the source:
+
+```bash
+nnUNetv2_move_plans_between_datasets -s SOURCE_DATSET -t TARGET_DATASET -sp SOURCE_PLANS_IDENTIFIER -tp TARGET_PLANS_IDENTIFIER
+```
+
+`SOURCE_PLANS_IDENTIFIER` is hereby probably nnUNetPlans unless you changed the experiment planner in
+nnUNetv2_plan_and_preprocess. For `TARGET_PLANS_IDENTIFIER` we recommend you set something custom in order to not
+overwrite default plans.
+
+Note that EVERYTHING is transferred between the datasets. Not just the network topology, batch size and patch size but
+also the normalization scheme! Therefore, a transfer between datasets that use different normalization schemes may not
+work well (but it could, depending on the schemes!).
+
+Note on CT normalization: Yes, also the clip values, mean and std are transferred!
+
+Now you can run the preprocessing on the source task:
+
+```bash
+nnUNetv2_preprocess -d SOURCE_DATSET -plans_name TARGET_PLANS_IDENTIFIER
+```
+
+And run the training as usual:
+
+```bash
+nnUNetv2_train SOURCE_DATSET CONFIG all -p TARGET_PLANS_IDENTIFIER
+```
+
+Note how we use the 'all' fold to train on all available data. For pretraining it does not make sense to split the data.
+
+## Using pretrained weights
+
+Once pretraining is completed (or you obtain compatible weights by other means) you can use them to initialize your model:
+
+```bash
+nnUNetv2_train TARGET_DATASET CONFIG FOLD -pretrained_weights PATH_TO_CHECKPOINT
+```
+
+Specify the checkpoint in PATH_TO_CHECKPOINT.
+
+When loading pretrained weights, all layers except the segmentation layers will be used!
+
+So far there are no specific nnUNet trainers for fine tuning, so the current recommendation is to just use
+nnUNetTrainer. You can however easily write your own trainers with learning rate ramp up, fine-tuning of segmentation
+heads or shorter training time.
\ No newline at end of file
diff --git a/nnUNet/documentation/region_based_training.md b/nnUNet/documentation/region_based_training.md
new file mode 100644
index 0000000..265c890
--- /dev/null
+++ b/nnUNet/documentation/region_based_training.md
@@ -0,0 +1,75 @@
+# Region-based training
+
+## What is this about?
+In some segmentation tasks, most prominently the
+[Brain Tumor Segmentation Challenge](http://braintumorsegmentation.org/), the target areas (based on which the metric
+will be computed) are different from the labels provided in the training data. This is the case because for some
+clinical applications, it is more relevant to detect the whole tumor, tumor core and enhancing tumor instead of the
+individual labels (edema, necrosis and non-enhancing tumor, enhancing tumor).
+
+
+
+The figure shows an example BraTS case along with label-based representation of the task (top) and region-based
+representation (bottom). The challenge evaluation is done on the regions. As we have shown in our
+[BraTS 2018 contribution](https://arxiv.org/abs/1809.10483), directly optimizing those
+overlapping areas over the individual labels yields better scoring models!
+
+## What can nnU-Net do?
+nnU-Net's region-based training allows you to learn areas that are constructed by merging individual labels. For
+some segmentation tasks this provides a benefit, as this shifts the importance allocated to different labels during training.
+Most prominently, this feature can be used to represent **hierarchical classes**, for example when organs +
+substructures are to be segmented. Imagine a liver segmentation problem, where vessels and tumors are also to be
+segmented. The first target region could thus be the entire liver (including the substructures), while the remaining
+targets are the individual substructues.
+
+Important: nnU-Net still requires integer label maps as input and will produce integer label maps as output!
+Region-based training can be used to learn overlapping labels, but there must be a way to model these overlaps
+for nnU-Net to work (see below how this is done).
+
+## How do you use it?
+
+When declaring the labels in the `dataset.json` file, BraTS would typically look like this:
+
+```python
+...
+"labels": {
+ "background": 0,
+ "edema": 1,
+ "non_enhancing_and_necrosis": 2,
+ "enhancing_tumor": 3
+},
+...
+```
+(we use different int values than the challenge because nnU-Net needs consecutive integers!)
+
+This representation corresponds to the upper row in the figure above.
+
+For region-based training, the labels need to be changed to the following:
+
+```python
+...
+"labels": {
+ "background": 0,
+ "whole_tumor": [1, 2, 3],
+ "tumor_core": [2, 3],
+ "enhancing_tumor": 3 # or [3]
+},
+"regions_class_order": [1, 2, 3],
+...
+```
+This corresponds to the bottom row in the figure above. Note how an additional entry in the dataset.json is
+required: `regions_class_order`. This tells nnU-Net how to convert the region representations back to an integer map.
+It essentially just tells nnU-Net what labels to place for which region in what order. Concretely, for the example
+given here, nnU-Net will place the label 1 for the 'whole_tumor' region, then place the label 2 where the "tumor_core"
+is and finally place the label 3 in the 'enhancing_tumor' area. With each step, part of the previously set pixels
+will be overwritten with the new label! So when setting your `regions_class_order`, place encompassing regions
+(like whole tumor etc) first, followed by substructures.
+
+**IMPORTANT** Because the conversion back to a segmentation map is sensitive to the order in which the regions are
+declared ("place label X in the first region") you need to make sure that this order is not perturbed! When
+automatically generating the dataset.json, make sure the dictionary keys do not get sorted alphabetically! Set
+`sort_keys=False` in `json.dump()`!!!
+
+nnU-Net will perform the evaluation + model selection also on the regions, not the individual labels!
+
+That's all. Easy, huh?
\ No newline at end of file
diff --git a/nnUNet/documentation/run_inference_with_pretrained_models.md b/nnUNet/documentation/run_inference_with_pretrained_models.md
new file mode 100644
index 0000000..d2698a2
--- /dev/null
+++ b/nnUNet/documentation/run_inference_with_pretrained_models.md
@@ -0,0 +1,7 @@
+# How to run inference with pretrained models
+**Important:** Pretrained weights from nnU-Net v1 are NOT compatible with V2. You will need to retrain with the new
+version. But honestly, you already have a fully trained model with which you can run inference (in v1), so
+just continue using that!
+
+Not yet available for V2 :-(
+If you wish to run inference with pretrained models, check out the old nnU-Net for now. We are working on this full steam!
diff --git a/nnUNet/documentation/set_environment_variables.md b/nnUNet/documentation/set_environment_variables.md
new file mode 100644
index 0000000..2dc22a5
--- /dev/null
+++ b/nnUNet/documentation/set_environment_variables.md
@@ -0,0 +1,78 @@
+# How to set environment variables
+
+nnU-Net requires some environment variables so that it always knows where the raw data, preprocessed data and trained
+models are. Depending on the operating system, these environment variables need to be set in different ways.
+
+Variables can either be set permanently (recommended!) or you can decide to set them everytime you call nnU-Net.
+
+# Linux & MacOS
+
+## Permanent
+Locate the `.bashrc` file in your home folder and add the following lines to the bottom:
+
+```bash
+export nnUNet_raw="/media/fabian/nnUNet_raw"
+export nnUNet_preprocessed="/media/fabian/nnUNet_preprocessed"
+export nnUNet_results="/media/fabian/nnUNet_results"
+```
+
+(of course you need to adapt the paths to the actual folders you intend to use).
+If you are using a different shell, such as zsh, you will need to find the correct script for it. For zsh this is `.zshrc`.
+
+## Temporary
+Just execute the following lines whenever you run nnU-Net:
+```bash
+export nnUNet_raw="/media/fabian/nnUNet_raw"
+export nnUNet_preprocessed="/media/fabian/nnUNet_preprocessed"
+export nnUNet_results="/media/fabian/nnUNet_results"
+```
+(of course you need to adapt the paths to the actual folders you intend to use).
+
+Important: These variables will be deleted if you close your terminal! They will also only apply to the current
+terminal window and DO NOT transfer to other terminals!
+
+Alternatively you can also just prefix them to your nnU-Net commands:
+
+`nnUNet_results="/media/fabian/nnUNet_results" nnUNet_preprocessed="/media/fabian/nnUNet_preprocessed" nnUNetv2_train[...]`
+
+## Verify that environment parameters are set
+You can always execute `echo ${nnUNet_raw}` etc to print the environment variables. This will return an empty string if
+they were not set.
+
+# Windows
+Useful links:
+- [https://www3.ntu.edu.sg](https://www3.ntu.edu.sg/home/ehchua/programming/howto/Environment_Variables.html#:~:text=To%20set%20(or%20change)%20a,it%20to%20an%20empty%20string.)
+- [https://phoenixnap.com](https://phoenixnap.com/kb/windows-set-environment-variable)
+
+## Permanent
+See `Set Environment Variable in Windows via GUI` [here](https://phoenixnap.com/kb/windows-set-environment-variable).
+Or read about setx (command prompt).
+
+## Temporary
+Just execute the following before you run nnU-Net:
+
+(powershell)
+```powershell
+$Env:nnUNet_raw = "/media/fabian/nnUNet_raw"
+$Env:nnUNet_preprocessed = "/media/fabian/nnUNet_preprocessed"
+$Env:nnUNet_results = "/media/fabian/nnUNet_results"
+```
+
+(command prompt)
+```commandline
+set nnUNet_raw="/media/fabian/nnUNet_raw"
+set nnUNet_preprocessed="/media/fabian/nnUNet_preprocessed"
+set nnUNet_results="/media/fabian/nnUNet_results"
+```
+
+(of course you need to adapt the paths to the actual folders you intend to use).
+
+Important: These variables will be deleted if you close your session! They will also only apply to the current
+window and DO NOT transfer to other sessions!
+
+## Verify that environment parameters are set
+Printing in Windows works differently depending on the environment you are in:
+
+powershell: `echo $Env:[variable_name]`
+
+command prompt: `echo %[variable_name]%`
\ No newline at end of file
diff --git a/nnUNet/documentation/setting_up_paths.md b/nnUNet/documentation/setting_up_paths.md
new file mode 100644
index 0000000..87f9f8c
--- /dev/null
+++ b/nnUNet/documentation/setting_up_paths.md
@@ -0,0 +1,38 @@
+# Setting up Paths
+
+nnU-Net relies on environment variables to know where raw data, preprocessed data and trained model weights are stored.
+To use the full functionality of nnU-Net, the following three environment variables must be set:
+
+1) `nnUNet_raw`: This is where you place the raw datasets. This folder will have one subfolder for each dataset names
+DatasetXXX_YYY where XXX is a 3-digit identifier (such as 001, 002, 043, 999, ...) and YYY is the (unique)
+dataset name. The datasets must be in nnU-Net format, see [here](dataset_format.md).
+
+ Example tree structure:
+ ```
+ nnUNet_raw/Dataset001_NAME1
+ ├── dataset.json
+ ├── imagesTr
+ │ ├── ...
+ ├── imagesTs
+ │ ├── ...
+ └── labelsTr
+ ├── ...
+ nnUNet_raw/Dataset002_NAME2
+ ├── dataset.json
+ ├── imagesTr
+ │ ├── ...
+ ├── imagesTs
+ │ ├── ...
+ └── labelsTr
+ ├── ...
+ ```
+
+2) `nnUNet_preprocessed`: This is the folder where the preprocessed data will be saved. The data will also be read from
+this folder during training. It is important that this folder is located on a drive with low access latency and high
+throughput (such as a nvme SSD (PCIe gen 3 is sufficient)).
+
+3) `nnUNet_results`: This specifies where nnU-Net will save the model weights. If pretrained models are downloaded, this
+is where it will save them.
+
+### How to set environment variables
+See [here](set_environment_variables.md).
\ No newline at end of file
diff --git a/nnUNet/documentation/tldr_migration_guide_from_v1.md b/nnUNet/documentation/tldr_migration_guide_from_v1.md
new file mode 100644
index 0000000..f9ec951
--- /dev/null
+++ b/nnUNet/documentation/tldr_migration_guide_from_v1.md
@@ -0,0 +1,20 @@
+# TLDR Migration Guide from nnU-Net V1
+
+- nnU-Net V2 can be installed simultaneously with V1. They won't get in each other's way
+- The environment variables needed for V2 have slightly different names. Read [this](setting_up_paths.md).
+- nnU-Net V2 datasets are called DatasetXXX_NAME. Not Task.
+- Datasets have the same structure (imagesTr, labelsTr, dataset.json) but we now support more
+[file types](dataset_format.md#supported-file-formats). The dataset.json is simplified. Use `generate_dataset_json`
+from nnunetv2.dataset_conversion.generate_dataset_json.py.
+- Careful: labels are now no longer declared as value:name but name:value. This has to do with [hierarchical labels](region_based_training.md).
+- nnU-Net v2 commands start with `nnUNetv2...`. They work mostly (but not entirely) the same. Just use the `-h` option.
+- You can transfer your V1 raw datasets to V2 with `nnUNetv2_convert_old_nnUNet_dataset`. You cannot transfer trained
+models. Continue to use the old nnU-Net Version for making inference with those.
+- These are the commands you are most likely to be using (in that order)
+ - `nnUNetv2_plan_and_preprocess`. Example: `nnUNetv2_plan_and_preprocess -d 2`
+ - `nnUNetv2_train`. Example: `nnUNetv2_train 2 3d_fullres 0`
+ - `nnUNetv2_find_best_configuration`. Example: `nnUNetv2_find_best_configuration 2 -c 2d 3d_fullres`. This command
+ will now create a `inference_instructions.txt` file in your `nnUNet_preprocessed/DatasetXXX_NAME/` folder which
+ tells you exactly how to do inference.
+ - `nnUNetv2_predict`. Example: `nnUNetv2_predict -i INPUT_FOLDER -o OUTPUT_FOLDER -c 3d_fullres -d 2`
+ - `nnUNetv2_apply_postprocessing` (see inference_instructions.txt)
diff --git a/nnUNet/nnunetv2/__init__.py b/nnUNet/nnunetv2/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/batch_running/__init__.py b/nnUNet/nnunetv2/batch_running/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/batch_running/benchmarking/__init__.py b/nnUNet/nnunetv2/batch_running/benchmarking/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/batch_running/benchmarking/generate_benchmarking_commands.py b/nnUNet/nnunetv2/batch_running/benchmarking/generate_benchmarking_commands.py
new file mode 100644
index 0000000..ca37206
--- /dev/null
+++ b/nnUNet/nnunetv2/batch_running/benchmarking/generate_benchmarking_commands.py
@@ -0,0 +1,41 @@
+if __name__ == '__main__':
+ """
+ This code probably only works within the DKFZ infrastructure (using LSF). You will need to adapt it to your scheduler!
+ """
+ gpu_models = [#'NVIDIAA100_PCIE_40GB', 'NVIDIAGeForceRTX2080Ti', 'NVIDIATITANRTX', 'TeslaV100_SXM2_32GB',
+ 'NVIDIAA100_SXM4_40GB']#, 'TeslaV100_PCIE_32GB']
+ datasets = [2, 3, 4, 5]
+ trainers = ['nnUNetTrainerBenchmark_5epochs', 'nnUNetTrainerBenchmark_5epochs_noDataLoading']
+ plans = ['nnUNetPlans']
+ configs = ['2d', '2d_bs3x', '2d_bs6x', '3d_fullres', '3d_fullres_bs3x', '3d_fullres_bs6x']
+ num_gpus = 1
+
+ benchmark_configurations = {d: configs for d in datasets}
+
+ exclude_hosts = "-R \"select[hname!='e230-dgxa100-1']'\""
+ resources = "-R \"tensorcore\""
+ queue = "-q gpu"
+ preamble = "-L /bin/bash \"source ~/load_env_torch210.sh && "
+ train_command = 'nnUNet_compile=False nnUNet_results=/dkfz/cluster/gpu/checkpoints/OE0441/isensee/nnUNet_results_remake_benchmark nnUNetv2_train'
+
+ folds = (0, )
+
+ use_these_modules = {
+ tr: plans for tr in trainers
+ }
+
+ additional_arguments = f' -num_gpus {num_gpus}' # ''
+
+ output_file = "/home/isensee/deleteme.txt"
+ with open(output_file, 'w') as f:
+ for g in gpu_models:
+ gpu_requirements = f"-gpu num={num_gpus}:j_exclusive=yes:gmodel={g}"
+ for tr in use_these_modules.keys():
+ for p in use_these_modules[tr]:
+ for dataset in benchmark_configurations.keys():
+ for config in benchmark_configurations[dataset]:
+ for fl in folds:
+ command = f'bsub {exclude_hosts} {resources} {queue} {gpu_requirements} {preamble} {train_command} {dataset} {config} {fl} -tr {tr} -p {p}'
+ if additional_arguments is not None and len(additional_arguments) > 0:
+ command += f' {additional_arguments}'
+ f.write(f'{command}\"\n')
\ No newline at end of file
diff --git a/nnUNet/nnunetv2/batch_running/benchmarking/summarize_benchmark_results.py b/nnUNet/nnunetv2/batch_running/benchmarking/summarize_benchmark_results.py
new file mode 100644
index 0000000..d966321
--- /dev/null
+++ b/nnUNet/nnunetv2/batch_running/benchmarking/summarize_benchmark_results.py
@@ -0,0 +1,70 @@
+from batchgenerators.utilities.file_and_folder_operations import join, load_json, isfile
+from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
+from nnunetv2.paths import nnUNet_results
+from nnunetv2.utilities.file_path_utilities import get_output_folder
+
+if __name__ == '__main__':
+ trainers = ['nnUNetTrainerBenchmark_5epochs', 'nnUNetTrainerBenchmark_5epochs_noDataLoading']
+ datasets = [2, 3, 4, 5]
+ plans = ['nnUNetPlans']
+ configs = ['2d', '2d_bs3x', '2d_bs6x', '3d_fullres', '3d_fullres_bs3x', '3d_fullres_bs6x']
+ output_file = join(nnUNet_results, 'benchmark_results.csv')
+
+ torch_version = '2.1.0.dev20230330'#"2.0.0"#"2.1.0.dev20230328" #"1.11.0a0+gitbc2c6ed" #
+ cudnn_version = 8700 # 8302 #
+ num_gpus = 1
+
+ unique_gpus = set()
+
+ # collect results in the most janky way possible. Amazing coding skills!
+ all_results = {}
+ for tr in trainers:
+ all_results[tr] = {}
+ for p in plans:
+ all_results[tr][p] = {}
+ for c in configs:
+ all_results[tr][p][c] = {}
+ for d in datasets:
+ dataset_name = maybe_convert_to_dataset_name(d)
+ output_folder = get_output_folder(dataset_name, tr, p, c, fold=0)
+ expected_benchmark_file = join(output_folder, 'benchmark_result.json')
+ all_results[tr][p][c][d] = {}
+ if isfile(expected_benchmark_file):
+ # filter results for what we want
+ results = [i for i in load_json(expected_benchmark_file).values()
+ if i['num_gpus'] == num_gpus and i['cudnn_version'] == cudnn_version and
+ i['torch_version'] == torch_version]
+ for r in results:
+ all_results[tr][p][c][d][r['gpu_name']] = r
+ unique_gpus.add(r['gpu_name'])
+
+ # haha. Fuck this. Collect GPUs in the code above.
+ # unique_gpus = np.unique([i["gpu_name"] for tr in trainers for p in plans for c in configs for d in datasets for i in all_results[tr][p][c][d]])
+
+ unique_gpus = list(unique_gpus)
+ unique_gpus.sort()
+
+ with open(output_file, 'w') as f:
+ f.write('Dataset,Trainer,Plans,Config')
+ for g in unique_gpus:
+ f.write(f",{g}")
+ f.write("\n")
+ for d in datasets:
+ for tr in trainers:
+ for p in plans:
+ for c in configs:
+ gpu_results = []
+ for g in unique_gpus:
+ if g in all_results[tr][p][c][d].keys():
+ gpu_results.append(round(all_results[tr][p][c][d][g]["fastest_epoch"], ndigits=2))
+ else:
+ gpu_results.append("MISSING")
+ # skip if all are missing
+ if all([i == 'MISSING' for i in gpu_results]):
+ continue
+ f.write(f"{d},{tr},{p},{c}")
+ for g in gpu_results:
+ f.write(f",{g}")
+ f.write("\n")
+ f.write("\n")
+
diff --git a/nnUNet/nnunetv2/batch_running/collect_results_custom_Decathlon.py b/nnUNet/nnunetv2/batch_running/collect_results_custom_Decathlon.py
new file mode 100644
index 0000000..e5079bd
--- /dev/null
+++ b/nnUNet/nnunetv2/batch_running/collect_results_custom_Decathlon.py
@@ -0,0 +1,114 @@
+from typing import Tuple
+
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import *
+
+from nnunetv2.evaluation.evaluate_predictions import load_summary_json
+from nnunetv2.paths import nnUNet_results
+from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name, convert_dataset_name_to_id
+from nnunetv2.utilities.file_path_utilities import get_output_folder
+
+
+def collect_results(trainers: dict, datasets: List, output_file: str,
+ configurations=("2d", "3d_fullres", "3d_lowres", "3d_cascade_fullres"),
+ folds=tuple(np.arange(5))):
+ results_dirs = (nnUNet_results,)
+ datasets_names = [maybe_convert_to_dataset_name(i) for i in datasets]
+ with open(output_file, 'w') as f:
+ for i, d in zip(datasets, datasets_names):
+ for c in configurations:
+ for module in trainers.keys():
+ for plans in trainers[module]:
+ for r in results_dirs:
+ expected_output_folder = get_output_folder(d, module, plans, c)
+ if isdir(expected_output_folder):
+ results_folds = []
+ f.write("%s,%s,%s,%s,%s" % (d, c, module, plans, r))
+ for fl in folds:
+ expected_output_folder_fold = get_output_folder(d, module, plans, c, fl)
+ expected_summary_file = join(expected_output_folder_fold, "validation",
+ "summary.json")
+ if not isfile(expected_summary_file):
+ print('expected output file not found:', expected_summary_file)
+ f.write(",")
+ results_folds.append(np.nan)
+ else:
+ foreground_mean = load_summary_json(expected_summary_file)['foreground_mean'][
+ 'Dice']
+ results_folds.append(foreground_mean)
+ f.write(",%02.4f" % foreground_mean)
+ f.write(",%02.4f\n" % np.nanmean(results_folds))
+
+
+def summarize(input_file, output_file, folds: Tuple[int, ...], configs: Tuple[str, ...], datasets, trainers):
+ txt = np.loadtxt(input_file, dtype=str, delimiter=',')
+ num_folds = txt.shape[1] - 6
+ valid_configs = {}
+ for d in datasets:
+ if isinstance(d, int):
+ d = maybe_convert_to_dataset_name(d)
+ configs_in_txt = np.unique(txt[:, 1][txt[:, 0] == d])
+ valid_configs[d] = [i for i in configs_in_txt if i in configs]
+ assert max(folds) < num_folds
+
+ with open(output_file, 'w') as f:
+ f.write("name")
+ for d in valid_configs.keys():
+ for c in valid_configs[d]:
+ f.write(",%d_%s" % (convert_dataset_name_to_id(d), c[:4]))
+ f.write(',mean\n')
+ valid_entries = txt[:, 4] == nnUNet_results
+ for t in trainers.keys():
+ trainer_locs = valid_entries & (txt[:, 2] == t)
+ for pl in trainers[t]:
+ f.write("%s__%s" % (t, pl))
+ trainer_plan_locs = trainer_locs & (txt[:, 3] == pl)
+ r = []
+ for d in valid_configs.keys():
+ trainer_plan_d_locs = trainer_plan_locs & (txt[:, 0] == d)
+ for v in valid_configs[d]:
+ trainer_plan_d_config_locs = trainer_plan_d_locs & (txt[:, 1] == v)
+ if np.any(trainer_plan_d_config_locs):
+ # we cannot have more than one row
+ assert np.sum(trainer_plan_d_config_locs) == 1
+
+ # now check that we have all folds
+ selected_row = txt[np.argwhere(trainer_plan_d_config_locs)[0,0]]
+
+ fold_results = selected_row[[i + 5 for i in folds]]
+
+ if '' in fold_results:
+ print('missing fold in', t, pl, d, v)
+ f.write(",nan")
+ r.append(np.nan)
+ else:
+ mean_dice = np.mean([float(i) for i in fold_results])
+ f.write(",%02.4f" % mean_dice)
+ r.append(mean_dice)
+ else:
+ print('missing:', t, pl, d, v)
+ f.write(",nan")
+ r.append(np.nan)
+ f.write(",%02.4f\n" % np.mean(r))
+
+
+if __name__ == '__main__':
+ use_these_trainers = {
+ 'nnUNetTrainer': ('nnUNetPlans',),
+ 'nnUNetTrainerDiceCELoss_noSmooth': ('nnUNetPlans',),
+ 'nnUNetTrainer_DASegOrd0': ('nnUNetPlans',),
+ }
+ all_results_file= join(nnUNet_results, 'customDecResults.csv')
+ datasets = [2, 3, 4, 17, 20, 24, 27, 38, 55, 64, 82]
+ collect_results(use_these_trainers, datasets, all_results_file)
+
+ folds = (0, 1, 2, 3, 4)
+ configs = ("3d_fullres", "3d_lowres")
+ output_file = join(nnUNet_results, 'customDecResults_summary5fold.csv')
+ summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers)
+
+ folds = (0, )
+ configs = ("3d_fullres", "3d_lowres")
+ output_file = join(nnUNet_results, 'customDecResults_summaryfold0.csv')
+ summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers)
+
diff --git a/nnUNet/nnunetv2/batch_running/collect_results_custom_Decathlon_2d.py b/nnUNet/nnunetv2/batch_running/collect_results_custom_Decathlon_2d.py
new file mode 100644
index 0000000..2795d3d
--- /dev/null
+++ b/nnUNet/nnunetv2/batch_running/collect_results_custom_Decathlon_2d.py
@@ -0,0 +1,18 @@
+from batchgenerators.utilities.file_and_folder_operations import *
+
+from nnunetv2.batch_running.collect_results_custom_Decathlon import collect_results, summarize
+from nnunetv2.paths import nnUNet_results
+
+if __name__ == '__main__':
+ use_these_trainers = {
+ 'nnUNetTrainer': ('nnUNetPlans', ),
+ }
+ all_results_file = join(nnUNet_results, 'hrnet_results.csv')
+ datasets = [2, 3, 4, 17, 20, 24, 27, 38, 55, 64, 82]
+ collect_results(use_these_trainers, datasets, all_results_file)
+
+ folds = (0, )
+ configs = ('2d', )
+ output_file = join(nnUNet_results, 'hrnet_results_summary_fold0.csv')
+ summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers)
+
diff --git a/nnUNet/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py b/nnUNet/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py
new file mode 100644
index 0000000..0a75fbd
--- /dev/null
+++ b/nnUNet/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py
@@ -0,0 +1,86 @@
+from copy import deepcopy
+import numpy as np
+
+
+def merge(dict1, dict2):
+ keys = np.unique(list(dict1.keys()) + list(dict2.keys()))
+ keys = np.unique(keys)
+ res = {}
+ for k in keys:
+ all_configs = []
+ if dict1.get(k) is not None:
+ all_configs += list(dict1[k])
+ if dict2.get(k) is not None:
+ all_configs += list(dict2[k])
+ if len(all_configs) > 0:
+ res[k] = tuple(np.unique(all_configs))
+ return res
+
+
+if __name__ == "__main__":
+ # after the Nature Methods paper we switch our evaluation to a different (more stable/high quality) set of
+ # datasets for evaluation and future development
+ configurations_all = {
+ 2: ("3d_fullres", "2d"),
+ 3: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
+ 4: ("2d", "3d_fullres"),
+ 17: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
+ 20: ("2d", "3d_fullres"),
+ 24: ("2d", "3d_fullres"),
+ 27: ("2d", "3d_fullres"),
+ 38: ("2d", "3d_fullres"),
+ 55: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
+ 64: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
+ 82: ("2d", "3d_fullres"),
+ # 83: ("2d", "3d_fullres"),
+ }
+
+ configurations_3d_fr_only = {
+ i: ("3d_fullres", ) for i in configurations_all if "3d_fullres" in configurations_all[i]
+ }
+
+ configurations_3d_c_only = {
+ i: ("3d_cascade_fullres", ) for i in configurations_all if "3d_cascade_fullres" in configurations_all[i]
+ }
+
+ configurations_3d_lr_only = {
+ i: ("3d_lowres", ) for i in configurations_all if "3d_lowres" in configurations_all[i]
+ }
+
+ configurations_2d_only = {
+ i: ("2d", ) for i in configurations_all if "2d" in configurations_all[i]
+ }
+
+ num_gpus = 1
+ exclude_hosts = "-R \"select[hname!='e230-dgx2-2']\" -R \"select[hname!='e230-dgx2-1']\" -R \"select[hname!='e230-dgx1-1']\" -R \"select[hname!='e230-dgxa100-1']\" -R \"select[hname!='e230-dgxa100-2']\" -R \"select[hname!='e230-dgxa100-3']\" -R \"select[hname!='e230-dgxa100-4']\""
+ resources = "-R \"tensorcore\""
+ gpu_requirements = f"-gpu num={num_gpus}:j_exclusive=yes:gmem=33G"
+ queue = "-q gpu-lowprio"
+ preamble = "-L /bin/bash \"source ~/load_env_cluster4.sh && "
+ train_command = 'nnUNet_results=/dkfz/cluster/gpu/checkpoints/OE0441/isensee/nnUNet_results_remake_release nnUNetv2_train'
+
+ folds = (0, )
+ # use_this = configurations_2d_only
+ use_this = merge(configurations_3d_fr_only, configurations_3d_lr_only)
+ # use_this = merge(use_this, configurations_3d_c_only)
+
+ use_these_modules = {
+ 'nnUNetTrainer': ('nnUNetPlans',),
+ 'nnUNetTrainerDiceCELoss_noSmooth': ('nnUNetPlans',),
+ # 'nnUNetTrainer_DASegOrd0': ('nnUNetPlans',),
+ }
+
+ additional_arguments = f'--disable_checkpointing -num_gpus {num_gpus}' # ''
+
+ output_file = "/home/isensee/deleteme.txt"
+ with open(output_file, 'w') as f:
+ for tr in use_these_modules.keys():
+ for p in use_these_modules[tr]:
+ for dataset in use_this.keys():
+ for config in use_this[dataset]:
+ for fl in folds:
+ command = f'bsub {exclude_hosts} {resources} {queue} {gpu_requirements} {preamble} {train_command} {dataset} {config} {fl} -tr {tr} -p {p}'
+ if additional_arguments is not None and len(additional_arguments) > 0:
+ command += f' {additional_arguments}'
+ f.write(f'{command}\"\n')
+
diff --git a/nnUNet/nnunetv2/batch_running/release_trainings/__init__.py b/nnUNet/nnunetv2/batch_running/release_trainings/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/batch_running/release_trainings/nnunetv2_v1/__init__.py b/nnUNet/nnunetv2/batch_running/release_trainings/nnunetv2_v1/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/batch_running/release_trainings/nnunetv2_v1/collect_results.py b/nnUNet/nnunetv2/batch_running/release_trainings/nnunetv2_v1/collect_results.py
new file mode 100644
index 0000000..f934186
--- /dev/null
+++ b/nnUNet/nnunetv2/batch_running/release_trainings/nnunetv2_v1/collect_results.py
@@ -0,0 +1,113 @@
+from typing import Tuple
+
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import *
+
+from nnunetv2.evaluation.evaluate_predictions import load_summary_json
+from nnunetv2.paths import nnUNet_results
+from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name, convert_dataset_name_to_id
+from nnunetv2.utilities.file_path_utilities import get_output_folder
+
+
+def collect_results(trainers: dict, datasets: List, output_file: str,
+ configurations=("2d", "3d_fullres", "3d_lowres", "3d_cascade_fullres"),
+ folds=tuple(np.arange(5))):
+ results_dirs = (nnUNet_results,)
+ datasets_names = [maybe_convert_to_dataset_name(i) for i in datasets]
+ with open(output_file, 'w') as f:
+ for i, d in zip(datasets, datasets_names):
+ for c in configurations:
+ for module in trainers.keys():
+ for plans in trainers[module]:
+ for r in results_dirs:
+ expected_output_folder = get_output_folder(d, module, plans, c)
+ if isdir(expected_output_folder):
+ results_folds = []
+ f.write("%s,%s,%s,%s,%s" % (d, c, module, plans, r))
+ for fl in folds:
+ expected_output_folder_fold = get_output_folder(d, module, plans, c, fl)
+ expected_summary_file = join(expected_output_folder_fold, "validation",
+ "summary.json")
+ if not isfile(expected_summary_file):
+ print('expected output file not found:', expected_summary_file)
+ f.write(",")
+ results_folds.append(np.nan)
+ else:
+ foreground_mean = load_summary_json(expected_summary_file)['foreground_mean'][
+ 'Dice']
+ results_folds.append(foreground_mean)
+ f.write(",%02.4f" % foreground_mean)
+ f.write(",%02.4f\n" % np.nanmean(results_folds))
+
+
+def summarize(input_file, output_file, folds: Tuple[int, ...], configs: Tuple[str, ...], datasets, trainers):
+ txt = np.loadtxt(input_file, dtype=str, delimiter=',')
+ num_folds = txt.shape[1] - 6
+ valid_configs = {}
+ for d in datasets:
+ if isinstance(d, int):
+ d = maybe_convert_to_dataset_name(d)
+ configs_in_txt = np.unique(txt[:, 1][txt[:, 0] == d])
+ valid_configs[d] = [i for i in configs_in_txt if i in configs]
+ assert max(folds) < num_folds
+
+ with open(output_file, 'w') as f:
+ f.write("name")
+ for d in valid_configs.keys():
+ for c in valid_configs[d]:
+ f.write(",%d_%s" % (convert_dataset_name_to_id(d), c[:4]))
+ f.write(',mean\n')
+ valid_entries = txt[:, 4] == nnUNet_results
+ for t in trainers.keys():
+ trainer_locs = valid_entries & (txt[:, 2] == t)
+ for pl in trainers[t]:
+ f.write("%s__%s" % (t, pl))
+ trainer_plan_locs = trainer_locs & (txt[:, 3] == pl)
+ r = []
+ for d in valid_configs.keys():
+ trainer_plan_d_locs = trainer_plan_locs & (txt[:, 0] == d)
+ for v in valid_configs[d]:
+ trainer_plan_d_config_locs = trainer_plan_d_locs & (txt[:, 1] == v)
+ if np.any(trainer_plan_d_config_locs):
+ # we cannot have more than one row
+ assert np.sum(trainer_plan_d_config_locs) == 1
+
+ # now check that we have all folds
+ selected_row = txt[np.argwhere(trainer_plan_d_config_locs)[0,0]]
+
+ fold_results = selected_row[[i + 5 for i in folds]]
+
+ if '' in fold_results:
+ print('missing fold in', t, pl, d, v)
+ f.write(",nan")
+ r.append(np.nan)
+ else:
+ mean_dice = np.mean([float(i) for i in fold_results])
+ f.write(",%02.4f" % mean_dice)
+ r.append(mean_dice)
+ else:
+ print('missing:', t, pl, d, v)
+ f.write(",nan")
+ r.append(np.nan)
+ f.write(",%02.4f\n" % np.mean(r))
+
+
+if __name__ == '__main__':
+ use_these_trainers = {
+ 'nnUNetTrainer': ('nnUNetPlans',),
+ 'nnUNetTrainer_v1loss': ('nnUNetPlans',),
+ }
+ all_results_file = join(nnUNet_results, 'customDecResults.csv')
+ datasets = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 17, 20, 24, 27, 35, 38, 48, 55, 64, 82]
+ collect_results(use_these_trainers, datasets, all_results_file)
+
+ folds = (0, 1, 2, 3, 4)
+ configs = ("3d_fullres", "3d_lowres")
+ output_file = join(nnUNet_results, 'customDecResults_summary5fold.csv')
+ summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers)
+
+ folds = (0, )
+ configs = ("3d_fullres", "3d_lowres")
+ output_file = join(nnUNet_results, 'customDecResults_summaryfold0.csv')
+ summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers)
+
diff --git a/nnUNet/nnunetv2/batch_running/release_trainings/nnunetv2_v1/generate_lsf_commands.py b/nnUNet/nnunetv2/batch_running/release_trainings/nnunetv2_v1/generate_lsf_commands.py
new file mode 100644
index 0000000..7c5934f
--- /dev/null
+++ b/nnUNet/nnunetv2/batch_running/release_trainings/nnunetv2_v1/generate_lsf_commands.py
@@ -0,0 +1,93 @@
+from copy import deepcopy
+import numpy as np
+
+
+def merge(dict1, dict2):
+ keys = np.unique(list(dict1.keys()) + list(dict2.keys()))
+ keys = np.unique(keys)
+ res = {}
+ for k in keys:
+ all_configs = []
+ if dict1.get(k) is not None:
+ all_configs += list(dict1[k])
+ if dict2.get(k) is not None:
+ all_configs += list(dict2[k])
+ if len(all_configs) > 0:
+ res[k] = tuple(np.unique(all_configs))
+ return res
+
+
+if __name__ == "__main__":
+ # after the Nature Methods paper we switch our evaluation to a different (more stable/high quality) set of
+ # datasets for evaluation and future development
+ configurations_all = {
+ # 1: ("3d_fullres", "2d"),
+ 2: ("3d_fullres", "2d"),
+ # 3: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
+ # 4: ("2d", "3d_fullres"),
+ 5: ("2d", "3d_fullres"),
+ # 6: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
+ # 7: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
+ # 8: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
+ # 9: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
+ # 10: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
+ # 17: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
+ 20: ("2d", "3d_fullres"),
+ 24: ("2d", "3d_fullres"),
+ 27: ("2d", "3d_fullres"),
+ 35: ("2d", "3d_fullres"),
+ 38: ("2d", "3d_fullres"),
+ # 55: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
+ # 64: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
+ # 82: ("2d", "3d_fullres"),
+ # 83: ("2d", "3d_fullres"),
+ }
+
+ configurations_3d_fr_only = {
+ i: ("3d_fullres", ) for i in configurations_all if "3d_fullres" in configurations_all[i]
+ }
+
+ configurations_3d_c_only = {
+ i: ("3d_cascade_fullres", ) for i in configurations_all if "3d_cascade_fullres" in configurations_all[i]
+ }
+
+ configurations_3d_lr_only = {
+ i: ("3d_lowres", ) for i in configurations_all if "3d_lowres" in configurations_all[i]
+ }
+
+ configurations_2d_only = {
+ i: ("2d", ) for i in configurations_all if "2d" in configurations_all[i]
+ }
+
+ num_gpus = 1
+ exclude_hosts = "-R \"select[hname!='e230-dgx2-2']\" -R \"select[hname!='e230-dgx2-1']\""
+ resources = "-R \"tensorcore\""
+ gpu_requirements = f"-gpu num={num_gpus}:j_exclusive=yes:gmem=1G"
+ queue = "-q gpu-lowprio"
+ preamble = "-L /bin/bash \"source ~/load_env_cluster4.sh && "
+ train_command = 'nnUNet_keep_files_open=True nnUNet_results=/dkfz/cluster/gpu/data/OE0441/isensee/nnUNet_results_remake_release_normfix nnUNetv2_train'
+
+ folds = (0, 1, 2, 3, 4)
+ # use_this = configurations_2d_only
+ # use_this = merge(configurations_3d_fr_only, configurations_3d_lr_only)
+ # use_this = merge(use_this, configurations_3d_c_only)
+ use_this = configurations_all
+
+ use_these_modules = {
+ 'nnUNetTrainer': ('nnUNetPlans',),
+ }
+
+ additional_arguments = f'--disable_checkpointing -num_gpus {num_gpus}' # ''
+
+ output_file = "/home/isensee/deleteme.txt"
+ with open(output_file, 'w') as f:
+ for tr in use_these_modules.keys():
+ for p in use_these_modules[tr]:
+ for dataset in use_this.keys():
+ for config in use_this[dataset]:
+ for fl in folds:
+ command = f'bsub {exclude_hosts} {resources} {queue} {gpu_requirements} {preamble} {train_command} {dataset} {config} {fl} -tr {tr} -p {p}'
+ if additional_arguments is not None and len(additional_arguments) > 0:
+ command += f' {additional_arguments}'
+ f.write(f'{command}\"\n')
+
diff --git a/nnUNet/nnunetv2/configuration.py b/nnUNet/nnunetv2/configuration.py
new file mode 100644
index 0000000..cdc8cb6
--- /dev/null
+++ b/nnUNet/nnunetv2/configuration.py
@@ -0,0 +1,10 @@
+import os
+
+from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA
+
+default_num_processes = 8 if 'nnUNet_def_n_proc' not in os.environ else int(os.environ['nnUNet_def_n_proc'])
+
+ANISO_THRESHOLD = 3 # determines when a sample is considered anisotropic (3 means that the spacing in the low
+# resolution axis must be 3x as large as the next largest spacing)
+
+default_n_proc_DA = get_allowed_n_proc_DA()
diff --git a/nnUNet/nnunetv2/dataset_conversion/Dataset027_ACDC.py b/nnUNet/nnunetv2/dataset_conversion/Dataset027_ACDC.py
new file mode 100644
index 0000000..569ff6f
--- /dev/null
+++ b/nnUNet/nnunetv2/dataset_conversion/Dataset027_ACDC.py
@@ -0,0 +1,87 @@
+import os
+import shutil
+from pathlib import Path
+
+from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
+from nnunetv2.paths import nnUNet_raw
+
+
+def make_out_dirs(dataset_id: int, task_name="ACDC"):
+ dataset_name = f"Dataset{dataset_id:03d}_{task_name}"
+
+ out_dir = Path(nnUNet_raw.replace('"', "")) / dataset_name
+ out_train_dir = out_dir / "imagesTr"
+ out_labels_dir = out_dir / "labelsTr"
+ out_test_dir = out_dir / "imagesTs"
+
+ os.makedirs(out_dir, exist_ok=True)
+ os.makedirs(out_train_dir, exist_ok=True)
+ os.makedirs(out_labels_dir, exist_ok=True)
+ os.makedirs(out_test_dir, exist_ok=True)
+
+ return out_dir, out_train_dir, out_labels_dir, out_test_dir
+
+
+def copy_files(src_data_folder: Path, train_dir: Path, labels_dir: Path, test_dir: Path):
+ """Copy files from the ACDC dataset to the nnUNet dataset folder. Returns the number of training cases."""
+ patients_train = sorted([f for f in (src_data_folder / "training").iterdir() if f.is_dir()])
+ patients_test = sorted([f for f in (src_data_folder / "testing").iterdir() if f.is_dir()])
+
+ num_training_cases = 0
+ # Copy training files and corresponding labels.
+ for patient_dir in patients_train:
+ for file in patient_dir.iterdir():
+ if file.suffix == ".gz" and "_gt" not in file.name and "_4d" not in file.name:
+ # The stem is 'patient.nii', and the suffix is '.gz'.
+ # We split the stem and append _0000 to the patient part.
+ shutil.copy(file, train_dir / f"{file.stem.split('.')[0]}_0000.nii.gz")
+ num_training_cases += 1
+ elif file.suffix == ".gz" and "_gt" in file.name:
+ shutil.copy(file, labels_dir / file.name.replace("_gt", ""))
+
+ # Copy test files.
+ for patient_dir in patients_test:
+ for file in patient_dir.iterdir():
+ if file.suffix == ".gz" and "_gt" not in file.name and "_4d" not in file.name:
+ shutil.copy(file, test_dir / f"{file.stem.split('.')[0]}_0000.nii.gz")
+
+ return num_training_cases
+
+
+def convert_acdc(src_data_folder: str, dataset_id=27):
+ out_dir, train_dir, labels_dir, test_dir = make_out_dirs(dataset_id=dataset_id)
+ num_training_cases = copy_files(Path(src_data_folder), train_dir, labels_dir, test_dir)
+
+ generate_dataset_json(
+ str(out_dir),
+ channel_names={
+ 0: "cineMRI",
+ },
+ labels={
+ "background": 0,
+ "RV": 1,
+ "MLV": 2,
+ "LVC": 3,
+ },
+ file_ending=".nii.gz",
+ num_training_cases=num_training_cases,
+ )
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-i",
+ "--input_folder",
+ type=str,
+ help="The downloaded ACDC dataset dir. Should contain extracted 'training' and 'testing' folders.",
+ )
+ parser.add_argument(
+ "-d", "--dataset_id", required=False, type=int, default=27, help="nnU-Net Dataset ID, default: 27"
+ )
+ args = parser.parse_args()
+ print("Converting...")
+ convert_acdc(args.input_folder, args.dataset_id)
+ print("Done!")
diff --git a/nnUNet/nnunetv2/dataset_conversion/Dataset073_Fluo_C3DH_A549_SIM.py b/nnUNet/nnunetv2/dataset_conversion/Dataset073_Fluo_C3DH_A549_SIM.py
new file mode 100644
index 0000000..eca22d0
--- /dev/null
+++ b/nnUNet/nnunetv2/dataset_conversion/Dataset073_Fluo_C3DH_A549_SIM.py
@@ -0,0 +1,85 @@
+from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
+from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed
+import tifffile
+from batchgenerators.utilities.file_and_folder_operations import *
+import shutil
+
+
+if __name__ == '__main__':
+ """
+ This is going to be my test dataset for working with tif as input and output images
+
+ All we do here is copy the files and rename them. Not file conversions take place
+ """
+ dataset_name = 'Dataset073_Fluo_C3DH_A549_SIM'
+
+ imagestr = join(nnUNet_raw, dataset_name, 'imagesTr')
+ imagests = join(nnUNet_raw, dataset_name, 'imagesTs')
+ labelstr = join(nnUNet_raw, dataset_name, 'labelsTr')
+ maybe_mkdir_p(imagestr)
+ maybe_mkdir_p(imagests)
+ maybe_mkdir_p(labelstr)
+
+ # we extract the downloaded train and test datasets to two separate folders and name them Fluo-C3DH-A549-SIM_train
+ # and Fluo-C3DH-A549-SIM_test
+ train_source = '/home/fabian/Downloads/Fluo-C3DH-A549-SIM_train'
+ test_source = '/home/fabian/Downloads/Fluo-C3DH-A549-SIM_test'
+
+ # with the old nnU-Net we had to convert all the files to nifti. This is no longer required. We can just copy the
+ # tif files
+
+ # tif is broken when it comes to spacing. No standards. Grr. So when we use tif nnU-Net expects a separate file
+ # that specifies the spacing. This file needs to exist for EVERY training/test case to allow for different spacings
+ # between files. Important! The spacing must align with the axes.
+ # Here when we do print(tifffile.imread('IMAGE').shape) we get (29, 300, 350). The low resolution axis is the first.
+ # The spacing on the website is griven in the wrong axis order. Great.
+ spacing = (1, 0.126, 0.126)
+
+ # train set
+ for seq in ['01', '02']:
+ images_dir = join(train_source, seq)
+ seg_dir = join(train_source, seq + '_GT', 'SEG')
+ # if we were to be super clean we would go by IDs but here we just trust the files are sorted the correct way.
+ # Simpler filenames in the cell tracking challenge would be soooo nice.
+ images = subfiles(images_dir, suffix='.tif', sort=True, join=False)
+ segs = subfiles(seg_dir, suffix='.tif', sort=True, join=False)
+ for i, (im, se) in enumerate(zip(images, segs)):
+ target_name = f'{seq}_image_{i:03d}'
+ # we still need the '_0000' suffix for images! Otherwise we would not be able to support multiple input
+ # channels distributed over separate files
+ shutil.copy(join(images_dir, im), join(imagestr, target_name + '_0000.tif'))
+ # spacing file!
+ save_json({'spacing': spacing}, join(imagestr, target_name + '.json'))
+ shutil.copy(join(seg_dir, se), join(labelstr, target_name + '.tif'))
+ # spacing file!
+ save_json({'spacing': spacing}, join(labelstr, target_name + '.json'))
+
+ # test set, same a strain just without the segmentations
+ for seq in ['01', '02']:
+ images_dir = join(test_source, seq)
+ images = subfiles(images_dir, suffix='.tif', sort=True, join=False)
+ for i, im in enumerate(images):
+ target_name = f'{seq}_image_{i:03d}'
+ shutil.copy(join(images_dir, im), join(imagests, target_name + '_0000.tif'))
+ # spacing file!
+ save_json({'spacing': spacing}, join(imagests, target_name + '.json'))
+
+ # now we generate the dataset json
+ generate_dataset_json(
+ join(nnUNet_raw, dataset_name),
+ {0: 'fluorescence_microscopy'},
+ {'background': 0, 'cell': 1},
+ 60,
+ '.tif'
+ )
+
+ # custom split to ensure we are stratifying properly. This dataset only has 2 folds
+ caseids = [i[:-4] for i in subfiles(labelstr, suffix='.tif', join=False)]
+ splits = []
+ splits.append(
+ {'train': [i for i in caseids if i.startswith('01_')], 'val': [i for i in caseids if i.startswith('02_')]}
+ )
+ splits.append(
+ {'train': [i for i in caseids if i.startswith('02_')], 'val': [i for i in caseids if i.startswith('01_')]}
+ )
+ save_json(splits, join(nnUNet_preprocessed, dataset_name, 'splits_final.json'))
\ No newline at end of file
diff --git a/nnUNet/nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py b/nnUNet/nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py
new file mode 100644
index 0000000..90dcc6c
--- /dev/null
+++ b/nnUNet/nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py
@@ -0,0 +1,87 @@
+import multiprocessing
+import shutil
+from multiprocessing import Pool
+
+from batchgenerators.utilities.file_and_folder_operations import *
+
+from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
+from nnunetv2.paths import nnUNet_raw
+from skimage import io
+from acvl_utils.morphology.morphology_helper import generic_filter_components
+from scipy.ndimage import binary_fill_holes
+
+
+def load_and_covnert_case(input_image: str, input_seg: str, output_image: str, output_seg: str,
+ min_component_size: int = 50):
+ seg = io.imread(input_seg)
+ seg[seg == 255] = 1
+ image = io.imread(input_image)
+ image = image.sum(2)
+ mask = image == (3 * 255)
+ # the dataset has large white areas in which road segmentations can exist but no image information is available.
+ # Remove the road label in these areas
+ mask = generic_filter_components(mask, filter_fn=lambda ids, sizes: [i for j, i in enumerate(ids) if
+ sizes[j] > min_component_size])
+ mask = binary_fill_holes(mask)
+ seg[mask] = 0
+ io.imsave(output_seg, seg, check_contrast=False)
+ shutil.copy(input_image, output_image)
+
+
+if __name__ == "__main__":
+ # extracted archive from https://www.kaggle.com/datasets/insaff/massachusetts-roads-dataset?resource=download
+ source = '/media/fabian/data/raw_datasets/Massachussetts_road_seg/road_segmentation_ideal'
+
+ dataset_name = 'Dataset120_RoadSegmentation'
+
+ imagestr = join(nnUNet_raw, dataset_name, 'imagesTr')
+ imagests = join(nnUNet_raw, dataset_name, 'imagesTs')
+ labelstr = join(nnUNet_raw, dataset_name, 'labelsTr')
+ labelsts = join(nnUNet_raw, dataset_name, 'labelsTs')
+ maybe_mkdir_p(imagestr)
+ maybe_mkdir_p(imagests)
+ maybe_mkdir_p(labelstr)
+ maybe_mkdir_p(labelsts)
+
+ train_source = join(source, 'training')
+ test_source = join(source, 'testing')
+
+ with multiprocessing.get_context("spawn").Pool(8) as p:
+
+ # not all training images have a segmentation
+ valid_ids = subfiles(join(train_source, 'output'), join=False, suffix='png')
+ num_train = len(valid_ids)
+ r = []
+ for v in valid_ids:
+ r.append(
+ p.starmap_async(
+ load_and_covnert_case,
+ ((
+ join(train_source, 'input', v),
+ join(train_source, 'output', v),
+ join(imagestr, v[:-4] + '_0000.png'),
+ join(labelstr, v),
+ 50
+ ),)
+ )
+ )
+
+ # test set
+ valid_ids = subfiles(join(test_source, 'output'), join=False, suffix='png')
+ for v in valid_ids:
+ r.append(
+ p.starmap_async(
+ load_and_covnert_case,
+ ((
+ join(test_source, 'input', v),
+ join(test_source, 'output', v),
+ join(imagests, v[:-4] + '_0000.png'),
+ join(labelsts, v),
+ 50
+ ),)
+ )
+ )
+ _ = [i.get() for i in r]
+
+ generate_dataset_json(join(nnUNet_raw, dataset_name), {0: 'R', 1: 'G', 2: 'B'}, {'background': 0, 'road': 1},
+ num_train, '.png', dataset_name=dataset_name)
diff --git a/nnUNet/nnunetv2/dataset_conversion/Dataset137_BraTS21.py b/nnUNet/nnunetv2/dataset_conversion/Dataset137_BraTS21.py
new file mode 100644
index 0000000..b4817d2
--- /dev/null
+++ b/nnUNet/nnunetv2/dataset_conversion/Dataset137_BraTS21.py
@@ -0,0 +1,98 @@
+import multiprocessing
+import shutil
+from multiprocessing import Pool
+
+import SimpleITK as sitk
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
+from nnunetv2.paths import nnUNet_raw
+
+
+def copy_BraTS_segmentation_and_convert_labels_to_nnUNet(in_file: str, out_file: str) -> None:
+ # use this for segmentation only!!!
+ # nnUNet wants the labels to be continuous. BraTS is 0, 1, 2, 4 -> we make that into 0, 1, 2, 3
+ img = sitk.ReadImage(in_file)
+ img_npy = sitk.GetArrayFromImage(img)
+
+ uniques = np.unique(img_npy)
+ for u in uniques:
+ if u not in [0, 1, 2, 4]:
+ raise RuntimeError('unexpected label')
+
+ seg_new = np.zeros_like(img_npy)
+ seg_new[img_npy == 4] = 3
+ seg_new[img_npy == 2] = 1
+ seg_new[img_npy == 1] = 2
+ img_corr = sitk.GetImageFromArray(seg_new)
+ img_corr.CopyInformation(img)
+ sitk.WriteImage(img_corr, out_file)
+
+
+def convert_labels_back_to_BraTS(seg: np.ndarray):
+ new_seg = np.zeros_like(seg)
+ new_seg[seg == 1] = 2
+ new_seg[seg == 3] = 4
+ new_seg[seg == 2] = 1
+ return new_seg
+
+
+def load_convert_labels_back_to_BraTS(filename, input_folder, output_folder):
+ a = sitk.ReadImage(join(input_folder, filename))
+ b = sitk.GetArrayFromImage(a)
+ c = convert_labels_back_to_BraTS(b)
+ d = sitk.GetImageFromArray(c)
+ d.CopyInformation(a)
+ sitk.WriteImage(d, join(output_folder, filename))
+
+
+def convert_folder_with_preds_back_to_BraTS_labeling_convention(input_folder: str, output_folder: str, num_processes: int = 12):
+ """
+ reads all prediction files (nifti) in the input folder, converts the labels back to BraTS convention and saves the
+ """
+ maybe_mkdir_p(output_folder)
+ nii = subfiles(input_folder, suffix='.nii.gz', join=False)
+ with multiprocessing.get_context("spawn").Pool(num_processes) as p:
+ p.starmap(load_convert_labels_back_to_BraTS, zip(nii, [input_folder] * len(nii), [output_folder] * len(nii)))
+
+
+if __name__ == '__main__':
+ brats_data_dir = '/home/isensee/drives/E132-Rohdaten/BraTS_2021/training'
+
+ task_id = 137
+ task_name = "BraTS2021"
+
+ foldername = "Dataset%03.0d_%s" % (task_id, task_name)
+
+ # setting up nnU-Net folders
+ out_base = join(nnUNet_raw, foldername)
+ imagestr = join(out_base, "imagesTr")
+ labelstr = join(out_base, "labelsTr")
+ maybe_mkdir_p(imagestr)
+ maybe_mkdir_p(labelstr)
+
+ case_ids = subdirs(brats_data_dir, prefix='BraTS', join=False)
+
+ for c in case_ids:
+ shutil.copy(join(brats_data_dir, c, c + "_t1.nii.gz"), join(imagestr, c + '_0000.nii.gz'))
+ shutil.copy(join(brats_data_dir, c, c + "_t1ce.nii.gz"), join(imagestr, c + '_0001.nii.gz'))
+ shutil.copy(join(brats_data_dir, c, c + "_t2.nii.gz"), join(imagestr, c + '_0002.nii.gz'))
+ shutil.copy(join(brats_data_dir, c, c + "_flair.nii.gz"), join(imagestr, c + '_0003.nii.gz'))
+
+ copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, c, c + "_seg.nii.gz"),
+ join(labelstr, c + '.nii.gz'))
+
+ generate_dataset_json(out_base,
+ channel_names={0: 'T1', 1: 'T1ce', 2: 'T2', 3: 'Flair'},
+ labels={
+ 'background': 0,
+ 'whole tumor': (1, 2, 3),
+ 'tumor core': (2, 3),
+ 'enhancing tumor': (3, )
+ },
+ num_training_cases=len(case_ids),
+ file_ending='.nii.gz',
+ regions_class_order=(1, 2, 3),
+ license='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863',
+ reference='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863',
+ dataset_release='1.0')
diff --git a/nnUNet/nnunetv2/dataset_conversion/Dataset218_Amos2022_task1.py b/nnUNet/nnunetv2/dataset_conversion/Dataset218_Amos2022_task1.py
new file mode 100644
index 0000000..1f33cd7
--- /dev/null
+++ b/nnUNet/nnunetv2/dataset_conversion/Dataset218_Amos2022_task1.py
@@ -0,0 +1,70 @@
+from batchgenerators.utilities.file_and_folder_operations import *
+import shutil
+from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
+from nnunetv2.paths import nnUNet_raw
+
+
+def convert_amos_task1(amos_base_dir: str, nnunet_dataset_id: int = 218):
+ """
+ AMOS doesn't say anything about how the validation set is supposed to be used. So we just incorporate that into
+ the train set. Having a 5-fold cross-validation is superior to a single train:val split
+ """
+ task_name = "AMOS2022_postChallenge_task1"
+
+ foldername = "Dataset%03.0d_%s" % (nnunet_dataset_id, task_name)
+
+ # setting up nnU-Net folders
+ out_base = join(nnUNet_raw, foldername)
+ imagestr = join(out_base, "imagesTr")
+ imagests = join(out_base, "imagesTs")
+ labelstr = join(out_base, "labelsTr")
+ maybe_mkdir_p(imagestr)
+ maybe_mkdir_p(imagests)
+ maybe_mkdir_p(labelstr)
+
+ dataset_json_source = load_json(join(amos_base_dir, 'dataset.json'))
+
+ training_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['training']]
+ tr_ctr = 0
+ for tr in training_identifiers:
+ if int(tr.split("_")[-1]) <= 410: # these are the CT images
+ tr_ctr += 1
+ shutil.copy(join(amos_base_dir, 'imagesTr', tr + '.nii.gz'), join(imagestr, f'{tr}_0000.nii.gz'))
+ shutil.copy(join(amos_base_dir, 'labelsTr', tr + '.nii.gz'), join(labelstr, f'{tr}.nii.gz'))
+
+ test_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['test']]
+ for ts in test_identifiers:
+ if int(ts.split("_")[-1]) <= 500: # these are the CT images
+ shutil.copy(join(amos_base_dir, 'imagesTs', ts + '.nii.gz'), join(imagests, f'{ts}_0000.nii.gz'))
+
+ val_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['validation']]
+ for vl in val_identifiers:
+ if int(vl.split("_")[-1]) <= 409: # these are the CT images
+ tr_ctr += 1
+ shutil.copy(join(amos_base_dir, 'imagesVa', vl + '.nii.gz'), join(imagestr, f'{vl}_0000.nii.gz'))
+ shutil.copy(join(amos_base_dir, 'labelsVa', vl + '.nii.gz'), join(labelstr, f'{vl}.nii.gz'))
+
+ generate_dataset_json(out_base, {0: "CT"}, labels={v: int(k) for k,v in dataset_json_source['labels'].items()},
+ num_training_cases=tr_ctr, file_ending='.nii.gz',
+ dataset_name=task_name, reference='https://amos22.grand-challenge.org/',
+ release='https://zenodo.org/record/7262581',
+ overwrite_image_reader_writer='NibabelIOWithReorient',
+ description="This is the dataset as released AFTER the challenge event. It has the "
+ "validation set gt in it! We just use the validation images as additional "
+ "training cases because AMOS doesn't specify how they should be used. nnU-Net's"
+ " 5-fold CV is better than some random train:val split.")
+
+
+if __name__ == '__main__':
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('input_folder', type=str,
+ help="The downloaded and extracted AMOS2022 (https://amos22.grand-challenge.org/) data. "
+ "Use this link: https://zenodo.org/record/7262581."
+ "You need to specify the folder with the imagesTr, imagesVal, labelsTr etc subfolders here!")
+ parser.add_argument('-d', required=False, type=int, default=218, help='nnU-Net Dataset ID, default: 218')
+ args = parser.parse_args()
+ amos_base = args.input_folder
+ convert_amos_task1(amos_base, args.d)
+
+
diff --git a/nnUNet/nnunetv2/dataset_conversion/Dataset219_Amos2022_task2.py b/nnUNet/nnunetv2/dataset_conversion/Dataset219_Amos2022_task2.py
new file mode 100644
index 0000000..9a5e2c6
--- /dev/null
+++ b/nnUNet/nnunetv2/dataset_conversion/Dataset219_Amos2022_task2.py
@@ -0,0 +1,65 @@
+from batchgenerators.utilities.file_and_folder_operations import *
+import shutil
+from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
+from nnunetv2.paths import nnUNet_raw
+
+
+def convert_amos_task2(amos_base_dir: str, nnunet_dataset_id: int = 219):
+ """
+ AMOS doesn't say anything about how the validation set is supposed to be used. So we just incorporate that into
+ the train set. Having a 5-fold cross-validation is superior to a single train:val split
+ """
+ task_name = "AMOS2022_postChallenge_task2"
+
+ foldername = "Dataset%03.0d_%s" % (nnunet_dataset_id, task_name)
+
+ # setting up nnU-Net folders
+ out_base = join(nnUNet_raw, foldername)
+ imagestr = join(out_base, "imagesTr")
+ imagests = join(out_base, "imagesTs")
+ labelstr = join(out_base, "labelsTr")
+ maybe_mkdir_p(imagestr)
+ maybe_mkdir_p(imagests)
+ maybe_mkdir_p(labelstr)
+
+ dataset_json_source = load_json(join(amos_base_dir, 'dataset.json'))
+
+ training_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['training']]
+ for tr in training_identifiers:
+ shutil.copy(join(amos_base_dir, 'imagesTr', tr + '.nii.gz'), join(imagestr, f'{tr}_0000.nii.gz'))
+ shutil.copy(join(amos_base_dir, 'labelsTr', tr + '.nii.gz'), join(labelstr, f'{tr}.nii.gz'))
+
+ test_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['test']]
+ for ts in test_identifiers:
+ shutil.copy(join(amos_base_dir, 'imagesTs', ts + '.nii.gz'), join(imagests, f'{ts}_0000.nii.gz'))
+
+ val_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['validation']]
+ for vl in val_identifiers:
+ shutil.copy(join(amos_base_dir, 'imagesVa', vl + '.nii.gz'), join(imagestr, f'{vl}_0000.nii.gz'))
+ shutil.copy(join(amos_base_dir, 'labelsVa', vl + '.nii.gz'), join(labelstr, f'{vl}.nii.gz'))
+
+ generate_dataset_json(out_base, {0: "either_CT_or_MR"}, labels={v: int(k) for k,v in dataset_json_source['labels'].items()},
+ num_training_cases=len(training_identifiers) + len(val_identifiers), file_ending='.nii.gz',
+ dataset_name=task_name, reference='https://amos22.grand-challenge.org/',
+ release='https://zenodo.org/record/7262581',
+ overwrite_image_reader_writer='NibabelIOWithReorient',
+ description="This is the dataset as released AFTER the challenge event. It has the "
+ "validation set gt in it! We just use the validation images as additional "
+ "training cases because AMOS doesn't specify how they should be used. nnU-Net's"
+ " 5-fold CV is better than some random train:val split.")
+
+
+if __name__ == '__main__':
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('input_folder', type=str,
+ help="The downloaded and extracted AMOS2022 (https://amos22.grand-challenge.org/) data. "
+ "Use this link: https://zenodo.org/record/7262581."
+ "You need to specify the folder with the imagesTr, imagesVal, labelsTr etc subfolders here!")
+ parser.add_argument('-d', required=False, type=int, default=219, help='nnU-Net Dataset ID, default: 219')
+ args = parser.parse_args()
+ amos_base = args.input_folder
+ convert_amos_task2(amos_base, args.d)
+
+ # /home/isensee/Downloads/amos22/amos22/
+
diff --git a/nnUNet/nnunetv2/dataset_conversion/Dataset220_KiTS2023.py b/nnUNet/nnunetv2/dataset_conversion/Dataset220_KiTS2023.py
new file mode 100644
index 0000000..20a794c
--- /dev/null
+++ b/nnUNet/nnunetv2/dataset_conversion/Dataset220_KiTS2023.py
@@ -0,0 +1,50 @@
+from batchgenerators.utilities.file_and_folder_operations import *
+import shutil
+from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json
+from nnunetv2.paths import nnUNet_raw
+
+
+def convert_kits2023(kits_base_dir: str, nnunet_dataset_id: int = 220):
+ task_name = "KiTS2023"
+
+ foldername = "Dataset%03.0d_%s" % (nnunet_dataset_id, task_name)
+
+ # setting up nnU-Net folders
+ out_base = join(nnUNet_raw, foldername)
+ imagestr = join(out_base, "imagesTr")
+ labelstr = join(out_base, "labelsTr")
+ maybe_mkdir_p(imagestr)
+ maybe_mkdir_p(labelstr)
+
+ cases = subdirs(kits_base_dir, prefix='case_', join=False)
+ for tr in cases:
+ shutil.copy(join(kits_base_dir, tr, 'imaging.nii.gz'), join(imagestr, f'{tr}_0000.nii.gz'))
+ shutil.copy(join(kits_base_dir, tr, 'segmentation.nii.gz'), join(labelstr, f'{tr}.nii.gz'))
+
+ generate_dataset_json(out_base, {0: "CT"},
+ labels={
+ "background": 0,
+ "kidney": (1, 2, 3),
+ "masses": (2, 3),
+ "tumor": 2
+ },
+ regions_class_order=(1, 3, 2),
+ num_training_cases=len(cases), file_ending='.nii.gz',
+ dataset_name=task_name, reference='none',
+ release='prerelease',
+ overwrite_image_reader_writer='NibabelIOWithReorient',
+ description="KiTS2023")
+
+
+if __name__ == '__main__':
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('input_folder', type=str,
+ help="The downloaded and extracted KiTS2023 dataset (must have case_XXXXX subfolders)")
+ parser.add_argument('-d', required=False, type=int, default=220, help='nnU-Net Dataset ID, default: 220')
+ args = parser.parse_args()
+ amos_base = args.input_folder
+ convert_kits2023(amos_base, args.d)
+
+ # /media/isensee/raw_data/raw_datasets/kits23/dataset
+
diff --git a/nnUNet/nnunetv2/dataset_conversion/Dataset988_dummyDataset4.py b/nnUNet/nnunetv2/dataset_conversion/Dataset988_dummyDataset4.py
new file mode 100644
index 0000000..80b295d
--- /dev/null
+++ b/nnUNet/nnunetv2/dataset_conversion/Dataset988_dummyDataset4.py
@@ -0,0 +1,32 @@
+import os
+
+from batchgenerators.utilities.file_and_folder_operations import *
+
+from nnunetv2.paths import nnUNet_raw
+from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets
+
+if __name__ == '__main__':
+ # creates a dummy dataset where there are no files in imagestr and labelstr
+ source_dataset = 'Dataset004_Hippocampus'
+
+ target_dataset = 'Dataset987_dummyDataset4'
+ target_dataset_dir = join(nnUNet_raw, target_dataset)
+ maybe_mkdir_p(target_dataset_dir)
+
+ dataset = get_filenames_of_train_images_and_targets(join(nnUNet_raw, source_dataset))
+
+ # the returned dataset will have absolute paths. We should use relative paths so that you can freely copy
+ # datasets around between systems. As long as the source dataset is there it will continue working even if
+ # nnUNet_raw is in different locations
+
+ # paths must be relative to target_dataset_dir!!!
+ for k in dataset.keys():
+ dataset[k]['label'] = os.path.relpath(dataset[k]['label'], target_dataset_dir)
+ dataset[k]['images'] = [os.path.relpath(i, target_dataset_dir) for i in dataset[k]['images']]
+
+ # load old dataset.json
+ dataset_json = load_json(join(nnUNet_raw, source_dataset, 'dataset.json'))
+ dataset_json['dataset'] = dataset
+
+ # save
+ save_json(dataset_json, join(target_dataset_dir, 'dataset.json'), sort_keys=False)
diff --git a/nnUNet/nnunetv2/dataset_conversion/__init__.py b/nnUNet/nnunetv2/dataset_conversion/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/dataset_conversion/convert_MSD_dataset.py b/nnUNet/nnunetv2/dataset_conversion/convert_MSD_dataset.py
new file mode 100644
index 0000000..97aac29
--- /dev/null
+++ b/nnUNet/nnunetv2/dataset_conversion/convert_MSD_dataset.py
@@ -0,0 +1,132 @@
+import argparse
+import multiprocessing
+import shutil
+from multiprocessing import Pool
+from typing import Optional
+import SimpleITK as sitk
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunetv2.paths import nnUNet_raw
+from nnunetv2.utilities.dataset_name_id_conversion import find_candidate_datasets
+from nnunetv2.configuration import default_num_processes
+import numpy as np
+
+
+def split_4d_nifti(filename, output_folder):
+ img_itk = sitk.ReadImage(filename)
+ dim = img_itk.GetDimension()
+ file_base = os.path.basename(filename)
+ if dim == 3:
+ shutil.copy(filename, join(output_folder, file_base[:-7] + "_0000.nii.gz"))
+ return
+ elif dim != 4:
+ raise RuntimeError("Unexpected dimensionality: %d of file %s, cannot split" % (dim, filename))
+ else:
+ img_npy = sitk.GetArrayFromImage(img_itk)
+ spacing = img_itk.GetSpacing()
+ origin = img_itk.GetOrigin()
+ direction = np.array(img_itk.GetDirection()).reshape(4,4)
+ # now modify these to remove the fourth dimension
+ spacing = tuple(list(spacing[:-1]))
+ origin = tuple(list(origin[:-1]))
+ direction = tuple(direction[:-1, :-1].reshape(-1))
+ for i, t in enumerate(range(img_npy.shape[0])):
+ img = img_npy[t]
+ img_itk_new = sitk.GetImageFromArray(img)
+ img_itk_new.SetSpacing(spacing)
+ img_itk_new.SetOrigin(origin)
+ img_itk_new.SetDirection(direction)
+ sitk.WriteImage(img_itk_new, join(output_folder, file_base[:-7] + "_%04.0d.nii.gz" % i))
+
+
+def convert_msd_dataset(source_folder: str, overwrite_target_id: Optional[int] = None,
+ num_processes: int = default_num_processes) -> None:
+ if source_folder.endswith('/') or source_folder.endswith('\\'):
+ source_folder = source_folder[:-1]
+
+ labelsTr = join(source_folder, 'labelsTr')
+ imagesTs = join(source_folder, 'imagesTs')
+ imagesTr = join(source_folder, 'imagesTr')
+ assert isdir(labelsTr), f"labelsTr subfolder missing in source folder"
+ assert isdir(imagesTs), f"imagesTs subfolder missing in source folder"
+ assert isdir(imagesTr), f"imagesTr subfolder missing in source folder"
+ dataset_json = join(source_folder, 'dataset.json')
+ assert isfile(dataset_json), f"dataset.json missing in source_folder"
+
+ # infer source dataset id and name
+ task, dataset_name = os.path.basename(source_folder).split('_')
+ task_id = int(task[4:])
+
+ # check if target dataset id is taken
+ target_id = task_id if overwrite_target_id is None else overwrite_target_id
+ existing_datasets = find_candidate_datasets(target_id)
+ assert len(existing_datasets) == 0, f"Target dataset id {target_id} is already taken, please consider changing " \
+ f"it using overwrite_target_id. Conflicting dataset: {existing_datasets} (check nnUNet_results, nnUNet_preprocessed and nnUNet_raw!)"
+
+ target_dataset_name = f"Dataset{target_id:03d}_{dataset_name}"
+ target_folder = join(nnUNet_raw, target_dataset_name)
+ target_imagesTr = join(target_folder, 'imagesTr')
+ target_imagesTs = join(target_folder, 'imagesTs')
+ target_labelsTr = join(target_folder, 'labelsTr')
+ maybe_mkdir_p(target_imagesTr)
+ maybe_mkdir_p(target_imagesTs)
+ maybe_mkdir_p(target_labelsTr)
+
+ with multiprocessing.get_context("spawn").Pool(num_processes) as p:
+ results = []
+
+ # convert 4d train images
+ source_images = [i for i in subfiles(imagesTr, suffix='.nii.gz', join=False) if
+ not i.startswith('.') and not i.startswith('_')]
+ source_images = [join(imagesTr, i) for i in source_images]
+
+ results.append(
+ p.starmap_async(
+ split_4d_nifti, zip(source_images, [target_imagesTr] * len(source_images))
+ )
+ )
+
+ # convert 4d test images
+ source_images = [i for i in subfiles(imagesTs, suffix='.nii.gz', join=False) if
+ not i.startswith('.') and not i.startswith('_')]
+ source_images = [join(imagesTs, i) for i in source_images]
+
+ results.append(
+ p.starmap_async(
+ split_4d_nifti, zip(source_images, [target_imagesTs] * len(source_images))
+ )
+ )
+
+ # copy segmentations
+ source_images = [i for i in subfiles(labelsTr, suffix='.nii.gz', join=False) if
+ not i.startswith('.') and not i.startswith('_')]
+ for s in source_images:
+ shutil.copy(join(labelsTr, s), join(target_labelsTr, s))
+
+ [i.get() for i in results]
+
+ dataset_json = load_json(dataset_json)
+ dataset_json['labels'] = {j: int(i) for i, j in dataset_json['labels'].items()}
+ dataset_json['file_ending'] = ".nii.gz"
+ dataset_json["channel_names"] = dataset_json["modality"]
+ del dataset_json["modality"]
+ del dataset_json["training"]
+ del dataset_json["test"]
+ save_json(dataset_json, join(nnUNet_raw, target_dataset_name, 'dataset.json'), sort_keys=False)
+
+
+def entry_point():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-i', type=str, required=True,
+ help='Downloaded and extracted MSD dataset folder. CANNOT be nnUNetv1 dataset! Example: '
+ '/home/fabian/Downloads/Task05_Prostate')
+ parser.add_argument('-overwrite_id', type=int, required=False, default=None,
+ help='Overwrite the dataset id. If not set we use the id of the MSD task (inferred from '
+ 'folder name). Only use this if you already have an equivalently numbered dataset!')
+ parser.add_argument('-np', type=int, required=False, default=default_num_processes,
+ help=f'Number of processes used. Default: {default_num_processes}')
+ args = parser.parse_args()
+ convert_msd_dataset(args.i, args.overwrite_id, args.np)
+
+
+if __name__ == '__main__':
+ convert_msd_dataset('/home/fabian/Downloads/Task05_Prostate', overwrite_target_id=201)
diff --git a/nnUNet/nnunetv2/dataset_conversion/convert_raw_dataset_from_old_nnunet_format.py b/nnUNet/nnunetv2/dataset_conversion/convert_raw_dataset_from_old_nnunet_format.py
new file mode 100644
index 0000000..fb77533
--- /dev/null
+++ b/nnUNet/nnunetv2/dataset_conversion/convert_raw_dataset_from_old_nnunet_format.py
@@ -0,0 +1,53 @@
+import shutil
+from copy import deepcopy
+
+from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p, isdir, load_json, save_json
+from nnunetv2.paths import nnUNet_raw
+
+
+def convert(source_folder, target_dataset_name):
+ """
+ remember that old tasks were called TaskXXX_YYY and new ones are called DatasetXXX_YYY
+ source_folder
+ """
+ if isdir(join(nnUNet_raw, target_dataset_name)):
+ raise RuntimeError(f'Target dataset name {target_dataset_name} already exists. Aborting... '
+ f'(we might break something). If you are sure you want to proceed, please manually '
+ f'delete {join(nnUNet_raw, target_dataset_name)}')
+ maybe_mkdir_p(join(nnUNet_raw, target_dataset_name))
+ shutil.copytree(join(source_folder, 'imagesTr'), join(nnUNet_raw, target_dataset_name, 'imagesTr'))
+ shutil.copytree(join(source_folder, 'labelsTr'), join(nnUNet_raw, target_dataset_name, 'labelsTr'))
+ if isdir(join(source_folder, 'imagesTs')):
+ shutil.copytree(join(source_folder, 'imagesTs'), join(nnUNet_raw, target_dataset_name, 'imagesTs'))
+ if isdir(join(source_folder, 'labelsTs')):
+ shutil.copytree(join(source_folder, 'labelsTs'), join(nnUNet_raw, target_dataset_name, 'labelsTs'))
+ if isdir(join(source_folder, 'imagesVal')):
+ shutil.copytree(join(source_folder, 'imagesVal'), join(nnUNet_raw, target_dataset_name, 'imagesVal'))
+ if isdir(join(source_folder, 'labelsVal')):
+ shutil.copytree(join(source_folder, 'labelsVal'), join(nnUNet_raw, target_dataset_name, 'labelsVal'))
+ shutil.copy(join(source_folder, 'dataset.json'), join(nnUNet_raw, target_dataset_name))
+
+ dataset_json = load_json(join(nnUNet_raw, target_dataset_name, 'dataset.json'))
+ del dataset_json['tensorImageSize']
+ del dataset_json['numTest']
+ del dataset_json['training']
+ del dataset_json['test']
+ dataset_json['channel_names'] = deepcopy(dataset_json['modality'])
+ del dataset_json['modality']
+
+ dataset_json['labels'] = {j: int(i) for i, j in dataset_json['labels'].items()}
+ dataset_json['file_ending'] = ".nii.gz"
+ save_json(dataset_json, join(nnUNet_raw, target_dataset_name, 'dataset.json'), sort_keys=False)
+
+
+def convert_entry_point():
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument("input_folder", type=str,
+ help='Raw old nnUNet dataset. This must be the folder with imagesTr,labelsTr etc subfolders! '
+ 'Please provide the PATH to the old Task, not just the task name. nnU-Net V2 does not '
+ 'know where v1 tasks are.')
+ parser.add_argument("output_dataset_name", type=str,
+ help='New dataset NAME (not path!). Must follow the DatasetXXX_NAME convention!')
+ args = parser.parse_args()
+ convert(args.input_folder, args.output_dataset_name)
diff --git a/nnUNet/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset996_IntegrationTest_Hippocampus_regions_ignore.py b/nnUNet/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset996_IntegrationTest_Hippocampus_regions_ignore.py
new file mode 100644
index 0000000..d59fc8e
--- /dev/null
+++ b/nnUNet/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset996_IntegrationTest_Hippocampus_regions_ignore.py
@@ -0,0 +1,75 @@
+import SimpleITK as sitk
+import shutil
+
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import isdir, join, load_json, save_json, nifti_files
+
+from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
+from nnunetv2.paths import nnUNet_raw
+from nnunetv2.utilities.label_handling.label_handling import LabelManager
+from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
+
+
+def sparsify_segmentation(seg: np.ndarray, label_manager: LabelManager, percent_of_slices: float) -> np.ndarray:
+ assert label_manager.has_ignore_label, "This preprocessor only works with datasets that have an ignore label!"
+ seg_new = np.ones_like(seg) * label_manager.ignore_label
+ x, y, z = seg.shape
+ # x
+ num_slices = max(1, round(x * percent_of_slices))
+ selected_slices = np.random.choice(x, num_slices, replace=False)
+ seg_new[selected_slices] = seg[selected_slices]
+ # y
+ num_slices = max(1, round(y * percent_of_slices))
+ selected_slices = np.random.choice(y, num_slices, replace=False)
+ seg_new[:, selected_slices] = seg[:, selected_slices]
+ # z
+ num_slices = max(1, round(z * percent_of_slices))
+ selected_slices = np.random.choice(z, num_slices, replace=False)
+ seg_new[:, :, selected_slices] = seg[:, :, selected_slices]
+ return seg_new
+
+
+if __name__ == '__main__':
+ dataset_name = 'IntegrationTest_Hippocampus_regions_ignore'
+ dataset_id = 996
+ dataset_name = f"Dataset{dataset_id:03d}_{dataset_name}"
+
+ try:
+ existing_dataset_name = maybe_convert_to_dataset_name(dataset_id)
+ if existing_dataset_name != dataset_name:
+ raise FileExistsError(f"A different dataset with id {dataset_id} already exists :-(: {existing_dataset_name}. If "
+ f"you intent to delete it, remember to also remove it in nnUNet_preprocessed and "
+ f"nnUNet_results!")
+ except RuntimeError:
+ pass
+
+ if isdir(join(nnUNet_raw, dataset_name)):
+ shutil.rmtree(join(nnUNet_raw, dataset_name))
+
+ source_dataset = maybe_convert_to_dataset_name(4)
+ shutil.copytree(join(nnUNet_raw, source_dataset), join(nnUNet_raw, dataset_name))
+
+ # additionally optimize entire hippocampus region, remove Posterior
+ dj = load_json(join(nnUNet_raw, dataset_name, 'dataset.json'))
+ dj['labels'] = {
+ 'background': 0,
+ 'hippocampus': (1, 2),
+ 'anterior': 1,
+ 'ignore': 3
+ }
+ dj['regions_class_order'] = (2, 1)
+ save_json(dj, join(nnUNet_raw, dataset_name, 'dataset.json'), sort_keys=False)
+
+ # now add ignore label to segmentation images
+ np.random.seed(1234)
+ lm = LabelManager(label_dict=dj['labels'], regions_class_order=dj.get('regions_class_order'))
+
+ segs = nifti_files(join(nnUNet_raw, dataset_name, 'labelsTr'))
+ for s in segs:
+ seg_itk = sitk.ReadImage(s)
+ seg_npy = sitk.GetArrayFromImage(seg_itk)
+ seg_npy = sparsify_segmentation(seg_npy, lm, 0.1 / 3)
+ seg_itk_new = sitk.GetImageFromArray(seg_npy)
+ seg_itk_new.CopyInformation(seg_itk)
+ sitk.WriteImage(seg_itk_new, s)
+
diff --git a/nnUNet/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset997_IntegrationTest_Hippocampus_regions.py b/nnUNet/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset997_IntegrationTest_Hippocampus_regions.py
new file mode 100644
index 0000000..b40c534
--- /dev/null
+++ b/nnUNet/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset997_IntegrationTest_Hippocampus_regions.py
@@ -0,0 +1,37 @@
+import shutil
+
+from batchgenerators.utilities.file_and_folder_operations import isdir, join, load_json, save_json
+
+from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
+from nnunetv2.paths import nnUNet_raw
+
+if __name__ == '__main__':
+ dataset_name = 'IntegrationTest_Hippocampus_regions'
+ dataset_id = 997
+ dataset_name = f"Dataset{dataset_id:03d}_{dataset_name}"
+
+ try:
+ existing_dataset_name = maybe_convert_to_dataset_name(dataset_id)
+ if existing_dataset_name != dataset_name:
+ raise FileExistsError(
+ f"A different dataset with id {dataset_id} already exists :-(: {existing_dataset_name}. If "
+ f"you intent to delete it, remember to also remove it in nnUNet_preprocessed and "
+ f"nnUNet_results!")
+ except RuntimeError:
+ pass
+
+ if isdir(join(nnUNet_raw, dataset_name)):
+ shutil.rmtree(join(nnUNet_raw, dataset_name))
+
+ source_dataset = maybe_convert_to_dataset_name(4)
+ shutil.copytree(join(nnUNet_raw, source_dataset), join(nnUNet_raw, dataset_name))
+
+ # additionally optimize entire hippocampus region, remove Posterior
+ dj = load_json(join(nnUNet_raw, dataset_name, 'dataset.json'))
+ dj['labels'] = {
+ 'background': 0,
+ 'hippocampus': (1, 2),
+ 'anterior': 1
+ }
+ dj['regions_class_order'] = (2, 1)
+ save_json(dj, join(nnUNet_raw, dataset_name, 'dataset.json'), sort_keys=False)
diff --git a/nnUNet/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset998_IntegrationTest_Hippocampus_ignore.py b/nnUNet/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset998_IntegrationTest_Hippocampus_ignore.py
new file mode 100644
index 0000000..1781a27
--- /dev/null
+++ b/nnUNet/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset998_IntegrationTest_Hippocampus_ignore.py
@@ -0,0 +1,33 @@
+import shutil
+
+from batchgenerators.utilities.file_and_folder_operations import isdir, join, load_json, save_json
+
+from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
+from nnunetv2.paths import nnUNet_raw
+
+
+if __name__ == '__main__':
+ dataset_name = 'IntegrationTest_Hippocampus_ignore'
+ dataset_id = 998
+ dataset_name = f"Dataset{dataset_id:03d}_{dataset_name}"
+
+ try:
+ existing_dataset_name = maybe_convert_to_dataset_name(dataset_id)
+ if existing_dataset_name != dataset_name:
+ raise FileExistsError(f"A different dataset with id {dataset_id} already exists :-(: {existing_dataset_name}. If "
+ f"you intent to delete it, remember to also remove it in nnUNet_preprocessed and "
+ f"nnUNet_results!")
+ except RuntimeError:
+ pass
+
+ if isdir(join(nnUNet_raw, dataset_name)):
+ shutil.rmtree(join(nnUNet_raw, dataset_name))
+
+ source_dataset = maybe_convert_to_dataset_name(4)
+ shutil.copytree(join(nnUNet_raw, source_dataset), join(nnUNet_raw, dataset_name))
+
+ # set class 2 to ignore label
+ dj = load_json(join(nnUNet_raw, dataset_name, 'dataset.json'))
+ dj['labels']['ignore'] = 2
+ del dj['labels']['Posterior']
+ save_json(dj, join(nnUNet_raw, dataset_name, 'dataset.json'), sort_keys=False)
diff --git a/nnUNet/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset999_IntegrationTest_Hippocampus.py b/nnUNet/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset999_IntegrationTest_Hippocampus.py
new file mode 100644
index 0000000..33075da
--- /dev/null
+++ b/nnUNet/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset999_IntegrationTest_Hippocampus.py
@@ -0,0 +1,27 @@
+import shutil
+
+from batchgenerators.utilities.file_and_folder_operations import isdir, join
+
+from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
+from nnunetv2.paths import nnUNet_raw
+
+
+if __name__ == '__main__':
+ dataset_name = 'IntegrationTest_Hippocampus'
+ dataset_id = 999
+ dataset_name = f"Dataset{dataset_id:03d}_{dataset_name}"
+
+ try:
+ existing_dataset_name = maybe_convert_to_dataset_name(dataset_id)
+ if existing_dataset_name != dataset_name:
+ raise FileExistsError(f"A different dataset with id {dataset_id} already exists :-(: {existing_dataset_name}. If "
+ f"you intent to delete it, remember to also remove it in nnUNet_preprocessed and "
+ f"nnUNet_results!")
+ except RuntimeError:
+ pass
+
+ if isdir(join(nnUNet_raw, dataset_name)):
+ shutil.rmtree(join(nnUNet_raw, dataset_name))
+
+ source_dataset = maybe_convert_to_dataset_name(4)
+ shutil.copytree(join(nnUNet_raw, source_dataset), join(nnUNet_raw, dataset_name))
diff --git a/nnUNet/nnunetv2/dataset_conversion/datasets_for_integration_tests/__init__.py b/nnUNet/nnunetv2/dataset_conversion/datasets_for_integration_tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/dataset_conversion/generate_dataset_json.py b/nnUNet/nnunetv2/dataset_conversion/generate_dataset_json.py
new file mode 100644
index 0000000..2f2e115
--- /dev/null
+++ b/nnUNet/nnunetv2/dataset_conversion/generate_dataset_json.py
@@ -0,0 +1,103 @@
+from typing import Tuple
+
+from batchgenerators.utilities.file_and_folder_operations import save_json, join
+
+
+def generate_dataset_json(output_folder: str,
+ channel_names: dict,
+ labels: dict,
+ num_training_cases: int,
+ file_ending: str,
+ regions_class_order: Tuple[int, ...] = None,
+ dataset_name: str = None, reference: str = None, release: str = None, license: str = None,
+ description: str = None,
+ overwrite_image_reader_writer: str = None, **kwargs):
+ """
+ Generates a dataset.json file in the output folder
+
+ channel_names:
+ Channel names must map the index to the name of the channel, example:
+ {
+ 0: 'T1',
+ 1: 'CT'
+ }
+ Note that the channel names may influence the normalization scheme!! Learn more in the documentation.
+
+ labels:
+ This will tell nnU-Net what labels to expect. Important: This will also determine whether you use region-based training or not.
+ Example regular labels:
+ {
+ 'background': 0,
+ 'left atrium': 1,
+ 'some other label': 2
+ }
+ Example region-based training:
+ {
+ 'background': 0,
+ 'whole tumor': (1, 2, 3),
+ 'tumor core': (2, 3),
+ 'enhancing tumor': 3
+ }
+
+ Remember that nnU-Net expects consecutive values for labels! nnU-Net also expects 0 to be background!
+
+ num_training_cases: is used to double check all cases are there!
+
+ file_ending: needed for finding the files correctly. IMPORTANT! File endings must match between images and
+ segmentations!
+
+ dataset_name, reference, release, license, description: self-explanatory and not used by nnU-Net. Just for
+ completeness and as a reminder that these would be great!
+
+ overwrite_image_reader_writer: If you need a special IO class for your dataset you can derive it from
+ BaseReaderWriter, place it into nnunet.imageio and reference it here by name
+
+ kwargs: whatever you put here will be placed in the dataset.json as well
+
+ """
+ has_regions: bool = any([isinstance(i, (tuple, list)) and len(i) > 1 for i in labels.values()])
+ if has_regions:
+ assert regions_class_order is not None, f"You have defined regions but regions_class_order is not set. " \
+ f"You need that."
+ # channel names need strings as keys
+ keys = list(channel_names.keys())
+ for k in keys:
+ if not isinstance(k, str):
+ channel_names[str(k)] = channel_names[k]
+ del channel_names[k]
+
+ # labels need ints as values
+ for l in labels.keys():
+ value = labels[l]
+ if isinstance(value, (tuple, list)):
+ value = tuple([int(i) for i in value])
+ labels[l] = value
+ else:
+ labels[l] = int(labels[l])
+
+ dataset_json = {
+ 'channel_names': channel_names, # previously this was called 'modality'. I didnt like this so this is
+ # channel_names now. Live with it.
+ 'labels': labels,
+ 'numTraining': num_training_cases,
+ 'file_ending': file_ending,
+ }
+
+ if dataset_name is not None:
+ dataset_json['name'] = dataset_name
+ if reference is not None:
+ dataset_json['reference'] = reference
+ if release is not None:
+ dataset_json['release'] = release
+ if license is not None:
+ dataset_json['licence'] = license
+ if description is not None:
+ dataset_json['description'] = description
+ if overwrite_image_reader_writer is not None:
+ dataset_json['overwrite_image_reader_writer'] = overwrite_image_reader_writer
+ if regions_class_order is not None:
+ dataset_json['regions_class_order'] = regions_class_order
+
+ dataset_json.update(kwargs)
+
+ save_json(dataset_json, join(output_folder, 'dataset.json'), sort_keys=False)
diff --git a/nnUNet/nnunetv2/ensembling/__init__.py b/nnUNet/nnunetv2/ensembling/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/ensembling/ensemble.py b/nnUNet/nnunetv2/ensembling/ensemble.py
new file mode 100644
index 0000000..68b378b
--- /dev/null
+++ b/nnUNet/nnunetv2/ensembling/ensemble.py
@@ -0,0 +1,206 @@
+import argparse
+import multiprocessing
+import shutil
+from copy import deepcopy
+from multiprocessing import Pool
+from typing import List, Union, Tuple
+
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import load_json, join, subfiles, \
+ maybe_mkdir_p, isdir, save_pickle, load_pickle, isfile
+from nnunetv2.configuration import default_num_processes
+from nnunetv2.imageio.base_reader_writer import BaseReaderWriter
+from nnunetv2.utilities.label_handling.label_handling import LabelManager
+from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
+
+
+def average_probabilities(list_of_files: List[str]) -> np.ndarray:
+ assert len(list_of_files), 'At least one file must be given in list_of_files'
+ avg = None
+ for f in list_of_files:
+ if avg is None:
+ avg = np.load(f)['probabilities']
+ # maybe increase precision to prevent rounding errors
+ if avg.dtype != np.float32:
+ avg = avg.astype(np.float32)
+ else:
+ avg += np.load(f)['probabilities']
+ avg /= len(list_of_files)
+ return avg
+
+
+def merge_files(list_of_files,
+ output_filename_truncated: str,
+ output_file_ending: str,
+ image_reader_writer: BaseReaderWriter,
+ label_manager: LabelManager,
+ save_probabilities: bool = False):
+ # load the pkl file associated with the first file in list_of_files
+ properties = load_pickle(list_of_files[0][:-4] + '.pkl')
+ # load and average predictions
+ probabilities = average_probabilities(list_of_files)
+ segmentation = label_manager.convert_logits_to_segmentation(probabilities)
+ image_reader_writer.write_seg(segmentation, output_filename_truncated + output_file_ending, properties)
+ if save_probabilities:
+ np.savez_compressed(output_filename_truncated + '.npz', probabilities=probabilities)
+ save_pickle(probabilities, output_filename_truncated + '.pkl')
+
+
+def ensemble_folders(list_of_input_folders: List[str],
+ output_folder: str,
+ save_merged_probabilities: bool = False,
+ num_processes: int = default_num_processes,
+ dataset_json_file_or_dict: str = None,
+ plans_json_file_or_dict: str = None):
+ """we need too much shit for this function. Problem is that we now have to support region-based training plus
+ multiple input/output formats so there isn't really a way around this.
+
+ If plans and dataset json are not specified, we assume each of the folders has a corresponding plans.json
+ and/or dataset.json in it. These are usually copied into those folders by nnU-Net during prediction.
+ We just pick the dataset.json and plans.json from the first of the folders and we DONT check whether the 5
+ folders contain the same plans etc! This can be a feature if results from different datasets are to be merged (only
+ works if label dict in dataset.json is the same between these datasets!!!)"""
+ if dataset_json_file_or_dict is not None:
+ if isinstance(dataset_json_file_or_dict, str):
+ dataset_json = load_json(dataset_json_file_or_dict)
+ else:
+ dataset_json = dataset_json_file_or_dict
+ else:
+ dataset_json = load_json(join(list_of_input_folders[0], 'dataset.json'))
+
+ if plans_json_file_or_dict is not None:
+ if isinstance(plans_json_file_or_dict, str):
+ plans = load_json(plans_json_file_or_dict)
+ else:
+ plans = plans_json_file_or_dict
+ else:
+ plans = load_json(join(list_of_input_folders[0], 'plans.json'))
+
+ plans_manager = PlansManager(plans)
+
+ # now collect the files in each of the folders and enforce that all files are present in all folders
+ files_per_folder = [set(subfiles(i, suffix='.npz', join=False)) for i in list_of_input_folders]
+ # first build a set with all files
+ s = deepcopy(files_per_folder[0])
+ for f in files_per_folder[1:]:
+ s.update(f)
+ for f in files_per_folder:
+ assert len(s.difference(f)) == 0, "Not all folders contain the same files for ensembling. Please only " \
+ "provide folders that contain the predictions"
+ lists_of_lists_of_files = [[join(fl, fi) for fl in list_of_input_folders] for fi in s]
+ output_files_truncated = [join(output_folder, fi[:-4]) for fi in s]
+
+ image_reader_writer = plans_manager.image_reader_writer_class()
+ label_manager = plans_manager.get_label_manager(dataset_json)
+
+ maybe_mkdir_p(output_folder)
+ shutil.copy(join(list_of_input_folders[0], 'dataset.json'), output_folder)
+
+ with multiprocessing.get_context("spawn").Pool(num_processes) as pool:
+ num_preds = len(s)
+ _ = pool.starmap(
+ merge_files,
+ zip(
+ lists_of_lists_of_files,
+ output_files_truncated,
+ [dataset_json['file_ending']] * num_preds,
+ [image_reader_writer] * num_preds,
+ [label_manager] * num_preds,
+ [save_merged_probabilities] * num_preds
+ )
+ )
+
+
+def entry_point_ensemble_folders():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-i', nargs='+', type=str, required=True,
+ help='list of input folders')
+ parser.add_argument('-o', type=str, required=True, help='output folder')
+ parser.add_argument('-np', type=int, required=False, default=default_num_processes,
+ help=f"Numbers of processes used for ensembling. Default: {default_num_processes}")
+ parser.add_argument('--save_npz', action='store_true', required=False, help='Set this flag to store output '
+ 'probabilities in separate .npz files')
+
+ args = parser.parse_args()
+ ensemble_folders(args.i, args.o, args.save_npz, args.np)
+
+
+def ensemble_crossvalidations(list_of_trained_model_folders: List[str],
+ output_folder: str,
+ folds: Union[Tuple[int, ...], List[int]] = (0, 1, 2, 3, 4),
+ num_processes: int = default_num_processes,
+ overwrite: bool = True) -> None:
+ """
+ Feature: different configurations can now have different splits
+ """
+ dataset_json = load_json(join(list_of_trained_model_folders[0], 'dataset.json'))
+ plans_manager = PlansManager(join(list_of_trained_model_folders[0], 'plans.json'))
+
+ # first collect all unique filenames
+ files_per_folder = {}
+ unique_filenames = set()
+ for tr in list_of_trained_model_folders:
+ files_per_folder[tr] = {}
+ for f in folds:
+ if not isdir(join(tr, f'fold_{f}', 'validation')):
+ raise RuntimeError(f'Expected model output directory does not exist. You must train all requested '
+ f'folds of the speficied model.\nModel: {tr}\nFold: {f}')
+ files_here = subfiles(join(tr, f'fold_{f}', 'validation'), suffix='.npz', join=False)
+ if len(files_here) == 0:
+ raise RuntimeError(f"No .npz files found in folder {join(tr, f'fold_{f}', 'validation')}. Rerun your "
+ f"validation with the --npz flag. Use nnUNetv2_train [...] --val --npz.")
+ files_per_folder[tr][f] = subfiles(join(tr, f'fold_{f}', 'validation'), suffix='.npz', join=False)
+ unique_filenames.update(files_per_folder[tr][f])
+
+ # verify that all trained_model_folders have all predictions
+ ok = True
+ for tr, fi in files_per_folder.items():
+ all_files_here = set()
+ for f in folds:
+ all_files_here.update(fi[f])
+ diff = unique_filenames.difference(all_files_here)
+ if len(diff) > 0:
+ ok = False
+ print(f'model {tr} does not seem to contain all predictions. Missing: {diff}')
+ if not ok:
+ raise RuntimeError('There were missing files, see print statements above this one')
+
+ # now we need to collect where these files are
+ file_mapping = []
+ for tr in list_of_trained_model_folders:
+ file_mapping.append({})
+ for f in folds:
+ for fi in files_per_folder[tr][f]:
+ # check for duplicates
+ assert fi not in file_mapping[-1].keys(), f"Duplicate detected. Case {fi} is present in more than " \
+ f"one fold of model {tr}."
+ file_mapping[-1][fi] = join(tr, f'fold_{f}', 'validation', fi)
+
+ lists_of_lists_of_files = [[fm[i] for fm in file_mapping] for i in unique_filenames]
+ output_files_truncated = [join(output_folder, fi[:-4]) for fi in unique_filenames]
+
+ image_reader_writer = plans_manager.image_reader_writer_class()
+ maybe_mkdir_p(output_folder)
+ label_manager = plans_manager.get_label_manager(dataset_json)
+
+ if not overwrite:
+ tmp = [isfile(i + dataset_json['file_ending']) for i in output_files_truncated]
+ lists_of_lists_of_files = [lists_of_lists_of_files[i] for i in range(len(tmp)) if not tmp[i]]
+ output_files_truncated = [output_files_truncated[i] for i in range(len(tmp)) if not tmp[i]]
+
+ with multiprocessing.get_context("spawn").Pool(num_processes) as pool:
+ num_preds = len(lists_of_lists_of_files)
+ _ = pool.starmap(
+ merge_files,
+ zip(
+ lists_of_lists_of_files,
+ output_files_truncated,
+ [dataset_json['file_ending']] * num_preds,
+ [image_reader_writer] * num_preds,
+ [label_manager] * num_preds,
+ [False] * num_preds
+ )
+ )
+
+ shutil.copy(join(list_of_trained_model_folders[0], 'plans.json'), join(output_folder, 'plans.json'))
+ shutil.copy(join(list_of_trained_model_folders[0], 'dataset.json'), join(output_folder, 'dataset.json'))
diff --git a/nnUNet/nnunetv2/evaluation/__init__.py b/nnUNet/nnunetv2/evaluation/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/evaluation/accumulate_cv_results.py b/nnUNet/nnunetv2/evaluation/accumulate_cv_results.py
new file mode 100644
index 0000000..6db2129
--- /dev/null
+++ b/nnUNet/nnunetv2/evaluation/accumulate_cv_results.py
@@ -0,0 +1,55 @@
+import shutil
+from typing import Union, List, Tuple
+
+from batchgenerators.utilities.file_and_folder_operations import load_json, join, isdir, maybe_mkdir_p, subfiles, isfile
+
+from nnunetv2.configuration import default_num_processes
+from nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder
+from nnunetv2.paths import nnUNet_raw
+from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
+
+
+def accumulate_cv_results(trained_model_folder,
+ merged_output_folder: str,
+ folds: Union[List[int], Tuple[int, ...]],
+ num_processes: int = default_num_processes,
+ overwrite: bool = True):
+ """
+ There are a lot of things that can get fucked up, so the simplest way to deal with potential problems is to
+ collect the cv results into a separate folder and then evaluate them again. No messing with summary_json files!
+ """
+
+ if overwrite and isdir(merged_output_folder):
+ shutil.rmtree(merged_output_folder)
+ maybe_mkdir_p(merged_output_folder)
+
+ dataset_json = load_json(join(trained_model_folder, 'dataset.json'))
+ plans_manager = PlansManager(join(trained_model_folder, 'plans.json'))
+ rw = plans_manager.image_reader_writer_class()
+ shutil.copy(join(trained_model_folder, 'dataset.json'), join(merged_output_folder, 'dataset.json'))
+ shutil.copy(join(trained_model_folder, 'plans.json'), join(merged_output_folder, 'plans.json'))
+
+ did_we_copy_something = False
+ for f in folds:
+ expected_validation_folder = join(trained_model_folder, f'fold_{f}', 'validation')
+ if not isdir(expected_validation_folder):
+ raise RuntimeError(f"fold {f} of model {trained_model_folder} is missing. Please train it!")
+ predicted_files = subfiles(expected_validation_folder, suffix=dataset_json['file_ending'], join=False)
+ for pf in predicted_files:
+ if overwrite and isfile(join(merged_output_folder, pf)):
+ raise RuntimeError(f'More than one of your folds has a prediction for case {pf}')
+ if overwrite or not isfile(join(merged_output_folder, pf)):
+ shutil.copy(join(expected_validation_folder, pf), join(merged_output_folder, pf))
+ did_we_copy_something = True
+
+ if did_we_copy_something or not isfile(join(merged_output_folder, 'summary.json')):
+ label_manager = plans_manager.get_label_manager(dataset_json)
+ compute_metrics_on_folder(join(nnUNet_raw, plans_manager.dataset_name, 'labelsTr'),
+ merged_output_folder,
+ join(merged_output_folder, 'summary.json'),
+ rw,
+ dataset_json['file_ending'],
+ label_manager.foreground_regions if label_manager.has_regions else
+ label_manager.foreground_labels,
+ label_manager.ignore_label,
+ num_processes)
diff --git a/nnUNet/nnunetv2/evaluation/evaluate_predictions.py b/nnUNet/nnunetv2/evaluation/evaluate_predictions.py
new file mode 100644
index 0000000..e692f78
--- /dev/null
+++ b/nnUNet/nnunetv2/evaluation/evaluate_predictions.py
@@ -0,0 +1,264 @@
+import multiprocessing
+import os
+from copy import deepcopy
+from multiprocessing import Pool
+from typing import Tuple, List, Union, Optional
+
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import subfiles, join, save_json, load_json, \
+ isfile
+from nnunetv2.configuration import default_num_processes
+from nnunetv2.imageio.base_reader_writer import BaseReaderWriter
+from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json, \
+ determine_reader_writer_from_file_ending
+from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO
+# the Evaluator class of the previous nnU-Net was great and all but man was it overengineered. Keep it simple
+from nnunetv2.utilities.json_export import recursive_fix_for_json_export
+from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
+
+
+def label_or_region_to_key(label_or_region: Union[int, Tuple[int]]):
+ return str(label_or_region)
+
+
+def key_to_label_or_region(key: str):
+ try:
+ return int(key)
+ except ValueError:
+ key = key.replace('(', '')
+ key = key.replace(')', '')
+ splitted = key.split(',')
+ return tuple([int(i) for i in splitted])
+
+
+def save_summary_json(results: dict, output_file: str):
+ """
+ stupid json does not support tuples as keys (why does it have to be so shitty) so we need to convert that shit
+ ourselves
+ """
+ results_converted = deepcopy(results)
+ # convert keys in mean metrics
+ results_converted['mean'] = {label_or_region_to_key(k): results['mean'][k] for k in results['mean'].keys()}
+ # convert metric_per_case
+ for i in range(len(results_converted["metric_per_case"])):
+ results_converted["metric_per_case"][i]['metrics'] = \
+ {label_or_region_to_key(k): results["metric_per_case"][i]['metrics'][k]
+ for k in results["metric_per_case"][i]['metrics'].keys()}
+ # sort_keys=True will make foreground_mean the first entry and thus easy to spot
+ save_json(results_converted, output_file, sort_keys=True)
+
+
+def load_summary_json(filename: str):
+ results = load_json(filename)
+ # convert keys in mean metrics
+ results['mean'] = {key_to_label_or_region(k): results['mean'][k] for k in results['mean'].keys()}
+ # convert metric_per_case
+ for i in range(len(results["metric_per_case"])):
+ results["metric_per_case"][i]['metrics'] = \
+ {key_to_label_or_region(k): results["metric_per_case"][i]['metrics'][k]
+ for k in results["metric_per_case"][i]['metrics'].keys()}
+ return results
+
+
+def labels_to_list_of_regions(labels: List[int]):
+ return [(i,) for i in labels]
+
+
+def region_or_label_to_mask(segmentation: np.ndarray, region_or_label: Union[int, Tuple[int, ...]]) -> np.ndarray:
+ if np.isscalar(region_or_label):
+ return segmentation == region_or_label
+ else:
+ mask = np.zeros_like(segmentation, dtype=bool)
+ for r in region_or_label:
+ mask[segmentation == r] = True
+ return mask
+
+
+def compute_tp_fp_fn_tn(mask_ref: np.ndarray, mask_pred: np.ndarray, ignore_mask: np.ndarray = None):
+ if ignore_mask is None:
+ use_mask = np.ones_like(mask_ref, dtype=bool)
+ else:
+ use_mask = ~ignore_mask
+ tp = np.sum((mask_ref & mask_pred) & use_mask)
+ fp = np.sum(((~mask_ref) & mask_pred) & use_mask)
+ fn = np.sum((mask_ref & (~mask_pred)) & use_mask)
+ tn = np.sum(((~mask_ref) & (~mask_pred)) & use_mask)
+ return tp, fp, fn, tn
+
+
+def compute_metrics(reference_file: str, prediction_file: str, image_reader_writer: BaseReaderWriter,
+ labels_or_regions: Union[List[int], List[Union[int, Tuple[int, ...]]]],
+ ignore_label: int = None) -> dict:
+ # load images
+ seg_ref, seg_ref_dict = image_reader_writer.read_seg(reference_file)
+ seg_pred, seg_pred_dict = image_reader_writer.read_seg(prediction_file)
+ # spacing = seg_ref_dict['spacing']
+
+ ignore_mask = seg_ref == ignore_label if ignore_label is not None else None
+
+ results = {}
+ results['reference_file'] = reference_file
+ results['prediction_file'] = prediction_file
+ results['metrics'] = {}
+ for r in labels_or_regions:
+ results['metrics'][r] = {}
+ mask_ref = region_or_label_to_mask(seg_ref, r)
+ mask_pred = region_or_label_to_mask(seg_pred, r)
+ tp, fp, fn, tn = compute_tp_fp_fn_tn(mask_ref, mask_pred, ignore_mask)
+ if tp + fp + fn == 0:
+ results['metrics'][r]['Dice'] = np.nan
+ results['metrics'][r]['IoU'] = np.nan
+ else:
+ results['metrics'][r]['Dice'] = 2 * tp / (2 * tp + fp + fn)
+ results['metrics'][r]['IoU'] = tp / (tp + fp + fn)
+ results['metrics'][r]['FP'] = fp
+ results['metrics'][r]['TP'] = tp
+ results['metrics'][r]['FN'] = fn
+ results['metrics'][r]['TN'] = tn
+ results['metrics'][r]['n_pred'] = fp + tp
+ results['metrics'][r]['n_ref'] = fn + tp
+ return results
+
+
+def compute_metrics_on_folder(folder_ref: str, folder_pred: str, output_file: str,
+ image_reader_writer: BaseReaderWriter,
+ file_ending: str,
+ regions_or_labels: Union[List[int], List[Union[int, Tuple[int, ...]]]],
+ ignore_label: int = None,
+ num_processes: int = default_num_processes,
+ chill: bool = True) -> dict:
+ """
+ output_file must end with .json; can be None
+ """
+ if output_file is not None:
+ assert output_file.endswith('.json'), 'output_file should end with .json'
+ files_pred = subfiles(folder_pred, suffix=file_ending, join=False)
+ files_ref = subfiles(folder_ref, suffix=file_ending, join=False)
+ if not chill:
+ present = [isfile(join(folder_pred, i)) for i in files_ref]
+ assert all(present), "Not all files in folder_pred exist in folder_ref"
+ files_ref = [join(folder_ref, i) for i in files_pred]
+ files_pred = [join(folder_pred, i) for i in files_pred]
+ with multiprocessing.get_context("spawn").Pool(num_processes) as pool:
+ # for i in list(zip(files_ref, files_pred, [image_reader_writer] * len(files_pred), [regions_or_labels] * len(files_pred), [ignore_label] * len(files_pred))):
+ # compute_metrics(*i)
+ results = pool.starmap(
+ compute_metrics,
+ list(zip(files_ref, files_pred, [image_reader_writer] * len(files_pred), [regions_or_labels] * len(files_pred),
+ [ignore_label] * len(files_pred)))
+ )
+
+ # mean metric per class
+ metric_list = list(results[0]['metrics'][regions_or_labels[0]].keys())
+ means = {}
+ for r in regions_or_labels:
+ means[r] = {}
+ for m in metric_list:
+ means[r][m] = np.nanmean([i['metrics'][r][m] for i in results])
+
+ # foreground mean
+ foreground_mean = {}
+ for m in metric_list:
+ values = []
+ for k in means.keys():
+ if k == 0 or k == '0':
+ continue
+ values.append(means[k][m])
+ foreground_mean[m] = np.mean(values)
+
+ [recursive_fix_for_json_export(i) for i in results]
+ recursive_fix_for_json_export(means)
+ recursive_fix_for_json_export(foreground_mean)
+ result = {'metric_per_case': results, 'mean': means, 'foreground_mean': foreground_mean}
+ if output_file is not None:
+ save_summary_json(result, output_file)
+ return result
+ # print('DONE')
+
+
+def compute_metrics_on_folder2(folder_ref: str, folder_pred: str, dataset_json_file: str, plans_file: str,
+ output_file: str = None,
+ num_processes: int = default_num_processes,
+ chill: bool = False):
+ dataset_json = load_json(dataset_json_file)
+ # get file ending
+ file_ending = dataset_json['file_ending']
+
+ # get reader writer class
+ example_file = subfiles(folder_ref, suffix=file_ending, join=True)[0]
+ rw = determine_reader_writer_from_dataset_json(dataset_json, example_file)()
+
+ # maybe auto set output file
+ if output_file is None:
+ output_file = join(folder_pred, 'summary.json')
+
+ lm = PlansManager(plans_file).get_label_manager(dataset_json)
+ compute_metrics_on_folder(folder_ref, folder_pred, output_file, rw, file_ending,
+ lm.foreground_regions if lm.has_regions else lm.foreground_labels, lm.ignore_label,
+ num_processes, chill=chill)
+
+
+def compute_metrics_on_folder_simple(folder_ref: str, folder_pred: str, labels: Union[Tuple[int, ...], List[int]],
+ output_file: str = None,
+ num_processes: int = default_num_processes,
+ ignore_label: int = None,
+ chill: bool = False):
+ example_file = subfiles(folder_ref, join=True)[0]
+ file_ending = os.path.splitext(example_file)[-1]
+ rw = determine_reader_writer_from_file_ending(file_ending, example_file, allow_nonmatching_filename=True,
+ verbose=False)()
+ # maybe auto set output file
+ if output_file is None:
+ output_file = join(folder_pred, 'summary.json')
+ compute_metrics_on_folder(folder_ref, folder_pred, output_file, rw, file_ending,
+ labels, ignore_label=ignore_label, num_processes=num_processes, chill=chill)
+
+
+def evaluate_folder_entry_point():
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('gt_folder', type=str, help='folder with gt segmentations')
+ parser.add_argument('pred_folder', type=str, help='folder with predicted segmentations')
+ parser.add_argument('-djfile', type=str, required=True,
+ help='dataset.json file')
+ parser.add_argument('-pfile', type=str, required=True,
+ help='plans.json file')
+ parser.add_argument('-o', type=str, required=False, default=None,
+ help='Output file. Optional. Default: pred_folder/summary.json')
+ parser.add_argument('-np', type=int, required=False, default=default_num_processes,
+ help=f'number of processes used. Optional. Default: {default_num_processes}')
+ parser.add_argument('--chill', action='store_true', help='dont crash if folder_pred doesnt have all files that are present in folder_gt')
+ args = parser.parse_args()
+ compute_metrics_on_folder2(args.gt_folder, args.pred_folder, args.djfile, args.pfile, args.o, args.np, chill=args.chill)
+
+
+def evaluate_simple_entry_point():
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('gt_folder', type=str, help='folder with gt segmentations')
+ parser.add_argument('pred_folder', type=str, help='folder with predicted segmentations')
+ parser.add_argument('-l', type=int, nargs='+', required=True,
+ help='list of labels')
+ parser.add_argument('-il', type=int, required=False, default=None,
+ help='ignore label')
+ parser.add_argument('-o', type=str, required=False, default=None,
+ help='Output file. Optional. Default: pred_folder/summary.json')
+ parser.add_argument('-np', type=int, required=False, default=default_num_processes,
+ help=f'number of processes used. Optional. Default: {default_num_processes}')
+ parser.add_argument('--chill', action='store_true', help='dont crash if folder_pred doesnt have all files that are present in folder_gt')
+
+ args = parser.parse_args()
+ compute_metrics_on_folder_simple(args.gt_folder, args.pred_folder, args.l, args.o, args.np, args.il, chill=args.chill)
+
+
+if __name__ == '__main__':
+ folder_ref = '/media/fabian/data/nnUNet_raw/Dataset004_Hippocampus/labelsTr'
+ folder_pred = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetModule__nnUNetPlans__3d_fullres/fold_0/validation'
+ output_file = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetModule__nnUNetPlans__3d_fullres/fold_0/validation/summary.json'
+ image_reader_writer = SimpleITKIO()
+ file_ending = '.nii.gz'
+ regions = labels_to_list_of_regions([1, 2])
+ ignore_label = None
+ num_processes = 12
+ compute_metrics_on_folder(folder_ref, folder_pred, output_file, image_reader_writer, file_ending, regions, ignore_label,
+ num_processes)
diff --git a/nnUNet/nnunetv2/evaluation/find_best_configuration.py b/nnUNet/nnunetv2/evaluation/find_best_configuration.py
new file mode 100644
index 0000000..c36008b
--- /dev/null
+++ b/nnUNet/nnunetv2/evaluation/find_best_configuration.py
@@ -0,0 +1,333 @@
+import argparse
+import os.path
+from copy import deepcopy
+from typing import Union, List, Tuple
+
+from batchgenerators.utilities.file_and_folder_operations import load_json, join, isdir, save_json
+
+from nnunetv2.configuration import default_num_processes
+from nnunetv2.ensembling.ensemble import ensemble_crossvalidations
+from nnunetv2.evaluation.accumulate_cv_results import accumulate_cv_results
+from nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder, load_summary_json
+from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw, nnUNet_results
+from nnunetv2.postprocessing.remove_connected_components import determine_postprocessing
+from nnunetv2.utilities.file_path_utilities import maybe_convert_to_dataset_name, get_output_folder, \
+ convert_identifier_to_trainer_plans_config, get_ensemble_name, folds_tuple_to_string
+from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
+
+default_trained_models = tuple([
+ {'plans': 'nnUNetPlans', 'configuration': '2d', 'trainer': 'nnUNetTrainer'},
+ {'plans': 'nnUNetPlans', 'configuration': '3d_fullres', 'trainer': 'nnUNetTrainer'},
+ {'plans': 'nnUNetPlans', 'configuration': '3d_lowres', 'trainer': 'nnUNetTrainer'},
+ {'plans': 'nnUNetPlans', 'configuration': '3d_cascade_fullres', 'trainer': 'nnUNetTrainer'},
+])
+
+
+def filter_available_models(model_dict: Union[List[dict], Tuple[dict, ...]], dataset_name_or_id: Union[str, int]):
+ valid = []
+ for trained_model in model_dict:
+ plans_manager = PlansManager(join(nnUNet_preprocessed, maybe_convert_to_dataset_name(dataset_name_or_id),
+ trained_model['plans'] + '.json'))
+ # check if configuration exists
+ # 3d_cascade_fullres and 3d_lowres do not exist for each dataset so we allow them to be absent IF they are not
+ # specified in the plans file
+ if trained_model['configuration'] not in plans_manager.available_configurations:
+ print(f"Configuration {trained_model['configuration']} not found in plans {trained_model['plans']}.\n"
+ f"Inferred plans file: {join(nnUNet_preprocessed, maybe_convert_to_dataset_name(dataset_name_or_id), trained_model['plans'] + '.json')}.")
+ continue
+
+ # check if trained model output folder exists. This is a requirement. No mercy here.
+ expected_output_folder = get_output_folder(dataset_name_or_id, trained_model['trainer'], trained_model['plans'],
+ trained_model['configuration'], fold=None)
+ if not isdir(expected_output_folder):
+ raise RuntimeError(f"Trained model {trained_model} does not have an output folder. "
+ f"Expected: {expected_output_folder}. Please run the training for this model! (don't forget "
+ f"the --npz flag if you want to ensemble multiple configurations)")
+
+ valid.append(trained_model)
+ return valid
+
+
+def generate_inference_command(dataset_name_or_id: Union[int, str], configuration_name: str,
+ plans_identifier: str = 'nnUNetPlans', trainer_name: str = 'nnUNetTrainer',
+ folds: Union[List[int], Tuple[int, ...]] = (0, 1, 2, 3, 4),
+ folder_with_segs_from_prev_stage: str = None,
+ input_folder: str = 'INPUT_FOLDER',
+ output_folder: str = 'OUTPUT_FOLDER',
+ save_npz: bool = False):
+ fold_str = ''
+ for f in folds:
+ fold_str += f' {f}'
+
+ predict_command = ''
+ trained_model_folder = get_output_folder(dataset_name_or_id, trainer_name, plans_identifier, configuration_name, fold=None)
+ plans_manager = PlansManager(join(trained_model_folder, 'plans.json'))
+ configuration_manager = plans_manager.get_configuration(configuration_name)
+ if 'previous_stage' in plans_manager.available_configurations:
+ prev_stage = configuration_manager.previous_stage_name
+ predict_command += generate_inference_command(dataset_name_or_id, prev_stage, plans_identifier, trainer_name,
+ folds, None, output_folder='OUTPUT_FOLDER_PREV_STAGE') + '\n'
+ folder_with_segs_from_prev_stage = 'OUTPUT_FOLDER_PREV_STAGE'
+
+ predict_command = f'nnUNetv2_predict -d {dataset_name_or_id} -i {input_folder} -o {output_folder} -f {fold_str} ' \
+ f'-tr {trainer_name} -c {configuration_name} -p {plans_identifier}'
+ if folder_with_segs_from_prev_stage is not None:
+ predict_command += f' -prev_stage_predictions {folder_with_segs_from_prev_stage}'
+ if save_npz:
+ predict_command += ' --save_probabilities'
+ return predict_command
+
+
+def find_best_configuration(dataset_name_or_id,
+ allowed_trained_models: Union[List[dict], Tuple[dict, ...]] = default_trained_models,
+ allow_ensembling: bool = True,
+ num_processes: int = default_num_processes,
+ overwrite: bool = True,
+ folds: Union[List[int], Tuple[int, ...]] = (0, 1, 2, 3, 4),
+ strict: bool = False):
+ dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id)
+ all_results = {}
+
+ allowed_trained_models = filter_available_models(deepcopy(allowed_trained_models), dataset_name_or_id)
+
+ for m in allowed_trained_models:
+ output_folder = get_output_folder(dataset_name_or_id, m['trainer'], m['plans'], m['configuration'], fold=None)
+ if not isdir(output_folder) and strict:
+ raise RuntimeError(f'{dataset_name}: The output folder of plans {m["plans"]} configuration '
+ f'{m["configuration"]} is missing. Please train the model (all requested folds!) first!')
+ identifier = os.path.basename(output_folder)
+ merged_output_folder = join(output_folder, f'crossval_results_folds_{folds_tuple_to_string(folds)}')
+ accumulate_cv_results(output_folder, merged_output_folder, folds, num_processes, overwrite)
+ all_results[identifier] = {
+ 'source': merged_output_folder,
+ 'result': load_summary_json(join(merged_output_folder, 'summary.json'))['foreground_mean']['Dice']
+ }
+
+ if allow_ensembling:
+ for i in range(len(allowed_trained_models)):
+ for j in range(i + 1, len(allowed_trained_models)):
+ m1, m2 = allowed_trained_models[i], allowed_trained_models[j]
+
+ output_folder_1 = get_output_folder(dataset_name_or_id, m1['trainer'], m1['plans'], m1['configuration'], fold=None)
+ output_folder_2 = get_output_folder(dataset_name_or_id, m2['trainer'], m2['plans'], m2['configuration'], fold=None)
+ identifier = get_ensemble_name(output_folder_1, output_folder_2, folds)
+
+ output_folder_ensemble = join(nnUNet_results, dataset_name, 'ensembles', identifier)
+
+ ensemble_crossvalidations([output_folder_1, output_folder_2], output_folder_ensemble, folds,
+ num_processes, overwrite=overwrite)
+
+ # evaluate ensembled predictions
+ plans_manager = PlansManager(join(output_folder_1, 'plans.json'))
+ dataset_json = load_json(join(output_folder_1, 'dataset.json'))
+ label_manager = plans_manager.get_label_manager(dataset_json)
+ rw = plans_manager.image_reader_writer_class()
+
+ compute_metrics_on_folder(join(nnUNet_preprocessed, dataset_name, 'gt_segmentations'),
+ output_folder_ensemble,
+ join(output_folder_ensemble, 'summary.json'),
+ rw,
+ dataset_json['file_ending'],
+ label_manager.foreground_regions if label_manager.has_regions else
+ label_manager.foreground_labels,
+ label_manager.ignore_label,
+ num_processes)
+ all_results[identifier] = \
+ {
+ 'source': output_folder_ensemble,
+ 'result': load_summary_json(join(output_folder_ensemble, 'summary.json'))['foreground_mean']['Dice']
+ }
+
+ # pick best and report inference command
+ best_score = max([i['result'] for i in all_results.values()])
+ best_keys = [k for k in all_results.keys() if all_results[k]['result'] == best_score] # may never happen but theoretically
+ # there can be a tie. Let's pick the first model in this case because it's going to be the simpler one (ensembles
+ # come after single configs)
+ best_key = best_keys[0]
+
+ print()
+ print('***All results:***')
+ for k, v in all_results.items():
+ print(f'{k}: {v["result"]}')
+ print(f'\n*Best*: {best_key}: {all_results[best_key]["result"]}')
+ print()
+
+ print('***Determining postprocessing for best model/ensemble***')
+ determine_postprocessing(all_results[best_key]['source'], join(nnUNet_preprocessed, dataset_name, 'gt_segmentations'),
+ plans_file_or_dict=join(all_results[best_key]['source'], 'plans.json'),
+ dataset_json_file_or_dict=join(all_results[best_key]['source'], 'dataset.json'),
+ num_processes=num_processes, keep_postprocessed_files=True)
+
+ # in addition to just reading the console output (how it was previously) we should return the information
+ # needed to run the full inference via API
+ return_dict = {
+ 'folds': folds,
+ 'dataset_name_or_id': dataset_name_or_id,
+ 'considered_models': allowed_trained_models,
+ 'ensembling_allowed': allow_ensembling,
+ 'all_results': {i: j['result'] for i, j in all_results.items()},
+ 'best_model_or_ensemble': {
+ 'result_on_crossval_pre_pp': all_results[best_key]["result"],
+ 'result_on_crossval_post_pp': load_json(join(all_results[best_key]['source'], 'postprocessed', 'summary.json'))['foreground_mean']['Dice'],
+ 'postprocessing_file': join(all_results[best_key]['source'], 'postprocessing.pkl'),
+ 'some_plans_file': join(all_results[best_key]['source'], 'plans.json'),
+ # just needed for label handling, can
+ # come from any of the ensemble members (if any)
+ 'selected_model_or_models': []
+ }
+ }
+ # convert best key to inference command:
+ if best_key.startswith('ensemble___'):
+ prefix, m1, m2, folds_string = best_key.split('___')
+ tr1, pl1, c1 = convert_identifier_to_trainer_plans_config(m1)
+ tr2, pl2, c2 = convert_identifier_to_trainer_plans_config(m2)
+ return_dict['best_model_or_ensemble']['selected_model_or_models'].append(
+ {
+ 'configuration': c1,
+ 'trainer': tr1,
+ 'plans_identifier': pl1,
+ })
+ return_dict['best_model_or_ensemble']['selected_model_or_models'].append(
+ {
+ 'configuration': c2,
+ 'trainer': tr2,
+ 'plans_identifier': pl2,
+ })
+ else:
+ tr, pl, c = convert_identifier_to_trainer_plans_config(best_key)
+ return_dict['best_model_or_ensemble']['selected_model_or_models'].append(
+ {
+ 'configuration': c,
+ 'trainer': tr,
+ 'plans_identifier': pl,
+ })
+
+ save_json(return_dict, join(nnUNet_results, dataset_name, 'inference_information.json')) # save this so that we don't have to run this
+ # everything someone wants to be reminded of the inference commands. They can just load this and give it to
+ # print_inference_instructions
+
+ # print it
+ print_inference_instructions(return_dict, instructions_file=join(nnUNet_results, dataset_name, 'inference_instructions.txt'))
+ return return_dict
+
+
+def print_inference_instructions(inference_info_dict: dict, instructions_file: str = None):
+ def _print_and_maybe_write_to_file(string):
+ print(string)
+ if f_handle is not None:
+ f_handle.write(f'{string}\n')
+
+ f_handle = open(instructions_file, 'w') if instructions_file is not None else None
+ print()
+ _print_and_maybe_write_to_file('***Run inference like this:***\n')
+ output_folders = []
+
+ dataset_name_or_id = inference_info_dict['dataset_name_or_id']
+ if len(inference_info_dict['best_model_or_ensemble']['selected_model_or_models']) > 1:
+ is_ensemble = True
+ _print_and_maybe_write_to_file('An ensemble won! What a surprise! Run the following commands to run predictions with the ensemble members:\n')
+ else:
+ is_ensemble = False
+
+ for j, i in enumerate(inference_info_dict['best_model_or_ensemble']['selected_model_or_models']):
+ tr, c, pl = i['trainer'], i['configuration'], i['plans_identifier']
+ if is_ensemble:
+ output_folder_name = f"OUTPUT_FOLDER_MODEL_{j+1}"
+ else:
+ output_folder_name = f"OUTPUT_FOLDER"
+ output_folders.append(output_folder_name)
+
+ _print_and_maybe_write_to_file(generate_inference_command(dataset_name_or_id, c, pl, tr, inference_info_dict['folds'],
+ save_npz=is_ensemble, output_folder=output_folder_name))
+
+ if is_ensemble:
+ output_folder_str = output_folders[0]
+ for o in output_folders[1:]:
+ output_folder_str += f' {o}'
+ output_ensemble = f"OUTPUT_FOLDER"
+ _print_and_maybe_write_to_file('\nThe run ensembling with:\n')
+ _print_and_maybe_write_to_file(f"nnUNetv2_ensemble -i {output_folder_str} -o {output_ensemble} -np {default_num_processes}")
+
+ _print_and_maybe_write_to_file("\n***Once inference is completed, run postprocessing like this:***\n")
+ _print_and_maybe_write_to_file(f"nnUNetv2_apply_postprocessing -i OUTPUT_FOLDER -o OUTPUT_FOLDER_PP "
+ f"-pp_pkl_file {inference_info_dict['best_model_or_ensemble']['postprocessing_file']} -np {default_num_processes} "
+ f"-plans_json {inference_info_dict['best_model_or_ensemble']['some_plans_file']}")
+
+
+def dumb_trainer_config_plans_to_trained_models_dict(trainers: List[str], configs: List[str], plans: List[str]):
+ """
+ function is called dumb because it's dumb
+ """
+ ret = []
+ for t in trainers:
+ for c in configs:
+ for p in plans:
+ ret.append(
+ {'plans': p, 'configuration': c, 'trainer': t}
+ )
+ return tuple(ret)
+
+
+def find_best_configuration_entry_point():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('dataset_name_or_id', type=str, help='Dataset Name or id')
+ parser.add_argument('-p', nargs='+', required=False, default=['nnUNetPlans'],
+ help='List of plan identifiers. Default: nnUNetPlans')
+ parser.add_argument('-c', nargs='+', required=False, default=['2d', '3d_fullres', '3d_lowres', '3d_cascade_fullres'],
+ help="List of configurations. Default: ['2d', '3d_fullres', '3d_lowres', '3d_cascade_fullres']")
+ parser.add_argument('-tr', nargs='+', required=False, default=['nnUNetTrainer'],
+ help='List of trainers. Default: nnUNetTrainer')
+ parser.add_argument('-np', required=False, default=default_num_processes, type=int,
+ help='Number of processes to use for ensembling, postprocessing etc')
+ parser.add_argument('-f', nargs='+', type=int, default=(0, 1, 2, 3, 4),
+ help='Folds to use. Default: 0 1 2 3 4')
+ parser.add_argument('--disable_ensembling', action='store_true', required=False,
+ help='Set this flag to disable ensembling')
+ parser.add_argument('--no_overwrite', action='store_true',
+ help='If set we will not overwrite already ensembled files etc. May speed up concecutive '
+ 'runs of this command (why would oyu want to do that?) at the risk of not updating '
+ 'outdated results.')
+ args = parser.parse_args()
+
+ model_dict = dumb_trainer_config_plans_to_trained_models_dict(args.tr, args.c, args.p)
+ dataset_name = maybe_convert_to_dataset_name(args.dataset_name_or_id)
+
+ find_best_configuration(dataset_name, model_dict, allow_ensembling=not args.disable_ensembling,
+ num_processes=args.np, overwrite=not args.no_overwrite, folds=args.f,
+ strict=False)
+
+
+def accumulate_crossval_results_entry_point():
+ parser = argparse.ArgumentParser('Copies all predicted segmentations from the individual folds into one joint '
+ 'folder and evaluates them')
+ parser.add_argument('dataset_name_or_id', type=str, help='Dataset Name or id')
+ parser.add_argument('-c', type=str, required=True,
+ default='3d_fullres',
+ help="Configuration")
+ parser.add_argument('-o', type=str, required=False, default=None,
+ help="Output folder. If not specified, the output folder will be located in the trained " \
+ "model directory (named crossval_results_folds_XXX).")
+ parser.add_argument('-f', nargs='+', type=int, default=(0, 1, 2, 3, 4),
+ help='Folds to use. Default: 0 1 2 3 4')
+ parser.add_argument('-p', type=str, required=False, default='nnUNetPlans',
+ help='Plan identifier in which to search for the specified configuration. Default: nnUNetPlans')
+ parser.add_argument('-tr', type=str, required=False, default='nnUNetTrainer',
+ help='Trainer class. Default: nnUNetTrainer')
+ args = parser.parse_args()
+ trained_model_folder = get_output_folder(args.dataset_name_or_id, args.tr, args.p, args.c)
+
+ if args.o is None:
+ merged_output_folder = join(trained_model_folder, f'crossval_results_folds_{folds_tuple_to_string(args.f)}')
+ else:
+ merged_output_folder = args.o
+
+ accumulate_cv_results(trained_model_folder, merged_output_folder, args.f)
+
+
+if __name__ == '__main__':
+ find_best_configuration(4,
+ default_trained_models,
+ True,
+ 8,
+ False,
+ (0, 1, 2, 3, 4))
diff --git a/nnUNet/nnunetv2/experiment_planning/__init__.py b/nnUNet/nnunetv2/experiment_planning/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/experiment_planning/dataset_fingerprint/__init__.py b/nnUNet/nnunetv2/experiment_planning/dataset_fingerprint/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/experiment_planning/dataset_fingerprint/fingerprint_extractor.py b/nnUNet/nnunetv2/experiment_planning/dataset_fingerprint/fingerprint_extractor.py
new file mode 100644
index 0000000..8280518
--- /dev/null
+++ b/nnUNet/nnunetv2/experiment_planning/dataset_fingerprint/fingerprint_extractor.py
@@ -0,0 +1,200 @@
+import multiprocessing
+import os
+from time import sleep
+from typing import List, Type, Union
+
+import nibabel as nib
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import load_json, join, save_json, isfile, maybe_mkdir_p
+from tqdm import tqdm
+
+from nnunetv2.imageio.base_reader_writer import BaseReaderWriter
+from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json
+from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed
+from nnunetv2.preprocessing.cropping.cropping import crop_to_nonzero
+from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
+from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets
+
+
+class DatasetFingerprintExtractor(object):
+ def __init__(self, dataset_name_or_id: Union[str, int], num_processes: int = 8, verbose: bool = False):
+ """
+ extracts the dataset fingerprint used for experiment planning. The dataset fingerprint will be saved as a
+ json file in the input_folder
+
+ Philosophy here is to do only what we really need. Don't store stuff that we can easily read from somewhere
+ else. Don't compute stuff we don't need (except for intensity_statistics_per_channel)
+ """
+ dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id)
+ self.verbose = verbose
+
+ self.dataset_name = dataset_name
+ self.input_folder = join(nnUNet_raw, dataset_name)
+ self.num_processes = num_processes
+ self.dataset_json = load_json(join(self.input_folder, 'dataset.json'))
+ self.dataset = get_filenames_of_train_images_and_targets(self.input_folder, self.dataset_json)
+
+ # We don't want to use all foreground voxels because that can accumulate a lot of data (out of memory). It is
+ # also not critically important to get all pixels as long as there are enough. Let's use 10e7 voxels in total
+ # (for the entire dataset)
+ self.num_foreground_voxels_for_intensitystats = 10e7
+
+ @staticmethod
+ def collect_foreground_intensities(segmentation: np.ndarray, images: np.ndarray, seed: int = 1234,
+ num_samples: int = 10000):
+ """
+ images=image with multiple channels = shape (c, x, y(, z))
+ """
+ assert len(images.shape) == 4
+ assert len(segmentation.shape) == 4
+
+ assert not np.any(np.isnan(segmentation)), "Segmentation contains NaN values. grrrr.... :-("
+ assert not np.any(np.isnan(images)), "Images contains NaN values. grrrr.... :-("
+
+ rs = np.random.RandomState(seed)
+
+ intensities_per_channel = []
+ # we don't use the intensity_statistics_per_channel at all, it's just something that might be nice to have
+ intensity_statistics_per_channel = []
+
+ # segmentation is 4d: 1,x,y,z. We need to remove the empty dimension for the following code to work
+ foreground_mask = segmentation[0] > 0
+
+ for i in range(len(images)):
+ foreground_pixels = images[i][foreground_mask]
+ num_fg = len(foreground_pixels)
+ # sample with replacement so that we don't get issues with cases that have less than num_samples
+ # foreground_pixels. We could also just sample less in those cases but that would than cause these
+ # training cases to be underrepresented
+ intensities_per_channel.append(
+ rs.choice(foreground_pixels, num_samples, replace=True) if num_fg > 0 else [])
+ intensity_statistics_per_channel.append({
+ 'mean': np.mean(foreground_pixels) if num_fg > 0 else np.nan,
+ 'median': np.median(foreground_pixels) if num_fg > 0 else np.nan,
+ 'min': np.min(foreground_pixels) if num_fg > 0 else np.nan,
+ 'max': np.max(foreground_pixels) if num_fg > 0 else np.nan,
+ 'percentile_99_5': np.percentile(foreground_pixels, 99.5) if num_fg > 0 else np.nan,
+ 'percentile_00_5': np.percentile(foreground_pixels, 0.5) if num_fg > 0 else np.nan,
+
+ })
+
+ return intensities_per_channel, intensity_statistics_per_channel
+
+ @staticmethod
+ def analyze_case(image_files: List[str], segmentation_file: str, reader_writer_class: Type[BaseReaderWriter],
+ num_samples: int = 10000):
+ rw = reader_writer_class()
+ images, properties_images = rw.read_images(image_files)
+ segmentation, properties_seg = rw.read_seg(segmentation_file)
+
+ # we no longer crop and save the cropped images before this is run. Instead we run the cropping on the fly.
+ # Downside is that we need to do this twice (once here and once during preprocessing). Upside is that we don't
+ # need to save the cropped data anymore. Given that cropping is not too expensive it makes sense to do it this
+ # way. This is only possible because we are now using our new input/output interface.
+ data_cropped, seg_cropped, bbox = crop_to_nonzero(images, segmentation)
+
+ foreground_intensities_per_channel, foreground_intensity_stats_per_channel = \
+ DatasetFingerprintExtractor.collect_foreground_intensities(seg_cropped, data_cropped,
+ num_samples=num_samples)
+
+ spacing = properties_images['spacing']
+
+ shape_before_crop = images.shape[1:]
+ shape_after_crop = data_cropped.shape[1:]
+ relative_size_after_cropping = np.prod(shape_after_crop) / np.prod(shape_before_crop)
+ return shape_after_crop, spacing, foreground_intensities_per_channel, foreground_intensity_stats_per_channel, \
+ relative_size_after_cropping
+
+ def run(self, overwrite_existing: bool = False) -> dict:
+ # we do not save the properties file in self.input_folder because that folder might be read-only. We can only
+ # reliably write in nnUNet_preprocessed and nnUNet_results, so nnUNet_preprocessed it is
+ preprocessed_output_folder = join(nnUNet_preprocessed, self.dataset_name)
+ maybe_mkdir_p(preprocessed_output_folder)
+ properties_file = join(preprocessed_output_folder, 'dataset_fingerprint.json')
+
+ if not isfile(properties_file) or overwrite_existing:
+ reader_writer_class = determine_reader_writer_from_dataset_json(self.dataset_json,
+ # yikes. Rip the following line
+ self.dataset[self.dataset.keys().__iter__().__next__()]['images'][0])
+
+ # determine how many foreground voxels we need to sample per training case
+ num_foreground_samples_per_case = int(self.num_foreground_voxels_for_intensitystats //
+ len(self.dataset))
+
+ r = []
+ with multiprocessing.get_context("spawn").Pool(self.num_processes) as p:
+ for k in self.dataset.keys():
+ r.append(p.starmap_async(DatasetFingerprintExtractor.analyze_case,
+ ((self.dataset[k]['images'], self.dataset[k]['label'], reader_writer_class,
+ num_foreground_samples_per_case),)))
+ remaining = list(range(len(self.dataset)))
+ # p is pretty nifti. If we kill workers they just respawn but don't do any work.
+ # So we need to store the original pool of workers.
+ workers = [j for j in p._pool]
+ with tqdm(desc=None, total=len(self.dataset), disable=self.verbose) as pbar:
+ while len(remaining) > 0:
+ all_alive = all([j.is_alive() for j in workers])
+ if not all_alive:
+ raise RuntimeError('Some background worker is 6 feet under. Yuck. \n'
+ 'OK jokes aside.\n'
+ 'One of your background processes is missing. This could be because of '
+ 'an error (look for an error message) or because it was killed '
+ 'by your OS due to running out of RAM. If you don\'t see '
+ 'an error message, out of RAM is likely the problem. In that case '
+ 'reducing the number of workers might help')
+ done = [i for i in remaining if r[i].ready()]
+ for _ in done:
+ pbar.update()
+ remaining = [i for i in remaining if i not in done]
+ sleep(0.1)
+
+ # results = ptqdm(DatasetFingerprintExtractor.analyze_case,
+ # (training_images_per_case, training_labels_per_case),
+ # processes=self.num_processes, zipped=True, reader_writer_class=reader_writer_class,
+ # num_samples=num_foreground_samples_per_case, disable=self.verbose)
+ results = [i.get()[0] for i in r]
+
+ shapes_after_crop = [r[0] for r in results]
+ spacings = [r[1] for r in results]
+ foreground_intensities_per_channel = [np.concatenate([r[2][i] for r in results]) for i in
+ range(len(results[0][2]))]
+ # we drop this so that the json file is somewhat human readable
+ # foreground_intensity_stats_by_case_and_modality = [r[3] for r in results]
+ median_relative_size_after_cropping = np.median([r[4] for r in results], 0)
+
+ num_channels = len(self.dataset_json['channel_names'].keys()
+ if 'channel_names' in self.dataset_json.keys()
+ else self.dataset_json['modality'].keys())
+ intensity_statistics_per_channel = {}
+ for i in range(num_channels):
+ intensity_statistics_per_channel[i] = {
+ 'mean': float(np.mean(foreground_intensities_per_channel[i])),
+ 'median': float(np.median(foreground_intensities_per_channel[i])),
+ 'std': float(np.std(foreground_intensities_per_channel[i])),
+ 'min': float(np.min(foreground_intensities_per_channel[i])),
+ 'max': float(np.max(foreground_intensities_per_channel[i])),
+ 'percentile_99_5': float(np.percentile(foreground_intensities_per_channel[i], 99.5)),
+ 'percentile_00_5': float(np.percentile(foreground_intensities_per_channel[i], 0.5)),
+ }
+
+ fingerprint = {
+ "spacings": spacings,
+ "shapes_after_crop": shapes_after_crop,
+ 'foreground_intensity_properties_per_channel': intensity_statistics_per_channel,
+ "median_relative_size_after_cropping": median_relative_size_after_cropping
+ }
+
+ try:
+ save_json(fingerprint, properties_file)
+ except Exception as e:
+ if isfile(properties_file):
+ os.remove(properties_file)
+ raise e
+ else:
+ fingerprint = load_json(properties_file)
+ return fingerprint
+
+
+if __name__ == '__main__':
+ dfe = DatasetFingerprintExtractor(2, 8)
+ dfe.run(overwrite_existing=False)
diff --git a/nnUNet/nnunetv2/experiment_planning/experiment_planners/__init__.py b/nnUNet/nnunetv2/experiment_planning/experiment_planners/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py b/nnUNet/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py
new file mode 100644
index 0000000..55d841e
--- /dev/null
+++ b/nnUNet/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py
@@ -0,0 +1,534 @@
+import os.path
+import shutil
+from copy import deepcopy
+from functools import lru_cache
+from typing import List, Union, Tuple, Type
+
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import load_json, join, save_json, isfile, maybe_mkdir_p
+from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet
+from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm
+
+from nnunetv2.configuration import ANISO_THRESHOLD
+from nnunetv2.experiment_planning.experiment_planners.network_topology import get_pool_and_conv_props
+from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json
+from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed
+from nnunetv2.preprocessing.normalization.map_channel_name_to_normalization import get_normalization_scheme
+from nnunetv2.preprocessing.resampling.default_resampling import resample_data_or_seg_to_shape, compute_new_shape
+from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
+from nnunetv2.utilities.json_export import recursive_fix_for_json_export
+from nnunetv2.utilities.utils import get_identifiers_from_splitted_dataset_folder, \
+ get_filenames_of_train_images_and_targets
+
+
+class ExperimentPlanner(object):
+ def __init__(self, dataset_name_or_id: Union[str, int],
+ gpu_memory_target_in_gb: float = 8,
+ preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetPlans',
+ overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,
+ suppress_transpose: bool = False):
+ """
+ overwrite_target_spacing only affects 3d_fullres! (but by extension 3d_lowres which starts with fullres may
+ also be affected
+ """
+
+ self.dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id)
+ self.suppress_transpose = suppress_transpose
+ self.raw_dataset_folder = join(nnUNet_raw, self.dataset_name)
+ preprocessed_folder = join(nnUNet_preprocessed, self.dataset_name)
+ self.dataset_json = load_json(join(self.raw_dataset_folder, 'dataset.json'))
+ self.dataset = get_filenames_of_train_images_and_targets(self.raw_dataset_folder, self.dataset_json)
+
+ # load dataset fingerprint
+ if not isfile(join(preprocessed_folder, 'dataset_fingerprint.json')):
+ raise RuntimeError('Fingerprint missing for this dataset. Please run nnUNet_extract_dataset_fingerprint')
+
+ self.dataset_fingerprint = load_json(join(preprocessed_folder, 'dataset_fingerprint.json'))
+
+ self.anisotropy_threshold = ANISO_THRESHOLD
+
+ self.UNet_base_num_features = 32
+ self.UNet_class = PlainConvUNet
+ # the following two numbers are really arbitrary and were set to reproduce nnU-Net v1's configurations as
+ # much as possible
+ self.UNet_reference_val_3d = 560000000 # 455600128 550000000
+ self.UNet_reference_val_2d = 85000000 # 83252480
+ self.UNet_reference_com_nfeatures = 32
+ self.UNet_reference_val_corresp_GB = 8
+ self.UNet_reference_val_corresp_bs_2d = 12
+ self.UNet_reference_val_corresp_bs_3d = 2
+ self.UNet_vram_target_GB = gpu_memory_target_in_gb
+ self.UNet_featuremap_min_edge_length = 4
+ self.UNet_blocks_per_stage_encoder = (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)
+ self.UNet_blocks_per_stage_decoder = (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)
+ self.UNet_min_batch_size = 2
+ self.UNet_max_features_2d = 512
+ self.UNet_max_features_3d = 320
+
+ self.lowres_creation_threshold = 0.25 # if the patch size of fullres is less than 25% of the voxels in the
+ # median shape then we need a lowres config as well
+
+ self.preprocessor_name = preprocessor_name
+ self.plans_identifier = plans_name
+ self.overwrite_target_spacing = overwrite_target_spacing
+ assert overwrite_target_spacing is None or len(overwrite_target_spacing), 'if overwrite_target_spacing is ' \
+ 'used then three floats must be ' \
+ 'given (as list or tuple)'
+ assert overwrite_target_spacing is None or all([isinstance(i, float) for i in overwrite_target_spacing]), \
+ 'if overwrite_target_spacing is used then three floats must be given (as list or tuple)'
+
+ self.plans = None
+
+ def determine_reader_writer(self):
+ example_image = self.dataset[self.dataset.keys().__iter__().__next__()]['images'][0]
+ return determine_reader_writer_from_dataset_json(self.dataset_json, example_image)
+
+ @staticmethod
+ @lru_cache(maxsize=None)
+ def static_estimate_VRAM_usage(patch_size: Tuple[int],
+ n_stages: int,
+ strides: Union[int, List[int], Tuple[int, ...]],
+ UNet_class: Union[Type[PlainConvUNet], Type[ResidualEncoderUNet]],
+ num_input_channels: int,
+ features_per_stage: Tuple[int],
+ blocks_per_stage_encoder: Union[int, Tuple[int]],
+ blocks_per_stage_decoder: Union[int, Tuple[int]],
+ num_labels: int):
+ """
+ Works for PlainConvUNet, ResidualEncoderUNet
+ """
+ dim = len(patch_size)
+ conv_op = convert_dim_to_conv_op(dim)
+ norm_op = get_matching_instancenorm(conv_op)
+ net = UNet_class(num_input_channels, n_stages,
+ features_per_stage,
+ conv_op,
+ 3,
+ strides,
+ blocks_per_stage_encoder,
+ num_labels,
+ blocks_per_stage_decoder,
+ norm_op=norm_op)
+ return net.compute_conv_feature_map_size(patch_size)
+
+ def determine_resampling(self, *args, **kwargs):
+ """
+ returns what functions to use for resampling data and seg, respectively. Also returns kwargs
+ resampling function must be callable(data, current_spacing, new_spacing, **kwargs)
+
+ determine_resampling is called within get_plans_for_configuration to allow for different functions for each
+ configuration
+ """
+ resampling_data = resample_data_or_seg_to_shape
+ resampling_data_kwargs = {
+ "is_seg": False,
+ "order": 3,
+ "order_z": 0,
+ "force_separate_z": None,
+ }
+ resampling_seg = resample_data_or_seg_to_shape
+ resampling_seg_kwargs = {
+ "is_seg": True,
+ "order": 1,
+ "order_z": 0,
+ "force_separate_z": None,
+ }
+ return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs
+
+ def determine_segmentation_softmax_export_fn(self, *args, **kwargs):
+ """
+ function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be
+ used as target. current_spacing and new_spacing are merely there in case we want to use it somehow
+
+ determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different
+ functions for each configuration
+
+ """
+ resampling_fn = resample_data_or_seg_to_shape
+ resampling_fn_kwargs = {
+ "is_seg": False,
+ "order": 1,
+ "order_z": 0,
+ "force_separate_z": None,
+ }
+ return resampling_fn, resampling_fn_kwargs
+
+ def determine_fullres_target_spacing(self) -> np.ndarray:
+ """
+ per default we use the 50th percentile=median for the target spacing. Higher spacing results in smaller data
+ and thus faster and easier training. Smaller spacing results in larger data and thus longer and harder training
+
+ For some datasets the median is not a good choice. Those are the datasets where the spacing is very anisotropic
+ (for example ACDC with (10, 1.5, 1.5)). These datasets still have examples with a spacing of 5 or 6 mm in the low
+ resolution axis. Choosing the median here will result in bad interpolation artifacts that can substantially
+ impact performance (due to the low number of slices).
+ """
+ if self.overwrite_target_spacing is not None:
+ return np.array(self.overwrite_target_spacing)
+
+ spacings = self.dataset_fingerprint['spacings']
+ sizes = self.dataset_fingerprint['shapes_after_crop']
+
+ target = np.percentile(np.vstack(spacings), 50, 0)
+
+ # todo sizes_after_resampling = [compute_new_shape(j, i, target) for i, j in zip(spacings, sizes)]
+
+ target_size = np.percentile(np.vstack(sizes), 50, 0)
+ # we need to identify datasets for which a different target spacing could be beneficial. These datasets have
+ # the following properties:
+ # - one axis which much lower resolution than the others
+ # - the lowres axis has much less voxels than the others
+ # - (the size in mm of the lowres axis is also reduced)
+ worst_spacing_axis = np.argmax(target)
+ other_axes = [i for i in range(len(target)) if i != worst_spacing_axis]
+ other_spacings = [target[i] for i in other_axes]
+ other_sizes = [target_size[i] for i in other_axes]
+
+ has_aniso_spacing = target[worst_spacing_axis] > (self.anisotropy_threshold * max(other_spacings))
+ has_aniso_voxels = target_size[worst_spacing_axis] * self.anisotropy_threshold < min(other_sizes)
+
+ if has_aniso_spacing and has_aniso_voxels:
+ spacings_of_that_axis = np.vstack(spacings)[:, worst_spacing_axis]
+ target_spacing_of_that_axis = np.percentile(spacings_of_that_axis, 10)
+ # don't let the spacing of that axis get higher than the other axes
+ if target_spacing_of_that_axis < max(other_spacings):
+ target_spacing_of_that_axis = max(max(other_spacings), target_spacing_of_that_axis) + 1e-5
+ target[worst_spacing_axis] = target_spacing_of_that_axis
+ return target
+
+ def determine_normalization_scheme_and_whether_mask_is_used_for_norm(self) -> Tuple[List[str], List[bool]]:
+ if 'channel_names' not in self.dataset_json.keys():
+ print('WARNING: "modalities" should be renamed to "channel_names" in dataset.json. This will be '
+ 'enforced soon!')
+ modalities = self.dataset_json['channel_names'] if 'channel_names' in self.dataset_json.keys() else \
+ self.dataset_json['modality']
+ normalization_schemes = [get_normalization_scheme(m) for m in modalities.values()]
+ if self.dataset_fingerprint['median_relative_size_after_cropping'] < (3 / 4.):
+ use_nonzero_mask_for_norm = [i.leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true for i in
+ normalization_schemes]
+ else:
+ use_nonzero_mask_for_norm = [False] * len(normalization_schemes)
+ assert all([i in (True, False) for i in use_nonzero_mask_for_norm]), 'use_nonzero_mask_for_norm must be ' \
+ 'True or False and cannot be None'
+ normalization_schemes = [i.__name__ for i in normalization_schemes]
+ return normalization_schemes, use_nonzero_mask_for_norm
+
+ def determine_transpose(self):
+ if self.suppress_transpose:
+ return [0, 1, 2], [0, 1, 2]
+
+ # todo we should use shapes for that as well. Not quite sure how yet
+ target_spacing = self.determine_fullres_target_spacing()
+
+ max_spacing_axis = np.argmax(target_spacing)
+ remaining_axes = [i for i in list(range(3)) if i != max_spacing_axis]
+ transpose_forward = [max_spacing_axis] + remaining_axes
+ transpose_backward = [np.argwhere(np.array(transpose_forward) == i)[0][0] for i in range(3)]
+ return transpose_forward, transpose_backward
+
+ def get_plans_for_configuration(self,
+ spacing: Union[np.ndarray, Tuple[float, ...], List[float]],
+ median_shape: Union[np.ndarray, Tuple[int, ...], List[int]],
+ data_identifier: str,
+ approximate_n_voxels_dataset: float) -> dict:
+ assert all([i > 0 for i in spacing]), f"Spacing must be > 0! Spacing: {spacing}"
+ # print(spacing, median_shape, approximate_n_voxels_dataset)
+ # find an initial patch size
+ # we first use the spacing to get an aspect ratio
+ tmp = 1 / np.array(spacing)
+
+ # we then upscale it so that it initially is certainly larger than what we need (rescale to have the same
+ # volume as a patch of size 256 ** 3)
+ # this may need to be adapted when using absurdly large GPU memory targets. Increasing this now would not be
+ # ideal because large initial patch sizes increase computation time because more iterations in the while loop
+ # further down may be required.
+ if len(spacing) == 3:
+ initial_patch_size = [round(i) for i in tmp * (256 ** 3 / np.prod(tmp)) ** (1 / 3)]
+ elif len(spacing) == 2:
+ initial_patch_size = [round(i) for i in tmp * (2048 ** 2 / np.prod(tmp)) ** (1 / 2)]
+ else:
+ raise RuntimeError()
+
+ # clip initial patch size to median_shape. It makes little sense to have it be larger than that. Note that
+ # this is different from how nnU-Net v1 does it!
+ # todo patch size can still get too large because we pad the patch size to a multiple of 2**n
+ initial_patch_size = np.array([min(i, j) for i, j in zip(initial_patch_size, median_shape[:len(spacing)])])
+
+ # use that to get the network topology. Note that this changes the patch_size depending on the number of
+ # pooling operations (must be divisible by 2**num_pool in each axis)
+ network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \
+ shape_must_be_divisible_by = get_pool_and_conv_props(spacing, initial_patch_size,
+ self.UNet_featuremap_min_edge_length,
+ 999999)
+
+ # now estimate vram consumption
+ num_stages = len(pool_op_kernel_sizes)
+ estimate = self.static_estimate_VRAM_usage(tuple(patch_size),
+ num_stages,
+ tuple([tuple(i) for i in pool_op_kernel_sizes]),
+ self.UNet_class,
+ len(self.dataset_json['channel_names'].keys()
+ if 'channel_names' in self.dataset_json.keys()
+ else self.dataset_json['modality'].keys()),
+ tuple([min(self.UNet_max_features_2d if len(patch_size) == 2 else
+ self.UNet_max_features_3d,
+ self.UNet_reference_com_nfeatures * 2 ** i) for
+ i in range(len(pool_op_kernel_sizes))]),
+ self.UNet_blocks_per_stage_encoder[:num_stages],
+ self.UNet_blocks_per_stage_decoder[:num_stages - 1],
+ len(self.dataset_json['labels'].keys()))
+
+ # how large is the reference for us here (batch size etc)?
+ # adapt for our vram target
+ reference = (self.UNet_reference_val_2d if len(spacing) == 2 else self.UNet_reference_val_3d) * \
+ (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB)
+
+ while estimate > reference:
+ # print(patch_size)
+ # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the
+ # aspect ratio the most (that is the largest relative to median shape)
+ axis_to_be_reduced = np.argsort(patch_size / median_shape[:len(spacing)])[-1]
+
+ # we cannot simply reduce that axis by shape_must_be_divisible_by[axis_to_be_reduced] because this
+ # may cause us to skip some valid sizes, for example shape_must_be_divisible_by is 64 for a shape of 256.
+ # If we subtracted that we would end up with 192, skipping 224 which is also a valid patch size
+ # (224 / 2**5 = 7; 7 < 2 * self.UNet_featuremap_min_edge_length(4) so it's valid). So we need to first
+ # subtract shape_must_be_divisible_by, then recompute it and then subtract the
+ # recomputed shape_must_be_divisible_by. Annoying.
+ tmp = deepcopy(patch_size)
+ tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
+ _, _, _, _, shape_must_be_divisible_by = \
+ get_pool_and_conv_props(spacing, tmp,
+ self.UNet_featuremap_min_edge_length,
+ 999999)
+ patch_size[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
+
+ # now recompute topology
+ network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \
+ shape_must_be_divisible_by = get_pool_and_conv_props(spacing, patch_size,
+ self.UNet_featuremap_min_edge_length,
+ 999999)
+
+ num_stages = len(pool_op_kernel_sizes)
+ estimate = self.static_estimate_VRAM_usage(tuple(patch_size),
+ num_stages,
+ tuple([tuple(i) for i in pool_op_kernel_sizes]),
+ self.UNet_class,
+ len(self.dataset_json['channel_names'].keys()
+ if 'channel_names' in self.dataset_json.keys()
+ else self.dataset_json['modality'].keys()),
+ tuple([min(self.UNet_max_features_2d if len(patch_size) == 2 else
+ self.UNet_max_features_3d,
+ self.UNet_reference_com_nfeatures * 2 ** i) for
+ i in range(len(pool_op_kernel_sizes))]),
+ self.UNet_blocks_per_stage_encoder[:num_stages],
+ self.UNet_blocks_per_stage_decoder[:num_stages - 1],
+ len(self.dataset_json['labels'].keys()))
+
+ # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was
+ # executed. If not, additional vram headroom is used to increase batch size
+ ref_bs = self.UNet_reference_val_corresp_bs_2d if len(spacing) == 2 else self.UNet_reference_val_corresp_bs_3d
+ batch_size = round((reference / estimate) * ref_bs)
+
+ # we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot
+ # go smaller than self.UNet_min_batch_size though
+ bs_corresponding_to_5_percent = round(
+ approximate_n_voxels_dataset * 0.05 / np.prod(patch_size, dtype=np.float64))
+ batch_size = max(min(batch_size, bs_corresponding_to_5_percent), self.UNet_min_batch_size)
+
+ resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs = self.determine_resampling()
+ resampling_softmax, resampling_softmax_kwargs = self.determine_segmentation_softmax_export_fn()
+
+ normalization_schemes, mask_is_used_for_norm = \
+ self.determine_normalization_scheme_and_whether_mask_is_used_for_norm()
+ num_stages = len(pool_op_kernel_sizes)
+ plan = {
+ 'data_identifier': data_identifier,
+ 'preprocessor_name': self.preprocessor_name,
+ 'batch_size': batch_size,
+ 'patch_size': patch_size,
+ 'median_image_size_in_voxels': median_shape,
+ 'spacing': spacing,
+ 'normalization_schemes': normalization_schemes,
+ 'use_mask_for_norm': mask_is_used_for_norm,
+ 'UNet_class_name': self.UNet_class.__name__,
+ 'UNet_base_num_features': self.UNet_base_num_features,
+ 'n_conv_per_stage_encoder': self.UNet_blocks_per_stage_encoder[:num_stages],
+ 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1],
+ 'num_pool_per_axis': network_num_pool_per_axis,
+ 'pool_op_kernel_sizes': pool_op_kernel_sizes,
+ 'conv_kernel_sizes': conv_kernel_sizes,
+ 'unet_max_num_features': self.UNet_max_features_3d if len(spacing) == 3 else self.UNet_max_features_2d,
+ 'resampling_fn_data': resampling_data.__name__,
+ 'resampling_fn_seg': resampling_seg.__name__,
+ 'resampling_fn_data_kwargs': resampling_data_kwargs,
+ 'resampling_fn_seg_kwargs': resampling_seg_kwargs,
+ 'resampling_fn_probabilities': resampling_softmax.__name__,
+ 'resampling_fn_probabilities_kwargs': resampling_softmax_kwargs,
+ }
+ return plan
+
+ def plan_experiment(self):
+ """
+ MOVE EVERYTHING INTO THE PLANS. MAXIMUM FLEXIBILITY
+
+ Ideally I would like to move transpose_forward/backward into the configurations so that this can also be done
+ differently for each configuration but this would cause problems with identifying the correct axes for 2d. There
+ surely is a way around that but eh. I'm feeling lazy and featuritis must also not be pushed to the extremes.
+
+ So for now if you want a different transpose_forward/backward you need to create a new planner. Also not too
+ hard.
+ """
+
+ # first get transpose
+ transpose_forward, transpose_backward = self.determine_transpose()
+
+ # get fullres spacing and transpose it
+ fullres_spacing = self.determine_fullres_target_spacing()
+ fullres_spacing_transposed = fullres_spacing[transpose_forward]
+
+ # get transposed new median shape (what we would have after resampling)
+ new_shapes = [compute_new_shape(j, i, fullres_spacing) for i, j in
+ zip(self.dataset_fingerprint['spacings'], self.dataset_fingerprint['shapes_after_crop'])]
+ new_median_shape = np.median(new_shapes, 0)
+ new_median_shape_transposed = new_median_shape[transpose_forward]
+
+ approximate_n_voxels_dataset = float(np.prod(new_median_shape_transposed, dtype=np.float64) *
+ self.dataset_json['numTraining'])
+ # only run 3d if this is a 3d dataset
+ if new_median_shape_transposed[0] != 1:
+ plan_3d_fullres = self.get_plans_for_configuration(fullres_spacing_transposed,
+ new_median_shape_transposed,
+ self.generate_data_identifier('3d_fullres'),
+ approximate_n_voxels_dataset)
+ # maybe add 3d_lowres as well
+ patch_size_fullres = plan_3d_fullres['patch_size']
+ median_num_voxels = np.prod(new_median_shape_transposed, dtype=np.float64)
+ num_voxels_in_patch = np.prod(patch_size_fullres, dtype=np.float64)
+
+ plan_3d_lowres = None
+ lowres_spacing = deepcopy(plan_3d_fullres['spacing'])
+
+ spacing_increase_factor = 1.03 # used to be 1.01 but that is slow with new GPU memory estimation!
+
+ while num_voxels_in_patch / median_num_voxels < self.lowres_creation_threshold:
+ # we incrementally increase the target spacing. We start with the anisotropic axis/axes until it/they
+ # is/are similar (factor 2) to the other ax(i/e)s.
+ max_spacing = max(lowres_spacing)
+ if np.any((max_spacing / lowres_spacing) > 2):
+ lowres_spacing[(max_spacing / lowres_spacing) > 2] *= spacing_increase_factor
+ else:
+ lowres_spacing *= spacing_increase_factor
+ median_num_voxels = np.prod(plan_3d_fullres['spacing'] / lowres_spacing * new_median_shape_transposed,
+ dtype=np.float64)
+ # print(lowres_spacing)
+ plan_3d_lowres = self.get_plans_for_configuration(lowres_spacing,
+ [round(i) for i in plan_3d_fullres['spacing'] /
+ lowres_spacing * new_median_shape_transposed],
+ self.generate_data_identifier('3d_lowres'),
+ float(np.prod(median_num_voxels) *
+ self.dataset_json['numTraining']))
+ num_voxels_in_patch = np.prod(plan_3d_lowres['patch_size'], dtype=np.int64)
+ print(f'Attempting to find 3d_lowres config. '
+ f'\nCurrent spacing: {lowres_spacing}. '
+ f'\nCurrent patch size: {plan_3d_lowres["patch_size"]}. '
+ f'\nCurrent median shape: {plan_3d_fullres["spacing"] / lowres_spacing * new_median_shape_transposed}')
+ if plan_3d_lowres is not None:
+ plan_3d_lowres['batch_dice'] = False
+ plan_3d_fullres['batch_dice'] = True
+ else:
+ plan_3d_fullres['batch_dice'] = False
+ else:
+ plan_3d_fullres = None
+ plan_3d_lowres = None
+
+ # 2D configuration
+ plan_2d = self.get_plans_for_configuration(fullres_spacing_transposed[1:],
+ new_median_shape_transposed[1:],
+ self.generate_data_identifier('2d'), approximate_n_voxels_dataset)
+ plan_2d['batch_dice'] = True
+
+ print('2D U-Net configuration:')
+ print(plan_2d)
+ print()
+
+ # median spacing and shape, just for reference when printing the plans
+ median_spacing = np.median(self.dataset_fingerprint['spacings'], 0)[transpose_forward]
+ median_shape = np.median(self.dataset_fingerprint['shapes_after_crop'], 0)[transpose_forward]
+
+ # instead of writing all that into the plans we just copy the original file. More files, but less crowded
+ # per file.
+ shutil.copy(join(self.raw_dataset_folder, 'dataset.json'),
+ join(nnUNet_preprocessed, self.dataset_name, 'dataset.json'))
+
+ # json is stupid and I hate it... "Object of type int64 is not JSON serializable" -> my ass
+ plans = {
+ 'dataset_name': self.dataset_name,
+ 'plans_name': self.plans_identifier,
+ 'original_median_spacing_after_transp': [float(i) for i in median_spacing],
+ 'original_median_shape_after_transp': [int(round(i)) for i in median_shape],
+ 'image_reader_writer': self.determine_reader_writer().__name__,
+ 'transpose_forward': [int(i) for i in transpose_forward],
+ 'transpose_backward': [int(i) for i in transpose_backward],
+ 'configurations': {'2d': plan_2d},
+ 'experiment_planner_used': self.__class__.__name__,
+ 'label_manager': 'LabelManager',
+ 'foreground_intensity_properties_per_channel': self.dataset_fingerprint[
+ 'foreground_intensity_properties_per_channel']
+ }
+
+ if plan_3d_lowres is not None:
+ plans['configurations']['3d_lowres'] = plan_3d_lowres
+ if plan_3d_fullres is not None:
+ plans['configurations']['3d_lowres']['next_stage'] = '3d_cascade_fullres'
+ print('3D lowres U-Net configuration:')
+ print(plan_3d_lowres)
+ print()
+ if plan_3d_fullres is not None:
+ plans['configurations']['3d_fullres'] = plan_3d_fullres
+ print('3D fullres U-Net configuration:')
+ print(plan_3d_fullres)
+ print()
+ if plan_3d_lowres is not None:
+ plans['configurations']['3d_cascade_fullres'] = {
+ 'inherits_from': '3d_fullres',
+ 'previous_stage': '3d_lowres'
+ }
+
+ self.plans = plans
+ self.save_plans(plans)
+ return plans
+
+ def save_plans(self, plans):
+ recursive_fix_for_json_export(plans)
+
+ plans_file = join(nnUNet_preprocessed, self.dataset_name, self.plans_identifier + '.json')
+
+ # we don't want to overwrite potentially existing custom configurations every time this is executed. So let's
+ # read the plans file if it already exists and keep any non-default configurations
+ if isfile(plans_file):
+ old_plans = load_json(plans_file)
+ old_configurations = old_plans['configurations']
+ for c in plans['configurations'].keys():
+ if c in old_configurations.keys():
+ del (old_configurations[c])
+ plans['configurations'].update(old_configurations)
+
+ maybe_mkdir_p(join(nnUNet_preprocessed, self.dataset_name))
+ save_json(plans, plans_file, sort_keys=False)
+ print('Plans were saved to %s' % join(nnUNet_preprocessed, self.dataset_name, self.plans_identifier + '.json'))
+
+ def generate_data_identifier(self, configuration_name: str) -> str:
+ """
+ configurations are unique within each plans file but differnet plans file can have configurations with the
+ same name. In order to distinguish the assiciated data we need a data identifier that reflects not just the
+ config but also the plans it originates from
+ """
+ return self.plans_identifier + '_' + configuration_name
+
+ def load_plans(self, fname: str):
+ self.plans = load_json(fname)
+
+
+if __name__ == '__main__':
+ ExperimentPlanner(2, 8).plan_experiment()
diff --git a/nnUNet/nnunetv2/experiment_planning/experiment_planners/network_topology.py b/nnUNet/nnunetv2/experiment_planning/experiment_planners/network_topology.py
new file mode 100644
index 0000000..8b75b46
--- /dev/null
+++ b/nnUNet/nnunetv2/experiment_planning/experiment_planners/network_topology.py
@@ -0,0 +1,105 @@
+from copy import deepcopy
+import numpy as np
+
+
+def get_shape_must_be_divisible_by(net_numpool_per_axis):
+ return 2 ** np.array(net_numpool_per_axis)
+
+
+def pad_shape(shape, must_be_divisible_by):
+ """
+ pads shape so that it is divisible by must_be_divisible_by
+ :param shape:
+ :param must_be_divisible_by:
+ :return:
+ """
+ if not isinstance(must_be_divisible_by, (tuple, list, np.ndarray)):
+ must_be_divisible_by = [must_be_divisible_by] * len(shape)
+ else:
+ assert len(must_be_divisible_by) == len(shape)
+
+ new_shp = [shape[i] + must_be_divisible_by[i] - shape[i] % must_be_divisible_by[i] for i in range(len(shape))]
+
+ for i in range(len(shape)):
+ if shape[i] % must_be_divisible_by[i] == 0:
+ new_shp[i] -= must_be_divisible_by[i]
+ new_shp = np.array(new_shp).astype(int)
+ return new_shp
+
+
+def get_pool_and_conv_props(spacing, patch_size, min_feature_map_size, max_numpool):
+ """
+ this is the same as get_pool_and_conv_props_v2 from old nnunet
+
+ :param spacing:
+ :param patch_size:
+ :param min_feature_map_size: min edge length of feature maps in bottleneck
+ :param max_numpool:
+ :return:
+ """
+ # todo review this code
+ dim = len(spacing)
+
+ current_spacing = deepcopy(list(spacing))
+ current_size = deepcopy(list(patch_size))
+
+ pool_op_kernel_sizes = [[1] * len(spacing)]
+ conv_kernel_sizes = []
+
+ num_pool_per_axis = [0] * dim
+ kernel_size = [1] * dim
+
+ while True:
+ # exclude axes that we cannot pool further because of min_feature_map_size constraint
+ valid_axes_for_pool = [i for i in range(dim) if current_size[i] >= 2*min_feature_map_size]
+ if len(valid_axes_for_pool) < 1:
+ break
+
+ spacings_of_axes = [current_spacing[i] for i in valid_axes_for_pool]
+
+ # find axis that are within factor of 2 within smallest spacing
+ min_spacing_of_valid = min(spacings_of_axes)
+ valid_axes_for_pool = [i for i in valid_axes_for_pool if current_spacing[i] / min_spacing_of_valid < 2]
+
+ # max_numpool constraint
+ valid_axes_for_pool = [i for i in valid_axes_for_pool if num_pool_per_axis[i] < max_numpool]
+
+ if len(valid_axes_for_pool) == 1:
+ if current_size[valid_axes_for_pool[0]] >= 3 * min_feature_map_size:
+ pass
+ else:
+ break
+ if len(valid_axes_for_pool) < 1:
+ break
+
+ # now we need to find kernel sizes
+ # kernel sizes are initialized to 1. They are successively set to 3 when their associated axis becomes within
+ # factor 2 of min_spacing. Once they are 3 they remain 3
+ for d in range(dim):
+ if kernel_size[d] == 3:
+ continue
+ else:
+ if spacings_of_axes[d] / min(current_spacing) < 2:
+ kernel_size[d] = 3
+
+ other_axes = [i for i in range(dim) if i not in valid_axes_for_pool]
+
+ pool_kernel_sizes = [0] * dim
+ for v in valid_axes_for_pool:
+ pool_kernel_sizes[v] = 2
+ num_pool_per_axis[v] += 1
+ current_spacing[v] *= 2
+ current_size[v] = np.ceil(current_size[v] / 2)
+ for nv in other_axes:
+ pool_kernel_sizes[nv] = 1
+
+ pool_op_kernel_sizes.append(pool_kernel_sizes)
+ conv_kernel_sizes.append(deepcopy(kernel_size))
+ #print(conv_kernel_sizes)
+
+ must_be_divisible_by = get_shape_must_be_divisible_by(num_pool_per_axis)
+ patch_size = pad_shape(patch_size, must_be_divisible_by)
+
+ # we need to add one more conv_kernel_size for the bottleneck. We always use 3x3(x3) conv here
+ conv_kernel_sizes.append([3]*dim)
+ return num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, must_be_divisible_by
diff --git a/nnUNet/nnunetv2/experiment_planning/experiment_planners/readme.md b/nnUNet/nnunetv2/experiment_planning/experiment_planners/readme.md
new file mode 100644
index 0000000..e2e4e18
--- /dev/null
+++ b/nnUNet/nnunetv2/experiment_planning/experiment_planners/readme.md
@@ -0,0 +1,38 @@
+What do experiment planners need to do (these are notes for myself while rewriting nnU-Net, they are provided as is
+without further explanations. These notes also include new features):
+- (done) preprocessor name should be configurable via cli
+- (done) gpu memory target should be configurable via cli
+- (done) plans name should be configurable via cli
+- (done) data name should be specified in plans (plans specify the data they want to use, this will allow us to manually
+ edit plans files without having to copy the data folders)
+- plans must contain:
+ - (done) transpose forward/backward
+ - (done) preprocessor name (can differ for each config)
+ - (done) spacing
+ - (done) normalization scheme
+ - (done) target spacing
+ - (done) conv and pool op kernel sizes
+ - (done) base num features for architecture
+ - (done) data identifier
+ - num conv per stage?
+ - (done) use mask for norm
+ - [NO. Handled by LabelManager & dataset.json] num segmentation outputs
+ - [NO. Handled by LabelManager & dataset.json] ignore class
+ - [NO. Handled by LabelManager & dataset.json] list of regions or classes
+ - [NO. Handled by LabelManager & dataset.json] regions class order, if applicable
+ - (done) resampling function to be used
+ - (done) the image reader writer class that should be used
+
+
+dataset.json
+mandatory:
+- numTraining
+- labels (value 'ignore' has special meaning. Cannot have more than one ignore_label)
+- modalities
+- file_ending
+
+optional
+- overwrite_image_reader_writer (if absent, auto)
+- regions
+- region_class_order
+-
\ No newline at end of file
diff --git a/nnUNet/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py b/nnUNet/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py
new file mode 100644
index 0000000..52ca938
--- /dev/null
+++ b/nnUNet/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py
@@ -0,0 +1,54 @@
+from typing import Union, List, Tuple
+
+from torch import nn
+
+from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner
+from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet
+
+
+class ResEncUNetPlanner(ExperimentPlanner):
+ def __init__(self, dataset_name_or_id: Union[str, int],
+ gpu_memory_target_in_gb: float = 8,
+ preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetPlans',
+ overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None,
+ suppress_transpose: bool = False):
+ super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name,
+ overwrite_target_spacing, suppress_transpose)
+
+ self.UNet_base_num_features = 32
+ self.UNet_class = ResidualEncoderUNet
+ # the following two numbers are really arbitrary and were set to reproduce default nnU-Net's configurations as
+ # much as possible
+ self.UNet_reference_val_3d = 680000000
+ self.UNet_reference_val_2d = 135000000
+ self.UNet_reference_com_nfeatures = 32
+ self.UNet_reference_val_corresp_GB = 8
+ self.UNet_reference_val_corresp_bs_2d = 12
+ self.UNet_reference_val_corresp_bs_3d = 2
+ self.UNet_featuremap_min_edge_length = 4
+ self.UNet_blocks_per_stage_encoder = (1, 3, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6)
+ self.UNet_blocks_per_stage_decoder = (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)
+ self.UNet_min_batch_size = 2
+ self.UNet_max_features_2d = 512
+ self.UNet_max_features_3d = 320
+
+
+if __name__ == '__main__':
+ # we know both of these networks run with batch size 2 and 12 on ~8-10GB, respectively
+ net = ResidualEncoderUNet(input_channels=1, n_stages=6, features_per_stage=(32, 64, 128, 256, 320, 320),
+ conv_op=nn.Conv3d, kernel_sizes=3, strides=(1, 2, 2, 2, 2, 2),
+ n_blocks_per_stage=(1, 3, 4, 6, 6, 6), num_classes=3,
+ n_conv_per_stage_decoder=(1, 1, 1, 1, 1),
+ conv_bias=True, norm_op=nn.InstanceNorm3d, norm_op_kwargs={}, dropout_op=None,
+ nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True)
+ print(net.compute_conv_feature_map_size((128, 128, 128))) # -> 558319104. The value you see above was finetuned
+ # from this one to match the regular nnunetplans more closely
+
+ net = ResidualEncoderUNet(input_channels=1, n_stages=7, features_per_stage=(32, 64, 128, 256, 512, 512, 512),
+ conv_op=nn.Conv2d, kernel_sizes=3, strides=(1, 2, 2, 2, 2, 2, 2),
+ n_blocks_per_stage=(1, 3, 4, 6, 6, 6, 6), num_classes=3,
+ n_conv_per_stage_decoder=(1, 1, 1, 1, 1, 1),
+ conv_bias=True, norm_op=nn.InstanceNorm2d, norm_op_kwargs={}, dropout_op=None,
+ nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True)
+ print(net.compute_conv_feature_map_size((512, 512))) # -> 129793792
+
diff --git a/nnUNet/nnunetv2/experiment_planning/plan_and_preprocess_api.py b/nnUNet/nnunetv2/experiment_planning/plan_and_preprocess_api.py
new file mode 100644
index 0000000..eb94840
--- /dev/null
+++ b/nnUNet/nnunetv2/experiment_planning/plan_and_preprocess_api.py
@@ -0,0 +1,138 @@
+import shutil
+from typing import List, Type, Optional, Tuple, Union
+
+import nnunetv2
+from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p, subfiles, load_json
+
+from nnunetv2.experiment_planning.dataset_fingerprint.fingerprint_extractor import DatasetFingerprintExtractor
+from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner
+from nnunetv2.experiment_planning.verify_dataset_integrity import verify_dataset_integrity
+from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed
+from nnunetv2.utilities.dataset_name_id_conversion import convert_id_to_dataset_name, maybe_convert_to_dataset_name
+from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
+from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
+from nnunetv2.configuration import default_num_processes
+from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets
+
+
+def extract_fingerprint_dataset(dataset_id: int,
+ fingerprint_extractor_class: Type[
+ DatasetFingerprintExtractor] = DatasetFingerprintExtractor,
+ num_processes: int = default_num_processes, check_dataset_integrity: bool = False,
+ clean: bool = True, verbose: bool = True):
+ """
+ Returns the fingerprint as a dictionary (additionally to saving it)
+ """
+ dataset_name = convert_id_to_dataset_name(dataset_id)
+ print(dataset_name)
+
+ if check_dataset_integrity:
+ verify_dataset_integrity(join(nnUNet_raw, dataset_name), num_processes)
+
+ fpe = fingerprint_extractor_class(dataset_id, num_processes, verbose=verbose)
+ return fpe.run(overwrite_existing=clean)
+
+
+def extract_fingerprints(dataset_ids: List[int], fingerprint_extractor_class_name: str = 'DatasetFingerprintExtractor',
+ num_processes: int = default_num_processes, check_dataset_integrity: bool = False,
+ clean: bool = True, verbose: bool = True):
+ """
+ clean = False will not actually run this. This is just a switch for use with nnUNetv2_plan_and_preprocess where
+ we don't want to rerun fingerprint extraction every time.
+ """
+ fingerprint_extractor_class = recursive_find_python_class(join(nnunetv2.__path__[0], "experiment_planning"),
+ fingerprint_extractor_class_name,
+ current_module="nnunetv2.experiment_planning")
+ for d in dataset_ids:
+ extract_fingerprint_dataset(d, fingerprint_extractor_class, num_processes, check_dataset_integrity, clean,
+ verbose)
+
+
+def plan_experiment_dataset(dataset_id: int,
+ experiment_planner_class: Type[ExperimentPlanner] = ExperimentPlanner,
+ gpu_memory_target_in_gb: float = 8, preprocess_class_name: str = 'DefaultPreprocessor',
+ overwrite_target_spacing: Optional[Tuple[float, ...]] = None,
+ overwrite_plans_name: Optional[str] = None) -> dict:
+ """
+ overwrite_target_spacing ONLY applies to 3d_fullres and 3d_cascade fullres!
+ """
+ kwargs = {}
+ if overwrite_plans_name is not None:
+ kwargs['plans_name'] = overwrite_plans_name
+ return experiment_planner_class(dataset_id,
+ gpu_memory_target_in_gb=gpu_memory_target_in_gb,
+ preprocessor_name=preprocess_class_name,
+ overwrite_target_spacing=[float(i) for i in overwrite_target_spacing] if
+ overwrite_target_spacing is not None else overwrite_target_spacing,
+ suppress_transpose=False, # might expose this later,
+ **kwargs
+ ).plan_experiment()
+
+
+def plan_experiments(dataset_ids: List[int], experiment_planner_class_name: str = 'ExperimentPlanner',
+ gpu_memory_target_in_gb: float = 8, preprocess_class_name: str = 'DefaultPreprocessor',
+ overwrite_target_spacing: Optional[Tuple[float, ...]] = None,
+ overwrite_plans_name: Optional[str] = None):
+ """
+ overwrite_target_spacing ONLY applies to 3d_fullres and 3d_cascade fullres!
+ """
+ experiment_planner = recursive_find_python_class(join(nnunetv2.__path__[0], "experiment_planning"),
+ experiment_planner_class_name,
+ current_module="nnunetv2.experiment_planning")
+ for d in dataset_ids:
+ plan_experiment_dataset(d, experiment_planner, gpu_memory_target_in_gb, preprocess_class_name,
+ overwrite_target_spacing, overwrite_plans_name)
+
+
+def preprocess_dataset(dataset_id: int,
+ plans_identifier: str = 'nnUNetPlans',
+ configurations: Union[Tuple[str], List[str]] = ('2d', '3d_fullres', '3d_lowres'),
+ num_processes: Union[int, Tuple[int, ...], List[int]] = (8, 4, 8),
+ verbose: bool = False) -> None:
+ if not isinstance(num_processes, list):
+ num_processes = list(num_processes)
+ if len(num_processes) == 1:
+ num_processes = num_processes * len(configurations)
+ if len(num_processes) != len(configurations):
+ raise RuntimeError(
+ f'The list provided with num_processes must either have len 1 or as many elements as there are '
+ f'configurations (see --help). Number of configurations: {len(configurations)}, length '
+ f'of num_processes: '
+ f'{len(num_processes)}')
+
+ dataset_name = convert_id_to_dataset_name(dataset_id)
+ print(f'Preprocessing dataset {dataset_name}')
+ plans_file = join(nnUNet_preprocessed, dataset_name, plans_identifier + '.json')
+ plans_manager = PlansManager(plans_file)
+ for n, c in zip(num_processes, configurations):
+ print(f'Configuration: {c}...')
+ if c not in plans_manager.available_configurations:
+ print(
+ f"INFO: Configuration {c} not found in plans file {plans_identifier + '.json'} of "
+ f"dataset {dataset_name}. Skipping.")
+ continue
+ configuration_manager = plans_manager.get_configuration(c)
+ preprocessor = configuration_manager.preprocessor_class(verbose=verbose)
+ preprocessor.run(dataset_id, c, plans_identifier, num_processes=n)
+
+ # copy the gt to a folder in the nnUNet_preprocessed so that we can do validation even if the raw data is no
+ # longer there (useful for compute cluster where only the preprocessed data is available)
+ from distutils.file_util import copy_file
+ maybe_mkdir_p(join(nnUNet_preprocessed, dataset_name, 'gt_segmentations'))
+ dataset_json = load_json(join(nnUNet_raw, dataset_name, 'dataset.json'))
+ dataset = get_filenames_of_train_images_and_targets(join(nnUNet_raw, dataset_name), dataset_json)
+ # only copy files that are newer than the ones already present
+ for k in dataset:
+ copy_file(dataset[k]['label'],
+ join(nnUNet_preprocessed, dataset_name, 'gt_segmentations', k + dataset_json['file_ending']),
+ update=True)
+
+
+
+def preprocess(dataset_ids: List[int],
+ plans_identifier: str = 'nnUNetPlans',
+ configurations: Union[Tuple[str], List[str]] = ('2d', '3d_fullres', '3d_lowres'),
+ num_processes: Union[int, Tuple[int, ...], List[int]] = (8, 4, 8),
+ verbose: bool = False):
+ for d in dataset_ids:
+ preprocess_dataset(d, plans_identifier, configurations, num_processes, verbose)
diff --git a/nnUNet/nnunetv2/experiment_planning/plan_and_preprocess_entrypoints.py b/nnUNet/nnunetv2/experiment_planning/plan_and_preprocess_entrypoints.py
new file mode 100644
index 0000000..a600653
--- /dev/null
+++ b/nnUNet/nnunetv2/experiment_planning/plan_and_preprocess_entrypoints.py
@@ -0,0 +1,201 @@
+from nnunetv2.configuration import default_num_processes
+from nnunetv2.experiment_planning.plan_and_preprocess_api import extract_fingerprints, plan_experiments, preprocess
+
+
+def extract_fingerprint_entry():
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-d', nargs='+', type=int,
+ help="[REQUIRED] List of dataset IDs. Example: 2 4 5. This will run fingerprint extraction, experiment "
+ "planning and preprocessing for these datasets. Can of course also be just one dataset")
+ parser.add_argument('-fpe', type=str, required=False, default='DatasetFingerprintExtractor',
+ help='[OPTIONAL] Name of the Dataset Fingerprint Extractor class that should be used. Default is '
+ '\'DatasetFingerprintExtractor\'.')
+ parser.add_argument('-np', type=int, default=default_num_processes, required=False,
+ help=f'[OPTIONAL] Number of processes used for fingerprint extraction. '
+ f'Default: {default_num_processes}')
+ parser.add_argument("--verify_dataset_integrity", required=False, default=False, action="store_true",
+ help="[RECOMMENDED] set this flag to check the dataset integrity. This is useful and should be done once for "
+ "each dataset!")
+ parser.add_argument("--clean", required=False, default=False, action="store_true",
+ help='[OPTIONAL] Set this flag to overwrite existing fingerprints. If this flag is not set and a '
+ 'fingerprint already exists, the fingerprint extractor will not run.')
+ parser.add_argument('--verbose', required=False, action='store_true',
+ help='Set this to print a lot of stuff. Useful for debugging. Will disable progrewss bar! '
+ 'Recommended for cluster environments')
+ args, unrecognized_args = parser.parse_known_args()
+ extract_fingerprints(args.d, args.fpe, args.np, args.verify_dataset_integrity, args.clean, args.verbose)
+
+
+def plan_experiment_entry():
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-d', nargs='+', type=int,
+ help="[REQUIRED] List of dataset IDs. Example: 2 4 5. This will run fingerprint extraction, experiment "
+ "planning and preprocessing for these datasets. Can of course also be just one dataset")
+ parser.add_argument('-pl', type=str, default='ExperimentPlanner', required=False,
+ help='[OPTIONAL] Name of the Experiment Planner class that should be used. Default is '
+ '\'ExperimentPlanner\'. Note: There is no longer a distinction between 2d and 3d planner. '
+ 'It\'s an all in one solution now. Wuch. Such amazing.')
+ parser.add_argument('-gpu_memory_target', default=8, type=float, required=False,
+ help='[OPTIONAL] DANGER ZONE! Sets a custom GPU memory target. Default: 8 [GB]. Changing this will '
+ 'affect patch and batch size and will '
+ 'definitely affect your models performance! Only use this if you really know what you '
+ 'are doing and NEVER use this without running the default nnU-Net first (as a baseline).')
+ parser.add_argument('-preprocessor_name', default='DefaultPreprocessor', type=str, required=False,
+ help='[OPTIONAL] DANGER ZONE! Sets a custom preprocessor class. This class must be located in '
+ 'nnunetv2.preprocessing. Default: \'DefaultPreprocessor\'. Changing this may affect your '
+ 'models performance! Only use this if you really know what you '
+ 'are doing and NEVER use this without running the default nnU-Net first (as a baseline).')
+ parser.add_argument('-overwrite_target_spacing', default=None, nargs='+', required=False,
+ help='[OPTIONAL] DANGER ZONE! Sets a custom target spacing for the 3d_fullres and 3d_cascade_fullres '
+ 'configurations. Default: None [no changes]. Changing this will affect image size and '
+ 'potentially patch and batch '
+ 'size. This will definitely affect your models performance! Only use this if you really '
+ 'know what you are doing and NEVER use this without running the default nnU-Net first '
+ '(as a baseline). Changing the target spacing for the other configurations is currently '
+ 'not implemented. New target spacing must be a list of three numbers!')
+ parser.add_argument('-overwrite_plans_name', default=None, required=False,
+ help='[OPTIONAL] DANGER ZONE! If you used -gpu_memory_target, -preprocessor_name or '
+ '-overwrite_target_spacing it is best practice to use -overwrite_plans_name to generate a '
+ 'differently named plans file such that the nnunet default plans are not '
+ 'overwritten. You will then need to specify your custom plans file with -p whenever '
+ 'running other nnunet commands (training, inference etc)')
+ args, unrecognized_args = parser.parse_known_args()
+ plan_experiments(args.d, args.pl, args.gpu_memory_target, args.preprocessor_name, args.overwrite_target_spacing,
+ args.overwrite_plans_name)
+
+
+def preprocess_entry():
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-d', nargs='+', type=int,
+ help="[REQUIRED] List of dataset IDs. Example: 2 4 5. This will run fingerprint extraction, experiment "
+ "planning and preprocessing for these datasets. Can of course also be just one dataset")
+ parser.add_argument('-plans_name', default='nnUNetPlans', required=False,
+ help='[OPTIONAL] You can use this to specify a custom plans file that you may have generated')
+ parser.add_argument('-c', required=False, default=['2d', '3d_fullres', '3d_lowres'], nargs='+',
+ help='[OPTIONAL] Configurations for which the preprocessing should be run. Default: 2d 3f_fullres '
+ '3d_lowres. 3d_cascade_fullres does not need to be specified because it uses the data '
+ 'from 3f_fullres. Configurations that do not exist for some dataset will be skipped.')
+ parser.add_argument('-np', type=int, nargs='+', default=[8, 4, 8], required=False,
+ help="[OPTIONAL] Use this to define how many processes are to be used. If this is just one number then "
+ "this number of processes is used for all configurations specified with -c. If it's a "
+ "list of numbers this list must have as many elements as there are configurations. We "
+ "then iterate over zip(configs, num_processes) to determine then umber of processes "
+ "used for each configuration. More processes is always faster (up to the number of "
+ "threads your PC can support, so 8 for a 4 core CPU with hyperthreading. If you don't "
+ "know what that is then dont touch it, or at least don't increase it!). DANGER: More "
+ "often than not the number of processes that can be used is limited by the amount of "
+ "RAM available. Image resampling takes up a lot of RAM. MONITOR RAM USAGE AND "
+ "DECREASE -np IF YOUR RAM FILLS UP TOO MUCH!. Default: 8 processes for 2d, 4 "
+ "for 3d_fullres, 8 for 3d_lowres and 4 for everything else")
+ parser.add_argument('--verbose', required=False, action='store_true',
+ help='Set this to print a lot of stuff. Useful for debugging. Will disable progrewss bar! '
+ 'Recommended for cluster environments')
+ args, unrecognized_args = parser.parse_known_args()
+ if args.np is None:
+ default_np = {
+ '2d': 4,
+ '3d_lowres': 8,
+ '3d_fullres': 4
+ }
+ np = {default_np[c] if c in default_np.keys() else 4 for c in args.c}
+ else:
+ np = args.np
+ preprocess(args.d, args.plans_name, configurations=args.c, num_processes=np, verbose=args.verbose)
+
+
+def plan_and_preprocess_entry():
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-d', nargs='+', type=int,
+ help="[REQUIRED] List of dataset IDs. Example: 2 4 5. This will run fingerprint extraction, experiment "
+ "planning and preprocessing for these datasets. Can of course also be just one dataset")
+ parser.add_argument('-fpe', type=str, required=False, default='DatasetFingerprintExtractor',
+ help='[OPTIONAL] Name of the Dataset Fingerprint Extractor class that should be used. Default is '
+ '\'DatasetFingerprintExtractor\'.')
+ parser.add_argument('-npfp', type=int, default=8, required=False,
+ help='[OPTIONAL] Number of processes used for fingerprint extraction. Default: 8')
+ parser.add_argument("--verify_dataset_integrity", required=False, default=False, action="store_true",
+ help="[RECOMMENDED] set this flag to check the dataset integrity. This is useful and should be done once for "
+ "each dataset!")
+ parser.add_argument('--no_pp', default=False, action='store_true', required=False,
+ help='[OPTIONAL] Set this to only run fingerprint extraction and experiment planning (no '
+ 'preprocesing). Useful for debugging.')
+ parser.add_argument("--clean", required=False, default=False, action="store_true",
+ help='[OPTIONAL] Set this flag to overwrite existing fingerprints. If this flag is not set and a '
+ 'fingerprint already exists, the fingerprint extractor will not run. REQUIRED IF YOU '
+ 'CHANGE THE DATASET FINGERPRINT EXTRACTOR OR MAKE CHANGES TO THE DATASET!')
+ parser.add_argument('-pl', type=str, default='ExperimentPlanner', required=False,
+ help='[OPTIONAL] Name of the Experiment Planner class that should be used. Default is '
+ '\'ExperimentPlanner\'. Note: There is no longer a distinction between 2d and 3d planner. '
+ 'It\'s an all in one solution now. Wuch. Such amazing.')
+ parser.add_argument('-gpu_memory_target', default=8, type=int, required=False,
+ help='[OPTIONAL] DANGER ZONE! Sets a custom GPU memory target. Default: 8 [GB]. Changing this will '
+ 'affect patch and batch size and will '
+ 'definitely affect your models performance! Only use this if you really know what you '
+ 'are doing and NEVER use this without running the default nnU-Net first (as a baseline).')
+ parser.add_argument('-preprocessor_name', default='DefaultPreprocessor', type=str, required=False,
+ help='[OPTIONAL] DANGER ZONE! Sets a custom preprocessor class. This class must be located in '
+ 'nnunetv2.preprocessing. Default: \'DefaultPreprocessor\'. Changing this may affect your '
+ 'models performance! Only use this if you really know what you '
+ 'are doing and NEVER use this without running the default nnU-Net first (as a baseline).')
+ parser.add_argument('-overwrite_target_spacing', default=None, nargs='+', required=False,
+ help='[OPTIONAL] DANGER ZONE! Sets a custom target spacing for the 3d_fullres and 3d_cascade_fullres '
+ 'configurations. Default: None [no changes]. Changing this will affect image size and '
+ 'potentially patch and batch '
+ 'size. This will definitely affect your models performance! Only use this if you really '
+ 'know what you are doing and NEVER use this without running the default nnU-Net first '
+ '(as a baseline). Changing the target spacing for the other configurations is currently '
+ 'not implemented. New target spacing must be a list of three numbers!')
+ parser.add_argument('-overwrite_plans_name', default='nnUNetPlans', required=False,
+ help='[OPTIONAL] uSE A CUSTOM PLANS IDENTIFIER. If you used -gpu_memory_target, '
+ '-preprocessor_name or '
+ '-overwrite_target_spacing it is best practice to use -overwrite_plans_name to generate a '
+ 'differently named plans file such that the nnunet default plans are not '
+ 'overwritten. You will then need to specify your custom plans file with -p whenever '
+ 'running other nnunet commands (training, inference etc)')
+ parser.add_argument('-c', required=False, default=['2d', '3d_fullres', '3d_lowres'], nargs='+',
+ help='[OPTIONAL] Configurations for which the preprocessing should be run. Default: 2d 3f_fullres '
+ '3d_lowres. 3d_cascade_fullres does not need to be specified because it uses the data '
+ 'from 3f_fullres. Configurations that do not exist for some dataset will be skipped.')
+ parser.add_argument('-np', type=int, nargs='+', default=None, required=False,
+ help="[OPTIONAL] Use this to define how many processes are to be used. If this is just one number then "
+ "this number of processes is used for all configurations specified with -c. If it's a "
+ "list of numbers this list must have as many elements as there are configurations. We "
+ "then iterate over zip(configs, num_processes) to determine then umber of processes "
+ "used for each configuration. More processes is always faster (up to the number of "
+ "threads your PC can support, so 8 for a 4 core CPU with hyperthreading. If you don't "
+ "know what that is then dont touch it, or at least don't increase it!). DANGER: More "
+ "often than not the number of processes that can be used is limited by the amount of "
+ "RAM available. Image resampling takes up a lot of RAM. MONITOR RAM USAGE AND "
+ "DECREASE -np IF YOUR RAM FILLS UP TOO MUCH!. Default: 8 processes for 2d, 4 "
+ "for 3d_fullres, 8 for 3d_lowres and 4 for everything else")
+ parser.add_argument('--verbose', required=False, action='store_true',
+ help='Set this to print a lot of stuff. Useful for debugging. Will disable progrewss bar! '
+ 'Recommended for cluster environments')
+ args = parser.parse_args()
+
+ # fingerprint extraction
+ print("Fingerprint extraction...")
+ extract_fingerprints(args.d, args.fpe, args.npfp, args.verify_dataset_integrity, args.clean, args.verbose)
+
+ # experiment planning
+ print('Experiment planning...')
+ plan_experiments(args.d, args.pl, args.gpu_memory_target, args.preprocessor_name, args.overwrite_target_spacing, args.overwrite_plans_name)
+
+ # manage default np
+ if args.np is None:
+ default_np = {"2d": 8, "3d_fullres": 4, "3d_lowres": 8}
+ np = [default_np[c] if c in default_np.keys() else 4 for c in args.c]
+ else:
+ np = args.np
+ # preprocessing
+ if not args.no_pp:
+ print('Preprocessing...')
+ preprocess(args.d, args.overwrite_plans_name, args.c, np, args.verbose)
+
+
+if __name__ == '__main__':
+ plan_and_preprocess_entry()
diff --git a/nnUNet/nnunetv2/experiment_planning/plans_for_pretraining/__init__.py b/nnUNet/nnunetv2/experiment_planning/plans_for_pretraining/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/experiment_planning/plans_for_pretraining/move_plans_between_datasets.py b/nnUNet/nnunetv2/experiment_planning/plans_for_pretraining/move_plans_between_datasets.py
new file mode 100644
index 0000000..d78b689
--- /dev/null
+++ b/nnUNet/nnunetv2/experiment_planning/plans_for_pretraining/move_plans_between_datasets.py
@@ -0,0 +1,79 @@
+import argparse
+from typing import Union
+
+from batchgenerators.utilities.file_and_folder_operations import join, isdir, isfile, load_json, subfiles, save_json
+
+from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json
+from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw
+from nnunetv2.utilities.file_path_utilities import maybe_convert_to_dataset_name
+from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets
+
+
+def move_plans_between_datasets(
+ source_dataset_name_or_id: Union[int, str],
+ target_dataset_name_or_id: Union[int, str],
+ source_plans_identifier: str,
+ target_plans_identifier: str = None):
+ source_dataset_name = maybe_convert_to_dataset_name(source_dataset_name_or_id)
+ target_dataset_name = maybe_convert_to_dataset_name(target_dataset_name_or_id)
+
+ if target_plans_identifier is None:
+ target_plans_identifier = source_plans_identifier
+
+ source_folder = join(nnUNet_preprocessed, source_dataset_name)
+ assert isdir(source_folder), f"Cannot move plans because preprocessed directory of source dataset is missing. " \
+ f"Run nnUNetv2_plan_and_preprocess for source dataset first!"
+
+ source_plans_file = join(source_folder, source_plans_identifier + '.json')
+ assert isfile(source_plans_file), f"Source plans are missing. Run the corresponding experiment planning first! " \
+ f"Expected file: {source_plans_file}"
+
+ source_plans = load_json(source_plans_file)
+ source_plans['dataset_name'] = target_dataset_name
+
+ # we need to change data_identifier to use target_plans_identifier
+ if target_plans_identifier != source_plans_identifier:
+ for c in source_plans['configurations'].keys():
+ old_identifier = source_plans['configurations'][c]["data_identifier"]
+ if old_identifier.startswith(source_plans_identifier):
+ new_identifier = target_plans_identifier + old_identifier[len(source_plans_identifier):]
+ else:
+ new_identifier = target_plans_identifier + '_' + old_identifier
+ source_plans['configurations'][c]["data_identifier"] = new_identifier
+
+ # we need to change the reader writer class!
+ target_raw_data_dir = join(nnUNet_raw, target_dataset_name)
+ target_dataset_json = load_json(join(target_raw_data_dir, 'dataset.json'))
+
+ # we may need to change the reader/writer
+ # pick any file from the source dataset
+ dataset = get_filenames_of_train_images_and_targets(target_raw_data_dir, target_dataset_json)
+ example_image = dataset[dataset.keys().__iter__().__next__()]['images'][0]
+ rw = determine_reader_writer_from_dataset_json(target_dataset_json, example_image, allow_nonmatching_filename=True,
+ verbose=False)
+
+ source_plans["image_reader_writer"] = rw.__name__
+
+ save_json(source_plans, join(nnUNet_preprocessed, target_dataset_name, target_plans_identifier + '.json'),
+ sort_keys=False)
+
+
+def entry_point_move_plans_between_datasets():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-s', type=str, required=True,
+ help='Source dataset name or id')
+ parser.add_argument('-t', type=str, required=True,
+ help='Target dataset name or id')
+ parser.add_argument('-sp', type=str, required=True,
+ help='Source plans identifier. If your plans are named "nnUNetPlans.json" then the '
+ 'identifier would be nnUNetPlans')
+ parser.add_argument('-tp', type=str, required=False, default=None,
+ help='Target plans identifier. Default is None meaning the source plans identifier will '
+ 'be kept. Not recommended if the source plans identifier is a default nnU-Net identifier '
+ 'such as nnUNetPlans!!!')
+ args = parser.parse_args()
+ move_plans_between_datasets(args.s, args.t, args.sp, args.tp)
+
+
+if __name__ == '__main__':
+ move_plans_between_datasets(2, 4, 'nnUNetPlans', 'nnUNetPlansFrom2')
diff --git a/nnUNet/nnunetv2/experiment_planning/verify_dataset_integrity.py b/nnUNet/nnunetv2/experiment_planning/verify_dataset_integrity.py
new file mode 100644
index 0000000..502611c
--- /dev/null
+++ b/nnUNet/nnunetv2/experiment_planning/verify_dataset_integrity.py
@@ -0,0 +1,234 @@
+# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center
+# (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import multiprocessing
+import re
+from multiprocessing import Pool
+from typing import Type
+
+import numpy as np
+import pandas as pd
+from batchgenerators.utilities.file_and_folder_operations import *
+
+from nnunetv2.imageio.base_reader_writer import BaseReaderWriter
+from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json
+from nnunetv2.paths import nnUNet_raw
+from nnunetv2.utilities.label_handling.label_handling import LabelManager
+from nnunetv2.utilities.utils import get_identifiers_from_splitted_dataset_folder, \
+ get_filenames_of_train_images_and_targets
+
+
+def verify_labels(label_file: str, readerclass: Type[BaseReaderWriter], expected_labels: List[int]) -> bool:
+ rw = readerclass()
+ seg, properties = rw.read_seg(label_file)
+ found_labels = np.sort(pd.unique(seg.ravel())) # np.unique(seg)
+ unexpected_labels = [i for i in found_labels if i not in expected_labels]
+ if len(found_labels) == 0 and found_labels[0] == 0:
+ print('WARNING: File %s only has label 0 (which should be background). This may be intentional or not, '
+ 'up to you.' % label_file)
+ if len(unexpected_labels) > 0:
+ print("Error: Unexpected labels found in file %s.\nExpected: %s\nFound: %s" % (label_file, expected_labels,
+ found_labels))
+ return False
+ return True
+
+
+def check_cases(image_files: List[str], label_file: str, expected_num_channels: int,
+ readerclass: Type[BaseReaderWriter]) -> bool:
+ rw = readerclass()
+ ret = True
+
+ images, properties_image = rw.read_images(image_files)
+ segmentation, properties_seg = rw.read_seg(label_file)
+
+ # check for nans
+ if np.any(np.isnan(images)):
+ print(f'Images contain NaN pixel values. You need to fix that by '
+ f'replacing NaN values with something that makes sense for your images!\nImages:\n{image_files}')
+ ret = False
+ if np.any(np.isnan(segmentation)):
+ print(f'Segmentation contains NaN pixel values. You need to fix that.\nSegmentation:\n{label_file}')
+ ret = False
+
+ # check shapes
+ shape_image = images.shape[1:]
+ shape_seg = segmentation.shape[1:]
+ if not all([i == j for i, j in zip(shape_image, shape_seg)]):
+ print('Error: Shape mismatch between segmentation and corresponding images. \nShape images: %s. '
+ '\nShape seg: %s. \nImage files: %s. \nSeg file: %s\n' %
+ (shape_image, shape_seg, image_files, label_file))
+ ret = False
+
+ # check spacings
+ spacing_images = properties_image['spacing']
+ spacing_seg = properties_seg['spacing']
+ if not np.allclose(spacing_seg, spacing_images):
+ print('Error: Spacing mismatch between segmentation and corresponding images. \nSpacing images: %s. '
+ '\nSpacing seg: %s. \nImage files: %s. \nSeg file: %s\n' %
+ (shape_image, shape_seg, image_files, label_file))
+ ret = False
+
+ # check modalities
+ if not len(images) == expected_num_channels:
+ print('Error: Unexpected number of modalities. \nExpected: %d. \nGot: %d. \nImages: %s\n'
+ % (expected_num_channels, len(images), image_files))
+ ret = False
+
+ # nibabel checks
+ if 'nibabel_stuff' in properties_image.keys():
+ # this image was read with NibabelIO
+ affine_image = properties_image['nibabel_stuff']['original_affine']
+ affine_seg = properties_seg['nibabel_stuff']['original_affine']
+ if not np.allclose(affine_image, affine_seg):
+ print('WARNING: Affine is not the same for image and seg! \nAffine image: %s \nAffine seg: %s\n'
+ 'Image files: %s. \nSeg file: %s.\nThis can be a problem but doesn\'t have to be. Please run '
+ 'nnUNet_plot_dataset_pngs to verify if everything is OK!\n'
+ % (affine_image, affine_seg, image_files, label_file))
+
+ # sitk checks
+ if 'sitk_stuff' in properties_image.keys():
+ # this image was read with SimpleITKIO
+ # spacing has already been checked, only check direction and origin
+ origin_image = properties_image['sitk_stuff']['origin']
+ origin_seg = properties_seg['sitk_stuff']['origin']
+ if not np.allclose(origin_image, origin_seg):
+ print('Warning: Origin mismatch between segmentation and corresponding images. \nOrigin images: %s. '
+ '\nOrigin seg: %s. \nImage files: %s. \nSeg file: %s\n' %
+ (origin_image, origin_seg, image_files, label_file))
+ direction_image = properties_image['sitk_stuff']['direction']
+ direction_seg = properties_seg['sitk_stuff']['direction']
+ if not np.allclose(direction_image, direction_seg):
+ print('Warning: Direction mismatch between segmentation and corresponding images. \nDirection images: %s. '
+ '\nDirection seg: %s. \nImage files: %s. \nSeg file: %s\n' %
+ (direction_image, direction_seg, image_files, label_file))
+
+ return ret
+
+
+def verify_dataset_integrity(folder: str, num_processes: int = 8) -> None:
+ """
+ folder needs the imagesTr, imagesTs and labelsTr subfolders. There also needs to be a dataset.json
+ checks if the expected number of training cases and labels are present
+ for each case, if possible, checks whether the pixel grids are aligned
+ checks whether the labels really only contain values they should
+ :param folder:
+ :return:
+ """
+ assert isfile(join(folder, "dataset.json")), "There needs to be a dataset.json file in folder, folder=%s" % folder
+ dataset_json = load_json(join(folder, "dataset.json"))
+
+ if not 'dataset' in dataset_json.keys():
+ assert isdir(join(folder, "imagesTr")), "There needs to be a imagesTr subfolder in folder, folder=%s" % folder
+ assert isdir(join(folder, "labelsTr")), "There needs to be a labelsTr subfolder in folder, folder=%s" % folder
+
+ # make sure all required keys are there
+ dataset_keys = list(dataset_json.keys())
+ required_keys = ['labels', "channel_names", "numTraining", "file_ending"]
+ assert all([i in dataset_keys for i in required_keys]), 'not all required keys are present in dataset.json.' \
+ '\n\nRequired: \n%s\n\nPresent: \n%s\n\nMissing: ' \
+ '\n%s\n\nUnused by nnU-Net:\n%s' % \
+ (str(required_keys),
+ str(dataset_keys),
+ str([i for i in required_keys if i not in dataset_keys]),
+ str([i for i in dataset_keys if i not in required_keys]))
+
+ expected_num_training = dataset_json['numTraining']
+ num_modalities = len(dataset_json['channel_names'].keys()
+ if 'channel_names' in dataset_json.keys()
+ else dataset_json['modality'].keys())
+ file_ending = dataset_json['file_ending']
+
+ dataset = get_filenames_of_train_images_and_targets(folder, dataset_json)
+
+ # check if the right number of training cases is present
+ assert len(dataset) == expected_num_training, 'Did not find the expected number of training cases ' \
+ '(%d). Found %d instead.\nExamples: %s' % \
+ (expected_num_training, len(dataset),
+ list(dataset.keys())[:5])
+
+ # check if corresponding labels are present
+ if 'dataset' in dataset_json.keys():
+ # just check if everything is there
+ ok = True
+ missing_images = []
+ missing_labels = []
+ for k in dataset:
+ for i in dataset[k]['images']:
+ if not isfile(i):
+ missing_images.append(i)
+ ok = False
+ if not isfile(dataset[k]['label']):
+ missing_labels.append(dataset[k]['label'])
+ ok = False
+ if not ok:
+ raise FileNotFoundError(f"Some expeted files were missing. Make sure you are properly referencing them "
+ f"in the dataset.json. Or use imagesTr & labelsTr folders!\nMissing images:"
+ f"\n{missing_images}\n\nMissing labels:\n{missing_labels}")
+ else:
+ # old code that uses imagestr and labelstr folders
+ labelfiles = subfiles(join(folder, 'labelsTr'), suffix=file_ending, join=False)
+ label_identifiers = [i[:-len(file_ending)] for i in labelfiles]
+ labels_present = [i in label_identifiers for i in dataset.keys()]
+ missing = [i for j, i in enumerate(dataset.keys()) if not labels_present[j]]
+ assert all(labels_present), 'not all training cases have a label file in labelsTr. Fix that. Missing: %s' % missing
+
+ labelfiles = [v['label'] for v in dataset.values()]
+ image_files = [v['images'] for v in dataset.values()]
+
+ # no plans exist yet, so we can't use PlansManager and gotta roll with the default. It's unlikely to cause
+ # problems anyway
+ label_manager = LabelManager(dataset_json['labels'], regions_class_order=dataset_json.get('regions_class_order'))
+ expected_labels = label_manager.all_labels
+ if label_manager.has_ignore_label:
+ expected_labels.append(label_manager.ignore_label)
+ labels_valid_consecutive = np.ediff1d(expected_labels) == 1
+ assert all(
+ labels_valid_consecutive), f'Labels must be in consecutive order (0, 1, 2, ...). The labels {np.array(expected_labels)[1:][~labels_valid_consecutive]} do not satisfy this restriction'
+
+ # determine reader/writer class
+ reader_writer_class = determine_reader_writer_from_dataset_json(dataset_json, dataset[dataset.keys().__iter__().__next__()]['images'][0])
+
+ # check whether only the desired labels are present
+ with multiprocessing.get_context("spawn").Pool(num_processes) as p:
+ result = p.starmap(
+ verify_labels,
+ zip([join(folder, 'labelsTr', i) for i in labelfiles], [reader_writer_class] * len(labelfiles),
+ [expected_labels] * len(labelfiles))
+ )
+ if not all(result):
+ raise RuntimeError(
+ 'Some segmentation images contained unexpected labels. Please check text output above to see which one(s).')
+
+ # check whether shapes and spacings match between images and labels
+ result = p.starmap(
+ check_cases,
+ zip(image_files, labelfiles, [num_modalities] * expected_num_training,
+ [reader_writer_class] * expected_num_training)
+ )
+ if not all(result):
+ raise RuntimeError(
+ 'Some images have errors. Please check text output above to see which one(s) and what\'s going on.')
+
+ # check for nans
+ # check all same orientation nibabel
+ print('\n####################')
+ print('verify_dataset_integrity Done. \nIf you didn\'t see any error messages then your dataset is most likely OK!')
+ print('####################\n')
+
+
+if __name__ == "__main__":
+ # investigate geometry issues
+ example_folder = join(nnUNet_raw, 'Dataset250_COMPUTING_it0')
+ num_processes = 6
+ verify_dataset_integrity(example_folder, num_processes)
diff --git a/nnUNet/nnunetv2/imageio/__init__.py b/nnUNet/nnunetv2/imageio/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/imageio/base_reader_writer.py b/nnUNet/nnunetv2/imageio/base_reader_writer.py
new file mode 100644
index 0000000..d71226f
--- /dev/null
+++ b/nnUNet/nnunetv2/imageio/base_reader_writer.py
@@ -0,0 +1,113 @@
+# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center
+# (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+from typing import Tuple, Union, List
+import numpy as np
+
+
+class BaseReaderWriter(ABC):
+ @staticmethod
+ def _check_all_same(input_list):
+ # compare all entries to the first
+ for i in input_list[1:]:
+ if not len(i) == len(input_list[0]):
+ return False
+ all_same = all(i[j] == input_list[0][j] for j in range(len(i)))
+ if not all_same:
+ return False
+ return True
+
+ @staticmethod
+ def _check_all_same_array(input_list):
+ # compare all entries to the first
+ for i in input_list[1:]:
+ if not all([a == b for a, b in zip(i.shape, input_list[0].shape)]):
+ return False
+ all_same = np.allclose(i, input_list[0])
+ if not all_same:
+ return False
+ return True
+
+ @abstractmethod
+ def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:
+ """
+ Reads a sequence of images and returns a 4d (!) np.ndarray along with a dictionary. The 4d array must have the
+ modalities (or color channels, or however you would like to call them) in its first axis, followed by the
+ spatial dimensions (so shape must be c,x,y,z where c is the number of modalities (can be 1)).
+ Use the dictionary to store necessary meta information that is lost when converting to numpy arrays, for
+ example the Spacing, Orientation and Direction of the image. This dictionary will be handed over to write_seg
+ for exporting the predicted segmentations, so make sure you have everything you need in there!
+
+ IMPORTANT: dict MUST have a 'spacing' key with a tuple/list of length 3 with the voxel spacing of the np.ndarray.
+ Example: my_dict = {'spacing': (3, 0.5, 0.5), ...}. This is needed for planning and
+ preprocessing. The ordering of the numbers must correspond to the axis ordering in the returned numpy array. So
+ if the array has shape c,x,y,z and the spacing is (a,b,c) then a must be the spacing of x, b the spacing of y
+ and c the spacing of z.
+
+ In the case of 2D images, the returned array should have shape (c, 1, x, y) and the spacing should be
+ (999, sp_x, sp_y). Make sure 999 is larger than sp_x and sp_y! Example: shape=(3, 1, 224, 224),
+ spacing=(999, 1, 1)
+
+ For images that don't have a spacing, set the spacing to 1 (2d exception with 999 for the first axis still applies!)
+
+ :param image_fnames:
+ :return:
+ 1) a np.ndarray of shape (c, x, y, z) where c is the number of image channels (can be 1) and x, y, z are
+ the spatial dimensions (set x=1 for 2D! Example: (3, 1, 224, 224) for RGB image).
+ 2) a dictionary with metadata. This can be anything. BUT it HAS to inclue a {'spacing': (a, b, c)} where a
+ is the spacing of x, b of y and c of z! If an image doesn't have spacing, just set this to 1. For 2D, set
+ a=999 (largest spacing value! Make it larger than b and c)
+
+ """
+ pass
+
+ @abstractmethod
+ def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]:
+ """
+ Same requirements as BaseReaderWriter.read_image. Returned segmentations must have shape 1,x,y,z. Multiple
+ segmentations are not (yet?) allowed
+
+ If images and segmentations can be read the same way you can just `return self.read_image((image_fname,))`
+ :param seg_fname:
+ :return:
+ 1) a np.ndarray of shape (1, x, y, z) where x, y, z are
+ the spatial dimensions (set x=1 for 2D! Example: (1, 1, 224, 224) for 2D segmentation).
+ 2) a dictionary with metadata. This can be anything. BUT it HAS to inclue a {'spacing': (a, b, c)} where a
+ is the spacing of x, b of y and c of z! If an image doesn't have spacing, just set this to 1. For 2D, set
+ a=999 (largest spacing value! Make it larger than b and c)
+ """
+ pass
+
+ @abstractmethod
+ def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None:
+ """
+ Export the predicted segmentation to the desired file format. The given seg array will have the same shape and
+ orientation as the corresponding image data, so you don't need to do any resampling or whatever. Just save :-)
+
+ properties is the same dictionary you created during read_images/read_seg so you can use the information here
+ to restore metadata
+
+ IMPORTANT: Segmentations are always 3D! If your input images were 2d then the segmentation will have shape
+ 1,x,y. You need to catch that and export accordingly (for 2d images you need to convert the 3d segmentation
+ to 2d via seg = seg[0])!
+
+ :param seg: A segmentation (np.ndarray, integer) of shape (x, y, z). For 2D segmentations this will be (1, y, z)!
+ :param output_fname:
+ :param properties: the dictionary that you created in read_images (the ones this segmentation is based on).
+ Use this to restore metadata
+ :return:
+ """
+ pass
\ No newline at end of file
diff --git a/nnUNet/nnunetv2/imageio/natural_image_reager_writer.py b/nnUNet/nnunetv2/imageio/natural_image_reager_writer.py
new file mode 100644
index 0000000..6dd7718
--- /dev/null
+++ b/nnUNet/nnunetv2/imageio/natural_image_reager_writer.py
@@ -0,0 +1,73 @@
+# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center
+# (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Tuple, Union, List
+import numpy as np
+from nnunetv2.imageio.base_reader_writer import BaseReaderWriter
+from skimage import io
+
+
+class NaturalImage2DIO(BaseReaderWriter):
+ """
+ ONLY SUPPORTS 2D IMAGES!!!
+ """
+
+ # there are surely more we could add here. Everything that can be read by skimage.io should be supported
+ supported_file_endings = [
+ '.png',
+ # '.jpg',
+ # '.jpeg', # jpg not supported because we cannot allow lossy compression! segmentation maps!
+ '.bmp',
+ '.tif'
+ ]
+
+ def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:
+ images = []
+ for f in image_fnames:
+ npy_img = io.imread(f)
+ if len(npy_img.shape) == 3:
+ # rgb image, last dimension should be the color channel and the size of that channel should be 3
+ # (or 4 if we have alpha)
+ assert npy_img.shape[-1] == 3 or npy_img.shape[-1] == 4, "If image has three dimensions then the last " \
+ "dimension must have shape 3 or 4 " \
+ f"(RGB or RGBA). Image shape here is {npy_img.shape}"
+ # move RGB(A) to front, add additional dim so that we have shape (1, c, X, Y), where c is either 3 or 4
+ images.append(npy_img.transpose((2, 0, 1))[:, None])
+ elif len(npy_img.shape) == 2:
+ # grayscale image
+ images.append(npy_img[None, None])
+
+ if not self._check_all_same([i.shape for i in images]):
+ print('ERROR! Not all input images have the same shape!')
+ print('Shapes:')
+ print([i.shape for i in images])
+ print('Image files:')
+ print(image_fnames)
+ raise RuntimeError()
+ return np.vstack(images).astype(np.float32), {'spacing': (999, 1, 1)}
+
+ def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]:
+ return self.read_images((seg_fname, ))
+
+ def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None:
+ io.imsave(output_fname, seg[0].astype(np.uint8), check_contrast=False)
+
+
+if __name__ == '__main__':
+ images = ('/media/fabian/data/nnUNet_raw/Dataset120_RoadSegmentation/imagesTr/img-11_0000.png',)
+ segmentation = '/media/fabian/data/nnUNet_raw/Dataset120_RoadSegmentation/labelsTr/img-11.png'
+ imgio = NaturalImage2DIO()
+ img, props = imgio.read_images(images)
+ seg, segprops = imgio.read_seg(segmentation)
\ No newline at end of file
diff --git a/nnUNet/nnunetv2/imageio/nibabel_reader_writer.py b/nnUNet/nnunetv2/imageio/nibabel_reader_writer.py
new file mode 100644
index 0000000..e4fa3f5
--- /dev/null
+++ b/nnUNet/nnunetv2/imageio/nibabel_reader_writer.py
@@ -0,0 +1,204 @@
+# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center
+# (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Tuple, Union, List
+import numpy as np
+from nibabel import io_orientation
+
+from nnunetv2.imageio.base_reader_writer import BaseReaderWriter
+import nibabel
+
+
+class NibabelIO(BaseReaderWriter):
+ """
+ Nibabel loads the images in a different order than sitk. We convert the axes to the sitk order to be
+ consistent. This is of course considered properly in segmentation export as well.
+
+ IMPORTANT: Run nnUNet_plot_dataset_pngs to verify that this did not destroy the alignment of data and seg!
+ """
+ supported_file_endings = [
+ '.nii.gz',
+ '.nrrd',
+ '.mha'
+ ]
+
+ def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:
+ images = []
+ original_affines = []
+
+ spacings_for_nnunet = []
+ for f in image_fnames:
+ nib_image = nibabel.load(f)
+ assert len(nib_image.shape) == 3, 'only 3d images are supported by NibabelIO'
+ original_affine = nib_image.affine
+
+ original_affines.append(original_affine)
+
+ # spacing is taken in reverse order to be consistent with SimpleITK axis ordering (confusing, I know...)
+ spacings_for_nnunet.append(
+ [float(i) for i in nib_image.header.get_zooms()[::-1]]
+ )
+
+ # transpose image to be consistent with the way SimpleITk reads images. Yeah. Annoying.
+ images.append(nib_image.get_fdata().transpose((2, 1, 0))[None])
+
+ if not self._check_all_same([i.shape for i in images]):
+ print('ERROR! Not all input images have the same shape!')
+ print('Shapes:')
+ print([i.shape for i in images])
+ print('Image files:')
+ print(image_fnames)
+ raise RuntimeError()
+ if not self._check_all_same_array(original_affines):
+ print('WARNING! Not all input images have the same original_affines!')
+ print('Affines:')
+ print(original_affines)
+ print('Image files:')
+ print(image_fnames)
+ print('It is up to you to decide whether that\'s a problem. You should run nnUNet_plot_dataset_pngs to verify '
+ 'that segmentations and data overlap.')
+ if not self._check_all_same(spacings_for_nnunet):
+ print('ERROR! Not all input images have the same spacing_for_nnunet! This might be caused by them not '
+ 'having the same affine')
+ print('spacings_for_nnunet:')
+ print(spacings_for_nnunet)
+ print('Image files:')
+ print(image_fnames)
+ raise RuntimeError()
+
+ stacked_images = np.vstack(images)
+ dict = {
+ 'nibabel_stuff': {
+ 'original_affine': original_affines[0],
+ },
+ 'spacing': spacings_for_nnunet[0]
+ }
+ return stacked_images.astype(np.float32), dict
+
+ def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]:
+ return self.read_images((seg_fname, ))
+
+ def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None:
+ # revert transpose
+ seg = seg.transpose((2, 1, 0)).astype(np.uint8)
+ seg_nib = nibabel.Nifti1Image(seg, affine=properties['nibabel_stuff']['original_affine'])
+ nibabel.save(seg_nib, output_fname)
+
+
+class NibabelIOWithReorient(BaseReaderWriter):
+ """
+ Reorients images to RAS
+
+ Nibabel loads the images in a different order than sitk. We convert the axes to the sitk order to be
+ consistent. This is of course considered properly in segmentation export as well.
+
+ IMPORTANT: Run nnUNet_plot_dataset_pngs to verify that this did not destroy the alignment of data and seg!
+ """
+ supported_file_endings = [
+ '.nii.gz',
+ '.nrrd',
+ '.mha'
+ ]
+
+ def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:
+ images = []
+ original_affines = []
+ reoriented_affines = []
+
+ spacings_for_nnunet = []
+ for f in image_fnames:
+ nib_image = nibabel.load(f)
+ assert len(nib_image.shape) == 3, 'only 3d images are supported by NibabelIO'
+ original_affine = nib_image.affine
+ reoriented_image = nib_image.as_reoriented(io_orientation(original_affine))
+ reoriented_affine = reoriented_image.affine
+
+ original_affines.append(original_affine)
+ reoriented_affines.append(reoriented_affine)
+
+ # spacing is taken in reverse order to be consistent with SimpleITK axis ordering (confusing, I know...)
+ spacings_for_nnunet.append(
+ [float(i) for i in reoriented_image.header.get_zooms()[::-1]]
+ )
+
+ # transpose image to be consistent with the way SimpleITk reads images. Yeah. Annoying.
+ images.append(reoriented_image.get_fdata().transpose((2, 1, 0))[None])
+
+ if not self._check_all_same([i.shape for i in images]):
+ print('ERROR! Not all input images have the same shape!')
+ print('Shapes:')
+ print([i.shape for i in images])
+ print('Image files:')
+ print(image_fnames)
+ raise RuntimeError()
+ if not self._check_all_same_array(reoriented_affines):
+ print('WARNING! Not all input images have the same reoriented_affines!')
+ print('Affines:')
+ print(reoriented_affines)
+ print('Image files:')
+ print(image_fnames)
+ print('It is up to you to decide whether that\'s a problem. You should run nnUNet_plot_dataset_pngs to verify '
+ 'that segmentations and data overlap.')
+ if not self._check_all_same(spacings_for_nnunet):
+ print('ERROR! Not all input images have the same spacing_for_nnunet! This might be caused by them not '
+ 'having the same affine')
+ print('spacings_for_nnunet:')
+ print(spacings_for_nnunet)
+ print('Image files:')
+ print(image_fnames)
+ raise RuntimeError()
+
+ stacked_images = np.vstack(images)
+ dict = {
+ 'nibabel_stuff': {
+ 'original_affine': original_affines[0],
+ 'reoriented_affine': reoriented_affines[0],
+ },
+ 'spacing': spacings_for_nnunet[0]
+ }
+ return stacked_images.astype(np.float32), dict
+
+ def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]:
+ return self.read_images((seg_fname, ))
+
+ def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None:
+ # revert transpose
+ seg = seg.transpose((2, 1, 0)).astype(np.uint8)
+
+ seg_nib = nibabel.Nifti1Image(seg, affine=properties['nibabel_stuff']['reoriented_affine'])
+ seg_nib_reoriented = seg_nib.as_reoriented(io_orientation(properties['nibabel_stuff']['original_affine']))
+ assert np.allclose(properties['nibabel_stuff']['original_affine'], seg_nib_reoriented.affine), \
+ 'restored affine does not match original affine'
+ nibabel.save(seg_nib_reoriented, output_fname)
+
+
+if __name__ == '__main__':
+ img_file = 'patient028_frame01_0000.nii.gz'
+ seg_file = 'patient028_frame01.nii.gz'
+
+ nibio = NibabelIO()
+ images, dct = nibio.read_images([img_file])
+ seg, dctseg = nibio.read_seg(seg_file)
+
+ nibio_r = NibabelIOWithReorient()
+ images_r, dct_r = nibio_r.read_images([img_file])
+ seg_r, dctseg_r = nibio_r.read_seg(seg_file)
+
+ nibio.write_seg(seg[0], '/home/isensee/seg_nibio.nii.gz', dctseg)
+ nibio_r.write_seg(seg_r[0], '/home/isensee/seg_nibio_r.nii.gz', dctseg_r)
+
+ s_orig = nibabel.load(seg_file).get_fdata()
+ s_nibio = nibabel.load('/home/isensee/seg_nibio.nii.gz').get_fdata()
+ s_nibio_r = nibabel.load('/home/isensee/seg_nibio_r.nii.gz').get_fdata()
diff --git a/nnUNet/nnunetv2/imageio/reader_writer_registry.py b/nnUNet/nnunetv2/imageio/reader_writer_registry.py
new file mode 100644
index 0000000..bdbee5d
--- /dev/null
+++ b/nnUNet/nnunetv2/imageio/reader_writer_registry.py
@@ -0,0 +1,79 @@
+import traceback
+from typing import Type
+
+from batchgenerators.utilities.file_and_folder_operations import join
+
+import nnunetv2
+from nnunetv2.imageio.natural_image_reager_writer import NaturalImage2DIO
+from nnunetv2.imageio.nibabel_reader_writer import NibabelIO, NibabelIOWithReorient
+from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO
+from nnunetv2.imageio.tif_reader_writer import Tiff3DIO
+from nnunetv2.imageio.base_reader_writer import BaseReaderWriter
+from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
+
+LIST_OF_IO_CLASSES = [
+ NaturalImage2DIO,
+ SimpleITKIO,
+ Tiff3DIO,
+ NibabelIO,
+ NibabelIOWithReorient
+]
+
+
+def determine_reader_writer_from_dataset_json(dataset_json_content: dict, example_file: str = None,
+ allow_nonmatching_filename: bool = False, verbose: bool = True
+ ) -> Type[BaseReaderWriter]:
+ if 'overwrite_image_reader_writer' in dataset_json_content.keys() and \
+ dataset_json_content['overwrite_image_reader_writer'] != 'None':
+ ioclass_name = dataset_json_content['overwrite_image_reader_writer']
+ # trying to find that class in the nnunetv2.imageio module
+ try:
+ ret = recursive_find_reader_writer_by_name(ioclass_name)
+ if verbose: print('Using %s reader/writer' % ret)
+ return ret
+ except RuntimeError:
+ if verbose: print('Warning: Unable to find ioclass specified in dataset.json: %s' % ioclass_name)
+ if verbose: print('Trying to automatically determine desired class')
+ return determine_reader_writer_from_file_ending(dataset_json_content['file_ending'], example_file,
+ allow_nonmatching_filename, verbose)
+
+
+def determine_reader_writer_from_file_ending(file_ending: str, example_file: str = None, allow_nonmatching_filename: bool = False,
+ verbose: bool = True):
+ for rw in LIST_OF_IO_CLASSES:
+ if file_ending.lower() in rw.supported_file_endings:
+ if example_file is not None:
+ # if an example file is provided, try if we can actually read it. If not move on to the next reader
+ try:
+ tmp = rw()
+ _ = tmp.read_images((example_file,))
+ if verbose: print('Using %s as reader/writer' % rw)
+ return rw
+ except:
+ if verbose: print(f'Failed to open file {example_file} with reader {rw}:')
+ traceback.print_exc()
+ pass
+ else:
+ if verbose: print('Using %s as reader/writer' % rw)
+ return rw
+ else:
+ if allow_nonmatching_filename and example_file is not None:
+ try:
+ tmp = rw()
+ _ = tmp.read_images((example_file,))
+ if verbose: print('Using %s as reader/writer' % rw)
+ return rw
+ except:
+ if verbose: print(f'Failed to open file {example_file} with reader {rw}:')
+ if verbose: traceback.print_exc()
+ pass
+ raise RuntimeError("Unable to determine a reader for file ending %s and file %s (file None means no file provided)." % (file_ending, example_file))
+
+
+def recursive_find_reader_writer_by_name(rw_class_name: str) -> Type[BaseReaderWriter]:
+ ret = recursive_find_python_class(join(nnunetv2.__path__[0], "imageio"), rw_class_name, 'nnunetv2.imageio')
+ if ret is None:
+ raise RuntimeError("Unable to find reader writer class '%s'. Please make sure this class is located in the "
+ "nnunetv2.imageio module." % rw_class_name)
+ else:
+ return ret
diff --git a/nnUNet/nnunetv2/imageio/readme.md b/nnUNet/nnunetv2/imageio/readme.md
new file mode 100644
index 0000000..7819425
--- /dev/null
+++ b/nnUNet/nnunetv2/imageio/readme.md
@@ -0,0 +1,7 @@
+- Derive your adapter from `BaseReaderWriter`.
+- Reimplement all abstractmethods.
+- make sure to support 2d and 3d input images (or raise some error).
+- place it in this folder or nnU-Net won't find it!
+- add it to LIST_OF_IO_CLASSES in `reader_writer_registry.py`
+
+Bam, you're done!
\ No newline at end of file
diff --git a/nnUNet/nnunetv2/imageio/simpleitk_reader_writer.py b/nnUNet/nnunetv2/imageio/simpleitk_reader_writer.py
new file mode 100644
index 0000000..2b9b168
--- /dev/null
+++ b/nnUNet/nnunetv2/imageio/simpleitk_reader_writer.py
@@ -0,0 +1,129 @@
+# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center
+# (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Tuple, Union, List
+import numpy as np
+from nnunetv2.imageio.base_reader_writer import BaseReaderWriter
+import SimpleITK as sitk
+
+
+class SimpleITKIO(BaseReaderWriter):
+ supported_file_endings = [
+ '.nii.gz',
+ '.nrrd',
+ '.mha'
+ ]
+
+ def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:
+ images = []
+ spacings = []
+ origins = []
+ directions = []
+
+ spacings_for_nnunet = []
+ for f in image_fnames:
+ itk_image = sitk.ReadImage(f)
+ spacings.append(itk_image.GetSpacing())
+ origins.append(itk_image.GetOrigin())
+ directions.append(itk_image.GetDirection())
+ npy_image = sitk.GetArrayFromImage(itk_image)
+ if len(npy_image.shape) == 2:
+ # 2d
+ npy_image = npy_image[None, None]
+ max_spacing = max(spacings[-1])
+ spacings_for_nnunet.append((max_spacing * 999, *list(spacings[-1])[::-1]))
+ elif len(npy_image.shape) == 3:
+ # 3d, as in original nnunet
+ npy_image = npy_image[None]
+ spacings_for_nnunet.append(list(spacings[-1])[::-1])
+ elif len(npy_image.shape) == 4:
+ # 4d, multiple modalities in one file
+ spacings_for_nnunet.append(list(spacings[-1])[::-1][1:])
+ pass
+ else:
+ raise RuntimeError("Unexpected number of dimensions: %d in file %s" % (len(npy_image.shape), f))
+
+ images.append(npy_image)
+ spacings_for_nnunet[-1] = list(np.abs(spacings_for_nnunet[-1]))
+
+ if not self._check_all_same([i.shape for i in images]):
+ print('ERROR! Not all input images have the same shape!')
+ print('Shapes:')
+ print([i.shape for i in images])
+ print('Image files:')
+ print(image_fnames)
+ raise RuntimeError()
+ if not self._check_all_same(spacings):
+ print('ERROR! Not all input images have the same spacing!')
+ print('Spacings:')
+ print(spacings)
+ print('Image files:')
+ print(image_fnames)
+ raise RuntimeError()
+ if not self._check_all_same(origins):
+ print('WARNING! Not all input images have the same origin!')
+ print('Origins:')
+ print(origins)
+ print('Image files:')
+ print(image_fnames)
+ print('It is up to you to decide whether that\'s a problem. You should run nnUNet_plot_dataset_pngs to verify '
+ 'that segmentations and data overlap.')
+ if not self._check_all_same(directions):
+ print('WARNING! Not all input images have the same direction!')
+ print('Directions:')
+ print(directions)
+ print('Image files:')
+ print(image_fnames)
+ print('It is up to you to decide whether that\'s a problem. You should run nnUNet_plot_dataset_pngs to verify '
+ 'that segmentations and data overlap.')
+ if not self._check_all_same(spacings_for_nnunet):
+ print('ERROR! Not all input images have the same spacing_for_nnunet! (This should not happen and must be a '
+ 'bug. Please report!')
+ print('spacings_for_nnunet:')
+ print(spacings_for_nnunet)
+ print('Image files:')
+ print(image_fnames)
+ raise RuntimeError()
+
+ stacked_images = np.vstack(images)
+ dict = {
+ 'sitk_stuff': {
+ # this saves the sitk geometry information. This part is NOT used by nnU-Net!
+ 'spacing': spacings[0],
+ 'origin': origins[0],
+ 'direction': directions[0]
+ },
+ # the spacing is inverted with [::-1] because sitk returns the spacing in the wrong order lol. Image arrays
+ # are returned x,y,z but spacing is returned z,y,x. Duh.
+ 'spacing': spacings_for_nnunet[0]
+ }
+ return stacked_images.astype(np.float32), dict
+
+ def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]:
+ return self.read_images((seg_fname, ))
+
+ def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None:
+ assert len(seg.shape) == 3, 'segmentation must be 3d. If you are exporting a 2d segmentation, please provide it as shape 1,x,y'
+ output_dimension = len(properties['sitk_stuff']['spacing'])
+ assert 1 < output_dimension < 4
+ if output_dimension == 2:
+ seg = seg[0]
+
+ itk_image = sitk.GetImageFromArray(seg.astype(np.uint8))
+ itk_image.SetSpacing(properties['sitk_stuff']['spacing'])
+ itk_image.SetOrigin(properties['sitk_stuff']['origin'])
+ itk_image.SetDirection(properties['sitk_stuff']['direction'])
+
+ sitk.WriteImage(itk_image, output_fname)
diff --git a/nnUNet/nnunetv2/imageio/tif_reader_writer.py b/nnUNet/nnunetv2/imageio/tif_reader_writer.py
new file mode 100644
index 0000000..0aa5ff3
--- /dev/null
+++ b/nnUNet/nnunetv2/imageio/tif_reader_writer.py
@@ -0,0 +1,100 @@
+# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center
+# (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os.path
+from typing import Tuple, Union, List
+import numpy as np
+from nnunetv2.imageio.base_reader_writer import BaseReaderWriter
+import tifffile
+from batchgenerators.utilities.file_and_folder_operations import isfile, load_json, save_json, split_path, join
+
+
+class Tiff3DIO(BaseReaderWriter):
+ """
+ reads and writes 3D tif(f) images. Uses tifffile package. Ignores metadata (for now)!
+
+ If you have 2D tiffs, use NaturalImage2DIO
+
+ Supports the use of auxiliary files for spacing information. If used, the auxiliary files are expected to end
+ with .json and omit the channel identifier. So, for example, the corresponding of image image1_0000.tif is
+ expected to be image1.json)!
+ """
+ supported_file_endings = [
+ '.tif',
+ '.tiff',
+ ]
+
+ def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]:
+ # figure out file ending used here
+ ending = '.' + image_fnames[0].split('.')[-1]
+ assert ending.lower() in self.supported_file_endings, f'Ending {ending} not supported by {self.__class__.__name__}'
+ ending_length = len(ending)
+ truncate_length = ending_length + 5 # 5 comes from len(_0000)
+
+ images = []
+ for f in image_fnames:
+ image = tifffile.imread(f)
+ if len(image.shape) != 3:
+ raise RuntimeError("Only 3D images are supported! File: %s" % f)
+ images.append(image[None])
+
+ # see if aux file can be found
+ expected_aux_file = image_fnames[0][:-truncate_length] + '.json'
+ if isfile(expected_aux_file):
+ spacing = load_json(expected_aux_file)['spacing']
+ assert len(spacing) == 3, 'spacing must have 3 entries, one for each dimension of the image. File: %s' % expected_aux_file
+ else:
+ print(f'WARNING no spacing file found for images {image_fnames}\nAssuming spacing (1, 1, 1).')
+ spacing = (1, 1, 1)
+
+ if not self._check_all_same([i.shape for i in images]):
+ print('ERROR! Not all input images have the same shape!')
+ print('Shapes:')
+ print([i.shape for i in images])
+ print('Image files:')
+ print(image_fnames)
+ raise RuntimeError()
+
+ return np.vstack(images).astype(np.float32), {'spacing': spacing}
+
+ def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None:
+ # not ideal but I really have no clue how to set spacing/resolution information properly in tif files haha
+ tifffile.imwrite(output_fname, data=seg.astype(np.uint8), compression='zlib')
+ file = os.path.basename(output_fname)
+ out_dir = os.path.dirname(output_fname)
+ ending = file.split('.')[-1]
+ save_json({'spacing': properties['spacing']}, join(out_dir, file[:-(len(ending) + 1)] + '.json'))
+
+ def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]:
+ # figure out file ending used here
+ ending = '.' + seg_fname.split('.')[-1]
+ assert ending.lower() in self.supported_file_endings, f'Ending {ending} not supported by {self.__class__.__name__}'
+ ending_length = len(ending)
+
+ seg = tifffile.imread(seg_fname)
+ if len(seg.shape) != 3:
+ raise RuntimeError(f"Only 3D images are supported! File: {seg_fname}")
+ seg = seg[None]
+
+ # see if aux file can be found
+ expected_aux_file = seg_fname[:-ending_length] + '.json'
+ if isfile(expected_aux_file):
+ spacing = load_json(expected_aux_file)['spacing']
+ assert len(spacing) == 3, 'spacing must have 3 entries, one for each dimension of the image. File: %s' % expected_aux_file
+ assert all([i > 0 for i in spacing]), f"Spacing must be > 0, spacing: {spacing}"
+ else:
+ print(f'WARNING no spacing file found for segmentation {seg_fname}\nAssuming spacing (1, 1, 1).')
+ spacing = (1, 1, 1)
+
+ return seg.astype(np.float32), {'spacing': spacing}
\ No newline at end of file
diff --git a/nnUNet/nnunetv2/inference/__init__.py b/nnUNet/nnunetv2/inference/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/inference/data_iterators.py b/nnUNet/nnunetv2/inference/data_iterators.py
new file mode 100644
index 0000000..37d7f07
--- /dev/null
+++ b/nnUNet/nnunetv2/inference/data_iterators.py
@@ -0,0 +1,318 @@
+import multiprocessing
+import queue
+from torch.multiprocessing import Event, Process, Queue, Manager
+
+from time import sleep
+from typing import Union, List
+
+import numpy as np
+import torch
+from batchgenerators.dataloading.data_loader import DataLoader
+
+from nnunetv2.preprocessing.preprocessors.default_preprocessor import DefaultPreprocessor
+from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot
+from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
+
+
+def preprocess_fromfiles_save_to_queue(list_of_lists: List[List[str]],
+ list_of_segs_from_prev_stage_files: Union[None, List[str]],
+ output_filenames_truncated: Union[None, List[str]],
+ plans_manager: PlansManager,
+ dataset_json: dict,
+ configuration_manager: ConfigurationManager,
+ target_queue: Queue,
+ done_event: Event,
+ abort_event: Event,
+ verbose: bool = False):
+ try:
+ label_manager = plans_manager.get_label_manager(dataset_json)
+ preprocessor = configuration_manager.preprocessor_class(verbose=verbose)
+ for idx in range(len(list_of_lists)):
+ data, seg, data_properites = preprocessor.run_case(list_of_lists[idx],
+ list_of_segs_from_prev_stage_files[
+ idx] if list_of_segs_from_prev_stage_files is not None else None,
+ plans_manager,
+ configuration_manager,
+ dataset_json)
+ if list_of_segs_from_prev_stage_files is not None and list_of_segs_from_prev_stage_files[idx] is not None:
+ seg_onehot = convert_labelmap_to_one_hot(seg[0], label_manager.foreground_labels, data.dtype)
+ data = np.vstack((data, seg_onehot))
+
+ data = torch.from_numpy(data).contiguous().float()
+
+ item = {'data': data, 'data_properites': data_properites,
+ 'ofile': output_filenames_truncated[idx] if output_filenames_truncated is not None else None}
+ success = False
+ while not success:
+ try:
+ if abort_event.is_set():
+ return
+ target_queue.put(item, timeout=0.01)
+ success = True
+ except queue.Full:
+ pass
+ done_event.set()
+ except Exception as e:
+ abort_event.set()
+ raise e
+
+
+def preprocessing_iterator_fromfiles(list_of_lists: List[List[str]],
+ list_of_segs_from_prev_stage_files: Union[None, List[str]],
+ output_filenames_truncated: Union[None, List[str]],
+ plans_manager: PlansManager,
+ dataset_json: dict,
+ configuration_manager: ConfigurationManager,
+ num_processes: int,
+ pin_memory: bool = False,
+ verbose: bool = False):
+ context = multiprocessing.get_context('spawn')
+ manager = Manager()
+ num_processes = min(len(list_of_lists), num_processes)
+ assert num_processes >= 1
+ processes = []
+ done_events = []
+ target_queues = []
+ abort_event = manager.Event()
+ for i in range(num_processes):
+ event = manager.Event()
+ queue = Manager().Queue(maxsize=1)
+ pr = context.Process(target=preprocess_fromfiles_save_to_queue,
+ args=(
+ list_of_lists[i::num_processes],
+ list_of_segs_from_prev_stage_files[
+ i::num_processes] if list_of_segs_from_prev_stage_files is not None else None,
+ output_filenames_truncated[
+ i::num_processes] if output_filenames_truncated is not None else None,
+ plans_manager,
+ dataset_json,
+ configuration_manager,
+ queue,
+ event,
+ abort_event,
+ verbose
+ ), daemon=True)
+ pr.start()
+ target_queues.append(queue)
+ done_events.append(event)
+ processes.append(pr)
+
+ worker_ctr = 0
+ # print(f"Type: {type(target_queues[worker_ctr])}")
+ while (not done_events[worker_ctr].is_set()) or (not target_queues[worker_ctr].empty()):
+ # print(type(target_queues[worker_ctr]))
+ if not target_queues[worker_ctr].empty():
+ item = target_queues[worker_ctr].get()
+ worker_ctr = (worker_ctr + 1) % num_processes
+ else:
+ all_ok = all(
+ [i.is_alive() or j.is_set() for i, j in zip(processes, done_events)]) and not abort_event.is_set()
+ if not all_ok:
+ raise RuntimeError('Background workers died. Look for the error message further up! If there is '
+ 'none then your RAM was full and the worker was killed by the OS. Use fewer '
+ 'workers or get more RAM in that case!')
+ sleep(0.01)
+ continue
+ if pin_memory:
+ [i.pin_memory() for i in item.values() if isinstance(i, torch.Tensor)]
+ yield item
+ [p.join() for p in processes]
+
+class PreprocessAdapter(DataLoader):
+ def __init__(self, list_of_lists: List[List[str]],
+ list_of_segs_from_prev_stage_files: Union[None, List[str]],
+ preprocessor: DefaultPreprocessor,
+ output_filenames_truncated: Union[None, List[str]],
+ plans_manager: PlansManager,
+ dataset_json: dict,
+ configuration_manager: ConfigurationManager,
+ num_threads_in_multithreaded: int = 1):
+ self.preprocessor, self.plans_manager, self.configuration_manager, self.dataset_json = \
+ preprocessor, plans_manager, configuration_manager, dataset_json
+
+ self.label_manager = plans_manager.get_label_manager(dataset_json)
+
+ if list_of_segs_from_prev_stage_files is None:
+ list_of_segs_from_prev_stage_files = [None] * len(list_of_lists)
+ if output_filenames_truncated is None:
+ output_filenames_truncated = [None] * len(list_of_lists)
+
+ super().__init__(list(zip(list_of_lists, list_of_segs_from_prev_stage_files, output_filenames_truncated)),
+ 1, num_threads_in_multithreaded,
+ seed_for_shuffle=1, return_incomplete=True,
+ shuffle=False, infinite=False, sampling_probabilities=None)
+
+ self.indices = list(range(len(list_of_lists)))
+
+ def generate_train_batch(self):
+ idx = self.get_indices()[0]
+ files = self._data[idx][0]
+ seg_prev_stage = self._data[idx][1]
+ ofile = self._data[idx][2]
+ # if we have a segmentation from the previous stage we have to process it together with the images so that we
+ # can crop it appropriately (if needed). Otherwise it would just be resized to the shape of the data after
+ # preprocessing and then there might be misalignments
+ data, seg, data_properites = self.preprocessor.run_case(files, seg_prev_stage, self.plans_manager,
+ self.configuration_manager,
+ self.dataset_json)
+ if seg_prev_stage is not None:
+ seg_onehot = convert_labelmap_to_one_hot(seg[0], self.label_manager.foreground_labels, data.dtype)
+ data = np.vstack((data, seg_onehot))
+
+ data = torch.from_numpy(data)
+
+ return {'data': data, 'data_properites': data_properites, 'ofile': ofile}
+
+
+class PreprocessAdapterFromNpy(DataLoader):
+ def __init__(self, list_of_images: List[np.ndarray],
+ list_of_segs_from_prev_stage: Union[List[np.ndarray], None],
+ list_of_image_properties: List[dict],
+ truncated_ofnames: Union[List[str], None],
+ plans_manager: PlansManager, dataset_json: dict, configuration_manager: ConfigurationManager,
+ num_threads_in_multithreaded: int = 1, verbose: bool = False):
+ preprocessor = configuration_manager.preprocessor_class(verbose=verbose)
+ self.preprocessor, self.plans_manager, self.configuration_manager, self.dataset_json, self.truncated_ofnames = \
+ preprocessor, plans_manager, configuration_manager, dataset_json, truncated_ofnames
+
+ self.label_manager = plans_manager.get_label_manager(dataset_json)
+
+ if list_of_segs_from_prev_stage is None:
+ list_of_segs_from_prev_stage = [None] * len(list_of_images)
+ if truncated_ofnames is None:
+ truncated_ofnames = [None] * len(list_of_images)
+
+ super().__init__(
+ list(zip(list_of_images, list_of_segs_from_prev_stage, list_of_image_properties, truncated_ofnames)),
+ 1, num_threads_in_multithreaded,
+ seed_for_shuffle=1, return_incomplete=True,
+ shuffle=False, infinite=False, sampling_probabilities=None)
+
+ self.indices = list(range(len(list_of_images)))
+
+ def generate_train_batch(self):
+ idx = self.get_indices()[0]
+ image = self._data[idx][0]
+ seg_prev_stage = self._data[idx][1]
+ props = self._data[idx][2]
+ ofname = self._data[idx][3]
+ # if we have a segmentation from the previous stage we have to process it together with the images so that we
+ # can crop it appropriately (if needed). Otherwise it would just be resized to the shape of the data after
+ # preprocessing and then there might be misalignments
+ data, seg = self.preprocessor.run_case_npy(image, seg_prev_stage, props,
+ self.plans_manager,
+ self.configuration_manager,
+ self.dataset_json)
+ if seg_prev_stage is not None:
+ seg_onehot = convert_labelmap_to_one_hot(seg[0], self.label_manager.foreground_labels, data.dtype)
+ data = np.vstack((data, seg_onehot))
+
+ data = torch.from_numpy(data)
+
+ return {'data': data, 'data_properites': props, 'ofile': ofname}
+
+
+def preprocess_fromnpy_save_to_queue(list_of_images: List[np.ndarray],
+ list_of_segs_from_prev_stage: Union[List[np.ndarray], None],
+ list_of_image_properties: List[dict],
+ truncated_ofnames: Union[List[str], None],
+ plans_manager: PlansManager,
+ dataset_json: dict,
+ configuration_manager: ConfigurationManager,
+ target_queue: Queue,
+ done_event: Event,
+ abort_event: Event,
+ verbose: bool = False):
+ try:
+ label_manager = plans_manager.get_label_manager(dataset_json)
+ preprocessor = configuration_manager.preprocessor_class(verbose=verbose)
+ for idx in range(len(list_of_images)):
+ data, seg = preprocessor.run_case_npy(list_of_images[idx],
+ list_of_segs_from_prev_stage[
+ idx] if list_of_segs_from_prev_stage is not None else None,
+ list_of_image_properties[idx],
+ plans_manager,
+ configuration_manager,
+ dataset_json)
+ if list_of_segs_from_prev_stage is not None and list_of_segs_from_prev_stage[idx] is not None:
+ seg_onehot = convert_labelmap_to_one_hot(seg[0], label_manager.foreground_labels, data.dtype)
+ data = np.vstack((data, seg_onehot))
+
+ data = torch.from_numpy(data).contiguous().float()
+
+ item = {'data': data, 'data_properites': list_of_image_properties[idx],
+ 'ofile': truncated_ofnames[idx] if truncated_ofnames is not None else None}
+ success = False
+ while not success:
+ try:
+ if abort_event.is_set():
+ return
+ target_queue.put(item, timeout=0.01)
+ success = True
+ except queue.Full:
+ pass
+ done_event.set()
+ except Exception as e:
+ abort_event.set()
+ raise e
+
+
+def preprocessing_iterator_fromnpy(list_of_images: List[np.ndarray],
+ list_of_segs_from_prev_stage: Union[List[np.ndarray], None],
+ list_of_image_properties: List[dict],
+ truncated_ofnames: Union[List[str], None],
+ plans_manager: PlansManager,
+ dataset_json: dict,
+ configuration_manager: ConfigurationManager,
+ num_processes: int,
+ pin_memory: bool = False,
+ verbose: bool = False):
+ context = multiprocessing.get_context('spawn')
+ manager = Manager()
+ num_processes = min(len(list_of_images), num_processes)
+ assert num_processes >= 1
+ target_queues = []
+ processes = []
+ done_events = []
+ abort_event = manager.Event()
+ for i in range(num_processes):
+ event = manager.Event()
+ queue = manager.Queue(maxsize=1)
+ pr = context.Process(target=preprocess_fromnpy_save_to_queue,
+ args=(
+ list_of_images[i::num_processes],
+ list_of_segs_from_prev_stage[
+ i::num_processes] if list_of_segs_from_prev_stage is not None else None,
+ list_of_image_properties[i::num_processes],
+ truncated_ofnames[i::num_processes] if truncated_ofnames is not None else None,
+ plans_manager,
+ dataset_json,
+ configuration_manager,
+ queue,
+ event,
+ abort_event,
+ verbose
+ ), daemon=True)
+ pr.start()
+ done_events.append(event)
+ processes.append(pr)
+ target_queues.append(queue)
+
+ worker_ctr = 0
+ while (not done_events[worker_ctr].is_set()) or (not target_queues[worker_ctr].empty()):
+ if not target_queues[worker_ctr].empty():
+ item = target_queues[worker_ctr].get()
+ worker_ctr = (worker_ctr + 1) % num_processes
+ else:
+ all_ok = all(
+ [i.is_alive() or j.is_set() for i, j in zip(processes, done_events)]) and not abort_event.is_set()
+ if not all_ok:
+ raise RuntimeError('Background workers died. Look for the error message further up! If there is '
+ 'none then your RAM was full and the worker was killed by the OS. Use fewer '
+ 'workers or get more RAM in that case!')
+ sleep(0.01)
+ continue
+ if pin_memory:
+ [i.pin_memory() for i in item.values() if isinstance(i, torch.Tensor)]
+ yield item
+ [p.join() for p in processes]
diff --git a/nnUNet/nnunetv2/inference/examples.py b/nnUNet/nnunetv2/inference/examples.py
new file mode 100644
index 0000000..8e8f264
--- /dev/null
+++ b/nnUNet/nnunetv2/inference/examples.py
@@ -0,0 +1,102 @@
+if __name__ == '__main__':
+ from nnunetv2.paths import nnUNet_results, nnUNet_raw
+ import torch
+ from batchgenerators.utilities.file_and_folder_operations import join
+ from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
+ from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO
+
+ # nnUNetv2_predict -d 3 -f 0 -c 3d_lowres -i imagesTs -o imagesTs_predlowres --continue_prediction
+
+ # instantiate the nnUNetPredictor
+ predictor = nnUNetPredictor(
+ tile_step_size=0.5,
+ use_gaussian=True,
+ use_mirroring=True,
+ perform_everything_on_gpu=True,
+ device=torch.device('cuda', 0),
+ verbose=False,
+ verbose_preprocessing=False,
+ allow_tqdm=True
+ )
+ # initializes the network architecture, loads the checkpoint
+ predictor.initialize_from_trained_model_folder(
+ join(nnUNet_results, 'Dataset003_Liver/nnUNetTrainer__nnUNetPlans__3d_lowres'),
+ use_folds=(0,),
+ checkpoint_name='checkpoint_final.pth',
+ )
+ # variant 1: give input and output folders
+ predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'),
+ join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres'),
+ save_probabilities=False, overwrite=False,
+ num_processes_preprocessing=2, num_processes_segmentation_export=2,
+ folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)
+
+ # variant 2, use list of files as inputs. Note how we use nested lists!!!
+ indir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs')
+ outdir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres')
+ predictor.predict_from_files([[join(indir, 'liver_152_0000.nii.gz')],
+ [join(indir, 'liver_142_0000.nii.gz')]],
+ [join(outdir, 'liver_152.nii.gz'),
+ join(outdir, 'liver_142.nii.gz')],
+ save_probabilities=False, overwrite=True,
+ num_processes_preprocessing=2, num_processes_segmentation_export=2,
+ folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)
+
+ # variant 2.5, returns segmentations
+ indir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs')
+ predicted_segmentations = predictor.predict_from_files([[join(indir, 'liver_152_0000.nii.gz')],
+ [join(indir, 'liver_142_0000.nii.gz')]],
+ None,
+ save_probabilities=True, overwrite=True,
+ num_processes_preprocessing=2,
+ num_processes_segmentation_export=2,
+ folder_with_segs_from_prev_stage=None, num_parts=1,
+ part_id=0)
+
+ # predict several npy images
+ from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO
+
+ img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')])
+ img2, props2 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_146_0000.nii.gz')])
+ img3, props3 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_145_0000.nii.gz')])
+ img4, props4 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_144_0000.nii.gz')])
+ # we do not set output files so that the segmentations will be returned. You can of course also specify output
+ # files instead (no return value on that case)
+ ret = predictor.predict_from_list_of_npy_arrays([img, img2, img3, img4],
+ None,
+ [props, props2, props3, props4],
+ None, 2, save_probabilities=False,
+ num_processes_segmentation_export=2)
+
+ # predict a single numpy array
+ img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')])
+ ret = predictor.predict_single_npy_array(img, props, None, None, True)
+
+ # custom iterator
+
+ img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')])
+ img2, props2 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_146_0000.nii.gz')])
+ img3, props3 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_145_0000.nii.gz')])
+ img4, props4 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_144_0000.nii.gz')])
+
+
+ # each element returned by data_iterator must be a dict with 'data', 'ofile' and 'data_properites' keys!
+ # If 'ofile' is None, the result will be returned instead of written to a file
+ # the iterator is responsible for performing the correct preprocessing!
+ # note how the iterator here does not use multiprocessing -> preprocessing will be done in the main thread!
+ # take a look at the default iterators for predict_from_files and predict_from_list_of_npy_arrays
+ # (they both use predictor.predict_from_data_iterator) for inspiration!
+ def my_iterator(list_of_input_arrs, list_of_input_props):
+ preprocessor = predictor.configuration_manager.preprocessor_class(verbose=predictor.verbose)
+ for a, p in zip(list_of_input_arrs, list_of_input_props):
+ data, seg = preprocessor.run_case_npy(a,
+ None,
+ p,
+ predictor.plans_manager,
+ predictor.configuration_manager,
+ predictor.dataset_json)
+ yield {'data': torch.from_numpy(data).contiguous().pin_memory(), 'data_properites': p, 'ofile': None}
+
+
+ ret = predictor.predict_from_data_iterator(my_iterator([img, img2, img3, img4], [props, props2, props3, props4]),
+ save_probabilities=False, num_processes_segmentation_export=3)
diff --git a/nnUNet/nnunetv2/inference/export_prediction.py b/nnUNet/nnunetv2/inference/export_prediction.py
new file mode 100644
index 0000000..418a73d
--- /dev/null
+++ b/nnUNet/nnunetv2/inference/export_prediction.py
@@ -0,0 +1,168 @@
+import os
+from copy import deepcopy
+from typing import Union, List
+
+from skimage.transform import resize
+import numpy as np
+import torch
+from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice
+from batchgenerators.utilities.file_and_folder_operations import load_json, isfile, save_pickle
+
+from nnunetv2.configuration import default_num_processes
+from nnunetv2.utilities.label_handling.label_handling import LabelManager
+from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
+
+
+def convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits: Union[torch.Tensor, np.ndarray],
+ plans_manager: PlansManager,
+ configuration_manager: ConfigurationManager,
+ label_manager: LabelManager,
+ properties_dict: dict,
+ return_probabilities: bool = False,
+ num_threads_torch: int = default_num_processes):
+ # This makes the docker hang. Why, idk
+ #old_threads = torch.get_num_threads()
+ #torch.set_num_threads(num_threads_torch)
+
+ # resample to original shape
+ current_spacing = configuration_manager.spacing if \
+ len(configuration_manager.spacing) == \
+ len(properties_dict['shape_after_cropping_and_before_resampling']) else \
+ [properties_dict['spacing'][0], *configuration_manager.spacing]
+ predicted_logits = predicted_logits.cpu()
+
+ # Ok so here the idea is that since output tensors will have 31 channels (30 classes + background),
+ # combined to the huge X and Y dimensions caused by the small median spacing, this makes for huge tensors,
+ # barely fitting in memory.
+ # So, instead of following the 'nonlin -> argmax -> resampling' scheme
+ # (which implies creation of np array along the way, like we have that kind of memory lying around?!);
+ # Instead we do 'argmax -> resampling' (who needs nonlin anyway? (unless you want pseudo-probabilities))
+ # This divides the size in memory by 62 (!). Indeed, we go from 31 channels to 1 and go fromm float16 to uint8.
+ # Results should be "fairly" similar
+ # segmentation = predicted_logits.squeeze() # Old hack for memory reduction
+ segmentation = torch.argmax(predicted_logits, 0, keepdim=False)
+ # If you get to this point, memory should not be a problem going forward
+ segmentation = segmentation.numpy().astype(np.uint8)
+
+ # Copied the resample_data_or_seg function
+ segmentation = resize(segmentation, properties_dict['shape_after_cropping_and_before_resampling'], 0,
+ mode="edge", clip=True, anti_aliasing=False)
+
+ #segmentation = configuration_manager.resampling_fn_seg(predicted_logits,
+ # properties_dict['shape_after_cropping_and_before_resampling'],
+ # current_spacing,
+ # properties_dict['spacing'])
+ # return value of resampling_fn_probabilities can be ndarray or Tensor but that doesnt matter because
+ # apply_inference_nonlin will covnert to torch
+ #predicted_probabilities = label_manager.apply_inference_nonlin(predicted_logits_r)
+ #segmentation = predicted_logits_r # label_manager.convert_probabilities_to_segmentation(predicted_probabilities)
+
+ # segmentation may be torch.Tensor but we continue with numpy
+ if isinstance(segmentation, torch.Tensor):
+ segmentation = segmentation.cpu().numpy()
+
+ # put segmentation in bbox (revert cropping)
+ segmentation_reverted_cropping = np.zeros(properties_dict['shape_before_cropping'],
+ dtype=np.uint8 if len(label_manager.foreground_labels) < 255 else np.uint16)
+ slicer = bounding_box_to_slice(properties_dict['bbox_used_for_cropping'])
+ segmentation_reverted_cropping[slicer] = segmentation
+ del segmentation
+
+ # revert transpose
+ segmentation_reverted_cropping = segmentation_reverted_cropping.transpose(plans_manager.transpose_backward)
+ if return_probabilities: # We don't. So u should never set it to True
+ # revert cropping
+ predicted_probabilities = label_manager.revert_cropping_on_probabilities(predicted_probabilities,
+ properties_dict[
+ 'bbox_used_for_cropping'],
+ properties_dict[
+ 'shape_before_cropping'])
+ predicted_probabilities = predicted_probabilities.cpu().numpy()
+ # revert transpose
+ predicted_probabilities = predicted_probabilities.transpose([0] + [i + 1 for i in
+ plans_manager.transpose_backward])
+ # Same reason as above, it makes the docker hang
+ # torch.set_num_threads(old_threads)
+ return segmentation_reverted_cropping, predicted_probabilities
+ else:
+ # torch.set_num_threads(old_threads)
+ return segmentation_reverted_cropping
+
+
+def export_prediction_from_logits(predicted_array_or_file: Union[np.ndarray, torch.Tensor], properties_dict: dict,
+ configuration_manager: ConfigurationManager,
+ plans_manager: PlansManager,
+ dataset_json_dict_or_file: Union[dict, str], output_file_truncated: str,
+ save_probabilities: bool = False):
+ # if isinstance(predicted_array_or_file, str):
+ # tmp = deepcopy(predicted_array_or_file)
+ # if predicted_array_or_file.endswith('.npy'):
+ # predicted_array_or_file = np.load(predicted_array_or_file)
+ # elif predicted_array_or_file.endswith('.npz'):
+ # predicted_array_or_file = np.load(predicted_array_or_file)['softmax']
+ # os.remove(tmp)
+
+ if isinstance(dataset_json_dict_or_file, str):
+ dataset_json_dict_or_file = load_json(dataset_json_dict_or_file)
+
+ label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file)
+ print("starting conversion", flush=True)
+ ret = convert_predicted_logits_to_segmentation_with_correct_shape(
+ predicted_array_or_file, plans_manager, configuration_manager, label_manager, properties_dict,
+ return_probabilities=save_probabilities
+ )
+ del predicted_array_or_file
+
+ # save
+ if save_probabilities:
+ segmentation_final, probabilities_final = ret
+ np.savez_compressed(output_file_truncated + '.npz', probabilities=probabilities_final)
+ save_pickle(properties_dict, output_file_truncated + '.pkl')
+ del probabilities_final, ret
+ else:
+ segmentation_final = ret
+ del ret
+
+ rw = plans_manager.image_reader_writer_class()
+ print("saving", flush=True)
+ rw.write_seg(segmentation_final, output_file_truncated + dataset_json_dict_or_file['file_ending'],
+ properties_dict)
+
+
+def resample_and_save(predicted: Union[torch.Tensor, np.ndarray], target_shape: List[int], output_file: str,
+ plans_manager: PlansManager, configuration_manager: ConfigurationManager, properties_dict: dict,
+ dataset_json_dict_or_file: Union[dict, str], num_threads_torch: int = default_num_processes) \
+ -> None:
+ # # needed for cascade
+ # if isinstance(predicted, str):
+ # assert isfile(predicted), "If isinstance(segmentation_softmax, str) then " \
+ # "isfile(segmentation_softmax) must be True"
+ # del_file = deepcopy(predicted)
+ # predicted = np.load(predicted)
+ # os.remove(del_file)
+ old_threads = torch.get_num_threads()
+ torch.set_num_threads(num_threads_torch)
+
+ if isinstance(dataset_json_dict_or_file, str):
+ dataset_json_dict_or_file = load_json(dataset_json_dict_or_file)
+
+ # resample to original shape
+ current_spacing = configuration_manager.spacing if \
+ len(configuration_manager.spacing) == len(properties_dict['shape_after_cropping_and_before_resampling']) else \
+ [properties_dict['spacing'][0], *configuration_manager.spacing]
+ target_spacing = configuration_manager.spacing if len(configuration_manager.spacing) == \
+ len(properties_dict['shape_after_cropping_and_before_resampling']) else \
+ [properties_dict['spacing'][0], *configuration_manager.spacing]
+ predicted_array_or_file = configuration_manager.resampling_fn_probabilities(predicted,
+ target_shape,
+ current_spacing,
+ target_spacing)
+
+ # create segmentation (argmax, regions, etc)
+ label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file)
+ segmentation = label_manager.convert_logits_to_segmentation(predicted_array_or_file)
+ # segmentation may be torch.Tensor but we continue with numpy
+ if isinstance(segmentation, torch.Tensor):
+ segmentation = segmentation.cpu().numpy()
+ np.savez_compressed(output_file, seg=segmentation.astype(np.uint8))
+ torch.set_num_threads(old_threads)
diff --git a/nnUNet/nnunetv2/inference/predict_from_raw_data.py b/nnUNet/nnunetv2/inference/predict_from_raw_data.py
new file mode 100644
index 0000000..299e61a
--- /dev/null
+++ b/nnUNet/nnunetv2/inference/predict_from_raw_data.py
@@ -0,0 +1,938 @@
+import inspect
+import multiprocessing
+import os
+import traceback
+from copy import deepcopy
+from time import sleep
+from typing import Tuple, Union, List, Optional
+
+import numpy as np
+import torch
+from acvl_utils.cropping_and_padding.padding import pad_nd_image
+from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
+from batchgenerators.utilities.file_and_folder_operations import load_json, join, isfile, maybe_mkdir_p, isdir, subdirs, \
+ save_json
+from torch import nn
+from torch._dynamo import OptimizedModule
+from torch.nn.parallel import DistributedDataParallel
+from tqdm import tqdm
+
+import nnunetv2
+from nnunetv2.configuration import default_num_processes
+from nnunetv2.inference.data_iterators import PreprocessAdapterFromNpy, preprocessing_iterator_fromfiles, \
+ preprocessing_iterator_fromnpy
+from nnunetv2.inference.export_prediction import export_prediction_from_logits, \
+ convert_predicted_logits_to_segmentation_with_correct_shape
+from nnunetv2.inference.sliding_window_prediction import compute_gaussian, \
+ compute_steps_for_sliding_window
+from nnunetv2.utilities.file_path_utilities import get_output_folder, check_workers_alive_and_busy
+from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
+from nnunetv2.utilities.helpers import empty_cache, dummy_context
+from nnunetv2.utilities.json_export import recursive_fix_for_json_export
+from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
+from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
+from nnunetv2.utilities.utils import create_lists_from_splitted_dataset_folder
+
+
+class nnUNetPredictor(object):
+ def __init__(self,
+ tile_step_size: float = 0.5,
+ use_gaussian: bool = True,
+ use_mirroring: bool = True,
+ perform_everything_on_gpu: bool = True,
+ device: torch.device = torch.device('cuda'),
+ verbose: bool = False,
+ verbose_preprocessing: bool = False,
+ allow_tqdm: bool = True):
+ self.verbose = verbose
+ self.verbose_preprocessing = verbose_preprocessing
+ self.allow_tqdm = allow_tqdm
+
+ self.plans_manager, self.configuration_manager, self.list_of_parameters, self.network, self.dataset_json, \
+ self.trainer_name, self.allowed_mirroring_axes, self.label_manager = None, None, None, None, None, None, None, None
+
+ self.tile_step_size = tile_step_size
+ self.use_gaussian = use_gaussian
+ self.use_mirroring = use_mirroring
+ if device.type == 'cuda':
+ # device = torch.device(type='cuda', index=0) # set the desired GPU with CUDA_VISIBLE_DEVICES!
+ # why would I ever want to do that. Stupid dobby. This kills DDP inference...
+ pass
+ if device.type != 'cuda':
+ print(f'perform_everything_on_gpu=True is only supported for cuda devices! Setting this to False')
+ perform_everything_on_gpu = False
+ self.device = device
+ self.perform_everything_on_gpu = perform_everything_on_gpu
+
+ def initialize_from_trained_model_folder(self, model_training_output_dir: str,
+ use_folds: Union[Tuple[Union[int, str]], None],
+ checkpoint_name: str = 'checkpoint_final.pth'):
+ """
+ This is used when making predictions with a trained model
+ """
+ if use_folds is None:
+ use_folds = nnUNetPredictor.auto_detect_available_folds(model_training_output_dir, checkpoint_name)
+
+ dataset_json = load_json(join(model_training_output_dir, 'dataset.json'))
+ plans = load_json(join(model_training_output_dir, 'plans.json'))
+ plans_manager = PlansManager(plans)
+
+ if isinstance(use_folds, str):
+ use_folds = [use_folds]
+
+ parameters = []
+ for i, f in enumerate(use_folds):
+ f = int(f) if f != 'all' else f
+ checkpoint = torch.load(join(model_training_output_dir, f'fold_{f}', checkpoint_name),
+ map_location=torch.device('cpu'))
+ checkpoint_name = join(model_training_output_dir, f'fold_{f}', checkpoint_name)
+
+ if i == 0:
+ trainer_name = checkpoint['trainer_name']
+ configuration_name = checkpoint['init_args']['configuration']
+ inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \
+ 'inference_allowed_mirroring_axes' in checkpoint.keys() else None
+ del checkpoint
+ # parameters.append(checkpoint['network_weights'])
+ parameters.append(checkpoint_name)
+
+ configuration_manager = plans_manager.get_configuration(configuration_name)
+ # restore network
+ num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
+ trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
+ trainer_name, 'nnunetv2.training.nnUNetTrainer')
+ network = trainer_class.build_network_architecture(plans_manager, dataset_json, configuration_manager,
+ num_input_channels, enable_deep_supervision=False)
+ self.plans_manager = plans_manager
+ self.configuration_manager = configuration_manager
+ self.list_of_parameters = parameters
+ self.network = network
+ self.dataset_json = dataset_json
+ self.trainer_name = trainer_name
+ self.allowed_mirroring_axes = inference_allowed_mirroring_axes
+ self.label_manager = plans_manager.get_label_manager(dataset_json)
+ if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \
+ and not isinstance(self.network, OptimizedModule):
+ print('compiling network')
+ self.network = torch.compile(self.network)
+
+ def manual_initialization(self, network: nn.Module, plans_manager: PlansManager,
+ configuration_manager: ConfigurationManager, parameters: Optional[List[dict]],
+ dataset_json: dict, trainer_name: str,
+ inference_allowed_mirroring_axes: Optional[Tuple[int, ...]]):
+ """
+ This is used by the nnUNetTrainer to initialize nnUNetPredictor for the final validation
+ """
+ self.plans_manager = plans_manager
+ self.configuration_manager = configuration_manager
+ self.list_of_parameters = parameters
+ self.network = network
+ self.dataset_json = dataset_json
+ self.trainer_name = trainer_name
+ self.allowed_mirroring_axes = inference_allowed_mirroring_axes
+ self.label_manager = plans_manager.get_label_manager(dataset_json)
+ allow_compile = True
+ allow_compile = allow_compile and ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't'))
+ allow_compile = allow_compile and not isinstance(self.network, OptimizedModule)
+ if isinstance(self.network, DistributedDataParallel):
+ allow_compile = allow_compile and isinstance(self.network.module, OptimizedModule)
+ if allow_compile:
+ print('compiling network')
+ self.network = torch.compile(self.network)
+
+ @staticmethod
+ def auto_detect_available_folds(model_training_output_dir, checkpoint_name):
+ print('use_folds is None, attempting to auto detect available folds')
+ fold_folders = subdirs(model_training_output_dir, prefix='fold_', join=False)
+ fold_folders = [i for i in fold_folders if i != 'fold_all']
+ fold_folders = [i for i in fold_folders if isfile(join(model_training_output_dir, i, checkpoint_name))]
+ use_folds = [int(i.split('_')[-1]) for i in fold_folders]
+ print(f'found the following folds: {use_folds}')
+ return use_folds
+
+ def _manage_input_and_output_lists(self, list_of_lists_or_source_folder: Union[str, List[List[str]]],
+ output_folder_or_list_of_truncated_output_files: Union[None, str, List[str]],
+ folder_with_segs_from_prev_stage: str = None,
+ overwrite: bool = True,
+ part_id: int = 0,
+ num_parts: int = 1,
+ save_probabilities: bool = False):
+ if isinstance(list_of_lists_or_source_folder, str):
+ list_of_lists_or_source_folder = create_lists_from_splitted_dataset_folder(list_of_lists_or_source_folder,
+ self.dataset_json['file_ending'])
+ print(f'There are {len(list_of_lists_or_source_folder)} cases in the source folder')
+ list_of_lists_or_source_folder = list_of_lists_or_source_folder[part_id::num_parts]
+ caseids = [os.path.basename(i[0])[:-(len(self.dataset_json['file_ending']) + 5)] for i in
+ list_of_lists_or_source_folder]
+ print(
+ f'I am process {part_id} out of {num_parts} (max process ID is {num_parts - 1}, we start counting with 0!)')
+ print(f'There are {len(caseids)} cases that I would like to predict')
+
+ if isinstance(output_folder_or_list_of_truncated_output_files, str):
+ output_filename_truncated = [join(output_folder_or_list_of_truncated_output_files, i) for i in caseids]
+ else:
+ output_filename_truncated = output_folder_or_list_of_truncated_output_files
+
+ seg_from_prev_stage_files = [join(folder_with_segs_from_prev_stage, i + self.dataset_json['file_ending']) if
+ folder_with_segs_from_prev_stage is not None else None for i in caseids]
+ # remove already predicted files form the lists
+ if not overwrite and output_filename_truncated is not None:
+ tmp = [isfile(i + self.dataset_json['file_ending']) for i in output_filename_truncated]
+ if save_probabilities:
+ tmp2 = [isfile(i + '.npz') for i in output_filename_truncated]
+ tmp = [i and j for i, j in zip(tmp, tmp2)]
+ not_existing_indices = [i for i, j in enumerate(tmp) if not j]
+
+ output_filename_truncated = [output_filename_truncated[i] for i in not_existing_indices]
+ list_of_lists_or_source_folder = [list_of_lists_or_source_folder[i] for i in not_existing_indices]
+ seg_from_prev_stage_files = [seg_from_prev_stage_files[i] for i in not_existing_indices]
+ print(f'overwrite was set to {overwrite}, so I am only working on cases that haven\'t been predicted yet. '
+ f'That\'s {len(not_existing_indices)} cases.')
+ return list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files
+
+ def predict_from_files(self,
+ list_of_lists_or_source_folder: Union[str, List[List[str]]],
+ output_folder_or_list_of_truncated_output_files: Union[str, None, List[str]],
+ save_probabilities: bool = False,
+ overwrite: bool = True,
+ num_processes_preprocessing: int = default_num_processes,
+ num_processes_segmentation_export: int = default_num_processes,
+ folder_with_segs_from_prev_stage: str = None,
+ num_parts: int = 1,
+ part_id: int = 0):
+ """
+ This is nnU-Net's default function for making predictions. It works best for batch predictions
+ (predicting many images at once).
+ """
+ if isinstance(output_folder_or_list_of_truncated_output_files, str):
+ output_folder = output_folder_or_list_of_truncated_output_files
+ elif isinstance(output_folder_or_list_of_truncated_output_files, list):
+ output_folder = os.path.dirname(output_folder_or_list_of_truncated_output_files[0])
+ else:
+ output_folder = None
+
+ ########################
+ # let's store the input arguments so that its clear what was used to generate the prediction
+ if output_folder is not None:
+ my_init_kwargs = {}
+ for k in inspect.signature(self.predict_from_files).parameters.keys():
+ my_init_kwargs[k] = locals()[k]
+ my_init_kwargs = deepcopy(
+ my_init_kwargs) # let's not unintentionally change anything in-place. Take this as a
+ recursive_fix_for_json_export(my_init_kwargs)
+ maybe_mkdir_p(output_folder)
+ save_json(my_init_kwargs, join(output_folder, 'predict_from_raw_data_args.json'))
+
+ # we need these two if we want to do things with the predictions like for example apply postprocessing
+ save_json(self.dataset_json, join(output_folder, 'dataset.json'), sort_keys=False)
+ save_json(self.plans_manager.plans, join(output_folder, 'plans.json'), sort_keys=False)
+ #######################
+
+ # check if we need a prediction from the previous stage
+ if self.configuration_manager.previous_stage_name is not None:
+ assert folder_with_segs_from_prev_stage is not None, \
+ f'The requested configuration is a cascaded network. It requires the segmentations of the previous ' \
+ f'stage ({self.configuration_manager.previous_stage_name}) as input. Please provide the folder where' \
+ f' they are located via folder_with_segs_from_prev_stage'
+
+ # sort out input and output filenames
+ list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files = \
+ self._manage_input_and_output_lists(list_of_lists_or_source_folder,
+ output_folder_or_list_of_truncated_output_files,
+ folder_with_segs_from_prev_stage, overwrite, part_id, num_parts,
+ save_probabilities)
+ if len(list_of_lists_or_source_folder) == 0:
+ return
+
+ data_iterator = self._internal_get_data_iterator_from_lists_of_filenames(list_of_lists_or_source_folder,
+ seg_from_prev_stage_files,
+ output_filename_truncated,
+ num_processes_preprocessing)
+
+ return self.predict_from_data_iterator(data_iterator, save_probabilities, num_processes_segmentation_export)
+
+ def _internal_get_data_iterator_from_lists_of_filenames(self,
+ input_list_of_lists: List[List[str]],
+ seg_from_prev_stage_files: Union[List[str], None],
+ output_filenames_truncated: Union[List[str], None],
+ num_processes: int):
+ return preprocessing_iterator_fromfiles(input_list_of_lists, seg_from_prev_stage_files,
+ output_filenames_truncated, self.plans_manager, self.dataset_json,
+ self.configuration_manager, num_processes, self.device.type == 'cuda',
+ self.verbose_preprocessing)
+ # preprocessor = self.configuration_manager.preprocessor_class(verbose=self.verbose_preprocessing)
+ # # hijack batchgenerators, yo
+ # # we use the multiprocessing of the batchgenerators dataloader to handle all the background worker stuff. This
+ # # way we don't have to reinvent the wheel here.
+ # num_processes = max(1, min(num_processes, len(input_list_of_lists)))
+ # ppa = PreprocessAdapter(input_list_of_lists, seg_from_prev_stage_files, preprocessor,
+ # output_filenames_truncated, self.plans_manager, self.dataset_json,
+ # self.configuration_manager, num_processes)
+ # if num_processes == 0:
+ # mta = SingleThreadedAugmenter(ppa, None)
+ # else:
+ # mta = MultiThreadedAugmenter(ppa, None, num_processes, 1, None, pin_memory=pin_memory)
+ # return mta
+
+ def get_data_iterator_from_raw_npy_data(self,
+ image_or_list_of_images: Union[np.ndarray, List[np.ndarray]],
+ segs_from_prev_stage_or_list_of_segs_from_prev_stage: Union[None,
+ np.ndarray,
+ List[
+ np.ndarray]],
+ properties_or_list_of_properties: Union[dict, List[dict]],
+ truncated_ofname: Union[str, List[str], None],
+ num_processes: int = 3):
+
+ list_of_images = [image_or_list_of_images] if not isinstance(image_or_list_of_images, list) else \
+ image_or_list_of_images
+
+ if isinstance(segs_from_prev_stage_or_list_of_segs_from_prev_stage, np.ndarray):
+ segs_from_prev_stage_or_list_of_segs_from_prev_stage = [
+ segs_from_prev_stage_or_list_of_segs_from_prev_stage]
+
+ if isinstance(truncated_ofname, str):
+ truncated_ofname = [truncated_ofname]
+
+ if isinstance(properties_or_list_of_properties, dict):
+ properties_or_list_of_properties = [properties_or_list_of_properties]
+
+ num_processes = min(num_processes, len(list_of_images))
+ pp = preprocessing_iterator_fromnpy(
+ list_of_images,
+ segs_from_prev_stage_or_list_of_segs_from_prev_stage,
+ properties_or_list_of_properties,
+ truncated_ofname,
+ self.plans_manager,
+ self.dataset_json,
+ self.configuration_manager,
+ num_processes,
+ self.device.type == 'cuda',
+ self.verbose_preprocessing
+ )
+
+ return pp
+
+ def predict_from_list_of_npy_arrays(self,
+ image_or_list_of_images: Union[np.ndarray, List[np.ndarray]],
+ segs_from_prev_stage_or_list_of_segs_from_prev_stage: Union[None,
+ np.ndarray,
+ List[
+ np.ndarray]],
+ properties_or_list_of_properties: Union[dict, List[dict]],
+ truncated_ofname: Union[str, List[str], None],
+ num_processes: int = 3,
+ save_probabilities: bool = False,
+ num_processes_segmentation_export: int = default_num_processes):
+ iterator = self.get_data_iterator_from_raw_npy_data(image_or_list_of_images,
+ segs_from_prev_stage_or_list_of_segs_from_prev_stage,
+ properties_or_list_of_properties,
+ truncated_ofname,
+ num_processes)
+ return self.predict_from_data_iterator(iterator, save_probabilities, num_processes_segmentation_export)
+
+ def predict_from_data_iterator(self,
+ data_iterator,
+ save_probabilities: bool = False,
+ num_processes_segmentation_export: int = default_num_processes):
+ """
+ each element returned by data_iterator must be a dict with 'data', 'ofile' and 'data_properites' keys!
+ If 'ofile' is None, the result will be returned instead of written to a file
+ """
+ with multiprocessing.get_context("spawn").Pool(num_processes_segmentation_export) as export_pool:
+ worker_list = [i for i in export_pool._pool]
+ r = []
+ for preprocessed in data_iterator:
+ data = preprocessed['data']
+ if isinstance(data, str):
+ delfile = data
+ data = torch.from_numpy(np.load(data))
+ os.remove(delfile)
+
+ ofile = preprocessed['ofile']
+ if ofile is not None:
+ print(f'\nPredicting {os.path.basename(ofile)}:')
+ else:
+ print(f'\nPredicting image of shape {data.shape}:')
+
+ print(f'perform_everything_on_gpu: {self.perform_everything_on_gpu}')
+
+ properties = preprocessed['data_properites']
+
+ # let's not get into a runaway situation where the GPU predicts so fast that the disk has to b swamped with
+ # npy files
+ proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2)
+ while not proceed:
+ # print('sleeping')
+ sleep(0.1)
+ proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2)
+
+ prediction = self.predict_logits_from_preprocessed_data(data).cpu()
+
+ if ofile is not None:
+ # this needs to go into background processes
+ # export_prediction_from_logits(prediction, properties, configuration_manager, plans_manager,
+ # dataset_json, ofile, save_probabilities)
+ print('sending off prediction to background worker for resampling and export')
+ r.append(
+ export_pool.starmap_async(
+ export_prediction_from_logits,
+ ((prediction, properties, self.configuration_manager, self.plans_manager,
+ self.dataset_json, ofile, save_probabilities),)
+ )
+ )
+ else:
+ # convert_predicted_logits_to_segmentation_with_correct_shape(prediction, plans_manager,
+ # configuration_manager, label_manager,
+ # properties,
+ # save_probabilities)
+ print('sending off prediction to background worker for resampling')
+ r.append(
+ export_pool.starmap_async(
+ convert_predicted_logits_to_segmentation_with_correct_shape, (
+ (prediction, self.plans_manager,
+ self.configuration_manager, self.label_manager,
+ properties,
+ save_probabilities),)
+ )
+ )
+ if ofile is not None:
+ print(f'done with {os.path.basename(ofile)}')
+ else:
+ print(f'\nDone with image of shape {data.shape}:')
+ ret = [i.get()[0] for i in r]
+
+ if isinstance(data_iterator, MultiThreadedAugmenter):
+ data_iterator._finish()
+
+ # clear lru cache
+ compute_gaussian.cache_clear()
+ # clear device cache
+ empty_cache(self.device)
+ return ret
+
+ def predict_single_npy_array(self, input_image: np.ndarray, image_properties: dict,
+ segmentation_previous_stage: np.ndarray = None,
+ output_file_truncated: str = None,
+ save_or_return_probabilities: bool = False):
+ """
+ image_properties must only have a 'spacing' key!
+ """
+ ppa = PreprocessAdapterFromNpy([input_image], [segmentation_previous_stage], [image_properties],
+ [output_file_truncated],
+ self.plans_manager, self.dataset_json, self.configuration_manager,
+ num_threads_in_multithreaded=1, verbose=self.verbose)
+ if self.verbose:
+ print('preprocessing')
+ dct = next(ppa)
+
+ if self.verbose:
+ print('predicting')
+ predicted_logits = self.predict_logits_from_preprocessed_data(dct['data'])
+ print(predicted_logits.dtype)
+ if self.verbose:
+ print('resampling to original shape')
+ if output_file_truncated is not None:
+ print("Starting export", flush = True)
+ export_prediction_from_logits(predicted_logits, dct['data_properites'], self.configuration_manager,
+ self.plans_manager, self.dataset_json, output_file_truncated,
+ save_or_return_probabilities)
+ else:
+ del input_image
+ ret = convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits, self.plans_manager,
+ self.configuration_manager,
+ self.label_manager,
+ dct['data_properites'],
+ return_probabilities=
+ save_or_return_probabilities)
+ if save_or_return_probabilities:
+ return ret[0], ret[1]
+ else:
+ return ret
+
+ def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Tensor:
+ """
+ IMPORTANT! IF YOU ARE RUNNING THE CASCADE, THE SEGMENTATION FROM THE PREVIOUS STAGE MUST ALREADY BE STACKED ON
+ TOP OF THE IMAGE AS ONE-HOT REPRESENTATION! SEE PreprocessAdapter ON HOW THIS SHOULD BE DONE!
+
+ RETURNED LOGITS HAVE THE SHAPE OF THE INPUT. THEY MUST BE CONVERTED BACK TO THE ORIGINAL IMAGE SIZE.
+ SEE convert_predicted_logits_to_segmentation_with_correct_shape
+ """
+ # we have some code duplication here but this allows us to run with perform_everything_on_gpu=True as
+ # default and not have the entire program crash in case of GPU out of memory. Neat. That should make
+ # things a lot faster for some datasets.
+ original_perform_everything_on_gpu = self.perform_everything_on_gpu
+ with torch.no_grad():
+ prediction = None
+ if self.perform_everything_on_gpu:
+ try:
+ for params in self.list_of_parameters:
+
+ # messing with state dict names...
+ if not isinstance(self.network, OptimizedModule):
+ self.network.load_state_dict(torch.load(params,
+ map_location=torch.device('cpu'))["network_weights"])
+ else:
+ self.network._orig_mod.load_state_dict(torch.load(params,
+ map_location=torch.device('cpu'))["network_weights"])
+
+ if prediction is None:
+ #prediction = torch.argmax(self.predict_sliding_window_return_logits(data), 0,
+ # keepdim=True).type(torch.uint8)
+ prediction = self.predict_sliding_window_return_logits(data)
+ else:
+ #prediction = torch.cat([prediction,
+ # torch.argmax(self.predict_sliding_window_return_logits(data), 0,
+ # keepdim=True).type(torch.uint8)])
+ prediction += self.predict_sliding_window_return_logits(data)
+ if len(self.list_of_parameters) > 1:
+ # prediction /= len(self.list_of_parameters)
+ prediction = torch.mode(prediction, dim=0, keepdim=True).values
+
+ except RuntimeError:
+ print('Prediction with perform_everything_on_gpu=True failed due to insufficient GPU memory. '
+ 'Falling back to perform_everything_on_gpu=False. Not a big deal, just slower...')
+ print('Error:')
+ traceback.print_exc()
+ prediction = None
+ self.perform_everything_on_gpu = False
+
+ if prediction is None:
+ for params in self.list_of_parameters:
+ # messing with state dict names...
+ if not isinstance(self.network, OptimizedModule):
+ self.network.load_state_dict(torch.load(params,
+ map_location=torch.device('cpu'))["network_weights"],
+ strict=False) # bc of old experiments..
+ else:
+ self.network._orig_mod.load_state_dict(torch.load(params,
+ map_location=torch.device('cpu'))["network_weights"],
+ strict=False) # bc of old experiments..
+
+ if prediction is None:
+ prediction = self.predict_sliding_window_return_logits(data)
+ # prediction = torch.argmax(self.predict_sliding_window_return_logits(data),0,keepdim=True).type(torch.uint8)
+ else:
+ prediction += self.predict_sliding_window_return_logits(data)
+ #prediction = torch.cat([prediction,
+ # torch.argmax(self.predict_sliding_window_return_logits(data), 0,
+ # keepdim=True).type(torch.uint8)])
+ if len(self.list_of_parameters) > 1:
+ prediction /= len(self.list_of_parameters)
+ # prediction = torch.mode(prediction, dim=0, keepdim=True).values
+ self.perform_everything_on_gpu = original_perform_everything_on_gpu
+ return prediction
+
+ def _internal_get_sliding_window_slicers(self, image_size: Tuple[int, ...]):
+ slicers = []
+ if len(self.configuration_manager.patch_size) < len(image_size):
+ assert len(self.configuration_manager.patch_size) == len(
+ image_size) - 1, 'if tile_size has less entries than image_size, ' \
+ 'len(tile_size) ' \
+ 'must be one shorter than len(image_size) ' \
+ '(only dimension ' \
+ 'discrepancy of 1 allowed).'
+ steps = compute_steps_for_sliding_window(image_size[1:], self.configuration_manager.patch_size,
+ self.tile_step_size)
+ if self.verbose: print(f'n_steps {image_size[0] * len(steps[0]) * len(steps[1])}, image size is'
+ f' {image_size}, tile_size {self.configuration_manager.patch_size}, '
+ f'tile_step_size {self.tile_step_size}\nsteps:\n{steps}')
+ for d in range(image_size[0]):
+ for sx in steps[0]:
+ for sy in steps[1]:
+ slicers.append(
+ tuple([slice(None), d, *[slice(si, si + ti) for si, ti in
+ zip((sx, sy), self.configuration_manager.patch_size)]]))
+ else:
+ steps = compute_steps_for_sliding_window(image_size, self.configuration_manager.patch_size,
+ self.tile_step_size)
+ if self.verbose: print(
+ f'n_steps {np.prod([len(i) for i in steps])}, image size is {image_size}, tile_size {self.configuration_manager.patch_size}, '
+ f'tile_step_size {self.tile_step_size}\nsteps:\n{steps}')
+ for sx in steps[0]:
+ for sy in steps[1]:
+ for sz in steps[2]:
+ slicers.append(
+ tuple([slice(None), *[slice(si, si + ti) for si, ti in
+ zip((sx, sy, sz), self.configuration_manager.patch_size)]]))
+ return slicers
+
+ def _internal_maybe_mirror_and_predict(self, x: torch.Tensor) -> torch.Tensor:
+ mirror_axes = self.allowed_mirroring_axes if self.use_mirroring else None
+ prediction = self.network(x)
+
+ if mirror_axes is not None:
+ # check for invalid numbers in mirror_axes
+ # x should be 5d for 3d images and 4d for 2d. so the max value of mirror_axes cannot exceed len(x.shape) - 3
+ assert max(mirror_axes) <= len(x.shape) - 3, 'mirror_axes does not match the dimension of the input!'
+
+ # Here I removed the combined flips, bc they ended up blurring the left/right dichotomy
+ # num_predictons = 2 ** len(mirror_axes)
+ num_predictons = len(mirror_axes)
+ if 0 in mirror_axes:
+ prediction += torch.flip(self.network(torch.flip(x, (2,))), (2,))
+ if 1 in mirror_axes:
+ prediction += torch.flip(self.network(torch.flip(x, (3,))), (3,))
+ if 2 in mirror_axes:
+ prediction += torch.flip(self.network(torch.flip(x, (4,))), (4,))
+ # if 0 in mirror_axes and 1 in mirror_axes:
+ # prediction += torch.flip(self.network(torch.flip(x, (2, 3))), (2, 3))
+ # if 0 in mirror_axes and 2 in mirror_axes:
+ # prediction += torch.flip(self.network(torch.flip(x, (2, 4))), (2, 4))
+ # if 1 in mirror_axes and 2 in mirror_axes:
+ # prediction += torch.flip(self.network(torch.flip(x, (3, 4))), (3, 4))
+ # if 0 in mirror_axes and 1 in mirror_axes and 2 in mirror_axes:
+ prediction += torch.flip(self.network(torch.flip(x, (2, 3, 4))), (2, 3, 4))
+ prediction /= num_predictons
+ return prediction
+
+ def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \
+ -> Union[np.ndarray, torch.Tensor]:
+ assert isinstance(input_image, torch.Tensor)
+ self.network = self.network.to(self.device)
+ self.network.eval()
+
+ empty_cache(self.device)
+
+ # Autocast is a little bitch.
+ # If the device_type is 'cpu' then it's slow as heck on some CPUs (no auto bfloat16 support detection)
+ # and needs to be disabled.
+ # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False
+ # is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
+ # So autocast will only be active if we have a cuda device.
+ with torch.no_grad():
+ with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
+ assert len(input_image.shape) == 4, 'input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)'
+
+ if self.verbose: print(f'Input shape: {input_image.shape}')
+ if self.verbose: print("step_size:", self.tile_step_size)
+ if self.verbose: print("mirror_axes:", self.allowed_mirroring_axes if self.use_mirroring else None)
+
+ # if input_image is smaller than tile_size we need to pad it to tile_size.
+ data, slicer_revert_padding = pad_nd_image(input_image, self.configuration_manager.patch_size,
+ 'constant', {'value': 0}, True,
+ None)
+
+ slicers = self._internal_get_sliding_window_slicers(data.shape[1:])
+
+ # preallocate results and num_predictions
+ results_device = self.device if self.perform_everything_on_gpu else torch.device('cpu')
+ if self.verbose: print('preallocating arrays')
+ try:
+ data = data.to(self.device)
+ predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]),
+ dtype=torch.half,
+ device=results_device)
+ n_predictions = torch.zeros(data.shape[1:], dtype=torch.half,
+ device=results_device)
+ if self.use_gaussian:
+ gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8,
+ value_scaling_factor=1000,
+ device=results_device)
+ except RuntimeError:
+ # sometimes the stuff is too large for GPUs. In that case fall back to CPU
+ results_device = torch.device('cpu')
+ data = data.to(results_device)
+ predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]),
+ dtype=torch.half,
+ device=results_device)
+ n_predictions = torch.zeros(data.shape[1:], dtype=torch.half,
+ device=results_device)
+ if self.use_gaussian:
+ gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8,
+ value_scaling_factor=1000,
+ device=results_device)
+ finally:
+ empty_cache(self.device)
+
+ if self.verbose: print('running prediction')
+ for sl in tqdm(slicers, disable=not self.allow_tqdm):
+ workon = data[sl][None]
+ workon = workon.to(self.device, non_blocking=False)
+
+ prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device)
+
+ predicted_logits[sl] += (prediction * gaussian if self.use_gaussian else prediction)
+ n_predictions[sl[1:]] += (gaussian if self.use_gaussian else 1)
+
+ predicted_logits /= n_predictions
+ empty_cache(self.device)
+ return predicted_logits[tuple([slice(None), *slicer_revert_padding[1:]])]
+
+
+def predict_entry_point_modelfolder():
+ import argparse
+ parser = argparse.ArgumentParser(description='Use this to run inference with nnU-Net. This function is used when '
+ 'you want to manually specify a folder containing a trained nnU-Net '
+ 'model. This is useful when the nnunet environment variables '
+ '(nnUNet_results) are not set.')
+ parser.add_argument('-i', type=str, required=True,
+ help='input folder. Remember to use the correct channel numberings for your files (_0000 etc). '
+ 'File endings must be the same as the training dataset!')
+ parser.add_argument('-o', type=str, required=True,
+ help='Output folder. If it does not exist it will be created. Predicted segmentations will '
+ 'have the same name as their source images.')
+ parser.add_argument('-m', type=str, required=True,
+ help='Folder in which the trained model is. Must have subfolders fold_X for the different '
+ 'folds you trained')
+ parser.add_argument('-f', nargs='+', type=str, required=False, default=(0, 1, 2, 3, 4),
+ help='Specify the folds of the trained model that should be used for prediction. '
+ 'Default: (0, 1, 2, 3, 4)')
+ parser.add_argument('-step_size', type=float, required=False, default=0.5,
+ help='Step size for sliding window prediction. The larger it is the faster but less accurate '
+ 'the prediction. Default: 0.5. Cannot be larger than 1. We recommend the default.')
+ parser.add_argument('--disable_tta', action='store_true', required=False, default=False,
+ help='Set this flag to disable test time data augmentation in the form of mirroring. Faster, '
+ 'but less accurate inference. Not recommended.')
+ parser.add_argument('--verbose', action='store_true', help="Set this if you like being talked to. You will have "
+ "to be a good listener/reader.")
+ parser.add_argument('--save_probabilities', action='store_true',
+ help='Set this to export predicted class "probabilities". Required if you want to ensemble '
+ 'multiple configurations.')
+ parser.add_argument('--continue_prediction', '--c', action='store_true',
+ help='Continue an aborted previous prediction (will not overwrite existing files)')
+ parser.add_argument('-chk', type=str, required=False, default='checkpoint_final.pth',
+ help='Name of the checkpoint you want to use. Default: checkpoint_final.pth')
+ parser.add_argument('-npp', type=int, required=False, default=3,
+ help='Number of processes used for preprocessing. More is not always better. Beware of '
+ 'out-of-RAM issues. Default: 3')
+ parser.add_argument('-nps', type=int, required=False, default=3,
+ help='Number of processes used for segmentation export. More is not always better. Beware of '
+ 'out-of-RAM issues. Default: 3')
+ parser.add_argument('-prev_stage_predictions', type=str, required=False, default=None,
+ help='Folder containing the predictions of the previous stage. Required for cascaded models.')
+ parser.add_argument('-device', type=str, default='cuda', required=False,
+ help="Use this to set the device the inference should run with. Available options are 'cuda' "
+ "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! "
+ "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!")
+
+ print(
+ "\n#######################################################################\nPlease cite the following paper "
+ "when using nnU-Net:\n"
+ "Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). "
+ "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. "
+ "Nature methods, 18(2), 203-211.\n#######################################################################\n")
+
+ args = parser.parse_args()
+ args.f = [i if i == 'all' else int(i) for i in args.f]
+
+ if not isdir(args.o):
+ maybe_mkdir_p(args.o)
+
+ assert args.device in ['cpu', 'cuda',
+ 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.'
+ if args.device == 'cpu':
+ # let's allow torch to use hella threads
+ import multiprocessing
+ torch.set_num_threads(multiprocessing.cpu_count())
+ device = torch.device('cpu')
+ elif args.device == 'cuda':
+ # multithreading in torch doesn't help nnU-Net if run on GPU
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ device = torch.device('cuda')
+ else:
+ device = torch.device('mps')
+
+ predictor = nnUNetPredictor(tile_step_size=args.step_size,
+ use_gaussian=True,
+ use_mirroring=not args.disable_tta,
+ perform_everything_on_gpu=True,
+ device=device,
+ verbose=args.verbose)
+ predictor.initialize_from_trained_model_folder(args.m, args.f, args.chk)
+ predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities,
+ overwrite=not args.continue_prediction,
+ num_processes_preprocessing=args.npp,
+ num_processes_segmentation_export=args.nps,
+ folder_with_segs_from_prev_stage=args.prev_stage_predictions,
+ num_parts=1, part_id=0)
+
+
+def predict_entry_point():
+ import argparse
+ parser = argparse.ArgumentParser(description='Use this to run inference with nnU-Net. This function is used when '
+ 'you want to manually specify a folder containing a trained nnU-Net '
+ 'model. This is useful when the nnunet environment variables '
+ '(nnUNet_results) are not set.')
+ parser.add_argument('-i', type=str, required=True,
+ help='input folder. Remember to use the correct channel numberings for your files (_0000 etc). '
+ 'File endings must be the same as the training dataset!')
+ parser.add_argument('-o', type=str, required=True,
+ help='Output folder. If it does not exist it will be created. Predicted segmentations will '
+ 'have the same name as their source images.')
+ parser.add_argument('-d', type=str, required=True,
+ help='Dataset with which you would like to predict. You can specify either dataset name or id')
+ parser.add_argument('-p', type=str, required=False, default='nnUNetPlans',
+ help='Plans identifier. Specify the plans in which the desired configuration is located. '
+ 'Default: nnUNetPlans')
+ parser.add_argument('-tr', type=str, required=False, default='nnUNetTrainer',
+ help='What nnU-Net trainer class was used for training? Default: nnUNetTrainer')
+ parser.add_argument('-c', type=str, required=True,
+ help='nnU-Net configuration that should be used for prediction. Config must be located '
+ 'in the plans specified with -p')
+ parser.add_argument('-f', nargs='+', type=str, required=False, default=(0, 1, 2, 3, 4),
+ help='Specify the folds of the trained model that should be used for prediction. '
+ 'Default: (0, 1, 2, 3, 4)')
+ parser.add_argument('-step_size', type=float, required=False, default=0.5,
+ help='Step size for sliding window prediction. The larger it is the faster but less accurate '
+ 'the prediction. Default: 0.5. Cannot be larger than 1. We recommend the default.')
+ parser.add_argument('--disable_tta', action='store_true', required=False, default=False,
+ help='Set this flag to disable test time data augmentation in the form of mirroring. Faster, '
+ 'but less accurate inference. Not recommended.')
+ parser.add_argument('--verbose', action='store_true', help="Set this if you like being talked to. You will have "
+ "to be a good listener/reader.")
+ parser.add_argument('--save_probabilities', action='store_true',
+ help='Set this to export predicted class "probabilities". Required if you want to ensemble '
+ 'multiple configurations.')
+ parser.add_argument('--continue_prediction', action='store_true',
+ help='Continue an aborted previous prediction (will not overwrite existing files)')
+ parser.add_argument('-chk', type=str, required=False, default='checkpoint_final.pth',
+ help='Name of the checkpoint you want to use. Default: checkpoint_final.pth')
+ parser.add_argument('-npp', type=int, required=False, default=3,
+ help='Number of processes used for preprocessing. More is not always better. Beware of '
+ 'out-of-RAM issues. Default: 3')
+ parser.add_argument('-nps', type=int, required=False, default=3,
+ help='Number of processes used for segmentation export. More is not always better. Beware of '
+ 'out-of-RAM issues. Default: 3')
+ parser.add_argument('-prev_stage_predictions', type=str, required=False, default=None,
+ help='Folder containing the predictions of the previous stage. Required for cascaded models.')
+ parser.add_argument('-num_parts', type=int, required=False, default=1,
+ help='Number of separate nnUNetv2_predict call that you will be making. Default: 1 (= this one '
+ 'call predicts everything)')
+ parser.add_argument('-part_id', type=int, required=False, default=0,
+ help='If multiple nnUNetv2_predict exist, which one is this? IDs start with 0 can end with '
+ 'num_parts - 1. So when you submit 5 nnUNetv2_predict calls you need to set -num_parts '
+ '5 and use -part_id 0, 1, 2, 3 and 4. Simple, right? Note: You are yourself responsible '
+ 'to make these run on separate GPUs! Use CUDA_VISIBLE_DEVICES (google, yo!)')
+ parser.add_argument('-device', type=str, default='cuda', required=False,
+ help="Use this to set the device the inference should run with. Available options are 'cuda' "
+ "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! "
+ "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!")
+
+ print(
+ "\n#######################################################################\nPlease cite the following paper "
+ "when using nnU-Net:\n"
+ "Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). "
+ "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. "
+ "Nature methods, 18(2), 203-211.\n#######################################################################\n")
+
+ args = parser.parse_args()
+ args.f = [i if i == 'all' else int(i) for i in args.f]
+
+ model_folder = get_output_folder(args.d, args.tr, args.p, args.c)
+
+ if not isdir(args.o):
+ maybe_mkdir_p(args.o)
+
+ # slightly passive agressive haha
+ assert args.part_id < args.num_parts, 'Do you even read the documentation? See nnUNetv2_predict -h.'
+
+ assert args.device in ['cpu', 'cuda',
+ 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.'
+ if args.device == 'cpu':
+ # let's allow torch to use hella threads
+ import multiprocessing
+ torch.set_num_threads(multiprocessing.cpu_count())
+ device = torch.device('cpu')
+ elif args.device == 'cuda':
+ # multithreading in torch doesn't help nnU-Net if run on GPU
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ device = torch.device('cuda')
+ else:
+ device = torch.device('mps')
+
+ predictor = nnUNetPredictor(tile_step_size=args.step_size,
+ use_gaussian=True,
+ use_mirroring=not args.disable_tta,
+ perform_everything_on_gpu=True,
+ device=device,
+ verbose=args.verbose,
+ verbose_preprocessing=False)
+ predictor.initialize_from_trained_model_folder(
+ model_folder,
+ args.f,
+ checkpoint_name=args.chk
+ )
+ predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities,
+ overwrite=not args.continue_prediction,
+ num_processes_preprocessing=args.npp,
+ num_processes_segmentation_export=args.nps,
+ folder_with_segs_from_prev_stage=args.prev_stage_predictions,
+ num_parts=args.num_parts,
+ part_id=args.part_id)
+ # r = predict_from_raw_data(args.i,
+ # args.o,
+ # model_folder,
+ # args.f,
+ # args.step_size,
+ # use_gaussian=True,
+ # use_mirroring=not args.disable_tta,
+ # perform_everything_on_gpu=True,
+ # verbose=args.verbose,
+ # save_probabilities=args.save_probabilities,
+ # overwrite=not args.continue_prediction,
+ # checkpoint_name=args.chk,
+ # num_processes_preprocessing=args.npp,
+ # num_processes_segmentation_export=args.nps,
+ # folder_with_segs_from_prev_stage=args.prev_stage_predictions,
+ # num_parts=args.num_parts,
+ # part_id=args.part_id,
+ # device=device)
+
+
+if __name__ == '__main__':
+ # predict a bunch of files
+ from nnunetv2.paths import nnUNet_results, nnUNet_raw
+ predictor = nnUNetPredictor(
+ tile_step_size=0.5,
+ use_gaussian=True,
+ use_mirroring=True,
+ perform_everything_on_gpu=True,
+ device=torch.device('cuda', 0),
+ verbose=False,
+ verbose_preprocessing=False,
+ allow_tqdm=True
+ )
+ predictor.initialize_from_trained_model_folder(
+ join(nnUNet_results, 'Dataset003_Liver/nnUNetTrainer__nnUNetPlans__3d_lowres'),
+ use_folds=(0, ),
+ checkpoint_name='checkpoint_final.pth',
+ )
+ predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'),
+ join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres'),
+ save_probabilities=False, overwrite=False,
+ num_processes_preprocessing=2, num_processes_segmentation_export=2,
+ folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)
+
+ # predict a numpy array
+ from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO
+ img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTr/liver_63_0000.nii.gz')])
+ ret = predictor.predict_single_npy_array(img, props, None, None, False)
+
+ iterator = predictor.get_data_iterator_from_raw_npy_data([img], None, [props], None, 1)
+ ret = predictor.predict_from_data_iterator(iterator, False, 1)
+
+
+ # predictor = nnUNetPredictor(
+ # tile_step_size=0.5,
+ # use_gaussian=True,
+ # use_mirroring=True,
+ # perform_everything_on_gpu=True,
+ # device=torch.device('cuda', 0),
+ # verbose=False,
+ # allow_tqdm=True
+ # )
+ # predictor.initialize_from_trained_model_folder(
+ # join(nnUNet_results, 'Dataset003_Liver/nnUNetTrainer__nnUNetPlans__3d_cascade_fullres'),
+ # use_folds=(0,),
+ # checkpoint_name='checkpoint_final.pth',
+ # )
+ # predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'),
+ # join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predCascade'),
+ # save_probabilities=False, overwrite=False,
+ # num_processes_preprocessing=2, num_processes_segmentation_export=2,
+ # folder_with_segs_from_prev_stage='/media/isensee/data/nnUNet_raw/Dataset003_Liver/imagesTs_predlowres',
+ # num_parts=1, part_id=0)
+
diff --git a/nnUNet/nnunetv2/inference/predict_from_raw_data.py.save b/nnUNet/nnunetv2/inference/predict_from_raw_data.py.save
new file mode 100644
index 0000000..a2329a2
--- /dev/null
+++ b/nnUNet/nnunetv2/inference/predict_from_raw_data.py.save
@@ -0,0 +1,924 @@
+import inspect
+import multiprocessing
+import os
+import traceback
+from copy import deepcopy
+from time import sleep
+from typing import Tuple, Union, List, Optional
+
+import numpy as np
+import torch
+from acvl_utils.cropping_and_padding.padding import pad_nd_image
+from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
+from batchgenerators.utilities.file_and_folder_operations import load_json, join, isfile, maybe_mkdir_p, isdir, subdirs, \
+ save_json
+from torch import nn
+from torch._dynamo import OptimizedModule
+from torch.nn.parallel import DistributedDataParallel
+from tqdm import tqdm
+
+import nnunetv2
+from nnunetv2.configuration import default_num_processes
+from nnunetv2.inference.data_iterators import PreprocessAdapterFromNpy, preprocessing_iterator_fromfiles, \
+ preprocessing_iterator_fromnpy
+from nnunetv2.inference.export_prediction import export_prediction_from_logits, \
+ convert_predicted_logits_to_segmentation_with_correct_shape
+from nnunetv2.inference.sliding_window_prediction import compute_gaussian, \
+ compute_steps_for_sliding_window
+from nnunetv2.utilities.file_path_utilities import get_output_folder, check_workers_alive_and_busy
+from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
+from nnunetv2.utilities.helpers import empty_cache, dummy_context
+from nnunetv2.utilities.json_export import recursive_fix_for_json_export
+from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
+from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
+from nnunetv2.utilities.utils import create_lists_from_splitted_dataset_folder
+
+
+class nnUNetPredictor(object):
+ def __init__(self,
+ tile_step_size: float = 0.5,
+ use_gaussian: bool = True,
+ use_mirroring: bool = True,
+ perform_everything_on_gpu: bool = True,
+ device: torch.device = torch.device('cuda'),
+ verbose: bool = False,
+ verbose_preprocessing: bool = False,
+ allow_tqdm: bool = True):
+ self.verbose = verbose
+ self.verbose_preprocessing = verbose_preprocessing
+ self.allow_tqdm = allow_tqdm
+
+ self.plans_manager, self.configuration_manager, self.list_of_parameters, self.network, self.dataset_json, \
+ self.trainer_name, self.allowed_mirroring_axes, self.label_manager = None, None, None, None, None, None, None, None
+
+ self.tile_step_size = tile_step_size
+ self.use_gaussian = use_gaussian
+ self.use_mirroring = use_mirroring
+ if device.type == 'cuda':
+ # device = torch.device(type='cuda', index=0) # set the desired GPU with CUDA_VISIBLE_DEVICES!
+ # why would I ever want to do that. Stupid dobby. This kills DDP inference...
+ pass
+ if device.type != 'cuda':
+ print(f'perform_everything_on_gpu=True is only supported for cuda devices! Setting this to False')
+ perform_everything_on_gpu = False
+ self.device = device
+ self.perform_everything_on_gpu = perform_everything_on_gpu
+
+ def initialize_from_trained_model_folder(self, model_training_output_dir: str,
+ use_folds: Union[Tuple[Union[int, str]], None],
+ checkpoint_name: str = 'checkpoint_final.pth'):
+ """
+ This is used when making predictions with a trained model
+ """
+ if use_folds is None:
+ use_folds = nnUNetPredictor.auto_detect_available_folds(model_training_output_dir, checkpoint_name)
+
+ dataset_json = load_json(join(model_training_output_dir, 'dataset.json'))
+ plans = load_json(join(model_training_output_dir, 'plans.json'))
+ plans_manager = PlansManager(plans)
+
+ if isinstance(use_folds, str):
+ use_folds = [use_folds]
+
+ parameters = []
+ for i, f in enumerate(use_folds):
+ f = int(f) if f != 'all' else f
+ checkpoint = torch.load(join(model_training_output_dir, f'fold_{f}', checkpoint_name),
+ map_location=torch.device('cpu'))
+ if i == 0:
+ trainer_name = checkpoint['trainer_name']
+ configuration_name = checkpoint['init_args']['configuration']
+ inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \
+ 'inference_allowed_mirroring_axes' in checkpoint.keys() else None
+
+ parameters.append(checkpoint['network_weights'])
+
+ configuration_manager = plans_manager.get_configuration(configuration_name)
+ # restore network
+ num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
+ trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
+ trainer_name, 'nnunetv2.training.nnUNetTrainer')
+ network = trainer_class.build_network_architecture(plans_manager, dataset_json, configuration_manager,
+ num_input_channels, enable_deep_supervision=False)
+ self.plans_manager = plans_manager
+ self.configuration_manager = configuration_manager
+ self.list_of_parameters = parameters
+ self.network = network
+ self.dataset_json = dataset_json
+ self.trainer_name = trainer_name
+ self.allowed_mirroring_axes = inference_allowed_mirroring_axes
+ self.label_manager = plans_manager.get_label_manager(dataset_json)
+ if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \
+ and not isinstance(self.network, OptimizedModule):
+ print('compiling network')
+ self.network = torch.compile(self.network)
+
+ def manual_initialization(self, network: nn.Module, plans_manager: PlansManager,
+ configuration_manager: ConfigurationManager, parameters: Optional[List[dict]],
+ dataset_json: dict, trainer_name: str,
+ inference_allowed_mirroring_axes: Optional[Tuple[int, ...]]):
+ """
+ This is used by the nnUNetTrainer to initialize nnUNetPredictor for the final validation
+ """
+ self.plans_manager = plans_manager
+ self.configuration_manager = configuration_manager
+ self.list_of_parameters = parameters
+ self.network = network
+ self.dataset_json = dataset_json
+ self.trainer_name = trainer_name
+ self.allowed_mirroring_axes = inference_allowed_mirroring_axes
+ self.label_manager = plans_manager.get_label_manager(dataset_json)
+ allow_compile = True
+ allow_compile = allow_compile and ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't'))
+ allow_compile = allow_compile and not isinstance(self.network, OptimizedModule)
+ if isinstance(self.network, DistributedDataParallel):
+ allow_compile = allow_compile and isinstance(self.network.module, OptimizedModule)
+ if allow_compile:
+ print('compiling network')
+ self.network = torch.compile(self.network)
+
+ @staticmethod
+ def auto_detect_available_folds(model_training_output_dir, checkpoint_name):
+ print('use_folds is None, attempting to auto detect available folds')
+ fold_folders = subdirs(model_training_output_dir, prefix='fold_', join=False)
+ fold_folders = [i for i in fold_folders if i != 'fold_all']
+ fold_folders = [i for i in fold_folders if isfile(join(model_training_output_dir, i, checkpoint_name))]
+ use_folds = [int(i.split('_')[-1]) for i in fold_folders]
+ print(f'found the following folds: {use_folds}')
+ return use_folds
+
+ def _manage_input_and_output_lists(self, list_of_lists_or_source_folder: Union[str, List[List[str]]],
+ output_folder_or_list_of_truncated_output_files: Union[None, str, List[str]],
+ folder_with_segs_from_prev_stage: str = None,
+ overwrite: bool = True,
+ part_id: int = 0,
+ num_parts: int = 1,
+ save_probabilities: bool = False):
+ if isinstance(list_of_lists_or_source_folder, str):
+ list_of_lists_or_source_folder = create_lists_from_splitted_dataset_folder(list_of_lists_or_source_folder,
+ self.dataset_json['file_ending'])
+ print(f'There are {len(list_of_lists_or_source_folder)} cases in the source folder')
+ list_of_lists_or_source_folder = list_of_lists_or_source_folder[part_id::num_parts]
+ caseids = [os.path.basename(i[0])[:-(len(self.dataset_json['file_ending']) + 5)] for i in
+ list_of_lists_or_source_folder]
+ print(
+ f'I am process {part_id} out of {num_parts} (max process ID is {num_parts - 1}, we start counting with 0!)')
+ print(f'There are {len(caseids)} cases that I would like to predict')
+
+ if isinstance(output_folder_or_list_of_truncated_output_files, str):
+ output_filename_truncated = [join(output_folder_or_list_of_truncated_output_files, i) for i in caseids]
+ else:
+ output_filename_truncated = output_folder_or_list_of_truncated_output_files
+
+ seg_from_prev_stage_files = [join(folder_with_segs_from_prev_stage, i + self.dataset_json['file_ending']) if
+ folder_with_segs_from_prev_stage is not None else None for i in caseids]
+ # remove already predicted files form the lists
+ if not overwrite and output_filename_truncated is not None:
+ tmp = [isfile(i + self.dataset_json['file_ending']) for i in output_filename_truncated]
+ if save_probabilities:
+ tmp2 = [isfile(i + '.npz') for i in output_filename_truncated]
+ tmp = [i and j for i, j in zip(tmp, tmp2)]
+ not_existing_indices = [i for i, j in enumerate(tmp) if not j]
+
+ output_filename_truncated = [output_filename_truncated[i] for i in not_existing_indices]
+ list_of_lists_or_source_folder = [list_of_lists_or_source_folder[i] for i in not_existing_indices]
+ seg_from_prev_stage_files = [seg_from_prev_stage_files[i] for i in not_existing_indices]
+ print(f'overwrite was set to {overwrite}, so I am only working on cases that haven\'t been predicted yet. '
+ f'That\'s {len(not_existing_indices)} cases.')
+ return list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files
+
+ def predict_from_files(self,
+ list_of_lists_or_source_folder: Union[str, List[List[str]]],
+ output_folder_or_list_of_truncated_output_files: Union[str, None, List[str]],
+ save_probabilities: bool = False,
+ overwrite: bool = True,
+ num_processes_preprocessing: int = default_num_processes,
+ num_processes_segmentation_export: int = default_num_processes,
+ folder_with_segs_from_prev_stage: str = None,
+ num_parts: int = 1,
+ part_id: int = 0):
+ """
+ This is nnU-Net's default function for making predictions. It works best for batch predictions
+ (predicting many images at once).
+ """
+ if isinstance(output_folder_or_list_of_truncated_output_files, str):
+ output_folder = output_folder_or_list_of_truncated_output_files
+ elif isinstance(output_folder_or_list_of_truncated_output_files, list):
+ output_folder = os.path.dirname(output_folder_or_list_of_truncated_output_files[0])
+ else:
+ output_folder = None
+
+ ########################
+ # let's store the input arguments so that its clear what was used to generate the prediction
+ if output_folder is not None:
+ my_init_kwargs = {}
+ for k in inspect.signature(self.predict_from_files).parameters.keys():
+ my_init_kwargs[k] = locals()[k]
+ my_init_kwargs = deepcopy(
+ my_init_kwargs) # let's not unintentionally change anything in-place. Take this as a
+ recursive_fix_for_json_export(my_init_kwargs)
+ maybe_mkdir_p(output_folder)
+ save_json(my_init_kwargs, join(output_folder, 'predict_from_raw_data_args.json'))
+
+ # we need these two if we want to do things with the predictions like for example apply postprocessing
+ save_json(self.dataset_json, join(output_folder, 'dataset.json'), sort_keys=False)
+ save_json(self.plans_manager.plans, join(output_folder, 'plans.json'), sort_keys=False)
+ #######################
+
+ # check if we need a prediction from the previous stage
+ if self.configuration_manager.previous_stage_name is not None:
+ assert folder_with_segs_from_prev_stage is not None, \
+ f'The requested configuration is a cascaded network. It requires the segmentations of the previous ' \
+ f'stage ({self.configuration_manager.previous_stage_name}) as input. Please provide the folder where' \
+ f' they are located via folder_with_segs_from_prev_stage'
+
+ # sort out input and output filenames
+ list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files = \
+ self._manage_input_and_output_lists(list_of_lists_or_source_folder,
+ output_folder_or_list_of_truncated_output_files,
+ folder_with_segs_from_prev_stage, overwrite, part_id, num_parts,
+ save_probabilities)
+ if len(list_of_lists_or_source_folder) == 0:
+ return
+
+ data_iterator = self._internal_get_data_iterator_from_lists_of_filenames(list_of_lists_or_source_folder,
+ seg_from_prev_stage_files,
+ output_filename_truncated,
+ num_processes_preprocessing)
+
+ return self.predict_from_data_iterator(data_iterator, save_probabilities, num_processes_segmentation_export)
+
+ def _internal_get_data_iterator_from_lists_of_filenames(self,
+ input_list_of_lists: List[List[str]],
+ seg_from_prev_stage_files: Union[List[str], None],
+ output_filenames_truncated: Union[List[str], None],
+ num_processes: int):
+ return preprocessing_iterator_fromfiles(input_list_of_lists, seg_from_prev_stage_files,
+ output_filenames_truncated, self.plans_manager, self.dataset_json,
+ self.configuration_manager, num_processes, self.device.type == 'cuda',
+ self.verbose_preprocessing)
+ # preprocessor = self.configuration_manager.preprocessor_class(verbose=self.verbose_preprocessing)
+ # # hijack batchgenerators, yo
+ # # we use the multiprocessing of the batchgenerators dataloader to handle all the background worker stuff. This
+ # # way we don't have to reinvent the wheel here.
+ # num_processes = max(1, min(num_processes, len(input_list_of_lists)))
+ # ppa = PreprocessAdapter(input_list_of_lists, seg_from_prev_stage_files, preprocessor,
+ # output_filenames_truncated, self.plans_manager, self.dataset_json,
+ # self.configuration_manager, num_processes)
+ # if num_processes == 0:
+ # mta = SingleThreadedAugmenter(ppa, None)
+ # else:
+ # mta = MultiThreadedAugmenter(ppa, None, num_processes, 1, None, pin_memory=pin_memory)
+ # return mta
+
+ def get_data_iterator_from_raw_npy_data(self,
+ image_or_list_of_images: Union[np.ndarray, List[np.ndarray]],
+ segs_from_prev_stage_or_list_of_segs_from_prev_stage: Union[None,
+ np.ndarray,
+ List[
+ np.ndarray]],
+ properties_or_list_of_properties: Union[dict, List[dict]],
+ truncated_ofname: Union[str, List[str], None],
+ num_processes: int = 3):
+
+ list_of_images = [image_or_list_of_images] if not isinstance(image_or_list_of_images, list) else \
+ image_or_list_of_images
+
+ if isinstance(segs_from_prev_stage_or_list_of_segs_from_prev_stage, np.ndarray):
+ segs_from_prev_stage_or_list_of_segs_from_prev_stage = [
+ segs_from_prev_stage_or_list_of_segs_from_prev_stage]
+
+ if isinstance(truncated_ofname, str):
+ truncated_ofname = [truncated_ofname]
+
+ if isinstance(properties_or_list_of_properties, dict):
+ properties_or_list_of_properties = [properties_or_list_of_properties]
+
+ num_processes = min(num_processes, len(list_of_images))
+ pp = preprocessing_iterator_fromnpy(
+ list_of_images,
+ segs_from_prev_stage_or_list_of_segs_from_prev_stage,
+ properties_or_list_of_properties,
+ truncated_ofname,
+ self.plans_manager,
+ self.dataset_json,
+ self.configuration_manager,
+ num_processes,
+ self.device.type == 'cuda',
+ self.verbose_preprocessing
+ )
+
+ return pp
+
+ def predict_from_list_of_npy_arrays(self,
+ image_or_list_of_images: Union[np.ndarray, List[np.ndarray]],
+ segs_from_prev_stage_or_list_of_segs_from_prev_stage: Union[None,
+ np.ndarray,
+ List[
+ np.ndarray]],
+ properties_or_list_of_properties: Union[dict, List[dict]],
+ truncated_ofname: Union[str, List[str], None],
+ num_processes: int = 3,
+ save_probabilities: bool = False,
+ num_processes_segmentation_export: int = default_num_processes):
+ iterator = self.get_data_iterator_from_raw_npy_data(image_or_list_of_images,
+ segs_from_prev_stage_or_list_of_segs_from_prev_stage,
+ properties_or_list_of_properties,
+ truncated_ofname,
+ num_processes)
+ return self.predict_from_data_iterator(iterator, save_probabilities, num_processes_segmentation_export)
+
+ def predict_from_data_iterator(self,
+ data_iterator,
+ save_probabilities: bool = False,
+ num_processes_segmentation_export: int = default_num_processes):
+ """
+ each element returned by data_iterator must be a dict with 'data', 'ofile' and 'data_properites' keys!
+ If 'ofile' is None, the result will be returned instead of written to a file
+ """
+ with multiprocessing.get_context("spawn").Pool(num_processes_segmentation_export) as export_pool:
+ worker_list = [i for i in export_pool._pool]
+ r = []
+ for preprocessed in data_iterator:
+ data = preprocessed['data']
+ if isinstance(data, str):
+ delfile = data
+ data = torch.from_numpy(np.load(data))
+ os.remove(delfile)
+
+ ofile = preprocessed['ofile']
+ if ofile is not None:
+ print(f'\nPredicting {os.path.basename(ofile)}:')
+ else:
+ print(f'\nPredicting image of shape {data.shape}:')
+
+ print(f'perform_everything_on_gpu: {self.perform_everything_on_gpu}')
+
+ properties = preprocessed['data_properites']
+
+ # let's not get into a runaway situation where the GPU predicts so fast that the disk has to b swamped with
+ # npy files
+ proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2)
+ while not proceed:
+ # print('sleeping')
+ sleep(0.1)
+ proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2)
+
+ prediction = self.predict_logits_from_preprocessed_data(data).cpu()
+
+ if ofile is not None:
+ # this needs to go into background processes
+ # export_prediction_from_logits(prediction, properties, configuration_manager, plans_manager,
+ # dataset_json, ofile, save_probabilities)
+ print('sending off prediction to background worker for resampling and export')
+ r.append(
+ export_pool.starmap_async(
+ export_prediction_from_logits,
+ ((prediction, properties, self.configuration_manager, self.plans_manager,
+ self.dataset_json, ofile, save_probabilities),)
+ )
+ )
+ else:
+ # convert_predicted_logits_to_segmentation_with_correct_shape(prediction, plans_manager,
+ # configuration_manager, label_manager,
+ # properties,
+ # save_probabilities)
+ print('sending off prediction to background worker for resampling')
+ r.append(
+ export_pool.starmap_async(
+ convert_predicted_logits_to_segmentation_with_correct_shape, (
+ (prediction, self.plans_manager,
+ self.configuration_manager, self.label_manager,
+ properties,
+ save_probabilities),)
+ )
+ )
+ if ofile is not None:
+ print(f'done with {os.path.basename(ofile)}')
+ else:
+ print(f'\nDone with image of shape {data.shape}:')
+ ret = [i.get()[0] for i in r]
+
+ if isinstance(data_iterator, MultiThreadedAugmenter):
+ data_iterator._finish()
+
+ # clear lru cache
+ compute_gaussian.cache_clear()
+ # clear device cache
+ empty_cache(self.device)
+ return ret
+
+ def predict_single_npy_array(self, input_image: np.ndarray, image_properties: dict,
+ segmentation_previous_stage: np.ndarray = None,
+ output_file_truncated: str = None,
+ save_or_return_probabilities: bool = False):
+ """
+ image_properties must only have a 'spacing' key!
+ """
+ ppa = PreprocessAdapterFromNpy([input_image], [segmentation_previous_stage], [image_properties],
+ [output_file_truncated],
+ self.plans_manager, self.dataset_json, self.configuration_manager,
+ num_threads_in_multithreaded=1, verbose=self.verbose)
+ if self.verbose:
+ print('preprocessing')
+ dct = next(ppa)
+
+ if self.verbose:
+ print('predicting')
+ predicted_logits = self.predict_logits_from_preprocessed_data(dct['data'])
+ print(predicted_logits.dtype)
+ #print("resampling", flush = True)
+ if self.verbose:
+ print('resampling to original shape')
+ if output_file_truncated is not None:
+ print("Starting export", flush = True)
+ export_prediction_from_logits(predicted_logits, dct['data_properites'], self.configuration_manager,
+ self.plans_manager, self.dataset_json, output_file_truncated,
+ save_or_return_probabilities)
+ else:
+ del input_image
+ ret = convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits, self.plans_manager,
+ self.configuration_manager,
+ self.label_manager,
+ dct['data_properites'],
+ return_probabilities=
+ save_or_return_probabilities)
+ #print("Done converting", flush=True)
+ if save_or_return_probabilities:
+ return ret[0], ret[1]
+ else:
+ return ret
+
+ def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Tensor:
+ """
+ IMPORTANT! IF YOU ARE RUNNING THE CASCADE, THE SEGMENTATION FROM THE PREVIOUS STAGE MUST ALREADY BE STACKED ON
+ TOP OF THE IMAGE AS ONE-HOT REPRESENTATION! SEE PreprocessAdapter ON HOW THIS SHOULD BE DONE!
+
+ RETURNED LOGITS HAVE THE SHAPE OF THE INPUT. THEY MUST BE CONVERTED BACK TO THE ORIGINAL IMAGE SIZE.
+ SEE convert_predicted_logits_to_segmentation_with_correct_shape
+ """
+ # we have some code duplication here but this allows us to run with perform_everything_on_gpu=True as
+ # default and not have the entire program crash in case of GPU out of memory. Neat. That should make
+ # things a lot faster for some datasets.
+ original_perform_everything_on_gpu = self.perform_everything_on_gpu
+ with torch.no_grad():
+ prediction = None
+ if self.perform_everything_on_gpu:
+ try:
+ for params in self.list_of_parameters:
+
+ # messing with state dict names...
+ if not isinstance(self.network, OptimizedModule):
+ self.network.load_state_dict(params)
+ else:
+ self.network._orig_mod.load_state_dict(params)
+
+ if prediction is None:
+ prediction = self.predict_sliding_window_return_logits(data)
+ else:
+ prediction += self.predict_sliding_window_return_logits(data)
+ # print("done sli win")
+ if len(self.list_of_parameters) > 1:
+ # print("dividing hold on")
+ prediction /= len(self.list_of_parameters)
+
+ except RuntimeError:
+ print('Prediction with perform_everything_on_gpu=True failed due to insufficient GPU memory. '
+ 'Falling back to perform_everything_on_gpu=False. Not a big deal, just slower...')
+ print('Error:')
+ traceback.print_exc()
+ prediction = None
+ self.perform_everything_on_gpu = False
+
+ if prediction is None:
+ for params in self.list_of_parameters:
+ # messing with state dict names...
+ if not isinstance(self.network, OptimizedModule):
+ self.network.load_state_dict(params)
+ else:
+ self.network._orig_mod.load_state_dict(params)
+
+ if prediction is None:
+ prediction = self.predict_sliding_window_return_logits(data)
+ else:
+ prediction += self.predict_sliding_window_return_logits(data)
+ if len(self.list_of_parameters) > 1:
+ prediction /= len(self.list_of_parameters)
+
+ # prediction = prediction.to('cpu')
+ self.perform_everything_on_gpu = original_perform_everything_on_gpu
+ return prediction
+
+ def _internal_get_sliding_window_slicers(self, image_size: Tuple[int, ...]):
+ slicers = []
+ if len(self.configuration_manager.patch_size) < len(image_size):
+ assert len(self.configuration_manager.patch_size) == len(
+ image_size) - 1, 'if tile_size has less entries than image_size, ' \
+ 'len(tile_size) ' \
+ 'must be one shorter than len(image_size) ' \
+ '(only dimension ' \
+ 'discrepancy of 1 allowed).'
+ steps = compute_steps_for_sliding_window(image_size[1:], self.configuration_manager.patch_size,
+ self.tile_step_size)
+ if self.verbose: print(f'n_steps {image_size[0] * len(steps[0]) * len(steps[1])}, image size is'
+ f' {image_size}, tile_size {self.configuration_manager.patch_size}, '
+ f'tile_step_size {self.tile_step_size}\nsteps:\n{steps}')
+ for d in range(image_size[0]):
+ for sx in steps[0]:
+ for sy in steps[1]:
+ slicers.append(
+ tuple([slice(None), d, *[slice(si, si + ti) for si, ti in
+ zip((sx, sy), self.configuration_manager.patch_size)]]))
+ else:
+ steps = compute_steps_for_sliding_window(image_size, self.configuration_manager.patch_size,
+ self.tile_step_size)
+ if self.verbose: print(
+ f'n_steps {np.prod([len(i) for i in steps])}, image size is {image_size}, tile_size {self.configuration_manager.patch_size}, '
+ f'tile_step_size {self.tile_step_size}\nsteps:\n{steps}')
+ for sx in steps[0]:
+ for sy in steps[1]:
+ for sz in steps[2]:
+ slicers.append(
+ tuple([slice(None), *[slice(si, si + ti) for si, ti in
+ zip((sx, sy, sz), self.configuration_manager.patch_size)]]))
+ return slicers
+
+ def _internal_maybe_mirror_and_predict(self, x: torch.Tensor) -> torch.Tensor:
+ mirror_axes = self.allowed_mirroring_axes if self.use_mirroring else None
+ prediction = self.network(x)
+
+ if mirror_axes is not None:
+ # check for invalid numbers in mirror_axes
+ # x should be 5d for 3d images and 4d for 2d. so the max value of mirror_axes cannot exceed len(x.shape) - 3
+ assert max(mirror_axes) <= len(x.shape) - 3, 'mirror_axes does not match the dimension of the input!'
+
+ num_predictons = 2 ** len(mirror_axes)
+ if 0 in mirror_axes:
+ prediction += torch.flip(self.network(torch.flip(x, (2,))), (2,))
+ if 1 in mirror_axes:
+ prediction += torch.flip(self.network(torch.flip(x, (3,))), (3,))
+ if 2 in mirror_axes:
+ prediction += torch.flip(self.network(torch.flip(x, (4,))), (4,))
+ if 0 in mirror_axes and 1 in mirror_axes:
+ prediction += torch.flip(self.network(torch.flip(x, (2, 3))), (2, 3))
+ if 0 in mirror_axes and 2 in mirror_axes:
+ prediction += torch.flip(self.network(torch.flip(x, (2, 4))), (2, 4))
+ if 1 in mirror_axes and 2 in mirror_axes:
+ prediction += torch.flip(self.network(torch.flip(x, (3, 4))), (3, 4))
+ if 0 in mirror_axes and 1 in mirror_axes and 2 in mirror_axes:
+ prediction += torch.flip(self.network(torch.flip(x, (2, 3, 4))), (2, 3, 4))
+ prediction /= num_predictons
+ return prediction
+
+ def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \
+ -> Union[np.ndarray, torch.Tensor]:
+ assert isinstance(input_image, torch.Tensor)
+ self.network = self.network.to(self.device)
+ self.network.eval()
+
+ empty_cache(self.device)
+
+ # Autocast is a little bitch.
+ # If the device_type is 'cpu' then it's slow as heck on some CPUs (no auto bfloat16 support detection)
+ # and needs to be disabled.
+ # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False
+ # is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
+ # So autocast will only be active if we have a cuda device.
+ with torch.no_grad():
+ with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
+ assert len(input_image.shape) == 4, 'input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)'
+
+ if self.verbose: print(f'Input shape: {input_image.shape}')
+ if self.verbose: print("step_size:", self.tile_step_size)
+ if self.verbose: print("mirror_axes:", self.allowed_mirroring_axes if self.use_mirroring else None)
+
+ # if input_image is smaller than tile_size we need to pad it to tile_size.
+ data, slicer_revert_padding = pad_nd_image(input_image, self.configuration_manager.patch_size,
+ 'constant', {'value': 0}, True,
+ None)
+
+ slicers = self._internal_get_sliding_window_slicers(data.shape[1:])
+
+ # preallocate results and num_predictions
+ results_device = self.device if self.perform_everything_on_gpu else torch.device('cpu')
+ if self.verbose: print('preallocating arrays')
+ try:
+ data = data.to(self.device)
+ predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]),
+ dtype=torch.half,
+ device=results_device)
+ n_predictions = torch.zeros(data.shape[1:], dtype=torch.half,
+ device=results_device)
+ if self.use_gaussian:
+ gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8,
+ value_scaling_factor=1000,
+ device=results_device)
+ except RuntimeError:
+ # sometimes the stuff is too large for GPUs. In that case fall back to CPU
+ print("Switching to CPU", flush=True)
+ results_device = torch.device('cpu')
+ data = data.to(results_device)
+ predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]),
+ dtype=torch.half,
+ device=results_device)
+ n_predictions = torch.zeros(data.shape[1:], dtype=torch.half,
+ device=results_device)
+ if self.use_gaussian:
+ gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8,
+ value_scaling_factor=1000,
+ device=results_device)
+ finally:
+ empty_cache(self.device)
+
+ if self.verbose: print('running prediction')
+ for sl in tqdm(slicers, disable=not self.allow_tqdm):
+ ^rint("go",flush=True)
+ workon = data[sl][None]
+ workon = workon.to(self.device, non_blocking=False)
+
+ prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device)
+
+ predicted_logits[sl] += (prediction * gaussian if self.use_gaussian else prediction)
+ n_predictions[sl[1:]] += (gaussian if self.use_gaussian else 1)
+
+ predicted_logits /= n_predictions
+ empty_cache(self.device)
+ return predicted_logits[tuple([slice(None), *slicer_revert_padding[1:]])]
+
+
+def predict_entry_point_modelfolder():
+ import argparse
+ parser = argparse.ArgumentParser(description='Use this to run inference with nnU-Net. This function is used when '
+ 'you want to manually specify a folder containing a trained nnU-Net '
+ 'model. This is useful when the nnunet environment variables '
+ '(nnUNet_results) are not set.')
+ parser.add_argument('-i', type=str, required=True,
+ help='input folder. Remember to use the correct channel numberings for your files (_0000 etc). '
+ 'File endings must be the same as the training dataset!')
+ parser.add_argument('-o', type=str, required=True,
+ help='Output folder. If it does not exist it will be created. Predicted segmentations will '
+ 'have the same name as their source images.')
+ parser.add_argument('-m', type=str, required=True,
+ help='Folder in which the trained model is. Must have subfolders fold_X for the different '
+ 'folds you trained')
+ parser.add_argument('-f', nargs='+', type=str, required=False, default=(0, 1, 2, 3, 4),
+ help='Specify the folds of the trained model that should be used for prediction. '
+ 'Default: (0, 1, 2, 3, 4)')
+ parser.add_argument('-step_size', type=float, required=False, default=0.5,
+ help='Step size for sliding window prediction. The larger it is the faster but less accurate '
+ 'the prediction. Default: 0.5. Cannot be larger than 1. We recommend the default.')
+ parser.add_argument('--disable_tta', action='store_true', required=False, default=False,
+ help='Set this flag to disable test time data augmentation in the form of mirroring. Faster, '
+ 'but less accurate inference. Not recommended.')
+ parser.add_argument('--verbose', action='store_true', help="Set this if you like being talked to. You will have "
+ "to be a good listener/reader.")
+ parser.add_argument('--save_probabilities', action='store_true',
+ help='Set this to export predicted class "probabilities". Required if you want to ensemble '
+ 'multiple configurations.')
+ parser.add_argument('--continue_prediction', '--c', action='store_true',
+ help='Continue an aborted previous prediction (will not overwrite existing files)')
+ parser.add_argument('-chk', type=str, required=False, default='checkpoint_final.pth',
+ help='Name of the checkpoint you want to use. Default: checkpoint_final.pth')
+ parser.add_argument('-npp', type=int, required=False, default=3,
+ help='Number of processes used for preprocessing. More is not always better. Beware of '
+ 'out-of-RAM issues. Default: 3')
+ parser.add_argument('-nps', type=int, required=False, default=3,
+ help='Number of processes used for segmentation export. More is not always better. Beware of '
+ 'out-of-RAM issues. Default: 3')
+ parser.add_argument('-prev_stage_predictions', type=str, required=False, default=None,
+ help='Folder containing the predictions of the previous stage. Required for cascaded models.')
+ parser.add_argument('-device', type=str, default='cuda', required=False,
+ help="Use this to set the device the inference should run with. Available options are 'cuda' "
+ "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! "
+ "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!")
+
+ print(
+ "\n#######################################################################\nPlease cite the following paper "
+ "when using nnU-Net:\n"
+ "Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). "
+ "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. "
+ "Nature methods, 18(2), 203-211.\n#######################################################################\n")
+
+ args = parser.parse_args()
+ args.f = [i if i == 'all' else int(i) for i in args.f]
+
+ if not isdir(args.o):
+ maybe_mkdir_p(args.o)
+
+ assert args.device in ['cpu', 'cuda',
+ 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.'
+ if args.device == 'cpu':
+ # let's allow torch to use hella threads
+ import multiprocessing
+ torch.set_num_threads(multiprocessing.cpu_count())
+ device = torch.device('cpu')
+ elif args.device == 'cuda':
+ # multithreading in torch doesn't help nnU-Net if run on GPU
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ device = torch.device('cuda')
+ else:
+ device = torch.device('mps')
+
+ predictor = nnUNetPredictor(tile_step_size=args.step_size,
+ use_gaussian=True,
+ use_mirroring=not args.disable_tta,
+ perform_everything_on_gpu=True,
+ device=device,
+ verbose=args.verbose)
+ predictor.initialize_from_trained_model_folder(args.m, args.f, args.chk)
+ predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities,
+ overwrite=not args.continue_prediction,
+ num_processes_preprocessing=args.npp,
+ num_processes_segmentation_export=args.nps,
+ folder_with_segs_from_prev_stage=args.prev_stage_predictions,
+ num_parts=1, part_id=0)
+
+
+def predict_entry_point():
+ import argparse
+ parser = argparse.ArgumentParser(description='Use this to run inference with nnU-Net. This function is used when '
+ 'you want to manually specify a folder containing a trained nnU-Net '
+ 'model. This is useful when the nnunet environment variables '
+ '(nnUNet_results) are not set.')
+ parser.add_argument('-i', type=str, required=True,
+ help='input folder. Remember to use the correct channel numberings for your files (_0000 etc). '
+ 'File endings must be the same as the training dataset!')
+ parser.add_argument('-o', type=str, required=True,
+ help='Output folder. If it does not exist it will be created. Predicted segmentations will '
+ 'have the same name as their source images.')
+ parser.add_argument('-d', type=str, required=True,
+ help='Dataset with which you would like to predict. You can specify either dataset name or id')
+ parser.add_argument('-p', type=str, required=False, default='nnUNetPlans',
+ help='Plans identifier. Specify the plans in which the desired configuration is located. '
+ 'Default: nnUNetPlans')
+ parser.add_argument('-tr', type=str, required=False, default='nnUNetTrainer',
+ help='What nnU-Net trainer class was used for training? Default: nnUNetTrainer')
+ parser.add_argument('-c', type=str, required=True,
+ help='nnU-Net configuration that should be used for prediction. Config must be located '
+ 'in the plans specified with -p')
+ parser.add_argument('-f', nargs='+', type=str, required=False, default=(0, 1, 2, 3, 4),
+ help='Specify the folds of the trained model that should be used for prediction. '
+ 'Default: (0, 1, 2, 3, 4)')
+ parser.add_argument('-step_size', type=float, required=False, default=0.5,
+ help='Step size for sliding window prediction. The larger it is the faster but less accurate '
+ 'the prediction. Default: 0.5. Cannot be larger than 1. We recommend the default.')
+ parser.add_argument('--disable_tta', action='store_true', required=False, default=False,
+ help='Set this flag to disable test time data augmentation in the form of mirroring. Faster, '
+ 'but less accurate inference. Not recommended.')
+ parser.add_argument('--verbose', action='store_true', help="Set this if you like being talked to. You will have "
+ "to be a good listener/reader.")
+ parser.add_argument('--save_probabilities', action='store_true',
+ help='Set this to export predicted class "probabilities". Required if you want to ensemble '
+ 'multiple configurations.')
+ parser.add_argument('--continue_prediction', action='store_true',
+ help='Continue an aborted previous prediction (will not overwrite existing files)')
+ parser.add_argument('-chk', type=str, required=False, default='checkpoint_final.pth',
+ help='Name of the checkpoint you want to use. Default: checkpoint_final.pth')
+ parser.add_argument('-npp', type=int, required=False, default=3,
+ help='Number of processes used for preprocessing. More is not always better. Beware of '
+ 'out-of-RAM issues. Default: 3')
+ parser.add_argument('-nps', type=int, required=False, default=3,
+ help='Number of processes used for segmentation export. More is not always better. Beware of '
+ 'out-of-RAM issues. Default: 3')
+ parser.add_argument('-prev_stage_predictions', type=str, required=False, default=None,
+ help='Folder containing the predictions of the previous stage. Required for cascaded models.')
+ parser.add_argument('-num_parts', type=int, required=False, default=1,
+ help='Number of separate nnUNetv2_predict call that you will be making. Default: 1 (= this one '
+ 'call predicts everything)')
+ parser.add_argument('-part_id', type=int, required=False, default=0,
+ help='If multiple nnUNetv2_predict exist, which one is this? IDs start with 0 can end with '
+ 'num_parts - 1. So when you submit 5 nnUNetv2_predict calls you need to set -num_parts '
+ '5 and use -part_id 0, 1, 2, 3 and 4. Simple, right? Note: You are yourself responsible '
+ 'to make these run on separate GPUs! Use CUDA_VISIBLE_DEVICES (google, yo!)')
+ parser.add_argument('-device', type=str, default='cuda', required=False,
+ help="Use this to set the device the inference should run with. Available options are 'cuda' "
+ "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! "
+ "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!")
+
+ print(
+ "\n#######################################################################\nPlease cite the following paper "
+ "when using nnU-Net:\n"
+ "Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). "
+ "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. "
+ "Nature methods, 18(2), 203-211.\n#######################################################################\n")
+
+ args = parser.parse_args()
+ args.f = [i if i == 'all' else int(i) for i in args.f]
+
+ model_folder = get_output_folder(args.d, args.tr, args.p, args.c)
+
+ if not isdir(args.o):
+ maybe_mkdir_p(args.o)
+
+ # slightly passive agressive haha
+ assert args.part_id < args.num_parts, 'Do you even read the documentation? See nnUNetv2_predict -h.'
+
+ assert args.device in ['cpu', 'cuda',
+ 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.'
+ if args.device == 'cpu':
+ # let's allow torch to use hella threads
+ import multiprocessing
+ torch.set_num_threads(multiprocessing.cpu_count())
+ device = torch.device('cpu')
+ elif args.device == 'cuda':
+ # multithreading in torch doesn't help nnU-Net if run on GPU
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ device = torch.device('cuda')
+ else:
+ device = torch.device('mps')
+
+ predictor = nnUNetPredictor(tile_step_size=args.step_size,
+ use_gaussian=True,
+ use_mirroring=not args.disable_tta,
+ perform_everything_on_gpu=True,
+ device=device,
+ verbose=args.verbose,
+ verbose_preprocessing=False)
+ predictor.initialize_from_trained_model_folder(
+ model_folder,
+ args.f,
+ checkpoint_name=args.chk
+ )
+ predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities,
+ overwrite=not args.continue_prediction,
+ num_processes_preprocessing=args.npp,
+ num_processes_segmentation_export=args.nps,
+ folder_with_segs_from_prev_stage=args.prev_stage_predictions,
+ num_parts=args.num_parts,
+ part_id=args.part_id)
+ # r = predict_from_raw_data(args.i,
+ # args.o,
+ # model_folder,
+ # args.f,
+ # args.step_size,
+ # use_gaussian=True,
+ # use_mirroring=not args.disable_tta,
+ # perform_everything_on_gpu=True,
+ # verbose=args.verbose,
+ # save_probabilities=args.save_probabilities,
+ # overwrite=not args.continue_prediction,
+ # checkpoint_name=args.chk,
+ # num_processes_preprocessing=args.npp,
+ # num_processes_segmentation_export=args.nps,
+ # folder_with_segs_from_prev_stage=args.prev_stage_predictions,
+ # num_parts=args.num_parts,
+ # part_id=args.part_id,
+ # device=device)
+
+
+if __name__ == '__main__':
+ # predict a bunch of files
+ from nnunetv2.paths import nnUNet_results, nnUNet_raw
+ predictor = nnUNetPredictor(
+ tile_step_size=0.5,
+ use_gaussian=True,
+ use_mirroring=True,
+ perform_everything_on_gpu=True,
+ device=torch.device('cuda', 0),
+ verbose=False,
+ verbose_preprocessing=False,
+ allow_tqdm=True
+ )
+ predictor.initialize_from_trained_model_folder(
+ join(nnUNet_results, 'Dataset003_Liver/nnUNetTrainer__nnUNetPlans__3d_lowres'),
+ use_folds=(0, ),
+ checkpoint_name='checkpoint_final.pth',
+ )
+ predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'),
+ join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres'),
+ save_probabilities=False, overwrite=False,
+ num_processes_preprocessing=2, num_processes_segmentation_export=2,
+ folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)
+
+ # predict a numpy array
+ from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO
+ img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTr/liver_63_0000.nii.gz')])
+ ret = predictor.predict_single_npy_array(img, props, None, None, False)
+
+ iterator = predictor.get_data_iterator_from_raw_npy_data([img], None, [props], None, 1)
+ ret = predictor.predict_from_data_iterator(iterator, False, 1)
+
+
+ # predictor = nnUNetPredictor(
+ # tile_step_size=0.5,
+ # use_gaussian=True,
+ # use_mirroring=True,
+ # perform_everything_on_gpu=True,
+ # device=torch.device('cuda', 0),
+ # verbose=False,
+ # allow_tqdm=True
+ # )
+ # predictor.initialize_from_trained_model_folder(
+ # join(nnUNet_results, 'Dataset003_Liver/nnUNetTrainer__nnUNetPlans__3d_cascade_fullres'),
+ # use_folds=(0,),
+ # checkpoint_name='checkpoint_final.pth',
+ # )
+ # predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'),
+ # join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predCascade'),
+ # save_probabilities=False, overwrite=False,
+ # num_processes_preprocessing=2, num_processes_segmentation_export=2,
+ # folder_with_segs_from_prev_stage='/media/isensee/data/nnUNet_raw/Dataset003_Liver/imagesTs_predlowres',
+ # num_parts=1, part_id=0)
+
diff --git a/nnUNet/nnunetv2/inference/readme.md b/nnUNet/nnunetv2/inference/readme.md
new file mode 100644
index 0000000..984b4a9
--- /dev/null
+++ b/nnUNet/nnunetv2/inference/readme.md
@@ -0,0 +1,205 @@
+The nnU-Net inference is now much more dynamic than before, allowing you to more seamlessly integrate nnU-Net into
+your existing workflows.
+This readme will give you a quick rundown of your options. This is not a complete guide. Look into the code to learn
+all the details!
+
+# Preface
+In terms of speed, the most efficient inference strategy is the one done by the nnU-Net defaults! Images are read on
+the fly and preprocessed in background workers. The main process takes the preprocessed images, predicts them and
+sends the prediction off to another set of background workers which will resize the resulting logits, convert
+them to a segmentation and export the segmentation.
+
+The reason the default setup is the best option is because
+
+1) loading and preprocessing as well as segmentation export are interlaced with the prediction. The main process can
+focus on communicating with the compute device (i.e. your GPU) and does not have to do any other processing.
+This uses your resources as well as possible!
+2) only the images and segmentation that are currently being needed are stored in RAM! Imaging predicting many images
+and having to store all of them + the results in your system memory
+
+# nnUNetPredictor
+The new nnUNetPredictor class encapsulates the inferencing code and makes it simple to switch between modes. Your
+code can hold a nnUNetPredictor instance and perform prediction on the fly. Previously this was not possible and each
+new prediction request resulted in reloading the parameters and reinstantiating the network architecture. Not ideal.
+
+The nnUNetPredictor must be ininitialized manually! You will want to use the
+`predictor.initialize_from_trained_model_folder` function for 99% of use cases!
+
+New feature: If you do not specify an output folder / output files then the predicted segmentations will be
+returned
+
+
+## Recommended nnU-Net default: predict from source files
+
+tldr:
+- loads images on the fly
+- performs preprocessing in background workers
+- main process focuses only on making predictions
+- results are again given to background workers for resampling and (optional) export
+
+pros:
+- best suited for predicting a large number of images
+- nicer to your RAM
+
+cons:
+- not ideal when single images are to be predicted
+- requires images to be present as files
+
+Example:
+```python
+ from nnunetv2.paths import nnUNet_results, nnUNet_raw
+ import torch
+ from batchgenerators.utilities.file_and_folder_operations import join
+ from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
+
+ # instantiate the nnUNetPredictor
+ predictor = nnUNetPredictor(
+ tile_step_size=0.5,
+ use_gaussian=True,
+ use_mirroring=True,
+ perform_everything_on_gpu=True,
+ device=torch.device('cuda', 0),
+ verbose=False,
+ verbose_preprocessing=False,
+ allow_tqdm=True
+ )
+ # initializes the network architecture, loads the checkpoint
+ predictor.initialize_from_trained_model_folder(
+ join(nnUNet_results, 'Dataset003_Liver/nnUNetTrainer__nnUNetPlans__3d_lowres'),
+ use_folds=(0,),
+ checkpoint_name='checkpoint_final.pth',
+ )
+ # variant 1: give input and output folders
+ predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'),
+ join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres'),
+ save_probabilities=False, overwrite=False,
+ num_processes_preprocessing=2, num_processes_segmentation_export=2,
+ folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)
+```
+
+Instead if giving input and output folders you can also give concrete files. If you give concrete files, there is no
+need for the _0000 suffix anymore! This can be useful in situations where you have no control over the filenames!
+Remember that the files must be given as 'list of lists' where each entry in the outer list is a case to be predicted
+and the inner list contains all the files belonging to that case. There is just one file for datasets with just one
+input modality (such as CT) but may be more files for others (such as MRI where there is sometimes T1, T2, Flair etc).
+IMPORTANT: the order in wich the files for each case are given must match the order of the channels as defined in the
+dataset.json!
+
+If you give files as input, you need to give individual output files as output!
+
+```python
+ # variant 2, use list of files as inputs. Note how we use nested lists!!!
+ indir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs')
+ outdir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres')
+ predictor.predict_from_files([[join(indir, 'liver_152_0000.nii.gz')],
+ [join(indir, 'liver_142_0000.nii.gz')]],
+ [join(outdir, 'liver_152.nii.gz'),
+ join(outdir, 'liver_142.nii.gz')],
+ save_probabilities=False, overwrite=False,
+ num_processes_preprocessing=2, num_processes_segmentation_export=2,
+ folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)
+```
+
+Did you know? If you do not specify output files, the predicted segmentations will be returned:
+```python
+ # variant 2.5, returns segmentations
+ indir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs')
+ predicted_segmentations = predictor.predict_from_files([[join(indir, 'liver_152_0000.nii.gz')],
+ [join(indir, 'liver_142_0000.nii.gz')]],
+ None,
+ save_probabilities=False, overwrite=True,
+ num_processes_preprocessing=2, num_processes_segmentation_export=2,
+ folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0)
+```
+
+## Prediction from npy arrays
+tldr:
+- you give images as a list of npy arrays
+- performs preprocessing in background workers
+- main process focuses only on making predictions
+- results are again given to background workers for resampling and (optional) export
+
+pros:
+- the correct variant for when you have images in RAM already
+- well suited for predicting multiple images
+
+cons:
+- uses more ram than the default
+- unsuited for large number of images as all images must be held in RAM
+
+```python
+ from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO
+
+ img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')])
+ img2, props2 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_146_0000.nii.gz')])
+ img3, props3 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_145_0000.nii.gz')])
+ img4, props4 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_144_0000.nii.gz')])
+ # we do not set output files so that the segmentations will be returned. You can of course also specify output
+ # files instead (no return value on that case)
+ ret = predictor.predict_from_list_of_npy_arrays([img, img2, img3, img4],
+ None,
+ [props, props2, props3, props4],
+ None, 2, save_probabilities=False,
+ num_processes_segmentation_export=2)
+```
+
+## Predicting a single npy array
+
+tldr:
+- you give one image as npy array
+- everything is done in the main process: preprocessing, prediction, resampling, (export)
+- no interlacing, slowest variant!
+- ONLY USE THIS IF YOU CANNOT GIVE NNUNET MULTIPLE IMAGES AT ONCE FOR SOME REASON
+
+pros:
+- no messing with multiprocessing
+- no messing with data iterator blabla
+
+cons:
+- slows as heck, yo
+- never the right choice unless you can only give a single image at a time to nnU-Net
+
+```python
+ # predict a single numpy array
+ img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTr/liver_63_0000.nii.gz')])
+ ret = predictor.predict_single_npy_array(img, props, None, None, False)
+```
+
+## Predicting with a custom data iterator
+tldr:
+- highly flexible
+- not for newbies
+
+pros:
+- you can do everything yourself
+- you have all the freedom you want
+- really fast if you remember to use multiprocessing in your iterator
+
+cons:
+- you need to do everything yourself
+- harder than you might think
+
+```python
+ img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')])
+ img2, props2 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_146_0000.nii.gz')])
+ img3, props3 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_145_0000.nii.gz')])
+ img4, props4 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_144_0000.nii.gz')])
+ # each element returned by data_iterator must be a dict with 'data', 'ofile' and 'data_properites' keys!
+ # If 'ofile' is None, the result will be returned instead of written to a file
+ # the iterator is responsible for performing the correct preprocessing!
+ # note how the iterator here does not use multiprocessing -> preprocessing will be done in the main thread!
+ # take a look at the default iterators for predict_from_files and predict_from_list_of_npy_arrays
+ # (they both use predictor.predict_from_data_iterator) for inspiration!
+ def my_iterator(list_of_input_arrs, list_of_input_props):
+ preprocessor = predictor.configuration_manager.preprocessor_class(verbose=predictor.verbose)
+ for a, p in zip(list_of_input_arrs, list_of_input_props):
+ data, seg = preprocessor.run_case_npy(a,
+ None,
+ p,
+ predictor.plans_manager,
+ predictor.configuration_manager,
+ predictor.dataset_json)
+ yield {'data': torch.from_numpy(data).contiguous().pin_memory(), 'data_properites': p, 'ofile': None}
+ ret = predictor.predict_from_data_iterator(my_iterator([img, img2, img3, img4], [props, props2, props3, props4]),
+ save_probabilities=False, num_processes_segmentation_export=3)
+```
\ No newline at end of file
diff --git a/nnUNet/nnunetv2/inference/sliding_window_prediction.py b/nnUNet/nnunetv2/inference/sliding_window_prediction.py
new file mode 100644
index 0000000..07316cf
--- /dev/null
+++ b/nnUNet/nnunetv2/inference/sliding_window_prediction.py
@@ -0,0 +1,67 @@
+from functools import lru_cache
+
+import numpy as np
+import torch
+from typing import Union, Tuple, List
+from acvl_utils.cropping_and_padding.padding import pad_nd_image
+from scipy.ndimage import gaussian_filter
+
+
+@lru_cache(maxsize=2)
+def compute_gaussian(tile_size: Union[Tuple[int, ...], List[int]], sigma_scale: float = 1. / 8,
+ value_scaling_factor: float = 1, dtype=torch.float16, device=torch.device('cuda', 0)) \
+ -> torch.Tensor:
+ tmp = np.zeros(tile_size)
+ center_coords = [i // 2 for i in tile_size]
+ sigmas = [i * sigma_scale for i in tile_size]
+ tmp[tuple(center_coords)] = 1
+ gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0)
+
+ gaussian_importance_map = torch.from_numpy(gaussian_importance_map).type(dtype).to(device)
+
+ gaussian_importance_map = gaussian_importance_map / torch.max(gaussian_importance_map) * value_scaling_factor
+ gaussian_importance_map = gaussian_importance_map.type(dtype)
+
+ # gaussian_importance_map cannot be 0, otherwise we may end up with nans!
+ gaussian_importance_map[gaussian_importance_map == 0] = torch.min(
+ gaussian_importance_map[gaussian_importance_map != 0])
+
+ return gaussian_importance_map
+
+
+def compute_steps_for_sliding_window(image_size: Tuple[int, ...], tile_size: Tuple[int, ...], tile_step_size: float) -> \
+ List[List[int]]:
+ assert [i >= j for i, j in zip(image_size, tile_size)], "image size must be as large or larger than patch_size"
+ assert 0 < tile_step_size <= 1, 'step_size must be larger than 0 and smaller or equal to 1'
+
+ # our step width is patch_size*step_size at most, but can be narrower. For example if we have image size of
+ # 110, patch size of 64 and step_size of 0.5, then we want to make 3 steps starting at coordinate 0, 23, 46
+ target_step_sizes_in_voxels = [i * tile_step_size for i in tile_size]
+
+ num_steps = [int(np.ceil((i - k) / j)) + 1 for i, j, k in zip(image_size, target_step_sizes_in_voxels, tile_size)]
+
+ steps = []
+ for dim in range(len(tile_size)):
+ # the highest step value for this dimension is
+ max_step_value = image_size[dim] - tile_size[dim]
+ if num_steps[dim] > 1:
+ actual_step_size = max_step_value / (num_steps[dim] - 1)
+ else:
+ actual_step_size = 99999999999 # does not matter because there is only one step at 0
+
+ steps_here = [int(np.round(actual_step_size * i)) for i in range(num_steps[dim])]
+
+ steps.append(steps_here)
+
+ return steps
+
+
+if __name__ == '__main__':
+ a = torch.rand((4, 2, 32, 23))
+ a_npy = a.numpy()
+
+ a_padded = pad_nd_image(a, new_shape=(48, 27))
+ a_npy_padded = pad_nd_image(a_npy, new_shape=(48, 27))
+ assert all([i == j for i, j in zip(a_padded.shape, (4, 2, 48, 27))])
+ assert all([i == j for i, j in zip(a_npy_padded.shape, (4, 2, 48, 27))])
+ assert np.all(a_padded.numpy() == a_npy_padded)
diff --git a/nnUNet/nnunetv2/model_sharing/__init__.py b/nnUNet/nnunetv2/model_sharing/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/model_sharing/entry_points.py b/nnUNet/nnunetv2/model_sharing/entry_points.py
new file mode 100644
index 0000000..1ab7c93
--- /dev/null
+++ b/nnUNet/nnunetv2/model_sharing/entry_points.py
@@ -0,0 +1,61 @@
+from nnunetv2.model_sharing.model_download import download_and_install_from_url
+from nnunetv2.model_sharing.model_export import export_pretrained_model
+from nnunetv2.model_sharing.model_import import install_model_from_zip_file
+
+
+def print_license_warning():
+ print('')
+ print('######################################################')
+ print('!!!!!!!!!!!!!!!!!!!!!!!!WARNING!!!!!!!!!!!!!!!!!!!!!!!')
+ print('######################################################')
+ print("Using the pretrained model weights is subject to the license of the dataset they were trained on. Some "
+ "allow commercial use, others don't. It is your responsibility to make sure you use them appropriately! Use "
+ "nnUNet_print_pretrained_model_info(task_name) to see a summary of the dataset and where to find its license!")
+ print('######################################################')
+ print('')
+
+
+def download_by_url():
+ import argparse
+ parser = argparse.ArgumentParser(
+ description="Use this to download pretrained models. This script is intended to download models via url only. "
+ "CAREFUL: This script will overwrite "
+ "existing models (if they share the same trainer class and plans as "
+ "the pretrained model.")
+ parser.add_argument("url", type=str, help='URL of the pretrained model')
+ args = parser.parse_args()
+ url = args.url
+ download_and_install_from_url(url)
+
+
+def install_from_zip_entry_point():
+ import argparse
+ parser = argparse.ArgumentParser(
+ description="Use this to install a zip file containing a pretrained model.")
+ parser.add_argument("zip", type=str, help='zip file')
+ args = parser.parse_args()
+ zip = args.zip
+ install_model_from_zip_file(zip)
+
+
+def export_pretrained_model_entry():
+ import argparse
+ parser = argparse.ArgumentParser(
+ description="Use this to export a trained model as a zip file.")
+ parser.add_argument('-d', type=str, required=True, help='Dataset name or id')
+ parser.add_argument('-o', type=str, required=True, help='Output file name')
+ parser.add_argument('-c', nargs='+', type=str, required=False,
+ default=('3d_lowres', '3d_fullres', '2d', '3d_cascade_fullres'),
+ help="List of configuration names")
+ parser.add_argument('-tr', required=False, type=str, default='nnUNetTrainer', help='Trainer class')
+ parser.add_argument('-p', required=False, type=str, default='nnUNetPlans', help='plans identifier')
+ parser.add_argument('-f', required=False, nargs='+', type=str, default=(0, 1, 2, 3, 4), help='list of fold ids')
+ parser.add_argument('-chk', required=False, nargs='+', type=str, default=('checkpoint_final.pth', ),
+ help='Lis tof checkpoint names to export. Default: checkpoint_final.pth')
+ parser.add_argument('--not_strict', action='store_false', default=False, required=False, help='Set this to allow missing folds and/or configurations')
+ parser.add_argument('--exp_cv_preds', action='store_true', required=False, help='Set this to export the cross-validation predictions as well')
+ args = parser.parse_args()
+
+ export_pretrained_model(dataset_name_or_id=args.d, output_file=args.o, configurations=args.c, trainer=args.tr,
+ plans_identifier=args.p, folds=args.f, strict=not args.not_strict, save_checkpoints=args.chk,
+ export_crossval_predictions=args.exp_cv_preds)
diff --git a/nnUNet/nnunetv2/model_sharing/model_download.py b/nnUNet/nnunetv2/model_sharing/model_download.py
new file mode 100644
index 0000000..d845ab5
--- /dev/null
+++ b/nnUNet/nnunetv2/model_sharing/model_download.py
@@ -0,0 +1,47 @@
+from typing import Optional
+
+import requests
+from batchgenerators.utilities.file_and_folder_operations import *
+from time import time
+from nnunetv2.model_sharing.model_import import install_model_from_zip_file
+from nnunetv2.paths import nnUNet_results
+from tqdm import tqdm
+
+
+def download_and_install_from_url(url):
+ assert nnUNet_results is not None, "Cannot install model because network_training_output_dir is not " \
+ "set (RESULTS_FOLDER missing as environment variable, see " \
+ "Installation instructions)"
+ print('Downloading pretrained model from url:', url)
+ import http.client
+ http.client.HTTPConnection._http_vsn = 10
+ http.client.HTTPConnection._http_vsn_str = 'HTTP/1.0'
+
+ import os
+ home = os.path.expanduser('~')
+ random_number = int(time() * 1e7)
+ tempfile = join(home, '.nnunetdownload_%s' % str(random_number))
+
+ try:
+ download_file(url=url, local_filename=tempfile, chunk_size=8192 * 16)
+ print("Download finished. Extracting...")
+ install_model_from_zip_file(tempfile)
+ print("Done")
+ except Exception as e:
+ raise e
+ finally:
+ if isfile(tempfile):
+ os.remove(tempfile)
+
+
+def download_file(url: str, local_filename: str, chunk_size: Optional[int] = 8192 * 16) -> str:
+ # borrowed from https://stackoverflow.com/questions/16694907/download-large-file-in-python-with-requests
+ # NOTE the stream=True parameter below
+ with requests.get(url, stream=True, timeout=100) as r:
+ r.raise_for_status()
+ with tqdm.wrapattr(open(local_filename, 'wb'), "write", total=int(r.headers.get("Content-Length"))) as f:
+ for chunk in r.iter_content(chunk_size=chunk_size):
+ f.write(chunk)
+ return local_filename
+
+
diff --git a/nnUNet/nnunetv2/model_sharing/model_export.py b/nnUNet/nnunetv2/model_sharing/model_export.py
new file mode 100644
index 0000000..2db8e24
--- /dev/null
+++ b/nnUNet/nnunetv2/model_sharing/model_export.py
@@ -0,0 +1,124 @@
+import zipfile
+
+from nnunetv2.utilities.file_path_utilities import *
+
+
+def export_pretrained_model(dataset_name_or_id: Union[int, str], output_file: str,
+ configurations: Tuple[str] = ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
+ trainer: str = 'nnUNetTrainer',
+ plans_identifier: str = 'nnUNetPlans',
+ folds: Tuple[int, ...] = (0, 1, 2, 3, 4),
+ strict: bool = True,
+ save_checkpoints: Tuple[str, ...] = ('checkpoint_final.pth',),
+ export_crossval_predictions: bool = False) -> None:
+ dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id)
+ with(zipfile.ZipFile(output_file, 'w', zipfile.ZIP_DEFLATED)) as zipf:
+ for c in configurations:
+ print(f"Configuration {c}")
+ trainer_output_dir = get_output_folder(dataset_name, trainer, plans_identifier, c)
+
+ if not isdir(trainer_output_dir):
+ if strict:
+ raise RuntimeError(f"{dataset_name} is missing the trained model of configuration {c}")
+ else:
+ continue
+
+ expected_fold_folder = ["fold_%s" % i if i != 'all' else 'fold_all' for i in folds]
+ assert all([isdir(join(trainer_output_dir, i)) for i in expected_fold_folder]), \
+ f"not all requested folds are present; {dataset_name} {c}; requested folds: {folds}"
+
+ assert isfile(join(trainer_output_dir, "plans.json")), f"plans.json missing, {dataset_name} {c}"
+
+ for fold_folder in expected_fold_folder:
+ print(f"Exporting {fold_folder}")
+ # debug.json, does not exist yet
+ source_file = join(trainer_output_dir, fold_folder, "debug.json")
+ if isfile(source_file):
+ zipf.write(source_file, os.path.relpath(source_file, nnUNet_results))
+
+ # all requested checkpoints
+ for chk in save_checkpoints:
+ source_file = join(trainer_output_dir, fold_folder, chk)
+ zipf.write(source_file, os.path.relpath(source_file, nnUNet_results))
+
+ # progress.png
+ source_file = join(trainer_output_dir, fold_folder, "progress.png")
+ zipf.write(source_file, os.path.relpath(source_file, nnUNet_results))
+
+ # if it exists, network architecture.png
+ source_file = join(trainer_output_dir, fold_folder, "network_architecture.pdf")
+ if isfile(source_file):
+ zipf.write(source_file, os.path.relpath(source_file, nnUNet_results))
+
+ # validation folder with all predicted segmentations etc
+ if export_crossval_predictions:
+ source_folder = join(trainer_output_dir, fold_folder, "validation")
+ files = [i for i in subfiles(source_folder, join=False) if not i.endswith('.npz') and not i.endswith('.pkl')]
+ for f in files:
+ zipf.write(join(source_folder, f), os.path.relpath(join(source_folder, f), nnUNet_results))
+ # just the summary.json file from the validation
+ else:
+ source_file = join(trainer_output_dir, fold_folder, "validation", "summary.json")
+ zipf.write(source_file, os.path.relpath(source_file, nnUNet_results))
+
+ source_folder = join(trainer_output_dir, f'crossval_results_folds_{folds_tuple_to_string(folds)}')
+ if isdir(source_folder):
+ if export_crossval_predictions:
+ source_files = subfiles(source_folder, join=True)
+ else:
+ source_files = [
+ join(trainer_output_dir, f'crossval_results_folds_{folds_tuple_to_string(folds)}', i) for i in
+ ['summary.json', 'postprocessing.pkl', 'postprocessing.json']
+ ]
+ for s in source_files:
+ if isfile(s):
+ zipf.write(s, os.path.relpath(s, nnUNet_results))
+ # plans
+ source_file = join(trainer_output_dir, "plans.json")
+ zipf.write(source_file, os.path.relpath(source_file, nnUNet_results))
+ # fingerprint
+ source_file = join(trainer_output_dir, "dataset_fingerprint.json")
+ zipf.write(source_file, os.path.relpath(source_file, nnUNet_results))
+ # dataset
+ source_file = join(trainer_output_dir, "dataset.json")
+ zipf.write(source_file, os.path.relpath(source_file, nnUNet_results))
+
+ ensemble_dir = join(nnUNet_results, dataset_name, 'ensembles')
+
+ if not isdir(ensemble_dir):
+ print("No ensemble directory found for task", dataset_name_or_id)
+ return
+ subd = subdirs(ensemble_dir, join=False)
+ # figure out whether the models in the ensemble are all within the exported models here
+ for ens in subd:
+ identifiers, folds = convert_ensemble_folder_to_model_identifiers_and_folds(ens)
+ ok = True
+ for i in identifiers:
+ tr, pl, c = convert_identifier_to_trainer_plans_config(i)
+ if tr == trainer and pl == plans_identifier and c in configurations:
+ pass
+ else:
+ ok = False
+ if ok:
+ print(f'found matching ensemble: {ens}')
+ source_folder = join(ensemble_dir, ens)
+ if export_crossval_predictions:
+ source_files = subfiles(source_folder, join=True)
+ else:
+ source_files = [
+ join(source_folder, i) for i in
+ ['summary.json', 'postprocessing.pkl', 'postprocessing.json'] if isfile(join(source_folder, i))
+ ]
+ for s in source_files:
+ zipf.write(s, os.path.relpath(s, nnUNet_results))
+ inference_information_file = join(nnUNet_results, dataset_name, 'inference_information.json')
+ if isfile(inference_information_file):
+ zipf.write(inference_information_file, os.path.relpath(inference_information_file, nnUNet_results))
+ inference_information_txt_file = join(nnUNet_results, dataset_name, 'inference_information.txt')
+ if isfile(inference_information_txt_file):
+ zipf.write(inference_information_txt_file, os.path.relpath(inference_information_txt_file, nnUNet_results))
+ print('Done')
+
+
+if __name__ == '__main__':
+ export_pretrained_model(2, '/home/fabian/temp/dataset2.zip', strict=False, export_crossval_predictions=True, folds=(0, ))
diff --git a/nnUNet/nnunetv2/model_sharing/model_import.py b/nnUNet/nnunetv2/model_sharing/model_import.py
new file mode 100644
index 0000000..0356e90
--- /dev/null
+++ b/nnUNet/nnunetv2/model_sharing/model_import.py
@@ -0,0 +1,8 @@
+import zipfile
+
+from nnunetv2.paths import nnUNet_results
+
+
+def install_model_from_zip_file(zip_file: str):
+ with zipfile.ZipFile(zip_file, 'r') as zip_ref:
+ zip_ref.extractall(nnUNet_results)
\ No newline at end of file
diff --git a/nnUNet/nnunetv2/paths.py b/nnUNet/nnunetv2/paths.py
new file mode 100644
index 0000000..f2220f9
--- /dev/null
+++ b/nnUNet/nnunetv2/paths.py
@@ -0,0 +1,39 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+"""
+PLEASE READ paths.md FOR INFORMATION TO HOW TO SET THIS UP
+"""
+
+nnUNet_raw = os.environ.get('nnUNet_raw')
+nnUNet_preprocessed = os.environ.get('nnUNet_preprocessed')
+nnUNet_results = os.environ.get('nnUNet_results')
+
+if nnUNet_raw is None:
+ print("nnUNet_raw is not defined and nnU-Net can only be used on data for which preprocessed files "
+ "are already present on your system. nnU-Net cannot be used for experiment planning and preprocessing like "
+ "this. If this is not intended, please read documentation/setting_up_paths.md for information on how to set "
+ "this up properly.")
+
+if nnUNet_preprocessed is None:
+ print("nnUNet_preprocessed is not defined and nnU-Net can not be used for preprocessing "
+ "or training. If this is not intended, please read documentation/setting_up_paths.md for information on how "
+ "to set this up.")
+
+if nnUNet_results is None:
+ print("nnUNet_results is not defined and nnU-Net cannot be used for training or "
+ "inference. If this is not intended behavior, please read documentation/setting_up_paths.md for information "
+ "on how to set this up.")
diff --git a/nnUNet/nnunetv2/postprocessing/__init__.py b/nnUNet/nnunetv2/postprocessing/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/postprocessing/remove_connected_components.py b/nnUNet/nnunetv2/postprocessing/remove_connected_components.py
new file mode 100644
index 0000000..94724fc
--- /dev/null
+++ b/nnUNet/nnunetv2/postprocessing/remove_connected_components.py
@@ -0,0 +1,362 @@
+import argparse
+import multiprocessing
+import shutil
+from multiprocessing import Pool
+from typing import Union, Tuple, List, Callable
+
+import numpy as np
+from acvl_utils.morphology.morphology_helper import remove_all_but_largest_component
+from batchgenerators.utilities.file_and_folder_operations import load_json, subfiles, maybe_mkdir_p, join, isfile, \
+ isdir, save_pickle, load_pickle, save_json
+from nnunetv2.configuration import default_num_processes
+from nnunetv2.evaluation.accumulate_cv_results import accumulate_cv_results
+from nnunetv2.evaluation.evaluate_predictions import region_or_label_to_mask, compute_metrics_on_folder, \
+ load_summary_json, label_or_region_to_key
+from nnunetv2.imageio.base_reader_writer import BaseReaderWriter
+from nnunetv2.paths import nnUNet_raw
+from nnunetv2.utilities.file_path_utilities import folds_tuple_to_string
+from nnunetv2.utilities.json_export import recursive_fix_for_json_export
+from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
+
+
+def remove_all_but_largest_component_from_segmentation(segmentation: np.ndarray,
+ labels_or_regions: Union[int, Tuple[int, ...],
+ List[Union[int, Tuple[int, ...]]]],
+ background_label: int = 0) -> np.ndarray:
+ mask = np.zeros_like(segmentation, dtype=bool)
+ if not isinstance(labels_or_regions, list):
+ labels_or_regions = [labels_or_regions]
+ for l_or_r in labels_or_regions:
+ mask |= region_or_label_to_mask(segmentation, l_or_r)
+ mask_keep = remove_all_but_largest_component(mask)
+ ret = np.copy(segmentation) # do not modify the input!
+ ret[mask & ~mask_keep] = background_label
+ return ret
+
+
+def apply_postprocessing(segmentation: np.ndarray, pp_fns: List[Callable], pp_fn_kwargs: List[dict]):
+ for fn, kwargs in zip(pp_fns, pp_fn_kwargs):
+ segmentation = fn(segmentation, **kwargs)
+ return segmentation
+
+
+def load_postprocess_save(segmentation_file: str,
+ output_fname: str,
+ image_reader_writer: BaseReaderWriter,
+ pp_fns: List[Callable],
+ pp_fn_kwargs: List[dict]):
+ seg, props = image_reader_writer.read_seg(segmentation_file)
+ seg = apply_postprocessing(seg[0], pp_fns, pp_fn_kwargs)
+ image_reader_writer.write_seg(seg, output_fname, props)
+
+
+def determine_postprocessing(folder_predictions: str,
+ folder_ref: str,
+ plans_file_or_dict: Union[str, dict],
+ dataset_json_file_or_dict: Union[str, dict],
+ num_processes: int = default_num_processes,
+ keep_postprocessed_files: bool = True):
+ """
+ Determines nnUNet postprocessing. Its output is a postprocessing.pkl file in folder_predictions which can be
+ used with apply_postprocessing_to_folder.
+
+ Postprocessed files are saved in folder_predictions/postprocessed. Set
+ keep_postprocessed_files=False to delete these files after this function is done (temp files will eb created
+ and deleted regardless).
+
+ If plans_file_or_dict or dataset_json_file_or_dict are None, we will look for them in input_folder
+ """
+ output_folder = join(folder_predictions, 'postprocessed')
+
+ if plans_file_or_dict is None:
+ expected_plans_file = join(folder_predictions, 'plans.json')
+ if not isfile(expected_plans_file):
+ raise RuntimeError(f"Expected plans file missing: {expected_plans_file}. The plans fils should have been "
+ f"created while running nnUNetv2_predict. Sadge.")
+ plans_file_or_dict = load_json(expected_plans_file)
+ plans_manager = PlansManager(plans_file_or_dict)
+
+ if dataset_json_file_or_dict is None:
+ expected_dataset_json_file = join(folder_predictions, 'dataset.json')
+ if not isfile(expected_dataset_json_file):
+ raise RuntimeError(
+ f"Expected plans file missing: {expected_dataset_json_file}. The plans fils should have been "
+ f"created while running nnUNetv2_predict. Sadge.")
+ dataset_json_file_or_dict = load_json(expected_dataset_json_file)
+
+ if not isinstance(dataset_json_file_or_dict, dict):
+ dataset_json = load_json(dataset_json_file_or_dict)
+ else:
+ dataset_json = dataset_json_file_or_dict
+
+ rw = plans_manager.image_reader_writer_class()
+ label_manager = plans_manager.get_label_manager(dataset_json)
+ labels_or_regions = label_manager.foreground_regions if label_manager.has_regions else label_manager.foreground_labels
+
+ predicted_files = subfiles(folder_predictions, suffix=dataset_json['file_ending'], join=False)
+ ref_files = subfiles(folder_ref, suffix=dataset_json['file_ending'], join=False)
+ # we should print a warning if not all files from folder_ref are present in folder_predictions
+ if not all([i in predicted_files for i in ref_files]):
+ print(f'WARNING: Not all files in folder_ref were found in folder_predictions. Determining postprocessing '
+ f'should always be done on the entire dataset!')
+
+ # before we start we should evaluate the imaegs in the source folder
+ if not isfile(join(folder_predictions, 'summary.json')):
+ compute_metrics_on_folder(folder_ref,
+ folder_predictions,
+ join(folder_predictions, 'summary.json'),
+ rw,
+ dataset_json['file_ending'],
+ labels_or_regions,
+ label_manager.ignore_label,
+ num_processes)
+
+ # we save the postprocessing functions in here
+ pp_fns = []
+ pp_fn_kwargs = []
+
+ # pool party!
+ with multiprocessing.get_context("spawn").Pool(num_processes) as pool:
+ # now let's see whether removing all but the largest foreground region improves the scores
+ output_here = join(output_folder, 'temp', 'keep_largest_fg')
+ maybe_mkdir_p(output_here)
+ pp_fn = remove_all_but_largest_component_from_segmentation
+ kwargs = {
+ 'labels_or_regions': label_manager.foreground_labels,
+ }
+
+ pool.starmap(
+ load_postprocess_save,
+ zip(
+ [join(folder_predictions, i) for i in predicted_files],
+ [join(output_here, i) for i in predicted_files],
+ [rw] * len(predicted_files),
+ [[pp_fn]] * len(predicted_files),
+ [[kwargs]] * len(predicted_files)
+ )
+ )
+ compute_metrics_on_folder(folder_ref,
+ output_here,
+ join(output_here, 'summary.json'),
+ rw,
+ dataset_json['file_ending'],
+ labels_or_regions,
+ label_manager.ignore_label,
+ num_processes)
+ # now we need to figure out if doing this improved the dice scores. We will implement that defensively in so far
+ # that if a single class got worse as a result we won't do this. We can change this in the future but right now I
+ # prefer to do it this way
+ baseline_results = load_summary_json(join(folder_predictions, 'summary.json'))
+ pp_results = load_summary_json(join(output_here, 'summary.json'))
+ do_this = pp_results['foreground_mean']['Dice'] > baseline_results['foreground_mean']['Dice']
+ if do_this:
+ for class_id in pp_results['mean'].keys():
+ if pp_results['mean'][class_id]['Dice'] < baseline_results['mean'][class_id]['Dice']:
+ do_this = False
+ break
+ if do_this:
+ print(f'Results were improved by removing all but the largest foreground region. '
+ f'Mean dice before: {round(baseline_results["foreground_mean"]["Dice"], 5)} '
+ f'after: {round(pp_results["foreground_mean"]["Dice"], 5)}')
+ source = output_here
+ pp_fns.append(pp_fn)
+ pp_fn_kwargs.append(kwargs)
+ else:
+ print(f'Removing all but the largest foreground region did not improve results!')
+ source = folder_predictions
+
+ # in the old nnU-Net we could just apply all-but-largest component removal to all classes at the same time and
+ # then evaluate for each class whether this improved results. This is no longer possible because we now support
+ # region-based predictions and regions can overlap, causing interactions
+ # in principle the order with which the postprocessing is applied to the regions matter as well and should be
+ # investigated, but due to some things that I am too lazy to explain right now it's going to be alright (I think)
+ # to stick to the order in which they are declared in dataset.json (if you want to think about it then think about
+ # region_class_order)
+ # 2023_02_06: I hate myself for the comment above. Thanks past me
+ if len(labels_or_regions) > 1:
+ for label_or_region in labels_or_regions:
+ pp_fn = remove_all_but_largest_component_from_segmentation
+ kwargs = {
+ 'labels_or_regions': label_or_region,
+ }
+
+ output_here = join(output_folder, 'temp', 'keep_largest_perClassOrRegion')
+ maybe_mkdir_p(output_here)
+
+ pool.starmap(
+ load_postprocess_save,
+ zip(
+ [join(source, i) for i in predicted_files],
+ [join(output_here, i) for i in predicted_files],
+ [rw] * len(predicted_files),
+ [[pp_fn]] * len(predicted_files),
+ [[kwargs]] * len(predicted_files)
+ )
+ )
+ compute_metrics_on_folder(folder_ref,
+ output_here,
+ join(output_here, 'summary.json'),
+ rw,
+ dataset_json['file_ending'],
+ labels_or_regions,
+ label_manager.ignore_label,
+ num_processes)
+ baseline_results = load_summary_json(join(source, 'summary.json'))
+ pp_results = load_summary_json(join(output_here, 'summary.json'))
+ do_this = pp_results['mean'][label_or_region]['Dice'] > baseline_results['mean'][label_or_region]['Dice']
+ if do_this:
+ print(f'Results were improved by removing all but the largest component for {label_or_region}. '
+ f'Dice before: {round(baseline_results["mean"][label_or_region]["Dice"], 5)} '
+ f'after: {round(pp_results["mean"][label_or_region]["Dice"], 5)}')
+ if isdir(join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest')):
+ shutil.rmtree(join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest'))
+ shutil.move(output_here, join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest'), )
+ source = join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest')
+ pp_fns.append(pp_fn)
+ pp_fn_kwargs.append(kwargs)
+ else:
+ print(f'Removing all but the largest component for {label_or_region} did not improve results! '
+ f'Dice before: {round(baseline_results["mean"][label_or_region]["Dice"], 5)} '
+ f'after: {round(pp_results["mean"][label_or_region]["Dice"], 5)}')
+ [shutil.copy(join(source, i), join(output_folder, i)) for i in subfiles(source, join=False)]
+ save_pickle((pp_fns, pp_fn_kwargs), join(folder_predictions, 'postprocessing.pkl'))
+
+ baseline_results = load_summary_json(join(folder_predictions, 'summary.json'))
+ final_results = load_summary_json(join(output_folder, 'summary.json'))
+ tmp = {
+ 'input_folder': {i: baseline_results[i] for i in ['foreground_mean', 'mean']},
+ 'postprocessed': {i: final_results[i] for i in ['foreground_mean', 'mean']},
+ 'postprocessing_fns': [i.__name__ for i in pp_fns],
+ 'postprocessing_kwargs': pp_fn_kwargs,
+ }
+ # json is a very annoying little bi###. Can't handle tuples as dict keys.
+ tmp['input_folder']['mean'] = {label_or_region_to_key(k): tmp['input_folder']['mean'][k] for k in
+ tmp['input_folder']['mean'].keys()}
+ tmp['postprocessed']['mean'] = {label_or_region_to_key(k): tmp['postprocessed']['mean'][k] for k in
+ tmp['postprocessed']['mean'].keys()}
+ # did I already say that I hate json? "TypeError: Object of type int64 is not JSON serializable" You retarded bro?
+ recursive_fix_for_json_export(tmp)
+ save_json(tmp, join(folder_predictions, 'postprocessing.json'))
+
+ shutil.rmtree(join(output_folder, 'temp'))
+
+ if not keep_postprocessed_files:
+ shutil.rmtree(output_folder)
+ return pp_fns, pp_fn_kwargs
+
+
+def apply_postprocessing_to_folder(input_folder: str,
+ output_folder: str,
+ pp_fns: List[Callable],
+ pp_fn_kwargs: List[dict],
+ plans_file_or_dict: Union[str, dict] = None,
+ dataset_json_file_or_dict: Union[str, dict] = None,
+ num_processes=8) -> None:
+ """
+ If plans_file_or_dict or dataset_json_file_or_dict are None, we will look for them in input_folder
+ """
+ if plans_file_or_dict is None:
+ expected_plans_file = join(input_folder, 'plans.json')
+ if not isfile(expected_plans_file):
+ raise RuntimeError(f"Expected plans file missing: {expected_plans_file}. The plans file should have been "
+ f"created while running nnUNetv2_predict. Sadge. If the folder you want to apply "
+ f"postprocessing to was create from an ensemble then just specify one of the "
+ f"plans files of the ensemble members in plans_file_or_dict")
+ plans_file_or_dict = load_json(expected_plans_file)
+ plans_manager = PlansManager(plans_file_or_dict)
+
+ if dataset_json_file_or_dict is None:
+ expected_dataset_json_file = join(input_folder, 'dataset.json')
+ if not isfile(expected_dataset_json_file):
+ raise RuntimeError(
+ f"Expected plans file missing: {expected_dataset_json_file}. The dataset.json should have been "
+ f"copied while running nnUNetv2_predict/nnUNetv2_ensemble. Sadge.")
+ dataset_json_file_or_dict = load_json(expected_dataset_json_file)
+
+ if not isinstance(dataset_json_file_or_dict, dict):
+ dataset_json = load_json(dataset_json_file_or_dict)
+ else:
+ dataset_json = dataset_json_file_or_dict
+
+ rw = plans_manager.image_reader_writer_class()
+
+ maybe_mkdir_p(output_folder)
+ with multiprocessing.get_context("spawn").Pool(num_processes) as p:
+ files = subfiles(input_folder, suffix=dataset_json['file_ending'], join=False)
+
+ _ = p.starmap(load_postprocess_save,
+ zip(
+ [join(input_folder, i) for i in files],
+ [join(output_folder, i) for i in files],
+ [rw] * len(files),
+ [pp_fns] * len(files),
+ [pp_fn_kwargs] * len(files)
+ )
+ )
+
+
+def entry_point_determine_postprocessing_folder():
+ parser = argparse.ArgumentParser('Writes postprocessing.pkl and postprocessing.json in input_folder.')
+ parser.add_argument('-i', type=str, required=True, help='Input folder')
+ parser.add_argument('-ref', type=str, required=True, help='Folder with gt labels')
+ parser.add_argument('-plans_json', type=str, required=False, default=None,
+ help="plans file to use. If not specified we will look for the plans.json file in the "
+ "input folder (input_folder/plans.json)")
+ parser.add_argument('-dataset_json', type=str, required=False, default=None,
+ help="dataset.json file to use. If not specified we will look for the dataset.json file in the "
+ "input folder (input_folder/dataset.json)")
+ parser.add_argument('-np', type=int, required=False, default=default_num_processes,
+ help=f"number of processes to use. Default: {default_num_processes}")
+ parser.add_argument('--remove_postprocessed', action='store_true', required=False,
+ help='set this is you don\'t want to keep the postprocessed files')
+
+ args = parser.parse_args()
+ determine_postprocessing(args.i, args.ref, args.plans_json, args.dataset_json, args.np,
+ not args.remove_postprocessed)
+
+
+def entry_point_apply_postprocessing():
+ parser = argparse.ArgumentParser('Apples postprocessing specified in pp_pkl_file to input folder.')
+ parser.add_argument('-i', type=str, required=True, help='Input folder')
+ parser.add_argument('-o', type=str, required=True, help='Output folder')
+ parser.add_argument('-pp_pkl_file', type=str, required=True, help='postprocessing.pkl file')
+ parser.add_argument('-np', type=int, required=False, default=default_num_processes,
+ help=f"number of processes to use. Default: {default_num_processes}")
+ parser.add_argument('-plans_json', type=str, required=False, default=None,
+ help="plans file to use. If not specified we will look for the plans.json file in the "
+ "input folder (input_folder/plans.json)")
+ parser.add_argument('-dataset_json', type=str, required=False, default=None,
+ help="dataset.json file to use. If not specified we will look for the dataset.json file in the "
+ "input folder (input_folder/dataset.json)")
+ args = parser.parse_args()
+ pp_fns, pp_fn_kwargs = load_pickle(args.pp_pkl_file)
+ apply_postprocessing_to_folder(args.i, args.o, pp_fns, pp_fn_kwargs, args.plans_json, args.dataset_json, args.np)
+
+
+if __name__ == '__main__':
+ trained_model_folder = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetTrainer__nnUNetPlans__3d_fullres'
+ labelstr = join(nnUNet_raw, 'Dataset004_Hippocampus', 'labelsTr')
+ plans_manager = PlansManager(join(trained_model_folder, 'plans.json'))
+ dataset_json = load_json(join(trained_model_folder, 'dataset.json'))
+ folds = (0, 1, 2, 3, 4)
+ label_manager = plans_manager.get_label_manager(dataset_json)
+
+ merged_output_folder = join(trained_model_folder, f'crossval_results_folds_{folds_tuple_to_string(folds)}')
+ accumulate_cv_results(trained_model_folder, merged_output_folder, folds, 8, False)
+
+ fns, kwargs = determine_postprocessing(merged_output_folder, labelstr, plans_manager.plans,
+ dataset_json, 8, keep_postprocessed_files=True)
+ save_pickle((fns, kwargs), join(trained_model_folder, 'postprocessing.pkl'))
+ fns, kwargs = load_pickle(join(trained_model_folder, 'postprocessing.pkl'))
+
+ apply_postprocessing_to_folder(merged_output_folder, merged_output_folder + '_pp', fns, kwargs,
+ plans_manager.plans, dataset_json,
+ 8)
+ compute_metrics_on_folder(labelstr,
+ merged_output_folder + '_pp',
+ join(merged_output_folder + '_pp', 'summary.json'),
+ plans_manager.image_reader_writer_class(),
+ dataset_json['file_ending'],
+ label_manager.foreground_regions if label_manager.has_regions else label_manager.foreground_labels,
+ label_manager.ignore_label,
+ 8)
diff --git a/nnUNet/nnunetv2/preprocessing/__init__.py b/nnUNet/nnunetv2/preprocessing/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/preprocessing/cropping/__init__.py b/nnUNet/nnunetv2/preprocessing/cropping/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/preprocessing/cropping/cropping.py b/nnUNet/nnunetv2/preprocessing/cropping/cropping.py
new file mode 100644
index 0000000..33421c3
--- /dev/null
+++ b/nnUNet/nnunetv2/preprocessing/cropping/cropping.py
@@ -0,0 +1,51 @@
+import numpy as np
+
+
+# Hello! crop_to_nonzero is the function you are looking for. Ignore the rest.
+from acvl_utils.cropping_and_padding.bounding_boxes import get_bbox_from_mask, crop_to_bbox, bounding_box_to_slice
+
+
+def create_nonzero_mask(data):
+ """
+
+ :param data:
+ :return: the mask is True where the data is nonzero
+ """
+ from scipy.ndimage import binary_fill_holes
+ assert len(data.shape) == 4 or len(data.shape) == 3, "data must have shape (C, X, Y, Z) or shape (C, X, Y)"
+ nonzero_mask = np.ones(data.shape[1:], dtype=bool)
+ for c in range(data.shape[0]):
+ this_mask = data[c] != 0
+ nonzero_mask = nonzero_mask | this_mask
+ nonzero_mask = binary_fill_holes(nonzero_mask)
+ return nonzero_mask
+
+
+def crop_to_nonzero(data, seg=None, nonzero_label=-1):
+ """
+
+ :param data:
+ :param seg:
+ :param nonzero_label: this will be written into the segmentation map
+ :return:
+ """
+ nonzero_mask = create_nonzero_mask(data)
+ bbox = get_bbox_from_mask(nonzero_mask)
+
+ slicer = bounding_box_to_slice(bbox)
+ data = data[tuple([slice(None), *slicer])]
+
+ if seg is not None:
+ seg = seg[tuple([slice(None), *slicer])]
+
+ nonzero_mask = nonzero_mask[slicer][None]
+ if seg is not None:
+ seg[(seg == 0) & (~nonzero_mask)] = nonzero_label
+ else:
+ nonzero_mask = nonzero_mask.astype(np.int8)
+ nonzero_mask[nonzero_mask == 0] = nonzero_label
+ nonzero_mask[nonzero_mask > 0] = 0
+ seg = nonzero_mask
+ return data, seg, bbox
+
+
diff --git a/nnUNet/nnunetv2/preprocessing/normalization/__init__.py b/nnUNet/nnunetv2/preprocessing/normalization/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/preprocessing/normalization/default_normalization_schemes.py b/nnUNet/nnunetv2/preprocessing/normalization/default_normalization_schemes.py
new file mode 100644
index 0000000..3c90a91
--- /dev/null
+++ b/nnUNet/nnunetv2/preprocessing/normalization/default_normalization_schemes.py
@@ -0,0 +1,95 @@
+from abc import ABC, abstractmethod
+from typing import Type
+
+import numpy as np
+from numpy import number
+
+
+class ImageNormalization(ABC):
+ leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = None
+
+ def __init__(self, use_mask_for_norm: bool = None, intensityproperties: dict = None,
+ target_dtype: Type[number] = np.float32):
+ assert use_mask_for_norm is None or isinstance(use_mask_for_norm, bool)
+ self.use_mask_for_norm = use_mask_for_norm
+ assert isinstance(intensityproperties, dict)
+ self.intensityproperties = intensityproperties
+ self.target_dtype = target_dtype
+
+ @abstractmethod
+ def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray:
+ """
+ Image and seg must have the same shape. Seg is not always used
+ """
+ pass
+
+
+class ZScoreNormalization(ImageNormalization):
+ leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = True
+
+ def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray:
+ """
+ here seg is used to store the zero valued region. The value for that region in the segmentation is -1 by
+ default.
+ """
+ image = image.astype(self.target_dtype)
+ if self.use_mask_for_norm is not None and self.use_mask_for_norm:
+ # negative values in the segmentation encode the 'outside' region (think zero values around the brain as
+ # in BraTS). We want to run the normalization only in the brain region, so we need to mask the image.
+ # The default nnU-net sets use_mask_for_norm to True if cropping to the nonzero region substantially
+ # reduced the image size.
+ mask = seg >= 0
+ mean = image[mask].mean()
+ std = image[mask].std()
+ image[mask] = (image[mask] - mean) / (max(std, 1e-8))
+ else:
+ mean = image.mean()
+ std = image.std()
+ image = (image - mean) / (max(std, 1e-8))
+ return image
+
+
+class CTNormalization(ImageNormalization):
+ leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False
+
+ def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray:
+ assert self.intensityproperties is not None, "CTNormalization requires intensity properties"
+ image = image.astype(self.target_dtype)
+ mean_intensity = self.intensityproperties['mean']
+ std_intensity = self.intensityproperties['std']
+ lower_bound = self.intensityproperties['percentile_00_5']
+ upper_bound = self.intensityproperties['percentile_99_5']
+ image = np.clip(image, lower_bound, upper_bound)
+ image = (image - mean_intensity) / max(std_intensity, 1e-8)
+ return image
+
+
+class NoNormalization(ImageNormalization):
+ leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False
+
+ def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray:
+ return image.astype(self.target_dtype)
+
+
+class RescaleTo01Normalization(ImageNormalization):
+ leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False
+
+ def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray:
+ image = image.astype(self.target_dtype)
+ image = image - image.min()
+ image = image / np.clip(image.max(), a_min=1e-8, a_max=None)
+ return image
+
+
+class RGBTo01Normalization(ImageNormalization):
+ leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False
+
+ def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray:
+ assert image.min() >= 0, "RGB images are uint 8, for whatever reason I found pixel values smaller than 0. " \
+ "Your images do not seem to be RGB images"
+ assert image.max() <= 255, "RGB images are uint 8, for whatever reason I found pixel values greater than 255" \
+ ". Your images do not seem to be RGB images"
+ image = image.astype(self.target_dtype)
+ image = image / 255.
+ return image
+
diff --git a/nnUNet/nnunetv2/preprocessing/normalization/map_channel_name_to_normalization.py b/nnUNet/nnunetv2/preprocessing/normalization/map_channel_name_to_normalization.py
new file mode 100644
index 0000000..e821650
--- /dev/null
+++ b/nnUNet/nnunetv2/preprocessing/normalization/map_channel_name_to_normalization.py
@@ -0,0 +1,24 @@
+from typing import Type
+
+from nnunetv2.preprocessing.normalization.default_normalization_schemes import CTNormalization, NoNormalization, \
+ ZScoreNormalization, RescaleTo01Normalization, RGBTo01Normalization, ImageNormalization
+
+channel_name_to_normalization_mapping = {
+ 'CT': CTNormalization,
+ 'noNorm': NoNormalization,
+ 'zscore': ZScoreNormalization,
+ 'rescale_0_1': RescaleTo01Normalization,
+ 'rgb_to_0_1': RGBTo01Normalization
+}
+
+
+def get_normalization_scheme(channel_name: str) -> Type[ImageNormalization]:
+ """
+ If we find the channel_name in channel_name_to_normalization_mapping return the corresponding normalization. If it is
+ not found, use the default (ZScoreNormalization)
+ """
+ norm_scheme = channel_name_to_normalization_mapping.get(channel_name)
+ if norm_scheme is None:
+ norm_scheme = ZScoreNormalization
+ # print('Using %s for image normalization' % norm_scheme.__name__)
+ return norm_scheme
diff --git a/nnUNet/nnunetv2/preprocessing/normalization/readme.md b/nnUNet/nnunetv2/preprocessing/normalization/readme.md
new file mode 100644
index 0000000..7b54396
--- /dev/null
+++ b/nnUNet/nnunetv2/preprocessing/normalization/readme.md
@@ -0,0 +1,5 @@
+The channel_names entry in dataset.json only determines the normlaization scheme. So if you want to use something different
+then you can just
+- create a new subclass of ImageNormalization
+- map your custom channel identifier to that subclass in channel_name_to_normalization_mapping
+- run plan and preprocess again with your custom normlaization scheme
\ No newline at end of file
diff --git a/nnUNet/nnunetv2/preprocessing/preprocessors/__init__.py b/nnUNet/nnunetv2/preprocessing/preprocessors/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/preprocessing/preprocessors/default_preprocessor.py b/nnUNet/nnunetv2/preprocessing/preprocessors/default_preprocessor.py
new file mode 100644
index 0000000..685ffcd
--- /dev/null
+++ b/nnUNet/nnunetv2/preprocessing/preprocessors/default_preprocessor.py
@@ -0,0 +1,302 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import multiprocessing
+import shutil
+from time import sleep
+from typing import Union, Tuple
+
+import nnunetv2
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw
+from nnunetv2.preprocessing.cropping.cropping import crop_to_nonzero
+from nnunetv2.preprocessing.resampling.default_resampling import compute_new_shape
+from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
+from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
+from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
+from nnunetv2.utilities.utils import get_identifiers_from_splitted_dataset_folder, \
+ create_lists_from_splitted_dataset_folder, get_filenames_of_train_images_and_targets
+from tqdm import tqdm
+
+
+class DefaultPreprocessor(object):
+ def __init__(self, verbose: bool = True):
+ self.verbose = verbose
+ """
+ Everything we need is in the plans. Those are given when run() is called
+ """
+
+ def run_case_npy(self, data: np.ndarray, seg: Union[np.ndarray, None], properties: dict,
+ plans_manager: PlansManager, configuration_manager: ConfigurationManager,
+ dataset_json: Union[dict, str]):
+ # let's not mess up the inputs!
+ data = np.copy(data)
+ if seg is not None:
+ seg = np.copy(seg)
+
+ has_seg = seg is not None
+
+ # apply transpose_forward, this also needs to be applied to the spacing!
+ data = data.transpose([0, *[i + 1 for i in plans_manager.transpose_forward]])
+ if seg is not None:
+ seg = seg.transpose([0, *[i + 1 for i in plans_manager.transpose_forward]])
+ original_spacing = [properties['spacing'][i] for i in plans_manager.transpose_forward]
+
+ # crop, remember to store size before cropping!
+ shape_before_cropping = data.shape[1:]
+ properties['shape_before_cropping'] = shape_before_cropping
+ # this command will generate a segmentation. This is important because of the nonzero mask which we may need
+ data, seg, bbox = crop_to_nonzero(data, seg)
+ properties['bbox_used_for_cropping'] = bbox
+ # print(data.shape, seg.shape)
+ properties['shape_after_cropping_and_before_resampling'] = data.shape[1:]
+
+ # resample
+ target_spacing = configuration_manager.spacing # this should already be transposed
+
+ if len(target_spacing) < len(data.shape[1:]):
+ # target spacing for 2d has 2 entries but the data and original_spacing have three because everything is 3d
+ # in 3d we do not change the spacing between slices
+ target_spacing = [original_spacing[0]] + target_spacing
+ new_shape = compute_new_shape(data.shape[1:], original_spacing, target_spacing)
+
+ # normalize
+ # normalization MUST happen before resampling or we get huge problems with resampled nonzero masks no
+ # longer fitting the images perfectly!
+ data = self._normalize(data, seg, configuration_manager,
+ plans_manager.foreground_intensity_properties_per_channel)
+
+ # print('current shape', data.shape[1:], 'current_spacing', original_spacing,
+ # '\ntarget shape', new_shape, 'target_spacing', target_spacing)
+ old_shape = data.shape[1:]
+ data = configuration_manager.resampling_fn_data(data, new_shape, original_spacing, target_spacing)
+ print(f"Actual new_shape: {data.shape}", flush=True)
+ print(type(seg))
+ if has_seg:
+ seg = configuration_manager.resampling_fn_seg(seg, new_shape, original_spacing, target_spacing)
+ else:
+ seg = None
+ if self.verbose:
+ print(f'old shape: {old_shape}, new_shape: {new_shape}, old_spacing: {original_spacing}, '
+ f'new_spacing: {target_spacing}, fn_data: {configuration_manager.resampling_fn_data}')
+
+ # if we have a segmentation, sample foreground locations for oversampling and add those to properties
+ if has_seg:
+ # reinstantiating LabelManager for each case is not ideal. We could replace the dataset_json argument
+ # with a LabelManager Instance in this function because that's all its used for. Dunno what's better.
+ # LabelManager is pretty light computation-wise.
+ label_manager = plans_manager.get_label_manager(dataset_json)
+ collect_for_this = label_manager.foreground_regions if label_manager.has_regions \
+ else label_manager.foreground_labels
+
+ # when using the ignore label we want to sample only from annotated regions. Therefore we also need to
+ # collect samples uniformly from all classes (incl background)
+ if label_manager.has_ignore_label:
+ collect_for_this.append(label_manager.all_labels)
+
+ # no need to filter background in regions because it is already filtered in handle_labels
+ # print(all_labels, regions)
+ properties['class_locations'] = self._sample_foreground_locations(seg, collect_for_this,
+ verbose=self.verbose)
+ seg = self.modify_seg_fn(seg, plans_manager, dataset_json, configuration_manager)
+
+ if np.max(seg) > 127:
+ seg = seg.astype(np.int16)
+ else:
+ seg = seg.astype(np.int8)
+ # data = data.astype(np.float16)
+ return data, seg
+
+ def run_case(self, image_files: List[str], seg_file: Union[str, None], plans_manager: PlansManager,
+ configuration_manager: ConfigurationManager,
+ dataset_json: Union[dict, str]):
+ """
+ seg file can be none (test cases)
+
+ order of operations is: transpose -> crop -> resample
+ so when we export we need to run the following order: resample -> crop -> transpose (we could also run
+ transpose at a different place, but reverting the order of operations done during preprocessing seems cleaner)
+ """
+ if isinstance(dataset_json, str):
+ dataset_json = load_json(dataset_json)
+
+ rw = plans_manager.image_reader_writer_class()
+
+ # load image(s)
+ data, data_properites = rw.read_images(image_files)
+
+ # if possible, load seg
+ if seg_file is not None:
+ seg, _ = rw.read_seg(seg_file)
+ else:
+ seg = None
+
+ data, seg = self.run_case_npy(data, seg, data_properites, plans_manager, configuration_manager,
+ dataset_json)
+ return data, seg, data_properites
+
+ def run_case_save(self, output_filename_truncated: str, image_files: List[str], seg_file: str,
+ plans_manager: PlansManager, configuration_manager: ConfigurationManager,
+ dataset_json: Union[dict, str]):
+ data, seg, properties = self.run_case(image_files, seg_file, plans_manager, configuration_manager, dataset_json)
+ # print('dtypes', data.dtype, seg.dtype)
+ np.savez_compressed(output_filename_truncated + '.npz', data=data, seg=seg)
+ write_pickle(properties, output_filename_truncated + '.pkl')
+
+ @staticmethod
+ def _sample_foreground_locations(seg: np.ndarray, classes_or_regions: Union[List[int], List[Tuple[int, ...]]],
+ seed: int = 1234, verbose: bool = False):
+ num_samples = 10000
+ min_percent_coverage = 0.01 # at least 1% of the class voxels need to be selected, otherwise it may be too
+ # sparse
+ rndst = np.random.RandomState(seed)
+ class_locs = {}
+ for c in classes_or_regions:
+ k = c if not isinstance(c, list) else tuple(c)
+ if isinstance(c, (tuple, list)):
+ mask = seg == c[0]
+ for cc in c[1:]:
+ mask = mask | (seg == cc)
+ all_locs = np.argwhere(mask)
+ else:
+ all_locs = np.argwhere(seg == c)
+ if len(all_locs) == 0:
+ class_locs[k] = []
+ continue
+ target_num_samples = min(num_samples, len(all_locs))
+ target_num_samples = max(target_num_samples, int(np.ceil(len(all_locs) * min_percent_coverage)))
+
+ selected = all_locs[rndst.choice(len(all_locs), target_num_samples, replace=False)]
+ class_locs[k] = selected
+ if verbose:
+ print(c, target_num_samples)
+ return class_locs
+
+ def _normalize(self, data: np.ndarray, seg: np.ndarray, configuration_manager: ConfigurationManager,
+ foreground_intensity_properties_per_channel: dict) -> np.ndarray:
+ for c in range(data.shape[0]):
+ scheme = configuration_manager.normalization_schemes[c]
+ normalizer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "preprocessing", "normalization"),
+ scheme,
+ 'nnunetv2.preprocessing.normalization')
+ if normalizer_class is None:
+ raise RuntimeError('Unable to locate class \'%s\' for normalization' % scheme)
+ normalizer = normalizer_class(use_mask_for_norm=configuration_manager.use_mask_for_norm[c],
+ intensityproperties=foreground_intensity_properties_per_channel[str(c)])
+ data[c] = normalizer.run(data[c], seg[0])
+ return data
+
+ def run(self, dataset_name_or_id: Union[int, str], configuration_name: str, plans_identifier: str,
+ num_processes: int):
+ """
+ data identifier = configuration name in plans. EZ.
+ """
+ dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id)
+
+ assert isdir(join(nnUNet_raw, dataset_name)), "The requested dataset could not be found in nnUNet_raw"
+
+ plans_file = join(nnUNet_preprocessed, dataset_name, plans_identifier + '.json')
+ assert isfile(plans_file), "Expected plans file (%s) not found. Run corresponding nnUNet_plan_experiment " \
+ "first." % plans_file
+ plans = load_json(plans_file)
+ plans_manager = PlansManager(plans)
+ configuration_manager = plans_manager.get_configuration(configuration_name)
+
+ if self.verbose:
+ print(f'Preprocessing the following configuration: {configuration_name}')
+ if self.verbose:
+ print(configuration_manager)
+
+ dataset_json_file = join(nnUNet_preprocessed, dataset_name, 'dataset.json')
+ dataset_json = load_json(dataset_json_file)
+
+ output_directory = join(nnUNet_preprocessed, dataset_name, configuration_manager.data_identifier)
+
+ if isdir(output_directory):
+ shutil.rmtree(output_directory)
+
+ maybe_mkdir_p(output_directory)
+
+ dataset = get_filenames_of_train_images_and_targets(join(nnUNet_raw, dataset_name), dataset_json)
+
+ # identifiers = [os.path.basename(i[:-len(dataset_json['file_ending'])]) for i in seg_fnames]
+ # output_filenames_truncated = [join(output_directory, i) for i in identifiers]
+
+ # multiprocessing magic.
+ r = []
+ with multiprocessing.get_context("spawn").Pool(num_processes) as p:
+ for k in dataset.keys():
+ r.append(p.starmap_async(self.run_case_save,
+ ((join(output_directory, k), dataset[k]['images'], dataset[k]['label'],
+ plans_manager, configuration_manager,
+ dataset_json),)))
+ remaining = list(range(len(dataset)))
+ # p is pretty nifti. If we kill workers they just respawn but don't do any work.
+ # So we need to store the original pool of workers.
+ workers = [j for j in p._pool]
+ with tqdm(desc=None, total=len(dataset), disable=self.verbose) as pbar:
+ while len(remaining) > 0:
+ all_alive = all([j.is_alive() for j in workers])
+ if not all_alive:
+ raise RuntimeError('Some background worker is 6 feet under. Yuck. \n'
+ 'OK jokes aside.\n'
+ 'One of your background processes is missing. This could be because of '
+ 'an error (look for an error message) or because it was killed '
+ 'by your OS due to running out of RAM. If you don\'t see '
+ 'an error message, out of RAM is likely the problem. In that case '
+ 'reducing the number of workers might help')
+ done = [i for i in remaining if r[i].ready()]
+ for _ in done:
+ pbar.update()
+ remaining = [i for i in remaining if i not in done]
+ sleep(0.1)
+
+ def modify_seg_fn(self, seg: np.ndarray, plans_manager: PlansManager, dataset_json: dict,
+ configuration_manager: ConfigurationManager) -> np.ndarray:
+ # this function will be called at the end of self.run_case. Can be used to change the segmentation
+ # after resampling. Useful for experimenting with sparse annotations: I can introduce sparsity after resampling
+ # and don't have to create a new dataset each time I modify my experiments
+ return seg
+
+
+def example_test_case_preprocessing():
+ # (paths to files may need adaptations)
+ plans_file = '/home/isensee/drives/gpu_data/nnUNet_preprocessed/Dataset219_AMOS2022_postChallenge_task2/nnUNetPlans.json'
+ dataset_json_file = '/home/isensee/drives/gpu_data/nnUNet_preprocessed/Dataset219_AMOS2022_postChallenge_task2/dataset.json'
+ input_images = ['/home/isensee/drives/e132-rohdaten/nnUNetv2/Dataset219_AMOS2022_postChallenge_task2/imagesTr/amos_0600_0000.nii.gz', ] # if you only have one channel, you still need a list: ['case000_0000.nii.gz']
+
+ configuration = '3d_fullres'
+ pp = DefaultPreprocessor()
+
+ # _ because this position would be the segmentation if seg_file was not None (training case)
+ # even if you have the segmentation, don't put the file there! You should always evaluate in the original
+ # resolution. What comes out of the preprocessor might have been resampled to some other image resolution (as
+ # specified by plans)
+ plans_manager = PlansManager(plans_file)
+ data, _, properties = pp.run_case(input_images, seg_file=None, plans_manager=plans_manager,
+ configuration_manager=plans_manager.get_configuration(configuration),
+ dataset_json=dataset_json_file)
+
+ # voila. Now plug data into your prediction function of choice. We of course recommend nnU-Net's default (TODO)
+ return data
+
+
+if __name__ == '__main__':
+ example_test_case_preprocessing()
+ # pp = DefaultPreprocessor()
+ # pp.run(2, '2d', 'nnUNetPlans', 8)
+
+ ###########################################################################################################
+ # how to process a test cases? This is an example:
+ # example_test_case_preprocessing()
diff --git a/nnUNet/nnunetv2/preprocessing/preprocessors/default_preprocessor.py.save b/nnUNet/nnunetv2/preprocessing/preprocessors/default_preprocessor.py.save
new file mode 100644
index 0000000..ee6c9e6
--- /dev/null
+++ b/nnUNet/nnunetv2/preprocessing/preprocessors/default_preprocessor.py.save
@@ -0,0 +1,301 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import multiprocessing
+import shutil
+from time import sleep
+from typing import Union, Tuple
+
+import nnunetv2
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw
+from nnunetv2.preprocessing.cropping.cropping import crop_to_nonzero
+from nnunetv2.preprocessing.resampling.default_resampling import compute_new_shape
+from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
+from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
+from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
+from nnunetv2.utilities.utils import get_identifiers_from_splitted_dataset_folder, \
+ create_lists_from_splitted_dataset_folder, get_filenames_of_train_images_and_targets
+from tqdm import tqdm
+
+
+class DefaultPreprocessor(object):
+ def __init__(self, verbose: bool = True):
+ self.verbose = verbose
+ """
+ Everything we need is in the plans. Those are given when run() is called
+ """
+
+ def run_case_npy(self, data: np.ndarray, seg: Union[np.ndarray, None], properties: dict,
+ plans_manager: PlansManager, configuration_manager: ConfigurationManager,
+ dataset_json: Union[dict, str]):
+ # let's not mess up the inputs!
+ data = np.copy(data)
+ if seg is not None:
+ seg = np.copy(seg)
+
+ has_seg = seg is not None
+
+ # apply transpose_forward, this also needs to be applied to the spacing!
+ data = data.transpose([0, *[i + 1 for i in plans_manager.transpose_forward]])
+ if seg is not None:
+ seg = seg.transpose([0, *[i + 1 for i in plans_manager.transpose_forward]])
+ original_spacing = [properties['spacing'][i] for i in plans_manager.transpose_forward]
+
+ # crop, remember to store size before cropping!
+ shape_before_cropping = data.shape[1:]
+ properties['shape_before_cropping'] = shape_before_cropping
+ # this command will generate a segmentation. This is important because of the nonzero mask which we may need
+ data, seg, bbox = crop_to_nonzero(data, seg)
+ properties['bbox_used_for_cropping'] = bbox
+ # print(data.shape, seg.shape)
+ properties['shape_after_cropping_and_before_resampling'] = data.shape[1:]
+
+ # resample
+ target_spacing = configuration_manager.spacing # this should already be transposed
+
+ if len(target_spacing) < len(data.shape[1:]):
+ # target spacing for 2d has 2 entries but the data and original_spacing have three because everything is 3d
+ # in 3d we do not change the spacing between slices
+ target_spacing = [original_spacing[0]] + target_spacing
+ new_shape = compute_new_shape(data.shape[1:], original_spacing, target_spacing)
+
+ # normalize
+ # normalization MUST happen before resampling or we get huge problems with resampled nonzero masks no
+ # longer fitting the images perfectly!
+ data = self._normalize(data, seg, configuration_manager,
+ plans_manager.foreground_intensity_properties_per_channel)
+
+ # print('current shape', data.shape[1:], 'current_spacing', original_spacing,
+ # '\ntarget shape', new_shape, 'target_spacing', target_spacing)
+ old_shape = data.shape[1:]
+ data = configuration_manager.resampling_fn_data(data, new_shape, original_spacing, target_spacing)
+ print(f"Actual new_shape: {data.shape}", flush=True)
+ print(type(seg))
+ if has_seg:
+ seg = configuration_manager.resampling_fn_seg(seg, new_shape, original_spacing, target_spacing)
+ if self.verbose:
+ print(f'old shape: {old_shape}, new_shape: {new_shape}, old_spacing: {original_spacing}, '
+ f'new_spacing: {target_spacing}, fn_data: {configuration_manager.resampling_fn_data}')
+
+ # if we have a segmentation, sample foreground locations for oversampling and add those to properties
+ if has_seg:
+ # reinstantiating LabelManager for each case is not ideal. We could replace the dataset_json argument
+ # with a LabelManager Instance in this function because that's all its used for. Dunno what's better.
+ # LabelManager is pretty light computation-wise.
+ label_manager = plans_manager.get_label_manager(dataset_json)
+ collect_for_this = label_manager.foreground_regions if label_manager.has_regions \
+ else label_manager.foreground_labels
+
+ # when using the ignore label we want to sample only from annotated regions. Therefore we also need to
+ # collect samples uniformly from all classes (incl background)
+ if label_manager.has_ignore_label:
+ collect_for_this.append(label_manager.all_labels)
+
+ # no need to filter background in regions because it is already filtered in handle_labels
+ # print(all_labels, regions)
+ properties['class_locations'] = self._sample_foreground_locations(seg, collect_for_this,
+ verbose=self.verbose)
+ seg = self.modify_seg_fn(seg, plans_manager, dataset_json, configuration_manager)
+ if has_seg:
+ if np.max(seg) > 127:
+ seg = seg.astype(np.int16)
+ else:
+ seg = seg.astype(np.int8)
+ else:
+
+ return data, seg
+
+ def run_case(self, image_files: List[str], seg_file: Union[str, None], plans_manager: PlansManager,
+ configuration_manager: ConfigurationManager,
+ dataset_json: Union[dict, str]):
+ """
+ seg file can be none (test cases)
+
+ order of operations is: transpose -> crop -> resample
+ so when we export we need to run the following order: resample -> crop -> transpose (we could also run
+ transpose at a different place, but reverting the order of operations done during preprocessing seems cleaner)
+ """
+ if isinstance(dataset_json, str):
+ dataset_json = load_json(dataset_json)
+
+ rw = plans_manager.image_reader_writer_class()
+
+ # load image(s)
+ data, data_properites = rw.read_images(image_files)
+
+ # if possible, load seg
+ if seg_file is not None:
+ seg, _ = rw.read_seg(seg_file)
+ else:
+ seg = None
+
+ data, seg = self.run_case_npy(data, seg, data_properites, plans_manager, configuration_manager,
+ dataset_json)
+ return data, seg, data_properites
+
+ def run_case_save(self, output_filename_truncated: str, image_files: List[str], seg_file: str,
+ plans_manager: PlansManager, configuration_manager: ConfigurationManager,
+ dataset_json: Union[dict, str]):
+ data, seg, properties = self.run_case(image_files, seg_file, plans_manager, configuration_manager, dataset_json)
+ # print('dtypes', data.dtype, seg.dtype)
+ np.savez_compressed(output_filename_truncated + '.npz', data=data, seg=seg)
+ write_pickle(properties, output_filename_truncated + '.pkl')
+
+ @staticmethod
+ def _sample_foreground_locations(seg: np.ndarray, classes_or_regions: Union[List[int], List[Tuple[int, ...]]],
+ seed: int = 1234, verbose: bool = False):
+ num_samples = 10000
+ min_percent_coverage = 0.01 # at least 1% of the class voxels need to be selected, otherwise it may be too
+ # sparse
+ rndst = np.random.RandomState(seed)
+ class_locs = {}
+ for c in classes_or_regions:
+ k = c if not isinstance(c, list) else tuple(c)
+ if isinstance(c, (tuple, list)):
+ mask = seg == c[0]
+ for cc in c[1:]:
+ mask = mask | (seg == cc)
+ all_locs = np.argwhere(mask)
+ else:
+ all_locs = np.argwhere(seg == c)
+ if len(all_locs) == 0:
+ class_locs[k] = []
+ continue
+ target_num_samples = min(num_samples, len(all_locs))
+ target_num_samples = max(target_num_samples, int(np.ceil(len(all_locs) * min_percent_coverage)))
+
+ selected = all_locs[rndst.choice(len(all_locs), target_num_samples, replace=False)]
+ class_locs[k] = selected
+ if verbose:
+ print(c, target_num_samples)
+ return class_locs
+
+ def _normalize(self, data: np.ndarray, seg: np.ndarray, configuration_manager: ConfigurationManager,
+ foreground_intensity_properties_per_channel: dict) -> np.ndarray:
+ for c in range(data.shape[0]):
+ scheme = configuration_manager.normalization_schemes[c]
+ normalizer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "preprocessing", "normalization"),
+ scheme,
+ 'nnunetv2.preprocessing.normalization')
+ if normalizer_class is None:
+ raise RuntimeError('Unable to locate class \'%s\' for normalization' % scheme)
+ normalizer = normalizer_class(use_mask_for_norm=configuration_manager.use_mask_for_norm[c],
+ intensityproperties=foreground_intensity_properties_per_channel[str(c)])
+ data[c] = normalizer.run(data[c], seg[0])
+ return data
+
+ def run(self, dataset_name_or_id: Union[int, str], configuration_name: str, plans_identifier: str,
+ num_processes: int):
+ """
+ data identifier = configuration name in plans. EZ.
+ """
+ dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id)
+
+ assert isdir(join(nnUNet_raw, dataset_name)), "The requested dataset could not be found in nnUNet_raw"
+
+ plans_file = join(nnUNet_preprocessed, dataset_name, plans_identifier + '.json')
+ assert isfile(plans_file), "Expected plans file (%s) not found. Run corresponding nnUNet_plan_experiment " \
+ "first." % plans_file
+ plans = load_json(plans_file)
+ plans_manager = PlansManager(plans)
+ configuration_manager = plans_manager.get_configuration(configuration_name)
+
+ if self.verbose:
+ print(f'Preprocessing the following configuration: {configuration_name}')
+ if self.verbose:
+ print(configuration_manager)
+
+ dataset_json_file = join(nnUNet_preprocessed, dataset_name, 'dataset.json')
+ dataset_json = load_json(dataset_json_file)
+
+ output_directory = join(nnUNet_preprocessed, dataset_name, configuration_manager.data_identifier)
+
+ if isdir(output_directory):
+ shutil.rmtree(output_directory)
+
+ maybe_mkdir_p(output_directory)
+
+ dataset = get_filenames_of_train_images_and_targets(join(nnUNet_raw, dataset_name), dataset_json)
+
+ # identifiers = [os.path.basename(i[:-len(dataset_json['file_ending'])]) for i in seg_fnames]
+ # output_filenames_truncated = [join(output_directory, i) for i in identifiers]
+
+ # multiprocessing magic.
+ r = []
+ with multiprocessing.get_context("spawn").Pool(num_processes) as p:
+ for k in dataset.keys():
+ r.append(p.starmap_async(self.run_case_save,
+ ((join(output_directory, k), dataset[k]['images'], dataset[k]['label'],
+ plans_manager, configuration_manager,
+ dataset_json),)))
+ remaining = list(range(len(dataset)))
+ # p is pretty nifti. If we kill workers they just respawn but don't do any work.
+ # So we need to store the original pool of workers.
+ workers = [j for j in p._pool]
+ with tqdm(desc=None, total=len(dataset), disable=self.verbose) as pbar:
+ while len(remaining) > 0:
+ all_alive = all([j.is_alive() for j in workers])
+ if not all_alive:
+ raise RuntimeError('Some background worker is 6 feet under. Yuck. \n'
+ 'OK jokes aside.\n'
+ 'One of your background processes is missing. This could be because of '
+ 'an error (look for an error message) or because it was killed '
+ 'by your OS due to running out of RAM. If you don\'t see '
+ 'an error message, out of RAM is likely the problem. In that case '
+ 'reducing the number of workers might help')
+ done = [i for i in remaining if r[i].ready()]
+ for _ in done:
+ pbar.update()
+ remaining = [i for i in remaining if i not in done]
+ sleep(0.1)
+
+ def modify_seg_fn(self, seg: np.ndarray, plans_manager: PlansManager, dataset_json: dict,
+ configuration_manager: ConfigurationManager) -> np.ndarray:
+ # this function will be called at the end of self.run_case. Can be used to change the segmentation
+ # after resampling. Useful for experimenting with sparse annotations: I can introduce sparsity after resampling
+ # and don't have to create a new dataset each time I modify my experiments
+ return seg
+
+
+def example_test_case_preprocessing():
+ # (paths to files may need adaptations)
+ plans_file = '/home/isensee/drives/gpu_data/nnUNet_preprocessed/Dataset219_AMOS2022_postChallenge_task2/nnUNetPlans.json'
+ dataset_json_file = '/home/isensee/drives/gpu_data/nnUNet_preprocessed/Dataset219_AMOS2022_postChallenge_task2/dataset.json'
+ input_images = ['/home/isensee/drives/e132-rohdaten/nnUNetv2/Dataset219_AMOS2022_postChallenge_task2/imagesTr/amos_0600_0000.nii.gz', ] # if you only have one channel, you still need a list: ['case000_0000.nii.gz']
+
+ configuration = '3d_fullres'
+ pp = DefaultPreprocessor()
+
+ # _ because this position would be the segmentation if seg_file was not None (training case)
+ # even if you have the segmentation, don't put the file there! You should always evaluate in the original
+ # resolution. What comes out of the preprocessor might have been resampled to some other image resolution (as
+ # specified by plans)
+ plans_manager = PlansManager(plans_file)
+ data, _, properties = pp.run_case(input_images, seg_file=None, plans_manager=plans_manager,
+ configuration_manager=plans_manager.get_configuration(configuration),
+ dataset_json=dataset_json_file)
+
+ # voila. Now plug data into your prediction function of choice. We of course recommend nnU-Net's default (TODO)
+ return data
+
+
+if __name__ == '__main__':
+ example_test_case_preprocessing()
+ # pp = DefaultPreprocessor()
+ # pp.run(2, '2d', 'nnUNetPlans', 8)
+
+ ###########################################################################################################
+ # how to process a test cases? This is an example:
+ # example_test_case_preprocessing()
diff --git a/nnUNet/nnunetv2/preprocessing/resampling/__init__.py b/nnUNet/nnunetv2/preprocessing/resampling/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/preprocessing/resampling/default_resampling.py b/nnUNet/nnunetv2/preprocessing/resampling/default_resampling.py
new file mode 100644
index 0000000..0079739
--- /dev/null
+++ b/nnUNet/nnunetv2/preprocessing/resampling/default_resampling.py
@@ -0,0 +1,222 @@
+from collections import OrderedDict
+from typing import Union, Tuple, List
+
+import numpy as np
+import pandas as pd
+import torch
+from batchgenerators.augmentations.utils import resize_segmentation
+from scipy.ndimage.interpolation import map_coordinates
+from skimage.transform import resize
+from nnunetv2.configuration import ANISO_THRESHOLD
+
+
+def get_do_separate_z(spacing: Union[Tuple[float, ...], List[float], np.ndarray], anisotropy_threshold=ANISO_THRESHOLD):
+ do_separate_z = (np.max(spacing) / np.min(spacing)) > anisotropy_threshold
+ print(f"{np.max(spacing)/np.min(spacing)} > {anisotropy_threshold}", flush = True)
+ return do_separate_z
+
+
+def get_lowres_axis(new_spacing: Union[Tuple[float, ...], List[float], np.ndarray]):
+ axis = np.where(max(new_spacing) / np.array(new_spacing) == 1)[0] # find which axis is anisotropic
+ return axis
+
+
+def compute_new_shape(old_shape: Union[Tuple[int, ...], List[int], np.ndarray],
+ old_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
+ new_spacing: Union[Tuple[float, ...], List[float], np.ndarray]) -> np.ndarray:
+ assert len(old_spacing) == len(old_shape)
+ assert len(old_shape) == len(new_spacing)
+ new_shape = np.array([int(round(i / j * k)) for i, j, k in zip(old_spacing, new_spacing, old_shape)])
+ return new_shape
+
+
+def resample_data_or_seg_to_spacing(data: np.ndarray,
+ current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
+ new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
+ is_seg: bool = False,
+ order: int = 3, order_z: int = 0,
+ force_separate_z: Union[bool, None] = False,
+ separate_z_anisotropy_threshold: float = ANISO_THRESHOLD):
+ if force_separate_z is not None:
+ do_separate_z = force_separate_z
+ if force_separate_z:
+ axis = get_lowres_axis(current_spacing)
+ else:
+ axis = None
+ else:
+ if get_do_separate_z(current_spacing, separate_z_anisotropy_threshold):
+ do_separate_z = True
+ axis = get_lowres_axis(current_spacing)
+ elif get_do_separate_z(new_spacing, separate_z_anisotropy_threshold):
+ do_separate_z = True
+ axis = get_lowres_axis(new_spacing)
+ else:
+ do_separate_z = False
+ axis = None
+
+ if axis is not None:
+ if len(axis) == 3:
+ # every axis has the same spacing, this should never happen, why is this code here?
+ do_separate_z = False
+ elif len(axis) == 2:
+ # this happens for spacings like (0.24, 1.25, 1.25) for example. In that case we do not want to resample
+ # separately in the out of plane axis
+ do_separate_z = False
+ else:
+ pass
+
+ if data is not None:
+ assert len(data.shape) == 4, "data must be c x y z"
+
+ shape = np.array(data[0].shape)
+ new_shape = compute_new_shape(shape[1:], current_spacing, new_spacing)
+
+ data_reshaped = resample_data_or_seg(data, new_shape, is_seg, axis, order, do_separate_z, order_z=order_z)
+ return data_reshaped
+
+
+def resample_data_or_seg_to_shape(data: Union[torch.Tensor, np.ndarray],
+ new_shape: Union[Tuple[int, ...], List[int], np.ndarray],
+ current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
+ new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
+ is_seg: bool = False,
+ order: int = 3, order_z: int = 0,
+ force_separate_z: Union[bool, None] = False,
+ separate_z_anisotropy_threshold: float = ANISO_THRESHOLD):
+ """
+ needed for segmentation export. Stupid, I know. Maybe we can fix that with Leos new resampling functions
+ """
+ if isinstance(data, torch.Tensor):
+ data = data.cpu().numpy()
+ if force_separate_z is not None:
+ do_separate_z = force_separate_z
+ if force_separate_z:
+ axis = get_lowres_axis(current_spacing)
+ else:
+ axis = None
+ else:
+ if get_do_separate_z(current_spacing, separate_z_anisotropy_threshold):
+ do_separate_z = True
+ axis = get_lowres_axis(current_spacing)
+ elif get_do_separate_z(new_spacing, separate_z_anisotropy_threshold):
+ do_separate_z = True
+ axis = get_lowres_axis(new_spacing)
+ else:
+ do_separate_z = False
+ axis = None
+
+ if axis is not None:
+ if len(axis) == 3:
+ # every axis has the same spacing, this should never happen, why is this code here?
+ do_separate_z = False
+ elif len(axis) == 2:
+ # this happens for spacings like (0.24, 1.25, 1.25) for example. In that case we do not want to resample
+ # separately in the out of plane axis
+ do_separate_z = False
+ else:
+ pass
+
+ if data is not None:
+ assert len(data.shape) == 4, "data must be c x y z"
+
+ data_reshaped = resample_data_or_seg(data, new_shape, is_seg, axis, order, do_separate_z, order_z=order_z)
+ return data_reshaped
+
+
+def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], List[float], np.ndarray],
+ is_seg: bool = False, axis: Union[None, int] = None, order: int = 3,
+ do_separate_z: bool = False, order_z: int = 0):
+ """
+ separate_z=True will resample with order 0 along z
+ :param data:
+ :param new_shape:
+ :param is_seg:
+ :param axis:
+ :param order:
+ :param do_separate_z:
+ :param order_z: only applies if do_separate_z is True
+ :return:
+ """
+ assert len(data.shape) == 4, "data must be (c, x, y, z)"
+ assert len(new_shape) == len(data.shape) - 1
+
+ if is_seg:
+ resize_fn = resize_segmentation
+ kwargs = OrderedDict()
+ else:
+ resize_fn = resize
+ kwargs = {'mode': 'edge', 'anti_aliasing': False}
+ dtype_data = data.dtype
+ shape = np.array(data[0].shape)
+ new_shape = np.array(new_shape)
+ if np.any(shape != new_shape):
+ data = data.astype(float)
+ if do_separate_z:
+ print("separate z, order in z is", order_z, "order inplane is", order)
+ assert len(axis) == 1, "only one anisotropic axis supported"
+ axis = axis[0]
+ if axis == 0:
+ new_shape_2d = new_shape[1:]
+ elif axis == 1:
+ new_shape_2d = new_shape[[0, 2]]
+ else:
+ new_shape_2d = new_shape[:-1]
+
+ reshaped_final_data = []
+ for c in range(data.shape[0]):
+ reshaped_data = []
+ for slice_id in range(shape[axis]):
+ if axis == 0:
+ reshaped_data.append(resize_fn(data[c, slice_id], new_shape_2d, order, **kwargs))
+ elif axis == 1:
+ reshaped_data.append(resize_fn(data[c, :, slice_id], new_shape_2d, order, **kwargs))
+ else:
+ reshaped_data.append(resize_fn(data[c, :, :, slice_id], new_shape_2d, order, **kwargs))
+ reshaped_data = np.stack(reshaped_data, axis)
+ if shape[axis] != new_shape[axis]:
+
+ # The following few lines are blatantly copied and modified from sklearn's resize()
+ rows, cols, dim = new_shape[0], new_shape[1], new_shape[2]
+ orig_rows, orig_cols, orig_dim = reshaped_data.shape
+
+ row_scale = float(orig_rows) / rows
+ col_scale = float(orig_cols) / cols
+ dim_scale = float(orig_dim) / dim
+
+ map_rows, map_cols, map_dims = np.mgrid[:rows, :cols, :dim]
+ map_rows = row_scale * (map_rows + 0.5) - 0.5
+ map_cols = col_scale * (map_cols + 0.5) - 0.5
+ map_dims = dim_scale * (map_dims + 0.5) - 0.5
+ map_rows = map_rows.astype(np.float32)
+ map_cols = map_cols.astype(np.float32)
+ map_dims = map_dims.astype(np.float32)
+ print(map_rows.dtype, flush=True)
+ print(map_cols.dtype, flush=True)
+ print(map_dims.dtype, flush=True)
+ coord_map = np.array([map_rows, map_cols, map_dims])
+ if not is_seg or order_z == 0:
+ reshaped_final_data.append(map_coordinates(reshaped_data, coord_map, order=order_z,
+ mode='nearest')[None])
+ else:
+ unique_labels = np.sort(pd.unique(reshaped_data.ravel())) # np.unique(reshaped_data)
+ reshaped = np.zeros(new_shape, dtype=dtype_data)
+
+ for i, cl in enumerate(unique_labels):
+ reshaped_multihot = np.round(
+ map_coordinates((reshaped_data == cl).astype(float), coord_map, order=order_z,
+ mode='nearest'))
+ reshaped[reshaped_multihot > 0.5] = cl
+ reshaped_final_data.append(reshaped[None])
+ else:
+ reshaped_final_data.append(reshaped_data[None])
+ reshaped_final_data = np.vstack(reshaped_final_data)
+ else:
+ # print("no separate z, order", order)
+ reshaped = []
+ for c in range(data.shape[0]):
+ reshaped.append(resize_fn(data[c], new_shape, order, **kwargs)[None])
+ reshaped_final_data = np.vstack(reshaped)
+ return reshaped_final_data.astype(dtype_data)
+ else:
+ # print("no resampling necessary")
+ return data
diff --git a/nnUNet/nnunetv2/preprocessing/resampling/utils.py b/nnUNet/nnunetv2/preprocessing/resampling/utils.py
new file mode 100644
index 0000000..0bff719
--- /dev/null
+++ b/nnUNet/nnunetv2/preprocessing/resampling/utils.py
@@ -0,0 +1,15 @@
+from typing import Callable
+
+import nnunetv2
+from batchgenerators.utilities.file_and_folder_operations import join
+from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
+
+
+def recursive_find_resampling_fn_by_name(resampling_fn: str) -> Callable:
+ ret = recursive_find_python_class(join(nnunetv2.__path__[0], "preprocessing", "resampling"), resampling_fn,
+ 'nnunetv2.preprocessing.resampling')
+ if ret is None:
+ raise RuntimeError("Unable to find resampling function named '%s'. Please make sure this fn is located in the "
+ "nnunetv2.preprocessing.resampling module." % resampling_fn)
+ else:
+ return ret
diff --git a/nnUNet/nnunetv2/run/__init__.py b/nnUNet/nnunetv2/run/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/run/load_pretrained_weights.py b/nnUNet/nnunetv2/run/load_pretrained_weights.py
new file mode 100644
index 0000000..b5c51bf
--- /dev/null
+++ b/nnUNet/nnunetv2/run/load_pretrained_weights.py
@@ -0,0 +1,66 @@
+import torch
+from torch._dynamo import OptimizedModule
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+
+def load_pretrained_weights(network, fname, verbose=False):
+ """
+ Transfers all weights between matching keys in state_dicts. matching is done by name and we only transfer if the
+ shape is also the same. Segmentation layers (the 1x1(x1) layers that produce the segmentation maps)
+ identified by keys ending with '.seg_layers') are not transferred!
+
+ If the pretrained weights were optained with a training outside nnU-Net and DDP or torch.optimize was used,
+ you need to change the keys of the pretrained state_dict. DDP adds a 'module.' prefix and torch.optim adds
+ '_orig_mod'. You DO NOT need to worry about this if pretraining was done with nnU-Net as
+ nnUNetTrainer.save_checkpoint takes care of that!
+
+ """
+ saved_model = torch.load(fname)
+ pretrained_dict = saved_model['network_weights']
+
+ skip_strings_in_pretrained = [
+ '.seg_layers.',
+ ]
+
+ if isinstance(network, DDP):
+ mod = network.module
+ else:
+ mod = network
+ if isinstance(mod, OptimizedModule):
+ mod = mod._orig_mod
+
+ model_dict = mod.state_dict()
+ # verify that all but the segmentation layers have the same shape
+ for key, _ in model_dict.items():
+ if all([i not in key for i in skip_strings_in_pretrained]):
+ assert key in pretrained_dict, \
+ f"Key {key} is missing in the pretrained model weights. The pretrained weights do not seem to be " \
+ f"compatible with your network."
+ assert model_dict[key].shape == pretrained_dict[key].shape, \
+ f"The shape of the parameters of key {key} is not the same. Pretrained model: " \
+ f"{pretrained_dict[key].shape}; your network: {model_dict[key]}. The pretrained model " \
+ f"does not seem to be compatible with your network."
+
+ # fun fact: in principle this allows loading from parameters that do not cover the entire network. For example pretrained
+ # encoders. Not supported by this function though (see assertions above)
+
+ # commenting out this abomination of a dict comprehension for preservation in the archives of 'what not to do'
+ # pretrained_dict = {'module.' + k if is_ddp else k: v
+ # for k, v in pretrained_dict.items()
+ # if (('module.' + k if is_ddp else k) in model_dict) and
+ # all([i not in k for i in skip_strings_in_pretrained])}
+
+ pretrained_dict = {k: v for k, v in pretrained_dict.items()
+ if k in model_dict.keys() and all([i not in k for i in skip_strings_in_pretrained])}
+
+ model_dict.update(pretrained_dict)
+
+ print("################### Loading pretrained weights from file ", fname, '###################')
+ if verbose:
+ print("Below is the list of overlapping blocks in pretrained model and nnUNet architecture:")
+ for key, value in pretrained_dict.items():
+ print(key, 'shape', value.shape)
+ print("################### Done ###################")
+ mod.load_state_dict(model_dict)
+
+
diff --git a/nnUNet/nnunetv2/run/run_finetuning_stunet.py b/nnUNet/nnunetv2/run/run_finetuning_stunet.py
new file mode 100644
index 0000000..f295194
--- /dev/null
+++ b/nnUNet/nnunetv2/run/run_finetuning_stunet.py
@@ -0,0 +1,66 @@
+from unittest.mock import patch
+from torch._dynamo import OptimizedModule
+from torch.nn.parallel import DistributedDataParallel as DDP
+import torch
+
+def load_stunet_pretrained_weights(network, fname, verbose=False):
+
+ saved_model = torch.load(fname)
+ print(saved_model.keys())
+ if fname.endswith('pth'):
+ pretrained_dict = saved_model['network_weights']
+ elif fname.endswith('model'):
+ pretrained_dict = saved_model['state_dict']
+
+ skip_strings_in_pretrained = [
+ 'seg_outputs',
+ 'conv_blocks_context.0',
+ ]
+
+ if isinstance(network, DDP):
+ mod = network.module
+ else:
+ mod = network
+ if isinstance(mod, OptimizedModule):
+ mod = mod._orig_mod
+
+ model_dict = mod.state_dict()
+ # verify that all but the segmentation layers have the same shape
+ for key, _ in model_dict.items():
+ if all([i not in key for i in skip_strings_in_pretrained]):
+ assert key in pretrained_dict, \
+ f"Key {key} is missing in the pretrained model weights. The pretrained weights do not seem to be " \
+ f"compatible with your network."
+ assert model_dict[key].shape == pretrained_dict[key].shape, \
+ f"The shape of the parameters of key {key} is not the same. Pretrained model: " \
+ f"{pretrained_dict[key].shape}; your network: {model_dict[key].shape}. The pretrained model " \
+ f"does not seem to be compatible with your network."
+
+ # fun fact: in principle this allows loading from parameters that do not cover the entire network. For example pretrained
+ # encoders. Not supported by this function though (see assertions above)
+
+ # commenting out this abomination of a dict comprehension for preservation in the archives of 'what not to do'
+ # pretrained_dict = {'module.' + k if is_ddp else k: v
+ # for k, v in pretrained_dict.items()
+ # if (('module.' + k if is_ddp else k) in model_dict) and
+ # all([i not in k for i in skip_strings_in_pretrained])}
+
+ pretrained_dict = {k: v for k, v in pretrained_dict.items()
+ if k in model_dict.keys() and all([i not in k for i in skip_strings_in_pretrained])}
+
+ model_dict.update(pretrained_dict)
+
+ print("################### Loading pretrained weights from file ", fname, '###################')
+ if verbose:
+ print("Below is the list of overlapping blocks in pretrained model and nnUNet architecture:")
+ for key, value in pretrained_dict.items():
+ print(key, 'shape', value.shape)
+ print("################### Done ###################")
+ mod.load_state_dict(model_dict)
+
+
+
+if __name__ == '__main__':
+ from run_training import run_training_entry
+ # with patch("run_training.load_pretrained_weights", load_stunet_pretrained_weights):
+ run_training_entry()
diff --git a/nnUNet/nnunetv2/run/run_training.py b/nnUNet/nnunetv2/run/run_training.py
new file mode 100644
index 0000000..7903590
--- /dev/null
+++ b/nnUNet/nnunetv2/run/run_training.py
@@ -0,0 +1,260 @@
+import os
+import socket
+from typing import Union, Optional
+
+import nnunetv2
+import torch.cuda
+import torch.distributed as dist
+import torch.multiprocessing as mp
+from batchgenerators.utilities.file_and_folder_operations import join, isfile, load_json
+from nnunetv2.paths import nnUNet_preprocessed
+# from nnunetv2.run.load_pretrained_weights import load_pretrained_weights
+from nnunetv2.run.run_finetuning_stunet import load_stunet_pretrained_weights as load_pretrained_weights
+from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
+from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
+from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
+from torch.backends import cudnn
+
+
+def find_free_network_port() -> int:
+ """Finds a free port on localhost.
+
+ It is useful in single-node training when we don't want to connect to a real main node but have to set the
+ `MASTER_PORT` environment variable.
+ """
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ s.bind(("", 0))
+ port = s.getsockname()[1]
+ s.close()
+ return port
+
+
+def get_trainer_from_args(dataset_name_or_id: Union[int, str],
+ configuration: str,
+ fold: int,
+ trainer_name: str = 'nnUNetTrainer',
+ plans_identifier: str = 'nnUNetPlans',
+ use_compressed: bool = False,
+ device: torch.device = torch.device('cuda')):
+ # load nnunet class and do sanity checks
+ nnunet_trainer = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
+ trainer_name, 'nnunetv2.training.nnUNetTrainer')
+ if nnunet_trainer is None:
+ raise RuntimeError(f'Could not find requested nnunet trainer {trainer_name} in '
+ f'nnunetv2.training.nnUNetTrainer ('
+ f'{join(nnunetv2.__path__[0], "training", "nnUNetTrainer")}). If it is located somewhere '
+ f'else, please move it there.')
+ assert issubclass(nnunet_trainer, nnUNetTrainer), 'The requested nnunet trainer class must inherit from ' \
+ 'nnUNetTrainer'
+
+ # handle dataset input. If it's an ID we need to convert to int from string
+ if dataset_name_or_id.startswith('Dataset'):
+ pass
+ else:
+ try:
+ dataset_name_or_id = int(dataset_name_or_id)
+ except ValueError:
+ raise ValueError(f'dataset_name_or_id must either be an integer or a valid dataset name with the pattern '
+ f'DatasetXXX_YYY where XXX are the three(!) task ID digits. Your '
+ f'input: {dataset_name_or_id}')
+
+ # initialize nnunet trainer
+ preprocessed_dataset_folder_base = join(nnUNet_preprocessed, maybe_convert_to_dataset_name(dataset_name_or_id))
+ plans_file = join(preprocessed_dataset_folder_base, plans_identifier + '.json')
+ plans = load_json(plans_file)
+ dataset_json = load_json(join(preprocessed_dataset_folder_base, 'dataset.json'))
+ nnunet_trainer = nnunet_trainer(plans=plans, configuration=configuration, fold=fold,
+ dataset_json=dataset_json, unpack_dataset=not use_compressed, device=device)
+ return nnunet_trainer
+
+
+def maybe_load_checkpoint(nnunet_trainer: nnUNetTrainer, continue_training: bool, validation_only: bool,
+ pretrained_weights_file: str = None):
+ if continue_training and pretrained_weights_file is not None:
+ raise RuntimeError('Cannot both continue a training AND load pretrained weights. Pretrained weights can only '
+ 'be used at the beginning of the training.')
+ if continue_training:
+ expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_final.pth')
+ if not isfile(expected_checkpoint_file):
+ expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_latest.pth')
+ # special case where --c is used to run a previously aborted validation
+ if not isfile(expected_checkpoint_file):
+ expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_best.pth')
+ if not isfile(expected_checkpoint_file):
+ print(f"WARNING: Cannot continue training because there seems to be no checkpoint available to "
+ f"continue from. Starting a new training...")
+ expected_checkpoint_file = None
+ elif validation_only:
+ expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_final.pth')
+ if not isfile(expected_checkpoint_file):
+ raise RuntimeError(f"Cannot run validation because the training is not finished yet!")
+ else:
+ if pretrained_weights_file is not None:
+ if not nnunet_trainer.was_initialized:
+ nnunet_trainer.initialize()
+ load_pretrained_weights(nnunet_trainer.network, pretrained_weights_file, verbose=True)
+ expected_checkpoint_file = None
+
+ if expected_checkpoint_file is not None:
+ nnunet_trainer.load_checkpoint(expected_checkpoint_file)
+
+
+def setup_ddp(rank, world_size):
+ # initialize the process group
+ dist.init_process_group("nccl", rank=rank, world_size=world_size)
+
+
+def cleanup_ddp():
+ dist.destroy_process_group()
+
+
+def run_ddp(rank, dataset_name_or_id, configuration, fold, tr, p, use_compressed, disable_checkpointing, c, val, pretrained_weights, npz, world_size):
+ setup_ddp(rank, world_size)
+ torch.cuda.set_device(torch.device('cuda', dist.get_rank()))
+
+ nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, tr, p,
+ use_compressed)
+
+ if disable_checkpointing:
+ nnunet_trainer.disable_checkpointing = disable_checkpointing
+
+ assert not (c and val), f'Cannot set --c and --val flag at the same time. Dummy.'
+
+ maybe_load_checkpoint(nnunet_trainer, c, val, pretrained_weights)
+
+ if torch.cuda.is_available():
+ cudnn.deterministic = False
+ cudnn.benchmark = True
+
+ if not val:
+ nnunet_trainer.run_training()
+
+ nnunet_trainer.perform_actual_validation(npz)
+ cleanup_ddp()
+
+
+def run_training(dataset_name_or_id: Union[str, int],
+ configuration: str, fold: Union[int, str],
+ trainer_class_name: str = 'nnUNetTrainer',
+ plans_identifier: str = 'nnUNetPlans',
+ pretrained_weights: Optional[str] = None,
+ num_gpus: int = 1,
+ use_compressed_data: bool = False,
+ export_validation_probabilities: bool = False,
+ continue_training: bool = False,
+ only_run_validation: bool = False,
+ disable_checkpointing: bool = False,
+ device: torch.device = torch.device('cuda')):
+ if isinstance(fold, str):
+ if fold != 'all':
+ try:
+ fold = int(fold)
+ except ValueError as e:
+ print(f'Unable to convert given value for fold to int: {fold}. fold must bei either "all" or an integer!')
+ raise e
+
+ if num_gpus > 1:
+ assert device.type == 'cuda', f"DDP training (triggered by num_gpus > 1) is only implemented for cuda devices. Your device: {device}"
+
+ os.environ['MASTER_ADDR'] = 'localhost'
+ if 'MASTER_PORT' not in os.environ.keys():
+ port = str(find_free_network_port())
+ print(f"using port {port}")
+ os.environ['MASTER_PORT'] = port # str(port)
+
+ mp.spawn(run_ddp,
+ args=(
+ dataset_name_or_id,
+ configuration,
+ fold,
+ trainer_class_name,
+ plans_identifier,
+ use_compressed_data,
+ disable_checkpointing,
+ continue_training,
+ only_run_validation,
+ pretrained_weights,
+ export_validation_probabilities,
+ num_gpus),
+ nprocs=num_gpus,
+ join=True)
+ else:
+ nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, trainer_class_name,
+ plans_identifier, use_compressed_data, device=device)
+
+ if disable_checkpointing:
+ nnunet_trainer.disable_checkpointing = disable_checkpointing
+
+ assert not (continue_training and only_run_validation), f'Cannot set --c and --val flag at the same time. Dummy.'
+
+ maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights)
+
+ if torch.cuda.is_available():
+ cudnn.deterministic = False
+ cudnn.benchmark = True
+
+ if not only_run_validation:
+ nnunet_trainer.run_training()
+
+ nnunet_trainer.perform_actual_validation(export_validation_probabilities)
+
+
+def run_training_entry():
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('dataset_name_or_id', type=str,
+ help="Dataset name or ID to train with")
+ parser.add_argument('configuration', type=str,
+ help="Configuration that should be trained")
+ parser.add_argument('fold', type=str,
+ help='Fold of the 5-fold cross-validation. Should be an int between 0 and 4.')
+ parser.add_argument('-tr', type=str, required=False, default='nnUNetTrainer',
+ help='[OPTIONAL] Use this flag to specify a custom trainer. Default: nnUNetTrainer')
+ parser.add_argument('-p', type=str, required=False, default='nnUNetPlans',
+ help='[OPTIONAL] Use this flag to specify a custom plans identifier. Default: nnUNetPlans')
+ parser.add_argument('-pretrained_weights', type=str, required=False, default=None,
+ help='[OPTIONAL] path to nnU-Net checkpoint file to be used as pretrained model. Will only '
+ 'be used when actually training. Beta. Use with caution.')
+ parser.add_argument('-num_gpus', type=int, default=1, required=False,
+ help='Specify the number of GPUs to use for training')
+ parser.add_argument("--use_compressed", default=False, action="store_true", required=False,
+ help="[OPTIONAL] If you set this flag the training cases will not be decompressed. Reading compressed "
+ "data is much more CPU and (potentially) RAM intensive and should only be used if you "
+ "know what you are doing")
+ parser.add_argument('--npz', action='store_true', required=False,
+ help='[OPTIONAL] Save softmax predictions from final validation as npz files (in addition to predicted '
+ 'segmentations). Needed for finding the best ensemble.')
+ parser.add_argument('--c', action='store_true', required=False,
+ help='[OPTIONAL] Continue training from latest checkpoint')
+ parser.add_argument('--val', action='store_true', required=False,
+ help='[OPTIONAL] Set this flag to only run the validation. Requires training to have finished.')
+ parser.add_argument('--disable_checkpointing', action='store_true', required=False,
+ help='[OPTIONAL] Set this flag to disable checkpointing. Ideal for testing things out and '
+ 'you dont want to flood your hard drive with checkpoints.')
+ parser.add_argument('-device', type=str, default='cuda', required=False,
+ help="Use this to set the device the training should run with. Available options are 'cuda' "
+ "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! "
+ "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_train [...] instead!")
+ args = parser.parse_args()
+
+ assert args.device in ['cpu', 'cuda', 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.'
+ if args.device == 'cpu':
+ # let's allow torch to use hella threads
+ import multiprocessing
+ torch.set_num_threads(multiprocessing.cpu_count())
+ device = torch.device('cpu')
+ elif args.device == 'cuda':
+ # multithreading in torch doesn't help nnU-Net if run on GPU
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ device = torch.device('cuda')
+ else:
+ device = torch.device('mps')
+
+ run_training(args.dataset_name_or_id, args.configuration, args.fold, args.tr, args.p, args.pretrained_weights,
+ args.num_gpus, args.use_compressed, args.npz, args.c, args.val, args.disable_checkpointing,
+ device=device)
+
+
+if __name__ == '__main__':
+ run_training_entry()
diff --git a/nnUNet/nnunetv2/tests/__init__.py b/nnUNet/nnunetv2/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/tests/integration_tests/__init__.py b/nnUNet/nnunetv2/tests/integration_tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/tests/integration_tests/add_lowres_and_cascade.py b/nnUNet/nnunetv2/tests/integration_tests/add_lowres_and_cascade.py
new file mode 100644
index 0000000..a1b4df1
--- /dev/null
+++ b/nnUNet/nnunetv2/tests/integration_tests/add_lowres_and_cascade.py
@@ -0,0 +1,33 @@
+from batchgenerators.utilities.file_and_folder_operations import *
+
+from nnunetv2.paths import nnUNet_preprocessed
+from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
+
+if __name__ == '__main__':
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-d', nargs='+', type=int, help='List of dataset ids')
+ args = parser.parse_args()
+
+ for d in args.d:
+ dataset_name = maybe_convert_to_dataset_name(d)
+ plans = load_json(join(nnUNet_preprocessed, dataset_name, 'nnUNetPlans.json'))
+ plans['configurations']['3d_lowres'] = {
+ "data_identifier": "nnUNetPlans_3d_lowres", # do not be a dumbo and forget this. I was a dumbo. And I paid dearly with ~10 min debugging time
+ 'inherits_from': '3d_fullres',
+ "patch_size": [20, 28, 20],
+ "median_image_size_in_voxels": [18.0, 25.0, 18.0],
+ "spacing": [2.0, 2.0, 2.0],
+ "n_conv_per_stage_encoder": [2, 2, 2],
+ "n_conv_per_stage_decoder": [2, 2],
+ "num_pool_per_axis": [2, 2, 2],
+ "pool_op_kernel_sizes": [[1, 1, 1], [2, 2, 2], [2, 2, 2]],
+ "conv_kernel_sizes": [[3, 3, 3], [3, 3, 3], [3, 3, 3]],
+ "next_stage": "3d_cascade_fullres"
+ }
+ plans['configurations']['3d_cascade_fullres'] = {
+ 'inherits_from': '3d_fullres',
+ "previous_stage": "3d_lowres"
+ }
+ save_json(plans, join(nnUNet_preprocessed, dataset_name, 'nnUNetPlans.json'), sort_keys=False)
\ No newline at end of file
diff --git a/nnUNet/nnunetv2/tests/integration_tests/cleanup_integration_test.py b/nnUNet/nnunetv2/tests/integration_tests/cleanup_integration_test.py
new file mode 100644
index 0000000..c9fca95
--- /dev/null
+++ b/nnUNet/nnunetv2/tests/integration_tests/cleanup_integration_test.py
@@ -0,0 +1,19 @@
+import shutil
+
+from batchgenerators.utilities.file_and_folder_operations import isdir, join
+
+from nnunetv2.paths import nnUNet_raw, nnUNet_results, nnUNet_preprocessed
+
+if __name__ == '__main__':
+ # deletes everything!
+ dataset_names = [
+ 'Dataset996_IntegrationTest_Hippocampus_regions_ignore',
+ 'Dataset997_IntegrationTest_Hippocampus_regions',
+ 'Dataset998_IntegrationTest_Hippocampus_ignore',
+ 'Dataset999_IntegrationTest_Hippocampus',
+ ]
+ for fld in [nnUNet_raw, nnUNet_preprocessed, nnUNet_results]:
+ for d in dataset_names:
+ if isdir(join(fld, d)):
+ shutil.rmtree(join(fld, d))
+
diff --git a/nnUNet/nnunetv2/tests/integration_tests/lsf_commands.sh b/nnUNet/nnunetv2/tests/integration_tests/lsf_commands.sh
new file mode 100644
index 0000000..3888c1a
--- /dev/null
+++ b/nnUNet/nnunetv2/tests/integration_tests/lsf_commands.sh
@@ -0,0 +1,10 @@
+bsub -q gpu.legacy -gpu num=1:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test.sh 996"
+bsub -q gpu.legacy -gpu num=1:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test.sh 997"
+bsub -q gpu.legacy -gpu num=1:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test.sh 998"
+bsub -q gpu.legacy -gpu num=1:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test.sh 999"
+
+
+bsub -q gpu.legacy -gpu num=2:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh 996"
+bsub -q gpu.legacy -gpu num=2:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh 997"
+bsub -q gpu.legacy -gpu num=2:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh 998"
+bsub -q gpu.legacy -gpu num=2:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh 999"
diff --git a/nnUNet/nnunetv2/tests/integration_tests/prepare_integration_tests.sh b/nnUNet/nnunetv2/tests/integration_tests/prepare_integration_tests.sh
new file mode 100644
index 0000000..b5dda42
--- /dev/null
+++ b/nnUNet/nnunetv2/tests/integration_tests/prepare_integration_tests.sh
@@ -0,0 +1,18 @@
+# assumes you are in the nnunet repo!
+
+# prepare raw datasets
+python nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset999_IntegrationTest_Hippocampus.py
+python nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset998_IntegrationTest_Hippocampus_ignore.py
+python nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset997_IntegrationTest_Hippocampus_regions.py
+python nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset996_IntegrationTest_Hippocampus_regions_ignore.py
+
+# now run experiment planning without preprocessing
+nnUNetv2_plan_and_preprocess -d 996 997 998 999 --no_pp
+
+# now add 3d lowres and cascade
+python nnunetv2/tests/integration_tests/add_lowres_and_cascade.py -d 996 997 998 999
+
+# now preprocess everything
+nnUNetv2_preprocess -d 996 997 998 999 -c 2d 3d_lowres 3d_fullres -np 8 8 8 # no need to preprocess cascade as its the same data as 3d_fullres
+
+# done
\ No newline at end of file
diff --git a/nnUNet/nnunetv2/tests/integration_tests/readme.md b/nnUNet/nnunetv2/tests/integration_tests/readme.md
new file mode 100644
index 0000000..2a44f13
--- /dev/null
+++ b/nnUNet/nnunetv2/tests/integration_tests/readme.md
@@ -0,0 +1,58 @@
+# Preface
+
+I am just a mortal with many tasks and limited time. Aint nobody got time for unittests.
+
+HOWEVER, at least some integration tests should be performed testing nnU-Net from start to finish.
+
+# Introduction - What the heck is happening?
+This test covers all possible labeling scenarios (standard labels, regions, ignore labels and regions with
+ignore labels). It runs the entire nnU-Net pipeline from start to finish:
+
+- fingerprint extraction
+- experiment planning
+- preprocessing
+- train all 4 configurations (2d, 3d_lowres, 3d_fullres, 3d_cascade_fullres) as 5-fold CV
+- automatically find the best model or ensemble
+- determine the postprocessing used for this
+- predict some test set
+- apply postprocessing to the test set
+
+To speed things up, we do the following:
+- pick Dataset004_Hippocampus because it is quadratisch praktisch gut. MNIST of medical image segmentation
+- by default this dataset does not have 3d_lowres or cascade. We just manually add them (cool new feature, eh?). See `add_lowres_and_cascade.py` to learn more!
+- we use nnUNetTrainer_5epochs for a short training
+
+# How to run it?
+
+Set your pwd to be the nnunet repo folder (the one where the `nnunetv2` folder and the `setup.py` are located!)
+
+Now generate the 4 dummy datasets (ids 996, 997, 998, 999) from dataset 4. This will crash if you don't have Dataset004!
+```commandline
+bash nnunetv2/tests/integration_tests/prepare_integration_tests.sh
+```
+
+Now you can run the integration test for each of the datasets:
+```commandline
+bash nnunetv2/tests/integration_tests/run_integration_test.sh DATSET_ID
+```
+use DATSET_ID 996, 997, 998 and 999. You can run these independently on different GPUs/systems to speed things up.
+This will take i dunno like 10-30 Minutes!?
+
+Also run
+```commandline
+bash nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh DATSET_ID
+```
+to verify DDP is working (needs 2 GPUs!)
+
+# How to check if the test was successful?
+If I was not as lazy as I am I would have programmed some automatism that checks if Dice scores etc are in an acceptable range.
+So you need to do the following:
+1) check that none of your runs crashed (duh)
+2) for each run, navigate to `nnUNet_results/DATASET_NAME` and take a look at the `inference_information.json` file.
+Does it make sense? If so: NICE!
+
+Once the integration test is completed you can delete all the temporary files associated with it by running:
+
+```commandline
+python nnunetv2/tests/integration_tests/cleanup_integration_test.py
+```
\ No newline at end of file
diff --git a/nnUNet/nnunetv2/tests/integration_tests/run_integration_test.sh b/nnUNet/nnunetv2/tests/integration_tests/run_integration_test.sh
new file mode 100644
index 0000000..ff0426c
--- /dev/null
+++ b/nnUNet/nnunetv2/tests/integration_tests/run_integration_test.sh
@@ -0,0 +1,27 @@
+
+
+nnUNetv2_train $1 3d_fullres 0 -tr nnUNetTrainer_5epochs --npz
+nnUNetv2_train $1 3d_fullres 1 -tr nnUNetTrainer_5epochs --npz
+nnUNetv2_train $1 3d_fullres 2 -tr nnUNetTrainer_5epochs --npz
+nnUNetv2_train $1 3d_fullres 3 -tr nnUNetTrainer_5epochs --npz
+nnUNetv2_train $1 3d_fullres 4 -tr nnUNetTrainer_5epochs --npz
+
+nnUNetv2_train $1 2d 0 -tr nnUNetTrainer_5epochs --npz
+nnUNetv2_train $1 2d 1 -tr nnUNetTrainer_5epochs --npz
+nnUNetv2_train $1 2d 2 -tr nnUNetTrainer_5epochs --npz
+nnUNetv2_train $1 2d 3 -tr nnUNetTrainer_5epochs --npz
+nnUNetv2_train $1 2d 4 -tr nnUNetTrainer_5epochs --npz
+
+nnUNetv2_train $1 3d_lowres 0 -tr nnUNetTrainer_5epochs --npz
+nnUNetv2_train $1 3d_lowres 1 -tr nnUNetTrainer_5epochs --npz
+nnUNetv2_train $1 3d_lowres 2 -tr nnUNetTrainer_5epochs --npz
+nnUNetv2_train $1 3d_lowres 3 -tr nnUNetTrainer_5epochs --npz
+nnUNetv2_train $1 3d_lowres 4 -tr nnUNetTrainer_5epochs --npz
+
+nnUNetv2_train $1 3d_cascade_fullres 0 -tr nnUNetTrainer_5epochs --npz
+nnUNetv2_train $1 3d_cascade_fullres 1 -tr nnUNetTrainer_5epochs --npz
+nnUNetv2_train $1 3d_cascade_fullres 2 -tr nnUNetTrainer_5epochs --npz
+nnUNetv2_train $1 3d_cascade_fullres 3 -tr nnUNetTrainer_5epochs --npz
+nnUNetv2_train $1 3d_cascade_fullres 4 -tr nnUNetTrainer_5epochs --npz
+
+python nnunetv2/tests/integration_tests/run_integration_test_bestconfig_inference.py -d $1
\ No newline at end of file
diff --git a/nnUNet/nnunetv2/tests/integration_tests/run_integration_test_bestconfig_inference.py b/nnUNet/nnunetv2/tests/integration_tests/run_integration_test_bestconfig_inference.py
new file mode 100644
index 0000000..89e783e
--- /dev/null
+++ b/nnUNet/nnunetv2/tests/integration_tests/run_integration_test_bestconfig_inference.py
@@ -0,0 +1,75 @@
+import argparse
+
+import torch
+from batchgenerators.utilities.file_and_folder_operations import join, load_pickle
+
+from nnunetv2.ensembling.ensemble import ensemble_folders
+from nnunetv2.evaluation.find_best_configuration import find_best_configuration, \
+ dumb_trainer_config_plans_to_trained_models_dict
+from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
+from nnunetv2.paths import nnUNet_raw, nnUNet_results
+from nnunetv2.postprocessing.remove_connected_components import apply_postprocessing_to_folder
+from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
+from nnunetv2.utilities.file_path_utilities import get_output_folder
+
+
+if __name__ == '__main__':
+ """
+ Predicts the imagesTs folder with the best configuration and applies postprocessing
+ """
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-d', type=int, help='dataset id')
+ args = parser.parse_args()
+ d = args.d
+
+ dataset_name = maybe_convert_to_dataset_name(d)
+ source_dir = join(nnUNet_raw, dataset_name, 'imagesTs')
+ target_dir_base = join(nnUNet_results, dataset_name)
+
+ models = dumb_trainer_config_plans_to_trained_models_dict(['nnUNetTrainer_5epochs'],
+ ['2d',
+ '3d_lowres',
+ '3d_cascade_fullres',
+ '3d_fullres'],
+ ['nnUNetPlans'])
+ ret = find_best_configuration(d, models, allow_ensembling=True, num_processes=8, overwrite=True,
+ folds=(0, 1, 2, 3, 4), strict=True)
+
+ has_ensemble = len(ret['best_model_or_ensemble']['selected_model_or_models']) > 1
+
+ # we don't use all folds to speed stuff up
+ used_folds = (0, 3)
+ output_folders = []
+ for im in ret['best_model_or_ensemble']['selected_model_or_models']:
+ output_dir = join(target_dir_base, f"pred_{im['configuration']}")
+ model_folder = get_output_folder(d, im['trainer'], im['plans_identifier'], im['configuration'])
+ # note that if the best model is the enseble of 3d_lowres and 3d cascade then 3d_lowres will be predicted
+ # twice (once standalone and once to generate the predictions for the cascade) because we don't reuse the
+ # prediction here. Proper way would be to check for that and
+ # then give the output of 3d_lowres inference to the folder_with_segs_from_prev_stage kwarg in
+ # predict_from_raw_data. Since we allow for
+ # dynamically setting 'previous_stage' in the plans I am too lazy to implement this here. This is just an
+ # integration test after all. Take a closer look at how this in handled in predict_from_raw_data
+ predictor = nnUNetPredictor(verbose=False, allow_tqdm=False)
+ predictor.initialize_from_trained_model_folder(model_folder, used_folds)
+ predictor.predict_from_files(source_dir, output_dir, has_ensemble, overwrite=True)
+ # predict_from_raw_data(list_of_lists_or_source_folder=source_dir, output_folder=output_dir,
+ # model_training_output_dir=model_folder, use_folds=used_folds,
+ # save_probabilities=has_ensemble, verbose=False, overwrite=True)
+ output_folders.append(output_dir)
+
+ # if we have an ensemble, we need to ensemble the results
+ if has_ensemble:
+ ensemble_folders(output_folders, join(target_dir_base, 'ensemble_predictions'), save_merged_probabilities=False)
+ folder_for_pp = join(target_dir_base, 'ensemble_predictions')
+ else:
+ folder_for_pp = output_folders[0]
+
+ # apply postprocessing
+ pp_fns, pp_fn_kwargs = load_pickle(ret['best_model_or_ensemble']['postprocessing_file'])
+ apply_postprocessing_to_folder(folder_for_pp, join(target_dir_base, 'ensemble_predictions_postprocessed'),
+ pp_fns,
+ pp_fn_kwargs, plans_file_or_dict=ret['best_model_or_ensemble']['some_plans_file'])
diff --git a/nnUNet/nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh b/nnUNet/nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh
new file mode 100644
index 0000000..5199247
--- /dev/null
+++ b/nnUNet/nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh
@@ -0,0 +1 @@
+nnUNetv2_train $1 3d_fullres 0 -tr nnUNetTrainer_10epochs -num_gpus 2
diff --git a/nnUNet/nnunetv2/training/__init__.py b/nnUNet/nnunetv2/training/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/training/data_augmentation/__init__.py b/nnUNet/nnunetv2/training/data_augmentation/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/training/data_augmentation/compute_initial_patch_size.py b/nnUNet/nnunetv2/training/data_augmentation/compute_initial_patch_size.py
new file mode 100644
index 0000000..a772bc2
--- /dev/null
+++ b/nnUNet/nnunetv2/training/data_augmentation/compute_initial_patch_size.py
@@ -0,0 +1,24 @@
+import numpy as np
+
+
+def get_patch_size(final_patch_size, rot_x, rot_y, rot_z, scale_range):
+ if isinstance(rot_x, (tuple, list)):
+ rot_x = max(np.abs(rot_x))
+ if isinstance(rot_y, (tuple, list)):
+ rot_y = max(np.abs(rot_y))
+ if isinstance(rot_z, (tuple, list)):
+ rot_z = max(np.abs(rot_z))
+ rot_x = min(90 / 360 * 2. * np.pi, rot_x)
+ rot_y = min(90 / 360 * 2. * np.pi, rot_y)
+ rot_z = min(90 / 360 * 2. * np.pi, rot_z)
+ from batchgenerators.augmentations.utils import rotate_coords_3d, rotate_coords_2d
+ coords = np.array(final_patch_size)
+ final_shape = np.copy(coords)
+ if len(coords) == 3:
+ final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, rot_x, 0, 0)), final_shape)), 0)
+ final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, rot_y, 0)), final_shape)), 0)
+ final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, 0, rot_z)), final_shape)), 0)
+ elif len(coords) == 2:
+ final_shape = np.max(np.vstack((np.abs(rotate_coords_2d(coords, rot_x)), final_shape)), 0)
+ final_shape /= min(scale_range)
+ return final_shape.astype(int)
diff --git a/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/__init__.py b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/cascade_transforms.py b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/cascade_transforms.py
new file mode 100644
index 0000000..378bab2
--- /dev/null
+++ b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/cascade_transforms.py
@@ -0,0 +1,136 @@
+from typing import Union, List, Tuple, Callable
+
+import numpy as np
+from acvl_utils.morphology.morphology_helper import label_with_component_sizes
+from batchgenerators.transforms.abstract_transforms import AbstractTransform
+from skimage.morphology import ball
+from skimage.morphology.binary import binary_erosion, binary_dilation, binary_closing, binary_opening
+
+
+class MoveSegAsOneHotToData(AbstractTransform):
+ def __init__(self, index_in_origin: int, all_labels: Union[Tuple[int, ...], List[int]],
+ key_origin="seg", key_target="data", remove_from_origin=True):
+ """
+ Takes data_dict[seg][:, index_in_origin], converts it to one hot encoding and appends it to
+ data_dict[key_target]. Optionally removes index_in_origin from data_dict[seg].
+ """
+ self.remove_from_origin = remove_from_origin
+ self.all_labels = all_labels
+ self.key_target = key_target
+ self.key_origin = key_origin
+ self.index_in_origin = index_in_origin
+
+ def __call__(self, **data_dict):
+ seg = data_dict[self.key_origin][:, self.index_in_origin:self.index_in_origin+1]
+
+ seg_onehot = np.zeros((seg.shape[0], len(self.all_labels), *seg.shape[2:]),
+ dtype=data_dict[self.key_target].dtype)
+ for i, l in enumerate(self.all_labels):
+ seg_onehot[:, i][seg[:, 0] == l] = 1
+
+ data_dict[self.key_target] = np.concatenate((data_dict[self.key_target], seg_onehot), 1)
+
+ if self.remove_from_origin:
+ remaining_channels = [i for i in range(data_dict[self.key_origin].shape[1]) if i != self.index_in_origin]
+ data_dict[self.key_origin] = data_dict[self.key_origin][:, remaining_channels]
+
+ return data_dict
+
+
+class RemoveRandomConnectedComponentFromOneHotEncodingTransform(AbstractTransform):
+ def __init__(self, channel_idx: Union[int, List[int]], key: str = "data", p_per_sample: float = 0.2,
+ fill_with_other_class_p: float = 0.25,
+ dont_do_if_covers_more_than_x_percent: float = 0.25, p_per_label: float = 1):
+ """
+ Randomly removes connected components in the specified channel_idx of data_dict[key]. Only considers components
+ smaller than dont_do_if_covers_more_than_X_percent of the sample. Also has the option of simulating
+ misclassification as another class (fill_with_other_class_p)
+ """
+ self.p_per_label = p_per_label
+ self.dont_do_if_covers_more_than_x_percent = dont_do_if_covers_more_than_x_percent
+ self.fill_with_other_class_p = fill_with_other_class_p
+ self.p_per_sample = p_per_sample
+ self.key = key
+ if not isinstance(channel_idx, (list, tuple)):
+ channel_idx = [channel_idx]
+ self.channel_idx = channel_idx
+
+ def __call__(self, **data_dict):
+ data = data_dict.get(self.key)
+ for b in range(data.shape[0]):
+ if np.random.uniform() < self.p_per_sample:
+ for c in self.channel_idx:
+ if np.random.uniform() < self.p_per_label:
+ # print(np.unique(data[b, c])) ## should be [0, 1]
+ workon = data[b, c].astype(bool)
+ if not np.any(workon):
+ continue
+ num_voxels = np.prod(workon.shape, dtype=np.uint64)
+ lab, component_sizes = label_with_component_sizes(workon.astype(bool))
+ if len(component_sizes) > 0:
+ valid_component_ids = [i for i, j in component_sizes.items() if j <
+ num_voxels*self.dont_do_if_covers_more_than_x_percent]
+ # print('RemoveRandomConnectedComponentFromOneHotEncodingTransform', c,
+ # np.unique(data[b, c]), len(component_sizes), valid_component_ids,
+ # len(valid_component_ids))
+ if len(valid_component_ids) > 0:
+ random_component = np.random.choice(valid_component_ids)
+ data[b, c][lab == random_component] = 0
+ if np.random.uniform() < self.fill_with_other_class_p:
+ other_ch = [i for i in self.channel_idx if i != c]
+ if len(other_ch) > 0:
+ other_class = np.random.choice(other_ch)
+ data[b, other_class][lab == random_component] = 1
+ data_dict[self.key] = data
+ return data_dict
+
+
+class ApplyRandomBinaryOperatorTransform(AbstractTransform):
+ def __init__(self,
+ channel_idx: Union[int, List[int], Tuple[int, ...]],
+ p_per_sample: float = 0.3,
+ any_of_these: Tuple[Callable] = (binary_dilation, binary_erosion, binary_closing, binary_opening),
+ key: str = "data",
+ strel_size: Tuple[int, int] = (1, 10),
+ p_per_label: float = 1):
+ """
+ Applies random binary operations (specified by any_of_these) with random ball size (radius is uniformly sampled
+ from interval strel_size) to specified channels. Expects the channel_idx to correspond to a hone hot encoded
+ segmentation (see for example MoveSegAsOneHotToData)
+ """
+ self.p_per_label = p_per_label
+ self.strel_size = strel_size
+ self.key = key
+ self.any_of_these = any_of_these
+ self.p_per_sample = p_per_sample
+
+ if not isinstance(channel_idx, (list, tuple)):
+ channel_idx = [channel_idx]
+ self.channel_idx = channel_idx
+
+ def __call__(self, **data_dict):
+ for b in range(data_dict[self.key].shape[0]):
+ if np.random.uniform() < self.p_per_sample:
+ # this needs to be applied in random order to the channels
+ np.random.shuffle(self.channel_idx)
+ for c in self.channel_idx:
+ if np.random.uniform() < self.p_per_label:
+ operation = np.random.choice(self.any_of_these)
+ selem = ball(np.random.uniform(*self.strel_size))
+ workon = data_dict[self.key][b, c].astype(bool)
+ if not np.any(workon):
+ continue
+ # print(np.unique(workon))
+ res = operation(workon, selem).astype(data_dict[self.key].dtype)
+ # print('ApplyRandomBinaryOperatorTransform', c, operation, np.sum(workon), np.sum(res))
+ data_dict[self.key][b, c] = res
+
+ # if class was added, we need to remove it in ALL other channels to keep one hot encoding
+ # properties
+ other_ch = [i for i in self.channel_idx if i != c]
+ if len(other_ch) > 0:
+ was_added_mask = (res - workon) > 0
+ for oc in other_ch:
+ data_dict[self.key][b, oc][was_added_mask] = 0
+ # if class was removed, leave it at background
+ return data_dict
diff --git a/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/deep_supervision_donwsampling.py b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/deep_supervision_donwsampling.py
new file mode 100644
index 0000000..6469ee2
--- /dev/null
+++ b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/deep_supervision_donwsampling.py
@@ -0,0 +1,55 @@
+from typing import Tuple, Union, List
+
+from batchgenerators.augmentations.utils import resize_segmentation
+from batchgenerators.transforms.abstract_transforms import AbstractTransform
+import numpy as np
+
+
+class DownsampleSegForDSTransform2(AbstractTransform):
+ '''
+ data_dict['output_key'] will be a list of segmentations scaled according to ds_scales
+ '''
+ def __init__(self, ds_scales: Union[List, Tuple],
+ order: int = 0, input_key: str = "seg",
+ output_key: str = "seg", axes: Tuple[int] = None):
+ """
+ Downscales data_dict[input_key] according to ds_scales. Each entry in ds_scales specified one deep supervision
+ output and its resolution relative to the original data, for example 0.25 specifies 1/4 of the original shape.
+ ds_scales can also be a tuple of tuples, for example ((1, 1, 1), (0.5, 0.5, 0.5)) to specify the downsampling
+ for each axis independently
+ """
+ self.axes = axes
+ self.output_key = output_key
+ self.input_key = input_key
+ self.order = order
+ self.ds_scales = ds_scales
+
+ def __call__(self, **data_dict):
+ if self.axes is None:
+ axes = list(range(2, len(data_dict[self.input_key].shape)))
+ else:
+ axes = self.axes
+
+ output = []
+ for s in self.ds_scales:
+ if not isinstance(s, (tuple, list)):
+ s = [s] * len(axes)
+ else:
+ assert len(s) == len(axes), f'If ds_scales is a tuple for each resolution (one downsampling factor ' \
+ f'for each axis) then the number of entried in that tuple (here ' \
+ f'{len(s)}) must be the same as the number of axes (here {len(axes)}).'
+
+ if all([i == 1 for i in s]):
+ output.append(data_dict[self.input_key])
+ else:
+ new_shape = np.array(data_dict[self.input_key].shape).astype(float)
+ for i, a in enumerate(axes):
+ new_shape[a] *= s[i]
+ new_shape = np.round(new_shape).astype(int)
+ out_seg = np.zeros(new_shape, dtype=data_dict[self.input_key].dtype)
+ for b in range(data_dict[self.input_key].shape[0]):
+ for c in range(data_dict[self.input_key].shape[1]):
+ out_seg[b, c] = resize_segmentation(data_dict[self.input_key][b, c], new_shape[2:], self.order)
+ output.append(out_seg)
+ data_dict[self.output_key] = output
+ return data_dict
diff --git a/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/limited_length_multithreaded_augmenter.py b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/limited_length_multithreaded_augmenter.py
new file mode 100644
index 0000000..dd8368c
--- /dev/null
+++ b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/limited_length_multithreaded_augmenter.py
@@ -0,0 +1,10 @@
+from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter
+
+
+class LimitedLenWrapper(NonDetMultiThreadedAugmenter):
+ def __init__(self, my_imaginary_length, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.len = my_imaginary_length
+
+ def __len__(self):
+ return self.len
diff --git a/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/manipulating_data_dict.py b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/manipulating_data_dict.py
new file mode 100644
index 0000000..587acd7
--- /dev/null
+++ b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/manipulating_data_dict.py
@@ -0,0 +1,10 @@
+from batchgenerators.transforms.abstract_transforms import AbstractTransform
+
+
+class RemoveKeyTransform(AbstractTransform):
+ def __init__(self, key_to_remove: str):
+ self.key_to_remove = key_to_remove
+
+ def __call__(self, **data_dict):
+ _ = data_dict.pop(self.key_to_remove, None)
+ return data_dict
diff --git a/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/masking.py b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/masking.py
new file mode 100644
index 0000000..b009993
--- /dev/null
+++ b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/masking.py
@@ -0,0 +1,22 @@
+from typing import List
+
+from batchgenerators.transforms.abstract_transforms import AbstractTransform
+
+
+class MaskTransform(AbstractTransform):
+ def __init__(self, apply_to_channels: List[int], mask_idx_in_seg: int = 0, set_outside_to: int = 0,
+ data_key: str = "data", seg_key: str = "seg"):
+ """
+ Sets everything outside the mask to 0. CAREFUL! outside is defined as < 0, not =0 (in the Mask)!!!
+ """
+ self.apply_to_channels = apply_to_channels
+ self.seg_key = seg_key
+ self.data_key = data_key
+ self.set_outside_to = set_outside_to
+ self.mask_idx_in_seg = mask_idx_in_seg
+
+ def __call__(self, **data_dict):
+ mask = data_dict[self.seg_key][:, self.mask_idx_in_seg] < 0
+ for c in self.apply_to_channels:
+ data_dict[self.data_key][:, c][mask] = self.set_outside_to
+ return data_dict
diff --git a/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py
new file mode 100644
index 0000000..52d2fc0
--- /dev/null
+++ b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py
@@ -0,0 +1,38 @@
+from typing import List, Tuple, Union
+
+from batchgenerators.transforms.abstract_transforms import AbstractTransform
+import numpy as np
+
+
+class ConvertSegmentationToRegionsTransform(AbstractTransform):
+ def __init__(self, regions: Union[List, Tuple],
+ seg_key: str = "seg", output_key: str = "seg", seg_channel: int = 0):
+ """
+ regions are tuple of tuples where each inner tuple holds the class indices that are merged into one region,
+ example:
+ regions= ((1, 2), (2, )) will result in 2 regions: one covering the region of labels 1&2 and the other just 2
+ :param regions:
+ :param seg_key:
+ :param output_key:
+ """
+ self.seg_channel = seg_channel
+ self.output_key = output_key
+ self.seg_key = seg_key
+ self.regions = regions
+
+ def __call__(self, **data_dict):
+ seg = data_dict.get(self.seg_key)
+ num_regions = len(self.regions)
+ if seg is not None:
+ seg_shp = seg.shape
+ output_shape = list(seg_shp)
+ output_shape[1] = num_regions
+ region_output = np.zeros(output_shape, dtype=seg.dtype)
+ for b in range(seg_shp[0]):
+ for region_id, region_source_labels in enumerate(self.regions):
+ if not isinstance(region_source_labels, (list, tuple)):
+ region_source_labels = (region_source_labels, )
+ for label_value in region_source_labels:
+ region_output[b, region_id][seg[b, self.seg_channel] == label_value] = 1
+ data_dict[self.output_key] = region_output
+ return data_dict
diff --git a/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/transforms_for_dummy_2d.py b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/transforms_for_dummy_2d.py
new file mode 100644
index 0000000..340fce7
--- /dev/null
+++ b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/transforms_for_dummy_2d.py
@@ -0,0 +1,45 @@
+from typing import Tuple, Union, List
+
+from batchgenerators.transforms.abstract_transforms import AbstractTransform
+
+
+class Convert3DTo2DTransform(AbstractTransform):
+ def __init__(self, apply_to_keys: Union[List[str], Tuple[str]] = ('data', 'seg')):
+ """
+ Transforms a 5D array (b, c, x, y, z) to a 4D array (b, c * x, y, z) by overloading the color channel
+ """
+ self.apply_to_keys = apply_to_keys
+
+ def __call__(self, **data_dict):
+ for k in self.apply_to_keys:
+ shp = data_dict[k].shape
+ assert len(shp) == 5, 'This transform only works on 3D data, so expects 5D tensor (b, c, x, y, z) as input.'
+ data_dict[k] = data_dict[k].reshape((shp[0], shp[1] * shp[2], shp[3], shp[4]))
+ shape_key = f'orig_shape_{k}'
+ assert shape_key not in data_dict.keys(), f'Convert3DTo2DTransform needs to store the original shape. ' \
+ f'It does that using the {shape_key} key. That key is ' \
+ f'already taken. Bummer.'
+ data_dict[shape_key] = shp
+ return data_dict
+
+
+class Convert2DTo3DTransform(AbstractTransform):
+ def __init__(self, apply_to_keys: Union[List[str], Tuple[str]] = ('data', 'seg')):
+ """
+ Reverts Convert3DTo2DTransform by transforming a 4D array (b, c * x, y, z) back to 5D (b, c, x, y, z)
+ """
+ self.apply_to_keys = apply_to_keys
+
+ def __call__(self, **data_dict):
+ for k in self.apply_to_keys:
+ shape_key = f'orig_shape_{k}'
+ assert shape_key in data_dict.keys(), f'Did not find key {shape_key} in data_dict. Shitty. ' \
+ f'Convert2DTo3DTransform only works in tandem with ' \
+ f'Convert3DTo2DTransform and you probably forgot to add ' \
+ f'Convert3DTo2DTransform to your pipeline. (Convert3DTo2DTransform ' \
+ f'is where the missing key is generated)'
+ original_shape = data_dict[shape_key]
+ current_shape = data_dict[k].shape
+ data_dict[k] = data_dict[k].reshape((original_shape[0], original_shape[1], original_shape[2],
+ current_shape[-2], current_shape[-1]))
+ return data_dict
diff --git a/nnUNet/nnunetv2/training/dataloading/__init__.py b/nnUNet/nnunetv2/training/dataloading/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/training/dataloading/base_data_loader.py b/nnUNet/nnunetv2/training/dataloading/base_data_loader.py
new file mode 100644
index 0000000..6a6a49f
--- /dev/null
+++ b/nnUNet/nnunetv2/training/dataloading/base_data_loader.py
@@ -0,0 +1,139 @@
+from typing import Union, Tuple
+
+from batchgenerators.dataloading.data_loader import DataLoader
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset
+from nnunetv2.utilities.label_handling.label_handling import LabelManager
+
+
+class nnUNetDataLoaderBase(DataLoader):
+ def __init__(self,
+ data: nnUNetDataset,
+ batch_size: int,
+ patch_size: Union[List[int], Tuple[int, ...], np.ndarray],
+ final_patch_size: Union[List[int], Tuple[int, ...], np.ndarray],
+ label_manager: LabelManager,
+ oversample_foreground_percent: float = 0.0,
+ sampling_probabilities: Union[List[int], Tuple[int, ...], np.ndarray] = None,
+ pad_sides: Union[List[int], Tuple[int, ...], np.ndarray] = None,
+ probabilistic_oversampling: bool = False):
+ super().__init__(data, batch_size, 1, None, True, False, True, sampling_probabilities)
+ assert isinstance(data, nnUNetDataset), 'nnUNetDataLoaderBase only supports dictionaries as data'
+ self.indices = list(data.keys())
+
+ self.oversample_foreground_percent = oversample_foreground_percent
+ self.final_patch_size = final_patch_size
+ self.patch_size = patch_size
+ self.list_of_keys = list(self._data.keys())
+ # need_to_pad denotes by how much we need to pad the data so that if we sample a patch of size final_patch_size
+ # (which is what the network will get) these patches will also cover the border of the images
+ self.need_to_pad = (np.array(patch_size) - np.array(final_patch_size)).astype(int)
+ if pad_sides is not None:
+ if not isinstance(pad_sides, np.ndarray):
+ pad_sides = np.array(pad_sides)
+ self.need_to_pad += pad_sides
+ self.num_channels = None
+ self.pad_sides = pad_sides
+ self.data_shape, self.seg_shape = self.determine_shapes()
+ self.sampling_probabilities = sampling_probabilities
+ self.annotated_classes_key = tuple(label_manager.all_labels)
+ self.has_ignore = label_manager.has_ignore_label
+ self.get_do_oversample = self._oversample_last_XX_percent if not probabilistic_oversampling \
+ else self._probabilistic_oversampling
+
+ def _oversample_last_XX_percent(self, sample_idx: int) -> bool:
+ """
+ determines whether sample sample_idx in a minibatch needs to be guaranteed foreground
+ """
+ return not sample_idx < round(self.batch_size * (1 - self.oversample_foreground_percent))
+
+ def _probabilistic_oversampling(self, sample_idx: int) -> bool:
+ # print('YEAH BOIIIIII')
+ return np.random.uniform() < self.oversample_foreground_percent
+
+ def determine_shapes(self):
+ # load one case
+ data, seg, properties = self._data.load_case(self.indices[0])
+ num_color_channels = data.shape[0]
+
+ data_shape = (self.batch_size, num_color_channels, *self.patch_size)
+ seg_shape = (self.batch_size, seg.shape[0], *self.patch_size)
+ return data_shape, seg_shape
+
+ def get_bbox(self, data_shape: np.ndarray, force_fg: bool, class_locations: Union[dict, None],
+ overwrite_class: Union[int, Tuple[int, ...]] = None, verbose: bool = False):
+ # in dataloader 2d we need to select the slice prior to this and also modify the class_locations to only have
+ # locations for the given slice
+ need_to_pad = self.need_to_pad.copy()
+ dim = len(data_shape)
+
+ for d in range(dim):
+ # if case_all_data.shape + need_to_pad is still < patch size we need to pad more! We pad on both sides
+ # always
+ if need_to_pad[d] + data_shape[d] < self.patch_size[d]:
+ need_to_pad[d] = self.patch_size[d] - data_shape[d]
+
+ # we can now choose the bbox from -need_to_pad // 2 to shape - patch_size + need_to_pad // 2. Here we
+ # define what the upper and lower bound can be to then sample form them with np.random.randint
+ lbs = [- need_to_pad[i] // 2 for i in range(dim)]
+ ubs = [data_shape[i] + need_to_pad[i] // 2 + need_to_pad[i] % 2 - self.patch_size[i] for i in range(dim)]
+
+ # if not force_fg then we can just sample the bbox randomly from lb and ub. Else we need to make sure we get
+ # at least one of the foreground classes in the patch
+ if not force_fg and not self.has_ignore:
+ bbox_lbs = [np.random.randint(lbs[i], ubs[i] + 1) for i in range(dim)]
+ # print('I want a random location')
+ else:
+ if not force_fg and self.has_ignore:
+ selected_class = self.annotated_classes_key
+ if len(class_locations[selected_class]) == 0:
+ # no annotated pixels in this case. Not good. But we can hardly skip it here
+ print('Warning! No annotated pixels in image!')
+ selected_class = None
+ # print(f'I have ignore labels and want to pick a labeled area. annotated_classes_key: {self.annotated_classes_key}')
+ elif force_fg:
+ assert class_locations is not None, 'if force_fg is set class_locations cannot be None'
+ if overwrite_class is not None:
+ assert overwrite_class in class_locations.keys(), 'desired class ("overwrite_class") does not ' \
+ 'have class_locations (missing key)'
+ # this saves us a np.unique. Preprocessing already did that for all cases. Neat.
+ # class_locations keys can also be tuple
+ eligible_classes_or_regions = [i for i in class_locations.keys() if len(class_locations[i]) > 0]
+
+ # if we have annotated_classes_key locations and other classes are present, remove the annotated_classes_key from the list
+ # strange formulation needed to circumvent
+ # ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
+ tmp = [i == self.annotated_classes_key if isinstance(i, tuple) else False for i in eligible_classes_or_regions]
+ if any(tmp):
+ if len(eligible_classes_or_regions) > 1:
+ eligible_classes_or_regions.pop(np.where(tmp)[0][0])
+
+ if len(eligible_classes_or_regions) == 0:
+ # this only happens if some image does not contain foreground voxels at all
+ selected_class = None
+ if verbose:
+ print('case does not contain any foreground classes')
+ else:
+ # I hate myself. Future me aint gonna be happy to read this
+ # 2022_11_25: had to read it today. Wasn't too bad
+ selected_class = eligible_classes_or_regions[np.random.choice(len(eligible_classes_or_regions))] if \
+ (overwrite_class is None or (overwrite_class not in eligible_classes_or_regions)) else overwrite_class
+ # print(f'I want to have foreground, selected class: {selected_class}')
+ else:
+ raise RuntimeError('lol what!?')
+ voxels_of_that_class = class_locations[selected_class] if selected_class is not None else None
+
+ if voxels_of_that_class is not None and len(voxels_of_that_class) > 0:
+ selected_voxel = voxels_of_that_class[np.random.choice(len(voxels_of_that_class))]
+ # selected voxel is center voxel. Subtract half the patch size to get lower bbox voxel.
+ # Make sure it is within the bounds of lb and ub
+ # i + 1 because we have first dimension 0!
+ bbox_lbs = [max(lbs[i], selected_voxel[i + 1] - self.patch_size[i] // 2) for i in range(dim)]
+ else:
+ # If the image does not contain any foreground classes, we fall back to random cropping
+ bbox_lbs = [np.random.randint(lbs[i], ubs[i] + 1) for i in range(dim)]
+
+ bbox_ubs = [bbox_lbs[i] + self.patch_size[i] for i in range(dim)]
+
+ return bbox_lbs, bbox_ubs
diff --git a/nnUNet/nnunetv2/training/dataloading/data_loader_2d.py b/nnUNet/nnunetv2/training/dataloading/data_loader_2d.py
new file mode 100644
index 0000000..b44004f
--- /dev/null
+++ b/nnUNet/nnunetv2/training/dataloading/data_loader_2d.py
@@ -0,0 +1,93 @@
+import numpy as np
+from nnunetv2.training.dataloading.base_data_loader import nnUNetDataLoaderBase
+from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset
+
+
+class nnUNetDataLoader2D(nnUNetDataLoaderBase):
+ def generate_train_batch(self):
+ selected_keys = self.get_indices()
+ # preallocate memory for data and seg
+ data_all = np.zeros(self.data_shape, dtype=np.float32)
+ seg_all = np.zeros(self.seg_shape, dtype=np.int16)
+ case_properties = []
+
+ for j, current_key in enumerate(selected_keys):
+ # oversampling foreground will improve stability of model training, especially if many patches are empty
+ # (Lung for example)
+ force_fg = self.get_do_oversample(j)
+ data, seg, properties = self._data.load_case(current_key)
+
+ # select a class/region first, then a slice where this class is present, then crop to that area
+ if not force_fg:
+ if self.has_ignore:
+ selected_class_or_region = self.annotated_classes_key
+ else:
+ selected_class_or_region = None
+ else:
+ # filter out all classes that are not present here
+ eligible_classes_or_regions = [i for i in properties['class_locations'].keys() if len(properties['class_locations'][i]) > 0]
+
+ # if we have annotated_classes_key locations and other classes are present, remove the annotated_classes_key from the list
+ # strange formulation needed to circumvent
+ # ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
+ tmp = [i == self.annotated_classes_key if isinstance(i, tuple) else False for i in eligible_classes_or_regions]
+ if any(tmp):
+ if len(eligible_classes_or_regions) > 1:
+ eligible_classes_or_regions.pop(np.where(tmp)[0][0])
+
+ selected_class_or_region = eligible_classes_or_regions[np.random.choice(len(eligible_classes_or_regions))] if \
+ len(eligible_classes_or_regions) > 0 else None
+ if selected_class_or_region is not None:
+ selected_slice = np.random.choice(properties['class_locations'][selected_class_or_region][:, 1])
+ else:
+ selected_slice = np.random.choice(len(data[0]))
+
+ data = data[:, selected_slice]
+ seg = seg[:, selected_slice]
+
+ # the line of death lol
+ # this needs to be a separate variable because we could otherwise permanently overwrite
+ # properties['class_locations']
+ # selected_class_or_region is:
+ # - None if we do not have an ignore label and force_fg is False OR if force_fg is True but there is no foreground in the image
+ # - A tuple of all (non-ignore) labels if there is an ignore label and force_fg is False
+ # - a class or region if force_fg is True
+ class_locations = {
+ selected_class_or_region: properties['class_locations'][selected_class_or_region][properties['class_locations'][selected_class_or_region][:, 1] == selected_slice][:, (0, 2, 3)]
+ } if (selected_class_or_region is not None) else None
+
+ # print(properties)
+ shape = data.shape[1:]
+ dim = len(shape)
+ bbox_lbs, bbox_ubs = self.get_bbox(shape, force_fg if selected_class_or_region is not None else None,
+ class_locations, overwrite_class=selected_class_or_region)
+
+ # whoever wrote this knew what he was doing (hint: it was me). We first crop the data to the region of the
+ # bbox that actually lies within the data. This will result in a smaller array which is then faster to pad.
+ # valid_bbox is just the coord that lied within the data cube. It will be padded to match the patch size
+ # later
+ valid_bbox_lbs = [max(0, bbox_lbs[i]) for i in range(dim)]
+ valid_bbox_ubs = [min(shape[i], bbox_ubs[i]) for i in range(dim)]
+
+ # At this point you might ask yourself why we would treat seg differently from seg_from_previous_stage.
+ # Why not just concatenate them here and forget about the if statements? Well that's because segneeds to
+ # be padded with -1 constant whereas seg_from_previous_stage needs to be padded with 0s (we could also
+ # remove label -1 in the data augmentation but this way it is less error prone)
+ this_slice = tuple([slice(0, data.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)])
+ data = data[this_slice]
+
+ this_slice = tuple([slice(0, seg.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)])
+ seg = seg[this_slice]
+
+ padding = [(-min(0, bbox_lbs[i]), max(bbox_ubs[i] - shape[i], 0)) for i in range(dim)]
+ data_all[j] = np.pad(data, ((0, 0), *padding), 'constant', constant_values=0)
+ seg_all[j] = np.pad(seg, ((0, 0), *padding), 'constant', constant_values=-1)
+
+ return {'data': data_all, 'seg': seg_all, 'properties': case_properties, 'keys': selected_keys}
+
+
+if __name__ == '__main__':
+ folder = '/media/fabian/data/nnUNet_preprocessed/Dataset004_Hippocampus/2d'
+ ds = nnUNetDataset(folder, None, 1000) # this should not load the properties!
+ dl = nnUNetDataLoader2D(ds, 366, (65, 65), (56, 40), 0.33, None, None)
+ a = next(dl)
diff --git a/nnUNet/nnunetv2/training/dataloading/data_loader_3d.py b/nnUNet/nnunetv2/training/dataloading/data_loader_3d.py
new file mode 100644
index 0000000..ab755e3
--- /dev/null
+++ b/nnUNet/nnunetv2/training/dataloading/data_loader_3d.py
@@ -0,0 +1,55 @@
+import numpy as np
+from nnunetv2.training.dataloading.base_data_loader import nnUNetDataLoaderBase
+from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset
+
+
+class nnUNetDataLoader3D(nnUNetDataLoaderBase):
+ def generate_train_batch(self):
+ selected_keys = self.get_indices()
+ # preallocate memory for data and seg
+ data_all = np.zeros(self.data_shape, dtype=np.float32)
+ seg_all = np.zeros(self.seg_shape, dtype=np.int16)
+ case_properties = []
+
+ for j, i in enumerate(selected_keys):
+ # oversampling foreground will improve stability of model training, especially if many patches are empty
+ # (Lung for example)
+ force_fg = self.get_do_oversample(j)
+
+ data, seg, properties = self._data.load_case(i)
+
+ # If we are doing the cascade then the segmentation from the previous stage will already have been loaded by
+ # self._data.load_case(i) (see nnUNetDataset.load_case)
+ shape = data.shape[1:]
+ dim = len(shape)
+ bbox_lbs, bbox_ubs = self.get_bbox(shape, force_fg, properties['class_locations'])
+
+ # whoever wrote this knew what he was doing (hint: it was me). We first crop the data to the region of the
+ # bbox that actually lies within the data. This will result in a smaller array which is then faster to pad.
+ # valid_bbox is just the coord that lied within the data cube. It will be padded to match the patch size
+ # later
+ valid_bbox_lbs = [max(0, bbox_lbs[i]) for i in range(dim)]
+ valid_bbox_ubs = [min(shape[i], bbox_ubs[i]) for i in range(dim)]
+
+ # At this point you might ask yourself why we would treat seg differently from seg_from_previous_stage.
+ # Why not just concatenate them here and forget about the if statements? Well that's because segneeds to
+ # be padded with -1 constant whereas seg_from_previous_stage needs to be padded with 0s (we could also
+ # remove label -1 in the data augmentation but this way it is less error prone)
+ this_slice = tuple([slice(0, data.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)])
+ data = data[this_slice]
+
+ this_slice = tuple([slice(0, seg.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)])
+ seg = seg[this_slice]
+
+ padding = [(-min(0, bbox_lbs[i]), max(bbox_ubs[i] - shape[i], 0)) for i in range(dim)]
+ data_all[j] = np.pad(data, ((0, 0), *padding), 'constant', constant_values=0)
+ seg_all[j] = np.pad(seg, ((0, 0), *padding), 'constant', constant_values=-1)
+
+ return {'data': data_all, 'seg': seg_all, 'properties': case_properties, 'keys': selected_keys}
+
+
+if __name__ == '__main__':
+ folder = '/media/fabian/data/nnUNet_preprocessed/Dataset002_Heart/3d_fullres'
+ ds = nnUNetDataset(folder, 0) # this should not load the properties!
+ dl = nnUNetDataLoader3D(ds, 5, (16, 16, 16), (16, 16, 16), 0.33, None, None)
+ a = next(dl)
diff --git a/nnUNet/nnunetv2/training/dataloading/nnunet_dataset.py b/nnUNet/nnunetv2/training/dataloading/nnunet_dataset.py
new file mode 100644
index 0000000..ae27fc3
--- /dev/null
+++ b/nnUNet/nnunetv2/training/dataloading/nnunet_dataset.py
@@ -0,0 +1,146 @@
+import os
+from typing import List
+
+import numpy as np
+import shutil
+
+from batchgenerators.utilities.file_and_folder_operations import join, load_pickle, isfile
+from nnunetv2.training.dataloading.utils import get_case_identifiers
+
+
+class nnUNetDataset(object):
+ def __init__(self, folder: str, case_identifiers: List[str] = None,
+ num_images_properties_loading_threshold: int = 0,
+ folder_with_segs_from_previous_stage: str = None):
+ """
+ This does not actually load the dataset. It merely creates a dictionary where the keys are training case names and
+ the values are dictionaries containing the relevant information for that case.
+ dataset[training_case] -> info
+ Info has the following key:value pairs:
+ - dataset[case_identifier]['properties']['data_file'] -> the full path to the npz file associated with the training case
+ - dataset[case_identifier]['properties']['properties_file'] -> the pkl file containing the case properties
+
+ In addition, if the total number of cases is < num_images_properties_loading_threshold we load all the pickle files
+ (containing auxiliary information). This is done for small datasets so that we don't spend too much CPU time on
+ reading pkl files on the fly during training. However, for large datasets storing all the aux info (which also
+ contains locations of foreground voxels in the images) can cause too much RAM utilization. In that
+ case is it better to load on the fly.
+
+ If properties are loaded into the RAM, the info dicts each will have an additional entry:
+ - dataset[case_identifier]['properties'] -> pkl file content
+
+ IMPORTANT! THIS CLASS ITSELF IS READ-ONLY. YOU CANNOT ADD KEY:VALUE PAIRS WITH nnUNetDataset[key] = value
+ USE THIS INSTEAD:
+ nnUNetDataset.dataset[key] = value
+ (not sure why you'd want to do that though. So don't do it)
+ """
+ super().__init__()
+ # print('loading dataset')
+ if case_identifiers is None:
+ case_identifiers = get_case_identifiers(folder)
+ case_identifiers.sort()
+
+ self.dataset = {}
+ for c in case_identifiers:
+ self.dataset[c] = {}
+ self.dataset[c]['data_file'] = join(folder, "%s.npz" % c)
+ self.dataset[c]['properties_file'] = join(folder, "%s.pkl" % c)
+ if folder_with_segs_from_previous_stage is not None:
+ self.dataset[c]['seg_from_prev_stage_file'] = join(folder_with_segs_from_previous_stage, "%s.npz" % c)
+
+ if len(case_identifiers) <= num_images_properties_loading_threshold:
+ for i in self.dataset.keys():
+ self.dataset[i]['properties'] = load_pickle(self.dataset[i]['properties_file'])
+
+ self.keep_files_open = ('nnUNet_keep_files_open' in os.environ.keys()) and \
+ (os.environ['nnUNet_keep_files_open'].lower() in ('true', '1', 't'))
+ # print(f'nnUNetDataset.keep_files_open: {self.keep_files_open}')
+
+ def __getitem__(self, key):
+ ret = {**self.dataset[key]}
+ if 'properties' not in ret.keys():
+ ret['properties'] = load_pickle(ret['properties_file'])
+ return ret
+
+ def __setitem__(self, key, value):
+ return self.dataset.__setitem__(key, value)
+
+ def keys(self):
+ return self.dataset.keys()
+
+ def __len__(self):
+ return self.dataset.__len__()
+
+ def items(self):
+ return self.dataset.items()
+
+ def values(self):
+ return self.dataset.values()
+
+ def load_case(self, key):
+ entry = self[key]
+ if 'open_data_file' in entry.keys():
+ data = entry['open_data_file']
+ # print('using open data file')
+ elif isfile(entry['data_file'][:-4] + ".npy"):
+ data = np.load(entry['data_file'][:-4] + ".npy", 'r')
+ if self.keep_files_open:
+ self.dataset[key]['open_data_file'] = data
+ # print('saving open data file')
+ else:
+ data = np.load(entry['data_file'])['data']
+
+ if 'open_seg_file' in entry.keys():
+ seg = entry['open_seg_file']
+ # print('using open data file')
+ elif isfile(entry['data_file'][:-4] + "_seg.npy"):
+ seg = np.load(entry['data_file'][:-4] + "_seg.npy", 'r')
+ if self.keep_files_open:
+ self.dataset[key]['open_seg_file'] = seg
+ # print('saving open seg file')
+ else:
+ seg = np.load(entry['data_file'])['seg']
+
+ if 'seg_from_prev_stage_file' in entry.keys():
+ if isfile(entry['seg_from_prev_stage_file'][:-4] + ".npy"):
+ seg_prev = np.load(entry['seg_from_prev_stage_file'][:-4] + ".npy", 'r')
+ else:
+ seg_prev = np.load(entry['seg_from_prev_stage_file'])['seg']
+ seg = np.vstack((seg, seg_prev[None]))
+
+ return data, seg, entry['properties']
+
+
+if __name__ == '__main__':
+ # this is a mini test. Todo: We can move this to tests in the future (requires simulated dataset)
+
+ folder = '/media/fabian/data/nnUNet_preprocessed/Dataset003_Liver/3d_lowres'
+ ds = nnUNetDataset(folder, num_images_properties_loading_threshold=0) # this should not load the properties!
+ # this SHOULD HAVE the properties
+ ks = ds['liver_0'].keys()
+ assert 'properties' in ks
+ # amazing. I am the best.
+
+ # this should have the properties
+ ds = nnUNetDataset(folder, num_images_properties_loading_threshold=1000)
+ # now rename the properties file so that it doesnt exist anymore
+ shutil.move(join(folder, 'liver_0.pkl'), join(folder, 'liver_XXX.pkl'))
+ # now we should still be able to access the properties because they have already been loaded
+ ks = ds['liver_0'].keys()
+ assert 'properties' in ks
+ # move file back
+ shutil.move(join(folder, 'liver_XXX.pkl'), join(folder, 'liver_0.pkl'))
+
+ # this should not have the properties
+ ds = nnUNetDataset(folder, num_images_properties_loading_threshold=0)
+ # now rename the properties file so that it doesnt exist anymore
+ shutil.move(join(folder, 'liver_0.pkl'), join(folder, 'liver_XXX.pkl'))
+ # now this should crash
+ try:
+ ks = ds['liver_0'].keys()
+ raise RuntimeError('we should not have come here')
+ except FileNotFoundError:
+ print('all good')
+ # move file back
+ shutil.move(join(folder, 'liver_XXX.pkl'), join(folder, 'liver_0.pkl'))
+
diff --git a/nnUNet/nnunetv2/training/dataloading/utils.py b/nnUNet/nnunetv2/training/dataloading/utils.py
new file mode 100644
index 0000000..bd145b4
--- /dev/null
+++ b/nnUNet/nnunetv2/training/dataloading/utils.py
@@ -0,0 +1,48 @@
+import multiprocessing
+import os
+from multiprocessing import Pool
+from typing import List
+
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import isfile, subfiles
+from nnunetv2.configuration import default_num_processes
+
+
+def _convert_to_npy(npz_file: str, unpack_segmentation: bool = True, overwrite_existing: bool = False) -> None:
+ try:
+ a = np.load(npz_file) # inexpensive, no compression is done here. This just reads metadata
+ if overwrite_existing or not isfile(npz_file[:-3] + "npy"):
+ np.save(npz_file[:-3] + "npy", a['data'])
+ if unpack_segmentation and (overwrite_existing or not isfile(npz_file[:-4] + "_seg.npy")):
+ np.save(npz_file[:-4] + "_seg.npy", a['seg'])
+ except KeyboardInterrupt:
+ if isfile(npz_file[:-3] + "npy"):
+ os.remove(npz_file[:-3] + "npy")
+ if isfile(npz_file[:-4] + "_seg.npy"):
+ os.remove(npz_file[:-4] + "_seg.npy")
+ raise KeyboardInterrupt
+
+
+def unpack_dataset(folder: str, unpack_segmentation: bool = True, overwrite_existing: bool = False,
+ num_processes: int = default_num_processes):
+ """
+ all npz files in this folder belong to the dataset, unpack them all
+ """
+ with multiprocessing.get_context("spawn").Pool(num_processes) as p:
+ npz_files = subfiles(folder, True, None, ".npz", True)
+ p.starmap(_convert_to_npy, zip(npz_files,
+ [unpack_segmentation] * len(npz_files),
+ [overwrite_existing] * len(npz_files))
+ )
+
+
+def get_case_identifiers(folder: str) -> List[str]:
+ """
+ finds all npz files in the given folder and reconstructs the training case names from them
+ """
+ case_identifiers = [i[:-4] for i in os.listdir(folder) if i.endswith("npz") and (i.find("segFromPrevStage") == -1)]
+ return case_identifiers
+
+
+if __name__ == '__main__':
+ unpack_dataset('/media/fabian/data/nnUNet_preprocessed/Dataset002_Heart/2d')
\ No newline at end of file
diff --git a/nnUNet/nnunetv2/training/logging/__init__.py b/nnUNet/nnunetv2/training/logging/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/training/logging/nnunet_logger.py b/nnUNet/nnunetv2/training/logging/nnunet_logger.py
new file mode 100644
index 0000000..8409738
--- /dev/null
+++ b/nnUNet/nnunetv2/training/logging/nnunet_logger.py
@@ -0,0 +1,103 @@
+import matplotlib
+from batchgenerators.utilities.file_and_folder_operations import join
+
+matplotlib.use('agg')
+import seaborn as sns
+import matplotlib.pyplot as plt
+
+
+class nnUNetLogger(object):
+ """
+ This class is really trivial. Don't expect cool functionality here. This is my makeshift solution to problems
+ arising from out-of-sync epoch numbers and numbers of logged loss values. It also simplifies the trainer class a
+ little
+
+ YOU MUST LOG EXACTLY ONE VALUE PER EPOCH FOR EACH OF THE LOGGING ITEMS! DONT FUCK IT UP
+ """
+ def __init__(self, verbose: bool = False):
+ self.my_fantastic_logging = {
+ 'mean_fg_dice': list(),
+ 'ema_fg_dice': list(),
+ 'dice_per_class_or_region': list(),
+ 'train_losses': list(),
+ 'val_losses': list(),
+ 'lrs': list(),
+ 'epoch_start_timestamps': list(),
+ 'epoch_end_timestamps': list()
+ }
+ self.verbose = verbose
+ # shut up, this logging is great
+
+ def log(self, key, value, epoch: int):
+ """
+ sometimes shit gets messed up. We try to catch that here
+ """
+ assert key in self.my_fantastic_logging.keys() and isinstance(self.my_fantastic_logging[key], list), \
+ 'This function is only intended to log stuff to lists and to have one entry per epoch'
+
+ if self.verbose: print(f'logging {key}: {value} for epoch {epoch}')
+
+ if len(self.my_fantastic_logging[key]) < (epoch + 1):
+ self.my_fantastic_logging[key].append(value)
+ else:
+ assert len(self.my_fantastic_logging[key]) == (epoch + 1), 'something went horribly wrong. My logging ' \
+ 'lists length is off by more than 1'
+ print(f'maybe some logging issue!? logging {key} and {value}')
+ self.my_fantastic_logging[key][epoch] = value
+
+ # handle the ema_fg_dice special case! It is automatically logged when we add a new mean_fg_dice
+ if key == 'mean_fg_dice':
+ new_ema_pseudo_dice = self.my_fantastic_logging['ema_fg_dice'][epoch - 1] * 0.9 + 0.1 * value \
+ if len(self.my_fantastic_logging['ema_fg_dice']) > 0 else value
+ self.log('ema_fg_dice', new_ema_pseudo_dice, epoch)
+
+ def plot_progress_png(self, output_folder):
+ # we infer the epoch form our internal logging
+ epoch = min([len(i) for i in self.my_fantastic_logging.values()]) - 1 # lists of epoch 0 have len 1
+ sns.set(font_scale=2.5)
+ fig, ax_all = plt.subplots(3, 1, figsize=(30, 54))
+ # regular progress.png as we are used to from previous nnU-Net versions
+ ax = ax_all[0]
+ ax2 = ax.twinx()
+ x_values = list(range(epoch + 1))
+ ax.plot(x_values, self.my_fantastic_logging['train_losses'][:epoch + 1], color='b', ls='-', label="loss_tr", linewidth=4)
+ ax.plot(x_values, self.my_fantastic_logging['val_losses'][:epoch + 1], color='r', ls='-', label="loss_val", linewidth=4)
+ ax2.plot(x_values, self.my_fantastic_logging['mean_fg_dice'][:epoch + 1], color='g', ls='dotted', label="pseudo dice",
+ linewidth=3)
+ ax2.plot(x_values, self.my_fantastic_logging['ema_fg_dice'][:epoch + 1], color='g', ls='-', label="pseudo dice (mov. avg.)",
+ linewidth=4)
+ ax.set_xlabel("epoch")
+ ax.set_ylabel("loss")
+ ax2.set_ylabel("pseudo dice")
+ ax.legend(loc=(0, 1))
+ ax2.legend(loc=(0.2, 1))
+
+ # epoch times to see whether the training speed is consistent (inconsistent means there are other jobs
+ # clogging up the system)
+ ax = ax_all[1]
+ ax.plot(x_values, [i - j for i, j in zip(self.my_fantastic_logging['epoch_end_timestamps'][:epoch + 1],
+ self.my_fantastic_logging['epoch_start_timestamps'])][:epoch + 1], color='b',
+ ls='-', label="epoch duration", linewidth=4)
+ ylim = [0] + [ax.get_ylim()[1]]
+ ax.set(ylim=ylim)
+ ax.set_xlabel("epoch")
+ ax.set_ylabel("time [s]")
+ ax.legend(loc=(0, 1))
+
+ # learning rate
+ ax = ax_all[2]
+ ax.plot(x_values, self.my_fantastic_logging['lrs'][:epoch + 1], color='b', ls='-', label="learning rate", linewidth=4)
+ ax.set_xlabel("epoch")
+ ax.set_ylabel("learning rate")
+ ax.legend(loc=(0, 1))
+
+ plt.tight_layout()
+
+ fig.savefig(join(output_folder, "progress.png"))
+ plt.close()
+
+ def get_checkpoint(self):
+ return self.my_fantastic_logging
+
+ def load_checkpoint(self, checkpoint: dict):
+ self.my_fantastic_logging = checkpoint
diff --git a/nnUNet/nnunetv2/training/loss/__init__.py b/nnUNet/nnunetv2/training/loss/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/training/loss/compound_losses.py b/nnUNet/nnunetv2/training/loss/compound_losses.py
new file mode 100644
index 0000000..9db0a42
--- /dev/null
+++ b/nnUNet/nnunetv2/training/loss/compound_losses.py
@@ -0,0 +1,151 @@
+import torch
+from nnunetv2.training.loss.dice import SoftDiceLoss, MemoryEfficientSoftDiceLoss
+from nnunetv2.training.loss.robust_ce_loss import RobustCrossEntropyLoss, TopKLoss
+from nnunetv2.utilities.helpers import softmax_helper_dim1
+from torch import nn
+
+
+class DC_and_CE_loss(nn.Module):
+ def __init__(self, soft_dice_kwargs, ce_kwargs, weight_ce=1, weight_dice=1, ignore_label=None,
+ dice_class=SoftDiceLoss):
+ """
+ Weights for CE and Dice do not need to sum to one. You can set whatever you want.
+ :param soft_dice_kwargs:
+ :param ce_kwargs:
+ :param aggregate:
+ :param square_dice:
+ :param weight_ce:
+ :param weight_dice:
+ """
+ super(DC_and_CE_loss, self).__init__()
+ if ignore_label is not None:
+ ce_kwargs['ignore_index'] = ignore_label
+
+ self.weight_dice = weight_dice
+ self.weight_ce = weight_ce
+ self.ignore_label = ignore_label
+
+ self.ce = RobustCrossEntropyLoss(**ce_kwargs)
+ self.dc = dice_class(apply_nonlin=softmax_helper_dim1, **soft_dice_kwargs)
+
+ def forward(self, net_output: torch.Tensor, target: torch.Tensor):
+ """
+ target must be b, c, x, y(, z) with c=1
+ :param net_output:
+ :param target:
+ :return:
+ """
+ if self.ignore_label is not None:
+ assert target.shape[1] == 1, 'ignore label is not implemented for one hot encoded target variables ' \
+ '(DC_and_CE_loss)'
+ mask = (target != self.ignore_label).bool()
+ # remove ignore label from target, replace with one of the known labels. It doesn't matter because we
+ # ignore gradients in those areas anyway
+ target_dice = torch.clone(target)
+ target_dice[target == self.ignore_label] = 0
+ num_fg = mask.sum()
+ else:
+ target_dice = target
+ mask = None
+
+ dc_loss = self.dc(net_output, target_dice, loss_mask=mask) \
+ if self.weight_dice != 0 else 0
+ ce_loss = self.ce(net_output, target[:, 0].long()) \
+ if self.weight_ce != 0 and (self.ignore_label is None or num_fg > 0) else 0
+
+ result = self.weight_ce * ce_loss + self.weight_dice * dc_loss
+ return result
+
+
+class DC_and_BCE_loss(nn.Module):
+ def __init__(self, bce_kwargs, soft_dice_kwargs, weight_ce=1, weight_dice=1, use_ignore_label: bool = False,
+ dice_class=MemoryEfficientSoftDiceLoss):
+ """
+ DO NOT APPLY NONLINEARITY IN YOUR NETWORK!
+
+ target mut be one hot encoded
+ IMPORTANT: We assume use_ignore_label is located in target[:, -1]!!!
+
+ :param soft_dice_kwargs:
+ :param bce_kwargs:
+ :param aggregate:
+ """
+ super(DC_and_BCE_loss, self).__init__()
+ if use_ignore_label:
+ bce_kwargs['reduction'] = 'none'
+
+ self.weight_dice = weight_dice
+ self.weight_ce = weight_ce
+ self.use_ignore_label = use_ignore_label
+
+ self.ce = nn.BCEWithLogitsLoss(**bce_kwargs)
+ self.dc = dice_class(apply_nonlin=torch.sigmoid, **soft_dice_kwargs)
+
+ def forward(self, net_output: torch.Tensor, target: torch.Tensor):
+ if self.use_ignore_label:
+ # target is one hot encoded here. invert it so that it is True wherever we can compute the loss
+ mask = (1 - target[:, -1:]).bool()
+ # remove ignore channel now that we have the mask
+ target_regions = torch.clone(target[:, :-1])
+ else:
+ target_regions = target
+ mask = None
+
+ dc_loss = self.dc(net_output, target_regions, loss_mask=mask)
+ if mask is not None:
+ ce_loss = (self.ce(net_output, target_regions) * mask).sum() / torch.clip(mask.sum(), min=1e-8)
+ else:
+ ce_loss = self.ce(net_output, target_regions)
+ result = self.weight_ce * ce_loss + self.weight_dice * dc_loss
+ return result
+
+
+class DC_and_topk_loss(nn.Module):
+ def __init__(self, soft_dice_kwargs, ce_kwargs, weight_ce=1, weight_dice=1, ignore_label=None):
+ """
+ Weights for CE and Dice do not need to sum to one. You can set whatever you want.
+ :param soft_dice_kwargs:
+ :param ce_kwargs:
+ :param aggregate:
+ :param square_dice:
+ :param weight_ce:
+ :param weight_dice:
+ """
+ super().__init__()
+ if ignore_label is not None:
+ ce_kwargs['ignore_index'] = ignore_label
+
+ self.weight_dice = weight_dice
+ self.weight_ce = weight_ce
+ self.ignore_label = ignore_label
+
+ self.ce = TopKLoss(**ce_kwargs)
+ self.dc = SoftDiceLoss(apply_nonlin=softmax_helper_dim1, **soft_dice_kwargs)
+
+ def forward(self, net_output: torch.Tensor, target: torch.Tensor):
+ """
+ target must be b, c, x, y(, z) with c=1
+ :param net_output:
+ :param target:
+ :return:
+ """
+ if self.ignore_label is not None:
+ assert target.shape[1] == 1, 'ignore label is not implemented for one hot encoded target variables ' \
+ '(DC_and_CE_loss)'
+ mask = (target != self.ignore_label).bool()
+ # remove ignore label from target, replace with one of the known labels. It doesn't matter because we
+ # ignore gradients in those areas anyway
+ target_dice = torch.clone(target)
+ target_dice[target == self.ignore_label] = 0
+ num_fg = mask.sum()
+ else:
+ target_dice = target
+ mask = None
+
+ dc_loss = self.dc(net_output, target_dice, loss_mask=mask) \
+ if self.weight_dice != 0 else 0
+ ce_loss = self.ce(net_output, target) \
+ if self.weight_ce != 0 and (self.ignore_label is None or num_fg > 0) else 0
+
+ result = self.weight_ce * ce_loss + self.weight_dice * dc_loss
+ return result
diff --git a/nnUNet/nnunetv2/training/loss/deep_supervision.py b/nnUNet/nnunetv2/training/loss/deep_supervision.py
new file mode 100644
index 0000000..db71e80
--- /dev/null
+++ b/nnUNet/nnunetv2/training/loss/deep_supervision.py
@@ -0,0 +1,35 @@
+from torch import nn
+
+
+class DeepSupervisionWrapper(nn.Module):
+ def __init__(self, loss, weight_factors=None):
+ """
+ Wraps a loss function so that it can be applied to multiple outputs. Forward accepts an arbitrary number of
+ inputs. Each input is expected to be a tuple/list. Each tuple/list must have the same length. The loss is then
+ applied to each entry like this:
+ l = w0 * loss(input0[0], input1[0], ...) + w1 * loss(input0[1], input1[1], ...) + ...
+ If weights are None, all w will be 1.
+ """
+ super(DeepSupervisionWrapper, self).__init__()
+ self.weight_factors = weight_factors
+ self.loss = loss
+
+ def forward(self, *args):
+ for i in args:
+ assert isinstance(i, (tuple, list)), "all args must be either tuple or list, got %s" % type(i)
+ # we could check for equal lengths here as well but we really shouldn't overdo it with checks because
+ # this code is executed a lot of times!
+
+ if self.weight_factors is None:
+ weights = [1] * len(args[0])
+ else:
+ weights = self.weight_factors
+
+ # we initialize the loss like this instead of 0 to ensure it sits on the correct device, not sure if that's
+ # really necessary
+ l = weights[0] * self.loss(*[j[0] for j in args])
+ for i, inputs in enumerate(zip(*args)):
+ if i == 0:
+ continue
+ l += weights[i] * self.loss(*inputs)
+ return l
\ No newline at end of file
diff --git a/nnUNet/nnunetv2/training/loss/dice.py b/nnUNet/nnunetv2/training/loss/dice.py
new file mode 100644
index 0000000..84c1e16
--- /dev/null
+++ b/nnUNet/nnunetv2/training/loss/dice.py
@@ -0,0 +1,194 @@
+from typing import Callable
+
+import torch
+from nnunetv2.utilities.ddp_allgather import AllGatherGrad
+from nnunetv2.utilities.tensor_utilities import sum_tensor
+from torch import nn
+
+
+class SoftDiceLoss(nn.Module):
+ def __init__(self, apply_nonlin: Callable = None, batch_dice: bool = False, do_bg: bool = True, smooth: float = 1.,
+ ddp: bool = True, clip_tp: float = None):
+ """
+ """
+ super(SoftDiceLoss, self).__init__()
+
+ self.do_bg = do_bg
+ self.batch_dice = batch_dice
+ self.apply_nonlin = apply_nonlin
+ self.smooth = smooth
+ self.clip_tp = clip_tp
+ self.ddp = ddp
+
+ def forward(self, x, y, loss_mask=None):
+ shp_x = x.shape
+
+ if self.batch_dice:
+ axes = [0] + list(range(2, len(shp_x)))
+ else:
+ axes = list(range(2, len(shp_x)))
+
+ if self.apply_nonlin is not None:
+ x = self.apply_nonlin(x)
+
+ tp, fp, fn, _ = get_tp_fp_fn_tn(x, y, axes, loss_mask, False)
+
+ if self.ddp and self.batch_dice:
+ tp = AllGatherGrad.apply(tp).sum(0)
+ fp = AllGatherGrad.apply(fp).sum(0)
+ fn = AllGatherGrad.apply(fn).sum(0)
+
+ if self.clip_tp is not None:
+ tp = torch.clip(tp, min=self.clip_tp , max=None)
+
+ nominator = 2 * tp
+ denominator = 2 * tp + fp + fn
+
+ dc = (nominator + self.smooth) / (torch.clip(denominator + self.smooth, 1e-8))
+
+ if not self.do_bg:
+ if self.batch_dice:
+ dc = dc[1:]
+ else:
+ dc = dc[:, 1:]
+ dc = dc.mean()
+
+ return -dc
+
+
+class MemoryEfficientSoftDiceLoss(nn.Module):
+ def __init__(self, apply_nonlin: Callable = None, batch_dice: bool = False, do_bg: bool = True, smooth: float = 1.,
+ ddp: bool = True):
+ """
+ saves 1.6 GB on Dataset017 3d_lowres
+ """
+ super(MemoryEfficientSoftDiceLoss, self).__init__()
+
+ self.do_bg = do_bg
+ self.batch_dice = batch_dice
+ self.apply_nonlin = apply_nonlin
+ self.smooth = smooth
+ self.ddp = ddp
+
+ def forward(self, x, y, loss_mask=None):
+ shp_x, shp_y = x.shape, y.shape
+
+ if self.apply_nonlin is not None:
+ x = self.apply_nonlin(x)
+
+ if not self.do_bg:
+ x = x[:, 1:]
+
+ # make everything shape (b, c)
+ axes = list(range(2, len(shp_x)))
+
+ with torch.no_grad():
+ if len(shp_x) != len(shp_y):
+ y = y.view((shp_y[0], 1, *shp_y[1:]))
+
+ if all([i == j for i, j in zip(shp_x, shp_y)]):
+ # if this is the case then gt is probably already a one hot encoding
+ y_onehot = y
+ else:
+ gt = y.long()
+ y_onehot = torch.zeros(shp_x, device=x.device, dtype=torch.bool)
+ y_onehot.scatter_(1, gt, 1)
+
+ if not self.do_bg:
+ y_onehot = y_onehot[:, 1:]
+ sum_gt = y_onehot.sum(axes) if loss_mask is None else (y_onehot * loss_mask).sum(axes)
+
+ intersect = (x * y_onehot).sum(axes) if loss_mask is None else (x * y_onehot * loss_mask).sum(axes)
+ sum_pred = x.sum(axes) if loss_mask is None else (x * loss_mask).sum(axes)
+
+ if self.ddp and self.batch_dice:
+ intersect = AllGatherGrad.apply(intersect).sum(0)
+ sum_pred = AllGatherGrad.apply(sum_pred).sum(0)
+ sum_gt = AllGatherGrad.apply(sum_gt).sum(0)
+
+ if self.batch_dice:
+ intersect = intersect.sum(0)
+ sum_pred = sum_pred.sum(0)
+ sum_gt = sum_gt.sum(0)
+
+ dc = (2 * intersect + self.smooth) / (torch.clip(sum_gt + sum_pred + self.smooth, 1e-8))
+
+ dc = dc.mean()
+ return -dc
+
+
+def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False):
+ """
+ net_output must be (b, c, x, y(, z)))
+ gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
+ if mask is provided it must have shape (b, 1, x, y(, z)))
+ :param net_output:
+ :param gt:
+ :param axes: can be (, ) = no summation
+ :param mask: mask must be 1 for valid pixels and 0 for invalid pixels
+ :param square: if True then fp, tp and fn will be squared before summation
+ :return:
+ """
+ if axes is None:
+ axes = tuple(range(2, len(net_output.size())))
+
+ shp_x = net_output.shape
+ shp_y = gt.shape
+
+ with torch.no_grad():
+ if len(shp_x) != len(shp_y):
+ gt = gt.view((shp_y[0], 1, *shp_y[1:]))
+
+ if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
+ # if this is the case then gt is probably already a one hot encoding
+ y_onehot = gt
+ else:
+ gt = gt.long()
+ y_onehot = torch.zeros(shp_x, device=net_output.device)
+ y_onehot.scatter_(1, gt, 1)
+
+ tp = net_output * y_onehot
+ fp = net_output * (1 - y_onehot)
+ fn = (1 - net_output) * y_onehot
+ tn = (1 - net_output) * (1 - y_onehot)
+
+ if mask is not None:
+ with torch.no_grad():
+ mask_here = torch.tile(mask, (1, tp.shape[1], *[1 for i in range(2, len(tp.shape))]))
+ tp *= mask_here
+ fp *= mask_here
+ fn *= mask_here
+ tn *= mask_here
+ # benchmark whether tiling the mask would be faster (torch.tile). It probably is for large batch sizes
+ # OK it barely makes a difference but the implementation above is a tiny bit faster + uses less vram
+ # (using nnUNetv2_train 998 3d_fullres 0)
+ # tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
+ # fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
+ # fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)
+ # tn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tn, dim=1)), dim=1)
+
+ if square:
+ tp = tp ** 2
+ fp = fp ** 2
+ fn = fn ** 2
+ tn = tn ** 2
+
+ if len(axes) > 0:
+ tp = sum_tensor(tp, axes, keepdim=False)
+ fp = sum_tensor(fp, axes, keepdim=False)
+ fn = sum_tensor(fn, axes, keepdim=False)
+ tn = sum_tensor(tn, axes, keepdim=False)
+
+ return tp, fp, fn, tn
+
+
+if __name__ == '__main__':
+ from nnunetv2.utilities.helpers import softmax_helper_dim1
+ pred = torch.rand((2, 3, 32, 32, 32))
+ ref = torch.randint(0, 3, (2, 32, 32, 32))
+
+ dl_old = SoftDiceLoss(apply_nonlin=softmax_helper_dim1, batch_dice=True, do_bg=False, smooth=0, ddp=False)
+ dl_new = MemoryEfficientSoftDiceLoss(apply_nonlin=softmax_helper_dim1, batch_dice=True, do_bg=False, smooth=0, ddp=False)
+ res_old = dl_old(pred, ref)
+ res_new = dl_new(pred, ref)
+ print(res_old, res_new)
diff --git a/nnUNet/nnunetv2/training/loss/robust_ce_loss.py b/nnUNet/nnunetv2/training/loss/robust_ce_loss.py
new file mode 100644
index 0000000..ad46659
--- /dev/null
+++ b/nnUNet/nnunetv2/training/loss/robust_ce_loss.py
@@ -0,0 +1,33 @@
+import torch
+from torch import nn, Tensor
+import numpy as np
+
+
+class RobustCrossEntropyLoss(nn.CrossEntropyLoss):
+ """
+ this is just a compatibility layer because my target tensor is float and has an extra dimension
+
+ input must be logits, not probabilities!
+ """
+ def forward(self, input: Tensor, target: Tensor) -> Tensor:
+ if len(target.shape) == len(input.shape):
+ assert target.shape[1] == 1
+ target = target[:, 0]
+ return super().forward(input, target.long())
+
+
+class TopKLoss(RobustCrossEntropyLoss):
+ """
+ input must be logits, not probabilities!
+ """
+ def __init__(self, weight=None, ignore_index: int = -100, k: float = 10, label_smoothing: float = 0):
+ self.k = k
+ super(TopKLoss, self).__init__(weight, False, ignore_index, reduce=False, label_smoothing=label_smoothing)
+
+ def forward(self, inp, target):
+ target = target[:, 0].long()
+ res = super(TopKLoss, self).forward(inp, target)
+ num_voxels = np.prod(res.shape, dtype=np.int64)
+ res, _ = torch.topk(res.view((-1, )), int(num_voxels * self.k / 100), sorted=False)
+ return res.mean()
+
diff --git a/nnUNet/nnunetv2/training/lr_scheduler/__init__.py b/nnUNet/nnunetv2/training/lr_scheduler/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/training/lr_scheduler/polylr.py b/nnUNet/nnunetv2/training/lr_scheduler/polylr.py
new file mode 100644
index 0000000..44857b5
--- /dev/null
+++ b/nnUNet/nnunetv2/training/lr_scheduler/polylr.py
@@ -0,0 +1,20 @@
+from torch.optim.lr_scheduler import _LRScheduler
+
+
+class PolyLRScheduler(_LRScheduler):
+ def __init__(self, optimizer, initial_lr: float, max_steps: int, exponent: float = 0.9, current_step: int = None):
+ self.optimizer = optimizer
+ self.initial_lr = initial_lr
+ self.max_steps = max_steps
+ self.exponent = exponent
+ self.ctr = 0
+ super().__init__(optimizer, current_step if current_step is not None else -1, False)
+
+ def step(self, current_step=None):
+ if current_step is None or current_step == -1:
+ current_step = self.ctr
+ self.ctr += 1
+
+ new_lr = self.initial_lr * (1 - current_step / self.max_steps) ** self.exponent
+ for param_group in self.optimizer.param_groups:
+ param_group['lr'] = new_lr
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/STUNetTrainer.py b/nnUNet/nnunetv2/training/nnUNetTrainer/STUNetTrainer.py
new file mode 100644
index 0000000..85587ac
--- /dev/null
+++ b/nnUNet/nnunetv2/training/nnUNetTrainer/STUNetTrainer.py
@@ -0,0 +1,254 @@
+from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
+import torch
+from torch import nn
+
+class STUNetTrainer(nnUNetTrainer):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.num_epochs = 1000
+ self.initial_lr = 1e-2
+
+ @staticmethod
+ def build_network_architecture(plans_manager,
+ dataset_json,
+ configuration_manager,
+ num_input_channels,
+ enable_deep_supervision: bool = True) -> nn.Module:
+ label_manager = plans_manager.get_label_manager(dataset_json)
+ num_classes=label_manager.num_segmentation_heads
+ kernel_sizes = [[3,3,3]] * 6
+ strides=configuration_manager.pool_op_kernel_sizes[1:]
+ if len(strides)>5:
+ strides = strides[:5]
+ while len(strides)<5:
+ strides.append([1,1,1])
+ return STUNet(num_input_channels, num_classes, depth=[1]*6, dims= [32 * x for x in [1, 2, 4, 8, 16, 16]],
+ pool_op_kernel_sizes=strides, conv_kernel_sizes=kernel_sizes, enable_deep_supervision=enable_deep_supervision)
+
+class STUNetTrainer_small(STUNetTrainer):
+ @staticmethod
+ def build_network_architecture(plans_manager,
+ dataset_json,
+ configuration_manager,
+ num_input_channels,
+ enable_deep_supervision: bool = True) -> nn.Module:
+ label_manager = plans_manager.get_label_manager(dataset_json)
+ num_classes=label_manager.num_segmentation_heads
+ kernel_sizes = [[3,3,3]] * 6
+ strides=configuration_manager.pool_op_kernel_sizes[1:]
+ if len(strides)>5:
+ strides = strides[:5]
+ while len(strides)<5:
+ strides.append([1,1,1])
+ return STUNet(num_input_channels, num_classes, depth=[1]*6, dims= [16 * x for x in [1, 2, 4, 8, 16, 16]],
+ pool_op_kernel_sizes=strides, conv_kernel_sizes=kernel_sizes, enable_deep_supervision=enable_deep_supervision)
+
+class STUNetTrainer_small_ft(STUNetTrainer_small):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.initial_lr = 1e-3
+ self.num_epochs = 1000
+
+
+class STUNetTrainer_base(STUNetTrainer):
+ @staticmethod
+ def build_network_architecture(plans_manager,
+ dataset_json,
+ configuration_manager,
+ num_input_channels,
+ enable_deep_supervision: bool = True) -> nn.Module:
+ label_manager = plans_manager.get_label_manager(dataset_json)
+ num_classes=label_manager.num_segmentation_heads
+ kernel_sizes = [[3,3,3]] * 6
+ strides=configuration_manager.pool_op_kernel_sizes[1:]
+ if len(strides)>5:
+ strides = strides[:5]
+ while len(strides)<5:
+ strides.append([1,1,1])
+ return STUNet(num_input_channels, num_classes, depth=[1]*6, dims= [32 * x for x in [1, 2, 4, 8, 16, 16]],
+ pool_op_kernel_sizes=strides, conv_kernel_sizes=kernel_sizes, enable_deep_supervision=enable_deep_supervision)
+
+class STUNetTrainer_base_ft(STUNetTrainer_base):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.initial_lr = 1e-3
+ self.num_epochs = 1000
+
+
+class STUNetTrainer_large(STUNetTrainer):
+ @staticmethod
+ def build_network_architecture(plans_manager,
+ dataset_json,
+ configuration_manager,
+ num_input_channels,
+ enable_deep_supervision: bool = True) -> nn.Module:
+ label_manager = plans_manager.get_label_manager(dataset_json)
+ num_classes=label_manager.num_segmentation_heads
+ kernel_sizes = [[3,3,3]] * 6
+ strides=configuration_manager.pool_op_kernel_sizes[1:]
+ if len(strides)>5:
+ strides = strides[:5]
+ while len(strides)<5:
+ strides.append([1,1,1])
+ return STUNet(num_input_channels, num_classes, depth=[2]*6, dims= [64 * x for x in [1, 2, 4, 8, 16, 16]],
+ pool_op_kernel_sizes=strides, conv_kernel_sizes=kernel_sizes, enable_deep_supervision=enable_deep_supervision)
+
+class STUNetTrainer_large_ft(STUNetTrainer_large):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.initial_lr = 1e-3
+ self.num_epochs = 1000
+
+
+class STUNetTrainer_huge(STUNetTrainer):
+ @staticmethod
+ def build_network_architecture(plans_manager,
+ dataset_json,
+ configuration_manager,
+ num_input_channels,
+ enable_deep_supervision: bool = True) -> nn.Module:
+ label_manager = plans_manager.get_label_manager(dataset_json)
+ num_classes=label_manager.num_segmentation_heads
+ kernel_sizes = [[3,3,3]] * 6
+ strides=configuration_manager.pool_op_kernel_sizes[1:]
+ if len(strides)>5:
+ strides = strides[:5]
+ while len(strides)<5:
+ strides.append([1,1,1])
+ return STUNet(num_input_channels, num_classes, depth=[3]*6, dims= [96 * x for x in [1, 2, 4, 8, 16, 16]],
+ pool_op_kernel_sizes=strides, conv_kernel_sizes=kernel_sizes, enable_deep_supervision=enable_deep_supervision)
+
+class STUNetTrainer_huge_ft(STUNetTrainer_huge):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.initial_lr = 1e-3
+ self.num_epochs = 1000
+
+
+
+class Decoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.deep_supervision = True
+
+class STUNet(nn.Module):
+ def __init__(self, input_channels, num_classes, depth=[1,1,1,1,1,1], dims=[32, 64, 128, 256, 512, 512],
+ pool_op_kernel_sizes=None, conv_kernel_sizes=None, enable_deep_supervision=True):
+ super().__init__()
+ self.conv_op = nn.Conv3d
+ self.input_channels = input_channels
+ self.num_classes = num_classes
+
+ self.final_nonlin = lambda x:x
+ self.decoder = Decoder()
+ self.decoder.deep_supervision = enable_deep_supervision
+ self.upscale_logits = False
+
+ self.pool_op_kernel_sizes = pool_op_kernel_sizes
+ self.conv_kernel_sizes = conv_kernel_sizes
+ self.conv_pad_sizes = []
+ for krnl in self.conv_kernel_sizes:
+ self.conv_pad_sizes.append([i // 2 for i in krnl])
+
+
+ num_pool = len(pool_op_kernel_sizes)
+
+ assert num_pool == len(dims) - 1
+
+ # encoder
+ self.conv_blocks_context = nn.ModuleList()
+ stage = nn.Sequential(BasicResBlock(input_channels, dims[0], self.conv_kernel_sizes[0], self.conv_pad_sizes[0], use_1x1conv=True),
+ *[BasicResBlock(dims[0], dims[0], self.conv_kernel_sizes[0], self.conv_pad_sizes[0]) for _ in range(depth[0]-1)])
+ self.conv_blocks_context.append(stage)
+ for d in range(1, num_pool+1):
+ stage = nn.Sequential(BasicResBlock(dims[d-1], dims[d], self.conv_kernel_sizes[d], self.conv_pad_sizes[d], stride=self.pool_op_kernel_sizes[d-1], use_1x1conv=True),
+ *[BasicResBlock(dims[d], dims[d], self.conv_kernel_sizes[d], self.conv_pad_sizes[d]) for _ in range(depth[d]-1)])
+ self.conv_blocks_context.append(stage)
+
+ # upsample_layers
+ self.upsample_layers = nn.ModuleList()
+ for u in range(num_pool):
+ upsample_layer = Upsample_Layer_nearest(dims[-1-u], dims[-2-u], pool_op_kernel_sizes[-1-u])
+ self.upsample_layers.append(upsample_layer)
+
+ # decoder
+ self.conv_blocks_localization = nn.ModuleList()
+ for u in range(num_pool):
+ stage = nn.Sequential(BasicResBlock(dims[-2-u] * 2, dims[-2-u], self.conv_kernel_sizes[-2-u], self.conv_pad_sizes[-2-u], use_1x1conv=True),
+ *[BasicResBlock(dims[-2-u], dims[-2-u], self.conv_kernel_sizes[-2-u], self.conv_pad_sizes[-2-u]) for _ in range(depth[-2-u]-1)])
+ self.conv_blocks_localization.append(stage)
+
+ # outputs
+ self.seg_outputs = nn.ModuleList()
+ for ds in range(len(self.conv_blocks_localization)):
+ self.seg_outputs.append(nn.Conv3d(dims[-2-ds], num_classes, kernel_size=1))
+
+ self.upscale_logits_ops = []
+ for usl in range(num_pool - 1):
+ self.upscale_logits_ops.append(lambda x: x)
+
+
+ def forward(self, x):
+ skips = []
+ seg_outputs = []
+
+ for d in range(len(self.conv_blocks_context) - 1):
+ x = self.conv_blocks_context[d](x)
+ skips.append(x)
+
+ x = self.conv_blocks_context[-1](x)
+
+ for u in range(len(self.conv_blocks_localization)):
+ x = self.upsample_layers[u](x)
+ x = torch.cat((x, skips[-(u + 1)]), dim=1)
+ x = self.conv_blocks_localization[u](x)
+ seg_outputs.append(self.final_nonlin(self.seg_outputs[u](x)))
+
+ if self.decoder.deep_supervision:
+ return tuple([seg_outputs[-1]] + [i(j) for i, j in
+ zip(list(self.upscale_logits_ops)[::-1], seg_outputs[:-1][::-1])])
+ else:
+ return seg_outputs[-1]
+
+
+class BasicResBlock(nn.Module):
+ def __init__(self, input_channels, output_channels, kernel_size=3, padding=1, stride=1, use_1x1conv=False):
+ super().__init__()
+ self.conv1 = nn.Conv3d(input_channels, output_channels, kernel_size, stride=stride, padding=padding)
+ self.norm1 = nn.InstanceNorm3d(output_channels, affine=True)
+ self.act1 = nn.LeakyReLU(inplace=True)
+
+ self.conv2 = nn.Conv3d(output_channels, output_channels, kernel_size, padding=padding)
+ self.norm2 = nn.InstanceNorm3d(output_channels, affine=True)
+ self.act2 = nn.LeakyReLU(inplace=True)
+
+ if use_1x1conv:
+ self.conv3 = nn.Conv3d(input_channels, output_channels, kernel_size=1, stride=stride)
+ else:
+ self.conv3 = None
+
+ def forward(self, x):
+ y = self.conv1(x)
+ y = self.act1(self.norm1(y))
+ y = self.norm2(self.conv2(y))
+ if self.conv3:
+ x = self.conv3(x)
+ y += x
+ return self.act2(y)
+
+class Upsample_Layer_nearest(nn.Module):
+ def __init__(self, input_channels, output_channels, pool_op_kernel_size, mode='nearest'):
+ super().__init__()
+ self.conv = nn.Conv3d(input_channels, output_channels, kernel_size=1)
+ self.pool_op_kernel_size = pool_op_kernel_size
+ self.mode = mode
+
+ def forward(self, x):
+ x = nn.functional.interpolate(x, scale_factor=self.pool_op_kernel_size, mode=self.mode)
+ x = self.conv(x)
+ return x
\ No newline at end of file
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/__init__.py b/nnUNet/nnunetv2/training/nnUNetTrainer/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py b/nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py
new file mode 100644
index 0000000..6dc1277
--- /dev/null
+++ b/nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py
@@ -0,0 +1,1241 @@
+import inspect
+import multiprocessing
+import os
+import shutil
+import sys
+import warnings
+from copy import deepcopy
+from datetime import datetime
+from time import time, sleep
+from typing import Union, Tuple, List
+
+import numpy as np
+import torch
+from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter
+from batchgenerators.transforms.abstract_transforms import AbstractTransform, Compose
+from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \
+ ContrastAugmentationTransform, GammaTransform
+from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform
+from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform
+from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform
+from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor
+from batchgenerators.utilities.file_and_folder_operations import join, load_json, isfile, save_json, maybe_mkdir_p
+from torch._dynamo import OptimizedModule
+
+from nnunetv2.configuration import ANISO_THRESHOLD, default_num_processes
+from nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder
+from nnunetv2.inference.export_prediction import export_prediction_from_logits, resample_and_save
+from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
+from nnunetv2.inference.sliding_window_prediction import compute_gaussian
+from nnunetv2.paths import nnUNet_preprocessed, nnUNet_results
+from nnunetv2.training.data_augmentation.compute_initial_patch_size import get_patch_size
+from nnunetv2.training.data_augmentation.custom_transforms.cascade_transforms import MoveSegAsOneHotToData, \
+ ApplyRandomBinaryOperatorTransform, RemoveRandomConnectedComponentFromOneHotEncodingTransform
+from nnunetv2.training.data_augmentation.custom_transforms.deep_supervision_donwsampling import \
+ DownsampleSegForDSTransform2
+from nnunetv2.training.data_augmentation.custom_transforms.limited_length_multithreaded_augmenter import \
+ LimitedLenWrapper
+from nnunetv2.training.data_augmentation.custom_transforms.masking import MaskTransform
+from nnunetv2.training.data_augmentation.custom_transforms.region_based_training import \
+ ConvertSegmentationToRegionsTransform
+from nnunetv2.training.data_augmentation.custom_transforms.transforms_for_dummy_2d import Convert2DTo3DTransform, \
+ Convert3DTo2DTransform
+from nnunetv2.training.dataloading.data_loader_2d import nnUNetDataLoader2D
+from nnunetv2.training.dataloading.data_loader_3d import nnUNetDataLoader3D
+from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset
+from nnunetv2.training.dataloading.utils import get_case_identifiers, unpack_dataset
+from nnunetv2.training.logging.nnunet_logger import nnUNetLogger
+from nnunetv2.training.loss.compound_losses import DC_and_CE_loss, DC_and_BCE_loss
+from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper
+from nnunetv2.training.loss.dice import get_tp_fp_fn_tn, MemoryEfficientSoftDiceLoss
+from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler
+from nnunetv2.utilities.collate_outputs import collate_outputs
+from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA
+from nnunetv2.utilities.file_path_utilities import check_workers_alive_and_busy
+from nnunetv2.utilities.get_network_from_plans import get_network_from_plans
+from nnunetv2.utilities.helpers import empty_cache, dummy_context
+from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot, determine_num_input_channels
+from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
+from sklearn.model_selection import KFold
+from torch import autocast, nn
+from torch import distributed as dist
+from torch.cuda import device_count
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+
+class nnUNetTrainer(object):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ # From https://grugbrain.dev/. Worth a read ya big brains ;-)
+
+ # apex predator of grug is complexity
+ # complexity bad
+ # say again:
+ # complexity very bad
+ # you say now:
+ # complexity very, very bad
+ # given choice between complexity or one on one against t-rex, grug take t-rex: at least grug see t-rex
+ # complexity is spirit demon that enter codebase through well-meaning but ultimately very clubbable non grug-brain developers and project managers who not fear complexity spirit demon or even know about sometime
+ # one day code base understandable and grug can get work done, everything good!
+ # next day impossible: complexity demon spirit has entered code and very dangerous situation!
+
+ # OK OK I am guilty. But I tried. http://tiny.cc/gzgwuz
+
+ self.is_ddp = dist.is_available() and dist.is_initialized()
+ self.local_rank = 0 if not self.is_ddp else dist.get_rank()
+
+ self.device = device
+
+ # print what device we are using
+ if self.is_ddp: # implicitly it's clear that we use cuda in this case
+ print(f"I am local rank {self.local_rank}. {device_count()} GPUs are available. The world size is "
+ f"{dist.get_world_size()}."
+ f"Setting device to {self.device}")
+ self.device = torch.device(type='cuda', index=self.local_rank)
+ else:
+ if self.device.type == 'cuda':
+ # we might want to let the user pick this but for now please pick the correct GPU with CUDA_VISIBLE_DEVICES=X
+ self.device = torch.device(type='cuda', index=0)
+ print(f"Using device: {self.device}")
+
+ # loading and saving this class for continuing from checkpoint should not happen based on pickling. This
+ # would also pickle the network etc. Bad, bad. Instead we just reinstantiate and then load the checkpoint we
+ # need. So let's save the init args
+ self.my_init_kwargs = {}
+ for k in inspect.signature(self.__init__).parameters.keys():
+ self.my_init_kwargs[k] = locals()[k]
+
+ ### Saving all the init args into class variables for later access
+ self.plans_manager = PlansManager(plans)
+ self.configuration_manager = self.plans_manager.get_configuration(configuration)
+ self.configuration_name = configuration
+ self.dataset_json = dataset_json
+ self.fold = fold
+ self.unpack_dataset = unpack_dataset
+
+ ### Setting all the folder names. We need to make sure things don't crash in case we are just running
+ # inference and some of the folders may not be defined!
+ self.preprocessed_dataset_folder_base = join(nnUNet_preprocessed, self.plans_manager.dataset_name) \
+ if nnUNet_preprocessed is not None else None
+ self.output_folder_base = join(nnUNet_results, self.plans_manager.dataset_name,
+ self.__class__.__name__ + '__' + self.plans_manager.plans_name + "__" + configuration) \
+ if nnUNet_results is not None else None
+ self.output_folder = join(self.output_folder_base, f'fold_{fold}')
+
+ self.preprocessed_dataset_folder = join(self.preprocessed_dataset_folder_base,
+ self.configuration_manager.data_identifier)
+ # unlike the previous nnunet folder_with_segs_from_previous_stage is now part of the plans. For now it has to
+ # be a different configuration in the same plans
+ # IMPORTANT! the mapping must be bijective, so lowres must point to fullres and vice versa (using
+ # "previous_stage" and "next_stage"). Otherwise it won't work!
+ self.is_cascaded = self.configuration_manager.previous_stage_name is not None
+ self.folder_with_segs_from_previous_stage = \
+ join(nnUNet_results, self.plans_manager.dataset_name,
+ self.__class__.__name__ + '__' + self.plans_manager.plans_name + "__" +
+ self.configuration_manager.previous_stage_name, 'predicted_next_stage', self.configuration_name) \
+ if self.is_cascaded else None
+
+ ### Some hyperparameters for you to fiddle with
+ self.initial_lr = 1e-2
+ self.weight_decay = 3e-5
+ self.oversample_foreground_percent = 0.33
+ self.num_iterations_per_epoch = 250
+ self.num_val_iterations_per_epoch = 50
+ self.num_epochs = 1000
+ self.current_epoch = 0
+
+ ### Dealing with labels/regions
+ self.label_manager = self.plans_manager.get_label_manager(dataset_json)
+ # labels can either be a list of int (regular training) or a list of tuples of int (region-based training)
+ # needed for predictions. We do sigmoid in case of (overlapping) regions
+
+ self.num_input_channels = None # -> self.initialize()
+ self.network = None # -> self._get_network()
+ self.optimizer = self.lr_scheduler = None # -> self.initialize
+ self.grad_scaler = GradScaler() if self.device.type == 'cuda' else None
+ self.loss = None # -> self.initialize
+
+ ### Simple logging. Don't take that away from me!
+ # initialize log file. This is just our log for the print statements etc. Not to be confused with lightning
+ # logging
+ timestamp = datetime.now()
+ maybe_mkdir_p(self.output_folder)
+ self.log_file = join(self.output_folder, "training_log_%d_%d_%d_%02.0d_%02.0d_%02.0d.txt" %
+ (timestamp.year, timestamp.month, timestamp.day, timestamp.hour, timestamp.minute,
+ timestamp.second))
+ self.logger = nnUNetLogger()
+
+ ### placeholders
+ self.dataloader_train = self.dataloader_val = None # see on_train_start
+
+ ### initializing stuff for remembering things and such
+ self._best_ema = None
+
+ ### inference things
+ self.inference_allowed_mirroring_axes = None # this variable is set in
+ # self.configure_rotation_dummyDA_mirroring_and_inital_patch_size and will be saved in checkpoints
+
+ ### checkpoint saving stuff
+ self.save_every = 50
+ self.disable_checkpointing = False
+
+ ## DDP batch size and oversampling can differ between workers and needs adaptation
+ # we need to change the batch size in DDP because we don't use any of those distributed samplers
+ self._set_batch_size_and_oversample()
+
+ self.was_initialized = False
+
+ self.print_to_log_file("\n#######################################################################\n"
+ "Please cite the following paper when using nnU-Net:\n"
+ "Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). "
+ "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. "
+ "Nature methods, 18(2), 203-211.\n"
+ "#######################################################################\n",
+ also_print_to_console=True, add_timestamp=False)
+
+ def initialize(self):
+ if not self.was_initialized:
+ self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager,
+ self.dataset_json)
+
+ self.network = self.build_network_architecture(self.plans_manager, self.dataset_json,
+ self.configuration_manager,
+ self.num_input_channels,
+ enable_deep_supervision=True).to(self.device)
+ # compile network for free speedup
+ if ('nnUNet_compile' in os.environ.keys()) and (
+ os.environ['nnUNet_compile'].lower() in ('true', '1', 't')):
+ self.print_to_log_file('Compiling network...')
+ self.network = torch.compile(self.network)
+
+ self.optimizer, self.lr_scheduler = self.configure_optimizers()
+ # if ddp, wrap in DDP wrapper
+ if self.is_ddp:
+ self.network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.network)
+ self.network = DDP(self.network, device_ids=[self.local_rank])
+
+ self.loss = self._build_loss()
+ self.was_initialized = True
+ else:
+ raise RuntimeError("You have called self.initialize even though the trainer was already initialized. "
+ "That should not happen.")
+
+ def _save_debug_information(self):
+ # saving some debug information
+ if self.local_rank == 0:
+ dct = {}
+ for k in self.__dir__():
+ if not k.startswith("__"):
+ if not callable(getattr(self, k)) or k in ['loss', ]:
+ dct[k] = str(getattr(self, k))
+ elif k in ['network', ]:
+ dct[k] = str(getattr(self, k).__class__.__name__)
+ else:
+ # print(k)
+ pass
+ if k in ['dataloader_train', 'dataloader_val']:
+ if hasattr(getattr(self, k), 'generator'):
+ dct[k + '.generator'] = str(getattr(self, k).generator)
+ if hasattr(getattr(self, k), 'num_processes'):
+ dct[k + '.num_processes'] = str(getattr(self, k).num_processes)
+ if hasattr(getattr(self, k), 'transform'):
+ dct[k + '.transform'] = str(getattr(self, k).transform)
+ import subprocess
+ hostname = subprocess.getoutput(['hostname'])
+ dct['hostname'] = hostname
+ torch_version = torch.__version__
+ if self.device.type == 'cuda':
+ gpu_name = torch.cuda.get_device_name()
+ dct['gpu_name'] = gpu_name
+ cudnn_version = torch.backends.cudnn.version()
+ else:
+ cudnn_version = 'None'
+ dct['device'] = str(self.device)
+ dct['torch_version'] = torch_version
+ dct['cudnn_version'] = cudnn_version
+ save_json(dct, join(self.output_folder, "debug.json"))
+
+ @staticmethod
+ def build_network_architecture(plans_manager: PlansManager,
+ dataset_json,
+ configuration_manager: ConfigurationManager,
+ num_input_channels,
+ enable_deep_supervision: bool = True) -> nn.Module:
+ """
+ his is where you build the architecture according to the plans. There is no obligation to use
+ get_network_from_plans, this is just a utility we use for the nnU-Net default architectures. You can do what
+ you want. Even ignore the plans and just return something static (as long as it can process the requested
+ patch size)
+ but don't bug us with your bugs arising from fiddling with this :-P
+ This is the function that is called in inference as well! This is needed so that all network architecture
+ variants can be loaded at inference time (inference will use the same nnUNetTrainer that was used for
+ training, so if you change the network architecture during training by deriving a new trainer class then
+ inference will know about it).
+
+ If you need to know how many segmentation outputs your custom architecture needs to have, use the following snippet:
+ > label_manager = plans_manager.get_label_manager(dataset_json)
+ > label_manager.num_segmentation_heads
+ (why so complicated? -> We can have either classical training (classes) or regions. If we have regions,
+ the number of outputs is != the number of classes. Also there is the ignore label for which no output
+ should be generated. label_manager takes care of all that for you.)
+
+ """
+ return get_network_from_plans(plans_manager, dataset_json, configuration_manager,
+ num_input_channels, deep_supervision=enable_deep_supervision)
+
+ def _get_deep_supervision_scales(self):
+ deep_supervision_scales = list(list(i) for i in 1 / np.cumprod(np.vstack(
+ self.configuration_manager.pool_op_kernel_sizes), axis=0))[:-1]
+ return deep_supervision_scales
+
+ def _set_batch_size_and_oversample(self):
+ if not self.is_ddp:
+ # set batch size to what the plan says, leave oversample untouched
+ self.batch_size = self.configuration_manager.batch_size
+ else:
+ # batch size is distributed over DDP workers and we need to change oversample_percent for each worker
+ batch_sizes = []
+ oversample_percents = []
+
+ world_size = dist.get_world_size()
+ my_rank = dist.get_rank()
+
+ global_batch_size = self.configuration_manager.batch_size
+ assert global_batch_size >= world_size, 'Cannot run DDP if the batch size is smaller than the number of ' \
+ 'GPUs... Duh.'
+
+ batch_size_per_GPU = np.ceil(global_batch_size / world_size).astype(int)
+
+ for rank in range(world_size):
+ if (rank + 1) * batch_size_per_GPU > global_batch_size:
+ batch_size = batch_size_per_GPU - ((rank + 1) * batch_size_per_GPU - global_batch_size)
+ else:
+ batch_size = batch_size_per_GPU
+
+ batch_sizes.append(batch_size)
+
+ sample_id_low = 0 if len(batch_sizes) == 0 else np.sum(batch_sizes[:-1])
+ sample_id_high = np.sum(batch_sizes)
+
+ if sample_id_high / global_batch_size < (1 - self.oversample_foreground_percent):
+ oversample_percents.append(0.0)
+ elif sample_id_low / global_batch_size > (1 - self.oversample_foreground_percent):
+ oversample_percents.append(1.0)
+ else:
+ percent_covered_by_this_rank = sample_id_high / global_batch_size - sample_id_low / global_batch_size
+ oversample_percent_here = 1 - (((1 - self.oversample_foreground_percent) -
+ sample_id_low / global_batch_size) / percent_covered_by_this_rank)
+ oversample_percents.append(oversample_percent_here)
+
+ print("worker", my_rank, "oversample", oversample_percents[my_rank])
+ print("worker", my_rank, "batch_size", batch_sizes[my_rank])
+ # self.print_to_log_file("worker", my_rank, "oversample", oversample_percents[my_rank])
+ # self.print_to_log_file("worker", my_rank, "batch_size", batch_sizes[my_rank])
+
+ self.batch_size = batch_sizes[my_rank]
+ self.oversample_foreground_percent = oversample_percents[my_rank]
+
+ def _build_loss(self):
+ if self.label_manager.has_regions:
+ loss = DC_and_BCE_loss({},
+ {'batch_dice': self.configuration_manager.batch_dice,
+ 'do_bg': True, 'smooth': 1e-5, 'ddp': self.is_ddp},
+ use_ignore_label=self.label_manager.ignore_label is not None,
+ dice_class=MemoryEfficientSoftDiceLoss)
+ else:
+ loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice,
+ 'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1,
+ ignore_label=self.label_manager.ignore_label, dice_class=MemoryEfficientSoftDiceLoss)
+
+ deep_supervision_scales = self._get_deep_supervision_scales()
+
+ # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
+ # this gives higher resolution outputs more weight in the loss
+ weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))])
+ weights[-1] = 0
+
+ # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
+ weights = weights / weights.sum()
+ # now wrap the loss
+ loss = DeepSupervisionWrapper(loss, weights)
+ return loss
+
+ def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
+ """
+ This function is stupid and certainly one of the weakest spots of this implementation. Not entirely sure how we can fix it.
+ """
+ patch_size = self.configuration_manager.patch_size
+ dim = len(patch_size)
+ # todo rotation should be defined dynamically based on patch size (more isotropic patch sizes = more rotation)
+ if dim == 2:
+ do_dummy_2d_data_aug = False
+ # todo revisit this parametrization
+ if max(patch_size) / min(patch_size) > 1.5:
+ rotation_for_DA = {
+ 'x': (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi),
+ 'y': (0, 0),
+ 'z': (0, 0)
+ }
+ else:
+ rotation_for_DA = {
+ 'x': (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi),
+ 'y': (0, 0),
+ 'z': (0, 0)
+ }
+ mirror_axes = (0, 1)
+ elif dim == 3:
+ # todo this is not ideal. We could also have patch_size (64, 16, 128) in which case a full 180deg 2d rot would be bad
+ # order of the axes is determined by spacing, not image size
+ do_dummy_2d_data_aug = (max(patch_size) / patch_size[0]) > ANISO_THRESHOLD
+ if do_dummy_2d_data_aug:
+ # why do we rotate 180 deg here all the time? We should also restrict it
+ rotation_for_DA = {
+ 'x': (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi),
+ 'y': (0, 0),
+ 'z': (0, 0)
+ }
+ else:
+ rotation_for_DA = {
+ 'x': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi),
+ 'y': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi),
+ 'z': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi),
+ }
+ mirror_axes = (1, 2)
+ else:
+ raise RuntimeError()
+
+ # todo this function is stupid. It doesn't even use the correct scale range (we keep things as they were in the
+ # old nnunet for now)
+ initial_patch_size = get_patch_size(patch_size[-dim:],
+ *rotation_for_DA.values(),
+ (0.85, 1.25))
+ if do_dummy_2d_data_aug:
+ initial_patch_size[0] = patch_size[0]
+
+ self.print_to_log_file(f'do_dummy_2d_data_aug: {do_dummy_2d_data_aug}')
+ self.inference_allowed_mirroring_axes = mirror_axes
+
+ return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
+
+ def print_to_log_file(self, *args, also_print_to_console=True, add_timestamp=True):
+ if self.local_rank == 0:
+ timestamp = time()
+ dt_object = datetime.fromtimestamp(timestamp)
+
+ if add_timestamp:
+ args = ("%s:" % dt_object, *args)
+
+ successful = False
+ max_attempts = 5
+ ctr = 0
+ while not successful and ctr < max_attempts:
+ try:
+ with open(self.log_file, 'a+') as f:
+ for a in args:
+ f.write(str(a))
+ f.write(" ")
+ f.write("\n")
+ successful = True
+ except IOError:
+ print("%s: failed to log: " % datetime.fromtimestamp(timestamp), sys.exc_info())
+ sleep(0.5)
+ ctr += 1
+ if also_print_to_console:
+ print(*args)
+ elif also_print_to_console:
+ print(*args)
+
+ def print_plans(self):
+ if self.local_rank == 0:
+ dct = deepcopy(self.plans_manager.plans)
+ del dct['configurations']
+ self.print_to_log_file(f"\nThis is the configuration used by this "
+ f"training:\nConfiguration name: {self.configuration_name}\n",
+ self.configuration_manager, '\n', add_timestamp=False)
+ self.print_to_log_file('These are the global plan.json settings:\n', dct, '\n', add_timestamp=False)
+
+ def configure_optimizers(self):
+ optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
+ momentum=0.99, nesterov=True)
+ lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs, exponent=3.)
+ return optimizer, lr_scheduler
+
+ def plot_network_architecture(self):
+ if self.local_rank == 0:
+ try:
+ # raise NotImplementedError('hiddenlayer no longer works and we do not have a viable alternative :-(')
+ # pip install git+https://github.com/saugatkandel/hiddenlayer.git
+
+ # from torchviz import make_dot
+ # # not viable.
+ # make_dot(tuple(self.network(torch.rand((1, self.num_input_channels,
+ # *self.configuration_manager.patch_size),
+ # device=self.device)))).render(
+ # join(self.output_folder, "network_architecture.pdf"), format='pdf')
+ # self.optimizer.zero_grad()
+
+ # broken.
+
+ import hiddenlayer as hl
+ g = hl.build_graph(self.network,
+ torch.rand((1, self.num_input_channels,
+ *self.configuration_manager.patch_size),
+ device=self.device),
+ transforms=None)
+ g.save(join(self.output_folder, "network_architecture.pdf"))
+ del g
+ except Exception as e:
+ self.print_to_log_file("Unable to plot network architecture:")
+ self.print_to_log_file(e)
+
+ # self.print_to_log_file("\nprinting the network instead:\n")
+ # self.print_to_log_file(self.network)
+ # self.print_to_log_file("\n")
+ finally:
+ empty_cache(self.device)
+
+ def do_split(self):
+ """
+ The default split is a 5 fold CV on all available training cases. nnU-Net will create a split (it is seeded,
+ so always the same) and save it as splits_final.pkl file in the preprocessed data directory.
+ Sometimes you may want to create your own split for various reasons. For this you will need to create your own
+ splits_final.pkl file. If this file is present, nnU-Net is going to use it and whatever splits are defined in
+ it. You can create as many splits in this file as you want. Note that if you define only 4 splits (fold 0-3)
+ and then set fold=4 when training (that would be the fifth split), nnU-Net will print a warning and proceed to
+ use a random 80:20 data split.
+ :return:
+ """
+ if self.fold == "all":
+ # if fold==all then we use all images for training and validation
+ case_identifiers = get_case_identifiers(self.preprocessed_dataset_folder)
+ tr_keys = case_identifiers
+ val_keys = tr_keys
+ else:
+ splits_file = join(self.preprocessed_dataset_folder_base, "splits_final.json")
+ dataset = nnUNetDataset(self.preprocessed_dataset_folder, case_identifiers=None,
+ num_images_properties_loading_threshold=0,
+ folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage)
+ # if the split file does not exist we need to create it
+ if not isfile(splits_file):
+ self.print_to_log_file("Creating new 5-fold cross-validation split...")
+ splits = []
+ all_keys_sorted = np.sort(list(dataset.keys()))
+ kfold = KFold(n_splits=5, shuffle=True, random_state=12345)
+ for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)):
+ train_keys = np.array(all_keys_sorted)[train_idx]
+ test_keys = np.array(all_keys_sorted)[test_idx]
+ splits.append({})
+ splits[-1]['train'] = list(train_keys)
+ splits[-1]['val'] = list(test_keys)
+ save_json(splits, splits_file)
+
+ else:
+ self.print_to_log_file("Using splits from existing split file:", splits_file)
+ splits = load_json(splits_file)
+ self.print_to_log_file("The split file contains %d splits." % len(splits))
+
+ self.print_to_log_file("Desired fold for training: %d" % self.fold)
+ if self.fold < len(splits):
+ tr_keys = splits[self.fold]['train']
+ val_keys = splits[self.fold]['val']
+ self.print_to_log_file("This split has %d training and %d validation cases."
+ % (len(tr_keys), len(val_keys)))
+ else:
+ self.print_to_log_file("INFO: You requested fold %d for training but splits "
+ "contain only %d folds. I am now creating a "
+ "random (but seeded) 80:20 split!" % (self.fold, len(splits)))
+ # if we request a fold that is not in the split file, create a random 80:20 split
+ rnd = np.random.RandomState(seed=12345 + self.fold)
+ keys = np.sort(list(dataset.keys()))
+ idx_tr = rnd.choice(len(keys), int(len(keys) * 0.8), replace=False)
+ idx_val = [i for i in range(len(keys)) if i not in idx_tr]
+ tr_keys = [keys[i] for i in idx_tr]
+ val_keys = [keys[i] for i in idx_val]
+ self.print_to_log_file("This random 80:20 split has %d training and %d validation cases."
+ % (len(tr_keys), len(val_keys)))
+ if any([i in val_keys for i in tr_keys]):
+ self.print_to_log_file('WARNING: Some validation cases are also in the training set. Please check the '
+ 'splits.json or ignore if this is intentional.')
+ return tr_keys, val_keys
+
+ def get_tr_and_val_datasets(self):
+ # create dataset split
+ tr_keys, val_keys = self.do_split()
+
+ # load the datasets for training and validation. Note that we always draw random samples so we really don't
+ # care about distributing training cases across GPUs.
+ dataset_tr = nnUNetDataset(self.preprocessed_dataset_folder, tr_keys,
+ folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage,
+ num_images_properties_loading_threshold=0)
+ dataset_val = nnUNetDataset(self.preprocessed_dataset_folder, val_keys,
+ folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage,
+ num_images_properties_loading_threshold=0)
+ return dataset_tr, dataset_val
+
+ def get_dataloaders(self):
+ # we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether
+ # we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be
+ patch_size = self.configuration_manager.patch_size
+ dim = len(patch_size)
+
+ # needed for deep supervision: how much do we need to downscale the segmentation targets for the different
+ # outputs?
+ deep_supervision_scales = self._get_deep_supervision_scales()
+
+ rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
+ self.configure_rotation_dummyDA_mirroring_and_inital_patch_size()
+
+ # training pipeline
+ tr_transforms = self.get_training_transforms(
+ patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug,
+ order_resampling_data=3, order_resampling_seg=1,
+ use_mask_for_norm=self.configuration_manager.use_mask_for_norm,
+ is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.foreground_labels,
+ regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None,
+ ignore_label=self.label_manager.ignore_label)
+
+ # validation pipeline
+ val_transforms = self.get_validation_transforms(deep_supervision_scales,
+ is_cascaded=self.is_cascaded,
+ foreground_labels=self.label_manager.foreground_labels,
+ regions=self.label_manager.foreground_regions if
+ self.label_manager.has_regions else None,
+ ignore_label=self.label_manager.ignore_label)
+
+ dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim)
+
+ allowed_num_processes = get_allowed_n_proc_DA()
+ if allowed_num_processes == 0:
+ mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms)
+ mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms)
+ else:
+ mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, data_loader=dl_tr, transform=tr_transforms,
+ num_processes=allowed_num_processes, num_cached=6, seeds=None,
+ pin_memory=self.device.type == 'cuda', wait_time=0.02)
+ mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, data_loader=dl_val,
+ transform=val_transforms, num_processes=max(1, allowed_num_processes // 2),
+ num_cached=3, seeds=None, pin_memory=self.device.type == 'cuda',
+ wait_time=0.02)
+ return mt_gen_train, mt_gen_val
+
+ def get_plain_dataloaders(self, initial_patch_size: Tuple[int, ...], dim: int):
+ dataset_tr, dataset_val = self.get_tr_and_val_datasets()
+
+ if dim == 2:
+ dl_tr = nnUNetDataLoader2D(dataset_tr, self.batch_size,
+ initial_patch_size,
+ self.configuration_manager.patch_size,
+ self.label_manager,
+ oversample_foreground_percent=self.oversample_foreground_percent,
+ sampling_probabilities=None, pad_sides=None)
+ dl_val = nnUNetDataLoader2D(dataset_val, self.batch_size,
+ self.configuration_manager.patch_size,
+ self.configuration_manager.patch_size,
+ self.label_manager,
+ oversample_foreground_percent=self.oversample_foreground_percent,
+ sampling_probabilities=None, pad_sides=None)
+ else:
+ dl_tr = nnUNetDataLoader3D(dataset_tr, self.batch_size,
+ initial_patch_size,
+ self.configuration_manager.patch_size,
+ self.label_manager,
+ oversample_foreground_percent=self.oversample_foreground_percent,
+ sampling_probabilities=None, pad_sides=None)
+ dl_val = nnUNetDataLoader3D(dataset_val, self.batch_size,
+ self.configuration_manager.patch_size,
+ self.configuration_manager.patch_size,
+ self.label_manager,
+ oversample_foreground_percent=self.oversample_foreground_percent,
+ sampling_probabilities=None, pad_sides=None)
+ return dl_tr, dl_val
+
+ @staticmethod
+ def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]],
+ rotation_for_DA: dict,
+ deep_supervision_scales: Union[List, Tuple],
+ mirror_axes: Tuple[int, ...],
+ do_dummy_2d_data_aug: bool,
+ order_resampling_data: int = 3,
+ order_resampling_seg: int = 1,
+ border_val_seg: int = -1,
+ use_mask_for_norm: List[bool] = None,
+ is_cascaded: bool = False,
+ foreground_labels: Union[Tuple[int, ...], List[int]] = None,
+ regions: List[Union[List[int], Tuple[int, ...], int]] = None,
+ ignore_label: int = None) -> AbstractTransform:
+ tr_transforms = []
+ if do_dummy_2d_data_aug:
+ ignore_axes = (0,)
+ tr_transforms.append(Convert3DTo2DTransform())
+ patch_size_spatial = patch_size[1:]
+ else:
+ patch_size_spatial = patch_size
+ ignore_axes = None
+
+ tr_transforms.append(SpatialTransform(
+ patch_size_spatial, patch_center_dist_from_border=None,
+ do_elastic_deform=False, alpha=(0, 0), sigma=(0, 0),
+ do_rotation=True, angle_x=rotation_for_DA['x'], angle_y=rotation_for_DA['y'], angle_z=rotation_for_DA['z'],
+ p_rot_per_axis=1, # todo experiment with this
+ do_scale=True, scale=(0.7, 1.4),
+ border_mode_data="constant", border_cval_data=0, order_data=order_resampling_data,
+ border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=order_resampling_seg,
+ random_crop=False, # random cropping is part of our dataloaders
+ p_el_per_sample=0, p_scale_per_sample=0.2, p_rot_per_sample=0.2,
+ independent_scale_for_each_axis=False # todo experiment with this
+ ))
+
+ if do_dummy_2d_data_aug:
+ tr_transforms.append(Convert2DTo3DTransform())
+
+ tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))
+ tr_transforms.append(GaussianBlurTransform((0.5, 1.), different_sigma_per_channel=True, p_per_sample=0.2,
+ p_per_channel=0.5))
+ tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25), p_per_sample=0.15))
+ tr_transforms.append(ContrastAugmentationTransform(p_per_sample=0.15))
+ tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True,
+ p_per_channel=0.5,
+ order_downsample=0, order_upsample=3, p_per_sample=0.25,
+ ignore_axes=ignore_axes))
+ tr_transforms.append(GammaTransform((0.7, 1.5), True, True, retain_stats=True, p_per_sample=0.1))
+ tr_transforms.append(GammaTransform((0.7, 1.5), False, True, retain_stats=True, p_per_sample=0.3))
+
+ if mirror_axes is not None and len(mirror_axes) > 0:
+ tr_transforms.append(MirrorTransform(mirror_axes))
+
+ if use_mask_for_norm is not None and any(use_mask_for_norm):
+ tr_transforms.append(MaskTransform([i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]],
+ mask_idx_in_seg=0, set_outside_to=0))
+
+ tr_transforms.append(RemoveLabelTransform(-1, 0))
+
+ if is_cascaded:
+ assert foreground_labels is not None, 'We need foreground_labels for cascade augmentations'
+ tr_transforms.append(MoveSegAsOneHotToData(1, foreground_labels, 'seg', 'data'))
+ tr_transforms.append(ApplyRandomBinaryOperatorTransform(
+ channel_idx=list(range(-len(foreground_labels), 0)),
+ p_per_sample=0.4,
+ key="data",
+ strel_size=(1, 8),
+ p_per_label=1))
+ tr_transforms.append(
+ RemoveRandomConnectedComponentFromOneHotEncodingTransform(
+ channel_idx=list(range(-len(foreground_labels), 0)),
+ key="data",
+ p_per_sample=0.2,
+ fill_with_other_class_p=0,
+ dont_do_if_covers_more_than_x_percent=0.15))
+
+ tr_transforms.append(RenameTransform('seg', 'target', True))
+
+ if regions is not None:
+ # the ignore label must also be converted
+ tr_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label]
+ if ignore_label is not None else regions,
+ 'target', 'target'))
+
+ if deep_supervision_scales is not None:
+ tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',
+ output_key='target'))
+ tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
+ tr_transforms = Compose(tr_transforms)
+ return tr_transforms
+
+ @staticmethod
+ def get_validation_transforms(deep_supervision_scales: Union[List, Tuple],
+ is_cascaded: bool = False,
+ foreground_labels: Union[Tuple[int, ...], List[int]] = None,
+ regions: List[Union[List[int], Tuple[int, ...], int]] = None,
+ ignore_label: int = None) -> AbstractTransform:
+ val_transforms = []
+ val_transforms.append(RemoveLabelTransform(-1, 0))
+
+ if is_cascaded:
+ val_transforms.append(MoveSegAsOneHotToData(1, foreground_labels, 'seg', 'data'))
+
+ val_transforms.append(RenameTransform('seg', 'target', True))
+
+ if regions is not None:
+ # the ignore label must also be converted
+ val_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label]
+ if ignore_label is not None else regions,
+ 'target', 'target'))
+
+ if deep_supervision_scales is not None:
+ val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',
+ output_key='target'))
+
+ val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
+ val_transforms = Compose(val_transforms)
+ return val_transforms
+
+ def set_deep_supervision_enabled(self, enabled: bool):
+ """
+ This function is specific for the default architecture in nnU-Net. If you change the architecture, there are
+ chances you need to change this as well!
+ """
+ if self.is_ddp:
+ self.network.module.decoder.deep_supervision = enabled
+ else:
+ self.network.decoder.deep_supervision = enabled
+
+ def on_train_start(self):
+ if not self.was_initialized:
+ self.initialize()
+
+ maybe_mkdir_p(self.output_folder)
+
+ # make sure deep supervision is on in the network
+ self.set_deep_supervision_enabled(True)
+
+ self.print_plans()
+ empty_cache(self.device)
+
+ # maybe unpack
+ if self.unpack_dataset and self.local_rank == 0:
+ self.print_to_log_file('unpacking dataset...')
+ unpack_dataset(self.preprocessed_dataset_folder, unpack_segmentation=True, overwrite_existing=False,
+ num_processes=max(1, round(get_allowed_n_proc_DA() // 2)))
+ self.print_to_log_file('unpacking done...')
+
+ if self.is_ddp:
+ dist.barrier()
+
+ # dataloaders must be instantiated here because they need access to the training data which may not be present
+ # when doing inference
+ self.dataloader_train, self.dataloader_val = self.get_dataloaders()
+
+ # copy plans and dataset.json so that they can be used for restoring everything we need for inference
+ save_json(self.plans_manager.plans, join(self.output_folder_base, 'plans.json'), sort_keys=False)
+ save_json(self.dataset_json, join(self.output_folder_base, 'dataset.json'), sort_keys=False)
+
+ # we don't really need the fingerprint but its still handy to have it with the others
+ shutil.copy(join(self.preprocessed_dataset_folder_base, 'dataset_fingerprint.json'),
+ join(self.output_folder_base, 'dataset_fingerprint.json'))
+
+ # produces a pdf in output folder
+ self.plot_network_architecture()
+
+ self._save_debug_information()
+
+ # print(f"batch size: {self.batch_size}")
+ # print(f"oversample: {self.oversample_foreground_percent}")
+
+ def on_train_end(self):
+ self.save_checkpoint(join(self.output_folder, "checkpoint_final.pth"))
+ # now we can delete latest
+ if self.local_rank == 0 and isfile(join(self.output_folder, "checkpoint_latest.pth")):
+ os.remove(join(self.output_folder, "checkpoint_latest.pth"))
+
+ # shut down dataloaders
+ old_stdout = sys.stdout
+ with open(os.devnull, 'w') as f:
+ sys.stdout = f
+ if self.dataloader_train is not None:
+ self.dataloader_train._finish()
+ if self.dataloader_val is not None:
+ self.dataloader_val._finish()
+ sys.stdout = old_stdout
+
+ empty_cache(self.device)
+ self.print_to_log_file("Training done.")
+
+ def on_train_epoch_start(self):
+ self.network.train()
+ self.lr_scheduler.step(self.current_epoch)
+ self.print_to_log_file('')
+ self.print_to_log_file(f'Epoch {self.current_epoch}')
+ self.print_to_log_file(
+ f"Current learning rate: {np.round(self.optimizer.param_groups[0]['lr'], decimals=5)}")
+ # lrs are the same for all workers so we don't need to gather them in case of DDP training
+ self.logger.log('lrs', self.optimizer.param_groups[0]['lr'], self.current_epoch)
+
+ def train_step(self, batch: dict) -> dict:
+ data = batch['data']
+ target = batch['target']
+
+ data = data.to(self.device, non_blocking=True)
+ if isinstance(target, list):
+ target = [i.to(self.device, non_blocking=True) for i in target]
+ else:
+ target = target.to(self.device, non_blocking=True)
+
+ self.optimizer.zero_grad()
+ # Autocast is a little bitch.
+ # If the device_type is 'cpu' then it's slow as heck and needs to be disabled.
+ # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
+ # So autocast will only be active if we have a cuda device.
+ with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
+ output = self.network(data)
+ # del data
+ l = self.loss(output, target)
+
+ if self.grad_scaler is not None:
+ self.grad_scaler.scale(l).backward()
+ self.grad_scaler.unscale_(self.optimizer)
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
+ self.grad_scaler.step(self.optimizer)
+ self.grad_scaler.update()
+ else:
+ l.backward()
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
+ self.optimizer.step()
+ return {'loss': l.detach().cpu().numpy()}
+
+ def on_train_epoch_end(self, train_outputs: List[dict]):
+ outputs = collate_outputs(train_outputs)
+
+ if self.is_ddp:
+ losses_tr = [None for _ in range(dist.get_world_size())]
+ dist.all_gather_object(losses_tr, outputs['loss'])
+ loss_here = np.vstack(losses_tr).mean()
+ else:
+ loss_here = np.mean(outputs['loss'])
+
+ self.logger.log('train_losses', loss_here, self.current_epoch)
+
+ def on_validation_epoch_start(self):
+ self.network.eval()
+
+ def validation_step(self, batch: dict) -> dict:
+ data = batch['data']
+ target = batch['target']
+
+ data = data.to(self.device, non_blocking=True)
+ if isinstance(target, list):
+ target = [i.to(self.device, non_blocking=True) for i in target]
+ else:
+ target = target.to(self.device, non_blocking=True)
+
+ # Autocast is a little bitch.
+ # If the device_type is 'cpu' then it's slow as heck and needs to be disabled.
+ # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
+ # So autocast will only be active if we have a cuda device.
+ with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
+ output = self.network(data)
+ del data
+ l = self.loss(output, target)
+
+ # we only need the output with the highest output resolution
+ output = output[0]
+ target = target[0]
+
+ # the following is needed for online evaluation. Fake dice (green line)
+ axes = [0] + list(range(2, len(output.shape)))
+
+ if self.label_manager.has_regions:
+ predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long()
+ else:
+ # no need for softmax
+ output_seg = output.argmax(1)[:, None]
+ predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float32)
+ predicted_segmentation_onehot.scatter_(1, output_seg, 1)
+ del output_seg
+
+ if self.label_manager.has_ignore_label:
+ if not self.label_manager.has_regions:
+ mask = (target != self.label_manager.ignore_label).float()
+ # CAREFUL that you don't rely on target after this line!
+ target[target == self.label_manager.ignore_label] = 0
+ else:
+ mask = 1 - target[:, -1:]
+ # CAREFUL that you don't rely on target after this line!
+ target = target[:, :-1]
+ else:
+ mask = None
+
+ tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask)
+
+ tp_hard = tp.detach().cpu().numpy()
+ fp_hard = fp.detach().cpu().numpy()
+ fn_hard = fn.detach().cpu().numpy()
+ if not self.label_manager.has_regions:
+ # if we train with regions all segmentation heads predict some kind of foreground. In conventional
+ # (softmax training) there needs tobe one output for the background. We are not interested in the
+ # background Dice
+ # [1:] in order to remove background
+ tp_hard = tp_hard[1:]
+ fp_hard = fp_hard[1:]
+ fn_hard = fn_hard[1:]
+
+ return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard}
+
+ def on_validation_epoch_end(self, val_outputs: List[dict]):
+ outputs_collated = collate_outputs(val_outputs)
+ tp = np.sum(outputs_collated['tp_hard'], 0)
+ fp = np.sum(outputs_collated['fp_hard'], 0)
+ fn = np.sum(outputs_collated['fn_hard'], 0)
+
+ if self.is_ddp:
+ world_size = dist.get_world_size()
+
+ tps = [None for _ in range(world_size)]
+ dist.all_gather_object(tps, tp)
+ tp = np.vstack([i[None] for i in tps]).sum(0)
+
+ fps = [None for _ in range(world_size)]
+ dist.all_gather_object(fps, fp)
+ fp = np.vstack([i[None] for i in fps]).sum(0)
+
+ fns = [None for _ in range(world_size)]
+ dist.all_gather_object(fns, fn)
+ fn = np.vstack([i[None] for i in fns]).sum(0)
+
+ losses_val = [None for _ in range(world_size)]
+ dist.all_gather_object(losses_val, outputs_collated['loss'])
+ loss_here = np.vstack(losses_val).mean()
+ else:
+ loss_here = np.mean(outputs_collated['loss'])
+
+ global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in
+ zip(tp, fp, fn)]]
+ mean_fg_dice = np.nanmean(global_dc_per_class)
+ self.logger.log('mean_fg_dice', mean_fg_dice, self.current_epoch)
+ self.logger.log('dice_per_class_or_region', global_dc_per_class, self.current_epoch)
+ self.logger.log('val_losses', loss_here, self.current_epoch)
+
+ def on_epoch_start(self):
+ self.logger.log('epoch_start_timestamps', time(), self.current_epoch)
+
+ def on_epoch_end(self):
+ self.logger.log('epoch_end_timestamps', time(), self.current_epoch)
+
+ # todo find a solution for this stupid shit
+ self.print_to_log_file('train_loss', np.round(self.logger.my_fantastic_logging['train_losses'][-1], decimals=4))
+ self.print_to_log_file('val_loss', np.round(self.logger.my_fantastic_logging['val_losses'][-1], decimals=4))
+ self.print_to_log_file('Pseudo dice', [np.round(i, decimals=4) for i in
+ self.logger.my_fantastic_logging['dice_per_class_or_region'][-1]])
+ self.print_to_log_file(
+ f"Epoch time: {np.round(self.logger.my_fantastic_logging['epoch_end_timestamps'][-1] - self.logger.my_fantastic_logging['epoch_start_timestamps'][-1], decimals=2)} s")
+
+ # handling periodic checkpointing
+ current_epoch = self.current_epoch
+ if (current_epoch + 1) % self.save_every == 0 and current_epoch != (self.num_epochs - 1):
+ self.save_checkpoint(join(self.output_folder, 'checkpoint_latest.pth'))
+
+ # handle 'best' checkpointing. ema_fg_dice is computed by the logger and can be accessed like this
+ if self._best_ema is None or self.logger.my_fantastic_logging['ema_fg_dice'][-1] > self._best_ema:
+ self._best_ema = self.logger.my_fantastic_logging['ema_fg_dice'][-1]
+ self.print_to_log_file(f"Yayy! New best EMA pseudo Dice: {np.round(self._best_ema, decimals=4)}")
+ self.save_checkpoint(join(self.output_folder, 'checkpoint_best.pth'))
+
+ if self.local_rank == 0:
+ self.logger.plot_progress_png(self.output_folder)
+
+ self.current_epoch += 1
+
+ def save_checkpoint(self, filename: str) -> None:
+ if self.local_rank == 0:
+ if not self.disable_checkpointing:
+ if self.is_ddp:
+ mod = self.network.module
+ else:
+ mod = self.network
+ if isinstance(mod, OptimizedModule):
+ mod = mod._orig_mod
+
+ checkpoint = {
+ 'network_weights': mod.state_dict(),
+ 'optimizer_state': self.optimizer.state_dict(),
+ 'grad_scaler_state': self.grad_scaler.state_dict() if self.grad_scaler is not None else None,
+ 'logging': self.logger.get_checkpoint(),
+ '_best_ema': self._best_ema,
+ 'current_epoch': self.current_epoch + 1,
+ 'init_args': self.my_init_kwargs,
+ 'trainer_name': self.__class__.__name__,
+ 'inference_allowed_mirroring_axes': self.inference_allowed_mirroring_axes,
+ }
+ torch.save(checkpoint, filename)
+ else:
+ self.print_to_log_file('No checkpoint written, checkpointing is disabled')
+
+ def load_checkpoint(self, filename_or_checkpoint: Union[dict, str]) -> None:
+ if not self.was_initialized:
+ self.initialize()
+
+ if isinstance(filename_or_checkpoint, str):
+ checkpoint = torch.load(filename_or_checkpoint, map_location=self.device)
+ # if state dict comes from nn.DataParallel but we use non-parallel model here then the state dict keys do not
+ # match. Use heuristic to make it match
+ new_state_dict = {}
+ for k, value in checkpoint['network_weights'].items():
+ key = k
+ if key not in self.network.state_dict().keys() and key.startswith('module.'):
+ key = key[7:]
+ new_state_dict[key] = value
+
+ self.my_init_kwargs = checkpoint['init_args']
+ self.current_epoch = checkpoint['current_epoch']
+ self.logger.load_checkpoint(checkpoint['logging'])
+ self._best_ema = checkpoint['_best_ema']
+ self.inference_allowed_mirroring_axes = checkpoint[
+ 'inference_allowed_mirroring_axes'] if 'inference_allowed_mirroring_axes' in checkpoint.keys() else self.inference_allowed_mirroring_axes
+
+ # messing with state dict naming schemes. Facepalm.
+ if self.is_ddp:
+ if isinstance(self.network.module, OptimizedModule):
+ self.network.module._orig_mod.load_state_dict(new_state_dict)
+ else:
+ self.network.module.load_state_dict(new_state_dict)
+ else:
+ if isinstance(self.network, OptimizedModule):
+ self.network._orig_mod.load_state_dict(new_state_dict)
+ else:
+ self.network.load_state_dict(new_state_dict)
+ self.optimizer.load_state_dict(checkpoint['optimizer_state'])
+ if self.grad_scaler is not None:
+ if checkpoint['grad_scaler_state'] is not None:
+ self.grad_scaler.load_state_dict(checkpoint['grad_scaler_state'])
+
+ def perform_actual_validation(self, save_probabilities: bool = False):
+ self.set_deep_supervision_enabled(False)
+ self.network.eval()
+
+ predictor = nnUNetPredictor(tile_step_size=0.5, use_gaussian=True, use_mirroring=True,
+ perform_everything_on_gpu=True, device=self.device, verbose=False,
+ verbose_preprocessing=False, allow_tqdm=False)
+ predictor.manual_initialization(self.network, self.plans_manager, self.configuration_manager, None,
+ self.dataset_json, self.__class__.__name__,
+ self.inference_allowed_mirroring_axes)
+
+ with multiprocessing.get_context("spawn").Pool(default_num_processes) as segmentation_export_pool:
+ worker_list = [i for i in segmentation_export_pool._pool]
+ validation_output_folder = join(self.output_folder, 'validation')
+ maybe_mkdir_p(validation_output_folder)
+
+ # we cannot use self.get_tr_and_val_datasets() here because we might be DDP and then we have to distribute
+ # the validation keys across the workers.
+ _, val_keys = self.do_split()
+ if self.is_ddp:
+ val_keys = val_keys[self.local_rank:: dist.get_world_size()]
+
+ dataset_val = nnUNetDataset(self.preprocessed_dataset_folder, val_keys,
+ folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage,
+ num_images_properties_loading_threshold=0)
+
+ next_stages = self.configuration_manager.next_stage_names
+
+ if next_stages is not None:
+ _ = [maybe_mkdir_p(join(self.output_folder_base, 'predicted_next_stage', n)) for n in next_stages]
+
+ results = []
+
+ for k in dataset_val.keys():
+ proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results,
+ allowed_num_queued=2)
+ while not proceed:
+ sleep(0.1)
+ proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results,
+ allowed_num_queued=2)
+
+ self.print_to_log_file(f"predicting {k}")
+ data, seg, properties = dataset_val.load_case(k)
+
+ if self.is_cascaded:
+ data = np.vstack((data, convert_labelmap_to_one_hot(seg[-1], self.label_manager.foreground_labels,
+ output_dtype=data.dtype)))
+ with warnings.catch_warnings():
+ # ignore 'The given NumPy array is not writable' warning
+ warnings.simplefilter("ignore")
+ data = torch.from_numpy(data)
+
+ output_filename_truncated = join(validation_output_folder, k)
+
+ try:
+ prediction = predictor.predict_sliding_window_return_logits(data)
+ except RuntimeError:
+ predictor.perform_everything_on_gpu = False
+ prediction = predictor.predict_sliding_window_return_logits(data)
+ predictor.perform_everything_on_gpu = True
+
+ prediction = prediction.cpu()
+
+ # this needs to go into background processes
+ results.append(
+ segmentation_export_pool.starmap_async(
+ export_prediction_from_logits, (
+ (prediction, properties, self.configuration_manager, self.plans_manager,
+ self.dataset_json, output_filename_truncated, save_probabilities),
+ )
+ )
+ )
+ # for debug purposes
+ # export_prediction(prediction_for_export, properties, self.configuration, self.plans, self.dataset_json,
+ # output_filename_truncated, save_probabilities)
+
+ # if needed, export the softmax prediction for the next stage
+ if next_stages is not None:
+ for n in next_stages:
+ next_stage_config_manager = self.plans_manager.get_configuration(n)
+ expected_preprocessed_folder = join(nnUNet_preprocessed, self.plans_manager.dataset_name,
+ next_stage_config_manager.data_identifier)
+
+ try:
+ # we do this so that we can use load_case and do not have to hard code how loading training cases is implemented
+ tmp = nnUNetDataset(expected_preprocessed_folder, [k],
+ num_images_properties_loading_threshold=0)
+ d, s, p = tmp.load_case(k)
+ except FileNotFoundError:
+ self.print_to_log_file(
+ f"Predicting next stage {n} failed for case {k} because the preprocessed file is missing! "
+ f"Run the preprocessing for this configuration first!")
+ continue
+
+ target_shape = d.shape[1:]
+ output_folder = join(self.output_folder_base, 'predicted_next_stage', n)
+ output_file = join(output_folder, k + '.npz')
+
+ # resample_and_save(prediction, target_shape, output_file, self.plans_manager, self.configuration_manager, properties,
+ # self.dataset_json)
+ results.append(segmentation_export_pool.starmap_async(
+ resample_and_save, (
+ (prediction, target_shape, output_file, self.plans_manager,
+ self.configuration_manager,
+ properties,
+ self.dataset_json),
+ )
+ ))
+
+ _ = [r.get() for r in results]
+
+ if self.is_ddp:
+ dist.barrier()
+
+ if self.local_rank == 0:
+ metrics = compute_metrics_on_folder(join(self.preprocessed_dataset_folder_base, 'gt_segmentations'),
+ validation_output_folder,
+ join(validation_output_folder, 'summary.json'),
+ self.plans_manager.image_reader_writer_class(),
+ self.dataset_json["file_ending"],
+ self.label_manager.foreground_regions if self.label_manager.has_regions else
+ self.label_manager.foreground_labels,
+ self.label_manager.ignore_label, chill=True)
+ self.print_to_log_file("Validation complete", also_print_to_console=True)
+ self.print_to_log_file("Mean Validation Dice: ", (metrics['foreground_mean']["Dice"]), also_print_to_console=True)
+
+ self.set_deep_supervision_enabled(True)
+ compute_gaussian.cache_clear()
+
+ def run_training(self):
+ self.on_train_start()
+
+ for epoch in range(self.current_epoch, self.num_epochs):
+ self.on_epoch_start()
+
+ self.on_train_epoch_start()
+ train_outputs = []
+ for batch_id in range(self.num_iterations_per_epoch):
+ train_outputs.append(self.train_step(next(self.dataloader_train)))
+ self.on_train_epoch_end(train_outputs)
+
+ with torch.no_grad():
+ self.on_validation_epoch_start()
+ val_outputs = []
+ for batch_id in range(self.num_val_iterations_per_epoch):
+ val_outputs.append(self.validation_step(next(self.dataloader_val)))
+ self.on_validation_epoch_end(val_outputs)
+
+ self.on_epoch_end()
+
+ self.on_train_end()
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer_aghiles.py b/nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer_aghiles.py
new file mode 100644
index 0000000..8b376eb
--- /dev/null
+++ b/nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer_aghiles.py
@@ -0,0 +1,1181 @@
+import inspect
+import multiprocessing
+import os
+import shutil
+import sys
+import warnings
+from copy import deepcopy
+from datetime import datetime
+from time import time, sleep
+from typing import Union, Tuple, List
+
+import numpy as np
+import torch
+from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter
+from batchgenerators.transforms.abstract_transforms import AbstractTransform, Compose
+from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \
+ ContrastAugmentationTransform, GammaTransform
+from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform
+from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform
+from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform
+from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor
+from batchgenerators.utilities.file_and_folder_operations import join, load_json, isfile, save_json, maybe_mkdir_p
+from torch._dynamo import OptimizedModule
+
+from nnunetv2.configuration import ANISO_THRESHOLD, default_num_processes
+from nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder
+from nnunetv2.inference.export_prediction import export_prediction_from_logits, resample_and_save
+from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
+from nnunetv2.inference.sliding_window_prediction import compute_gaussian
+from nnunetv2.paths import nnUNet_preprocessed, nnUNet_results
+from nnunetv2.training.data_augmentation.compute_initial_patch_size import get_patch_size
+from nnunetv2.training.data_augmentation.custom_transforms.cascade_transforms import MoveSegAsOneHotToData, \
+ ApplyRandomBinaryOperatorTransform, RemoveRandomConnectedComponentFromOneHotEncodingTransform
+from nnunetv2.training.data_augmentation.custom_transforms.deep_supervision_donwsampling import \
+ DownsampleSegForDSTransform2
+from nnunetv2.training.data_augmentation.custom_transforms.limited_length_multithreaded_augmenter import \
+ LimitedLenWrapper
+from nnunetv2.training.data_augmentation.custom_transforms.masking import MaskTransform
+from nnunetv2.training.data_augmentation.custom_transforms.region_based_training import \
+ ConvertSegmentationToRegionsTransform
+from nnunetv2.training.data_augmentation.custom_transforms.transforms_for_dummy_2d import Convert2DTo3DTransform, \
+ Convert3DTo2DTransform
+from nnunetv2.training.dataloading.data_loader_2d import nnUNetDataLoader2D
+from nnunetv2.training.dataloading.data_loader_3d import nnUNetDataLoader3D
+from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset
+from nnunetv2.training.dataloading.utils import get_case_identifiers, unpack_dataset
+from nnunetv2.training.logging.nnunet_logger import nnUNetLogger
+from nnunetv2.training.loss.compound_losses import DC_and_CE_loss, DC_and_BCE_loss
+from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper
+from nnunetv2.training.loss.dice import get_tp_fp_fn_tn, MemoryEfficientSoftDiceLoss
+from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler
+from nnunetv2.utilities.collate_outputs import collate_outputs
+from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA
+from nnunetv2.utilities.file_path_utilities import check_workers_alive_and_busy
+from nnunetv2.utilities.get_network_from_plans import get_network_from_plans
+from nnunetv2.utilities.helpers import empty_cache, dummy_context
+from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot, determine_num_input_channels
+from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
+from sklearn.model_selection import KFold
+from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
+from torch import autocast, nn
+from torch import distributed as dist
+from torch.cuda import device_count
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+
+class nnUNetTrainer_aghiles(nnUNetTrainer):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ # From https://grugbrain.dev/. Worth a read ya big brains ;-)
+
+ # apex predator of grug is complexity
+ # complexity bad
+ # say again:
+ # complexity very bad
+ # you say now:
+ # complexity very, very bad
+ # given choice between complexity or one on one against t-rex, grug take t-rex: at least grug see t-rex
+ # complexity is spirit demon that enter codebase through well-meaning but ultimately very clubbable non grug-brain developers and project managers who not fear complexity spirit demon or even know about sometime
+ # one day code base understandable and grug can get work done, everything good!
+ # next day impossible: complexity demon spirit has entered code and very dangerous situation!
+
+ # OK OK I am guilty. But I tried. http://tiny.cc/gzgwuz
+
+ self.is_ddp = dist.is_available() and dist.is_initialized()
+ self.local_rank = 0 if not self.is_ddp else dist.get_rank()
+
+ self.device = device
+
+ # print what device we are using
+ if self.is_ddp: # implicitly it's clear that we use cuda in this case
+ print(f"I am local rank {self.local_rank}. {device_count()} GPUs are available. The world size is "
+ f"{dist.get_world_size()}."
+ f"Setting device to {self.device}")
+ self.device = torch.device(type='cuda', index=self.local_rank)
+ else:
+ if self.device.type == 'cuda':
+ # we might want to let the user pick this but for now please pick the correct GPU with CUDA_VISIBLE_DEVICES=X
+ self.device = torch.device(type='cuda', index=0)
+ print(f"Using device: {self.device}")
+
+ # loading and saving this class for continuing from checkpoint should not happen based on pickling. This
+ # would also pickle the network etc. Bad, bad. Instead we just reinstantiate and then load the checkpoint we
+ # need. So let's save the init args
+ self.my_init_kwargs = {}
+ for k in inspect.signature(self.__init__).parameters.keys():
+ self.my_init_kwargs[k] = locals()[k]
+
+ ### Saving all the init args into class variables for later access
+ self.plans_manager = PlansManager(plans)
+ self.configuration_manager = self.plans_manager.get_configuration(configuration)
+ self.configuration_name = configuration
+ self.dataset_json = dataset_json
+ self.fold = fold
+ self.unpack_dataset = unpack_dataset
+
+ ### Setting all the folder names. We need to make sure things don't crash in case we are just running
+ # inference and some of the folders may not be defined!
+ self.preprocessed_dataset_folder_base = join(nnUNet_preprocessed, self.plans_manager.dataset_name) \
+ if nnUNet_preprocessed is not None else None
+ self.output_folder_base = join(nnUNet_results, self.plans_manager.dataset_name,
+ self.__class__.__name__ + '__' + self.plans_manager.plans_name + "__" + configuration) \
+ if nnUNet_results is not None else None
+ self.output_folder = join(self.output_folder_base, f'fold_{fold}')
+
+ self.preprocessed_dataset_folder = join(self.preprocessed_dataset_folder_base,
+ self.configuration_manager.data_identifier)
+ # unlike the previous nnunet folder_with_segs_from_previous_stage is now part of the plans. For now it has to
+ # be a different configuration in the same plans
+ # IMPORTANT! the mapping must be bijective, so lowres must point to fullres and vice versa (using
+ # "previous_stage" and "next_stage"). Otherwise it won't work!
+ self.is_cascaded = self.configuration_manager.previous_stage_name is not None
+ self.folder_with_segs_from_previous_stage = \
+ join(nnUNet_results, self.plans_manager.dataset_name,
+ self.__class__.__name__ + '__' + self.plans_manager.plans_name + "__" +
+ self.configuration_manager.previous_stage_name, 'predicted_next_stage', self.configuration_name) \
+ if self.is_cascaded else None
+
+ ### Some hyperparameters for you to fiddle with
+ self.initial_lr = 1e-2
+ self.weight_decay = 3e-5
+ self.oversample_foreground_percent = 0.33
+ self.num_iterations_per_epoch = 50
+ self.num_val_iterations_per_epoch = 30
+ self.num_epochs = 40
+ self.current_epoch = 0
+
+ ### Dealing with labels/regions
+ self.label_manager = self.plans_manager.get_label_manager(dataset_json)
+ # labels can either be a list of int (regular training) or a list of tuples of int (region-based training)
+ # needed for predictions. We do sigmoid in case of (overlapping) regions
+
+ self.num_input_channels = None # -> self.initialize()
+ self.network = None # -> self._get_network()
+ self.optimizer = self.lr_scheduler = None # -> self.initialize
+ self.grad_scaler = GradScaler() if self.device.type == 'cuda' else None
+ self.loss = None # -> self.initialize
+
+ ### Simple logging. Don't take that away from me!
+ # initialize log file. This is just our log for the print statements etc. Not to be confused with lightning
+ # logging
+ timestamp = datetime.now()
+ maybe_mkdir_p(self.output_folder)
+ self.log_file = join(self.output_folder, "training_log_%d_%d_%d_%02.0d_%02.0d_%02.0d.txt" %
+ (timestamp.year, timestamp.month, timestamp.day, timestamp.hour, timestamp.minute,
+ timestamp.second))
+ self.logger = nnUNetLogger()
+
+ ### placeholders
+ self.dataloader_train = self.dataloader_val = None # see on_train_start
+
+ ### initializing stuff for remembering things and such
+ self._best_ema = None
+
+ ### inference things
+ self.inference_allowed_mirroring_axes = None # this variable is set in
+ # self.configure_rotation_dummyDA_mirroring_and_inital_patch_size and will be saved in checkpoints
+
+ ### checkpoint saving stuff
+ self.save_every = 50
+ self.disable_checkpointing = False
+
+ ## DDP batch size and oversampling can differ between workers and needs adaptation
+ # we need to change the batch size in DDP because we don't use any of those distributed samplers
+ self._set_batch_size_and_oversample()
+
+ self.was_initialized = False
+
+ self.print_to_log_file("\n#######################################################################\n"
+ "Please cite the following paper when using nnU-Net:\n"
+ "Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). "
+ "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. "
+ "Nature methods, 18(2), 203-211.\n"
+ "#######################################################################\n",
+ also_print_to_console=True, add_timestamp=False)
+
+ def initialize(self):
+ if not self.was_initialized:
+ self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager,
+ self.dataset_json)
+
+ self.network = self.build_network_architecture(self.plans_manager, self.dataset_json,
+ self.configuration_manager,
+ self.num_input_channels,
+ enable_deep_supervision=False).to(self.device)
+ # compile network for free speedup
+ if ('nnUNet_compile' in os.environ.keys()) and (
+ os.environ['nnUNet_compile'].lower() in ('true', '1', 't')):
+ self.print_to_log_file('Compiling network...')
+ self.network = torch.compile(self.network)
+
+ self.optimizer, self.lr_scheduler = self.configure_optimizers()
+ # if ddp, wrap in DDP wrapper
+ if self.is_ddp:
+ self.network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.network)
+ self.network = DDP(self.network, device_ids=[self.local_rank])
+
+ self.loss = self._build_loss()
+ self.was_initialized = True
+ else:
+ raise RuntimeError("You have called self.initialize even though the trainer was already initialized. "
+ "That should not happen.")
+
+ def _save_debug_information(self):
+ # saving some debug information
+ if self.local_rank == 0:
+ dct = {}
+ for k in self.__dir__():
+ if not k.startswith("__"):
+ if not callable(getattr(self, k)) or k in ['loss', ]:
+ dct[k] = str(getattr(self, k))
+ elif k in ['network', ]:
+ dct[k] = str(getattr(self, k).__class__.__name__)
+ else:
+ # print(k)
+ pass
+ if k in ['dataloader_train', 'dataloader_val']:
+ if hasattr(getattr(self, k), 'generator'):
+ dct[k + '.generator'] = str(getattr(self, k).generator)
+ if hasattr(getattr(self, k), 'num_processes'):
+ dct[k + '.num_processes'] = str(getattr(self, k).num_processes)
+ if hasattr(getattr(self, k), 'transform'):
+ dct[k + '.transform'] = str(getattr(self, k).transform)
+ import subprocess
+ hostname = subprocess.getoutput(['hostname'])
+ dct['hostname'] = hostname
+ torch_version = torch.__version__
+ if self.device.type == 'cuda':
+ gpu_name = torch.cuda.get_device_name()
+ dct['gpu_name'] = gpu_name
+ cudnn_version = torch.backends.cudnn.version()
+ else:
+ cudnn_version = 'None'
+ dct['device'] = str(self.device)
+ dct['torch_version'] = torch_version
+ dct['cudnn_version'] = cudnn_version
+ save_json(dct, join(self.output_folder, "debug.json"))
+
+ @staticmethod
+ def build_network_architecture(plans_manager: PlansManager,
+ dataset_json,
+ configuration_manager: ConfigurationManager,
+ num_input_channels,
+ enable_deep_supervision: bool = False) -> nn.Module:
+ """
+ his is where you build the architecture according to the plans. There is no obligation to use
+ get_network_from_plans, this is just a utility we use for the nnU-Net default architectures. You can do what
+ you want. Even ignore the plans and just return something static (as long as it can process the requested
+ patch size)
+ but don't bug us with your bugs arising from fiddling with this :-P
+ This is the function that is called in inference as well! This is needed so that all network architecture
+ variants can be loaded at inference time (inference will use the same nnUNetTrainer that was used for
+ training, so if you change the network architecture during training by deriving a new trainer class then
+ inference will know about it).
+
+ If you need to know how many segmentation outputs your custom architecture needs to have, use the following snippet:
+ > label_manager = plans_manager.get_label_manager(dataset_json)
+ > label_manager.num_segmentation_heads
+ (why so complicated? -> We can have either classical training (classes) or regions. If we have regions,
+ the number of outputs is != the number of classes. Also there is the ignore label for which no output
+ should be generated. label_manager takes care of all that for you.)
+
+ """
+ return get_network_from_plans(plans_manager, dataset_json, configuration_manager,
+ num_input_channels, deep_supervision=enable_deep_supervision)
+
+ def _get_deep_supervision_scales(self):
+ deep_supervision_scales = list(list(i) for i in 1 / np.cumprod(np.vstack(
+ self.configuration_manager.pool_op_kernel_sizes), axis=0))[:-1]
+ return deep_supervision_scales
+
+ def _set_batch_size_and_oversample(self):
+ if not self.is_ddp:
+ # set batch size to what the plan says, leave oversample untouched
+ self.batch_size = self.configuration_manager.batch_size
+ else:
+ # batch size is distributed over DDP workers and we need to change oversample_percent for each worker
+ batch_sizes = []
+ oversample_percents = []
+
+ world_size = dist.get_world_size()
+ my_rank = dist.get_rank()
+
+ global_batch_size = self.configuration_manager.batch_size
+ assert global_batch_size >= world_size, 'Cannot run DDP if the batch size is smaller than the number of ' \
+ 'GPUs... Duh.'
+
+ batch_size_per_GPU = np.ceil(global_batch_size / world_size).astype(int)
+
+ for rank in range(world_size):
+ if (rank + 1) * batch_size_per_GPU > global_batch_size:
+ batch_size = batch_size_per_GPU - ((rank + 1) * batch_size_per_GPU - global_batch_size)
+ else:
+ batch_size = batch_size_per_GPU
+
+ batch_sizes.append(batch_size)
+
+ sample_id_low = 0 if len(batch_sizes) == 0 else np.sum(batch_sizes[:-1])
+ sample_id_high = np.sum(batch_sizes)
+
+ if sample_id_high / global_batch_size < (1 - self.oversample_foreground_percent):
+ oversample_percents.append(0.0)
+ elif sample_id_low / global_batch_size > (1 - self.oversample_foreground_percent):
+ oversample_percents.append(1.0)
+ else:
+ percent_covered_by_this_rank = sample_id_high / global_batch_size - sample_id_low / global_batch_size
+ oversample_percent_here = 1 - (((1 - self.oversample_foreground_percent) -
+ sample_id_low / global_batch_size) / percent_covered_by_this_rank)
+ oversample_percents.append(oversample_percent_here)
+
+ print("worker", my_rank, "oversample", oversample_percents[my_rank])
+ print("worker", my_rank, "batch_size", batch_sizes[my_rank])
+ # self.print_to_log_file("worker", my_rank, "oversample", oversample_percents[my_rank])
+ # self.print_to_log_file("worker", my_rank, "batch_size", batch_sizes[my_rank])
+
+ self.batch_size = batch_sizes[my_rank]
+ self.oversample_foreground_percent = oversample_percents[my_rank]
+
+ def _build_loss(self):
+ if self.label_manager.has_regions:
+ loss = DC_and_BCE_loss({},
+ {'batch_dice': self.configuration_manager.batch_dice,
+ 'do_bg': True, 'smooth': 1e-5, 'ddp': self.is_ddp},
+ use_ignore_label=self.label_manager.ignore_label is not None,
+ dice_class=MemoryEfficientSoftDiceLoss)
+ else:
+ loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice,
+ 'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1,
+ ignore_label=self.label_manager.ignore_label, dice_class=MemoryEfficientSoftDiceLoss)
+
+ deep_supervision_scales = self._get_deep_supervision_scales()
+
+ # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
+ # this gives higher resolution outputs more weight in the loss
+ weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))])
+ weights[-1] = 0
+
+ # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
+ weights = weights / weights.sum()
+ # now wrap the loss
+ loss = DeepSupervisionWrapper(loss, weights)
+ return loss
+
+ def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
+ """
+ This function is stupid and certainly one of the weakest spots of this implementation. Not entirely sure how we can fix it.
+ """
+ patch_size = self.configuration_manager.patch_size
+ dim = len(patch_size)
+ # todo rotation should be defined dynamically based on patch size (more isotropic patch sizes = more rotation)
+ if dim == 2:
+ do_dummy_2d_data_aug = False
+ # todo revisit this parametrization
+ if max(patch_size) / min(patch_size) > 1.5:
+ rotation_for_DA = {
+ 'x': (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi),
+ 'y': (0, 0),
+ 'z': (0, 0)
+ }
+ else:
+ rotation_for_DA = {
+ 'x': (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi),
+ 'y': (0, 0),
+ 'z': (0, 0)
+ }
+ mirror_axes = (0, 1)
+ elif dim == 3:
+ # todo this is not ideal. We could also have patch_size (64, 16, 128) in which case a full 180deg 2d rot would be bad
+ # order of the axes is determined by spacing, not image size
+ do_dummy_2d_data_aug = (max(patch_size) / patch_size[0]) > ANISO_THRESHOLD
+ if do_dummy_2d_data_aug:
+ # why do we rotate 180 deg here all the time? We should also restrict it
+ rotation_for_DA = {
+ 'x': (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi),
+ 'y': (0, 0),
+ 'z': (0, 0)
+ }
+ else:
+ rotation_for_DA = {
+ 'x': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi),
+ 'y': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi),
+ 'z': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi),
+ }
+ mirror_axes = (0, 1, 2)
+ else:
+ raise RuntimeError()
+
+ # todo this function is stupid. It doesn't even use the correct scale range (we keep things as they were in the
+ # old nnunet for now)
+ initial_patch_size = get_patch_size(patch_size[-dim:],
+ *rotation_for_DA.values(),
+ (0.85, 1.25))
+ if do_dummy_2d_data_aug:
+ initial_patch_size[0] = patch_size[0]
+
+ self.print_to_log_file(f'do_dummy_2d_data_aug: {do_dummy_2d_data_aug}')
+ self.inference_allowed_mirroring_axes = mirror_axes
+
+ return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
+
+ def print_to_log_file(self, *args, also_print_to_console=True, add_timestamp=True):
+ if self.local_rank == 0:
+ timestamp = time()
+ dt_object = datetime.fromtimestamp(timestamp)
+
+ if add_timestamp:
+ args = ("%s:" % dt_object, *args)
+
+ successful = False
+ max_attempts = 5
+ ctr = 0
+ while not successful and ctr < max_attempts:
+ try:
+ with open(self.log_file, 'a+') as f:
+ for a in args:
+ f.write(str(a))
+ f.write(" ")
+ f.write("\n")
+ successful = True
+ except IOError:
+ print("%s: failed to log: " % datetime.fromtimestamp(timestamp), sys.exc_info())
+ sleep(0.5)
+ ctr += 1
+ if also_print_to_console:
+ print(*args)
+ elif also_print_to_console:
+ print(*args)
+
+ def print_plans(self):
+ if self.local_rank == 0:
+ dct = deepcopy(self.plans_manager.plans)
+ del dct['configurations']
+ self.print_to_log_file(f"\nThis is the configuration used by this "
+ f"training:\nConfiguration name: {self.configuration_name}\n",
+ self.configuration_manager, '\n', add_timestamp=False)
+ self.print_to_log_file('These are the global plan.json settings:\n', dct, '\n', add_timestamp=False)
+
+ def configure_optimizers(self):
+ optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
+ momentum=0.99, nesterov=True)
+ lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs, exponent=3.)
+ return optimizer, lr_scheduler
+
+ def plot_network_architecture(self):
+ if self.local_rank == 0:
+ try:
+ # raise NotImplementedError('hiddenlayer no longer works and we do not have a viable alternative :-(')
+ # pip install git+https://github.com/saugatkandel/hiddenlayer.git
+
+ # from torchviz import make_dot
+ # # not viable.
+ # make_dot(tuple(self.network(torch.rand((1, self.num_input_channels,
+ # *self.configuration_manager.patch_size),
+ # device=self.device)))).render(
+ # join(self.output_folder, "network_architecture.pdf"), format='pdf')
+ # self.optimizer.zero_grad()
+
+ # broken.
+
+ import hiddenlayer as hl
+ g = hl.build_graph(self.network,
+ torch.rand((1, self.num_input_channels,
+ *self.configuration_manager.patch_size),
+ device=self.device),
+ transforms=None)
+ g.save(join(self.output_folder, "network_architecture.pdf"))
+ del g
+ except Exception as e:
+ self.print_to_log_file("Unable to plot network architecture:")
+ self.print_to_log_file(e)
+
+ # self.print_to_log_file("\nprinting the network instead:\n")
+ # self.print_to_log_file(self.network)
+ # self.print_to_log_file("\n")
+ finally:
+ empty_cache(self.device)
+
+ def do_split(self):
+ """
+ The default split is a 5 fold CV on all available training cases. nnU-Net will create a split (it is seeded,
+ so always the same) and save it as splits_final.pkl file in the preprocessed data directory.
+ Sometimes you may want to create your own split for various reasons. For this you will need to create your own
+ splits_final.pkl file. If this file is present, nnU-Net is going to use it and whatever splits are defined in
+ it. You can create as many splits in this file as you want. Note that if you define only 4 splits (fold 0-3)
+ and then set fold=4 when training (that would be the fifth split), nnU-Net will print a warning and proceed to
+ use a random 80:20 data split.
+ :return:
+ """
+ if self.fold == "all":
+ # if fold==all then we use all images for training and validation
+ case_identifiers = get_case_identifiers(self.preprocessed_dataset_folder)
+ tr_keys = case_identifiers
+ val_keys = tr_keys
+ else:
+ splits_file = join(self.preprocessed_dataset_folder_base, "splits_final.json")
+ dataset = nnUNetDataset(self.preprocessed_dataset_folder, case_identifiers=None,
+ num_images_properties_loading_threshold=0,
+ folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage)
+ # if the split file does not exist we need to create it
+ if not isfile(splits_file):
+ self.print_to_log_file("Creating new 5-fold cross-validation split...")
+ splits = []
+ all_keys_sorted = np.sort(list(dataset.keys()))
+ kfold = KFold(n_splits=5, shuffle=True, random_state=12345)
+ for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)):
+ train_keys = np.array(all_keys_sorted)[train_idx]
+ test_keys = np.array(all_keys_sorted)[test_idx]
+ splits.append({})
+ splits[-1]['train'] = list(train_keys)
+ splits[-1]['val'] = list(test_keys)
+ save_json(splits, splits_file)
+
+ else:
+ self.print_to_log_file("Using splits from existing split file:", splits_file)
+ splits = load_json(splits_file)
+ self.print_to_log_file("The split file contains %d splits." % len(splits))
+
+ self.print_to_log_file("Desired fold for training: %d" % self.fold)
+ if self.fold < len(splits):
+ tr_keys = splits[self.fold]['train']
+ val_keys = splits[self.fold]['val']
+ self.print_to_log_file("This split has %d training and %d validation cases."
+ % (len(tr_keys), len(val_keys)))
+ else:
+ self.print_to_log_file("INFO: You requested fold %d for training but splits "
+ "contain only %d folds. I am now creating a "
+ "random (but seeded) 80:20 split!" % (self.fold, len(splits)))
+ # if we request a fold that is not in the split file, create a random 80:20 split
+ rnd = np.random.RandomState(seed=12345 + self.fold)
+ keys = np.sort(list(dataset.keys()))
+ idx_tr = rnd.choice(len(keys), int(len(keys) * 0.8), replace=False)
+ idx_val = [i for i in range(len(keys)) if i not in idx_tr]
+ tr_keys = [keys[i] for i in idx_tr]
+ val_keys = [keys[i] for i in idx_val]
+ self.print_to_log_file("This random 80:20 split has %d training and %d validation cases."
+ % (len(tr_keys), len(val_keys)))
+ if any([i in val_keys for i in tr_keys]):
+ self.print_to_log_file('WARNING: Some validation cases are also in the training set. Please check the '
+ 'splits.json or ignore if this is intentional.')
+ return tr_keys, val_keys
+
+ def get_tr_and_val_datasets(self):
+ # create dataset split
+ tr_keys, val_keys = self.do_split()
+
+ # load the datasets for training and validation. Note that we always draw random samples so we really don't
+ # care about distributing training cases across GPUs.
+ dataset_tr = nnUNetDataset(self.preprocessed_dataset_folder, tr_keys,
+ folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage,
+ num_images_properties_loading_threshold=0)
+ dataset_val = nnUNetDataset(self.preprocessed_dataset_folder, val_keys,
+ folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage,
+ num_images_properties_loading_threshold=0)
+ return dataset_tr, dataset_val
+
+ def get_dataloaders(self):
+ # we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether
+ # we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be
+ patch_size = self.configuration_manager.patch_size
+ dim = len(patch_size)
+
+ # needed for deep supervision: how much do we need to downscale the segmentation targets for the different
+ # outputs?
+ deep_supervision_scales = self._get_deep_supervision_scales()
+
+ rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
+ self.configure_rotation_dummyDA_mirroring_and_inital_patch_size()
+
+ # training pipeline
+ tr_transforms = self.get_training_transforms(
+ patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug,
+ order_resampling_data=3, order_resampling_seg=1,
+ use_mask_for_norm=self.configuration_manager.use_mask_for_norm,
+ is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.foreground_labels,
+ regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None,
+ ignore_label=self.label_manager.ignore_label)
+
+ # validation pipeline
+ val_transforms = self.get_validation_transforms(deep_supervision_scales,
+ is_cascaded=self.is_cascaded,
+ foreground_labels=self.label_manager.foreground_labels,
+ regions=self.label_manager.foreground_regions if
+ self.label_manager.has_regions else None,
+ ignore_label=self.label_manager.ignore_label)
+
+ dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim)
+
+ allowed_num_processes = get_allowed_n_proc_DA()
+ if allowed_num_processes == 0:
+ mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms)
+ mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms)
+ else:
+ mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, data_loader=dl_tr, transform=tr_transforms,
+ num_processes=allowed_num_processes, num_cached=6, seeds=None,
+ pin_memory=self.device.type == 'cuda', wait_time=0.02)
+ mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, data_loader=dl_val,
+ transform=val_transforms, num_processes=max(1, allowed_num_processes // 2),
+ num_cached=3, seeds=None, pin_memory=self.device.type == 'cuda',
+ wait_time=0.02)
+ return mt_gen_train, mt_gen_val
+
+ def get_plain_dataloaders(self, initial_patch_size: Tuple[int, ...], dim: int):
+ dataset_tr, dataset_val = self.get_tr_and_val_datasets()
+
+ if dim == 2:
+ dl_tr = nnUNetDataLoader2D(dataset_tr, self.batch_size,
+ initial_patch_size,
+ self.configuration_manager.patch_size,
+ self.label_manager,
+ oversample_foreground_percent=self.oversample_foreground_percent,
+ sampling_probabilities=None, pad_sides=None)
+ dl_val = nnUNetDataLoader2D(dataset_val, self.batch_size,
+ self.configuration_manager.patch_size,
+ self.configuration_manager.patch_size,
+ self.label_manager,
+ oversample_foreground_percent=self.oversample_foreground_percent,
+ sampling_probabilities=None, pad_sides=None)
+ else:
+ dl_tr = nnUNetDataLoader3D(dataset_tr, self.batch_size,
+ initial_patch_size,
+ self.configuration_manager.patch_size,
+ self.label_manager,
+ oversample_foreground_percent=self.oversample_foreground_percent,
+ sampling_probabilities=None, pad_sides=None)
+ dl_val = nnUNetDataLoader3D(dataset_val, self.batch_size,
+ self.configuration_manager.patch_size,
+ self.configuration_manager.patch_size,
+ self.label_manager,
+ oversample_foreground_percent=self.oversample_foreground_percent,
+ sampling_probabilities=None, pad_sides=None)
+ return dl_tr, dl_val
+
+ @staticmethod
+ def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]],
+ rotation_for_DA: dict,
+ deep_supervision_scales: Union[List, Tuple],
+ mirror_axes: Tuple[int, ...],
+ do_dummy_2d_data_aug: bool,
+ order_resampling_data: int = 3,
+ order_resampling_seg: int = 1,
+ border_val_seg: int = -1,
+ use_mask_for_norm: List[bool] = None,
+ is_cascaded: bool = False,
+ foreground_labels: Union[Tuple[int, ...], List[int]] = None,
+ regions: List[Union[List[int], Tuple[int, ...], int]] = None,
+ ignore_label: int = None) -> AbstractTransform:
+ tr_transforms = []
+
+ tr_transforms.append(RenameTransform('seg', 'target', True))
+
+ if regions is not None:
+ # the ignore label must also be converted
+ tr_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label]
+ if ignore_label is not None else regions,
+ 'target', 'target'))
+
+ if deep_supervision_scales is not None:
+ tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',
+ output_key='target'))
+ tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
+ tr_transforms = Compose(tr_transforms)
+ return tr_transforms
+
+ @staticmethod
+ def get_validation_transforms(deep_supervision_scales: Union[List, Tuple],
+ is_cascaded: bool = False,
+ foreground_labels: Union[Tuple[int, ...], List[int]] = None,
+ regions: List[Union[List[int], Tuple[int, ...], int]] = None,
+ ignore_label: int = None) -> AbstractTransform:
+ val_transforms = []
+ val_transforms.append(RemoveLabelTransform(-1, 0))
+
+ if is_cascaded:
+ val_transforms.append(MoveSegAsOneHotToData(1, foreground_labels, 'seg', 'data'))
+
+ val_transforms.append(RenameTransform('seg', 'target', True))
+
+ if regions is not None:
+ # the ignore label must also be converted
+ val_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label]
+ if ignore_label is not None else regions,
+ 'target', 'target'))
+
+ if deep_supervision_scales is not None:
+ val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',
+ output_key='target'))
+
+ val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
+ val_transforms = Compose(val_transforms)
+ return val_transforms
+
+ def set_deep_supervision_enabled(self, enabled: bool):
+ """
+ This function is specific for the default architecture in nnU-Net. If you change the architecture, there are
+ chances you need to change this as well!
+ """
+ if self.is_ddp:
+ self.network.module.decoder.deep_supervision = enabled
+ else:
+ self.network.decoder.deep_supervision = enabled
+
+ def on_train_start(self):
+ if not self.was_initialized:
+ self.initialize()
+
+ maybe_mkdir_p(self.output_folder)
+
+ # make sure deep supervision is on in the network
+ self.set_deep_supervision_enabled(True)
+
+ self.print_plans()
+ empty_cache(self.device)
+
+ # maybe unpack
+ if self.unpack_dataset and self.local_rank == 0:
+ self.print_to_log_file('unpacking dataset...')
+ unpack_dataset(self.preprocessed_dataset_folder, unpack_segmentation=True, overwrite_existing=False,
+ num_processes=max(1, round(get_allowed_n_proc_DA() // 2)))
+ self.print_to_log_file('unpacking done...')
+
+ if self.is_ddp:
+ dist.barrier()
+
+ # dataloaders must be instantiated here because they need access to the training data which may not be present
+ # when doing inference
+ self.dataloader_train, self.dataloader_val = self.get_dataloaders()
+
+ # copy plans and dataset.json so that they can be used for restoring everything we need for inference
+ save_json(self.plans_manager.plans, join(self.output_folder_base, 'plans.json'), sort_keys=False)
+ save_json(self.dataset_json, join(self.output_folder_base, 'dataset.json'), sort_keys=False)
+
+ # we don't really need the fingerprint but its still handy to have it with the others
+ shutil.copy(join(self.preprocessed_dataset_folder_base, 'dataset_fingerprint.json'),
+ join(self.output_folder_base, 'dataset_fingerprint.json'))
+
+ # produces a pdf in output folder
+ self.plot_network_architecture()
+
+ self._save_debug_information()
+
+ # print(f"batch size: {self.batch_size}")
+ # print(f"oversample: {self.oversample_foreground_percent}")
+
+ def on_train_end(self):
+ self.save_checkpoint(join(self.output_folder, "checkpoint_final.pth"))
+ # now we can delete latest
+ if self.local_rank == 0 and isfile(join(self.output_folder, "checkpoint_latest.pth")):
+ os.remove(join(self.output_folder, "checkpoint_latest.pth"))
+
+ # shut down dataloaders
+ old_stdout = sys.stdout
+ with open(os.devnull, 'w') as f:
+ sys.stdout = f
+ if self.dataloader_train is not None:
+ self.dataloader_train._finish()
+ if self.dataloader_val is not None:
+ self.dataloader_val._finish()
+ sys.stdout = old_stdout
+
+ empty_cache(self.device)
+ self.print_to_log_file("Training done.")
+
+ def on_train_epoch_start(self):
+ self.network.train()
+ self.lr_scheduler.step(self.current_epoch)
+ self.print_to_log_file('')
+ self.print_to_log_file(f'Epoch {self.current_epoch}')
+ self.print_to_log_file(
+ f"Current learning rate: {np.round(self.optimizer.param_groups[0]['lr'], decimals=5)}")
+ # lrs are the same for all workers so we don't need to gather them in case of DDP training
+ self.logger.log('lrs', self.optimizer.param_groups[0]['lr'], self.current_epoch)
+
+ def train_step(self, batch: dict) -> dict:
+ data = batch['data']
+ target = batch['target']
+
+ data = data.to(self.device, non_blocking=True)
+ if isinstance(target, list):
+ target = [i.to(self.device, non_blocking=True) for i in target]
+ else:
+ target = target.to(self.device, non_blocking=True)
+
+ self.optimizer.zero_grad()
+ # Autocast is a little bitch.
+ # If the device_type is 'cpu' then it's slow as heck and needs to be disabled.
+ # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
+ # So autocast will only be active if we have a cuda device.
+ with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
+ output = self.network(data)
+ # del data
+ l = self.loss(output, target)
+
+ if self.grad_scaler is not None:
+ self.grad_scaler.scale(l).backward()
+ self.grad_scaler.unscale_(self.optimizer)
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
+ self.grad_scaler.step(self.optimizer)
+ self.grad_scaler.update()
+ else:
+ l.backward()
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
+ self.optimizer.step()
+ return {'loss': l.detach().cpu().numpy()}
+
+ def on_train_epoch_end(self, train_outputs: List[dict]):
+ outputs = collate_outputs(train_outputs)
+
+ if self.is_ddp:
+ losses_tr = [None for _ in range(dist.get_world_size())]
+ dist.all_gather_object(losses_tr, outputs['loss'])
+ loss_here = np.vstack(losses_tr).mean()
+ else:
+ loss_here = np.mean(outputs['loss'])
+
+ self.logger.log('train_losses', loss_here, self.current_epoch)
+
+ def on_validation_epoch_start(self):
+ self.network.eval()
+
+ def validation_step(self, batch: dict) -> dict:
+ data = batch['data']
+ target = batch['target']
+
+ data = data.to(self.device, non_blocking=True)
+ if isinstance(target, list):
+ target = [i.to(self.device, non_blocking=True) for i in target]
+ else:
+ target = target.to(self.device, non_blocking=True)
+
+ # Autocast is a little bitch.
+ # If the device_type is 'cpu' then it's slow as heck and needs to be disabled.
+ # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
+ # So autocast will only be active if we have a cuda device.
+ with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
+ output = self.network(data)
+ del data
+ l = self.loss(output, target)
+
+ # we only need the output with the highest output resolution
+ output = output[0]
+ target = target[0]
+
+ # the following is needed for online evaluation. Fake dice (green line)
+ axes = [0] + list(range(2, len(output.shape)))
+
+ if self.label_manager.has_regions:
+ predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long()
+ else:
+ # no need for softmax
+ output_seg = output.argmax(1)[:, None]
+ predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float32)
+ predicted_segmentation_onehot.scatter_(1, output_seg, 1)
+ del output_seg
+
+ if self.label_manager.has_ignore_label:
+ if not self.label_manager.has_regions:
+ mask = (target != self.label_manager.ignore_label).float()
+ # CAREFUL that you don't rely on target after this line!
+ target[target == self.label_manager.ignore_label] = 0
+ else:
+ mask = 1 - target[:, -1:]
+ # CAREFUL that you don't rely on target after this line!
+ target = target[:, :-1]
+ else:
+ mask = None
+
+ tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask)
+
+ tp_hard = tp.detach().cpu().numpy()
+ fp_hard = fp.detach().cpu().numpy()
+ fn_hard = fn.detach().cpu().numpy()
+ if not self.label_manager.has_regions:
+ # if we train with regions all segmentation heads predict some kind of foreground. In conventional
+ # (softmax training) there needs tobe one output for the background. We are not interested in the
+ # background Dice
+ # [1:] in order to remove background
+ tp_hard = tp_hard[1:]
+ fp_hard = fp_hard[1:]
+ fn_hard = fn_hard[1:]
+
+ return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard}
+
+ def on_validation_epoch_end(self, val_outputs: List[dict]):
+ outputs_collated = collate_outputs(val_outputs)
+ tp = np.sum(outputs_collated['tp_hard'], 0)
+ fp = np.sum(outputs_collated['fp_hard'], 0)
+ fn = np.sum(outputs_collated['fn_hard'], 0)
+
+ if self.is_ddp:
+ world_size = dist.get_world_size()
+
+ tps = [None for _ in range(world_size)]
+ dist.all_gather_object(tps, tp)
+ tp = np.vstack([i[None] for i in tps]).sum(0)
+
+ fps = [None for _ in range(world_size)]
+ dist.all_gather_object(fps, fp)
+ fp = np.vstack([i[None] for i in fps]).sum(0)
+
+ fns = [None for _ in range(world_size)]
+ dist.all_gather_object(fns, fn)
+ fn = np.vstack([i[None] for i in fns]).sum(0)
+
+ losses_val = [None for _ in range(world_size)]
+ dist.all_gather_object(losses_val, outputs_collated['loss'])
+ loss_here = np.vstack(losses_val).mean()
+ else:
+ loss_here = np.mean(outputs_collated['loss'])
+
+ global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in
+ zip(tp, fp, fn)]]
+ mean_fg_dice = np.nanmean(global_dc_per_class)
+ self.logger.log('mean_fg_dice', mean_fg_dice, self.current_epoch)
+ self.logger.log('dice_per_class_or_region', global_dc_per_class, self.current_epoch)
+ self.logger.log('val_losses', loss_here, self.current_epoch)
+
+ def on_epoch_start(self):
+ self.logger.log('epoch_start_timestamps', time(), self.current_epoch)
+
+ def on_epoch_end(self):
+ self.logger.log('epoch_end_timestamps', time(), self.current_epoch)
+
+ # todo find a solution for this stupid shit
+ self.print_to_log_file('train_loss', np.round(self.logger.my_fantastic_logging['train_losses'][-1], decimals=4))
+ self.print_to_log_file('val_loss', np.round(self.logger.my_fantastic_logging['val_losses'][-1], decimals=4))
+ self.print_to_log_file('Pseudo dice', [np.round(i, decimals=4) for i in
+ self.logger.my_fantastic_logging['dice_per_class_or_region'][-1]])
+ self.print_to_log_file(
+ f"Epoch time: {np.round(self.logger.my_fantastic_logging['epoch_end_timestamps'][-1] - self.logger.my_fantastic_logging['epoch_start_timestamps'][-1], decimals=2)} s")
+
+ # handling periodic checkpointing
+ current_epoch = self.current_epoch
+ if (current_epoch + 1) % self.save_every == 0 and current_epoch != (self.num_epochs - 1):
+ self.save_checkpoint(join(self.output_folder, 'checkpoint_latest.pth'))
+
+ # handle 'best' checkpointing. ema_fg_dice is computed by the logger and can be accessed like this
+ if self._best_ema is None or self.logger.my_fantastic_logging['ema_fg_dice'][-1] > self._best_ema:
+ self._best_ema = self.logger.my_fantastic_logging['ema_fg_dice'][-1]
+ self.print_to_log_file(f"Yayy! New best EMA pseudo Dice: {np.round(self._best_ema, decimals=4)}")
+ self.save_checkpoint(join(self.output_folder, 'checkpoint_best.pth'))
+
+ if self.local_rank == 0:
+ self.logger.plot_progress_png(self.output_folder)
+
+ self.current_epoch += 1
+
+ def save_checkpoint(self, filename: str) -> None:
+ if self.local_rank == 0:
+ if not self.disable_checkpointing:
+ if self.is_ddp:
+ mod = self.network.module
+ else:
+ mod = self.network
+ if isinstance(mod, OptimizedModule):
+ mod = mod._orig_mod
+
+ checkpoint = {
+ 'network_weights': mod.state_dict(),
+ 'optimizer_state': self.optimizer.state_dict(),
+ 'grad_scaler_state': self.grad_scaler.state_dict() if self.grad_scaler is not None else None,
+ 'logging': self.logger.get_checkpoint(),
+ '_best_ema': self._best_ema,
+ 'current_epoch': self.current_epoch + 1,
+ 'init_args': self.my_init_kwargs,
+ 'trainer_name': self.__class__.__name__,
+ 'inference_allowed_mirroring_axes': self.inference_allowed_mirroring_axes,
+ }
+ torch.save(checkpoint, filename)
+ else:
+ self.print_to_log_file('No checkpoint written, checkpointing is disabled')
+
+ def load_checkpoint(self, filename_or_checkpoint: Union[dict, str]) -> None:
+ if not self.was_initialized:
+ self.initialize()
+
+ if isinstance(filename_or_checkpoint, str):
+ checkpoint = torch.load(filename_or_checkpoint, map_location=self.device)
+ # if state dict comes from nn.DataParallel but we use non-parallel model here then the state dict keys do not
+ # match. Use heuristic to make it match
+ new_state_dict = {}
+ for k, value in checkpoint['network_weights'].items():
+ key = k
+ if key not in self.network.state_dict().keys() and key.startswith('module.'):
+ key = key[7:]
+ new_state_dict[key] = value
+
+ self.my_init_kwargs = checkpoint['init_args']
+ self.current_epoch = checkpoint['current_epoch']
+ self.logger.load_checkpoint(checkpoint['logging'])
+ self._best_ema = checkpoint['_best_ema']
+ self.inference_allowed_mirroring_axes = checkpoint[
+ 'inference_allowed_mirroring_axes'] if 'inference_allowed_mirroring_axes' in checkpoint.keys() else self.inference_allowed_mirroring_axes
+
+ # messing with state dict naming schemes. Facepalm.
+ if self.is_ddp:
+ if isinstance(self.network.module, OptimizedModule):
+ self.network.module._orig_mod.load_state_dict(new_state_dict)
+ else:
+ self.network.module.load_state_dict(new_state_dict)
+ else:
+ if isinstance(self.network, OptimizedModule):
+ self.network._orig_mod.load_state_dict(new_state_dict)
+ else:
+ self.network.load_state_dict(new_state_dict)
+ self.optimizer.load_state_dict(checkpoint['optimizer_state'])
+ if self.grad_scaler is not None:
+ if checkpoint['grad_scaler_state'] is not None:
+ self.grad_scaler.load_state_dict(checkpoint['grad_scaler_state'])
+
+ def perform_actual_validation(self, save_probabilities: bool = False):
+ self.set_deep_supervision_enabled(False)
+ self.network.eval()
+
+ predictor = nnUNetPredictor(tile_step_size=0.5, use_gaussian=True, use_mirroring=True,
+ perform_everything_on_gpu=True, device=self.device, verbose=False,
+ verbose_preprocessing=False, allow_tqdm=False)
+ predictor.manual_initialization(self.network, self.plans_manager, self.configuration_manager, None,
+ self.dataset_json, self.__class__.__name__,
+ self.inference_allowed_mirroring_axes)
+
+ with multiprocessing.get_context("spawn").Pool(default_num_processes) as segmentation_export_pool:
+ worker_list = [i for i in segmentation_export_pool._pool]
+ validation_output_folder = join(self.output_folder, 'validation')
+ maybe_mkdir_p(validation_output_folder)
+
+ # we cannot use self.get_tr_and_val_datasets() here because we might be DDP and then we have to distribute
+ # the validation keys across the workers.
+ _, val_keys = self.do_split()
+ if self.is_ddp:
+ val_keys = val_keys[self.local_rank:: dist.get_world_size()]
+
+ dataset_val = nnUNetDataset(self.preprocessed_dataset_folder, val_keys,
+ folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage,
+ num_images_properties_loading_threshold=0)
+
+ next_stages = self.configuration_manager.next_stage_names
+
+ if next_stages is not None:
+ _ = [maybe_mkdir_p(join(self.output_folder_base, 'predicted_next_stage', n)) for n in next_stages]
+
+ results = []
+
+ for k in dataset_val.keys():
+ proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results,
+ allowed_num_queued=2)
+ while not proceed:
+ sleep(0.1)
+ proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results,
+ allowed_num_queued=2)
+
+ self.print_to_log_file(f"predicting {k}")
+ data, seg, properties = dataset_val.load_case(k)
+
+ if self.is_cascaded:
+ data = np.vstack((data, convert_labelmap_to_one_hot(seg[-1], self.label_manager.foreground_labels,
+ output_dtype=data.dtype)))
+ with warnings.catch_warnings():
+ # ignore 'The given NumPy array is not writable' warning
+ warnings.simplefilter("ignore")
+ data = torch.from_numpy(data)
+
+ output_filename_truncated = join(validation_output_folder, k)
+
+ try:
+ prediction = predictor.predict_sliding_window_return_logits(data)
+ except RuntimeError:
+ predictor.perform_everything_on_gpu = False
+ prediction = predictor.predict_sliding_window_return_logits(data)
+ predictor.perform_everything_on_gpu = True
+
+ prediction = prediction.cpu()
+
+ # this needs to go into background processes
+ results.append(
+ segmentation_export_pool.starmap_async(
+ export_prediction_from_logits, (
+ (prediction, properties, self.configuration_manager, self.plans_manager,
+ self.dataset_json, output_filename_truncated, save_probabilities),
+ )
+ )
+ )
+ # for debug purposes
+ # export_prediction(prediction_for_export, properties, self.configuration, self.plans, self.dataset_json,
+ # output_filename_truncated, save_probabilities)
+
+ # if needed, export the softmax prediction for the next stage
+ if next_stages is not None:
+ for n in next_stages:
+ next_stage_config_manager = self.plans_manager.get_configuration(n)
+ expected_preprocessed_folder = join(nnUNet_preprocessed, self.plans_manager.dataset_name,
+ next_stage_config_manager.data_identifier)
+
+ try:
+ # we do this so that we can use load_case and do not have to hard code how loading training cases is implemented
+ tmp = nnUNetDataset(expected_preprocessed_folder, [k],
+ num_images_properties_loading_threshold=0)
+ d, s, p = tmp.load_case(k)
+ except FileNotFoundError:
+ self.print_to_log_file(
+ f"Predicting next stage {n} failed for case {k} because the preprocessed file is missing! "
+ f"Run the preprocessing for this configuration first!")
+ continue
+
+ target_shape = d.shape[1:]
+ output_folder = join(self.output_folder_base, 'predicted_next_stage', n)
+ output_file = join(output_folder, k + '.npz')
+
+ # resample_and_save(prediction, target_shape, output_file, self.plans_manager, self.configuration_manager, properties,
+ # self.dataset_json)
+ results.append(segmentation_export_pool.starmap_async(
+ resample_and_save, (
+ (prediction, target_shape, output_file, self.plans_manager,
+ self.configuration_manager,
+ properties,
+ self.dataset_json),
+ )
+ ))
+
+ _ = [r.get() for r in results]
+
+ if self.is_ddp:
+ dist.barrier()
+
+ if self.local_rank == 0:
+ metrics = compute_metrics_on_folder(join(self.preprocessed_dataset_folder_base, 'gt_segmentations'),
+ validation_output_folder,
+ join(validation_output_folder, 'summary.json'),
+ self.plans_manager.image_reader_writer_class(),
+ self.dataset_json["file_ending"],
+ self.label_manager.foreground_regions if self.label_manager.has_regions else
+ self.label_manager.foreground_labels,
+ self.label_manager.ignore_label, chill=True)
+ self.print_to_log_file("Validation complete", also_print_to_console=True)
+ self.print_to_log_file("Mean Validation Dice: ", (metrics['foreground_mean']["Dice"]), also_print_to_console=True)
+
+ self.set_deep_supervision_enabled(True)
+ compute_gaussian.cache_clear()
+
+ def run_training(self):
+ self.on_train_start()
+
+ for epoch in range(self.current_epoch, self.num_epochs):
+ self.on_epoch_start()
+
+ self.on_train_epoch_start()
+ train_outputs = []
+ for batch_id in range(self.num_iterations_per_epoch):
+ train_outputs.append(self.train_step(next(self.dataloader_train)))
+ self.on_train_epoch_end(train_outputs)
+
+ with torch.no_grad():
+ self.on_validation_epoch_start()
+ val_outputs = []
+ for batch_id in range(self.num_val_iterations_per_epoch):
+ val_outputs.append(self.validation_step(next(self.dataloader_val)))
+ self.on_validation_epoch_end(val_outputs)
+
+ self.on_epoch_end()
+
+ self.on_train_end()
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer_autopet.py b/nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer_autopet.py
new file mode 100644
index 0000000..51b8f0d
--- /dev/null
+++ b/nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer_autopet.py
@@ -0,0 +1,1339 @@
+import inspect
+import multiprocessing
+import os
+import shutil
+import sys
+import warnings
+from copy import deepcopy
+from datetime import datetime
+from time import time, sleep
+from typing import Union, Tuple, List
+
+import numpy as np
+import torch
+from torch.nn import BCEWithLogitsLoss
+from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter
+from batchgenerators.transforms.abstract_transforms import AbstractTransform, Compose
+from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \
+ ContrastAugmentationTransform, GammaTransform
+from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform
+from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform
+from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform
+from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor
+from batchgenerators.utilities.file_and_folder_operations import join, load_json, isfile, save_json, maybe_mkdir_p
+from torch._dynamo import OptimizedModule
+
+from nnunetv2.configuration import ANISO_THRESHOLD, default_num_processes
+from nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder
+from nnunetv2.inference.export_prediction import export_prediction_from_logits, resample_and_save
+from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
+from nnunetv2.inference.sliding_window_prediction import compute_gaussian
+from nnunetv2.paths import nnUNet_preprocessed, nnUNet_results
+from nnunetv2.training.data_augmentation.compute_initial_patch_size import get_patch_size
+from nnunetv2.training.data_augmentation.custom_transforms.cascade_transforms import MoveSegAsOneHotToData, \
+ ApplyRandomBinaryOperatorTransform, RemoveRandomConnectedComponentFromOneHotEncodingTransform
+from nnunetv2.training.data_augmentation.custom_transforms.deep_supervision_donwsampling import \
+ DownsampleSegForDSTransform2
+from nnunetv2.training.data_augmentation.custom_transforms.limited_length_multithreaded_augmenter import \
+ LimitedLenWrapper
+from nnunetv2.training.data_augmentation.custom_transforms.masking import MaskTransform
+from nnunetv2.training.data_augmentation.custom_transforms.region_based_training import \
+ ConvertSegmentationToRegionsTransform
+from nnunetv2.training.data_augmentation.custom_transforms.transforms_for_dummy_2d import Convert2DTo3DTransform, \
+ Convert3DTo2DTransform
+from nnunetv2.training.dataloading.data_loader_2d import nnUNetDataLoader2D
+from nnunetv2.training.dataloading.data_loader_3d import nnUNetDataLoader3D
+from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset
+from nnunetv2.training.dataloading.utils import get_case_identifiers, unpack_dataset
+from nnunetv2.training.logging.nnunet_logger import nnUNetLogger
+from nnunetv2.training.loss.compound_losses import DC_and_CE_loss, DC_and_BCE_loss
+from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper
+from nnunetv2.training.loss.dice import get_tp_fp_fn_tn, MemoryEfficientSoftDiceLoss
+from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler
+from nnunetv2.utilities.collate_outputs import collate_outputs
+from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA
+from nnunetv2.utilities.file_path_utilities import check_workers_alive_and_busy
+from nnunetv2.utilities.get_network_from_plans import get_network_from_plans
+from nnunetv2.utilities.helpers import empty_cache, dummy_context
+from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot, determine_num_input_channels
+from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
+from sklearn.model_selection import KFold
+from torch import autocast, nn
+from torch import distributed as dist
+from torch.cuda import device_count
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
+from monai.networks.nets import ViT
+from dynamic_network_architectures.building_blocks.helper import convert_conv_op_to_dim
+
+
+class AutoPETNet(nn.Module):
+ def __init__(self, encoder, decoder, cl_a, cl_c, cl_s, classifier, fs_a, fs_c, fs_s):
+ super().__init__()
+ self.encoder = encoder
+ self.decoder = decoder
+ self.cl_a = cl_a
+ self.cl_c = cl_c
+ self.cl_s = cl_s
+ self.classifier = classifier
+ self.fs_a = fs_a
+ self.fs_c = fs_c
+ self.fs_s = fs_s
+ self.hidden_size = 768
+ spatial_dims = 3
+ self.proj_axes = (0, spatial_dims + 1) + tuple(d + 1 for d in range(spatial_dims))
+
+ def proj_feat(self, x, feat_size):
+ self.proj_view_shape = list(feat_size) + [self.hidden_size]
+ new_view = [x.size(0)] + self.proj_view_shape
+ x = x.view(new_view)
+ x = x.permute(self.proj_axes).contiguous()
+ return x
+
+ def forward(self, x):
+
+ mip_axial = torch.cat([torch.max(x[idx].unsqueeze(0), 4, keepdim=True)[0] for idx in range(np.shape(x)[0])], dim=0)
+ mip_coro = torch.cat([torch.max(x[idx].unsqueeze(0), 3, keepdim=True)[0] for idx in range(np.shape(x)[0])], dim=0)
+ mip_sagi = torch.cat([torch.max(x[idx].unsqueeze(0), 2, keepdim=True)[0] for idx in range(np.shape(x)[0])], dim=0)
+ skips = self.encoder(x)
+ output = self.decoder(skips)
+ feature_a, hs_a = self.cl_a(mip_axial)
+ feature_c, hs_c = self.cl_c(mip_coro)
+ feature_s, hs_s = self.cl_s(mip_sagi)
+ features = torch.nn.AvgPool3d((4, 4, 4))(skips[-1])
+ feature_a = torch.nn.AvgPool3d((8, 8, 1))(self.proj_feat(feature_a, self.fs_a))
+ feature_c = torch.nn.AvgPool3d((8, 1, 8))(self.proj_feat(feature_c, self.fs_c))
+ feature_s = torch.nn.AvgPool3d((1, 8, 8))(self.proj_feat(feature_s, self.fs_s))
+ all_features = torch.cat([features, feature_a, feature_c, feature_s], dim=1).squeeze(-1).squeeze(-1).squeeze(-1)
+ classif = torch.softmax(self.classifier(all_features), dim=1)
+ if self.training:
+ return output, classif
+ else:
+ if isinstance(output, list):
+ for i in range(output[0].shape[0]):
+ output[0][i] = output[0][i] * torch.argmax(classif, dim=1)[i]
+ return output
+ else:
+ for i in range(output.shape[0]):
+ output[i] = output[i] * torch.argmax(classif, dim=1)[i]
+ return output
+
+ def compute_conv_feature_map_size(self, input_size):
+ # Ok this is not good. Later ?
+ assert len(input_size) == convert_conv_op_to_dim(self.encoder.conv_op), "just give the image size without color/feature channels or " \
+ "batch channel. Do not give input_size=(b, c, x, y(, z)). " \
+ "Give input_size=(x, y(, z))!"
+ return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size)
+
+
+class nnUNetTrainer_autopet(nnUNetTrainer):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ # From https://grugbrain.dev/. Worth a read ya big brains ;-)
+
+ # apex predator of grug is complexity
+ # complexity bad
+ # say again:
+ # complexity very bad
+ # you say now:
+ # complexity very, very bad
+ # given choice between complexity or one on one against t-rex, grug take t-rex: at least grug see t-rex
+ # complexity is spirit demon that enter codebase through well-meaning but ultimately very clubbable non grug-brain developers and project managers who not fear complexity spirit demon or even know about sometime
+ # one day code base understandable and grug can get work done, everything good!
+ # next day impossible: complexity demon spirit has entered code and very dangerous situation!
+
+ # OK OK I am guilty. But I tried. http://tiny.cc/gzgwuz
+
+ self.is_ddp = dist.is_available() and dist.is_initialized()
+ self.local_rank = 0 if not self.is_ddp else dist.get_rank()
+
+ self.device = device
+
+ # print what device we are using
+ if self.is_ddp: # implicitly it's clear that we use cuda in this case
+ print(f"I am local rank {self.local_rank}. {device_count()} GPUs are available. The world size is "
+ f"{dist.get_world_size()}."
+ f"Setting device to {self.device}")
+ self.device = torch.device(type='cuda', index=self.local_rank)
+ else:
+ if self.device.type == 'cuda':
+ # we might want to let the user pick this but for now please pick the correct GPU with CUDA_VISIBLE_DEVICES=X
+ self.device = torch.device(type='cuda', index=0)
+ print(f"Using device: {self.device}")
+
+ # loading and saving this class for continuing from checkpoint should not happen based on pickling. This
+ # would also pickle the network etc. Bad, bad. Instead we just reinstantiate and then load the checkpoint we
+ # need. So let's save the init args
+ self.my_init_kwargs = {}
+ for k in inspect.signature(self.__init__).parameters.keys():
+ self.my_init_kwargs[k] = locals()[k]
+
+ ### Saving all the init args into class variables for later access
+ self.plans_manager = PlansManager(plans)
+ self.configuration_manager = self.plans_manager.get_configuration(configuration)
+ self.configuration_name = configuration
+ self.dataset_json = dataset_json
+ self.fold = fold
+ self.unpack_dataset = unpack_dataset
+
+ ### Setting all the folder names. We need to make sure things don't crash in case we are just running
+ # inference and some of the folders may not be defined!
+ self.preprocessed_dataset_folder_base = join(nnUNet_preprocessed, self.plans_manager.dataset_name) \
+ if nnUNet_preprocessed is not None else None
+ self.output_folder_base = join(nnUNet_results, self.plans_manager.dataset_name,
+ self.__class__.__name__ + '__' + self.plans_manager.plans_name + "__" + configuration) \
+ if nnUNet_results is not None else None
+ self.output_folder = join(self.output_folder_base, f'fold_{fold}')
+
+ self.preprocessed_dataset_folder = join(self.preprocessed_dataset_folder_base,
+ self.configuration_manager.data_identifier)
+ # unlike the previous nnunet folder_with_segs_from_previous_stage is now part of the plans. For now it has to
+ # be a different configuration in the same plans
+ # IMPORTANT! the mapping must be bijective, so lowres must point to fullres and vice versa (using
+ # "previous_stage" and "next_stage"). Otherwise it won't work!
+ self.is_cascaded = self.configuration_manager.previous_stage_name is not None
+ self.folder_with_segs_from_previous_stage = \
+ join(nnUNet_results, self.plans_manager.dataset_name,
+ self.__class__.__name__ + '__' + self.plans_manager.plans_name + "__" +
+ self.configuration_manager.previous_stage_name, 'predicted_next_stage', self.configuration_name) \
+ if self.is_cascaded else None
+
+ ### Some hyperparameters for you to fiddle with
+ self.initial_lr = 1e-4
+ self.weight_decay = 3e-5
+ self.oversample_foreground_percent = 0.33
+ self.num_iterations_per_epoch = 825
+ self.num_val_iterations_per_epoch = 200
+ self.num_epochs = 250
+ self.current_epoch = 0
+
+ ### Dealing with labels/regions
+ self.label_manager = self.plans_manager.get_label_manager(dataset_json)
+ # labels can either be a list of int (regular training) or a list of tuples of int (region-based training)
+ # needed for predictions. We do sigmoid in case of (overlapping) regions
+
+ self.num_input_channels = None # -> self.initialize()
+ self.network = None # -> self._get_network()
+ self.optimizer = self.lr_scheduler = None # -> self.initialize
+ self.grad_scaler = GradScaler() if self.device.type == 'cuda' else None
+ self.loss = None # -> self.initialize
+
+ ### Simple logging. Don't take that away from me!
+ # initialize log file. This is just our log for the print statements etc. Not to be confused with lightning
+ # logging
+ timestamp = datetime.now()
+ maybe_mkdir_p(self.output_folder)
+ self.log_file = join(self.output_folder, "training_log_%d_%d_%d_%02.0d_%02.0d_%02.0d.txt" %
+ (timestamp.year, timestamp.month, timestamp.day, timestamp.hour, timestamp.minute,
+ timestamp.second))
+ self.logger = nnUNetLogger()
+
+ ### placeholders
+ self.dataloader_train = self.dataloader_val = None # see on_train_start
+
+ ### initializing stuff for remembering things and such
+ self._best_ema = None
+
+ ### inference things
+ self.inference_allowed_mirroring_axes = None # this variable is set in
+ # self.configure_rotation_dummyDA_mirroring_and_inital_patch_size and will be saved in checkpoints
+
+ ### checkpoint saving stuff
+ self.save_every = 50
+ self.disable_checkpointing = False
+
+ ## DDP batch size and oversampling can differ between workers and needs adaptation
+ # we need to change the batch size in DDP because we don't use any of those distributed samplers
+ self._set_batch_size_and_oversample()
+ self.hidden_size = 768
+ spatial_dims = 3
+ self.proj_axes = (0, spatial_dims + 1) + tuple(d + 1 for d in range(spatial_dims))
+
+ self.was_initialized = False
+
+ self.print_to_log_file("\n#######################################################################\n"
+ "Please cite the following paper when using nnU-Net:\n"
+ "Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). "
+ "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. "
+ "Nature methods, 18(2), 203-211.\n"
+ "#######################################################################\n",
+ also_print_to_console=True, add_timestamp=False)
+
+ def proj_feat(self, x, feat_size):
+ self.proj_view_shape = list(feat_size) + [self.hidden_size]
+ new_view = [x.size(0)] + self.proj_view_shape
+ x = x.view(new_view)
+ x = x.permute(self.proj_axes).contiguous()
+ return x
+
+ def initialize(self):
+ if not self.was_initialized:
+ self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager,
+ self.dataset_json)
+
+ self.network = self.build_network_architecture(self.plans_manager, self.dataset_json,
+ self.configuration_manager,
+ self.num_input_channels,
+ enable_deep_supervision=True).to(self.device)
+ # compile network for free speedup
+ if ('nnUNet_compile' in os.environ.keys()) and (
+ os.environ['nnUNet_compile'].lower() in ('true', '1', 't')):
+ self.print_to_log_file('Compiling network...')
+ self.network = torch.compile(self.network)
+
+ self.optimizer, self.lr_scheduler = self.configure_optimizers()
+ # if ddp, wrap in DDP wrapper
+ if self.is_ddp:
+ self.network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.network)
+ self.network = DDP(self.network, device_ids=[self.local_rank])
+
+ self.loss = self._build_loss()
+ self.was_initialized = True
+ else:
+ raise RuntimeError("You have called self.initialize even though the trainer was already initialized. "
+ "That should not happen.")
+
+ def _save_debug_information(self):
+ # saving some debug information
+ if self.local_rank == 0:
+ dct = {}
+ for k in self.__dir__():
+ if not k.startswith("__"):
+ if not callable(getattr(self, k)) or k in ['loss', ]:
+ dct[k] = str(getattr(self, k))
+ elif k in ['network', ]:
+ dct[k] = str(getattr(self, k).__class__.__name__)
+ else:
+ # print(k)
+ pass
+ if k in ['dataloader_train', 'dataloader_val']:
+ if hasattr(getattr(self, k), 'generator'):
+ dct[k + '.generator'] = str(getattr(self, k).generator)
+ if hasattr(getattr(self, k), 'num_processes'):
+ dct[k + '.num_processes'] = str(getattr(self, k).num_processes)
+ if hasattr(getattr(self, k), 'transform'):
+ dct[k + '.transform'] = str(getattr(self, k).transform)
+ import subprocess
+ hostname = subprocess.getoutput(['hostname'])
+ dct['hostname'] = hostname
+ torch_version = torch.__version__
+ if self.device.type == 'cuda':
+ gpu_name = torch.cuda.get_device_name()
+ dct['gpu_name'] = gpu_name
+ cudnn_version = torch.backends.cudnn.version()
+ else:
+ cudnn_version = 'None'
+ dct['device'] = str(self.device)
+ dct['torch_version'] = torch_version
+ dct['cudnn_version'] = cudnn_version
+ save_json(dct, join(self.output_folder, "debug.json"))
+
+ @staticmethod
+ def build_network_architecture(plans_manager: PlansManager,
+ dataset_json,
+ configuration_manager: ConfigurationManager,
+ num_input_channels,
+ enable_deep_supervision: bool = True) -> nn.Module:
+ """
+ his is where you build the architecture according to the plans. There is no obligation to use
+ get_network_from_plans, this is just a utility we use for the nnU-Net default architectures. You can do what
+ you want. Even ignore the plans and just return something static (as long as it can process the requested
+ patch size)
+ but don't bug us with your bugs arising from fiddling with this :-P
+ This is the function that is called in inference as well! This is needed so that all network architecture
+ variants can be loaded at inference time (inference will use the same nnUNetTrainer that was used for
+ training, so if you change the network architecture during training by deriving a new trainer class then
+ inference will know about it).
+
+ If you need to know how many segmentation outputs your custom architecture needs to have, use the following snippet:
+ > label_manager = plans_manager.get_label_manager(dataset_json)
+ > label_manager.num_segmentation_heads
+ (why so complicated? -> We can have either classical training (classes) or regions. If we have regions,
+ the number of outputs is != the number of classes. Also there is the ignore label for which no output
+ should be generated. label_manager takes care of all that for you.)
+
+ """
+
+ axial_ps = (configuration_manager.patch_size[0], configuration_manager.patch_size[1], 1)
+ coro_ps = (configuration_manager.patch_size[0], 1, configuration_manager.patch_size[2])
+ sagi_ps = (1, configuration_manager.patch_size[1], configuration_manager.patch_size[2])
+ feat_size_axial = tuple(img_d // p_d for img_d, p_d in zip(axial_ps, (16, 16, 1)))
+ feat_size_coro = tuple(img_d // p_d for img_d, p_d in zip(coro_ps, (16, 1, 16)))
+ feat_size_sagi = tuple(img_d // p_d for img_d, p_d in zip(sagi_ps, (1, 16, 16)))
+ model_classiff_axial = ViT(in_channels=num_input_channels,
+ img_size=axial_ps,
+ patch_size=(16, 16, 1), classification=False)
+ model_classiff_coro = ViT(in_channels=num_input_channels,
+ img_size=coro_ps,
+ patch_size=(16, 1, 16), classification=False)
+ model_classiff_sagi = ViT(in_channels=num_input_channels,
+ img_size=sagi_ps,
+ patch_size=(1, 16, 16), classification=False)
+ classifier = nn.Linear(3 * 768 + configuration_manager.unet_max_num_features, 2)
+ network = get_network_from_plans(plans_manager, dataset_json, configuration_manager,
+ num_input_channels, deep_supervision=enable_deep_supervision)
+ net = AutoPETNet(network.encoder, network.decoder, model_classiff_axial, model_classiff_coro, model_classiff_sagi,
+ classifier, feat_size_axial, feat_size_coro, feat_size_sagi)
+ return net
+
+ def _get_deep_supervision_scales(self):
+ deep_supervision_scales = list(list(i) for i in 1 / np.cumprod(np.vstack(
+ self.configuration_manager.pool_op_kernel_sizes), axis=0))[:-1]
+ return deep_supervision_scales
+
+ def _set_batch_size_and_oversample(self):
+ if not self.is_ddp:
+ # set batch size to what the plan says, leave oversample untouched
+ self.batch_size = self.configuration_manager.batch_size
+ else:
+ # batch size is distributed over DDP workers and we need to change oversample_percent for each worker
+ batch_sizes = []
+ oversample_percents = []
+
+ world_size = dist.get_world_size()
+ my_rank = dist.get_rank()
+
+ global_batch_size = self.configuration_manager.batch_size
+ assert global_batch_size >= world_size, 'Cannot run DDP if the batch size is smaller than the number of ' \
+ 'GPUs... Duh.'
+
+ batch_size_per_GPU = np.ceil(global_batch_size / world_size).astype(int)
+
+ for rank in range(world_size):
+ if (rank + 1) * batch_size_per_GPU > global_batch_size:
+ batch_size = batch_size_per_GPU - ((rank + 1) * batch_size_per_GPU - global_batch_size)
+ else:
+ batch_size = batch_size_per_GPU
+
+ batch_sizes.append(batch_size)
+
+ sample_id_low = 0 if len(batch_sizes) == 0 else np.sum(batch_sizes[:-1])
+ sample_id_high = np.sum(batch_sizes)
+
+ if sample_id_high / global_batch_size < (1 - self.oversample_foreground_percent):
+ oversample_percents.append(0.0)
+ elif sample_id_low / global_batch_size > (1 - self.oversample_foreground_percent):
+ oversample_percents.append(1.0)
+ else:
+ percent_covered_by_this_rank = sample_id_high / global_batch_size - sample_id_low / global_batch_size
+ oversample_percent_here = 1 - (((1 - self.oversample_foreground_percent) -
+ sample_id_low / global_batch_size) / percent_covered_by_this_rank)
+ oversample_percents.append(oversample_percent_here)
+
+ print("worker", my_rank, "oversample", oversample_percents[my_rank])
+ print("worker", my_rank, "batch_size", batch_sizes[my_rank])
+ # self.print_to_log_file("worker", my_rank, "oversample", oversample_percents[my_rank])
+ # self.print_to_log_file("worker", my_rank, "batch_size", batch_sizes[my_rank])
+
+ self.batch_size = batch_sizes[my_rank]
+ self.oversample_foreground_percent = oversample_percents[my_rank]
+
+ def _build_loss(self):
+ if self.label_manager.has_regions:
+ loss = DC_and_BCE_loss({},
+ {'batch_dice': self.configuration_manager.batch_dice,
+ 'do_bg': True, 'smooth': 1e-5, 'ddp': self.is_ddp},
+ use_ignore_label=self.label_manager.ignore_label is not None,
+ dice_class=MemoryEfficientSoftDiceLoss)
+ else:
+ loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice,
+ 'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1,
+ ignore_label=self.label_manager.ignore_label, dice_class=MemoryEfficientSoftDiceLoss)
+ self.classif_loss = BCEWithLogitsLoss()
+ deep_supervision_scales = self._get_deep_supervision_scales()
+
+ # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
+ # this gives higher resolution outputs more weight in the loss
+ weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))])
+ weights[-1] = 0
+
+ # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
+ weights = weights / weights.sum()
+ # now wrap the loss
+ loss = DeepSupervisionWrapper(loss, weights)
+ return loss
+
+ def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
+ """
+ This function is stupid and certainly one of the weakest spots of this implementation. Not entirely sure how we can fix it.
+ """
+ patch_size = self.configuration_manager.patch_size
+ dim = len(patch_size)
+ # todo rotation should be defined dynamically based on patch size (more isotropic patch sizes = more rotation)
+ if dim == 2:
+ do_dummy_2d_data_aug = False
+ # todo revisit this parametrization
+ if max(patch_size) / min(patch_size) > 1.5:
+ rotation_for_DA = {
+ 'x': (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi),
+ 'y': (0, 0),
+ 'z': (0, 0)
+ }
+ else:
+ rotation_for_DA = {
+ 'x': (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi),
+ 'y': (0, 0),
+ 'z': (0, 0)
+ }
+ mirror_axes = (0, 1)
+ elif dim == 3:
+ # todo this is not ideal. We could also have patch_size (64, 16, 128) in which case a full 180deg 2d rot would be bad
+ # order of the axes is determined by spacing, not image size
+ do_dummy_2d_data_aug = (max(patch_size) / patch_size[0]) > ANISO_THRESHOLD
+ if do_dummy_2d_data_aug:
+ # why do we rotate 180 deg here all the time? We should also restrict it
+ rotation_for_DA = {
+ 'x': (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi),
+ 'y': (0, 0),
+ 'z': (0, 0)
+ }
+ else:
+ rotation_for_DA = {
+ 'x': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi),
+ 'y': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi),
+ 'z': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi),
+ }
+ mirror_axes = (0, 1, 2)
+ else:
+ raise RuntimeError()
+
+ # todo this function is stupid. It doesn't even use the correct scale range (we keep things as they were in the
+ # old nnunet for now)
+ initial_patch_size = get_patch_size(patch_size[-dim:],
+ *rotation_for_DA.values(),
+ (0.85, 1.25))
+ if do_dummy_2d_data_aug:
+ initial_patch_size[0] = patch_size[0]
+
+ self.print_to_log_file(f'do_dummy_2d_data_aug: {do_dummy_2d_data_aug}')
+ self.inference_allowed_mirroring_axes = mirror_axes
+
+ return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
+
+ def print_to_log_file(self, *args, also_print_to_console=True, add_timestamp=True):
+ if self.local_rank == 0:
+ timestamp = time()
+ dt_object = datetime.fromtimestamp(timestamp)
+
+ if add_timestamp:
+ args = ("%s:" % dt_object, *args)
+
+ successful = False
+ max_attempts = 5
+ ctr = 0
+ while not successful and ctr < max_attempts:
+ try:
+ with open(self.log_file, 'a+') as f:
+ for a in args:
+ f.write(str(a))
+ f.write(" ")
+ f.write("\n")
+ successful = True
+ except IOError:
+ print("%s: failed to log: " % datetime.fromtimestamp(timestamp), sys.exc_info())
+ sleep(0.5)
+ ctr += 1
+ if also_print_to_console:
+ print(*args)
+ elif also_print_to_console:
+ print(*args)
+
+ def print_plans(self):
+ if self.local_rank == 0:
+ dct = deepcopy(self.plans_manager.plans)
+ del dct['configurations']
+ self.print_to_log_file(f"\nThis is the configuration used by this "
+ f"training:\nConfiguration name: {self.configuration_name}\n",
+ self.configuration_manager, '\n', add_timestamp=False)
+ self.print_to_log_file('These are the global plan.json settings:\n', dct, '\n', add_timestamp=False)
+
+ def configure_optimizers(self):
+ optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
+ momentum=0.99, nesterov=True)
+ lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs)
+ return optimizer, lr_scheduler
+
+ def plot_network_architecture(self):
+ if self.local_rank == 0:
+ try:
+ # raise NotImplementedError('hiddenlayer no longer works and we do not have a viable alternative :-(')
+ # pip install git+https://github.com/saugatkandel/hiddenlayer.git
+
+ # from torchviz import make_dot
+ # # not viable.
+ # make_dot(tuple(self.network(torch.rand((1, self.num_input_channels,
+ # *self.configuration_manager.patch_size),
+ # device=self.device)))).render(
+ # join(self.output_folder, "network_architecture.pdf"), format='pdf')
+ # self.optimizer.zero_grad()
+
+ # broken.
+
+ import hiddenlayer as hl
+ g = hl.build_graph(self.network,
+ torch.rand((1, self.num_input_channels,
+ *self.configuration_manager.patch_size),
+ device=self.device),
+ transforms=None)
+ g.save(join(self.output_folder, "network_architecture.pdf"))
+ del g
+ except Exception as e:
+ self.print_to_log_file("Unable to plot network architecture:")
+ self.print_to_log_file(e)
+
+ # self.print_to_log_file("\nprinting the network instead:\n")
+ # self.print_to_log_file(self.network)
+ # self.print_to_log_file("\n")
+ finally:
+ empty_cache(self.device)
+
+ def do_split(self):
+ """
+ The default split is a 5 fold CV on all available training cases. nnU-Net will create a split (it is seeded,
+ so always the same) and save it as splits_final.pkl file in the preprocessed data directory.
+ Sometimes you may want to create your own split for various reasons. For this you will need to create your own
+ splits_final.pkl file. If this file is present, nnU-Net is going to use it and whatever splits are defined in
+ it. You can create as many splits in this file as you want. Note that if you define only 4 splits (fold 0-3)
+ and then set fold=4 when training (that would be the fifth split), nnU-Net will print a warning and proceed to
+ use a random 80:20 data split.
+ :return:
+ """
+ if self.fold == "all":
+ # if fold==all then we use all images for training and validation
+ case_identifiers = get_case_identifiers(self.preprocessed_dataset_folder)
+ tr_keys = case_identifiers
+ val_keys = tr_keys
+ else:
+ splits_file = join(self.preprocessed_dataset_folder_base, "splits_final.json")
+ dataset = nnUNetDataset(self.preprocessed_dataset_folder, case_identifiers=None,
+ num_images_properties_loading_threshold=0,
+ folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage)
+ # if the split file does not exist we need to create it
+ if not isfile(splits_file):
+ self.print_to_log_file("Creating new 5-fold cross-validation split...")
+ splits = []
+ all_keys_sorted = np.sort(list(dataset.keys()))
+ kfold = KFold(n_splits=5, shuffle=True, random_state=12345)
+ for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)):
+ train_keys = np.array(all_keys_sorted)[train_idx]
+ test_keys = np.array(all_keys_sorted)[test_idx]
+ splits.append({})
+ splits[-1]['train'] = list(train_keys)
+ splits[-1]['val'] = list(test_keys)
+ save_json(splits, splits_file)
+
+ else:
+ self.print_to_log_file("Using splits from existing split file:", splits_file)
+ splits = load_json(splits_file)
+ self.print_to_log_file("The split file contains %d splits." % len(splits))
+
+ self.print_to_log_file("Desired fold for training: %d" % self.fold)
+ if self.fold < len(splits):
+ tr_keys = splits[self.fold]['train']
+ val_keys = splits[self.fold]['val']
+ self.print_to_log_file("This split has %d training and %d validation cases."
+ % (len(tr_keys), len(val_keys)))
+ else:
+ self.print_to_log_file("INFO: You requested fold %d for training but splits "
+ "contain only %d folds. I am now creating a "
+ "random (but seeded) 80:20 split!" % (self.fold, len(splits)))
+ # if we request a fold that is not in the split file, create a random 80:20 split
+ rnd = np.random.RandomState(seed=12345 + self.fold)
+ keys = np.sort(list(dataset.keys()))
+ idx_tr = rnd.choice(len(keys), int(len(keys) * 0.8), replace=False)
+ idx_val = [i for i in range(len(keys)) if i not in idx_tr]
+ tr_keys = [keys[i] for i in idx_tr]
+ val_keys = [keys[i] for i in idx_val]
+ self.print_to_log_file("This random 80:20 split has %d training and %d validation cases."
+ % (len(tr_keys), len(val_keys)))
+ if any([i in val_keys for i in tr_keys]):
+ self.print_to_log_file('WARNING: Some validation cases are also in the training set. Please check the '
+ 'splits.json or ignore if this is intentional.')
+ return tr_keys, val_keys
+
+ def get_tr_and_val_datasets(self):
+ # create dataset split
+ tr_keys, val_keys = self.do_split()
+
+ # load the datasets for training and validation. Note that we always draw random samples so we really don't
+ # care about distributing training cases across GPUs.
+ dataset_tr = nnUNetDataset(self.preprocessed_dataset_folder, tr_keys,
+ folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage,
+ num_images_properties_loading_threshold=0)
+ dataset_val = nnUNetDataset(self.preprocessed_dataset_folder, val_keys,
+ folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage,
+ num_images_properties_loading_threshold=0)
+ return dataset_tr, dataset_val
+
+ def get_dataloaders(self):
+ # we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether
+ # we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be
+ patch_size = self.configuration_manager.patch_size
+ dim = len(patch_size)
+
+ # needed for deep supervision: how much do we need to downscale the segmentation targets for the different
+ # outputs?
+ deep_supervision_scales = self._get_deep_supervision_scales()
+
+ rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
+ self.configure_rotation_dummyDA_mirroring_and_inital_patch_size()
+
+ # training pipeline
+ tr_transforms = self.get_training_transforms(
+ patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug,
+ order_resampling_data=3, order_resampling_seg=1,
+ use_mask_for_norm=self.configuration_manager.use_mask_for_norm,
+ is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.foreground_labels,
+ regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None,
+ ignore_label=self.label_manager.ignore_label)
+
+ # validation pipeline
+ val_transforms = self.get_validation_transforms(deep_supervision_scales,
+ is_cascaded=self.is_cascaded,
+ foreground_labels=self.label_manager.foreground_labels,
+ regions=self.label_manager.foreground_regions if
+ self.label_manager.has_regions else None,
+ ignore_label=self.label_manager.ignore_label)
+
+ dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim)
+
+ allowed_num_processes = get_allowed_n_proc_DA()
+ if allowed_num_processes == 0:
+ mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms)
+ mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms)
+ else:
+ mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, data_loader=dl_tr, transform=tr_transforms,
+ num_processes=allowed_num_processes, num_cached=6, seeds=None,
+ pin_memory=self.device.type == 'cuda', wait_time=0.02)
+ mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, data_loader=dl_val,
+ transform=val_transforms, num_processes=max(1, allowed_num_processes // 2),
+ num_cached=3, seeds=None, pin_memory=self.device.type == 'cuda',
+ wait_time=0.02)
+ return mt_gen_train, mt_gen_val
+
+ def get_plain_dataloaders(self, initial_patch_size: Tuple[int, ...], dim: int):
+ dataset_tr, dataset_val = self.get_tr_and_val_datasets()
+
+ if dim == 2:
+ dl_tr = nnUNetDataLoader2D(dataset_tr, self.batch_size,
+ initial_patch_size,
+ self.configuration_manager.patch_size,
+ self.label_manager,
+ oversample_foreground_percent=self.oversample_foreground_percent,
+ sampling_probabilities=None, pad_sides=None)
+ dl_val = nnUNetDataLoader2D(dataset_val, self.batch_size,
+ self.configuration_manager.patch_size,
+ self.configuration_manager.patch_size,
+ self.label_manager,
+ oversample_foreground_percent=self.oversample_foreground_percent,
+ sampling_probabilities=None, pad_sides=None)
+ else:
+ dl_tr = nnUNetDataLoader3D(dataset_tr, self.batch_size,
+ initial_patch_size,
+ self.configuration_manager.patch_size,
+ self.label_manager,
+ oversample_foreground_percent=self.oversample_foreground_percent,
+ sampling_probabilities=None, pad_sides=None)
+ dl_val = nnUNetDataLoader3D(dataset_val, self.batch_size,
+ self.configuration_manager.patch_size,
+ self.configuration_manager.patch_size,
+ self.label_manager,
+ oversample_foreground_percent=self.oversample_foreground_percent,
+ sampling_probabilities=None, pad_sides=None)
+ return dl_tr, dl_val
+
+ @staticmethod
+ def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]],
+ rotation_for_DA: dict,
+ deep_supervision_scales: Union[List, Tuple],
+ mirror_axes: Tuple[int, ...],
+ do_dummy_2d_data_aug: bool,
+ order_resampling_data: int = 3,
+ order_resampling_seg: int = 1,
+ border_val_seg: int = -1,
+ use_mask_for_norm: List[bool] = None,
+ is_cascaded: bool = False,
+ foreground_labels: Union[Tuple[int, ...], List[int]] = None,
+ regions: List[Union[List[int], Tuple[int, ...], int]] = None,
+ ignore_label: int = None) -> AbstractTransform:
+ tr_transforms = []
+ if do_dummy_2d_data_aug:
+ ignore_axes = (0,)
+ tr_transforms.append(Convert3DTo2DTransform())
+ patch_size_spatial = patch_size[1:]
+ else:
+ patch_size_spatial = patch_size
+ ignore_axes = None
+
+ tr_transforms.append(SpatialTransform(
+ patch_size_spatial, patch_center_dist_from_border=None,
+ do_elastic_deform=False, alpha=(0, 0), sigma=(0, 0),
+ do_rotation=True, angle_x=rotation_for_DA['x'], angle_y=rotation_for_DA['y'], angle_z=rotation_for_DA['z'],
+ p_rot_per_axis=1, # todo experiment with this
+ do_scale=True, scale=(0.7, 1.4),
+ border_mode_data="constant", border_cval_data=0, order_data=order_resampling_data,
+ border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=order_resampling_seg,
+ random_crop=False, # random cropping is part of our dataloaders
+ p_el_per_sample=0, p_scale_per_sample=0.2, p_rot_per_sample=0.2,
+ independent_scale_for_each_axis=False # todo experiment with this
+ ))
+
+ if do_dummy_2d_data_aug:
+ tr_transforms.append(Convert2DTo3DTransform())
+
+ tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))
+ tr_transforms.append(GaussianBlurTransform((0.5, 1.), different_sigma_per_channel=True, p_per_sample=0.2,
+ p_per_channel=0.5))
+ tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25), p_per_sample=0.15))
+ tr_transforms.append(ContrastAugmentationTransform(p_per_sample=0.15))
+ tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True,
+ p_per_channel=0.5,
+ order_downsample=0, order_upsample=3, p_per_sample=0.25,
+ ignore_axes=ignore_axes))
+ tr_transforms.append(GammaTransform((0.7, 1.5), True, True, retain_stats=True, p_per_sample=0.1))
+ tr_transforms.append(GammaTransform((0.7, 1.5), False, True, retain_stats=True, p_per_sample=0.3))
+
+ if mirror_axes is not None and len(mirror_axes) > 0:
+ tr_transforms.append(MirrorTransform(mirror_axes))
+
+ if use_mask_for_norm is not None and any(use_mask_for_norm):
+ tr_transforms.append(MaskTransform([i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]],
+ mask_idx_in_seg=0, set_outside_to=0))
+
+ tr_transforms.append(RemoveLabelTransform(-1, 0))
+
+ if is_cascaded:
+ assert foreground_labels is not None, 'We need foreground_labels for cascade augmentations'
+ tr_transforms.append(MoveSegAsOneHotToData(1, foreground_labels, 'seg', 'data'))
+ tr_transforms.append(ApplyRandomBinaryOperatorTransform(
+ channel_idx=list(range(-len(foreground_labels), 0)),
+ p_per_sample=0.4,
+ key="data",
+ strel_size=(1, 8),
+ p_per_label=1))
+ tr_transforms.append(
+ RemoveRandomConnectedComponentFromOneHotEncodingTransform(
+ channel_idx=list(range(-len(foreground_labels), 0)),
+ key="data",
+ p_per_sample=0.2,
+ fill_with_other_class_p=0,
+ dont_do_if_covers_more_than_x_percent=0.15))
+
+ tr_transforms.append(RenameTransform('seg', 'target', True))
+
+ if regions is not None:
+ # the ignore label must also be converted
+ tr_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label]
+ if ignore_label is not None else regions,
+ 'target', 'target'))
+
+ if deep_supervision_scales is not None:
+ tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',
+ output_key='target'))
+ tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
+ tr_transforms = Compose(tr_transforms)
+ return tr_transforms
+
+ @staticmethod
+ def get_validation_transforms(deep_supervision_scales: Union[List, Tuple],
+ is_cascaded: bool = False,
+ foreground_labels: Union[Tuple[int, ...], List[int]] = None,
+ regions: List[Union[List[int], Tuple[int, ...], int]] = None,
+ ignore_label: int = None) -> AbstractTransform:
+ val_transforms = []
+ val_transforms.append(RemoveLabelTransform(-1, 0))
+
+ if is_cascaded:
+ val_transforms.append(MoveSegAsOneHotToData(1, foreground_labels, 'seg', 'data'))
+
+ val_transforms.append(RenameTransform('seg', 'target', True))
+
+ if regions is not None:
+ # the ignore label must also be converted
+ val_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label]
+ if ignore_label is not None else regions,
+ 'target', 'target'))
+
+ if deep_supervision_scales is not None:
+ val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',
+ output_key='target'))
+
+ val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
+ val_transforms = Compose(val_transforms)
+ return val_transforms
+
+ def set_deep_supervision_enabled(self, enabled: bool):
+ """
+ This function is specific for the default architecture in nnU-Net. If you change the architecture, there are
+ chances you need to change this as well!
+ """
+ if self.is_ddp:
+ self.network.module.decoder.deep_supervision = enabled
+ else:
+ self.network.decoder.deep_supervision = enabled
+
+ def on_train_start(self):
+ if not self.was_initialized:
+ self.initialize()
+
+ maybe_mkdir_p(self.output_folder)
+
+ # make sure deep supervision is on in the network
+ self.set_deep_supervision_enabled(True)
+
+ self.print_plans()
+ empty_cache(self.device)
+
+ # maybe unpack
+ if self.unpack_dataset and self.local_rank == 0:
+ self.print_to_log_file('unpacking dataset...')
+ unpack_dataset(self.preprocessed_dataset_folder, unpack_segmentation=True, overwrite_existing=False,
+ num_processes=max(1, round(get_allowed_n_proc_DA() // 2)))
+ self.print_to_log_file('unpacking done...')
+
+ if self.is_ddp:
+ dist.barrier()
+
+ # dataloaders must be instantiated here because they need access to the training data which may not be present
+ # when doing inference
+ self.dataloader_train, self.dataloader_val = self.get_dataloaders()
+
+ # copy plans and dataset.json so that they can be used for restoring everything we need for inference
+ save_json(self.plans_manager.plans, join(self.output_folder_base, 'plans.json'), sort_keys=False)
+ save_json(self.dataset_json, join(self.output_folder_base, 'dataset.json'), sort_keys=False)
+
+ # we don't really need the fingerprint but its still handy to have it with the others
+ shutil.copy(join(self.preprocessed_dataset_folder_base, 'dataset_fingerprint.json'),
+ join(self.output_folder_base, 'dataset_fingerprint.json'))
+
+ # produces a pdf in output folder
+ self.plot_network_architecture()
+
+ self._save_debug_information()
+
+ # print(f"batch size: {self.batch_size}")
+ # print(f"oversample: {self.oversample_foreground_percent}")
+
+ def on_train_end(self):
+ self.save_checkpoint(join(self.output_folder, "checkpoint_final.pth"))
+ # now we can delete latest
+ if self.local_rank == 0 and isfile(join(self.output_folder, "checkpoint_latest.pth")):
+ os.remove(join(self.output_folder, "checkpoint_latest.pth"))
+
+ # shut down dataloaders
+ old_stdout = sys.stdout
+ with open(os.devnull, 'w') as f:
+ sys.stdout = f
+ if self.dataloader_train is not None:
+ self.dataloader_train._finish()
+ if self.dataloader_val is not None:
+ self.dataloader_val._finish()
+ sys.stdout = old_stdout
+
+ empty_cache(self.device)
+ self.print_to_log_file("Training done.")
+
+ def on_train_epoch_start(self):
+ self.network.train()
+ self.lr_scheduler.step(self.current_epoch)
+ self.print_to_log_file('')
+ self.print_to_log_file(f'Epoch {self.current_epoch}')
+ self.print_to_log_file(
+ f"Current learning rate: {np.round(self.optimizer.param_groups[0]['lr'], decimals=5)}")
+ # lrs are the same for all workers so we don't need to gather them in case of DDP training
+ self.logger.log('lrs', self.optimizer.param_groups[0]['lr'], self.current_epoch)
+
+ def train_step(self, batch: dict) -> dict:
+ data = batch['data']
+ target = batch['target']
+
+ data = data.to(self.device, non_blocking=True)
+ if isinstance(target, list):
+ target = [i.to(self.device, non_blocking=True) for i in target]
+ else:
+ target = target.to(self.device, non_blocking=True)
+ target_class = torch.nn.functional.one_hot(
+ torch.cat([torch.max(target[0][idx]).long().unsqueeze(0) for idx in range(np.shape(target[0])[0])], dim=0), num_classes=2)
+
+ self.optimizer.zero_grad()
+ # Autocast is a little bitch.
+ # If the device_type is 'cpu' then it's slow as heck and needs to be disabled.
+ # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
+ # So autocast will only be active if we have a cuda device.
+ with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
+ outputs, classif = self.network(data)
+ # del data
+ l = self.loss(outputs, target) + self.classif_loss(classif, target_class.float())
+
+ # Seg backward loop
+ if self.grad_scaler is not None:
+ self.grad_scaler.scale(l).backward()
+ self.grad_scaler.unscale_(self.optimizer)
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
+ self.grad_scaler.step(self.optimizer)
+ self.grad_scaler.update()
+ else:
+ l.backward()
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
+ self.optimizer.step()
+
+ return {'loss': l.detach().cpu().numpy()}
+
+ def on_train_epoch_end(self, train_outputs: List[dict]):
+ outputs = collate_outputs(train_outputs)
+
+ if self.is_ddp:
+ losses_tr = [None for _ in range(dist.get_world_size())]
+ dist.all_gather_object(losses_tr, outputs['loss'])
+ loss_here = np.vstack(losses_tr).mean()
+ else:
+ loss_here = np.mean(outputs['loss'])
+
+ self.logger.log('train_losses', loss_here, self.current_epoch)
+
+ def on_validation_epoch_start(self):
+ self.network.eval()
+
+ def validation_step(self, batch: dict) -> dict:
+ data = batch['data']
+ target = batch['target']
+
+ data = data.to(self.device, non_blocking=True)
+ if isinstance(target, list):
+ target = [i.to(self.device, non_blocking=True) for i in target]
+ else:
+ target = target.to(self.device, non_blocking=True)
+
+ target_class = torch.nn.functional.one_hot(
+ torch.cat([torch.max(target[0][idx]).long().unsqueeze(0) for idx in range(np.shape(target[0])[0])], dim=0), num_classes=2)
+ # Autocast is a little bitch.
+ # If the device_type is 'cpu' then it's slow as heck and needs to be disabled.
+ # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
+ # So autocast will only be active if we have a cuda device.
+ with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
+ output = self.network(data)
+ del data
+ l = self.loss(output, target)
+
+ # we only need the output with the highest output resolution
+ output = output[0]
+ target = target[0]
+
+ # the following is needed for online evaluation. Fake dice (green line)
+ axes = [0] + list(range(2, len(output.shape)))
+
+ if self.label_manager.has_regions:
+ predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long()
+ else:
+ # no need for softmax
+ output_seg = output.argmax(1)[:, None]
+ predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float32)
+ predicted_segmentation_onehot.scatter_(1, output_seg, 1)
+ del output_seg
+
+ if self.label_manager.has_ignore_label:
+ if not self.label_manager.has_regions:
+ mask = (target != self.label_manager.ignore_label).float()
+ # CAREFUL that you don't rely on target after this line!
+ target[target == self.label_manager.ignore_label] = 0
+ else:
+ mask = 1 - target[:, -1:]
+ # CAREFUL that you don't rely on target after this line!
+ target = target[:, :-1]
+ else:
+ mask = None
+
+ tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask)
+
+ tp_hard = tp.detach().cpu().numpy()
+ fp_hard = fp.detach().cpu().numpy()
+ fn_hard = fn.detach().cpu().numpy()
+ if not self.label_manager.has_regions:
+ # if we train with regions all segmentation heads predict some kind of foreground. In conventional
+ # (softmax training) there needs tobe one output for the background. We are not interested in the
+ # background Dice
+ # [1:] in order to remove background
+ tp_hard = tp_hard[1:]
+ fp_hard = fp_hard[1:]
+ fn_hard = fn_hard[1:]
+
+ return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard}
+
+ def on_validation_epoch_end(self, val_outputs: List[dict]):
+ outputs_collated = collate_outputs(val_outputs)
+ tp = np.sum(outputs_collated['tp_hard'], 0)
+ fp = np.sum(outputs_collated['fp_hard'], 0)
+ fn = np.sum(outputs_collated['fn_hard'], 0)
+
+ if self.is_ddp:
+ world_size = dist.get_world_size()
+
+ tps = [None for _ in range(world_size)]
+ dist.all_gather_object(tps, tp)
+ tp = np.vstack([i[None] for i in tps]).sum(0)
+
+ fps = [None for _ in range(world_size)]
+ dist.all_gather_object(fps, fp)
+ fp = np.vstack([i[None] for i in fps]).sum(0)
+
+ fns = [None for _ in range(world_size)]
+ dist.all_gather_object(fns, fn)
+ fn = np.vstack([i[None] for i in fns]).sum(0)
+
+ losses_val = [None for _ in range(world_size)]
+ dist.all_gather_object(losses_val, outputs_collated['loss'])
+ loss_here = np.vstack(losses_val).mean()
+ else:
+ loss_here = np.mean(outputs_collated['loss'])
+
+ global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in
+ zip(tp, fp, fn)]]
+ mean_fg_dice = np.nanmean(global_dc_per_class)
+ self.logger.log('mean_fg_dice', mean_fg_dice, self.current_epoch)
+ self.logger.log('dice_per_class_or_region', global_dc_per_class, self.current_epoch)
+ self.logger.log('val_losses', loss_here, self.current_epoch)
+
+ def on_epoch_start(self):
+ self.logger.log('epoch_start_timestamps', time(), self.current_epoch)
+
+ def on_epoch_end(self):
+ self.logger.log('epoch_end_timestamps', time(), self.current_epoch)
+
+ # todo find a solution for this stupid shit
+ self.print_to_log_file('train_loss', np.round(self.logger.my_fantastic_logging['train_losses'][-1], decimals=4))
+ self.print_to_log_file('val_loss', np.round(self.logger.my_fantastic_logging['val_losses'][-1], decimals=4))
+ self.print_to_log_file('Pseudo dice', [np.round(i, decimals=4) for i in
+ self.logger.my_fantastic_logging['dice_per_class_or_region'][-1]])
+ self.print_to_log_file(
+ f"Epoch time: {np.round(self.logger.my_fantastic_logging['epoch_end_timestamps'][-1] - self.logger.my_fantastic_logging['epoch_start_timestamps'][-1], decimals=2)} s")
+
+ # handling periodic checkpointing
+ current_epoch = self.current_epoch
+ if (current_epoch + 1) % self.save_every == 0 and current_epoch != (self.num_epochs - 1):
+ self.save_checkpoint(join(self.output_folder, 'checkpoint_latest.pth'))
+
+ # handle 'best' checkpointing. ema_fg_dice is computed by the logger and can be accessed like this
+ if self._best_ema is None or self.logger.my_fantastic_logging['ema_fg_dice'][-1] > self._best_ema:
+ self._best_ema = self.logger.my_fantastic_logging['ema_fg_dice'][-1]
+ self.print_to_log_file(f"Yayy! New best EMA pseudo Dice: {np.round(self._best_ema, decimals=4)}")
+ self.save_checkpoint(join(self.output_folder, 'checkpoint_best.pth'))
+
+ if self.local_rank == 0:
+ self.logger.plot_progress_png(self.output_folder)
+
+ self.current_epoch += 1
+
+ def save_checkpoint(self, filename: str) -> None:
+ if self.local_rank == 0:
+ if not self.disable_checkpointing:
+ if self.is_ddp:
+ mod = self.network.module
+ else:
+ mod = self.network
+ if isinstance(mod, OptimizedModule):
+ mod = mod._orig_mod
+
+ checkpoint = {
+ 'network_weights': mod.state_dict(),
+ 'optimizer_state': self.optimizer.state_dict(),
+ 'grad_scaler_state': self.grad_scaler.state_dict() if self.grad_scaler is not None else None,
+ 'logging': self.logger.get_checkpoint(),
+ '_best_ema': self._best_ema,
+ 'current_epoch': self.current_epoch + 1,
+ 'init_args': self.my_init_kwargs,
+ 'trainer_name': self.__class__.__name__,
+ 'inference_allowed_mirroring_axes': self.inference_allowed_mirroring_axes,
+ }
+ torch.save(checkpoint, filename)
+ else:
+ self.print_to_log_file('No checkpoint written, checkpointing is disabled')
+
+ def load_checkpoint(self, filename_or_checkpoint: Union[dict, str]) -> None:
+ if not self.was_initialized:
+ self.initialize()
+
+ if isinstance(filename_or_checkpoint, str):
+ checkpoint = torch.load(filename_or_checkpoint, map_location=self.device)
+ # if state dict comes from nn.DataParallel but we use non-parallel model here then the state dict keys do not
+ # match. Use heuristic to make it match
+ new_state_dict = {}
+ for k, value in checkpoint['network_weights'].items():
+ key = k
+ if key not in self.network.state_dict().keys() and key.startswith('module.'):
+ key = key[7:]
+ new_state_dict[key] = value
+
+ self.my_init_kwargs = checkpoint['init_args']
+ self.current_epoch = checkpoint['current_epoch']
+ self.logger.load_checkpoint(checkpoint['logging'])
+ self._best_ema = checkpoint['_best_ema']
+ self.inference_allowed_mirroring_axes = checkpoint[
+ 'inference_allowed_mirroring_axes'] if 'inference_allowed_mirroring_axes' in checkpoint.keys() else self.inference_allowed_mirroring_axes
+
+ # messing with state dict naming schemes. Facepalm.
+ if self.is_ddp:
+ if isinstance(self.network.module, OptimizedModule):
+ self.network.module._orig_mod.load_state_dict(new_state_dict)
+ else:
+ self.network.module.load_state_dict(new_state_dict)
+ else:
+ if isinstance(self.network, OptimizedModule):
+ self.network._orig_mod.load_state_dict(new_state_dict)
+ else:
+ self.network.load_state_dict(new_state_dict)
+ self.optimizer.load_state_dict(checkpoint['optimizer_state'])
+ if self.grad_scaler is not None:
+ if checkpoint['grad_scaler_state'] is not None:
+ self.grad_scaler.load_state_dict(checkpoint['grad_scaler_state'])
+
+ def perform_actual_validation(self, save_probabilities: bool = False):
+ self.set_deep_supervision_enabled(False)
+ self.network.eval()
+
+ predictor = nnUNetPredictor(tile_step_size=0.5, use_gaussian=True, use_mirroring=True,
+ perform_everything_on_gpu=True, device=self.device, verbose=False,
+ verbose_preprocessing=False, allow_tqdm=False)
+ predictor.manual_initialization(self.network, self.plans_manager, self.configuration_manager, None,
+ self.dataset_json, self.__class__.__name__,
+ self.inference_allowed_mirroring_axes)
+
+ with multiprocessing.get_context("spawn").Pool(default_num_processes) as segmentation_export_pool:
+ worker_list = [i for i in segmentation_export_pool._pool]
+ validation_output_folder = join(self.output_folder, 'validation')
+ maybe_mkdir_p(validation_output_folder)
+
+ # we cannot use self.get_tr_and_val_datasets() here because we might be DDP and then we have to distribute
+ # the validation keys across the workers.
+ _, val_keys = self.do_split()
+ if self.is_ddp:
+ val_keys = val_keys[self.local_rank:: dist.get_world_size()]
+
+ dataset_val = nnUNetDataset(self.preprocessed_dataset_folder, val_keys,
+ folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage,
+ num_images_properties_loading_threshold=0)
+
+ next_stages = self.configuration_manager.next_stage_names
+
+ if next_stages is not None:
+ _ = [maybe_mkdir_p(join(self.output_folder_base, 'predicted_next_stage', n)) for n in next_stages]
+
+ results = []
+ for k in dataset_val.keys():
+ proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results,
+ allowed_num_queued=2)
+ while not proceed:
+ sleep(0.1)
+ proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results,
+ allowed_num_queued=2)
+
+ self.print_to_log_file(f"predicting {k}")
+ data, seg, properties = dataset_val.load_case(k)
+
+ if self.is_cascaded:
+ data = np.vstack((data, convert_labelmap_to_one_hot(seg[-1], self.label_manager.foreground_labels,
+ output_dtype=data.dtype)))
+ with warnings.catch_warnings():
+ # ignore 'The given NumPy array is not writable' warning
+ warnings.simplefilter("ignore")
+ data = torch.from_numpy(data)
+
+ output_filename_truncated = join(validation_output_folder, k)
+
+ try:
+ prediction = predictor.predict_sliding_window_return_logits(data)
+ except RuntimeError:
+ predictor.perform_everything_on_gpu = False
+ prediction = predictor.predict_sliding_window_return_logits(data)
+ predictor.perform_everything_on_gpu = True
+
+ prediction = prediction.cpu()
+
+ # this needs to go into background processes
+ results.append(
+ segmentation_export_pool.starmap_async(
+ export_prediction_from_logits, (
+ (prediction, properties, self.configuration_manager, self.plans_manager,
+ self.dataset_json, output_filename_truncated, save_probabilities),
+ )
+ )
+ )
+ # for debug purposes
+ # export_prediction(prediction_for_export, properties, self.configuration, self.plans, self.dataset_json,
+ # output_filename_truncated, save_probabilities)
+
+ # if needed, export the softmax prediction for the next stage
+ if next_stages is not None:
+ for n in next_stages:
+ next_stage_config_manager = self.plans_manager.get_configuration(n)
+ expected_preprocessed_folder = join(nnUNet_preprocessed, self.plans_manager.dataset_name,
+ next_stage_config_manager.data_identifier)
+
+ try:
+ # we do this so that we can use load_case and do not have to hard code how loading training cases is implemented
+ tmp = nnUNetDataset(expected_preprocessed_folder, [k],
+ num_images_properties_loading_threshold=0)
+ d, s, p = tmp.load_case(k)
+ except FileNotFoundError:
+ self.print_to_log_file(
+ f"Predicting next stage {n} failed for case {k} because the preprocessed file is missing! "
+ f"Run the preprocessing for this configuration first!")
+ continue
+
+ target_shape = d.shape[1:]
+ output_folder = join(self.output_folder_base, 'predicted_next_stage', n)
+ output_file = join(output_folder, k + '.npz')
+
+ # resample_and_save(prediction, target_shape, output_file, self.plans_manager, self.configuration_manager, properties,
+ # self.dataset_json)
+ results.append(segmentation_export_pool.starmap_async(
+ resample_and_save, (
+ (prediction, target_shape, output_file, self.plans_manager,
+ self.configuration_manager,
+ properties,
+ self.dataset_json),
+ )
+ ))
+
+ _ = [r.get() for r in results]
+
+ if self.is_ddp:
+ dist.barrier()
+
+ if self.local_rank == 0:
+ metrics = compute_metrics_on_folder(join(self.preprocessed_dataset_folder_base, 'gt_segmentations'),
+ validation_output_folder,
+ join(validation_output_folder, 'summary.json'),
+ self.plans_manager.image_reader_writer_class(),
+ self.dataset_json["file_ending"],
+ self.label_manager.foreground_regions if self.label_manager.has_regions else
+ self.label_manager.foreground_labels,
+ self.label_manager.ignore_label, chill=True)
+ self.print_to_log_file("Validation complete", also_print_to_console=True)
+ self.print_to_log_file("Mean Validation Dice: ", (metrics['foreground_mean']["Dice"]), also_print_to_console=True)
+
+ self.set_deep_supervision_enabled(True)
+ compute_gaussian.cache_clear()
+
+ def run_training(self):
+ self.on_train_start()
+
+ for epoch in range(self.current_epoch, self.num_epochs):
+ self.on_epoch_start()
+
+ self.on_train_epoch_start()
+ train_outputs = []
+ for batch_id in range(self.num_iterations_per_epoch):
+ train_outputs.append(self.train_step(next(self.dataloader_train)))
+ self.on_train_epoch_end(train_outputs)
+
+ with torch.no_grad():
+ self.on_validation_epoch_start()
+ val_outputs = []
+ for batch_id in range(self.num_val_iterations_per_epoch):
+ val_outputs.append(self.validation_step(next(self.dataloader_val)))
+ self.on_validation_epoch_end(val_outputs)
+
+ self.on_epoch_end()
+
+ self.on_train_end()
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer_segrap.py b/nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer_segrap.py
new file mode 100644
index 0000000..6d95b26
--- /dev/null
+++ b/nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer_segrap.py
@@ -0,0 +1,1242 @@
+import inspect
+import multiprocessing
+import os
+import shutil
+import sys
+import warnings
+from copy import deepcopy
+from datetime import datetime
+from time import time, sleep
+from typing import Union, Tuple, List
+
+import numpy as np
+import torch
+from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter
+from batchgenerators.transforms.abstract_transforms import AbstractTransform, Compose
+from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \
+ ContrastAugmentationTransform, GammaTransform
+from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform
+from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform
+from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform
+from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor
+from batchgenerators.utilities.file_and_folder_operations import join, load_json, isfile, save_json, maybe_mkdir_p
+from torch._dynamo import OptimizedModule
+
+from nnunetv2.configuration import ANISO_THRESHOLD, default_num_processes
+from nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder
+from nnunetv2.inference.export_prediction import export_prediction_from_logits, resample_and_save
+from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
+from nnunetv2.inference.sliding_window_prediction import compute_gaussian
+from nnunetv2.paths import nnUNet_preprocessed, nnUNet_results
+from nnunetv2.training.data_augmentation.compute_initial_patch_size import get_patch_size
+from nnunetv2.training.data_augmentation.custom_transforms.cascade_transforms import MoveSegAsOneHotToData, \
+ ApplyRandomBinaryOperatorTransform, RemoveRandomConnectedComponentFromOneHotEncodingTransform
+from nnunetv2.training.data_augmentation.custom_transforms.deep_supervision_donwsampling import \
+ DownsampleSegForDSTransform2
+from nnunetv2.training.data_augmentation.custom_transforms.limited_length_multithreaded_augmenter import \
+ LimitedLenWrapper
+from nnunetv2.training.data_augmentation.custom_transforms.masking import MaskTransform
+from nnunetv2.training.data_augmentation.custom_transforms.region_based_training import \
+ ConvertSegmentationToRegionsTransform
+from nnunetv2.training.data_augmentation.custom_transforms.transforms_for_dummy_2d import Convert2DTo3DTransform, \
+ Convert3DTo2DTransform
+from nnunetv2.training.dataloading.data_loader_2d import nnUNetDataLoader2D
+from nnunetv2.training.dataloading.data_loader_3d import nnUNetDataLoader3D
+from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset
+from nnunetv2.training.dataloading.utils import get_case_identifiers, unpack_dataset
+from nnunetv2.training.logging.nnunet_logger import nnUNetLogger
+from nnunetv2.training.loss.compound_losses import DC_and_CE_loss, DC_and_BCE_loss
+from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper
+from nnunetv2.training.loss.dice import get_tp_fp_fn_tn, MemoryEfficientSoftDiceLoss
+from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler
+from nnunetv2.utilities.collate_outputs import collate_outputs
+from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA
+from nnunetv2.utilities.file_path_utilities import check_workers_alive_and_busy
+from nnunetv2.utilities.get_network_from_plans import get_network_from_plans
+from nnunetv2.utilities.helpers import empty_cache, dummy_context
+from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot, determine_num_input_channels
+from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
+from sklearn.model_selection import KFold
+from torch import autocast, nn
+from torch import distributed as dist
+from torch.cuda import device_count
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+
+class nnUNetTrainer_segrap(object):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ # From https://grugbrain.dev/. Worth a read ya big brains ;-)
+
+ # apex predator of grug is complexity
+ # complexity bad
+ # say again:
+ # complexity very bad
+ # you say now:
+ # complexity very, very bad
+ # given choice between complexity or one on one against t-rex, grug take t-rex: at least grug see t-rex
+ # complexity is spirit demon that enter codebase through well-meaning but ultimately very clubbable non grug-brain developers and project managers who not fear complexity spirit demon or even know about sometime
+ # one day code base understandable and grug can get work done, everything good!
+ # next day impossible: complexity demon spirit has entered code and very dangerous situation!
+
+ # OK OK I am guilty. But I tried. http://tiny.cc/gzgwuz
+
+ self.is_ddp = dist.is_available() and dist.is_initialized()
+ self.local_rank = 0 if not self.is_ddp else dist.get_rank()
+
+ self.device = device
+
+ # print what device we are using
+ if self.is_ddp: # implicitly it's clear that we use cuda in this case
+ print(f"I am local rank {self.local_rank}. {device_count()} GPUs are available. The world size is "
+ f"{dist.get_world_size()}."
+ f"Setting device to {self.device}")
+ self.device = torch.device(type='cuda', index=self.local_rank)
+ else:
+ if self.device.type == 'cuda':
+ # we might want to let the user pick this but for now please pick the correct GPU with CUDA_VISIBLE_DEVICES=X
+ self.device = torch.device(type='cuda', index=0)
+ print(f"Using device: {self.device}")
+
+ # loading and saving this class for continuing from checkpoint should not happen based on pickling. This
+ # would also pickle the network etc. Bad, bad. Instead we just reinstantiate and then load the checkpoint we
+ # need. So let's save the init args
+ self.my_init_kwargs = {}
+ for k in inspect.signature(self.__init__).parameters.keys():
+ self.my_init_kwargs[k] = locals()[k]
+
+ ### Saving all the init args into class variables for later access
+ self.plans_manager = PlansManager(plans)
+ self.configuration_manager = self.plans_manager.get_configuration(configuration)
+ self.configuration_name = configuration
+ self.dataset_json = dataset_json
+ self.fold = fold
+ self.unpack_dataset = unpack_dataset
+
+ ### Setting all the folder names. We need to make sure things don't crash in case we are just running
+ # inference and some of the folders may not be defined!
+ self.preprocessed_dataset_folder_base = join(nnUNet_preprocessed, self.plans_manager.dataset_name) \
+ if nnUNet_preprocessed is not None else None
+ self.output_folder_base = join(nnUNet_results, self.plans_manager.dataset_name,
+ self.__class__.__name__ + '__' + self.plans_manager.plans_name + "__" + configuration) \
+ if nnUNet_results is not None else None
+ self.output_folder = join(self.output_folder_base, f'fold_{fold}')
+
+ self.preprocessed_dataset_folder = join(self.preprocessed_dataset_folder_base,
+ self.configuration_manager.data_identifier)
+ # unlike the previous nnunet folder_with_segs_from_previous_stage is now part of the plans. For now it has to
+ # be a different configuration in the same plans
+ # IMPORTANT! the mapping must be bijective, so lowres must point to fullres and vice versa (using
+ # "previous_stage" and "next_stage"). Otherwise it won't work!
+ self.is_cascaded = self.configuration_manager.previous_stage_name is not None
+ self.folder_with_segs_from_previous_stage = \
+ join(nnUNet_results, self.plans_manager.dataset_name,
+ self.__class__.__name__ + '__' + self.plans_manager.plans_name + "__" +
+ self.configuration_manager.previous_stage_name, 'predicted_next_stage', self.configuration_name) \
+ if self.is_cascaded else None
+
+ ### Some hyperparameters for you to fiddle with
+ self.initial_lr = 1e-2
+ self.weight_decay = 3e-5
+ self.oversample_foreground_percent = 0.33
+ self.num_iterations_per_epoch = 250
+ self.num_val_iterations_per_epoch = 50
+ self.num_epochs = 1000
+ self.current_epoch = 0
+
+ ### Dealing with labels/regions
+ self.label_manager = self.plans_manager.get_label_manager(dataset_json)
+ # labels can either be a list of int (regular training) or a list of tuples of int (region-based training)
+ # needed for predictions. We do sigmoid in case of (overlapping) regions
+
+ self.num_input_channels = None # -> self.initialize()
+ self.network = None # -> self._get_network()
+ self.optimizer = self.lr_scheduler = None # -> self.initialize
+ self.grad_scaler = GradScaler() if self.device.type == 'cuda' else None
+ self.loss = None # -> self.initialize
+
+ ### Simple logging. Don't take that away from me!
+ # initialize log file. This is just our log for the print statements etc. Not to be confused with lightning
+ # logging
+ timestamp = datetime.now()
+ maybe_mkdir_p(self.output_folder)
+ self.log_file = join(self.output_folder, "training_log_%d_%d_%d_%02.0d_%02.0d_%02.0d.txt" %
+ (timestamp.year, timestamp.month, timestamp.day, timestamp.hour, timestamp.minute,
+ timestamp.second))
+ self.logger = nnUNetLogger()
+
+ ### placeholders
+ self.dataloader_train = self.dataloader_val = None # see on_train_start
+
+ ### initializing stuff for remembering things and such
+ self._best_ema = None
+
+ ### inference things
+ self.inference_allowed_mirroring_axes = None # this variable is set in
+ # self.configure_rotation_dummyDA_mirroring_and_inital_patch_size and will be saved in checkpoints
+
+ ### checkpoint saving stuff
+ self.save_every = 50
+ self.disable_checkpointing = False
+
+ ## DDP batch size and oversampling can differ between workers and needs adaptation
+ # we need to change the batch size in DDP because we don't use any of those distributed samplers
+ self._set_batch_size_and_oversample()
+
+ self.was_initialized = False
+
+ self.print_to_log_file("\n#######################################################################\n"
+ "Please cite the following paper when using nnU-Net:\n"
+ "Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). "
+ "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. "
+ "Nature methods, 18(2), 203-211.\n"
+ "#######################################################################\n",
+ also_print_to_console=True, add_timestamp=False)
+
+ def initialize(self):
+ if not self.was_initialized:
+ self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager,
+ self.dataset_json)
+
+ self.network = self.build_network_architecture(self.plans_manager, self.dataset_json,
+ self.configuration_manager,
+ self.num_input_channels,
+ enable_deep_supervision=True).to(self.device)
+ self.reconstruction_decoder = copy.deepcopy(self.network.decoder)
+ # compile network for free speedup
+ if ('nnUNet_compile' in os.environ.keys()) and (
+ os.environ['nnUNet_compile'].lower() in ('true', '1', 't')):
+ self.print_to_log_file('Compiling network...')
+ self.network = torch.compile(self.network)
+
+ self.optimizer, self.lr_scheduler = self.configure_optimizers()
+ # if ddp, wrap in DDP wrapper
+ if self.is_ddp:
+ self.network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.network)
+ self.network = DDP(self.network, device_ids=[self.local_rank])
+
+ self.loss = self._build_loss()
+ self.was_initialized = True
+ else:
+ raise RuntimeError("You have called self.initialize even though the trainer was already initialized. "
+ "That should not happen.")
+
+ def _save_debug_information(self):
+ # saving some debug information
+ if self.local_rank == 0:
+ dct = {}
+ for k in self.__dir__():
+ if not k.startswith("__"):
+ if not callable(getattr(self, k)) or k in ['loss', ]:
+ dct[k] = str(getattr(self, k))
+ elif k in ['network', ]:
+ dct[k] = str(getattr(self, k).__class__.__name__)
+ else:
+ # print(k)
+ pass
+ if k in ['dataloader_train', 'dataloader_val']:
+ if hasattr(getattr(self, k), 'generator'):
+ dct[k + '.generator'] = str(getattr(self, k).generator)
+ if hasattr(getattr(self, k), 'num_processes'):
+ dct[k + '.num_processes'] = str(getattr(self, k).num_processes)
+ if hasattr(getattr(self, k), 'transform'):
+ dct[k + '.transform'] = str(getattr(self, k).transform)
+ import subprocess
+ hostname = subprocess.getoutput(['hostname'])
+ dct['hostname'] = hostname
+ torch_version = torch.__version__
+ if self.device.type == 'cuda':
+ gpu_name = torch.cuda.get_device_name()
+ dct['gpu_name'] = gpu_name
+ cudnn_version = torch.backends.cudnn.version()
+ else:
+ cudnn_version = 'None'
+ dct['device'] = str(self.device)
+ dct['torch_version'] = torch_version
+ dct['cudnn_version'] = cudnn_version
+ save_json(dct, join(self.output_folder, "debug.json"))
+
+ @staticmethod
+ def build_network_architecture(plans_manager: PlansManager,
+ dataset_json,
+ configuration_manager: ConfigurationManager,
+ num_input_channels,
+ enable_deep_supervision: bool = True) -> nn.Module:
+ """
+ his is where you build the architecture according to the plans. There is no obligation to use
+ get_network_from_plans, this is just a utility we use for the nnU-Net default architectures. You can do what
+ you want. Even ignore the plans and just return something static (as long as it can process the requested
+ patch size)
+ but don't bug us with your bugs arising from fiddling with this :-P
+ This is the function that is called in inference as well! This is needed so that all network architecture
+ variants can be loaded at inference time (inference will use the same nnUNetTrainer that was used for
+ training, so if you change the network architecture during training by deriving a new trainer class then
+ inference will know about it).
+
+ If you need to know how many segmentation outputs your custom architecture needs to have, use the following snippet:
+ > label_manager = plans_manager.get_label_manager(dataset_json)
+ > label_manager.num_segmentation_heads
+ (why so complicated? -> We can have either classical training (classes) or regions. If we have regions,
+ the number of outputs is != the number of classes. Also there is the ignore label for which no output
+ should be generated. label_manager takes care of all that for you.)
+
+ """
+ return get_network_from_plans(plans_manager, dataset_json, configuration_manager,
+ num_input_channels, deep_supervision=enable_deep_supervision)
+
+ def _get_deep_supervision_scales(self):
+ deep_supervision_scales = list(list(i) for i in 1 / np.cumprod(np.vstack(
+ self.configuration_manager.pool_op_kernel_sizes), axis=0))[:-1]
+ return deep_supervision_scales
+
+ def _set_batch_size_and_oversample(self):
+ if not self.is_ddp:
+ # set batch size to what the plan says, leave oversample untouched
+ self.batch_size = self.configuration_manager.batch_size
+ else:
+ # batch size is distributed over DDP workers and we need to change oversample_percent for each worker
+ batch_sizes = []
+ oversample_percents = []
+
+ world_size = dist.get_world_size()
+ my_rank = dist.get_rank()
+
+ global_batch_size = self.configuration_manager.batch_size
+ assert global_batch_size >= world_size, 'Cannot run DDP if the batch size is smaller than the number of ' \
+ 'GPUs... Duh.'
+
+ batch_size_per_GPU = np.ceil(global_batch_size / world_size).astype(int)
+
+ for rank in range(world_size):
+ if (rank + 1) * batch_size_per_GPU > global_batch_size:
+ batch_size = batch_size_per_GPU - ((rank + 1) * batch_size_per_GPU - global_batch_size)
+ else:
+ batch_size = batch_size_per_GPU
+
+ batch_sizes.append(batch_size)
+
+ sample_id_low = 0 if len(batch_sizes) == 0 else np.sum(batch_sizes[:-1])
+ sample_id_high = np.sum(batch_sizes)
+
+ if sample_id_high / global_batch_size < (1 - self.oversample_foreground_percent):
+ oversample_percents.append(0.0)
+ elif sample_id_low / global_batch_size > (1 - self.oversample_foreground_percent):
+ oversample_percents.append(1.0)
+ else:
+ percent_covered_by_this_rank = sample_id_high / global_batch_size - sample_id_low / global_batch_size
+ oversample_percent_here = 1 - (((1 - self.oversample_foreground_percent) -
+ sample_id_low / global_batch_size) / percent_covered_by_this_rank)
+ oversample_percents.append(oversample_percent_here)
+
+ print("worker", my_rank, "oversample", oversample_percents[my_rank])
+ print("worker", my_rank, "batch_size", batch_sizes[my_rank])
+ # self.print_to_log_file("worker", my_rank, "oversample", oversample_percents[my_rank])
+ # self.print_to_log_file("worker", my_rank, "batch_size", batch_sizes[my_rank])
+
+ self.batch_size = batch_sizes[my_rank]
+ self.oversample_foreground_percent = oversample_percents[my_rank]
+
+ def _build_loss(self):
+ if self.label_manager.has_regions:
+ loss = DC_and_BCE_loss({},
+ {'batch_dice': self.configuration_manager.batch_dice,
+ 'do_bg': True, 'smooth': 1e-5, 'ddp': self.is_ddp},
+ use_ignore_label=self.label_manager.ignore_label is not None,
+ dice_class=MemoryEfficientSoftDiceLoss)
+ else:
+ loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice,
+ 'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1,
+ ignore_label=self.label_manager.ignore_label, dice_class=MemoryEfficientSoftDiceLoss)
+
+ deep_supervision_scales = self._get_deep_supervision_scales()
+
+ # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
+ # this gives higher resolution outputs more weight in the loss
+ weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))])
+ weights[-1] = 0
+
+ # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
+ weights = weights / weights.sum()
+ # now wrap the loss
+ loss = DeepSupervisionWrapper(loss, weights)
+ return loss
+
+ def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
+ """
+ This function is stupid and certainly one of the weakest spots of this implementation. Not entirely sure how we can fix it.
+ """
+ patch_size = self.configuration_manager.patch_size
+ dim = len(patch_size)
+ # todo rotation should be defined dynamically based on patch size (more isotropic patch sizes = more rotation)
+ if dim == 2:
+ do_dummy_2d_data_aug = False
+ # todo revisit this parametrization
+ if max(patch_size) / min(patch_size) > 1.5:
+ rotation_for_DA = {
+ 'x': (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi),
+ 'y': (0, 0),
+ 'z': (0, 0)
+ }
+ else:
+ rotation_for_DA = {
+ 'x': (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi),
+ 'y': (0, 0),
+ 'z': (0, 0)
+ }
+ mirror_axes = (0, 1)
+ elif dim == 3:
+ # todo this is not ideal. We could also have patch_size (64, 16, 128) in which case a full 180deg 2d rot would be bad
+ # order of the axes is determined by spacing, not image size
+ do_dummy_2d_data_aug = (max(patch_size) / patch_size[0]) > ANISO_THRESHOLD
+ if do_dummy_2d_data_aug:
+ # why do we rotate 180 deg here all the time? We should also restrict it
+ rotation_for_DA = {
+ 'x': (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi),
+ 'y': (0, 0),
+ 'z': (0, 0)
+ }
+ else:
+ rotation_for_DA = {
+ 'x': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi),
+ 'y': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi),
+ 'z': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi),
+ }
+ mirror_axes = (0, 1, 2)
+ else:
+ raise RuntimeError()
+
+ # todo this function is stupid. It doesn't even use the correct scale range (we keep things as they were in the
+ # old nnunet for now)
+ initial_patch_size = get_patch_size(patch_size[-dim:],
+ *rotation_for_DA.values(),
+ (0.85, 1.25))
+ if do_dummy_2d_data_aug:
+ initial_patch_size[0] = patch_size[0]
+
+ self.print_to_log_file(f'do_dummy_2d_data_aug: {do_dummy_2d_data_aug}')
+ self.inference_allowed_mirroring_axes = mirror_axes
+
+ return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
+
+ def print_to_log_file(self, *args, also_print_to_console=True, add_timestamp=True):
+ if self.local_rank == 0:
+ timestamp = time()
+ dt_object = datetime.fromtimestamp(timestamp)
+
+ if add_timestamp:
+ args = ("%s:" % dt_object, *args)
+
+ successful = False
+ max_attempts = 5
+ ctr = 0
+ while not successful and ctr < max_attempts:
+ try:
+ with open(self.log_file, 'a+') as f:
+ for a in args:
+ f.write(str(a))
+ f.write(" ")
+ f.write("\n")
+ successful = True
+ except IOError:
+ print("%s: failed to log: " % datetime.fromtimestamp(timestamp), sys.exc_info())
+ sleep(0.5)
+ ctr += 1
+ if also_print_to_console:
+ print(*args)
+ elif also_print_to_console:
+ print(*args)
+
+ def print_plans(self):
+ if self.local_rank == 0:
+ dct = deepcopy(self.plans_manager.plans)
+ del dct['configurations']
+ self.print_to_log_file(f"\nThis is the configuration used by this "
+ f"training:\nConfiguration name: {self.configuration_name}\n",
+ self.configuration_manager, '\n', add_timestamp=False)
+ self.print_to_log_file('These are the global plan.json settings:\n', dct, '\n', add_timestamp=False)
+
+ def configure_optimizers(self):
+ optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
+ momentum=0.99, nesterov=True)
+ lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs, exponent=3.)
+ return optimizer, lr_scheduler
+
+ def plot_network_architecture(self):
+ if self.local_rank == 0:
+ try:
+ # raise NotImplementedError('hiddenlayer no longer works and we do not have a viable alternative :-(')
+ # pip install git+https://github.com/saugatkandel/hiddenlayer.git
+
+ # from torchviz import make_dot
+ # # not viable.
+ # make_dot(tuple(self.network(torch.rand((1, self.num_input_channels,
+ # *self.configuration_manager.patch_size),
+ # device=self.device)))).render(
+ # join(self.output_folder, "network_architecture.pdf"), format='pdf')
+ # self.optimizer.zero_grad()
+
+ # broken.
+
+ import hiddenlayer as hl
+ g = hl.build_graph(self.network,
+ torch.rand((1, self.num_input_channels,
+ *self.configuration_manager.patch_size),
+ device=self.device),
+ transforms=None)
+ g.save(join(self.output_folder, "network_architecture.pdf"))
+ del g
+ except Exception as e:
+ self.print_to_log_file("Unable to plot network architecture:")
+ self.print_to_log_file(e)
+
+ # self.print_to_log_file("\nprinting the network instead:\n")
+ # self.print_to_log_file(self.network)
+ # self.print_to_log_file("\n")
+ finally:
+ empty_cache(self.device)
+
+ def do_split(self):
+ """
+ The default split is a 5 fold CV on all available training cases. nnU-Net will create a split (it is seeded,
+ so always the same) and save it as splits_final.pkl file in the preprocessed data directory.
+ Sometimes you may want to create your own split for various reasons. For this you will need to create your own
+ splits_final.pkl file. If this file is present, nnU-Net is going to use it and whatever splits are defined in
+ it. You can create as many splits in this file as you want. Note that if you define only 4 splits (fold 0-3)
+ and then set fold=4 when training (that would be the fifth split), nnU-Net will print a warning and proceed to
+ use a random 80:20 data split.
+ :return:
+ """
+ if self.fold == "all":
+ # if fold==all then we use all images for training and validation
+ case_identifiers = get_case_identifiers(self.preprocessed_dataset_folder)
+ tr_keys = case_identifiers
+ val_keys = tr_keys
+ else:
+ splits_file = join(self.preprocessed_dataset_folder_base, "splits_final.json")
+ dataset = nnUNetDataset(self.preprocessed_dataset_folder, case_identifiers=None,
+ num_images_properties_loading_threshold=0,
+ folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage)
+ # if the split file does not exist we need to create it
+ if not isfile(splits_file):
+ self.print_to_log_file("Creating new 5-fold cross-validation split...")
+ splits = []
+ all_keys_sorted = np.sort(list(dataset.keys()))
+ kfold = KFold(n_splits=5, shuffle=True, random_state=12345)
+ for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)):
+ train_keys = np.array(all_keys_sorted)[train_idx]
+ test_keys = np.array(all_keys_sorted)[test_idx]
+ splits.append({})
+ splits[-1]['train'] = list(train_keys)
+ splits[-1]['val'] = list(test_keys)
+ save_json(splits, splits_file)
+
+ else:
+ self.print_to_log_file("Using splits from existing split file:", splits_file)
+ splits = load_json(splits_file)
+ self.print_to_log_file("The split file contains %d splits." % len(splits))
+
+ self.print_to_log_file("Desired fold for training: %d" % self.fold)
+ if self.fold < len(splits):
+ tr_keys = splits[self.fold]['train']
+ val_keys = splits[self.fold]['val']
+ self.print_to_log_file("This split has %d training and %d validation cases."
+ % (len(tr_keys), len(val_keys)))
+ else:
+ self.print_to_log_file("INFO: You requested fold %d for training but splits "
+ "contain only %d folds. I am now creating a "
+ "random (but seeded) 80:20 split!" % (self.fold, len(splits)))
+ # if we request a fold that is not in the split file, create a random 80:20 split
+ rnd = np.random.RandomState(seed=12345 + self.fold)
+ keys = np.sort(list(dataset.keys()))
+ idx_tr = rnd.choice(len(keys), int(len(keys) * 0.8), replace=False)
+ idx_val = [i for i in range(len(keys)) if i not in idx_tr]
+ tr_keys = [keys[i] for i in idx_tr]
+ val_keys = [keys[i] for i in idx_val]
+ self.print_to_log_file("This random 80:20 split has %d training and %d validation cases."
+ % (len(tr_keys), len(val_keys)))
+ if any([i in val_keys for i in tr_keys]):
+ self.print_to_log_file('WARNING: Some validation cases are also in the training set. Please check the '
+ 'splits.json or ignore if this is intentional.')
+ return tr_keys, val_keys
+
+ def get_tr_and_val_datasets(self):
+ # create dataset split
+ tr_keys, val_keys = self.do_split()
+
+ # load the datasets for training and validation. Note that we always draw random samples so we really don't
+ # care about distributing training cases across GPUs.
+ dataset_tr = nnUNetDataset(self.preprocessed_dataset_folder, tr_keys,
+ folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage,
+ num_images_properties_loading_threshold=0)
+ dataset_val = nnUNetDataset(self.preprocessed_dataset_folder, val_keys,
+ folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage,
+ num_images_properties_loading_threshold=0)
+ return dataset_tr, dataset_val
+
+ def get_dataloaders(self):
+ # we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether
+ # we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be
+ patch_size = self.configuration_manager.patch_size
+ dim = len(patch_size)
+
+ # needed for deep supervision: how much do we need to downscale the segmentation targets for the different
+ # outputs?
+ deep_supervision_scales = self._get_deep_supervision_scales()
+
+ rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
+ self.configure_rotation_dummyDA_mirroring_and_inital_patch_size()
+
+ # training pipeline
+ tr_transforms = self.get_training_transforms(
+ patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug,
+ order_resampling_data=3, order_resampling_seg=1,
+ use_mask_for_norm=self.configuration_manager.use_mask_for_norm,
+ is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.foreground_labels,
+ regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None,
+ ignore_label=self.label_manager.ignore_label)
+
+ # validation pipeline
+ val_transforms = self.get_validation_transforms(deep_supervision_scales,
+ is_cascaded=self.is_cascaded,
+ foreground_labels=self.label_manager.foreground_labels,
+ regions=self.label_manager.foreground_regions if
+ self.label_manager.has_regions else None,
+ ignore_label=self.label_manager.ignore_label)
+
+ dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim)
+
+ allowed_num_processes = get_allowed_n_proc_DA()
+ if allowed_num_processes == 0:
+ mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms)
+ mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms)
+ else:
+ mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, data_loader=dl_tr, transform=tr_transforms,
+ num_processes=allowed_num_processes, num_cached=6, seeds=None,
+ pin_memory=self.device.type == 'cuda', wait_time=0.02)
+ mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, data_loader=dl_val,
+ transform=val_transforms, num_processes=max(1, allowed_num_processes // 2),
+ num_cached=3, seeds=None, pin_memory=self.device.type == 'cuda',
+ wait_time=0.02)
+ return mt_gen_train, mt_gen_val
+
+ def get_plain_dataloaders(self, initial_patch_size: Tuple[int, ...], dim: int):
+ dataset_tr, dataset_val = self.get_tr_and_val_datasets()
+
+ if dim == 2:
+ dl_tr = nnUNetDataLoader2D(dataset_tr, self.batch_size,
+ initial_patch_size,
+ self.configuration_manager.patch_size,
+ self.label_manager,
+ oversample_foreground_percent=self.oversample_foreground_percent,
+ sampling_probabilities=None, pad_sides=None)
+ dl_val = nnUNetDataLoader2D(dataset_val, self.batch_size,
+ self.configuration_manager.patch_size,
+ self.configuration_manager.patch_size,
+ self.label_manager,
+ oversample_foreground_percent=self.oversample_foreground_percent,
+ sampling_probabilities=None, pad_sides=None)
+ else:
+ dl_tr = nnUNetDataLoader3D(dataset_tr, self.batch_size,
+ initial_patch_size,
+ self.configuration_manager.patch_size,
+ self.label_manager,
+ oversample_foreground_percent=self.oversample_foreground_percent,
+ sampling_probabilities=None, pad_sides=None)
+ dl_val = nnUNetDataLoader3D(dataset_val, self.batch_size,
+ self.configuration_manager.patch_size,
+ self.configuration_manager.patch_size,
+ self.label_manager,
+ oversample_foreground_percent=self.oversample_foreground_percent,
+ sampling_probabilities=None, pad_sides=None)
+ return dl_tr, dl_val
+
+ @staticmethod
+ def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]],
+ rotation_for_DA: dict,
+ deep_supervision_scales: Union[List, Tuple],
+ mirror_axes: Tuple[int, ...],
+ do_dummy_2d_data_aug: bool,
+ order_resampling_data: int = 3,
+ order_resampling_seg: int = 1,
+ border_val_seg: int = -1,
+ use_mask_for_norm: List[bool] = None,
+ is_cascaded: bool = False,
+ foreground_labels: Union[Tuple[int, ...], List[int]] = None,
+ regions: List[Union[List[int], Tuple[int, ...], int]] = None,
+ ignore_label: int = None) -> AbstractTransform:
+ tr_transforms = []
+ if do_dummy_2d_data_aug:
+ ignore_axes = (0,)
+ tr_transforms.append(Convert3DTo2DTransform())
+ patch_size_spatial = patch_size[1:]
+ else:
+ patch_size_spatial = patch_size
+ ignore_axes = None
+
+ tr_transforms.append(SpatialTransform(
+ patch_size_spatial, patch_center_dist_from_border=None,
+ do_elastic_deform=False, alpha=(0, 0), sigma=(0, 0),
+ do_rotation=True, angle_x=rotation_for_DA['x'], angle_y=rotation_for_DA['y'], angle_z=rotation_for_DA['z'],
+ p_rot_per_axis=1, # todo experiment with this
+ do_scale=True, scale=(0.7, 1.4),
+ border_mode_data="constant", border_cval_data=0, order_data=order_resampling_data,
+ border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=order_resampling_seg,
+ random_crop=False, # random cropping is part of our dataloaders
+ p_el_per_sample=0, p_scale_per_sample=0.2, p_rot_per_sample=0.2,
+ independent_scale_for_each_axis=False # todo experiment with this
+ ))
+
+ if do_dummy_2d_data_aug:
+ tr_transforms.append(Convert2DTo3DTransform())
+
+ tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))
+ tr_transforms.append(GaussianBlurTransform((0.5, 1.), different_sigma_per_channel=True, p_per_sample=0.2,
+ p_per_channel=0.5))
+ tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25), p_per_sample=0.15))
+ tr_transforms.append(ContrastAugmentationTransform(p_per_sample=0.15))
+ tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True,
+ p_per_channel=0.5,
+ order_downsample=0, order_upsample=3, p_per_sample=0.25,
+ ignore_axes=ignore_axes))
+ tr_transforms.append(GammaTransform((0.7, 1.5), True, True, retain_stats=True, p_per_sample=0.1))
+ tr_transforms.append(GammaTransform((0.7, 1.5), False, True, retain_stats=True, p_per_sample=0.3))
+
+ if mirror_axes is not None and len(mirror_axes) > 0:
+ tr_transforms.append(MirrorTransform(mirror_axes))
+
+ if use_mask_for_norm is not None and any(use_mask_for_norm):
+ tr_transforms.append(MaskTransform([i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]],
+ mask_idx_in_seg=0, set_outside_to=0))
+
+ tr_transforms.append(RemoveLabelTransform(-1, 0))
+
+ if is_cascaded:
+ assert foreground_labels is not None, 'We need foreground_labels for cascade augmentations'
+ tr_transforms.append(MoveSegAsOneHotToData(1, foreground_labels, 'seg', 'data'))
+ tr_transforms.append(ApplyRandomBinaryOperatorTransform(
+ channel_idx=list(range(-len(foreground_labels), 0)),
+ p_per_sample=0.4,
+ key="data",
+ strel_size=(1, 8),
+ p_per_label=1))
+ tr_transforms.append(
+ RemoveRandomConnectedComponentFromOneHotEncodingTransform(
+ channel_idx=list(range(-len(foreground_labels), 0)),
+ key="data",
+ p_per_sample=0.2,
+ fill_with_other_class_p=0,
+ dont_do_if_covers_more_than_x_percent=0.15))
+
+ tr_transforms.append(RenameTransform('seg', 'target', True))
+
+ if regions is not None:
+ # the ignore label must also be converted
+ tr_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label]
+ if ignore_label is not None else regions,
+ 'target', 'target'))
+
+ if deep_supervision_scales is not None:
+ tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',
+ output_key='target'))
+ tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
+ tr_transforms = Compose(tr_transforms)
+ return tr_transforms
+
+ @staticmethod
+ def get_validation_transforms(deep_supervision_scales: Union[List, Tuple],
+ is_cascaded: bool = False,
+ foreground_labels: Union[Tuple[int, ...], List[int]] = None,
+ regions: List[Union[List[int], Tuple[int, ...], int]] = None,
+ ignore_label: int = None) -> AbstractTransform:
+ val_transforms = []
+ val_transforms.append(RemoveLabelTransform(-1, 0))
+
+ if is_cascaded:
+ val_transforms.append(MoveSegAsOneHotToData(1, foreground_labels, 'seg', 'data'))
+
+ val_transforms.append(RenameTransform('seg', 'target', True))
+
+ if regions is not None:
+ # the ignore label must also be converted
+ val_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label]
+ if ignore_label is not None else regions,
+ 'target', 'target'))
+
+ if deep_supervision_scales is not None:
+ val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',
+ output_key='target'))
+
+ val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
+ val_transforms = Compose(val_transforms)
+ return val_transforms
+
+ def set_deep_supervision_enabled(self, enabled: bool):
+ """
+ This function is specific for the default architecture in nnU-Net. If you change the architecture, there are
+ chances you need to change this as well!
+ """
+ if self.is_ddp:
+ self.network.module.decoder.deep_supervision = enabled
+ else:
+ self.network.decoder.deep_supervision = enabled
+
+ def on_train_start(self):
+ if not self.was_initialized:
+ self.initialize()
+
+ maybe_mkdir_p(self.output_folder)
+
+ # make sure deep supervision is on in the network
+ self.set_deep_supervision_enabled(True)
+
+ self.print_plans()
+ empty_cache(self.device)
+
+ # maybe unpack
+ if self.unpack_dataset and self.local_rank == 0:
+ self.print_to_log_file('unpacking dataset...')
+ unpack_dataset(self.preprocessed_dataset_folder, unpack_segmentation=True, overwrite_existing=False,
+ num_processes=max(1, round(get_allowed_n_proc_DA() // 2)))
+ self.print_to_log_file('unpacking done...')
+
+ if self.is_ddp:
+ dist.barrier()
+
+ # dataloaders must be instantiated here because they need access to the training data which may not be present
+ # when doing inference
+ self.dataloader_train, self.dataloader_val = self.get_dataloaders()
+
+ # copy plans and dataset.json so that they can be used for restoring everything we need for inference
+ save_json(self.plans_manager.plans, join(self.output_folder_base, 'plans.json'), sort_keys=False)
+ save_json(self.dataset_json, join(self.output_folder_base, 'dataset.json'), sort_keys=False)
+
+ # we don't really need the fingerprint but its still handy to have it with the others
+ shutil.copy(join(self.preprocessed_dataset_folder_base, 'dataset_fingerprint.json'),
+ join(self.output_folder_base, 'dataset_fingerprint.json'))
+
+ # produces a pdf in output folder
+ self.plot_network_architecture()
+
+ self._save_debug_information()
+
+ # print(f"batch size: {self.batch_size}")
+ # print(f"oversample: {self.oversample_foreground_percent}")
+
+ def on_train_end(self):
+ self.save_checkpoint(join(self.output_folder, "checkpoint_final.pth"))
+ # now we can delete latest
+ if self.local_rank == 0 and isfile(join(self.output_folder, "checkpoint_latest.pth")):
+ os.remove(join(self.output_folder, "checkpoint_latest.pth"))
+
+ # shut down dataloaders
+ old_stdout = sys.stdout
+ with open(os.devnull, 'w') as f:
+ sys.stdout = f
+ if self.dataloader_train is not None:
+ self.dataloader_train._finish()
+ if self.dataloader_val is not None:
+ self.dataloader_val._finish()
+ sys.stdout = old_stdout
+
+ empty_cache(self.device)
+ self.print_to_log_file("Training done.")
+
+ def on_train_epoch_start(self):
+ self.network.train()
+ self.lr_scheduler.step(self.current_epoch)
+ self.print_to_log_file('')
+ self.print_to_log_file(f'Epoch {self.current_epoch}')
+ self.print_to_log_file(
+ f"Current learning rate: {np.round(self.optimizer.param_groups[0]['lr'], decimals=5)}")
+ # lrs are the same for all workers so we don't need to gather them in case of DDP training
+ self.logger.log('lrs', self.optimizer.param_groups[0]['lr'], self.current_epoch)
+
+ def train_step(self, batch: dict) -> dict:
+ data = batch['data']
+ target = batch['target']
+
+ data = data.to(self.device, non_blocking=True)
+ if isinstance(target, list):
+ target = [i.to(self.device, non_blocking=True) for i in target]
+ else:
+ target = target.to(self.device, non_blocking=True)
+
+ self.optimizer.zero_grad()
+ # Autocast is a little bitch.
+ # If the device_type is 'cpu' then it's slow as heck and needs to be disabled.
+ # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
+ # So autocast will only be active if we have a cuda device.
+ with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
+ output = self.network(data)
+ # del data
+ l = self.loss(output, target)
+
+ if self.grad_scaler is not None:
+ self.grad_scaler.scale(l).backward()
+ self.grad_scaler.unscale_(self.optimizer)
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
+ self.grad_scaler.step(self.optimizer)
+ self.grad_scaler.update()
+ else:
+ l.backward()
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
+ self.optimizer.step()
+ return {'loss': l.detach().cpu().numpy()}
+
+ def on_train_epoch_end(self, train_outputs: List[dict]):
+ outputs = collate_outputs(train_outputs)
+
+ if self.is_ddp:
+ losses_tr = [None for _ in range(dist.get_world_size())]
+ dist.all_gather_object(losses_tr, outputs['loss'])
+ loss_here = np.vstack(losses_tr).mean()
+ else:
+ loss_here = np.mean(outputs['loss'])
+
+ self.logger.log('train_losses', loss_here, self.current_epoch)
+
+ def on_validation_epoch_start(self):
+ self.network.eval()
+
+ def validation_step(self, batch: dict) -> dict:
+ data = batch['data']
+ target = batch['target']
+
+ data = data.to(self.device, non_blocking=True)
+ if isinstance(target, list):
+ target = [i.to(self.device, non_blocking=True) for i in target]
+ else:
+ target = target.to(self.device, non_blocking=True)
+
+ # Autocast is a little bitch.
+ # If the device_type is 'cpu' then it's slow as heck and needs to be disabled.
+ # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
+ # So autocast will only be active if we have a cuda device.
+ with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
+ output = self.network(data)
+ del data
+ l = self.loss(output, target)
+
+ # we only need the output with the highest output resolution
+ output = output[0]
+ target = target[0]
+
+ # the following is needed for online evaluation. Fake dice (green line)
+ axes = [0] + list(range(2, len(output.shape)))
+
+ if self.label_manager.has_regions:
+ predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long()
+ else:
+ # no need for softmax
+ output_seg = output.argmax(1)[:, None]
+ predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float32)
+ predicted_segmentation_onehot.scatter_(1, output_seg, 1)
+ del output_seg
+
+ if self.label_manager.has_ignore_label:
+ if not self.label_manager.has_regions:
+ mask = (target != self.label_manager.ignore_label).float()
+ # CAREFUL that you don't rely on target after this line!
+ target[target == self.label_manager.ignore_label] = 0
+ else:
+ mask = 1 - target[:, -1:]
+ # CAREFUL that you don't rely on target after this line!
+ target = target[:, :-1]
+ else:
+ mask = None
+
+ tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask)
+
+ tp_hard = tp.detach().cpu().numpy()
+ fp_hard = fp.detach().cpu().numpy()
+ fn_hard = fn.detach().cpu().numpy()
+ if not self.label_manager.has_regions:
+ # if we train with regions all segmentation heads predict some kind of foreground. In conventional
+ # (softmax training) there needs tobe one output for the background. We are not interested in the
+ # background Dice
+ # [1:] in order to remove background
+ tp_hard = tp_hard[1:]
+ fp_hard = fp_hard[1:]
+ fn_hard = fn_hard[1:]
+
+ return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard}
+
+ def on_validation_epoch_end(self, val_outputs: List[dict]):
+ outputs_collated = collate_outputs(val_outputs)
+ tp = np.sum(outputs_collated['tp_hard'], 0)
+ fp = np.sum(outputs_collated['fp_hard'], 0)
+ fn = np.sum(outputs_collated['fn_hard'], 0)
+
+ if self.is_ddp:
+ world_size = dist.get_world_size()
+
+ tps = [None for _ in range(world_size)]
+ dist.all_gather_object(tps, tp)
+ tp = np.vstack([i[None] for i in tps]).sum(0)
+
+ fps = [None for _ in range(world_size)]
+ dist.all_gather_object(fps, fp)
+ fp = np.vstack([i[None] for i in fps]).sum(0)
+
+ fns = [None for _ in range(world_size)]
+ dist.all_gather_object(fns, fn)
+ fn = np.vstack([i[None] for i in fns]).sum(0)
+
+ losses_val = [None for _ in range(world_size)]
+ dist.all_gather_object(losses_val, outputs_collated['loss'])
+ loss_here = np.vstack(losses_val).mean()
+ else:
+ loss_here = np.mean(outputs_collated['loss'])
+
+ global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in
+ zip(tp, fp, fn)]]
+ mean_fg_dice = np.nanmean(global_dc_per_class)
+ self.logger.log('mean_fg_dice', mean_fg_dice, self.current_epoch)
+ self.logger.log('dice_per_class_or_region', global_dc_per_class, self.current_epoch)
+ self.logger.log('val_losses', loss_here, self.current_epoch)
+
+ def on_epoch_start(self):
+ self.logger.log('epoch_start_timestamps', time(), self.current_epoch)
+
+ def on_epoch_end(self):
+ self.logger.log('epoch_end_timestamps', time(), self.current_epoch)
+
+ # todo find a solution for this stupid shit
+ self.print_to_log_file('train_loss', np.round(self.logger.my_fantastic_logging['train_losses'][-1], decimals=4))
+ self.print_to_log_file('val_loss', np.round(self.logger.my_fantastic_logging['val_losses'][-1], decimals=4))
+ self.print_to_log_file('Pseudo dice', [np.round(i, decimals=4) for i in
+ self.logger.my_fantastic_logging['dice_per_class_or_region'][-1]])
+ self.print_to_log_file(
+ f"Epoch time: {np.round(self.logger.my_fantastic_logging['epoch_end_timestamps'][-1] - self.logger.my_fantastic_logging['epoch_start_timestamps'][-1], decimals=2)} s")
+
+ # handling periodic checkpointing
+ current_epoch = self.current_epoch
+ if (current_epoch + 1) % self.save_every == 0 and current_epoch != (self.num_epochs - 1):
+ self.save_checkpoint(join(self.output_folder, 'checkpoint_latest.pth'))
+
+ # handle 'best' checkpointing. ema_fg_dice is computed by the logger and can be accessed like this
+ if self._best_ema is None or self.logger.my_fantastic_logging['ema_fg_dice'][-1] > self._best_ema:
+ self._best_ema = self.logger.my_fantastic_logging['ema_fg_dice'][-1]
+ self.print_to_log_file(f"Yayy! New best EMA pseudo Dice: {np.round(self._best_ema, decimals=4)}")
+ self.save_checkpoint(join(self.output_folder, 'checkpoint_best.pth'))
+
+ if self.local_rank == 0:
+ self.logger.plot_progress_png(self.output_folder)
+
+ self.current_epoch += 1
+
+ def save_checkpoint(self, filename: str) -> None:
+ if self.local_rank == 0:
+ if not self.disable_checkpointing:
+ if self.is_ddp:
+ mod = self.network.module
+ else:
+ mod = self.network
+ if isinstance(mod, OptimizedModule):
+ mod = mod._orig_mod
+
+ checkpoint = {
+ 'network_weights': mod.state_dict(),
+ 'optimizer_state': self.optimizer.state_dict(),
+ 'grad_scaler_state': self.grad_scaler.state_dict() if self.grad_scaler is not None else None,
+ 'logging': self.logger.get_checkpoint(),
+ '_best_ema': self._best_ema,
+ 'current_epoch': self.current_epoch + 1,
+ 'init_args': self.my_init_kwargs,
+ 'trainer_name': self.__class__.__name__,
+ 'inference_allowed_mirroring_axes': self.inference_allowed_mirroring_axes,
+ }
+ torch.save(checkpoint, filename)
+ else:
+ self.print_to_log_file('No checkpoint written, checkpointing is disabled')
+
+ def load_checkpoint(self, filename_or_checkpoint: Union[dict, str]) -> None:
+ if not self.was_initialized:
+ self.initialize()
+
+ if isinstance(filename_or_checkpoint, str):
+ checkpoint = torch.load(filename_or_checkpoint, map_location=self.device)
+ # if state dict comes from nn.DataParallel but we use non-parallel model here then the state dict keys do not
+ # match. Use heuristic to make it match
+ new_state_dict = {}
+ for k, value in checkpoint['network_weights'].items():
+ key = k
+ if key not in self.network.state_dict().keys() and key.startswith('module.'):
+ key = key[7:]
+ new_state_dict[key] = value
+
+ self.my_init_kwargs = checkpoint['init_args']
+ self.current_epoch = checkpoint['current_epoch']
+ self.logger.load_checkpoint(checkpoint['logging'])
+ self._best_ema = checkpoint['_best_ema']
+ self.inference_allowed_mirroring_axes = checkpoint[
+ 'inference_allowed_mirroring_axes'] if 'inference_allowed_mirroring_axes' in checkpoint.keys() else self.inference_allowed_mirroring_axes
+
+ # messing with state dict naming schemes. Facepalm.
+ if self.is_ddp:
+ if isinstance(self.network.module, OptimizedModule):
+ self.network.module._orig_mod.load_state_dict(new_state_dict)
+ else:
+ self.network.module.load_state_dict(new_state_dict)
+ else:
+ if isinstance(self.network, OptimizedModule):
+ self.network._orig_mod.load_state_dict(new_state_dict)
+ else:
+ self.network.load_state_dict(new_state_dict)
+ self.optimizer.load_state_dict(checkpoint['optimizer_state'])
+ if self.grad_scaler is not None:
+ if checkpoint['grad_scaler_state'] is not None:
+ self.grad_scaler.load_state_dict(checkpoint['grad_scaler_state'])
+
+ def perform_actual_validation(self, save_probabilities: bool = False):
+ self.set_deep_supervision_enabled(False)
+ self.network.eval()
+
+ predictor = nnUNetPredictor(tile_step_size=0.5, use_gaussian=True, use_mirroring=True,
+ perform_everything_on_gpu=True, device=self.device, verbose=False,
+ verbose_preprocessing=False, allow_tqdm=False)
+ predictor.manual_initialization(self.network, self.plans_manager, self.configuration_manager, None,
+ self.dataset_json, self.__class__.__name__,
+ self.inference_allowed_mirroring_axes)
+
+ with multiprocessing.get_context("spawn").Pool(default_num_processes) as segmentation_export_pool:
+ worker_list = [i for i in segmentation_export_pool._pool]
+ validation_output_folder = join(self.output_folder, 'validation')
+ maybe_mkdir_p(validation_output_folder)
+
+ # we cannot use self.get_tr_and_val_datasets() here because we might be DDP and then we have to distribute
+ # the validation keys across the workers.
+ _, val_keys = self.do_split()
+ if self.is_ddp:
+ val_keys = val_keys[self.local_rank:: dist.get_world_size()]
+
+ dataset_val = nnUNetDataset(self.preprocessed_dataset_folder, val_keys,
+ folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage,
+ num_images_properties_loading_threshold=0)
+
+ next_stages = self.configuration_manager.next_stage_names
+
+ if next_stages is not None:
+ _ = [maybe_mkdir_p(join(self.output_folder_base, 'predicted_next_stage', n)) for n in next_stages]
+
+ results = []
+
+ for k in dataset_val.keys():
+ proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results,
+ allowed_num_queued=2)
+ while not proceed:
+ sleep(0.1)
+ proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results,
+ allowed_num_queued=2)
+
+ self.print_to_log_file(f"predicting {k}")
+ data, seg, properties = dataset_val.load_case(k)
+
+ if self.is_cascaded:
+ data = np.vstack((data, convert_labelmap_to_one_hot(seg[-1], self.label_manager.foreground_labels,
+ output_dtype=data.dtype)))
+ with warnings.catch_warnings():
+ # ignore 'The given NumPy array is not writable' warning
+ warnings.simplefilter("ignore")
+ data = torch.from_numpy(data)
+
+ output_filename_truncated = join(validation_output_folder, k)
+
+ try:
+ prediction = predictor.predict_sliding_window_return_logits(data)
+ except RuntimeError:
+ predictor.perform_everything_on_gpu = False
+ prediction = predictor.predict_sliding_window_return_logits(data)
+ predictor.perform_everything_on_gpu = True
+
+ prediction = prediction.cpu()
+
+ # this needs to go into background processes
+ results.append(
+ segmentation_export_pool.starmap_async(
+ export_prediction_from_logits, (
+ (prediction, properties, self.configuration_manager, self.plans_manager,
+ self.dataset_json, output_filename_truncated, save_probabilities),
+ )
+ )
+ )
+ # for debug purposes
+ # export_prediction(prediction_for_export, properties, self.configuration, self.plans, self.dataset_json,
+ # output_filename_truncated, save_probabilities)
+
+ # if needed, export the softmax prediction for the next stage
+ if next_stages is not None:
+ for n in next_stages:
+ next_stage_config_manager = self.plans_manager.get_configuration(n)
+ expected_preprocessed_folder = join(nnUNet_preprocessed, self.plans_manager.dataset_name,
+ next_stage_config_manager.data_identifier)
+
+ try:
+ # we do this so that we can use load_case and do not have to hard code how loading training cases is implemented
+ tmp = nnUNetDataset(expected_preprocessed_folder, [k],
+ num_images_properties_loading_threshold=0)
+ d, s, p = tmp.load_case(k)
+ except FileNotFoundError:
+ self.print_to_log_file(
+ f"Predicting next stage {n} failed for case {k} because the preprocessed file is missing! "
+ f"Run the preprocessing for this configuration first!")
+ continue
+
+ target_shape = d.shape[1:]
+ output_folder = join(self.output_folder_base, 'predicted_next_stage', n)
+ output_file = join(output_folder, k + '.npz')
+
+ # resample_and_save(prediction, target_shape, output_file, self.plans_manager, self.configuration_manager, properties,
+ # self.dataset_json)
+ results.append(segmentation_export_pool.starmap_async(
+ resample_and_save, (
+ (prediction, target_shape, output_file, self.plans_manager,
+ self.configuration_manager,
+ properties,
+ self.dataset_json),
+ )
+ ))
+
+ _ = [r.get() for r in results]
+
+ if self.is_ddp:
+ dist.barrier()
+
+ if self.local_rank == 0:
+ metrics = compute_metrics_on_folder(join(self.preprocessed_dataset_folder_base, 'gt_segmentations'),
+ validation_output_folder,
+ join(validation_output_folder, 'summary.json'),
+ self.plans_manager.image_reader_writer_class(),
+ self.dataset_json["file_ending"],
+ self.label_manager.foreground_regions if self.label_manager.has_regions else
+ self.label_manager.foreground_labels,
+ self.label_manager.ignore_label, chill=True)
+ self.print_to_log_file("Validation complete", also_print_to_console=True)
+ self.print_to_log_file("Mean Validation Dice: ", (metrics['foreground_mean']["Dice"]), also_print_to_console=True)
+
+ self.set_deep_supervision_enabled(True)
+ compute_gaussian.cache_clear()
+
+ def run_training(self):
+ self.on_train_start()
+
+ for epoch in range(self.current_epoch, self.num_epochs):
+ self.on_epoch_start()
+
+ self.on_train_epoch_start()
+ train_outputs = []
+ for batch_id in range(self.num_iterations_per_epoch):
+ train_outputs.append(self.train_step(next(self.dataloader_train)))
+ self.on_train_epoch_end(train_outputs)
+
+ with torch.no_grad():
+ self.on_validation_epoch_start()
+ val_outputs = []
+ for batch_id in range(self.num_val_iterations_per_epoch):
+ val_outputs.append(self.validation_step(next(self.dataloader_val)))
+ self.on_validation_epoch_end(val_outputs)
+
+ self.on_epoch_end()
+
+ self.on_train_end()
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/__init__.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/benchmarking/__init__.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/benchmarking/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs.py
new file mode 100644
index 0000000..fad1fff
--- /dev/null
+++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs.py
@@ -0,0 +1,65 @@
+import torch
+from batchgenerators.utilities.file_and_folder_operations import save_json, join, isfile, load_json
+
+from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
+from torch import distributed as dist
+
+
+class nnUNetTrainerBenchmark_5epochs(nnUNetTrainer):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ assert self.fold == 0, "It makes absolutely no sense to specify a certain fold. Stick with 0 so that we can parse the results."
+ self.disable_checkpointing = True
+ self.num_epochs = 5
+ assert torch.cuda.is_available(), "This only works on GPU"
+ self.crashed_with_runtime_error = False
+
+ def perform_actual_validation(self, save_probabilities: bool = False):
+ pass
+
+ def save_checkpoint(self, filename: str) -> None:
+ # do not trust people to remember that self.disable_checkpointing must be True for this trainer
+ pass
+
+ def run_training(self):
+ try:
+ super().run_training()
+ except RuntimeError:
+ self.crashed_with_runtime_error = True
+
+ def on_train_end(self):
+ super().on_train_end()
+
+ if not self.is_ddp or self.local_rank == 0:
+ torch_version = torch.__version__
+ cudnn_version = torch.backends.cudnn.version()
+ gpu_name = torch.cuda.get_device_name()
+ if self.crashed_with_runtime_error:
+ fastest_epoch = 'Not enough VRAM!'
+ else:
+ epoch_times = [i - j for i, j in zip(self.logger.my_fantastic_logging['epoch_end_timestamps'],
+ self.logger.my_fantastic_logging['epoch_start_timestamps'])]
+ fastest_epoch = min(epoch_times)
+
+ if self.is_ddp:
+ num_gpus = dist.get_world_size()
+ else:
+ num_gpus = 1
+
+ benchmark_result_file = join(self.output_folder, 'benchmark_result.json')
+ if isfile(benchmark_result_file):
+ old_results = load_json(benchmark_result_file)
+ else:
+ old_results = {}
+ # generate some unique key
+ my_key = f"{cudnn_version}__{torch_version.replace(' ', '')}__{gpu_name.replace(' ', '')}__gpus_{num_gpus}"
+ old_results[my_key] = {
+ 'torch_version': torch_version,
+ 'cudnn_version': cudnn_version,
+ 'gpu_name': gpu_name,
+ 'fastest_epoch': fastest_epoch,
+ 'num_gpus': num_gpus,
+ }
+ save_json(old_results,
+ join(self.output_folder, 'benchmark_result.json'))
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs_noDataLoading.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs_noDataLoading.py
new file mode 100644
index 0000000..6c12ecc
--- /dev/null
+++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs_noDataLoading.py
@@ -0,0 +1,51 @@
+import torch
+
+from nnunetv2.training.nnUNetTrainer.variants.benchmarking.nnUNetTrainerBenchmark_5epochs import \
+ nnUNetTrainerBenchmark_5epochs
+from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
+
+
+class nnUNetTrainerBenchmark_5epochs_noDataLoading(nnUNetTrainerBenchmark_5epochs):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self._set_batch_size_and_oversample()
+ num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager,
+ self.dataset_json)
+ patch_size = self.configuration_manager.patch_size
+ dummy_data = torch.rand((self.batch_size, num_input_channels, *patch_size), device=self.device)
+ dummy_target = [
+ torch.round(
+ torch.rand((self.batch_size, 1, *[int(i * j) for i, j in zip(patch_size, k)]), device=self.device) *
+ max(self.label_manager.all_labels)
+ ) for k in self._get_deep_supervision_scales()]
+ self.dummy_batch = {'data': dummy_data, 'target': dummy_target}
+
+ def get_dataloaders(self):
+ return None, None
+
+ def run_training(self):
+ try:
+ self.on_train_start()
+
+ for epoch in range(self.current_epoch, self.num_epochs):
+ self.on_epoch_start()
+
+ self.on_train_epoch_start()
+ train_outputs = []
+ for batch_id in range(self.num_iterations_per_epoch):
+ train_outputs.append(self.train_step(self.dummy_batch))
+ self.on_train_epoch_end(train_outputs)
+
+ with torch.no_grad():
+ self.on_validation_epoch_start()
+ val_outputs = []
+ for batch_id in range(self.num_val_iterations_per_epoch):
+ val_outputs.append(self.validation_step(self.dummy_batch))
+ self.on_validation_epoch_end(val_outputs)
+
+ self.on_epoch_end()
+
+ self.on_train_end()
+ except RuntimeError:
+ self.crashed_with_runtime_error = True
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/__init__.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py
new file mode 100644
index 0000000..bd9c31c
--- /dev/null
+++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py
@@ -0,0 +1,410 @@
+from typing import List, Union, Tuple
+
+import numpy as np
+import torch
+from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter
+from batchgenerators.transforms.abstract_transforms import AbstractTransform, Compose
+from batchgenerators.transforms.color_transforms import BrightnessTransform, ContrastAugmentationTransform, \
+ GammaTransform
+from batchgenerators.transforms.local_transforms import BrightnessGradientAdditiveTransform, LocalGammaTransform
+from batchgenerators.transforms.noise_transforms import MedianFilterTransform, GaussianBlurTransform, \
+ GaussianNoiseTransform, BlankRectangleTransform, SharpeningTransform
+from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform
+from batchgenerators.transforms.spatial_transforms import SpatialTransform, Rot90Transform, TransposeAxesTransform, \
+ MirrorTransform
+from batchgenerators.transforms.utility_transforms import OneOfTransform, RemoveLabelTransform, RenameTransform, \
+ NumpyToTensor
+
+from nnunetv2.configuration import ANISO_THRESHOLD
+from nnunetv2.training.data_augmentation.compute_initial_patch_size import get_patch_size
+from nnunetv2.training.data_augmentation.custom_transforms.cascade_transforms import MoveSegAsOneHotToData, \
+ ApplyRandomBinaryOperatorTransform, RemoveRandomConnectedComponentFromOneHotEncodingTransform
+from nnunetv2.training.data_augmentation.custom_transforms.deep_supervision_donwsampling import \
+ DownsampleSegForDSTransform2
+from nnunetv2.training.data_augmentation.custom_transforms.limited_length_multithreaded_augmenter import \
+ LimitedLenWrapper
+from nnunetv2.training.data_augmentation.custom_transforms.masking import MaskTransform
+from nnunetv2.training.data_augmentation.custom_transforms.region_based_training import \
+ ConvertSegmentationToRegionsTransform
+from nnunetv2.training.data_augmentation.custom_transforms.transforms_for_dummy_2d import Convert3DTo2DTransform, \
+ Convert2DTo3DTransform
+from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
+from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA
+
+
+class nnUNetTrainerDA5(nnUNetTrainer):
+ def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
+ """
+ This function is stupid and certainly one of the weakest spots of this implementation. Not entirely sure how we can fix it.
+ """
+ patch_size = self.configuration_manager.patch_size
+ dim = len(patch_size)
+ # todo rotation should be defined dynamically based on patch size (more isotropic patch sizes = more rotation)
+ if dim == 2:
+ do_dummy_2d_data_aug = False
+ # todo revisit this parametrization
+ if max(patch_size) / min(patch_size) > 1.5:
+ rotation_for_DA = {
+ 'x': (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi),
+ 'y': (0, 0),
+ 'z': (0, 0)
+ }
+ else:
+ rotation_for_DA = {
+ 'x': (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi),
+ 'y': (0, 0),
+ 'z': (0, 0)
+ }
+ mirror_axes = (0, 1)
+ elif dim == 3:
+ # todo this is not ideal. We could also have patch_size (64, 16, 128) in which case a full 180deg 2d rot would be bad
+ # order of the axes is determined by spacing, not image size
+ do_dummy_2d_data_aug = (max(patch_size) / patch_size[0]) > ANISO_THRESHOLD
+ if do_dummy_2d_data_aug:
+ # why do we rotate 180 deg here all the time? We should also restrict it
+ rotation_for_DA = {
+ 'x': (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi),
+ 'y': (0, 0),
+ 'z': (0, 0)
+ }
+ else:
+ rotation_for_DA = {
+ 'x': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi),
+ 'y': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi),
+ 'z': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi),
+ }
+ mirror_axes = (0, 1, 2)
+ else:
+ raise RuntimeError()
+
+ # todo this function is stupid. It doesn't even use the correct scale range (we keep things as they were in the
+ # old nnunet for now)
+ initial_patch_size = get_patch_size(patch_size[-dim:],
+ *rotation_for_DA.values(),
+ (0.7, 1.43))
+ if do_dummy_2d_data_aug:
+ initial_patch_size[0] = patch_size[0]
+
+ self.print_to_log_file(f'do_dummy_2d_data_aug: {do_dummy_2d_data_aug}')
+ self.inference_allowed_mirroring_axes = mirror_axes
+
+ return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
+
+ @staticmethod
+ def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]],
+ rotation_for_DA: dict,
+ deep_supervision_scales: Union[List, Tuple],
+ mirror_axes: Tuple[int, ...],
+ do_dummy_2d_data_aug: bool,
+ order_resampling_data: int = 3,
+ order_resampling_seg: int = 1,
+ border_val_seg: int = -1,
+ use_mask_for_norm: List[bool] = None,
+ is_cascaded: bool = False,
+ foreground_labels: Union[Tuple[int, ...], List[int]] = None,
+ regions: List[Union[List[int], Tuple[int, ...], int]] = None,
+ ignore_label: int = None) -> AbstractTransform:
+ matching_axes = np.array([sum([i == j for j in patch_size]) for i in patch_size])
+ valid_axes = list(np.where(matching_axes == np.max(matching_axes))[0])
+
+ tr_transforms = []
+
+ if do_dummy_2d_data_aug:
+ ignore_axes = (0,)
+ tr_transforms.append(Convert3DTo2DTransform())
+ patch_size_spatial = patch_size[1:]
+ else:
+ patch_size_spatial = patch_size
+ ignore_axes = None
+
+ tr_transforms.append(
+ SpatialTransform(
+ patch_size_spatial,
+ patch_center_dist_from_border=None,
+ do_elastic_deform=False,
+ do_rotation=True,
+ angle_x=rotation_for_DA['x'],
+ angle_y=rotation_for_DA['y'],
+ angle_z=rotation_for_DA['z'],
+ p_rot_per_axis=0.5,
+ do_scale=True,
+ scale=(0.7, 1.43),
+ border_mode_data="constant",
+ border_cval_data=0,
+ order_data=order_resampling_data,
+ border_mode_seg="constant",
+ border_cval_seg=-1,
+ order_seg=order_resampling_seg,
+ random_crop=False,
+ p_el_per_sample=0.2,
+ p_scale_per_sample=0.2,
+ p_rot_per_sample=0.4,
+ independent_scale_for_each_axis=True,
+ )
+ )
+
+ if do_dummy_2d_data_aug:
+ tr_transforms.append(Convert2DTo3DTransform())
+
+ if np.any(matching_axes > 1):
+ tr_transforms.append(
+ Rot90Transform(
+ (0, 1, 2, 3), axes=valid_axes, data_key='data', label_key='seg', p_per_sample=0.5
+ ),
+ )
+
+ if np.any(matching_axes > 1):
+ tr_transforms.append(
+ TransposeAxesTransform(valid_axes, data_key='data', label_key='seg', p_per_sample=0.5)
+ )
+
+ tr_transforms.append(OneOfTransform([
+ MedianFilterTransform(
+ (2, 8),
+ same_for_each_channel=False,
+ p_per_sample=0.2,
+ p_per_channel=0.5
+ ),
+ GaussianBlurTransform((0.3, 1.5),
+ different_sigma_per_channel=True,
+ p_per_sample=0.2,
+ p_per_channel=0.5)
+ ]))
+
+ tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))
+
+ tr_transforms.append(BrightnessTransform(0,
+ 0.5,
+ per_channel=True,
+ p_per_sample=0.1,
+ p_per_channel=0.5
+ )
+ )
+
+ tr_transforms.append(OneOfTransform(
+ [
+ ContrastAugmentationTransform(
+ contrast_range=(0.5, 2),
+ preserve_range=True,
+ per_channel=True,
+ data_key='data',
+ p_per_sample=0.2,
+ p_per_channel=0.5
+ ),
+ ContrastAugmentationTransform(
+ contrast_range=(0.5, 2),
+ preserve_range=False,
+ per_channel=True,
+ data_key='data',
+ p_per_sample=0.2,
+ p_per_channel=0.5
+ ),
+ ]
+ ))
+
+ tr_transforms.append(
+ SimulateLowResolutionTransform(zoom_range=(0.25, 1),
+ per_channel=True,
+ p_per_channel=0.5,
+ order_downsample=0,
+ order_upsample=3,
+ p_per_sample=0.15,
+ ignore_axes=ignore_axes
+ )
+ )
+
+ tr_transforms.append(
+ GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1))
+ tr_transforms.append(
+ GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1))
+
+ if mirror_axes is not None and len(mirror_axes) > 0:
+ tr_transforms.append(MirrorTransform(mirror_axes))
+
+ tr_transforms.append(
+ BlankRectangleTransform([[max(1, p // 10), p // 3] for p in patch_size],
+ rectangle_value=np.mean,
+ num_rectangles=(1, 5),
+ force_square=False,
+ p_per_sample=0.4,
+ p_per_channel=0.5
+ )
+ )
+
+ tr_transforms.append(
+ BrightnessGradientAdditiveTransform(
+ lambda x, y: np.exp(np.random.uniform(np.log(x[y] // 6), np.log(x[y]))),
+ (-0.5, 1.5),
+ max_strength=lambda x, y: np.random.uniform(-5, -1) if np.random.uniform() < 0.5 else np.random.uniform(1, 5),
+ mean_centered=False,
+ same_for_all_channels=False,
+ p_per_sample=0.3,
+ p_per_channel=0.5
+ )
+ )
+
+ tr_transforms.append(
+ LocalGammaTransform(
+ lambda x, y: np.exp(np.random.uniform(np.log(x[y] // 6), np.log(x[y]))),
+ (-0.5, 1.5),
+ lambda: np.random.uniform(0.01, 0.8) if np.random.uniform() < 0.5 else np.random.uniform(1.5, 4),
+ same_for_all_channels=False,
+ p_per_sample=0.3,
+ p_per_channel=0.5
+ )
+ )
+
+ tr_transforms.append(
+ SharpeningTransform(
+ strength=(0.1, 1),
+ same_for_each_channel=False,
+ p_per_sample=0.2,
+ p_per_channel=0.5
+ )
+ )
+
+ if use_mask_for_norm is not None and any(use_mask_for_norm):
+ tr_transforms.append(MaskTransform([i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]],
+ mask_idx_in_seg=0, set_outside_to=0))
+
+ tr_transforms.append(RemoveLabelTransform(-1, 0))
+
+ if is_cascaded:
+ if ignore_label is not None:
+ raise NotImplementedError('ignore label not yet supported in cascade')
+ assert foreground_labels is not None, 'We need all_labels for cascade augmentations'
+ use_labels = [i for i in foreground_labels if i != 0]
+ tr_transforms.append(MoveSegAsOneHotToData(1, use_labels, 'seg', 'data'))
+ tr_transforms.append(ApplyRandomBinaryOperatorTransform(
+ channel_idx=list(range(-len(use_labels), 0)),
+ p_per_sample=0.4,
+ key="data",
+ strel_size=(1, 8),
+ p_per_label=1))
+ tr_transforms.append(
+ RemoveRandomConnectedComponentFromOneHotEncodingTransform(
+ channel_idx=list(range(-len(use_labels), 0)),
+ key="data",
+ p_per_sample=0.2,
+ fill_with_other_class_p=0,
+ dont_do_if_covers_more_than_x_percent=0.15))
+
+ tr_transforms.append(RenameTransform('seg', 'target', True))
+
+ if regions is not None:
+ # the ignore label must also be converted
+ tr_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label]
+ if ignore_label is not None else regions,
+ 'target', 'target'))
+
+ if deep_supervision_scales is not None:
+ tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',
+ output_key='target'))
+ tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
+ tr_transforms = Compose(tr_transforms)
+ return tr_transforms
+
+
+class nnUNetTrainerDA5ord0(nnUNetTrainerDA5):
+ def get_dataloaders(self):
+ """
+ changed order_resampling_data, order_resampling_seg
+ """
+ # we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether
+ # we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be
+ patch_size = self.configuration_manager.patch_size
+ dim = len(patch_size)
+
+ # needed for deep supervision: how much do we need to downscale the segmentation targets for the different
+ # outputs?
+ deep_supervision_scales = self._get_deep_supervision_scales()
+
+ rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
+ self.configure_rotation_dummyDA_mirroring_and_inital_patch_size()
+
+ # training pipeline
+ tr_transforms = self.get_training_transforms(
+ patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug,
+ order_resampling_data=0, order_resampling_seg=0,
+ use_mask_for_norm=self.configuration_manager.use_mask_for_norm,
+ is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.all_labels,
+ regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None,
+ ignore_label=self.label_manager.ignore_label)
+
+ # validation pipeline
+ val_transforms = self.get_validation_transforms(deep_supervision_scales,
+ is_cascaded=self.is_cascaded,
+ foreground_labels=self.label_manager.all_labels,
+ regions=self.label_manager.foreground_regions if
+ self.label_manager.has_regions else None,
+ ignore_label=self.label_manager.ignore_label)
+
+ dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim)
+
+ allowed_num_processes = get_allowed_n_proc_DA()
+ if allowed_num_processes == 0:
+ mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms)
+ mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms)
+ else:
+ mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms,
+ allowed_num_processes, 6, None, True, 0.02)
+ mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms,
+ max(1, allowed_num_processes // 2), 3, None, True, 0.02)
+
+ return mt_gen_train, mt_gen_val
+
+
+class nnUNetTrainerDA5Segord0(nnUNetTrainerDA5):
+ def get_dataloaders(self):
+ """
+ changed order_resampling_data, order_resampling_seg
+ """
+ # we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether
+ # we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be
+ patch_size = self.configuration_manager.patch_size
+ dim = len(patch_size)
+
+ # needed for deep supervision: how much do we need to downscale the segmentation targets for the different
+ # outputs?
+ deep_supervision_scales = self._get_deep_supervision_scales()
+
+ rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
+ self.configure_rotation_dummyDA_mirroring_and_inital_patch_size()
+
+ # training pipeline
+ tr_transforms = self.get_training_transforms(
+ patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug,
+ order_resampling_data=3, order_resampling_seg=0,
+ use_mask_for_norm=self.configuration_manager.use_mask_for_norm,
+ is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.all_labels,
+ regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None,
+ ignore_label=self.label_manager.ignore_label)
+
+ # validation pipeline
+ val_transforms = self.get_validation_transforms(deep_supervision_scales,
+ is_cascaded=self.is_cascaded,
+ foreground_labels=self.label_manager.all_labels,
+ regions=self.label_manager.foreground_regions if
+ self.label_manager.has_regions else None,
+ ignore_label=self.label_manager.ignore_label)
+
+ dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim)
+
+ allowed_num_processes = get_allowed_n_proc_DA()
+ if allowed_num_processes == 0:
+ mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms)
+ mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms)
+ else:
+ mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms,
+ allowed_num_processes, 6, None, True, 0.02)
+ mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms,
+ max(1, allowed_num_processes // 2), 3, None, True, 0.02)
+
+ return mt_gen_train, mt_gen_val
+
+
+class nnUNetTrainerDA5_10epochs(nnUNetTrainerDA5):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.num_epochs = 10
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDAOrd0.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDAOrd0.py
new file mode 100644
index 0000000..e87ff8f
--- /dev/null
+++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDAOrd0.py
@@ -0,0 +1,104 @@
+from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter
+
+from nnunetv2.training.data_augmentation.custom_transforms.limited_length_multithreaded_augmenter import \
+ LimitedLenWrapper
+from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
+from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA
+
+
+class nnUNetTrainerDAOrd0(nnUNetTrainer):
+ def get_dataloaders(self):
+ """
+ changed order_resampling_data, order_resampling_seg
+ """
+ # we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether
+ # we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be
+ patch_size = self.configuration_manager.patch_size
+ dim = len(patch_size)
+
+ # needed for deep supervision: how much do we need to downscale the segmentation targets for the different
+ # outputs?
+ deep_supervision_scales = self._get_deep_supervision_scales()
+
+ rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
+ self.configure_rotation_dummyDA_mirroring_and_inital_patch_size()
+
+ # training pipeline
+ tr_transforms = self.get_training_transforms(
+ patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug,
+ order_resampling_data=0, order_resampling_seg=0,
+ use_mask_for_norm=self.configuration_manager.use_mask_for_norm,
+ is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.all_labels,
+ regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None,
+ ignore_label=self.label_manager.ignore_label)
+
+ # validation pipeline
+ val_transforms = self.get_validation_transforms(deep_supervision_scales,
+ is_cascaded=self.is_cascaded,
+ foreground_labels=self.label_manager.all_labels,
+ regions=self.label_manager.foreground_regions if
+ self.label_manager.has_regions else None,
+ ignore_label=self.label_manager.ignore_label)
+
+ dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim)
+
+ allowed_num_processes = get_allowed_n_proc_DA()
+ if allowed_num_processes == 0:
+ mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms)
+ mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms)
+ else:
+ mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms,
+ allowed_num_processes, 6, None, True, 0.02)
+ mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms,
+ max(1, allowed_num_processes // 2), 3, None, True, 0.02)
+
+ return mt_gen_train, mt_gen_val
+
+
+class nnUNetTrainer_DASegOrd0(nnUNetTrainer):
+ def get_dataloaders(self):
+ """
+ changed order_resampling_data, order_resampling_seg
+ """
+ # we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether
+ # we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be
+ patch_size = self.configuration_manager.patch_size
+ dim = len(patch_size)
+
+ # needed for deep supervision: how much do we need to downscale the segmentation targets for the different
+ # outputs?
+ deep_supervision_scales = self._get_deep_supervision_scales()
+
+ rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
+ self.configure_rotation_dummyDA_mirroring_and_inital_patch_size()
+
+ # training pipeline
+ tr_transforms = self.get_training_transforms(
+ patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug,
+ order_resampling_data=3, order_resampling_seg=0,
+ use_mask_for_norm=self.configuration_manager.use_mask_for_norm,
+ is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.all_labels,
+ regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None,
+ ignore_label=self.label_manager.ignore_label)
+
+ # validation pipeline
+ val_transforms = self.get_validation_transforms(deep_supervision_scales,
+ is_cascaded=self.is_cascaded,
+ foreground_labels=self.label_manager.all_labels,
+ regions=self.label_manager.foreground_regions if
+ self.label_manager.has_regions else None,
+ ignore_label=self.label_manager.ignore_label)
+
+ dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim)
+
+ allowed_num_processes = get_allowed_n_proc_DA()
+ if allowed_num_processes == 0:
+ mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms)
+ mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms)
+ else:
+ mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms,
+ allowed_num_processes, 6, None, True, 0.02)
+ mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms,
+ max(1, allowed_num_processes // 2), 3, None, True, 0.02)
+
+ return mt_gen_train, mt_gen_val
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoDA.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoDA.py
new file mode 100644
index 0000000..527e262
--- /dev/null
+++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoDA.py
@@ -0,0 +1,40 @@
+from typing import Union, Tuple, List
+
+from batchgenerators.transforms.abstract_transforms import AbstractTransform
+
+from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
+import numpy as np
+
+
+class nnUNetTrainerNoDA(nnUNetTrainer):
+ @staticmethod
+ def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]],
+ rotation_for_DA: dict,
+ deep_supervision_scales: Union[List, Tuple],
+ mirror_axes: Tuple[int, ...],
+ do_dummy_2d_data_aug: bool,
+ order_resampling_data: int = 1,
+ order_resampling_seg: int = 0,
+ border_val_seg: int = -1,
+ use_mask_for_norm: List[bool] = None,
+ is_cascaded: bool = False,
+ foreground_labels: Union[Tuple[int, ...], List[int]] = None,
+ regions: List[Union[List[int], Tuple[int, ...], int]] = None,
+ ignore_label: int = None) -> AbstractTransform:
+ return nnUNetTrainer.get_validation_transforms(deep_supervision_scales, is_cascaded, foreground_labels,
+ regions, ignore_label)
+
+ def get_plain_dataloaders(self, initial_patch_size: Tuple[int, ...], dim: int):
+ return super().get_plain_dataloaders(
+ initial_patch_size=self.configuration_manager.patch_size,
+ dim=dim
+ )
+
+ def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
+ # we need to disable mirroring here so that no mirroring will be applied in inferene!
+ rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
+ super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()
+ mirror_axes = None
+ self.inference_allowed_mirroring_axes = None
+ return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
+
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoMirroring.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoMirroring.py
new file mode 100644
index 0000000..18ea1ea
--- /dev/null
+++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoMirroring.py
@@ -0,0 +1,28 @@
+from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
+
+
+class nnUNetTrainerNoMirroring(nnUNetTrainer):
+ def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
+ rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
+ super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()
+ mirror_axes = None
+ self.inference_allowed_mirroring_axes = None
+ return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
+
+
+class nnUNetTrainer_onlyMirror01(nnUNetTrainer):
+ """
+ Only mirrors along spatial axes 0 and 1 for 3D and 0 for 2D
+ """
+ def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
+ rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
+ super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()
+ patch_size = self.configuration_manager.patch_size
+ dim = len(patch_size)
+ if dim == 2:
+ mirror_axes = (0, )
+ else:
+ mirror_axes = (0, 1)
+ self.inference_allowed_mirroring_axes = mirror_axes
+ return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
+
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/loss/__init__.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/loss/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerCELoss.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerCELoss.py
new file mode 100644
index 0000000..1a363cc
--- /dev/null
+++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerCELoss.py
@@ -0,0 +1,23 @@
+from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper
+from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
+from nnunetv2.training.loss.robust_ce_loss import RobustCrossEntropyLoss
+import numpy as np
+
+
+class nnUNetTrainerCELoss(nnUNetTrainer):
+ def _build_loss(self):
+ assert not self.label_manager.has_regions, 'regions not supported by this trainer'
+ loss = RobustCrossEntropyLoss(weight=None,
+ ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100)
+
+ deep_supervision_scales = self._get_deep_supervision_scales()
+
+ # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
+ # this gives higher resolution outputs more weight in the loss
+ weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))])
+
+ # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
+ weights = weights / weights.sum()
+ # now wrap the loss
+ loss = DeepSupervisionWrapper(loss, weights)
+ return loss
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerDiceLoss.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerDiceLoss.py
new file mode 100644
index 0000000..82114c0
--- /dev/null
+++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerDiceLoss.py
@@ -0,0 +1,56 @@
+import numpy as np
+import torch
+
+from nnunetv2.training.loss.compound_losses import DC_and_BCE_loss, DC_and_CE_loss
+from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper
+from nnunetv2.training.loss.dice import MemoryEfficientSoftDiceLoss
+from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
+from nnunetv2.utilities.helpers import softmax_helper_dim1
+
+
+class nnUNetTrainerDiceLoss(nnUNetTrainer):
+ def _build_loss(self):
+ loss = MemoryEfficientSoftDiceLoss(**{'batch_dice': self.configuration_manager.batch_dice,
+ 'do_bg': self.label_manager.has_regions, 'smooth': 1e-5, 'ddp': self.is_ddp},
+ apply_nonlin=torch.sigmoid if self.label_manager.has_regions else softmax_helper_dim1)
+
+ deep_supervision_scales = self._get_deep_supervision_scales()
+
+ # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
+ # this gives higher resolution outputs more weight in the loss
+ weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))])
+
+ # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
+ weights = weights / weights.sum()
+ # now wrap the loss
+ loss = DeepSupervisionWrapper(loss, weights)
+ return loss
+
+
+class nnUNetTrainerDiceCELoss_noSmooth(nnUNetTrainer):
+ def _build_loss(self):
+ # set smooth to 0
+ if self.label_manager.has_regions:
+ loss = DC_and_BCE_loss({},
+ {'batch_dice': self.configuration_manager.batch_dice,
+ 'do_bg': True, 'smooth': 0, 'ddp': self.is_ddp},
+ use_ignore_label=self.label_manager.ignore_label is not None,
+ dice_class=MemoryEfficientSoftDiceLoss)
+ else:
+ loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice,
+ 'smooth': 0, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1,
+ ignore_label=self.label_manager.ignore_label,
+ dice_class=MemoryEfficientSoftDiceLoss)
+
+ deep_supervision_scales = self._get_deep_supervision_scales()
+
+ # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
+ # this gives higher resolution outputs more weight in the loss
+ weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))])
+
+ # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
+ weights = weights / weights.sum()
+ # now wrap the loss
+ loss = DeepSupervisionWrapper(loss, weights)
+ return loss
+
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerTopkLoss.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerTopkLoss.py
new file mode 100644
index 0000000..4ebfd41
--- /dev/null
+++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerTopkLoss.py
@@ -0,0 +1,66 @@
+from nnunetv2.training.loss.compound_losses import DC_and_topk_loss
+from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper
+from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
+import numpy as np
+from nnunetv2.training.loss.robust_ce_loss import TopKLoss
+
+
+class nnUNetTrainerTopk10Loss(nnUNetTrainer):
+ def _build_loss(self):
+ assert not self.label_manager.has_regions, 'regions not supported by this trainer'
+ loss = TopKLoss(ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100,
+ k=10)
+
+ deep_supervision_scales = self._get_deep_supervision_scales()
+
+ # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
+ # this gives higher resolution outputs more weight in the loss
+ weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))])
+
+ # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
+ weights = weights / weights.sum()
+ # now wrap the loss
+ loss = DeepSupervisionWrapper(loss, weights)
+ return loss
+
+
+class nnUNetTrainerTopk10LossLS01(nnUNetTrainer):
+ def _build_loss(self):
+ assert not self.label_manager.has_regions, 'regions not supported by this trainer'
+ loss = TopKLoss(ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100,
+ k=10, label_smoothing=0.1)
+
+ deep_supervision_scales = self._get_deep_supervision_scales()
+
+ # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
+ # this gives higher resolution outputs more weight in the loss
+ weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))])
+
+ # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
+ weights = weights / weights.sum()
+ # now wrap the loss
+ loss = DeepSupervisionWrapper(loss, weights)
+ return loss
+
+
+class nnUNetTrainerDiceTopK10Loss(nnUNetTrainer):
+ def _build_loss(self):
+ assert not self.label_manager.has_regions, 'regions not supported by this trainer'
+ loss = DC_and_topk_loss({'batch_dice': self.configuration_manager.batch_dice,
+ 'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp},
+ {'k': 10,
+ 'label_smoothing': 0.0},
+ weight_ce=1, weight_dice=1,
+ ignore_label=self.label_manager.ignore_label)
+
+ deep_supervision_scales = self._get_deep_supervision_scales()
+
+ # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
+ # this gives higher resolution outputs more weight in the loss
+ weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))])
+
+ # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
+ weights = weights / weights.sum()
+ # now wrap the loss
+ loss = DeepSupervisionWrapper(loss, weights)
+ return loss
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/lr_schedule/__init__.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/lr_schedule/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/lr_schedule/nnUNetTrainerCosAnneal.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/lr_schedule/nnUNetTrainerCosAnneal.py
new file mode 100644
index 0000000..60455f2
--- /dev/null
+++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/lr_schedule/nnUNetTrainerCosAnneal.py
@@ -0,0 +1,13 @@
+import torch
+from torch.optim.lr_scheduler import CosineAnnealingLR
+
+from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
+
+
+class nnUNetTrainerCosAnneal(nnUNetTrainer):
+ def configure_optimizers(self):
+ optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
+ momentum=0.99, nesterov=True)
+ lr_scheduler = CosineAnnealingLR(optimizer, T_max=self.num_epochs)
+ return optimizer, lr_scheduler
+
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/network_architecture/__init__.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/network_architecture/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerBN.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerBN.py
new file mode 100644
index 0000000..b2f26e2
--- /dev/null
+++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerBN.py
@@ -0,0 +1,73 @@
+from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet, PlainConvUNet
+from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_batchnorm
+from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0, InitWeights_He
+from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
+from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager
+from torch import nn
+
+
+class nnUNetTrainerBN(nnUNetTrainer):
+ @staticmethod
+ def build_network_architecture(plans_manager: PlansManager,
+ dataset_json,
+ configuration_manager: ConfigurationManager,
+ num_input_channels,
+ enable_deep_supervision: bool = True) -> nn.Module:
+ num_stages = len(configuration_manager.conv_kernel_sizes)
+
+ dim = len(configuration_manager.conv_kernel_sizes[0])
+ conv_op = convert_dim_to_conv_op(dim)
+
+ label_manager = plans_manager.get_label_manager(dataset_json)
+
+ segmentation_network_class_name = configuration_manager.UNet_class_name
+ mapping = {
+ 'PlainConvUNet': PlainConvUNet,
+ 'ResidualEncoderUNet': ResidualEncoderUNet
+ }
+ kwargs = {
+ 'PlainConvUNet': {
+ 'conv_bias': True,
+ 'norm_op': get_matching_batchnorm(conv_op),
+ 'norm_op_kwargs': {'eps': 1e-5, 'affine': True},
+ 'dropout_op': None, 'dropout_op_kwargs': None,
+ 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True},
+ },
+ 'ResidualEncoderUNet': {
+ 'conv_bias': True,
+ 'norm_op': get_matching_batchnorm(conv_op),
+ 'norm_op_kwargs': {'eps': 1e-5, 'affine': True},
+ 'dropout_op': None, 'dropout_op_kwargs': None,
+ 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True},
+ }
+ }
+ assert segmentation_network_class_name in mapping.keys(), 'The network architecture specified by the plans file ' \
+ 'is non-standard (maybe your own?). Yo\'ll have to dive ' \
+ 'into either this ' \
+ 'function (get_network_from_plans) or ' \
+ 'the init of your nnUNetModule to accomodate that.'
+ network_class = mapping[segmentation_network_class_name]
+
+ conv_or_blocks_per_stage = {
+ 'n_conv_per_stage'
+ if network_class != ResidualEncoderUNet else 'n_blocks_per_stage': configuration_manager.n_conv_per_stage_encoder,
+ 'n_conv_per_stage_decoder': configuration_manager.n_conv_per_stage_decoder
+ }
+ # network class name!!
+ model = network_class(
+ input_channels=num_input_channels,
+ n_stages=num_stages,
+ features_per_stage=[min(configuration_manager.UNet_base_num_features * 2 ** i,
+ configuration_manager.unet_max_num_features) for i in range(num_stages)],
+ conv_op=conv_op,
+ kernel_sizes=configuration_manager.conv_kernel_sizes,
+ strides=configuration_manager.pool_op_kernel_sizes,
+ num_classes=label_manager.num_segmentation_heads,
+ deep_supervision=enable_deep_supervision,
+ **conv_or_blocks_per_stage,
+ **kwargs[segmentation_network_class_name]
+ )
+ model.apply(InitWeights_He(1e-2))
+ if network_class == ResidualEncoderUNet:
+ model.apply(init_last_bn_before_add_to_0)
+ return model
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py
new file mode 100644
index 0000000..a07ff8a
--- /dev/null
+++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py
@@ -0,0 +1,114 @@
+import torch
+from torch import autocast
+
+from nnunetv2.training.loss.compound_losses import DC_and_BCE_loss, DC_and_CE_loss
+from nnunetv2.training.loss.dice import get_tp_fp_fn_tn, MemoryEfficientSoftDiceLoss
+from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
+from nnunetv2.utilities.helpers import dummy_context
+from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+
+class nnUNetTrainerNoDeepSupervision(nnUNetTrainer):
+ def _build_loss(self):
+ if self.label_manager.has_regions:
+ loss = DC_and_BCE_loss({},
+ {'batch_dice': self.configuration_manager.batch_dice,
+ 'do_bg': True, 'smooth': 1e-5, 'ddp': self.is_ddp},
+ use_ignore_label=self.label_manager.ignore_label is not None,
+ dice_class=MemoryEfficientSoftDiceLoss)
+ else:
+ loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice,
+ 'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1,
+ ignore_label=self.label_manager.ignore_label,
+ dice_class=MemoryEfficientSoftDiceLoss)
+ return loss
+
+ def _get_deep_supervision_scales(self):
+ return None
+
+ def initialize(self):
+ if not self.was_initialized:
+ self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager,
+ self.dataset_json)
+
+ self.network = self.build_network_architecture(self.plans_manager, self.dataset_json,
+ self.configuration_manager,
+ self.num_input_channels,
+ enable_deep_supervision=False).to(self.device)
+
+ self.optimizer, self.lr_scheduler = self.configure_optimizers()
+ # if ddp, wrap in DDP wrapper
+ if self.is_ddp:
+ self.network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.network)
+ self.network = DDP(self.network, device_ids=[self.local_rank])
+
+ self.loss = self._build_loss()
+ self.was_initialized = True
+ else:
+ raise RuntimeError("You have called self.initialize even though the trainer was already initialized. "
+ "That should not happen.")
+
+ def set_deep_supervision_enabled(self, enabled: bool):
+ pass
+
+ def validation_step(self, batch: dict) -> dict:
+ data = batch['data']
+ target = batch['target']
+
+ data = data.to(self.device, non_blocking=True)
+ if isinstance(target, list):
+ target = [i.to(self.device, non_blocking=True) for i in target]
+ else:
+ target = target.to(self.device, non_blocking=True)
+
+ self.optimizer.zero_grad()
+
+ # Autocast is a little bitch.
+ # If the device_type is 'cpu' then it's slow as heck and needs to be disabled.
+ # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
+ # So autocast will only be active if we have a cuda device.
+ with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
+ output = self.network(data)
+ del data
+ l = self.loss(output, target)
+
+ # the following is needed for online evaluation. Fake dice (green line)
+ axes = [0] + list(range(2, len(output.shape)))
+
+ if self.label_manager.has_regions:
+ predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long()
+ else:
+ # no need for softmax
+ output_seg = output.argmax(1)[:, None]
+ predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float32)
+ predicted_segmentation_onehot.scatter_(1, output_seg, 1)
+ del output_seg
+
+ if self.label_manager.has_ignore_label:
+ if not self.label_manager.has_regions:
+ mask = (target != self.label_manager.ignore_label).float()
+ # CAREFUL that you don't rely on target after this line!
+ target[target == self.label_manager.ignore_label] = 0
+ else:
+ mask = 1 - target[:, -1:]
+ # CAREFUL that you don't rely on target after this line!
+ target = target[:, :-1]
+ else:
+ mask = None
+
+ tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask)
+
+ tp_hard = tp.detach().cpu().numpy()
+ fp_hard = fp.detach().cpu().numpy()
+ fn_hard = fn.detach().cpu().numpy()
+ if not self.label_manager.has_regions:
+ # if we train with regions all segmentation heads predict some kind of foreground. In conventional
+ # (softmax training) there needs tobe one output for the background. We are not interested in the
+ # background Dice
+ # [1:] in order to remove background
+ tp_hard = tp_hard[1:]
+ fp_hard = fp_hard[1:]
+ fn_hard = fn_hard[1:]
+
+ return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard}
\ No newline at end of file
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/optimizer/__init__.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/optimizer/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/optimizer/nnUNetTrainerAdam.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/optimizer/nnUNetTrainerAdam.py
new file mode 100644
index 0000000..be5a7f4
--- /dev/null
+++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/optimizer/nnUNetTrainerAdam.py
@@ -0,0 +1,58 @@
+import torch
+from torch.optim import Adam, AdamW
+
+from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler
+from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
+
+
+class nnUNetTrainerAdam(nnUNetTrainer):
+ def configure_optimizers(self):
+ optimizer = AdamW(self.network.parameters(),
+ lr=self.initial_lr,
+ weight_decay=self.weight_decay,
+ amsgrad=True)
+ # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
+ # momentum=0.99, nesterov=True)
+ lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs)
+ return optimizer, lr_scheduler
+
+
+class nnUNetTrainerVanillaAdam(nnUNetTrainer):
+ def configure_optimizers(self):
+ optimizer = Adam(self.network.parameters(),
+ lr=self.initial_lr,
+ weight_decay=self.weight_decay)
+ # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
+ # momentum=0.99, nesterov=True)
+ lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs)
+ return optimizer, lr_scheduler
+
+
+class nnUNetTrainerVanillaAdam1en3(nnUNetTrainerVanillaAdam):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.initial_lr = 1e-3
+
+
+class nnUNetTrainerVanillaAdam3en4(nnUNetTrainerVanillaAdam):
+ # https://twitter.com/karpathy/status/801621764144971776?lang=en
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.initial_lr = 3e-4
+
+
+class nnUNetTrainerAdam1en3(nnUNetTrainerAdam):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.initial_lr = 1e-3
+
+
+class nnUNetTrainerAdam3en4(nnUNetTrainerAdam):
+ # https://twitter.com/karpathy/status/801621764144971776?lang=en
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.initial_lr = 3e-4
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/optimizer/nnUNetTrainerAdan.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/optimizer/nnUNetTrainerAdan.py
new file mode 100644
index 0000000..8747f47
--- /dev/null
+++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/optimizer/nnUNetTrainerAdan.py
@@ -0,0 +1,66 @@
+import torch
+
+from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler
+from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
+from torch.optim.lr_scheduler import CosineAnnealingLR
+try:
+ from adan_pytorch import Adan
+except ImportError:
+ Adan = None
+
+
+class nnUNetTrainerAdan(nnUNetTrainer):
+ def configure_optimizers(self):
+ if Adan is None:
+ raise RuntimeError('This trainer requires adan_pytorch to be installed, install with "pip install adan-pytorch"')
+ optimizer = Adan(self.network.parameters(),
+ lr=self.initial_lr,
+ # betas=(0.02, 0.08, 0.01), defaults
+ weight_decay=self.weight_decay)
+ # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
+ # momentum=0.99, nesterov=True)
+ lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs)
+ return optimizer, lr_scheduler
+
+
+class nnUNetTrainerAdan1en3(nnUNetTrainerAdan):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.initial_lr = 1e-3
+
+
+class nnUNetTrainerAdan3en4(nnUNetTrainerAdan):
+ # https://twitter.com/karpathy/status/801621764144971776?lang=en
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.initial_lr = 3e-4
+
+
+class nnUNetTrainerAdan1en1(nnUNetTrainerAdan):
+ # this trainer makes no sense -> nan!
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.initial_lr = 1e-1
+
+
+class nnUNetTrainerAdanCosAnneal(nnUNetTrainerAdan):
+ # def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ # device: torch.device = torch.device('cuda')):
+ # super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ # self.num_epochs = 15
+
+ def configure_optimizers(self):
+ if Adan is None:
+ raise RuntimeError('This trainer requires adan_pytorch to be installed, install with "pip install adan-pytorch"')
+ optimizer = Adan(self.network.parameters(),
+ lr=self.initial_lr,
+ # betas=(0.02, 0.08, 0.01), defaults
+ weight_decay=self.weight_decay)
+ # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
+ # momentum=0.99, nesterov=True)
+ lr_scheduler = CosineAnnealingLR(optimizer, T_max=self.num_epochs)
+ return optimizer, lr_scheduler
+
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/sampling/__init__.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/sampling/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/sampling/nnUNetTrainer_probabilisticOversampling.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/sampling/nnUNetTrainer_probabilisticOversampling.py
new file mode 100644
index 0000000..89fef48
--- /dev/null
+++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/sampling/nnUNetTrainer_probabilisticOversampling.py
@@ -0,0 +1,76 @@
+from typing import Tuple
+
+import torch
+
+from nnunetv2.training.dataloading.data_loader_2d import nnUNetDataLoader2D
+from nnunetv2.training.dataloading.data_loader_3d import nnUNetDataLoader3D
+from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
+import numpy as np
+
+
+class nnUNetTrainer_probabilisticOversampling(nnUNetTrainer):
+ """
+ sampling of foreground happens randomly and not for the last 33% of samples in a batch
+ since most trainings happen with batch size 2 and nnunet guarantees at least one fg sample, effectively this can
+ be 50%
+ Here we compute the actual oversampling percentage used by nnUNetTrainer in order to be as consistent as possible.
+ If we switch to this oversampling then we can keep it at a constant 0.33 or whatever.
+ """
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.oversample_foreground_percent = float(np.mean(
+ [not sample_idx < round(self.configuration_manager.batch_size * (1 - self.oversample_foreground_percent))
+ for sample_idx in range(self.configuration_manager.batch_size)]))
+ self.print_to_log_file(f"self.oversample_foreground_percent {self.oversample_foreground_percent}")
+
+ def get_plain_dataloaders(self, initial_patch_size: Tuple[int, ...], dim: int):
+ dataset_tr, dataset_val = self.get_tr_and_val_datasets()
+
+ if dim == 2:
+ dl_tr = nnUNetDataLoader2D(dataset_tr,
+ self.batch_size,
+ initial_patch_size,
+ self.configuration_manager.patch_size,
+ self.label_manager,
+ oversample_foreground_percent=self.oversample_foreground_percent,
+ sampling_probabilities=None, pad_sides=None, probabilistic_oversampling=True)
+ dl_val = nnUNetDataLoader2D(dataset_val,
+ self.batch_size,
+ self.configuration_manager.patch_size,
+ self.configuration_manager.patch_size,
+ self.label_manager,
+ oversample_foreground_percent=self.oversample_foreground_percent,
+ sampling_probabilities=None, pad_sides=None, probabilistic_oversampling=True)
+ else:
+ dl_tr = nnUNetDataLoader3D(dataset_tr,
+ self.batch_size,
+ initial_patch_size,
+ self.configuration_manager.patch_size,
+ self.label_manager,
+ oversample_foreground_percent=self.oversample_foreground_percent,
+ sampling_probabilities=None, pad_sides=None, probabilistic_oversampling=True)
+ dl_val = nnUNetDataLoader3D(dataset_val,
+ self.batch_size,
+ self.configuration_manager.patch_size,
+ self.configuration_manager.patch_size,
+ self.label_manager,
+ oversample_foreground_percent=self.oversample_foreground_percent,
+ sampling_probabilities=None, pad_sides=None, probabilistic_oversampling=True)
+ return dl_tr, dl_val
+
+
+class nnUNetTrainer_probabilisticOversampling_033(nnUNetTrainer_probabilisticOversampling):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.oversample_foreground_percent = 0.33
+
+
+class nnUNetTrainer_probabilisticOversampling_010(nnUNetTrainer_probabilisticOversampling):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.oversample_foreground_percent = 0.1
+
+
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/training_length/__init__.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/training_length/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py
new file mode 100644
index 0000000..e3a71a0
--- /dev/null
+++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py
@@ -0,0 +1,76 @@
+import torch
+
+from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
+
+
+class nnUNetTrainer_5epochs(nnUNetTrainer):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ """used for debugging plans etc"""
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.num_epochs = 5
+
+
+class nnUNetTrainer_1epoch(nnUNetTrainer):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ """used for debugging plans etc"""
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.num_epochs = 1
+
+
+class nnUNetTrainer_10epochs(nnUNetTrainer):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ """used for debugging plans etc"""
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.num_epochs = 10
+
+
+class nnUNetTrainer_20epochs(nnUNetTrainer):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.num_epochs = 20
+
+
+class nnUNetTrainer_50epochs(nnUNetTrainer):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.num_epochs = 50
+
+
+class nnUNetTrainer_100epochs(nnUNetTrainer):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.num_epochs = 100
+
+
+class nnUNetTrainer_250epochs(nnUNetTrainer):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.num_epochs = 250
+
+
+class nnUNetTrainer_2000epochs(nnUNetTrainer):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.num_epochs = 2000
+
+
+class nnUNetTrainer_4000epochs(nnUNetTrainer):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.num_epochs = 4000
+
+
+class nnUNetTrainer_8000epochs(nnUNetTrainer):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.num_epochs = 8000
diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs_NoMirroring.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs_NoMirroring.py
new file mode 100644
index 0000000..c16b885
--- /dev/null
+++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs_NoMirroring.py
@@ -0,0 +1,60 @@
+import torch
+
+from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
+
+
+class nnUNetTrainer_250epochs_NoMirroring(nnUNetTrainer):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.num_epochs = 250
+
+ def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
+ rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
+ super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()
+ mirror_axes = None
+ self.inference_allowed_mirroring_axes = None
+ return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
+
+
+class nnUNetTrainer_2000epochs_NoMirroring(nnUNetTrainer):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.num_epochs = 2000
+
+ def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
+ rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
+ super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()
+ mirror_axes = None
+ self.inference_allowed_mirroring_axes = None
+ return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
+
+
+class nnUNetTrainer_4000epochs_NoMirroring(nnUNetTrainer):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.num_epochs = 4000
+
+ def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
+ rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
+ super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()
+ mirror_axes = None
+ self.inference_allowed_mirroring_axes = None
+ return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
+
+
+class nnUNetTrainer_8000epochs_NoMirroring(nnUNetTrainer):
+ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
+ device: torch.device = torch.device('cuda')):
+ super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
+ self.num_epochs = 8000
+
+ def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
+ rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
+ super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()
+ mirror_axes = None
+ self.inference_allowed_mirroring_axes = None
+ return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
+
diff --git a/nnUNet/nnunetv2/unet.py b/nnUNet/nnunetv2/unet.py
new file mode 100644
index 0000000..df14b63
--- /dev/null
+++ b/nnUNet/nnunetv2/unet.py
@@ -0,0 +1,344 @@
+from typing import T
+import torch
+from torch.nn import Conv3d
+from einops import rearrange, repeat
+from torch import nn, einsum
+from inspect import isfunction
+
+
+class UNetEncoderS(nn.Module):
+
+ def __init__(self, channels):
+ super(UNetEncoderS, self).__init__()
+ self.inc = (DoubleConv(channels, 16))
+ self.down1 = (Down(16, 32, pooling=(1, 2, 2)))
+ self.down2 = (Down(32, 64))
+ self.down3 = (Down(64, 128))
+ self.down4 = (Down(128, 256))
+
+ def forward(self, x):
+ x1 = self.inc(x)
+ x2 = self.down1(x1)
+ x3 = self.down2(x2)
+ x4 = self.down3(x3)
+ x5 = self.down4(x4)
+ skips = [x1, x2, x3, x4]
+ return x5, skips
+
+
+class UNetEncoderL(nn.Module):
+ def __init__(self, channels):
+ super(UNetEncoderL, self).__init__()
+ self.inc = (DoubleConv(channels, 32))
+ self.down1 = (Down(32, 64, pooling=(1, 2, 2)))
+ self.down2 = (Down(64, 128))
+ self.down3 = (Down(128, 256))
+ self.down4 = (Down(256, 512))
+
+ def forward(self, x):
+
+ x1 = self.inc(x)
+ x2 = self.down1(x1)
+ x3 = self.down2(x2)
+ x4 = self.down3(x3)
+ x5 = self.down4(x4)
+ skips = [x1, x2, x3, x4]
+ return x5, skips
+
+
+class OutConv(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super(OutConv, self).__init__()
+ self.conv1 = DoubleConv(in_channels, out_channels)
+ self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=1)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ return self.conv2(x)
+
+
+class SegmentationHeadS(nn.Module):
+ def __init__(self, in_features, segmentation_classes, do_ds):
+ super(SegmentationHeadS, self).__init__()
+ self.up_segmentation1 = (Up(in_features, 128,
+ bilinear=False)
+ )
+ self.up_segmentation2 = (Up(128, 64, bilinear=False))
+ self.up_segmentation3 = (Up(64, 32, bilinear=False))
+ self.up_segmentation4 = (Up(32, 16, bilinear=False,
+ pooling=(1, 2, 2)))
+ self.outc_segmentation = (OutConv(16, segmentation_classes))
+
+ self.do_ds = do_ds
+ self.proj1 = Conv3d(256, segmentation_classes, 1)
+ self.proj2 = Conv3d(128, segmentation_classes, 1)
+ self.proj3 = Conv3d(64, segmentation_classes, 1)
+ self.proj4 = Conv3d(32, segmentation_classes, 1)
+ self.non_lin = lambda x: x # torch.softmax(x, 1)
+
+ def forward(self, x, skips):
+ x1, x2, x3, x4 = skips
+ if self.do_ds:
+ outputs = []
+ outputs.append(self.non_lin(self.proj1(x)))
+ x = self.up_segmentation1(x, x4)
+ if self.do_ds:
+ outputs.append(self.non_lin(self.proj2(x)))
+ x = self.up_segmentation2(x, x3)
+ if self.do_ds:
+ outputs.append(self.non_lin(self.proj3(x)))
+ x = self.up_segmentation3(x, x2)
+ if self.do_ds:
+ outputs.append(self.non_lin(self.proj4(x)))
+ x = self.up_segmentation4(x, x1)
+ x = self.outc_segmentation(x)
+ if self.do_ds:
+ outputs.append(self.non_lin(x))
+ if self.do_ds:
+ return outputs[::-1]
+ return self.non_lin(x)
+
+ def eval(self: T) -> T:
+ a = super(SegmentationHeadS, self).eval()
+ self.do_ds = False
+ return a
+
+ def train(self: T, mode: bool = True) -> T:
+ a = super(SegmentationHeadS, self).train()
+ self.do_ds = True
+ return a
+
+
+class DoubleConv(nn.Module):
+
+ """(convolution => [BN] => ReLU) * 2"""
+ def __init__(self, in_channels, out_channels,
+ mid_channels=None, dropout_rate: float = 0.):
+ super().__init__()
+ if not mid_channels:
+ mid_channels = out_channels
+ self.double_conv = nn.Sequential(
+ nn.Conv3d(in_channels, mid_channels, kernel_size=3,
+ padding=1, bias=True),
+ nn.Dropout3d(dropout_rate),
+ nn.BatchNorm3d(mid_channels),
+ nn.LeakyReLU(inplace=True),
+ nn.Conv3d(mid_channels, out_channels, kernel_size=3,
+ padding=1, bias=True),
+ nn.Dropout3d(dropout_rate),
+ nn.BatchNorm3d(out_channels),
+ nn.LeakyReLU(inplace=True)
+ )
+
+ def forward(self, x):
+ return self.double_conv(x)
+
+
+class Down(nn.Module):
+ """Downscaling with maxpool then double conv"""
+
+ def __init__(self, in_channels, out_channels, pooling=(2,2,2)):
+ super().__init__()
+ self.maxpool_conv = nn.Sequential(
+ nn.MaxPool3d(pooling),
+ DoubleConv(in_channels, out_channels, dropout_rate=0.5)
+ )
+
+ def forward(self, x):
+ return self.maxpool_conv(x)
+
+
+class Up(nn.Module):
+ """Upscaling then double conv"""
+
+ def __init__(self, in_channels, out_channels, pooling=(2, 2, 2), bilinear=True):
+ super().__init__()
+
+ # if bilinear, use the normal convolutions to reduce the number of channels
+ if bilinear:
+ self.up = nn.Upsample(scale_factor=pooling, mode='bilinear', align_corners=True)
+ self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
+ else:
+ self.up = nn.ConvTranspose3d(in_channels, in_channels // 2, kernel_size=pooling, stride=pooling, bias=False)
+ self.conv = DoubleConv(in_channels, out_channels)
+
+ def forward(self, x1, x2):
+ x1 = self.up(x1)
+ # input is CHW
+ # diffY = x2.size()[2] - x1.size()[2]
+ # diffX = x2.size()[3] - x1.size()[3]
+
+ # x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
+ # diffY // 2, diffY - diffY // 2])
+ # if you have padding issues, see
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
+ x = torch.cat([x2, x1], dim=1)
+ return self.conv(x)
+
+
+class SegmentationHeadL(nn.Module):
+ def __init__(self, in_features, segmentation_classes, do_ds):
+ super(SegmentationHeadL, self).__init__()
+ self.up_segmentation1 = (Up(in_features, 256, bilinear=False))
+ self.up_segmentation2 = (Up(256, 128, bilinear=False))
+ self.up_segmentation3 = (Up(128, 64, bilinear=False))
+ self.up_segmentation4 = (Up(64, 32, bilinear=False,
+ pooling=(1, 2, 2)))
+ self.outc_segmentation = (OutConv(32, segmentation_classes))
+
+ self.do_ds = do_ds
+ self.proj1 = Conv3d(512, segmentation_classes, 1)
+ self.proj2 = Conv3d(256, segmentation_classes, 1)
+ self.proj3 = Conv3d(128, segmentation_classes, 1)
+ self.proj4 = Conv3d(64, segmentation_classes, 1)
+ self.non_lin = lambda x: x # torch.softmax(x, 1)
+
+ def forward(self, x, skips):
+ x1, x2, x3, x4 = skips
+ if self.do_ds:
+ outputs = []
+ outputs.append(self.non_lin(self.proj1(x)))
+ x = self.up_segmentation1(x, x4)
+ if self.do_ds:
+ outputs.append(self.non_lin(self.proj2(x)))
+ x = self.up_segmentation2(x, x3)
+ if self.do_ds:
+ outputs.append(self.non_lin(self.proj3(x)))
+ x = self.up_segmentation3(x, x2)
+ if self.do_ds:
+ outputs.append(self.non_lin(self.proj4(x)))
+ x = self.up_segmentation4(x, x1)
+ x = self.outc_segmentation(x)
+ if self.do_ds:
+ outputs.append(self.non_lin(x))
+ if self.do_ds:
+ return outputs[::-1]
+ return self.non_lin(x)
+
+ def eval(self: T) -> T:
+ a = super(SegmentationHeadL, self).eval()
+ self.do_ds = False
+ return a
+
+ def train(self: T, mode: bool = True) -> T:
+ a = super(SegmentationHeadL, self).train()
+ self.do_ds = True
+ return a
+
+
+def exists(val):
+ return val is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+class CrossAttention(nn.Module):
+
+ def __init__(self, query_dim, context_dim=None,
+ heads=8, dim_head=64, dropout=0.):
+
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+
+ self.to_q = nn.Conv3d(query_dim, inner_dim,
+ kernel_size=1, bias=False)
+ self.to_k = nn.Conv3d(context_dim, inner_dim,
+ kernel_size=1, bias=False)
+ self.to_v = nn.Conv3d(context_dim, inner_dim,
+ kernel_size=1, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Conv3d(inner_dim, query_dim, kernel_size=1),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x, context=None, mask=None):
+ n = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ b, c, h, w, d = x.shape
+ q, k, v = map(lambda t:
+ rearrange(t, 'b (n c) h w l -> (b n) c (h w l)',
+ n=n), (q, k, v))
+
+ # force cast to fp32 to avoid overflowing
+ with torch.autocast(enabled=False, device_type='cuda'):
+ q, k = q.float(), k.float()
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+ del q, k
+
+ if exists(mask):
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ sim = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', sim, v)
+ out = rearrange(out, '(b n) c (h w l) -> b (n c) h w l', n=n, h=h, w=w, l=d)
+ out = self.to_out(out)
+ return out
+
+
+class UNetDeepSupervisionDoubleEncoder(nn.Module):
+ def __init__(self, n_channels_1, n_channels_2,
+ n_classes_segmentation, deep_supervision=True,
+ encoder=UNetEncoderL,
+ segmentation_head=SegmentationHeadL):
+ super(UNetDeepSupervisionDoubleEncoder, self).__init__()
+ self.n_channels_1 = n_channels_1
+ self.n_channels_2 = n_channels_2
+ self.n_classes_segmentation = n_classes_segmentation
+ self.nb_decoders = 4
+ self.do_ds = deep_supervision
+ self.deep_supervision = deep_supervision
+
+ self.encoder1 = encoder(self.n_channels_1)
+ self.encoder2 = encoder(self.n_channels_2)
+ feature_size = 512 if segmentation_head == SegmentationHeadL else 256
+ print(feature_size)
+ self.segmentation_head = segmentation_head(feature_size,
+ self.n_classes_segmentation,
+ self.do_ds)
+ # self.CA = CrossAttention(query_dim=feature_size)
+
+ def forward(self, x_in):
+ x, y = x_in[:, 0:1, :, :, :], x_in[:, 1:2, :, :, :]
+ features1, skips_1 = self.encoder1(x)
+ features2, skips_2 = self.encoder2(y)
+ x = torch.cat([features1, features2], dim=1)
+ skip = [torch.cat([s1, s2], dim=1) for s1, s2 in zip(skips_1, skips_2)]
+
+ # self.CA(features1, context=features2)
+ return self.segmentation_head(x, skip)
+
+ def eval(self: T) -> T:
+ super(UNetDeepSupervisionDoubleEncoder, self).eval()
+ self.do_ds = False
+ self.segmentation_head.eval()
+ self.encoder1.eval()
+ self.encoder2.eval()
+
+ def train(self: T, mode: bool = True) -> T:
+ super(UNetDeepSupervisionDoubleEncoder, self).train()
+ self.do_ds = True
+ self.encoder1.train()
+ self.encoder2.train()
+ self.segmentation_head.train()
+
diff --git a/nnUNet/nnunetv2/utilities/__init__.py b/nnUNet/nnunetv2/utilities/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/utilities/collate_outputs.py b/nnUNet/nnunetv2/utilities/collate_outputs.py
new file mode 100644
index 0000000..c9d6798
--- /dev/null
+++ b/nnUNet/nnunetv2/utilities/collate_outputs.py
@@ -0,0 +1,24 @@
+from typing import List
+
+import numpy as np
+
+
+def collate_outputs(outputs: List[dict]):
+ """
+ used to collate default train_step and validation_step outputs. If you want something different then you gotta
+ extend this
+
+ we expect outputs to be a list of dictionaries where each of the dict has the same set of keys
+ """
+ collated = {}
+ for k in outputs[0].keys():
+ if np.isscalar(outputs[0][k]):
+ collated[k] = [o[k] for o in outputs]
+ elif isinstance(outputs[0][k], np.ndarray):
+ collated[k] = np.vstack([o[k][None] for o in outputs])
+ elif isinstance(outputs[0][k], list):
+ collated[k] = [item for o in outputs for item in o[k]]
+ else:
+ raise ValueError(f'Cannot collate input of type {type(outputs[0][k])}. '
+ f'Modify collate_outputs to add this functionality')
+ return collated
\ No newline at end of file
diff --git a/nnUNet/nnunetv2/utilities/dataset_name_id_conversion.py b/nnUNet/nnunetv2/utilities/dataset_name_id_conversion.py
new file mode 100644
index 0000000..1f2c350
--- /dev/null
+++ b/nnUNet/nnunetv2/utilities/dataset_name_id_conversion.py
@@ -0,0 +1,74 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Union
+
+from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw, nnUNet_results
+from batchgenerators.utilities.file_and_folder_operations import *
+import numpy as np
+
+
+def find_candidate_datasets(dataset_id: int):
+ startswith = "Dataset%03.0d" % dataset_id
+ if nnUNet_preprocessed is not None and isdir(nnUNet_preprocessed):
+ candidates_preprocessed = subdirs(nnUNet_preprocessed, prefix=startswith, join=False)
+ else:
+ candidates_preprocessed = []
+
+ if nnUNet_raw is not None and isdir(nnUNet_raw):
+ candidates_raw = subdirs(nnUNet_raw, prefix=startswith, join=False)
+ else:
+ candidates_raw = []
+
+ candidates_trained_models = []
+ if nnUNet_results is not None and isdir(nnUNet_results):
+ candidates_trained_models += subdirs(nnUNet_results, prefix=startswith, join=False)
+
+ all_candidates = candidates_preprocessed + candidates_raw + candidates_trained_models
+ unique_candidates = np.unique(all_candidates)
+ return unique_candidates
+
+
+def convert_id_to_dataset_name(dataset_id: int):
+ unique_candidates = find_candidate_datasets(dataset_id)
+ if len(unique_candidates) > 1:
+ raise RuntimeError("More than one dataset name found for dataset id %d. Please correct that. (I looked in the "
+ "following folders:\n%s\n%s\n%s" % (dataset_id, nnUNet_raw, nnUNet_preprocessed, nnUNet_results))
+ if len(unique_candidates) == 0:
+ raise RuntimeError(f"Could not find a dataset with the ID {dataset_id}. Make sure the requested dataset ID "
+ f"exists and that nnU-Net knows where raw and preprocessed data are located "
+ f"(see Documentation - Installation). Here are your currently defined folders:\n"
+ f"nnUNet_preprocessed={os.environ.get('nnUNet_preprocessed') if os.environ.get('nnUNet_preprocessed') is not None else 'None'}\n"
+ f"nnUNet_results={os.environ.get('nnUNet_results') if os.environ.get('nnUNet_results') is not None else 'None'}\n"
+ f"nnUNet_raw={os.environ.get('nnUNet_raw') if os.environ.get('nnUNet_raw') is not None else 'None'}\n"
+ f"If something is not right, adapt your environment variables.")
+ return unique_candidates[0]
+
+
+def convert_dataset_name_to_id(dataset_name: str):
+ assert dataset_name.startswith("Dataset")
+ dataset_id = int(dataset_name[7:10])
+ return dataset_id
+
+
+def maybe_convert_to_dataset_name(dataset_name_or_id: Union[int, str]) -> str:
+ if isinstance(dataset_name_or_id, str) and dataset_name_or_id.startswith("Dataset"):
+ return dataset_name_or_id
+ if isinstance(dataset_name_or_id, str):
+ try:
+ dataset_name_or_id = int(dataset_name_or_id)
+ except ValueError:
+ raise ValueError("dataset_name_or_id was a string and did not start with 'Dataset' so we tried to "
+ "convert it to a dataset ID (int). That failed, however. Please give an integer number "
+ "('1', '2', etc) or a correct tast name. Your input: %s" % dataset_name_or_id)
+ return convert_id_to_dataset_name(dataset_name_or_id)
\ No newline at end of file
diff --git a/nnUNet/nnunetv2/utilities/ddp_allgather.py b/nnUNet/nnunetv2/utilities/ddp_allgather.py
new file mode 100644
index 0000000..c42b3ef
--- /dev/null
+++ b/nnUNet/nnunetv2/utilities/ddp_allgather.py
@@ -0,0 +1,49 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Optional, Tuple
+
+import torch
+from torch import distributed
+
+
+def print_if_rank0(*args):
+ if distributed.get_rank() == 0:
+ print(*args)
+
+
+class AllGatherGrad(torch.autograd.Function):
+ # stolen from pytorch lightning
+ @staticmethod
+ def forward(
+ ctx: Any,
+ tensor: torch.Tensor,
+ group: Optional["torch.distributed.ProcessGroup"] = None,
+ ) -> torch.Tensor:
+ ctx.group = group
+
+ gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
+
+ torch.distributed.all_gather(gathered_tensor, tensor, group=group)
+ gathered_tensor = torch.stack(gathered_tensor, dim=0)
+
+ return gathered_tensor
+
+ @staticmethod
+ def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:
+ grad_output = torch.cat(grad_output)
+
+ torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group)
+
+ return grad_output[torch.distributed.get_rank()], None
+
diff --git a/nnUNet/nnunetv2/utilities/default_n_proc_DA.py b/nnUNet/nnunetv2/utilities/default_n_proc_DA.py
new file mode 100644
index 0000000..3ecc922
--- /dev/null
+++ b/nnUNet/nnunetv2/utilities/default_n_proc_DA.py
@@ -0,0 +1,44 @@
+import subprocess
+import os
+
+
+def get_allowed_n_proc_DA():
+ """
+ This function is used to set the number of processes used on different Systems. It is specific to our cluster
+ infrastructure at DKFZ. You can modify it to suit your needs. Everything is allowed.
+
+ IMPORTANT: if the environment variable nnUNet_n_proc_DA is set it will overwrite anything in this script
+ (see first line).
+
+ Interpret the output as the number of processes used for data augmentation PER GPU.
+
+ The way it is implemented here is simply a look up table. We know the hostnames, CPU and GPU configurations of our
+ systems and set the numbers accordingly. For example, a system with 4 GPUs and 48 threads can use 12 threads per
+ GPU without overloading the CPU (technically 11 because we have a main process as well), so that's what we use.
+ """
+
+ if 'nnUNet_n_proc_DA' in os.environ.keys():
+ use_this = int(os.environ['nnUNet_n_proc_DA'])
+ else:
+ hostname = subprocess.getoutput(['hostname'])
+ if hostname in ['Fabian', ]:
+ use_this = 12
+ elif hostname in ['hdf19-gpu16', 'hdf19-gpu17', 'hdf19-gpu18', 'hdf19-gpu19', 'e230-AMDworkstation']:
+ use_this = 16
+ elif hostname.startswith('e230-dgx1'):
+ use_this = 10
+ elif hostname.startswith('hdf18-gpu') or hostname.startswith('e132-comp'):
+ use_this = 16
+ elif hostname.startswith('e230-dgx2'):
+ use_this = 6
+ elif hostname.startswith('e230-dgxa100-'):
+ use_this = 28
+ elif hostname.startswith('lsf22-gpu'):
+ use_this = 28
+ elif hostname.startswith('hdf19-gpu') or hostname.startswith('e071-gpu'):
+ use_this = 12
+ else:
+ use_this = 12 # default value
+
+ use_this = min(use_this, os.cpu_count())
+ return use_this
diff --git a/nnUNet/nnunetv2/utilities/file_path_utilities.py b/nnUNet/nnunetv2/utilities/file_path_utilities.py
new file mode 100644
index 0000000..611f6e2
--- /dev/null
+++ b/nnUNet/nnunetv2/utilities/file_path_utilities.py
@@ -0,0 +1,123 @@
+from multiprocessing import Pool
+from typing import Union, Tuple
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import *
+
+from nnunetv2.configuration import default_num_processes
+from nnunetv2.paths import nnUNet_results
+from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
+
+
+def convert_trainer_plans_config_to_identifier(trainer_name, plans_identifier, configuration):
+ return f'{trainer_name}__{plans_identifier}__{configuration}'
+
+
+def convert_identifier_to_trainer_plans_config(identifier: str):
+ return os.path.basename(identifier).split('__')
+
+
+def get_output_folder(dataset_name_or_id: Union[str, int], trainer_name: str = 'nnUNetTrainer',
+ plans_identifier: str = 'nnUNetPlans', configuration: str = '3d_fullres',
+ fold: Union[str, int] = None) -> str:
+ tmp = join(nnUNet_results, maybe_convert_to_dataset_name(dataset_name_or_id),
+ convert_trainer_plans_config_to_identifier(trainer_name, plans_identifier, configuration))
+ if fold is not None:
+ tmp = join(tmp, f'fold_{fold}')
+ return tmp
+
+
+def parse_dataset_trainer_plans_configuration_from_path(path: str):
+ folders = split_path(path)
+ # this here can be a little tricky because we are making assumptions. Let's hope this never fails lol
+
+ # safer to make this depend on two conditions, the fold_x and the DatasetXXX
+ # first let's see if some fold_X is present
+ fold_x_present = [i.startswith('fold_') for i in folders]
+ if any(fold_x_present):
+ idx = fold_x_present.index(True)
+ # OK now two entries before that there should be DatasetXXX
+ assert len(folders[:idx]) >= 2, 'Bad path, cannot extract what I need. Your path needs to be at least ' \
+ 'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work'
+ if folders[idx - 2].startswith('Dataset'):
+ splitted = folders[idx - 1].split('__')
+ assert len(splitted) == 3, 'Bad path, cannot extract what I need. Your path needs to be at least ' \
+ 'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work'
+ return folders[idx - 2], *splitted
+ else:
+ # we can only check for dataset followed by a string that is separable into three strings by splitting with '__'
+ # look for DatasetXXX
+ dataset_folder = [i.startswith('Dataset') for i in folders]
+ if any(dataset_folder):
+ idx = dataset_folder.index(True)
+ assert len(folders) >= (idx + 1), 'Bad path, cannot extract what I need. Your path needs to be at least ' \
+ 'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work'
+ splitted = folders[idx + 1].split('__')
+ assert len(splitted) == 3, 'Bad path, cannot extract what I need. Your path needs to be at least ' \
+ 'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work'
+ return folders[idx], *splitted
+
+
+def get_ensemble_name(model1_folder, model2_folder, folds: Tuple[int, ...]):
+ identifier = 'ensemble___' + os.path.basename(model1_folder) + '___' + \
+ os.path.basename(model2_folder) + '___' + folds_tuple_to_string(folds)
+ return identifier
+
+
+def get_ensemble_name_from_d_tr_c(dataset, tr1, p1, c1, tr2, p2, c2, folds: Tuple[int, ...]):
+ model1_folder = get_output_folder(dataset, tr1, p1, c1)
+ model2_folder = get_output_folder(dataset, tr2, p2, c2)
+
+ get_ensemble_name(model1_folder, model2_folder, folds)
+
+
+def convert_ensemble_folder_to_model_identifiers_and_folds(ensemble_folder: str):
+ prefix, *models, folds = os.path.basename(ensemble_folder).split('___')
+ return models, folds
+
+
+def folds_tuple_to_string(folds: Union[List[int], Tuple[int, ...]]):
+ s = str(folds[0])
+ for f in folds[1:]:
+ s += f"_{f}"
+ return s
+
+
+def folds_string_to_tuple(folds_string: str):
+ folds = folds_string.split('_')
+ res = []
+ for f in folds:
+ try:
+ res.append(int(f))
+ except ValueError:
+ res.append(f)
+ return res
+
+
+def check_workers_alive_and_busy(export_pool: Pool, worker_list: List, results_list: List, allowed_num_queued: int = 0):
+ """
+
+ returns True if the number of results that are not ready is greater than the number of available workers + allowed_num_queued
+ """
+ alive = [i.is_alive() for i in worker_list]
+ if not all(alive):
+ raise RuntimeError('Some background workers are no longer alive')
+
+ not_ready = [not i.ready() for i in results_list]
+ if sum(not_ready) >= (len(export_pool._pool) + allowed_num_queued):
+ return True
+ return False
+
+
+if __name__ == '__main__':
+ ### well at this point I could just write tests...
+ path = '/home/fabian/results/nnUNet_remake/Dataset002_Heart/nnUNetModule__nnUNetPlans__3d_fullres'
+ print(parse_dataset_trainer_plans_configuration_from_path(path))
+ path = 'Dataset002_Heart/nnUNetModule__nnUNetPlans__3d_fullres'
+ print(parse_dataset_trainer_plans_configuration_from_path(path))
+ path = '/home/fabian/results/nnUNet_remake/Dataset002_Heart/nnUNetModule__nnUNetPlans__3d_fullres/fold_all'
+ print(parse_dataset_trainer_plans_configuration_from_path(path))
+ try:
+ path = '/home/fabian/results/nnUNet_remake/Dataset002_Heart/'
+ print(parse_dataset_trainer_plans_configuration_from_path(path))
+ except AssertionError:
+ print('yayy, assertion works')
diff --git a/nnUNet/nnunetv2/utilities/find_class_by_name.py b/nnUNet/nnunetv2/utilities/find_class_by_name.py
new file mode 100644
index 0000000..a345d99
--- /dev/null
+++ b/nnUNet/nnunetv2/utilities/find_class_by_name.py
@@ -0,0 +1,24 @@
+import importlib
+import pkgutil
+
+from batchgenerators.utilities.file_and_folder_operations import *
+
+
+def recursive_find_python_class(folder: str, class_name: str, current_module: str):
+ tr = None
+ for importer, modname, ispkg in pkgutil.iter_modules([folder]):
+ # print(modname, ispkg)
+ if not ispkg:
+ m = importlib.import_module(current_module + "." + modname)
+ if hasattr(m, class_name):
+ tr = getattr(m, class_name)
+ break
+
+ if tr is None:
+ for importer, modname, ispkg in pkgutil.iter_modules([folder]):
+ if ispkg:
+ next_current_module = current_module + "." + modname
+ tr = recursive_find_python_class(join(folder, modname), class_name, current_module=next_current_module)
+ if tr is not None:
+ break
+ return tr
\ No newline at end of file
diff --git a/nnUNet/nnunetv2/utilities/get_network_from_plans.py b/nnUNet/nnunetv2/utilities/get_network_from_plans.py
new file mode 100644
index 0000000..447d1d5
--- /dev/null
+++ b/nnUNet/nnunetv2/utilities/get_network_from_plans.py
@@ -0,0 +1,77 @@
+from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet
+from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op
+from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0
+from nnunetv2.utilities.network_initialization import InitWeights_He
+from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager
+from torch import nn
+
+
+def get_network_from_plans(plans_manager: PlansManager,
+ dataset_json: dict,
+ configuration_manager: ConfigurationManager,
+ num_input_channels: int,
+ deep_supervision: bool = True):
+ """
+ we may have to change this in the future to accommodate other plans -> network mappings
+
+ num_input_channels can differ depending on whether we do cascade. Its best to make this info available in the
+ trainer rather than inferring it again from the plans here.
+ """
+ num_stages = len(configuration_manager.conv_kernel_sizes)
+
+ dim = len(configuration_manager.conv_kernel_sizes[0])
+ conv_op = convert_dim_to_conv_op(dim)
+
+ label_manager = plans_manager.get_label_manager(dataset_json)
+
+ segmentation_network_class_name = configuration_manager.UNet_class_name
+ mapping = {
+ 'PlainConvUNet': PlainConvUNet,
+ 'ResidualEncoderUNet': ResidualEncoderUNet
+ }
+ kwargs = {
+ 'PlainConvUNet': {
+ 'conv_bias': True,
+ 'norm_op': get_matching_instancenorm(conv_op),
+ 'norm_op_kwargs': {'eps': 1e-5, 'affine': True},
+ 'dropout_op': None, 'dropout_op_kwargs': None,
+ 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True},
+ },
+ 'ResidualEncoderUNet': {
+ 'conv_bias': True,
+ 'norm_op': get_matching_instancenorm(conv_op),
+ 'norm_op_kwargs': {'eps': 1e-5, 'affine': True},
+ 'dropout_op': None, 'dropout_op_kwargs': None,
+ 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True},
+ }
+ }
+ assert segmentation_network_class_name in mapping.keys(), 'The network architecture specified by the plans file ' \
+ 'is non-standard (maybe your own?). Yo\'ll have to dive ' \
+ 'into either this ' \
+ 'function (get_network_from_plans) or ' \
+ 'the init of your nnUNetModule to accomodate that.'
+ network_class = mapping[segmentation_network_class_name]
+
+ conv_or_blocks_per_stage = {
+ 'n_conv_per_stage'
+ if network_class != ResidualEncoderUNet else 'n_blocks_per_stage': configuration_manager.n_conv_per_stage_encoder,
+ 'n_conv_per_stage_decoder': configuration_manager.n_conv_per_stage_decoder
+ }
+ # network class name!!
+ model = network_class(
+ input_channels=num_input_channels,
+ n_stages=num_stages,
+ features_per_stage=[min(configuration_manager.UNet_base_num_features * 2 ** i,
+ configuration_manager.unet_max_num_features) for i in range(num_stages)],
+ conv_op=conv_op,
+ kernel_sizes=configuration_manager.conv_kernel_sizes,
+ strides=configuration_manager.pool_op_kernel_sizes,
+ num_classes=label_manager.num_segmentation_heads,
+ deep_supervision=deep_supervision,
+ **conv_or_blocks_per_stage,
+ **kwargs[segmentation_network_class_name]
+ )
+ model.apply(InitWeights_He(1e-2))
+ if network_class == ResidualEncoderUNet:
+ model.apply(init_last_bn_before_add_to_0)
+ return model
diff --git a/nnUNet/nnunetv2/utilities/helpers.py b/nnUNet/nnunetv2/utilities/helpers.py
new file mode 100644
index 0000000..42448e3
--- /dev/null
+++ b/nnUNet/nnunetv2/utilities/helpers.py
@@ -0,0 +1,27 @@
+import torch
+
+
+def softmax_helper_dim0(x: torch.Tensor) -> torch.Tensor:
+ return torch.softmax(x, 0)
+
+
+def softmax_helper_dim1(x: torch.Tensor) -> torch.Tensor:
+ return torch.softmax(x, 1)
+
+
+def empty_cache(device: torch.device):
+ if device.type == 'cuda':
+ torch.cuda.empty_cache()
+ elif device.type == 'mps':
+ from torch import mps
+ mps.empty_cache()
+ else:
+ pass
+
+
+class dummy_context(object):
+ def __enter__(self):
+ pass
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ pass
diff --git a/nnUNet/nnunetv2/utilities/json_export.py b/nnUNet/nnunetv2/utilities/json_export.py
new file mode 100644
index 0000000..faed954
--- /dev/null
+++ b/nnUNet/nnunetv2/utilities/json_export.py
@@ -0,0 +1,59 @@
+from collections.abc import Iterable
+
+import numpy as np
+import torch
+
+
+def recursive_fix_for_json_export(my_dict: dict):
+ # json is stupid. 'cannot serialize object of type bool_/int64/float64'. Come on bro.
+ keys = list(my_dict.keys()) # cannot iterate over keys() if we change keys....
+ for k in keys:
+ if isinstance(k, (np.int64, np.int32, np.int8, np.uint8)):
+ tmp = my_dict[k]
+ del my_dict[k]
+ my_dict[int(k)] = tmp
+ del tmp
+ k = int(k)
+
+ if isinstance(my_dict[k], dict):
+ recursive_fix_for_json_export(my_dict[k])
+ elif isinstance(my_dict[k], np.ndarray):
+ assert len(my_dict[k].shape) == 1, 'only 1d arrays are supported'
+ my_dict[k] = fix_types_iterable(my_dict[k], output_type=list)
+ elif isinstance(my_dict[k], (np.bool_,)):
+ my_dict[k] = bool(my_dict[k])
+ elif isinstance(my_dict[k], (np.int64, np.int32, np.int8, np.uint8)):
+ my_dict[k] = int(my_dict[k])
+ elif isinstance(my_dict[k], (np.float32, np.float64, np.float16)):
+ my_dict[k] = float(my_dict[k])
+ elif isinstance(my_dict[k], list):
+ my_dict[k] = fix_types_iterable(my_dict[k], output_type=type(my_dict[k]))
+ elif isinstance(my_dict[k], tuple):
+ my_dict[k] = fix_types_iterable(my_dict[k], output_type=tuple)
+ elif isinstance(my_dict[k], torch.device):
+ my_dict[k] = str(my_dict[k])
+ else:
+ pass # pray it can be serialized
+
+
+def fix_types_iterable(iterable, output_type):
+ # this sh!t is hacky as hell and will break if you use it for anything outside nnunet. Keep you hands off of this.
+ out = []
+ for i in iterable:
+ if type(i) in (np.int64, np.int32, np.int8, np.uint8):
+ out.append(int(i))
+ elif isinstance(i, dict):
+ recursive_fix_for_json_export(i)
+ out.append(i)
+ elif type(i) in (np.float32, np.float64, np.float16):
+ out.append(float(i))
+ elif type(i) in (np.bool_,):
+ out.append(bool(i))
+ elif isinstance(i, str):
+ out.append(i)
+ elif isinstance(i, Iterable):
+ # print('recursive call on', i, type(i))
+ out.append(fix_types_iterable(i, type(i)))
+ else:
+ out.append(i)
+ return output_type(out)
diff --git a/nnUNet/nnunetv2/utilities/label_handling/__init__.py b/nnUNet/nnunetv2/utilities/label_handling/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/utilities/label_handling/label_handling.py b/nnUNet/nnunetv2/utilities/label_handling/label_handling.py
new file mode 100644
index 0000000..333296d
--- /dev/null
+++ b/nnUNet/nnunetv2/utilities/label_handling/label_handling.py
@@ -0,0 +1,322 @@
+from __future__ import annotations
+from time import time
+from typing import Union, List, Tuple, Type
+
+import numpy as np
+import torch
+from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice
+from batchgenerators.utilities.file_and_folder_operations import join
+
+import nnunetv2
+from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
+from nnunetv2.utilities.helpers import softmax_helper_dim0
+
+from typing import TYPE_CHECKING
+
+# see https://adamj.eu/tech/2021/05/13/python-type-hints-how-to-fix-circular-imports/
+if TYPE_CHECKING:
+ from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
+
+
+class LabelManager(object):
+ def __init__(self, label_dict: dict, regions_class_order: Union[List[int], None], force_use_labels: bool = False,
+ inference_nonlin=None):
+ self._sanity_check(label_dict)
+ self.label_dict = label_dict
+ self.regions_class_order = regions_class_order
+ self._force_use_labels = force_use_labels
+
+ if force_use_labels:
+ self._has_regions = False
+ else:
+ self._has_regions: bool = any(
+ [isinstance(i, (tuple, list)) and len(i) > 1 for i in self.label_dict.values()])
+
+ self._ignore_label: Union[None, int] = self._determine_ignore_label()
+ self._all_labels: List[int] = self._get_all_labels()
+
+ self._regions: Union[None, List[Union[int, Tuple[int, ...]]]] = self._get_regions()
+
+ if self.has_ignore_label:
+ assert self.ignore_label == max(
+ self.all_labels) + 1, 'If you use the ignore label it must have the highest ' \
+ 'label value! It cannot be 0 or in between other labels. ' \
+ 'Sorry bro.'
+
+ if inference_nonlin is None:
+ self.inference_nonlin = torch.sigmoid if self.has_regions else softmax_helper_dim0
+ else:
+ self.inference_nonlin = inference_nonlin
+
+ def _sanity_check(self, label_dict: dict):
+ if not 'background' in label_dict.keys():
+ raise RuntimeError('Background label not declared (remeber that this should be label 0!)')
+ bg_label = label_dict['background']
+ if isinstance(bg_label, (tuple, list)):
+ raise RuntimeError(f"Background label must be 0. Not a list. Not a tuple. Your background label: {bg_label}")
+ assert int(bg_label) == 0, f"Background label must be 0. Your background label: {bg_label}"
+ # not sure if we want to allow regions that contain background. I don't immediately see how this could cause
+ # problems so we allow it for now. That doesn't mean that this is explicitly supported. It could be that this
+ # just crashes.
+
+ def _get_all_labels(self) -> List[int]:
+ all_labels = []
+ for k, r in self.label_dict.items():
+ # ignore label is not going to be used, hence the name. Duh.
+ if k == 'ignore':
+ continue
+ if isinstance(r, (tuple, list)):
+ for ri in r:
+ all_labels.append(int(ri))
+ else:
+ all_labels.append(int(r))
+ all_labels = list(np.unique(all_labels))
+ all_labels.sort()
+ return all_labels
+
+ def _get_regions(self) -> Union[None, List[Union[int, Tuple[int, ...]]]]:
+ if not self._has_regions or self._force_use_labels:
+ return None
+ else:
+ assert self.regions_class_order is not None, 'if region-based training is requested then you need to ' \
+ 'define regions_class_order!'
+ regions = []
+ for k, r in self.label_dict.items():
+ # ignore ignore label
+ if k == 'ignore':
+ continue
+ # ignore regions that are background
+ if (np.isscalar(r) and r == 0) \
+ or \
+ (isinstance(r, (tuple, list)) and len(np.unique(r)) == 1 and np.unique(r)[0] == 0):
+ continue
+ if isinstance(r, list):
+ r = tuple(r)
+ regions.append(r)
+ assert len(self.regions_class_order) == len(regions), 'regions_class_order must have as ' \
+ 'many entries as there are ' \
+ 'regions'
+ return regions
+
+ def _determine_ignore_label(self) -> Union[None, int]:
+ ignore_label = self.label_dict.get('ignore')
+ if ignore_label is not None:
+ assert isinstance(ignore_label, int), f'Ignore label has to be an integer. It cannot be a region ' \
+ f'(list/tuple). Got {type(ignore_label)}.'
+ return ignore_label
+
+ @property
+ def has_regions(self) -> bool:
+ return self._has_regions
+
+ @property
+ def has_ignore_label(self) -> bool:
+ return self.ignore_label is not None
+
+ @property
+ def all_regions(self) -> Union[None, List[Union[int, Tuple[int, ...]]]]:
+ return self._regions
+
+ @property
+ def all_labels(self) -> List[int]:
+ return self._all_labels
+
+ @property
+ def ignore_label(self) -> Union[None, int]:
+ return self._ignore_label
+
+ def apply_inference_nonlin(self, logits: Union[np.ndarray, torch.Tensor]) -> \
+ Union[np.ndarray, torch.Tensor]:
+ """
+ logits has to have shape (c, x, y(, z)) where c is the number of classes/regions
+ """
+ if isinstance(logits, np.ndarray):
+ logits = torch.from_numpy(logits)
+
+ with torch.no_grad():
+ # softmax etc is not implemented for half
+ logits = logits.float()
+ probabilities = self.inference_nonlin(logits)
+
+ return probabilities
+
+ def convert_probabilities_to_segmentation(self, predicted_probabilities: Union[np.ndarray, torch.Tensor]) -> \
+ Union[np.ndarray, torch.Tensor]:
+ """
+ assumes that inference_nonlinearity was already applied!
+
+ predicted_probabilities has to have shape (c, x, y(, z)) where c is the number of classes/regions
+ """
+ if not isinstance(predicted_probabilities, (np.ndarray, torch.Tensor)):
+ raise RuntimeError(f"Unexpected input type. Expected np.ndarray or torch.Tensor,"
+ f" got {type(predicted_probabilities)}")
+
+ if self.has_regions:
+ assert self.regions_class_order is not None, 'if region-based training is requested then you need to ' \
+ 'define regions_class_order!'
+ # check correct number of outputs
+ assert predicted_probabilities.shape[0] == self.num_segmentation_heads, \
+ f'unexpected number of channels in predicted_probabilities. Expected {self.num_segmentation_heads}, ' \
+ f'got {predicted_probabilities.shape[0]}. Remeber that predicted_probabilities should have shape ' \
+ f'(c, x, y(, z)).'
+
+ if self.has_regions:
+ if isinstance(predicted_probabilities, np.ndarray):
+ segmentation = np.zeros(predicted_probabilities.shape[1:], dtype=np.uint16)
+ else:
+ # no uint16 in torch
+ segmentation = torch.zeros(predicted_probabilities.shape[1:], dtype=torch.int16,
+ device=predicted_probabilities.device)
+ for i, c in enumerate(self.regions_class_order):
+ segmentation[predicted_probabilities[i] > 0.5] = c
+ else:
+ segmentation = predicted_probabilities.argmax(0)
+
+ return segmentation
+
+ def convert_logits_to_segmentation(self, predicted_logits: Union[np.ndarray, torch.Tensor]) -> \
+ Union[np.ndarray, torch.Tensor]:
+ input_is_numpy = isinstance(predicted_logits, np.ndarray)
+ probabilities = self.apply_inference_nonlin(predicted_logits)
+ if input_is_numpy and isinstance(probabilities, torch.Tensor):
+ probabilities = probabilities.cpu().numpy()
+ return self.convert_probabilities_to_segmentation(probabilities)
+
+ def revert_cropping_on_probabilities(self, predicted_probabilities: Union[torch.Tensor, np.ndarray],
+ bbox: List[List[int]],
+ original_shape: Union[List[int], Tuple[int, ...]]):
+ """
+ ONLY USE THIS WITH PROBABILITIES, DO NOT USE LOGITS AND DO NOT USE FOR SEGMENTATION MAPS!!!
+
+ predicted_probabilities must be (c, x, y(, z))
+
+ Why do we do this here? Well if we pad probabilities we need to make sure that convert_logits_to_segmentation
+ correctly returns background in the padded areas. Also we want to ba able to look at the padded probabilities
+ and not have strange artifacts.
+ Only LabelManager knows how this needs to be done. So let's let him/her do it, ok?
+ """
+ # revert cropping
+ probs_reverted_cropping = np.zeros((predicted_probabilities.shape[0], *original_shape),
+ dtype=predicted_probabilities.dtype) \
+ if isinstance(predicted_probabilities, np.ndarray) else \
+ torch.zeros((predicted_probabilities.shape[0], *original_shape), dtype=predicted_probabilities.dtype)
+
+ if not self.has_regions:
+ probs_reverted_cropping[0] = 1
+
+ slicer = bounding_box_to_slice(bbox)
+ probs_reverted_cropping[tuple([slice(None)] + list(slicer))] = predicted_probabilities
+ return probs_reverted_cropping
+
+ @staticmethod
+ def filter_background(classes_or_regions: Union[List[int], List[Union[int, Tuple[int, ...]]]]):
+ # heck yeah
+ # This is definitely taking list comprehension too far. Enjoy.
+ return [i for i in classes_or_regions if
+ ((not isinstance(i, (tuple, list))) and i != 0)
+ or
+ (isinstance(i, (tuple, list)) and not (
+ len(np.unique(i)) == 1 and np.unique(i)[0] == 0))]
+
+ @property
+ def foreground_regions(self):
+ return self.filter_background(self.all_regions)
+
+ @property
+ def foreground_labels(self):
+ return self.filter_background(self.all_labels)
+
+ @property
+ def num_segmentation_heads(self):
+ if self.has_regions:
+ return len(self.foreground_regions)
+ else:
+ return len(self.all_labels)
+
+
+def get_labelmanager_class_from_plans(plans: dict) -> Type[LabelManager]:
+ if 'label_manager' not in plans.keys():
+ print('No label manager specified in plans. Using default: LabelManager')
+ return LabelManager
+ else:
+ labelmanager_class = recursive_find_python_class(join(nnunetv2.__path__[0], "utilities", "label_handling"),
+ plans['label_manager'],
+ current_module="nnunetv2.utilities.label_handling")
+ return labelmanager_class
+
+
+def convert_labelmap_to_one_hot(segmentation: Union[np.ndarray, torch.Tensor],
+ all_labels: Union[List, torch.Tensor, np.ndarray, tuple],
+ output_dtype=None) -> Union[np.ndarray, torch.Tensor]:
+ """
+ if output_dtype is None then we use np.uint8/torch.uint8
+ if input is torch.Tensor then output will be on the same device
+
+ np.ndarray is faster than torch.Tensor
+
+ if segmentation is torch.Tensor, this function will be faster if it is LongTensor. If it is somethine else we have
+ to cast which takes time.
+
+ IMPORTANT: This function only works properly if your labels are consecutive integers, so something like 0, 1, 2, 3, ...
+ DO NOT use it with 0, 32, 123, 255, ... or whatever (fix your labels, yo)
+ """
+ if isinstance(segmentation, torch.Tensor):
+ result = torch.zeros((len(all_labels), *segmentation.shape),
+ dtype=output_dtype if output_dtype is not None else torch.uint8,
+ device=segmentation.device)
+ # variant 1, 2x faster than 2
+ result.scatter_(0, segmentation[None].long(), 1) # why does this have to be long!?
+ # variant 2, slower than 1
+ # for i, l in enumerate(all_labels):
+ # result[i] = segmentation == l
+ else:
+ result = np.zeros((len(all_labels), *segmentation.shape),
+ dtype=output_dtype if output_dtype is not None else np.uint8)
+ # variant 1, fastest in my testing
+ for i, l in enumerate(all_labels):
+ result[i] = segmentation == l
+ # variant 2. Takes about twice as long so nah
+ # result = np.eye(len(all_labels))[segmentation].transpose((3, 0, 1, 2))
+ return result
+
+
+def determine_num_input_channels(plans_manager: PlansManager,
+ configuration_or_config_manager: Union[str, ConfigurationManager],
+ dataset_json: dict) -> int:
+ if isinstance(configuration_or_config_manager, str):
+ config_manager = plans_manager.get_configuration(configuration_or_config_manager)
+ else:
+ config_manager = configuration_or_config_manager
+
+ label_manager = plans_manager.get_label_manager(dataset_json)
+ num_modalities = len(dataset_json['modality']) if 'modality' in dataset_json.keys() else len(dataset_json['channel_names'])
+
+ # cascade has different number of input channels
+ if config_manager.previous_stage_name is not None:
+ num_label_inputs = len(label_manager.foreground_labels)
+ num_input_channels = num_modalities + num_label_inputs
+ else:
+ num_input_channels = num_modalities
+ return num_input_channels
+
+
+if __name__ == '__main__':
+ # this code used to be able to differentiate variant 1 and 2 to measure time.
+ num_labels = 7
+ seg = np.random.randint(0, num_labels, size=(256, 256, 256), dtype=np.uint8)
+ seg_torch = torch.from_numpy(seg)
+ st = time()
+ onehot_npy = convert_labelmap_to_one_hot(seg, np.arange(num_labels))
+ time_1 = time()
+ onehot_npy2 = convert_labelmap_to_one_hot(seg, np.arange(num_labels))
+ time_2 = time()
+ onehot_torch = convert_labelmap_to_one_hot(seg_torch, np.arange(num_labels))
+ time_torch = time()
+ onehot_torch2 = convert_labelmap_to_one_hot(seg_torch, np.arange(num_labels))
+ time_torch2 = time()
+ print(
+ f'np: {time_1 - st}, np2: {time_2 - time_1}, torch: {time_torch - time_2}, torch2: {time_torch2 - time_torch}')
+ onehot_torch = onehot_torch.numpy()
+ onehot_torch2 = onehot_torch2.numpy()
+ print(np.all(onehot_torch == onehot_npy))
+ print(np.all(onehot_torch2 == onehot_npy))
diff --git a/nnUNet/nnunetv2/utilities/network_initialization.py b/nnUNet/nnunetv2/utilities/network_initialization.py
new file mode 100644
index 0000000..1ead271
--- /dev/null
+++ b/nnUNet/nnunetv2/utilities/network_initialization.py
@@ -0,0 +1,12 @@
+from torch import nn
+
+
+class InitWeights_He(object):
+ def __init__(self, neg_slope=1e-2):
+ self.neg_slope = neg_slope
+
+ def __call__(self, module):
+ if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d):
+ module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope)
+ if module.bias is not None:
+ module.bias = nn.init.constant_(module.bias, 0)
diff --git a/nnUNet/nnunetv2/utilities/overlay_plots.py b/nnUNet/nnunetv2/utilities/overlay_plots.py
new file mode 100644
index 0000000..8b0b9d1
--- /dev/null
+++ b/nnUNet/nnunetv2/utilities/overlay_plots.py
@@ -0,0 +1,275 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import multiprocessing
+from multiprocessing.pool import Pool
+from typing import Tuple, Union
+
+import numpy as np
+import pandas as pd
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunetv2.configuration import default_num_processes
+from nnunetv2.imageio.base_reader_writer import BaseReaderWriter
+from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json
+from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed
+from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
+from nnunetv2.utilities.utils import get_identifiers_from_splitted_dataset_folder, \
+ get_filenames_of_train_images_and_targets
+
+color_cycle = (
+ "000000",
+ "4363d8",
+ "f58231",
+ "3cb44b",
+ "e6194B",
+ "911eb4",
+ "ffe119",
+ "bfef45",
+ "42d4f4",
+ "f032e6",
+ "000075",
+ "9A6324",
+ "808000",
+ "800000",
+ "469990",
+)
+
+
+def hex_to_rgb(hex: str):
+ assert len(hex) == 6
+ return tuple(int(hex[i:i + 2], 16) for i in (0, 2, 4))
+
+
+def generate_overlay(input_image: np.ndarray, segmentation: np.ndarray, mapping: dict = None,
+ color_cycle: Tuple[str, ...] = color_cycle,
+ overlay_intensity: float = 0.6):
+ """
+ image can be 2d greyscale or 2d RGB (color channel in last dimension!)
+
+ Segmentation must be label map of same shape as image (w/o color channels)
+
+ mapping can be label_id -> idx_in_cycle or None
+
+ returned image is scaled to [0, 255] (uint8)!!!
+ """
+ # create a copy of image
+ image = np.copy(input_image)
+
+ if len(image.shape) == 2:
+ image = np.tile(image[:, :, None], (1, 1, 3))
+ elif len(image.shape) == 3:
+ if image.shape[2] == 1:
+ image = np.tile(image, (1, 1, 3))
+ else:
+ raise RuntimeError(f'if 3d image is given the last dimension must be the color channels (3 channels). '
+ f'Only 2D images are supported. Your image shape: {image.shape}')
+ else:
+ raise RuntimeError("unexpected image shape. only 2D images and 2D images with color channels (color in "
+ "last dimension) are supported")
+
+ # rescale image to [0, 255]
+ image = image - image.min()
+ image = image / image.max() * 255
+
+ # create output
+ if mapping is None:
+ uniques = np.sort(pd.unique(segmentation.ravel())) # np.unique(segmentation)
+ mapping = {i: c for c, i in enumerate(uniques)}
+
+ for l in mapping.keys():
+ image[segmentation == l] += overlay_intensity * np.array(hex_to_rgb(color_cycle[mapping[l]]))
+
+ # rescale result to [0, 255]
+ image = image / image.max() * 255
+ return image.astype(np.uint8)
+
+
+def select_slice_to_plot(image: np.ndarray, segmentation: np.ndarray) -> int:
+ """
+ image and segmentation are expected to be 3D
+
+ selects the slice with the largest amount of fg (regardless of label)
+
+ we give image so that we can easily replace this function if needed
+ """
+ fg_mask = segmentation != 0
+ fg_per_slice = fg_mask.sum((1, 2))
+ selected_slice = int(np.argmax(fg_per_slice))
+ return selected_slice
+
+
+def select_slice_to_plot2(image: np.ndarray, segmentation: np.ndarray) -> int:
+ """
+ image and segmentation are expected to be 3D (or 1, x, y)
+
+ selects the slice with the largest amount of fg (how much percent of each class are in each slice? pick slice
+ with highest avg percent)
+
+ we give image so that we can easily replace this function if needed
+ """
+ classes = [i for i in np.sort(pd.unique(segmentation.ravel())) if i != 0]
+ fg_per_slice = np.zeros((image.shape[0], len(classes)))
+ for i, c in enumerate(classes):
+ fg_mask = segmentation == c
+ fg_per_slice[:, i] = fg_mask.sum((1, 2))
+ fg_per_slice[:, i] /= fg_per_slice.sum()
+ fg_per_slice = fg_per_slice.mean(1)
+ return int(np.argmax(fg_per_slice))
+
+
+def plot_overlay(image_file: str, segmentation_file: str, image_reader_writer: BaseReaderWriter, output_file: str,
+ overlay_intensity: float = 0.6):
+ import matplotlib.pyplot as plt
+
+ image, props = image_reader_writer.read_images((image_file, ))
+ image = image[0]
+ seg, props_seg = image_reader_writer.read_seg(segmentation_file)
+ seg = seg[0]
+
+ assert all([i == j for i, j in zip(image.shape, seg.shape)]), "image and seg do not have the same shape: %s, %s" % (
+ image_file, segmentation_file)
+
+ assert len(image.shape) == 3, 'only 3D images/segs are supported'
+
+ selected_slice = select_slice_to_plot2(image, seg)
+ # print(image.shape, selected_slice)
+
+ overlay = generate_overlay(image[selected_slice], seg[selected_slice], overlay_intensity=overlay_intensity)
+
+ plt.imsave(output_file, overlay)
+
+
+def plot_overlay_preprocessed(case_file: str, output_file: str, overlay_intensity: float = 0.6, channel_idx=0):
+ import matplotlib.pyplot as plt
+ data = np.load(case_file)['data']
+ seg = np.load(case_file)['seg'][0]
+
+ assert channel_idx < (data.shape[0]), 'This dataset only supports channel index up to %d' % (data.shape[0] - 1)
+
+ image = data[channel_idx]
+ seg[seg < 0] = 0
+
+ selected_slice = select_slice_to_plot2(image, seg)
+
+ overlay = generate_overlay(image[selected_slice], seg[selected_slice], overlay_intensity=overlay_intensity)
+
+ plt.imsave(output_file, overlay)
+
+
+def multiprocessing_plot_overlay(list_of_image_files, list_of_seg_files, image_reader_writer,
+ list_of_output_files, overlay_intensity,
+ num_processes=8):
+ with multiprocessing.get_context("spawn").Pool(num_processes) as p:
+ r = p.starmap_async(plot_overlay, zip(
+ list_of_image_files, list_of_seg_files, [image_reader_writer] * len(list_of_output_files),
+ list_of_output_files, [overlay_intensity] * len(list_of_output_files)
+ ))
+ r.get()
+
+
+def multiprocessing_plot_overlay_preprocessed(list_of_case_files, list_of_output_files, overlay_intensity,
+ num_processes=8, channel_idx=0):
+ with multiprocessing.get_context("spawn").Pool(num_processes) as p:
+ r = p.starmap_async(plot_overlay_preprocessed, zip(
+ list_of_case_files, list_of_output_files, [overlay_intensity] * len(list_of_output_files),
+ [channel_idx] * len(list_of_output_files)
+ ))
+ r.get()
+
+
+def generate_overlays_from_raw(dataset_name_or_id: Union[int, str], output_folder: str,
+ num_processes: int = 8, channel_idx: int = 0, overlay_intensity: float = 0.6):
+ dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id)
+ folder = join(nnUNet_raw, dataset_name)
+ dataset_json = load_json(join(folder, 'dataset.json'))
+ dataset = get_filenames_of_train_images_and_targets(folder, dataset_json)
+
+ image_files = [v['images'][channel_idx] for v in dataset.values()]
+ seg_files = [v['label'] for v in dataset.values()]
+
+ assert all([isfile(i) for i in image_files])
+ assert all([isfile(i) for i in seg_files])
+
+ maybe_mkdir_p(output_folder)
+ output_files = [join(output_folder, i + '.png') for i in dataset.keys()]
+
+ image_reader_writer = determine_reader_writer_from_dataset_json(dataset_json, image_files[0])()
+ multiprocessing_plot_overlay(image_files, seg_files, image_reader_writer, output_files, overlay_intensity, num_processes)
+
+
+def generate_overlays_from_preprocessed(dataset_name_or_id: Union[int, str], output_folder: str,
+ num_processes: int = 8, channel_idx: int = 0,
+ configuration: str = None,
+ plans_identifier: str = 'nnUNetPlans',
+ overlay_intensity: float = 0.6):
+ dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id)
+ folder = join(nnUNet_preprocessed, dataset_name)
+ if not isdir(folder): raise RuntimeError("run preprocessing for that task first")
+
+ plans = load_json(join(folder, plans_identifier + '.json'))
+ if configuration is None:
+ if '3d_fullres' in plans['configurations'].keys():
+ configuration = '3d_fullres'
+ else:
+ configuration = '2d'
+ data_identifier = plans['configurations'][configuration]["data_identifier"]
+ preprocessed_folder = join(folder, data_identifier)
+
+ if not isdir(preprocessed_folder):
+ raise RuntimeError(f"Preprocessed data folder for configuration {configuration} of plans identifier "
+ f"{plans_identifier} ({dataset_name}) does not exist. Run preprocessing for this "
+ f"configuration first!")
+
+ identifiers = [i[:-4] for i in subfiles(preprocessed_folder, suffix='.npz', join=False)]
+
+ output_files = [join(output_folder, i + '.png') for i in identifiers]
+ image_files = [join(preprocessed_folder, i + ".npz") for i in identifiers]
+
+ maybe_mkdir_p(output_folder)
+ multiprocessing_plot_overlay_preprocessed(image_files, output_files, overlay_intensity=overlay_intensity,
+ num_processes=num_processes, channel_idx=channel_idx)
+
+
+def entry_point_generate_overlay():
+ import argparse
+ parser = argparse.ArgumentParser("Plots png overlays of the slice with the most foreground. Note that this "
+ "disregards spacing information!")
+ parser.add_argument('-d', type=str, help="Dataset name or id", required=True)
+ parser.add_argument('-o', type=str, help="output folder", required=True)
+ parser.add_argument('-np', type=int, default=default_num_processes, required=False,
+ help=f"number of processes used. Default: {default_num_processes}")
+ parser.add_argument('-channel_idx', type=int, default=0, required=False,
+ help="channel index used (0 = _0000). Default: 0")
+ parser.add_argument('--use_raw', action='store_true', required=False, help="if set then we use raw data. else "
+ "we use preprocessed")
+ parser.add_argument('-p', type=str, required=False, default='nnUNetPlans',
+ help='plans identifier. Only used if --use_raw is not set! Default: nnUNetPlans')
+ parser.add_argument('-c', type=str, required=False, default=None,
+ help='configuration name. Only used if --use_raw is not set! Default: None = '
+ '3d_fullres if available, else 2d')
+ parser.add_argument('-overlay_intensity', type=float, required=False, default=0.6,
+ help='overlay intensity. Higher = brighter/less transparent')
+
+
+ args = parser.parse_args()
+
+ if args.use_raw:
+ generate_overlays_from_raw(args.d, args.o, args.np, args.channel_idx,
+ overlay_intensity=args.overlay_intensity)
+ else:
+ generate_overlays_from_preprocessed(args.d, args.o, args.np, args.channel_idx, args.c, args.p,
+ overlay_intensity=args.overlay_intensity)
+
+
+if __name__ == '__main__':
+ entry_point_generate_overlay()
\ No newline at end of file
diff --git a/nnUNet/nnunetv2/utilities/plans_handling/__init__.py b/nnUNet/nnunetv2/utilities/plans_handling/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/nnUNet/nnunetv2/utilities/plans_handling/plans_handler.py b/nnUNet/nnunetv2/utilities/plans_handling/plans_handler.py
new file mode 100644
index 0000000..6c39fd1
--- /dev/null
+++ b/nnUNet/nnunetv2/utilities/plans_handling/plans_handler.py
@@ -0,0 +1,307 @@
+from __future__ import annotations
+
+import dynamic_network_architectures
+from copy import deepcopy
+from functools import lru_cache, partial
+from typing import Union, Tuple, List, Type, Callable
+
+import numpy as np
+import torch
+
+from nnunetv2.preprocessing.resampling.utils import recursive_find_resampling_fn_by_name
+from torch import nn
+
+import nnunetv2
+from batchgenerators.utilities.file_and_folder_operations import load_json, join
+
+from nnunetv2.imageio.reader_writer_registry import recursive_find_reader_writer_by_name
+from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
+from nnunetv2.utilities.label_handling.label_handling import get_labelmanager_class_from_plans
+
+
+# see https://adamj.eu/tech/2021/05/13/python-type-hints-how-to-fix-circular-imports/
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from nnunetv2.utilities.label_handling.label_handling import LabelManager
+ from nnunetv2.imageio.base_reader_writer import BaseReaderWriter
+ from nnunetv2.preprocessing.preprocessors.default_preprocessor import DefaultPreprocessor
+ from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner
+
+
+class ConfigurationManager(object):
+ def __init__(self, configuration_dict: dict):
+ self.configuration = configuration_dict
+
+ def __repr__(self):
+ return self.configuration.__repr__()
+
+ @property
+ def data_identifier(self) -> str:
+ return self.configuration['data_identifier']
+
+ @property
+ def preprocessor_name(self) -> str:
+ return self.configuration['preprocessor_name']
+
+ @property
+ @lru_cache(maxsize=1)
+ def preprocessor_class(self) -> Type[DefaultPreprocessor]:
+ preprocessor_class = recursive_find_python_class(join(nnunetv2.__path__[0], "preprocessing"),
+ self.preprocessor_name,
+ current_module="nnunetv2.preprocessing")
+ return preprocessor_class
+
+ @property
+ def batch_size(self) -> int:
+ return self.configuration['batch_size']
+
+ @property
+ def patch_size(self) -> List[int]:
+ return self.configuration['patch_size']
+
+ @property
+ def median_image_size_in_voxels(self) -> List[int]:
+ return self.configuration['median_image_size_in_voxels']
+
+ @property
+ def spacing(self) -> List[float]:
+ return self.configuration['spacing']
+
+ @property
+ def normalization_schemes(self) -> List[str]:
+ return self.configuration['normalization_schemes']
+
+ @property
+ def use_mask_for_norm(self) -> List[bool]:
+ return self.configuration['use_mask_for_norm']
+
+ @property
+ def UNet_class_name(self) -> str:
+ return self.configuration['UNet_class_name']
+
+ @property
+ @lru_cache(maxsize=1)
+ def UNet_class(self) -> Type[nn.Module]:
+ unet_class = recursive_find_python_class(join(dynamic_network_architectures.__path__[0], "architectures"),
+ self.UNet_class_name,
+ current_module="dynamic_network_architectures.architectures")
+ if unet_class is None:
+ raise RuntimeError('The network architecture specified by the plans file '
+ 'is non-standard (maybe your own?). Fix this by not using '
+ 'ConfigurationManager.UNet_class to instantiate '
+ 'it (probably just overwrite build_network_architecture of your trainer.')
+ return unet_class
+
+ @property
+ def UNet_base_num_features(self) -> int:
+ return self.configuration['UNet_base_num_features']
+
+ @property
+ def n_conv_per_stage_encoder(self) -> List[int]:
+ return self.configuration['n_conv_per_stage_encoder']
+
+ @property
+ def n_conv_per_stage_decoder(self) -> List[int]:
+ return self.configuration['n_conv_per_stage_decoder']
+
+ @property
+ def num_pool_per_axis(self) -> List[int]:
+ return self.configuration['num_pool_per_axis']
+
+ @property
+ def pool_op_kernel_sizes(self) -> List[List[int]]:
+ return self.configuration['pool_op_kernel_sizes']
+
+ @property
+ def conv_kernel_sizes(self) -> List[List[int]]:
+ return self.configuration['conv_kernel_sizes']
+
+ @property
+ def unet_max_num_features(self) -> int:
+ return self.configuration['unet_max_num_features']
+
+ @property
+ @lru_cache(maxsize=1)
+ def resampling_fn_data(self) -> Callable[
+ [Union[torch.Tensor, np.ndarray],
+ Union[Tuple[int, ...], List[int], np.ndarray],
+ Union[Tuple[float, ...], List[float], np.ndarray],
+ Union[Tuple[float, ...], List[float], np.ndarray]
+ ],
+ Union[torch.Tensor, np.ndarray]]:
+ fn = recursive_find_resampling_fn_by_name(self.configuration['resampling_fn_data'])
+ fn = partial(fn, **self.configuration['resampling_fn_data_kwargs'])
+ return fn
+
+ @property
+ @lru_cache(maxsize=1)
+ def resampling_fn_probabilities(self) -> Callable[
+ [Union[torch.Tensor, np.ndarray],
+ Union[Tuple[int, ...], List[int], np.ndarray],
+ Union[Tuple[float, ...], List[float], np.ndarray],
+ Union[Tuple[float, ...], List[float], np.ndarray]
+ ],
+ Union[torch.Tensor, np.ndarray]]:
+ fn = recursive_find_resampling_fn_by_name(self.configuration['resampling_fn_probabilities'])
+ fn = partial(fn, **self.configuration['resampling_fn_probabilities_kwargs'])
+ return fn
+
+ @property
+ @lru_cache(maxsize=1)
+ def resampling_fn_seg(self) -> Callable[
+ [Union[torch.Tensor, np.ndarray],
+ Union[Tuple[int, ...], List[int], np.ndarray],
+ Union[Tuple[float, ...], List[float], np.ndarray],
+ Union[Tuple[float, ...], List[float], np.ndarray]
+ ],
+ Union[torch.Tensor, np.ndarray]]:
+ fn = recursive_find_resampling_fn_by_name(self.configuration['resampling_fn_seg'])
+ fn = partial(fn, **self.configuration['resampling_fn_seg_kwargs'])
+ return fn
+
+ @property
+ def batch_dice(self) -> bool:
+ return self.configuration['batch_dice']
+
+ @property
+ def next_stage_names(self) -> Union[List[str], None]:
+ ret = self.configuration.get('next_stage')
+ if ret is not None:
+ if isinstance(ret, str):
+ ret = [ret]
+ return ret
+
+ @property
+ def previous_stage_name(self) -> Union[str, None]:
+ return self.configuration.get('previous_stage')
+
+
+class PlansManager(object):
+ def __init__(self, plans_file_or_dict: Union[str, dict]):
+ """
+ Why do we need this?
+ 1) resolve inheritance in configurations
+ 2) expose otherwise annoying stuff like getting the label manager or IO class from a string
+ 3) clearly expose the things that are in the plans instead of hiding them in a dict
+ 4) cache shit
+
+ This class does not prevent you from going wild. You can still use the plans directly if you prefer
+ (PlansHandler.plans['key'])
+ """
+ self.plans = plans_file_or_dict if isinstance(plans_file_or_dict, dict) else load_json(plans_file_or_dict)
+
+ def __repr__(self):
+ return self.plans.__repr__()
+
+ def _internal_resolve_configuration_inheritance(self, configuration_name: str,
+ visited: Tuple[str, ...] = None) -> dict:
+ if configuration_name not in self.plans['configurations'].keys():
+ raise ValueError(f'The configuration {configuration_name} does not exist in the plans I have. Valid '
+ f'configuration names are {list(self.plans["configurations"].keys())}.')
+ configuration = deepcopy(self.plans['configurations'][configuration_name])
+ if 'inherits_from' in configuration:
+ parent_config_name = configuration['inherits_from']
+
+ if visited is None:
+ visited = (configuration_name,)
+ else:
+ if parent_config_name in visited:
+ raise RuntimeError(f"Circular dependency detected. The following configurations were visited "
+ f"while solving inheritance (in that order!): {visited}. "
+ f"Current configuration: {configuration_name}. Its parent configuration "
+ f"is {parent_config_name}.")
+ visited = (*visited, configuration_name)
+
+ base_config = self._internal_resolve_configuration_inheritance(parent_config_name, visited)
+ base_config.update(configuration)
+ configuration = base_config
+ return configuration
+
+ @lru_cache(maxsize=10)
+ def get_configuration(self, configuration_name: str):
+ if configuration_name not in self.plans['configurations'].keys():
+ raise RuntimeError(f"Requested configuration {configuration_name} not found in plans. "
+ f"Available configurations: {list(self.plans['configurations'].keys())}")
+
+ configuration_dict = self._internal_resolve_configuration_inheritance(configuration_name)
+ return ConfigurationManager(configuration_dict)
+
+ @property
+ def dataset_name(self) -> str:
+ return self.plans['dataset_name']
+
+ @property
+ def plans_name(self) -> str:
+ return self.plans['plans_name']
+
+ @property
+ def original_median_spacing_after_transp(self) -> List[float]:
+ return self.plans['original_median_spacing_after_transp']
+
+ @property
+ def original_median_shape_after_transp(self) -> List[float]:
+ return self.plans['original_median_shape_after_transp']
+
+ @property
+ @lru_cache(maxsize=1)
+ def image_reader_writer_class(self) -> Type[BaseReaderWriter]:
+ return recursive_find_reader_writer_by_name(self.plans['image_reader_writer'])
+
+ @property
+ def transpose_forward(self) -> List[int]:
+ return self.plans['transpose_forward']
+
+ @property
+ def transpose_backward(self) -> List[int]:
+ return self.plans['transpose_backward']
+
+ @property
+ def available_configurations(self) -> List[str]:
+ return list(self.plans['configurations'].keys())
+
+ @property
+ @lru_cache(maxsize=1)
+ def experiment_planner_class(self) -> Type[ExperimentPlanner]:
+ planner_name = self.experiment_planner_name
+ experiment_planner = recursive_find_python_class(join(nnunetv2.__path__[0], "experiment_planning"),
+ planner_name,
+ current_module="nnunetv2.experiment_planning")
+ return experiment_planner
+
+ @property
+ def experiment_planner_name(self) -> str:
+ return self.plans['experiment_planner_used']
+
+ @property
+ @lru_cache(maxsize=1)
+ def label_manager_class(self) -> Type[LabelManager]:
+ return get_labelmanager_class_from_plans(self.plans)
+
+ def get_label_manager(self, dataset_json: dict, **kwargs) -> LabelManager:
+ return self.label_manager_class(label_dict=dataset_json['labels'],
+ regions_class_order=dataset_json.get('regions_class_order'),
+ **kwargs)
+
+ @property
+ def foreground_intensity_properties_per_channel(self) -> dict:
+ if 'foreground_intensity_properties_per_channel' not in self.plans.keys():
+ if 'foreground_intensity_properties_by_modality' in self.plans.keys():
+ return self.plans['foreground_intensity_properties_by_modality']
+ return self.plans['foreground_intensity_properties_per_channel']
+
+
+if __name__ == '__main__':
+ from nnunetv2.paths import nnUNet_preprocessed
+ from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
+
+ plans = load_json(join(nnUNet_preprocessed, maybe_convert_to_dataset_name(3), 'nnUNetPlans.json'))
+ # build new configuration that inherits from 3d_fullres
+ plans['configurations']['3d_fullres_bs4'] = {
+ 'batch_size': 4,
+ 'inherits_from': '3d_fullres'
+ }
+ # now get plans and configuration managers
+ plans_manager = PlansManager(plans)
+ configuration_manager = plans_manager.get_configuration('3d_fullres_bs4')
+ print(configuration_manager) # look for batch size 4
diff --git a/nnUNet/nnunetv2/utilities/tensor_utilities.py b/nnUNet/nnunetv2/utilities/tensor_utilities.py
new file mode 100644
index 0000000..b16ffca
--- /dev/null
+++ b/nnUNet/nnunetv2/utilities/tensor_utilities.py
@@ -0,0 +1,15 @@
+from typing import Union, List, Tuple
+
+import numpy as np
+import torch
+
+
+def sum_tensor(inp: torch.Tensor, axes: Union[np.ndarray, Tuple, List], keepdim: bool = False) -> torch.Tensor:
+ axes = np.unique(axes).astype(int)
+ if keepdim:
+ for ax in axes:
+ inp = inp.sum(int(ax), keepdim=True)
+ else:
+ for ax in sorted(axes, reverse=True):
+ inp = inp.sum(int(ax))
+ return inp
diff --git a/nnUNet/nnunetv2/utilities/utils.py b/nnUNet/nnunetv2/utilities/utils.py
new file mode 100644
index 0000000..b0c16a2
--- /dev/null
+++ b/nnUNet/nnunetv2/utilities/utils.py
@@ -0,0 +1,69 @@
+# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center
+# (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os.path
+from functools import lru_cache
+from typing import Union
+
+from batchgenerators.utilities.file_and_folder_operations import *
+import numpy as np
+import re
+
+from nnunetv2.paths import nnUNet_raw
+
+
+def get_identifiers_from_splitted_dataset_folder(folder: str, file_ending: str):
+ files = subfiles(folder, suffix=file_ending, join=False)
+ # all files have a 4 digit channel index (_XXXX)
+ crop = len(file_ending) + 5
+ files = [i[:-crop] for i in files]
+ # only unique image ids
+ files = np.unique(files)
+ return files
+
+
+def create_lists_from_splitted_dataset_folder(folder: str, file_ending: str, identifiers: List[str] = None) -> List[
+ List[str]]:
+ """
+ does not rely on dataset.json
+ """
+ if identifiers is None:
+ identifiers = get_identifiers_from_splitted_dataset_folder(folder, file_ending)
+ files = subfiles(folder, suffix=file_ending, join=False, sort=True)
+ list_of_lists = []
+ for f in identifiers:
+ p = re.compile(re.escape(f) + r"_\d\d\d\d" + re.escape(file_ending))
+ list_of_lists.append([join(folder, i) for i in files if p.fullmatch(i)])
+ return list_of_lists
+
+
+def get_filenames_of_train_images_and_targets(raw_dataset_folder: str, dataset_json: dict = None):
+ if dataset_json is None:
+ dataset_json = load_json(join(raw_dataset_folder, 'dataset.json'))
+
+ if 'dataset' in dataset_json.keys():
+ dataset = dataset_json['dataset']
+ for k in dataset.keys():
+ dataset[k]['label'] = os.path.abspath(join(raw_dataset_folder, dataset[k]['label'])) if not os.path.isabs(dataset[k]['label']) else dataset[k]['label']
+ dataset[k]['images'] = [os.path.abspath(join(raw_dataset_folder, i)) if not os.path.isabs(i) else i for i in dataset[k]['images']]
+ else:
+ identifiers = get_identifiers_from_splitted_dataset_folder(join(raw_dataset_folder, 'imagesTr'), dataset_json['file_ending'])
+ images = create_lists_from_splitted_dataset_folder(join(raw_dataset_folder, 'imagesTr'), dataset_json['file_ending'], identifiers)
+ segs = [join(raw_dataset_folder, 'labelsTr', i + dataset_json['file_ending']) for i in identifiers]
+ dataset = {i: {'images': im, 'label': se} for i, im, se in zip(identifiers, images, segs)}
+ return dataset
+
+
+if __name__ == '__main__':
+ print(get_filenames_of_train_images_and_targets(join(nnUNet_raw, 'Dataset002_Heart')))
diff --git a/nnUNet/readme.md b/nnUNet/readme.md
new file mode 100644
index 0000000..1b6ba5f
--- /dev/null
+++ b/nnUNet/readme.md
@@ -0,0 +1,133 @@
+# Welcome to the new nnU-Net!
+
+Click [here](https://github.com/MIC-DKFZ/nnUNet/tree/nnunetv1) if you were looking for the old one instead.
+
+Coming from V1? Check out the [TLDR Migration Guide](documentation/tldr_migration_guide_from_v1.md). Reading the rest of the documentation is still strongly recommended ;-)
+
+# What is nnU-Net?
+Image datasets are enormously diverse: image dimensionality (2D, 3D), modalities/input channels (RGB image, CT, MRI, microscopy, ...),
+image sizes, voxel sizes, class ratio, target structure properties and more change substantially between datasets.
+Traditionally, given a new problem, a tailored solution needs to be manually designed and optimized - a process that
+is prone to errors, not scalable and where success is overwhelmingly determined by the skill of the experimenter. Even
+for experts, this process is anything but simple: there are not only many design choices and data properties that need to
+be considered, but they are also tightly interconnected, rendering reliable manual pipeline optimization all but impossible!
+
+![nnU-Net overview](documentation/assets/nnU-Net_overview.png)
+
+**nnU-Net is a semantic segmentation method that automatically adapts to a given dataset. It will analyze the provided
+training cases and automatically configure a matching U-Net-based segmentation pipeline. No expertise required on your
+end! You can simply train the models and use them for your application**.
+
+Upon release, nnU-Net was evaluated on 23 datasets belonging to competitions from the biomedical domain. Despite competing
+with handcrafted solutions for each respective dataset, nnU-Net's fully automated pipeline scored several first places on
+open leaderboards! Since then nnU-Net has stood the test of time: it continues to be used as a baseline and method
+development framework ([9 out of 10 challenge winners at MICCAI 2020](https://arxiv.org/abs/2101.00232) and 5 out of 7
+in MICCAI 2021 built their methods on top of nnU-Net,
+ [we won AMOS2022 with nnU-Net](https://amos22.grand-challenge.org/final-ranking/))!
+
+Please cite the [following paper](https://www.google.com/url?q=https://www.nature.com/articles/s41592-020-01008-z&sa=D&source=docs&ust=1677235958581755&usg=AOvVaw3dWL0SrITLhCJUBiNIHCQO) when using nnU-Net:
+
+ Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring
+ method for deep learning-based biomedical image segmentation. Nature methods, 18(2), 203-211.
+
+
+## What can nnU-Net do for you?
+If you are a **domain scientist** (biologist, radiologist, ...) looking to analyze your own images, nnU-Net provides
+an out-of-the-box solution that is all but guaranteed to provide excellent results on your individual dataset. Simply
+convert your dataset into the nnU-Net format and enjoy the power of AI - no expertise required!
+
+If you are an **AI researcher** developing segmentation methods, nnU-Net:
+- offers a fantastic out-of-the-box applicable baseline algorithm to compete against
+- can act as a method development framework to test your contribution on a large number of datasets without having to
+tune individual pipelines (for example evaluating a new loss function)
+- provides a strong starting point for further dataset-specific optimizations. This is particularly used when competing
+in segmentation challenges
+- provides a new perspective on the design of segmentation methods: maybe you can find better connections between
+dataset properties and best-fitting segmentation pipelines?
+
+## What is the scope of nnU-Net?
+nnU-Net is built for semantic segmentation. It can handle 2D and 3D images with arbitrary
+input modalities/channels. It can understand voxel spacings, anisotropies and is robust even when classes are highly
+imbalanced.
+
+nnU-Net relies on supervised learning, which means that you need to provide training cases for your application. The number of
+required training cases varies heavily depending on the complexity of the segmentation problem. No
+one-fits-all number can be provided here! nnU-Net does not require more training cases than other solutions - maybe
+even less due to our extensive use of data augmentation.
+
+nnU-Net expects to be able to process entire images at once during preprocessing and postprocessing, so it cannot
+handle enormous images. As a reference: we tested images from 40x40x40 pixels all the way up to 1500x1500x1500 in 3D
+and 40x40 up to ~30000x30000 in 2D! If your RAM allows it, larger is always possible.
+
+## How does nnU-Net work?
+Given a new dataset, nnU-Net will systematically analyze the provided training cases and create a 'dataset fingerprint'.
+nnU-Net then creates several U-Net configurations for each dataset:
+- `2d`: a 2D U-Net (for 2D and 3D datasets)
+- `3d_fullres`: a 3D U-Net that operates on a high image resolution (for 3D datasets only)
+- `3d_lowres` → `3d_cascade_fullres`: a 3D U-Net cascade where first a 3D U-Net operates on low resolution images and
+then a second high-resolution 3D U-Net refined the predictions of the former (for 3D datasets with large image sizes only)
+
+**Note that not all U-Net configurations are created for all datasets. In datasets with small image sizes, the
+U-Net cascade (and with it the 3d_lowres configuration) is omitted because the patch size of the full
+resolution U-Net already covers a large part of the input images.**
+
+nnU-Net configures its segmentation pipelines based on a three-step recipe:
+- **Fixed parameters** are not adapted. During development of nnU-Net we identified a robust configuration (that is, certain architecture and training properties) that can
+simply be used all the time. This includes, for example, nnU-Net's loss function, (most of the) data augmentation strategy and learning rate.
+- **Rule-based parameters** use the dataset fingerprint to adapt certain segmentation pipeline properties by following
+hard-coded heuristic rules. For example, the network topology (pooling behavior and depth of the network architecture)
+are adapted to the patch size; the patch size, network topology and batch size are optimized jointly given some GPU
+memory constraint.
+- **Empirical parameters** are essentially trial-and-error. For example the selection of the best U-net configuration
+for the given dataset (2D, 3D full resolution, 3D low resolution, 3D cascade) and the optimization of the postprocessing strategy.
+
+## How to get started?
+Read these:
+- [Installation instructions](documentation/installation_instructions.md)
+- [Dataset conversion](documentation/dataset_format.md)
+- [Usage instructions](documentation/how_to_use_nnunet.md)
+
+Additional information:
+- [Region-based training](documentation/region_based_training.md)
+- [Manual data splits](documentation/manual_data_splits.md)
+- [Pretraining and finetuning](documentation/pretraining_and_finetuning.md)
+- [Intensity Normalization in nnU-Net](documentation/explanation_normalization.md)
+- [Manually editing nnU-Net configurations](documentation/explanation_plans_files.md)
+- [Extending nnU-Net](documentation/extending_nnunet.md)
+- [What is different in V2?](documentation/changelog.md)
+
+[//]: # (- [Ignore label](documentation/ignore_label.md))
+
+## Where does nnU-net perform well and where does it not perform?
+nnU-Net excels in segmentation problems that need to be solved by training from scratch,
+for example: research applications that feature non-standard image modalities and input channels,
+challenge datasets from the biomedical domain, majority of 3D segmentation problems, etc . We have yet to find a
+dataset for which nnU-Net's working principle fails!
+
+Note: On standard segmentation
+problems, such as 2D RGB images in ADE20k and Cityscapes, fine-tuning a foundation model (that was pretrained on a large corpus of
+similar images, e.g. Imagenet 22k, JFT-300M) will provide better performance than nnU-Net! That is simply because these
+models allow much better initialization. Foundation models are not supported by nnU-Net as
+they 1) are not useful for segmentation problems that deviate from the standard setting (see above mentioned
+datasets), 2) would typically only support 2D architectures and 3) conflict with our core design principle of carefully adapting
+the network topology for each dataset (if the topology is changed one can no longer transfer pretrained weights!)
+
+## What happened to the old nnU-Net?
+The core of the old nnU-Net was hacked together in a short time period while participating in the Medical Segmentation
+Decathlon challenge in 2018. Consequently, code structure and quality were not the best. Many features
+were added later on and didn't quite fit into the nnU-Net design principles. Overall quite messy, really. And annoying to work with.
+
+nnU-Net V2 is a complete overhaul. The "delete everything and start again" kind. So everything is better
+(in the author's opinion haha). While the segmentation performance [remains the same](https://docs.google.com/spreadsheets/d/13gqjIKEMPFPyMMMwA1EML57IyoBjfC3-QCTn4zRN_Mg/edit?usp=sharing), a lot of cool stuff has been added.
+It is now also much easier to use it as a development framework and to manually fine-tune its configuration to new
+datasets. A big driver for the reimplementation was also the emergence of [Helmholtz Imaging](http://helmholtz-imaging.de),
+prompting us to extend nnU-Net to more image formats and domains. Take a look [here](documentation/changelog.md) for some highlights.
+
+# Acknowledgements
+
+
+
+
+nnU-Net is developed and maintained by the Applied Computer Vision Lab (ACVL) of [Helmholtz Imaging](http://helmholtz-imaging.de)
+and the [Division of Medical Image Computing](https://www.dkfz.de/en/mic/index.php) at the
+[German Cancer Research Center (DKFZ)](https://www.dkfz.de/en/index.html).
diff --git a/nnUNet/setup.cfg b/nnUNet/setup.cfg
new file mode 100644
index 0000000..002a15d
--- /dev/null
+++ b/nnUNet/setup.cfg
@@ -0,0 +1,2 @@
+[metadata]
+description-file = readme.md
\ No newline at end of file
diff --git a/nnUNet/setup.py b/nnUNet/setup.py
new file mode 100755
index 0000000..5b9271f
--- /dev/null
+++ b/nnUNet/setup.py
@@ -0,0 +1,62 @@
+from setuptools import setup, find_namespace_packages
+
+setup(name='nnunetv2',
+ packages=find_namespace_packages(include=["nnunetv2", "nnunetv2.*"]),
+ version='2.1.1',
+ description='nnU-Net. Framework for out-of-the box biomedical image segmentation.',
+ url='https://github.com/MIC-DKFZ/nnUNet',
+ author='Helmholtz Imaging Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center',
+ author_email='f.isensee@dkfz-heidelberg.de',
+ license='Apache License Version 2.0, January 2004',
+ python_requires=">=3.9",
+ install_requires=[
+ "torch>=2.0.0",
+ "acvl-utils>=0.2",
+ "dynamic-network-architectures>=0.2",
+ "tqdm",
+ "dicom2nifti",
+ "scikit-image>=0.14",
+ "scipy",
+ "batchgenerators>=0.25",
+ "numpy",
+ "scikit-learn",
+ "scikit-image>=0.19.3",
+ "SimpleITK>=2.2.1",
+ "pandas",
+ "graphviz",
+ 'tifffile',
+ 'requests',
+ "nibabel",
+ "matplotlib",
+ "seaborn",
+ "imagecodecs",
+ "yacs"
+ ],
+ entry_points={
+ 'console_scripts': [
+ 'nnUNetv2_plan_and_preprocess = nnunetv2.experiment_planning.plan_and_preprocess_entrypoints:plan_and_preprocess_entry', # api available
+ 'nnUNetv2_extract_fingerprint = nnunetv2.experiment_planning.plan_and_preprocess_entrypoints:extract_fingerprint_entry', # api available
+ 'nnUNetv2_plan_experiment = nnunetv2.experiment_planning.plan_and_preprocess_entrypoints:plan_experiment_entry', # api available
+ 'nnUNetv2_preprocess = nnunetv2.experiment_planning.plan_and_preprocess_entrypoints:preprocess_entry', # api available
+ 'nnUNetv2_train = nnunetv2.run.run_training:run_training_entry', # api available
+ 'nnUNetv2_predict_from_modelfolder = nnunetv2.inference.predict_from_raw_data:predict_entry_point_modelfolder', # api available
+ 'nnUNetv2_predict = nnunetv2.inference.predict_from_raw_data:predict_entry_point', # api available
+ 'nnUNetv2_convert_old_nnUNet_dataset = nnunetv2.dataset_conversion.convert_raw_dataset_from_old_nnunet_format:convert_entry_point', # api available
+ 'nnUNetv2_find_best_configuration = nnunetv2.evaluation.find_best_configuration:find_best_configuration_entry_point', # api available
+ 'nnUNetv2_determine_postprocessing = nnunetv2.postprocessing.remove_connected_components:entry_point_determine_postprocessing_folder', # api available
+ 'nnUNetv2_apply_postprocessing = nnunetv2.postprocessing.remove_connected_components:entry_point_apply_postprocessing', # api available
+ 'nnUNetv2_ensemble = nnunetv2.ensembling.ensemble:entry_point_ensemble_folders', # api available
+ 'nnUNetv2_accumulate_crossval_results = nnunetv2.evaluation.find_best_configuration:accumulate_crossval_results_entry_point', # api available
+ 'nnUNetv2_plot_overlay_pngs = nnunetv2.utilities.overlay_plots:entry_point_generate_overlay', # api available
+ 'nnUNetv2_download_pretrained_model_by_url = nnunetv2.model_sharing.entry_points:download_by_url', # api available
+ 'nnUNetv2_install_pretrained_model_from_zip = nnunetv2.model_sharing.entry_points:install_from_zip_entry_point', # api available
+ 'nnUNetv2_export_model_to_zip = nnunetv2.model_sharing.entry_points:export_pretrained_model_entry', # api available
+ 'nnUNetv2_move_plans_between_datasets = nnunetv2.experiment_planning.plans_for_pretraining.move_plans_between_datasets:entry_point_move_plans_between_datasets', # api available
+ 'nnUNetv2_evaluate_folder = nnunetv2.evaluation.evaluate_predictions:evaluate_folder_entry_point', # api available
+ 'nnUNetv2_evaluate_simple = nnunetv2.evaluation.evaluate_predictions:evaluate_simple_entry_point', # api available
+ 'nnUNetv2_convert_MSD_dataset = nnunetv2.dataset_conversion.convert_MSD_dataset:entry_point' # api available
+ ],
+ },
+ keywords=['deep learning', 'image segmentation', 'medical image analysis',
+ 'medical image segmentation', 'nnU-Net', 'nnunet']
+ )
diff --git a/process.py b/process.py
new file mode 100644
index 0000000..bbabfdd
--- /dev/null
+++ b/process.py
@@ -0,0 +1,158 @@
+import time
+import SimpleITK as sitk
+import numpy as np
+np.lib.index_tricks.int = np.uint16
+import ants
+from os.path import join
+from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
+import json
+from custom_algorithm import Hanseg2023Algorithm
+
+LABEL_dict = {
+ "background": 0,
+ "A_Carotid_L": 1,
+ "A_Carotid_R": 2,
+ "Arytenoid": 3,
+ "Bone_Mandible": 4,
+ "Brainstem": 5,
+ "BuccalMucosa": 6,
+ "Cavity_Oral": 7,
+ "Cochlea_L": 8,
+ "Cochlea_R": 9,
+ "Cricopharyngeus": 10,
+ "Esophagus_S": 11,
+ "Eye_AL": 12,
+ "Eye_AR": 13,
+ "Eye_PL": 14,
+ "Eye_PR": 15,
+ "Glnd_Lacrimal_L": 16,
+ "Glnd_Lacrimal_R": 17,
+ "Glnd_Submand_L": 18,
+ "Glnd_Submand_R": 19,
+ "Glnd_Thyroid": 20,
+ "Glottis": 21,
+ "Larynx_SG": 22,
+ "Lips": 23,
+ "OpticChiasm": 24,
+ "OpticNrv_L": 25,
+ "OpticNrv_R": 26,
+ "Parotid_L": 27,
+ "Parotid_R": 28,
+ "Pituitary": 29,
+ "SpinalCord": 30,
+}
+
+
+def ants_2_itk(image):
+ imageITK = sitk.GetImageFromArray(image.numpy().T)
+ imageITK.SetOrigin(image.origin)
+ imageITK.SetSpacing(image.spacing)
+ imageITK.SetDirection(image.direction.reshape(9))
+ return imageITK
+
+def itk_2_ants(image):
+ image_ants = ants.from_numpy(sitk.GetArrayFromImage(image).T,
+ origin=image.GetOrigin(),
+ spacing=image.GetSpacing(),
+ direction=np.array(image.GetDirection()).reshape(3, 3))
+ return image_ants
+
+
+class MyHanseg2023Algorithm(Hanseg2023Algorithm):
+ def __init__(self):
+ super().__init__()
+
+ def predict(self, *, image_ct: ants.ANTsImage, image_mrt1: ants.ANTsImage) -> sitk.Image:
+ print("Computing registration", flush=True)
+ time0reg= time.time_ns()
+ mytx = ants.registration(fixed=image_ct, moving=image_mrt1, type_of_transform='Affine') #, aff_iterations=(150, 150, 150, 150))
+ print(f"Time reg: {(time.time_ns()-time0reg)/1000000000}")
+ warped_MR = ants.apply_transforms(fixed=image_ct, moving=image_mrt1,
+ transformlist=mytx['fwdtransforms'], defaultvalue=image_mrt1.min())
+ trained_model_path = join("/opt", "algorithm", "checkpoint", "nnUNet", "Dataset777_HaNSeg2023", "nnUNetTrainer__nnUNetPlans__3d_fullres")
+ # trained_model_path = join("/usr/DATA/backup_home_dir/jhhan/01_research/01_MICCAI/01_grandchellenge/han_seg/src/HanSeg_2023/nnUNet/dataset/nnUNet_results",
+ # "Dataset777_HaNSeg2023", "nnUNetTrainer__nnUNetPlans__3d_fullres")
+
+ spacing = tuple(map(float,json.load(open(join(trained_model_path, "plans.json"), "r"))["configurations"]["3d_fullres"]["spacing"]))
+ ct_image = ants_2_itk(image_ct)
+ mr_image = ants_2_itk(warped_MR)
+ del image_mrt1
+ del warped_MR
+
+
+ properties = {
+ 'sitk_stuff':
+ {'spacing': ct_image.GetSpacing(),
+ 'origin': ct_image.GetOrigin(),
+ 'direction': ct_image.GetDirection()
+ },
+ # the spacing is inverted with [::-1] because sitk returns the spacing in the wrong order lol. Image arrays
+ # are returned x,y,z but spacing is returned z,y,x. Duh.
+ 'spacing': ct_image.GetSpacing()[::-1]
+ }
+ images = np.vstack([sitk.GetArrayFromImage(ct_image)[None], sitk.GetArrayFromImage(mr_image)[None]]).astype(np.float32)
+ fin_origin = ct_image.GetOrigin()
+ fin_spacing = ct_image.GetSpacing()
+ fin_direction = ct_image.GetDirection()
+ fin_size = ct_image.GetSize()
+ print(fin_spacing)
+ print(spacing)
+ print(fin_size)
+
+ old_shape = np.shape(sitk.GetArrayFromImage(ct_image))
+ del mr_image
+ del ct_image
+ # Shamelessly copied from nnUNet/nnunetv2/preprocessing/resampling/default_resampling.py
+ new_shape = np.array([int(round(i / j * k)) for i, j, k in zip(fin_spacing, spacing[::-1], fin_size)])
+ if new_shape.prod()< 1e8:
+ print(f"Image is not too large ({new_shape.prod()}), using the folds (0,1,2,3,4) with mirror")
+ predictor = nnUNetPredictor(tile_step_size=0.4, use_mirroring=True, perform_everything_on_gpu=True,
+ verbose=True, verbose_preprocessing=True,
+ allow_tqdm=True)
+ predictor.initialize_from_trained_model_folder(trained_model_path, use_folds=(0,1,2,3),
+ checkpoint_name="checkpoint_best.pth")
+ # predictor.allowed_mirroring_axes = (0, 2)
+ elif new_shape.prod()< 1.3e8:
+ print(f"Image is not too large ({new_shape.prod()}), using the folds (0,1,2,3,4)")
+
+ predictor = nnUNetPredictor(tile_step_size=0.6, use_mirroring=True, perform_everything_on_gpu=False,
+ verbose=True, verbose_preprocessing=True,
+ allow_tqdm=True)
+ predictor.initialize_from_trained_model_folder(trained_model_path, use_folds=(0,1,2,3), #(0,1,2,3,4)
+ checkpoint_name="checkpoint_best.pth")
+ elif new_shape.prod()< 1.7e8:
+ print(f"Image is not too large ({new_shape.prod()}), using the 'all' fold with mirror")
+
+ predictor = nnUNetPredictor(tile_step_size=0.4, use_mirroring=True, perform_everything_on_gpu=False,
+ verbose=True, verbose_preprocessing=True,
+ allow_tqdm=True)
+ predictor.initialize_from_trained_model_folder(trained_model_path, use_folds="0",
+ checkpoint_name="checkpoint_best.pth")
+ # predictor.allowed_mirroring_axes = (0, 2)
+
+ else:
+ predictor = nnUNetPredictor(tile_step_size=0.6, use_mirroring=True, perform_everything_on_gpu=False,
+ verbose=True, verbose_preprocessing=True,
+ allow_tqdm=True)
+ print(f"Image is too large ({new_shape.prod()}), using the 'all' fold")
+ predictor.initialize_from_trained_model_folder(trained_model_path, use_folds="0",
+ checkpoint_name="checkpoint_best.pth")
+
+ img_temp = predictor.predict_single_npy_array(images, properties, None, None, False).astype(np.uint8)
+ del images
+ print("Prediction Done", flush=True)
+ output_seg = sitk.GetImageFromArray(img_temp)
+ print(f"Seg: {output_seg.GetSize()}, CT: {fin_size}")
+ # output_seg.CopyInformation(ct_image)
+ output_seg.SetOrigin(fin_origin)
+ output_seg.SetSpacing(fin_spacing)
+ output_seg.SetDirection(fin_direction)
+ print("Got Image", flush=True)
+ # save the simpleITK image
+ # sitk.WriteImage(output_seg, str("output_seg.seg.nrrd"), True)
+ return output_seg
+
+if __name__ == "__main__":
+ time0 = time.time_ns()
+ MyHanseg2023Algorithm().process()
+ print((time.time_ns()-time0)/1000000000)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..1e174c7
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,104 @@
+#
+# This file is autogenerated by pip-compile with Python 3.8
+# by the following command:
+#
+# pip-compile --resolver=backtracking
+#
+antspyx==0.4.2
+arrow==1.2.3
+ # via jinja2-time
+binaryornot==0.4.4
+ # via cookiecutter
+build==0.10.0
+ # via pip-tools
+certifi==2022.12.7
+ # via requests
+chardet==5.1.0
+ # via binaryornot
+charset-normalizer==3.1.0
+ # via requests
+click==8.1.3
+ # via
+ # cookiecutter
+ # evalutils
+ # pip-tools
+cookiecutter==2.1.1
+ # via evalutils
+evalutils==0.4.0
+ # via -r requirements.in
+idna==3.4
+ # via requests
+imageio[tifffile]==2.26.0
+ # via evalutils
+jinja2==3.1.2
+ # via
+ # cookiecutter
+ # jinja2-time
+jinja2-time==0.2.0
+ # via cookiecutter
+joblib==1.2.0
+ # via scikit-learn
+markupsafe==2.1.2
+ # via jinja2
+# nnunetv2==2.2
+multiprocess==0.70.14
+numba==0.58.1
+numpy==1.24.2
+ # via
+ # evalutils
+ # imageio
+ # pandas
+ # scikit-learn
+ # scipy
+ # tifffile
+packaging==23.0
+ # via build
+pandas==1.5.3
+ # via evalutils
+pillow==9.4.0
+ # via imageio
+pip-tools==6.12.3
+ # via evalutils
+pyproject-hooks==1.0.0
+ # via build
+python-dateutil==2.8.2
+ # via
+ # arrow
+ # pandas
+python-slugify==8.0.1
+ # via cookiecutter
+pytz==2022.7.1
+ # via pandas
+pyyaml==6.0
+ # via cookiecutter
+requests==2.28.2
+ # via cookiecutter
+scikit-learn==1.2.2
+ # via evalutils
+scipy==1.10.1
+ # via
+ # evalutils
+ # scikit-learn
+simpleitk==2.2.1
+ # via evalutils
+six==1.16.0
+ # via python-dateutil
+text-unidecode==1.3
+ # via python-slugify
+threadpoolctl==3.1.0
+ # via scikit-learn
+tifffile==2023.3.15
+ # via imageio
+tomli==2.0.1
+ # via
+ # build
+ # pyproject-hooks
+torchio==0.18.86
+urllib3==1.26.15
+ # via requests
+wheel==0.40.0
+ # via pip-tools
+
+# The following packages are considered to be unsafe in a requirements file:
+# pip
+# setuptools
diff --git a/test.sh b/test.sh
new file mode 100755
index 0000000..aa2ece8
--- /dev/null
+++ b/test.sh
@@ -0,0 +1,37 @@
+#!/usr/bin/env bash
+
+SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )"
+echo $SCRIPTPATH
+OUTDIR=$SCRIPTPATH/output
+
+./build.sh
+
+# Maximum is currently 30g, configurable in your algorithm image settings on grand challenge
+MEM_LIMIT="12g"
+
+# create output dir if it does not exist
+if [ ! -d $OUTDIR ]; then
+ mkdir $OUTDIR;
+fi
+echo "starting docker"
+# Do not change any of the parameters to docker run, these are fixed
+docker run --rm \
+ --memory="${MEM_LIMIT}" \
+ --memory-swap="${MEM_LIMIT}" \
+ --network="none" \
+ --cap-drop="ALL" \
+ --security-opt="no-new-privileges" \
+ --shm-size="128m" \
+ --pids-limit="256" \
+ --gpus="0" \
+ -v $SCRIPTPATH/test/:/input/ \
+ -v $SCRIPTPATH/output/:/output \
+ hanseg2023algorithm_dmx
+echo "docker done"
+
+echo
+echo
+echo "Compare files in $OUTDIR with the expected results to see if test is successful"
+docker run --rm \
+ -v $OUTDIR:/output/ \
+ python:3.8-slim ls -al /output/images/head_neck_oar