diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..1068f81 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,3 @@ +test/ +.git/ +*.tar.gz diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..d1bf892 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,42 @@ +FROM pytorch/pytorch + +RUN groupadd -r user && useradd -m --no-log-init -r -g user user + +RUN mkdir -p /opt/app /input /output \ + && chown user:user /opt/app /input /output + +USER user +WORKDIR /opt/app + +ENV PATH="/home/user/.local/bin:${PATH}" + +RUN python -m pip install --user -U pip && python -m pip install --user pip-tools && python -m pip install --upgrade pip +COPY --chown=user:user nnUNet/ /opt/app/nnUNet/ +RUN python -m pip install -e nnUNet +#RUN python -m pip uninstall -y scipy +#RUN python -m pip install --user --upgrade scipy + +COPY --chown=user:user requirements.txt /opt/app/ +RUN python -m pip install --user -r requirements.txt + + +# This is the checkpoint file, uncomment the line below and modify /local/path/to/the/checkpoint to your needs +COPY --chown=user:user nnUNetTrainer__nnUNetPlans__3d_fullres.zip /opt/algorithm/checkpoint/nnUNet/ +RUN python -c "import zipfile; import os; zipfile.ZipFile('/opt/algorithm/checkpoint/nnUNet/nnUNetTrainer__nnUNetPlans__3d_fullres.zip').extractall('/opt/algorithm/checkpoint/nnUNet/')" + +COPY --chown=user:user custom_algorithm.py /opt/app/ +COPY --chown=user:user process.py /opt/app/ + +# COPY --chown=user:user weights /opt/algorithm/checkpoint +ENV nnUNet_results="/opt/algorithm/checkpoint/" +ENV nnUNet_raw="/opt/algorithm/nnUNet_raw_data_base" +ENV nnUNet_preprocessed="/opt/algorithm/preproc" +# ENV ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS=64(nope!) +# ENV nnUNet_def_n_proc=1 + +#ENTRYPOINT [ "python3", "-m", "process" ] + +ENV MKL_SERVICE_FORCE_INTEL=1 + +# Launches the script +ENTRYPOINT python -m process $0 $@ diff --git a/README.md b/README.md new file mode 100644 index 0000000..e68b7fb --- /dev/null +++ b/README.md @@ -0,0 +1,13 @@ +# DMX Solution to HaNSeg Challenge + +The Head and Neck oragan-at-risk CT & MR segmentation challenge. Contribution to the Grand Challenge (MICCAI 2023) + +Challenge URL: **[HaN-Seg 2023 challenge](https://han-seg2023.grand-challenge.org/)** + +This solution is based on: + + - [ANTsPY](https://antspy.readthedocs.io/en/latest/) + - [nnUNetv2](https://github.com/MIC-DKFZ/nnUNet/) + - [Zhack47](https://github.com/Zhack47/HaNSeg-QuantIF) + + diff --git a/build.sh b/build.sh new file mode 100755 index 0000000..b9c52ff --- /dev/null +++ b/build.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" +# docker build --no-cache -t hanseg2023algorithm "$SCRIPTPATH" +docker build -t hanseg2023algorithm_dmx "$SCRIPTPATH" diff --git a/custom_algorithm.py b/custom_algorithm.py new file mode 100644 index 0000000..fb688f4 --- /dev/null +++ b/custom_algorithm.py @@ -0,0 +1,195 @@ +import ants +import SimpleITK as sitk + +from evalutils import SegmentationAlgorithm + +import logging +from pathlib import Path +from typing import ( + Optional, + Pattern, + Tuple, +) + +from pandas import DataFrame +from evalutils.exceptions import FileLoaderError, ValidationError +from evalutils.validators import DataFrameValidator +from evalutils.io import ( + ImageLoader, +) + +logger = logging.getLogger(__name__) + + +class HanSegUniquePathIndicesValidator(DataFrameValidator): + """ + Validates that the indicies from the filenames are unique + """ + + def validate(self, *, df: DataFrame): + try: + paths_ct = df["path_ct"].tolist() + paths_mrt1 = df["path_mrt1"].tolist() + paths = paths_ct + paths_mrt1 + except KeyError: + raise ValidationError( + "Column `path_ct` or `path_mrt1` not found in DataFrame." + ) + + assert len(paths_ct) == len( + paths_mrt1 + ), "The number of CT and MR images is not equal." + + +class HanSegUniqueImagesValidator(DataFrameValidator): + """ + Validates that each image in the set is unique + """ + + def validate(self, *, df: DataFrame): + try: + hashes_ct = df["hash_ct"].tolist() + hashes_mrt1 = df["hash_mrt1"].tolist() + hashes = hashes_ct + hashes_mrt1 + except KeyError: + raise ValidationError( + "Column `hash_ct` or `hash_mrt1` not found in DataFrame." + ) + + if len(set(hashes)) != len(hashes): + raise ValidationError( + "The images are not unique, please submit a unique image for " + "each case." + ) + + +class Hanseg2023Algorithm(SegmentationAlgorithm): + def __init__( + self, + input_path=Path("/input/images/"), + output_path=Path("/output/images/head_neck_oar/"), + **kwargs, + ): + super().__init__( + validators=dict( + input_image=( + HanSegUniqueImagesValidator(), + HanSegUniquePathIndicesValidator(), + ) + ), + input_path=input_path, + output_path=output_path, + **kwargs, + ) + + def _load_input_image(self, *, case) -> Tuple[sitk.Image, Path]: + input_image_file_path_ct = case["path_ct"] + input_image_file_path_mrt1 = case["path_mrt1"] + + input_image_file_loader = self._file_loaders["input_image"] + if not isinstance(input_image_file_loader, ImageLoader): + raise RuntimeError("The used FileLoader was not of subclass ImageLoader") + + # Load the image for this case + + #input_image_ct = input_image_file_loader.load_image(input_image_file_path_ct) + #input_image_mrt1 = input_image_file_loader.load_image( + # input_image_file_path_mrt1 + #) + # Ok so hear me out... + # Instead of loading the nrrd files as SimpleITK images, shifting to ants Image, + # doing the registration, then back to SITK Image, we load them directly with ants + # I did this back when time limit was 5 min, to win seconds + input_image_ct = ants.image_read(input_image_file_path_ct.__str__()) + input_image_mrt1 = ants.image_read(input_image_file_path_mrt1.__str__()) + + # Check that it is the expected image + #if input_image_file_loader.hash_image(input_image_ct) != case["hash_ct"]: + # raise RuntimeError("CT image hashes do not match") + #if input_image_file_loader.hash_image(input_image_mrt1) != case["hash_mrt1"]: + # raise RuntimeError("MR image hashes do not match") + + return ( + input_image_ct, + input_image_file_path_ct, + input_image_mrt1, + input_image_file_path_mrt1, + ) + + def process_case(self, *, idx, case): + # Load and test the image for this case + ( + input_image_ct, + input_image_file_path_ct, + input_image_mrt1, + input_image_file_path_mrt1, + ) = self._load_input_image(case=case) + + # Segment nodule candidates + segmented_nodules = self.predict( + image_ct=input_image_ct, image_mrt1=input_image_mrt1 + ) + + # Write resulting segmentation to output location + segmentation_path = self._output_path / input_image_file_path_ct.name.replace( + "_CT", "_seg" + ) + self._output_path.mkdir(parents=True, exist_ok=True) + sitk.WriteImage(segmented_nodules, str(segmentation_path), True) + + # Write segmentation file path to result.json for this case + return { + "outputs": [dict(type="metaio_image", filename=segmentation_path.name)], + "inputs": [ + dict(type="metaio_ct_image", filename=input_image_file_path_ct.name), + dict( + type="metaio_mrt1_image", filename=input_image_file_path_mrt1.name + ), + ], + "error_messages": [], + } + + def _load_cases( + self, + *, + folder: Path, + file_loader: ImageLoader, + file_filter: Optional[Pattern[str]] = None, + ) -> DataFrame: + cases = [] + + paths_ct = sorted(folder.glob("ct/*"), key=self._file_sorter_key) + paths_mrt1 = sorted(folder.glob("t1-mri/*"), key=self._file_sorter_key) + + for pth_ct, pth_mr in zip(paths_ct, paths_mrt1): + if file_filter is None or ( + file_filter.match(str(pth_ct)) and file_filter.match(str(pth_mr)) + ): + try: + case_ct = file_loader.load(fname=pth_ct)[0] + case_mrt1 = file_loader.load(fname=pth_mr)[0] + new_cases = [ + { + "hash_ct": case_ct["hash"], + "path_ct": case_ct["path"], + "hash_mrt1": case_mrt1["hash"], + "path_mrt1": case_mrt1["path"], + } + ] + except FileLoaderError: + logger.warning( + f"Could not load {pth_ct.name} or {pth_mr.name} using {file_loader}." + ) + else: + cases += new_cases + else: + logger.info( + f"Skip loading {pth_ct.name} and {pth_mr.name} because it doesn't match {file_filter}." + ) + + if len(cases) == 0: + raise FileLoaderError( + f"Could not load any files in {folder} with " f"{file_loader}." + ) + + return DataFrame(cases) diff --git a/export.sh b/export.sh new file mode 100755 index 0000000..54abcec --- /dev/null +++ b/export.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash + +./build.sh + +docker save hanseg2023algorithm_dmx | gzip -c > HanSeg2023AlgorithmDMX.tar.gz diff --git a/nnUNet/LICENSE b/nnUNet/LICENSE new file mode 100644 index 0000000..8bbe09c --- /dev/null +++ b/nnUNet/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [2019] [Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/nnUNet/documentation/__init__.py b/nnUNet/documentation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/documentation/benchmarking.md b/nnUNet/documentation/benchmarking.md new file mode 100644 index 0000000..108a7f2 --- /dev/null +++ b/nnUNet/documentation/benchmarking.md @@ -0,0 +1,115 @@ +# nnU-Netv2 benchmarks + +Does your system run like it should? Is your epoch time longer than expected? What epoch times should you expect? + +Look no further for we have the solution here! + +## What does the nnU-netv2 benchmark do? + +nnU-Net's benchmark trains models for 5 epochs. At the end, the fastest epoch will +be noted down, along with the GPU name, torch version and cudnn version. You can find the benchmark output in the +corresponding nnUNet_results subfolder (see example below). Don't worry, we also provide scripts to collect your +results. Or you just start a benchmark and look at the console output. Everything is possible. Nothing is forbidden. + +The benchmark implementation revolves around two trainers: +- `nnUNetTrainerBenchmark_5epochs` runs a regular training for 5 epochs. When completed, writes a .json file with the fastest +epoch time as well as the GPU used and the torch and cudnn versions. Useful for speed testing the entire pipeline +(data loading, augmentation, GPU training) +- `nnUNetTrainerBenchmark_5epochs_noDataLoading` is the same, but it doesn't do any data loading or augmentation. It +just presents dummy arrays to the GPU. Useful for checking pure GPU speed. + +## How to run the nnU-Netv2 benchmark? +It's quite simple, actually. It looks just like a regular nnU-Net training. + +We provide reference numbers for some of the Medical Segmentation Decathlon datasets because they are easily +accessible: [download here](https://drive.google.com/drive/folders/1HqEgzS8BV2c7xYNrZdEAnrHk7osJJ--2). If it needs to be +quick and dirty, focus on Tasks 2 and 4. Download and extract the data and convert them to the nnU-Net format with +`nnUNetv2_convert_MSD_dataset`. +Run `nnUNetv2_plan_and_preprocess` for them. + +Then, for each dataset, run the following commands (only one per GPU! Or one after the other): + +```bash +nnUNetv2_train DATSET_ID 2d 0 -tr nnUNetTrainerBenchmark_5epochs +nnUNetv2_train DATSET_ID 3d_fullres 0 -tr nnUNetTrainerBenchmark_5epochs +nnUNetv2_train DATSET_ID 2d 0 -tr nnUNetTrainerBenchmark_5epochs_noDataLoading +nnUNetv2_train DATSET_ID 3d_fullres 0 -tr nnUNetTrainerBenchmark_5epochs_noDataLoading +``` + +If you want to inspect the outcome manually, check (for example!) your +`nnUNet_results/DATASET_NAME/nnUNetTrainerBenchmark_5epochs__nnUNetPlans__3d_fullres/fold_0/` folder for the `benchmark_result.json` file. + +Note that there can be multiple entries in this file if the benchmark was run on different GPU types, torch versions or cudnn versions! + +If you want to summarize your results like we did in our [results](#results), check the +[summary script](../nnunetv2/batch_running/benchmarking/summarize_benchmark_results.py). Here you need to change the +torch version, cudnn version and dataset you want to summarize, then execute the script. You can find the exact +values you need to put there in one of your `benchmark_result.json` files. + +## Results +We have tested a variety of GPUs and summarized the results in a +[spreadsheet](https://docs.google.com/spreadsheets/d/12Cvt_gr8XU2qWaE0XJk5jJlxMEESPxyqW0CWbQhTNNY/edit?usp=sharing). +Note that you can select the torch and cudnn versions at the bottom! There may be comments in this spreadsheet. Read them! + +## Result interpretation + +Results are shown as epoch time in seconds. Lower is better (duh). Epoch times can fluctuate between runs, so as +long as you are within like 5-10% of the numbers we report, everything should be dandy. + +If not, here is how you can try to find the culprit! + +The first thing to do is to compare the performance between the `nnUNetTrainerBenchmark_5epochs_noDataLoading` and +`nnUNetTrainerBenchmark_5epochs` trainers. If the difference is about the same as we report in our spreadsheet, but +both your numbers are worse, the problem is with your GPU: + +- Are you certain you compare the correct GPU? (duh) +- If yes, then you might want to install PyTorch in a different way. Never `pip install torch`! Go to the +[PyTorch installation](https://pytorch.org/get-started/locally/) page, select the most recent cuda version your +system supports and only then copy and execute the correct command! Either pip or conda should work +- If the problem is still not fixed, we recommend you try +[compiling pytorch from source](https://github.com/pytorch/pytorch#from-source). It's more difficult but that's +how we roll here at the DKFZ (at least the cool kids here). +- Another thing to consider is to try exactly the same torch + cudnn version as we did in our spreadsheet. +Sometimes newer versions can actually degrade performance and there might be bugs from time to time. Older versions +are also often a lot slower! +- Finally, some very basic things that could impact your GPU performance: + - Is the GPU cooled adequately? Check the temperature with `nvidia-smi`. Hot GPUs throttle performance in order to not self-destruct + - Is your OS using the GPU for displaying your desktop at the same time? If so then you can expect a performance + penalty (I dunno like 10% !?). That's expected and OK. + - Are other users using the GPU as well? + + +If you see a large performance difference between `nnUNetTrainerBenchmark_5epochs_noDataLoading` (fast) and +`nnUNetTrainerBenchmark_5epochs` (slow) then the problem might be related to data loading and augmentation. As a +reminder, nnU-net does not use pre-augmented images (offline augmentation) but instead generates augmented training +samples on the fly during training (no, you cannot switch it to offline). This requires that your system can do partial +reads of the image files fast enough (SSD storage required!) and that your CPU is powerful enough to run the augmentations. + +Check the following: + +- [CPU bottleneck] How many CPU threads are running during the training? nnU-Net uses 12 processes for data augmentation by default. +If you see those 12 running constantly during training, consider increasing the number of processes used for data +augmentation (provided there is headroom on your CPU!). Increase the number until you see less active workers than +you configured (or just set the number to 32 and forget about it). You can do so by setting the `nnUNet_n_proc_DA` +environment variable (Linux: `export nnUNet_n_proc_DA=24`). Read [here](set_environment_variables.md) on how to do this. +If your CPU does not support more processes (setting more processes than your CPU has threads makes +no sense!) you are out of luck and in desperate need of a system upgrade! +- [I/O bottleneck] If you don't see 12 (or nnUNet_n_proc_DA if you set it) processes running but your training times +are still slow then open up `top` (sorry, Windows users. I don't know how to do this on Windows) and look at the value +left of 'wa' in the row that begins +with '%Cpu (s)'. If this is >1.0 (arbitrarily set threshold here, essentially look for unusually high 'wa'. In a +healthy training 'wa' will be almost 0) then your storage cannot keep up with data loading. Make sure to set +nnUNet_preprocessed to a folder that is located on an SSD. nvme is preferred over SATA. PCIe3 is enough. 3000MB/s +sequential read recommended. +- [funky stuff] Sometimes there is funky stuff going on, especially when batch sizes are large, files are small and +patch sizes are small as well. As part of the data loading process, nnU-Net needs to open and close a file for each +training sample. Now imagine a dataset like Dataset004_Hippocampus where for the 2d config we have a batch size of +366 and we run 250 iterations in <10s on an A100. That's a lotta files per second (366 * 250 / 10 = 9150 files per second). +Oof. If the files are on some network drive (even if it's nvme) then (probably) good night. The good news: nnU-Net +has got you covered: add `export nnUNet_keep_files_open=True` to your .bashrc and the problem goes away. The neat +part: it causes new problems if you are not allowed to have enough open files. You may have to increase the number +of allowed open files. `ulimit -n` gives your current limit (Linux only). It should not be something like 1024. +Increasing that to 65535 works well for me. See here for how to change these limits: +[Link](https://kupczynski.info/posts/ubuntu-18-10-ulimits/) +(works for Ubuntu 18, google for your OS!). + diff --git a/nnUNet/documentation/changelog.md b/nnUNet/documentation/changelog.md new file mode 100644 index 0000000..0b56e44 --- /dev/null +++ b/nnUNet/documentation/changelog.md @@ -0,0 +1,51 @@ +# What is different in v2? + +- We now support **hierarchical labels** (named regions in nnU-Net). For example, instead of training BraTS with the +'edema', 'necrosis' and 'enhancing tumor' labels you can directly train it on the target areas 'whole tumor', +'tumor core' and 'enhancing tumor'. See [here](region_based_training.md) for a detailed description + also have a look at the +[BraTS 2021 conversion script](../nnunetv2/dataset_conversion/Dataset137_BraTS21.py). +- Cross-platform support. Cuda, mps (Apple M1/M2) and of course CPU support! Simply select the device with +`-device` in `nnUNetv2_train` and `nnUNetv2_predict`. +- Unified trainer class: nnUNetTrainer. No messing around with cascaded trainer, DDP trainer, region-based trainer, +ignore trainer etc. All default functionality is in there! +- Supports more input/output data formats through ImageIO classes. +- I/O formats can be extended by implementing new Adapters based on `BaseReaderWriter`. +- The nnUNet_raw_cropped folder no longer exists -> saves disk space at no performance penalty. magic! (no jk the +saving of cropped npz files was really slow, so it's actually faster to crop on the fly). +- Preprocessed data and segmentation are stored in different files when unpacked. Seg is stored as int8 and thus +takes 1/4 of the disk space per pixel (and I/O throughput) as in v1. +- Native support for multi-GPU (DDP) TRAINING. +Multi-GPU INFERENCE should still be run with `CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] -num_parts Y -part_id X`. +There is no cross-GPU communication in inference, so it doesn't make sense to add additional complexity with DDP. +- All nnU-Net functionality is now also accessible via API. Check the corresponding entry point in `setup.py` to see +what functions you need to call. +- Dataset fingerprint is now explicitly created and saved in a json file (see nnUNet_preprocessed). + +- Complete overhaul of plans files (read also [this](explanation_plans_files.md): + - Plans are now .json and can be opened and read more easily + - Configurations are explicitly named ("3d_fullres" , ...) + - Configurations can inherit from each other to make manual experimentation easier + - A ton of additional functionality is now included in and can be changed through the plans, for example normalization strategy, resampling etc. + - Stages of the cascade are now explicitly listed in the plans. 3d_lowres has 'next_stage' (which can also be a + list of configurations!). 3d_cascade_fullres has a 'previous_stage' entry. By manually editing plans files you can + now connect anything you want, for example 2d with 3d_fullres or whatever. Be wild! (But don't create cycles!) + - Multiple configurations can point to the same preprocessed data folder to save disk space. Careful! Only + configurations that use the same spacing, resampling, normalization etc. should share a data source! By default, + 3d_fullres and 3d_cascade_fullres share the same data + - Any number of configurations can be added to the plans (remember to give them a unique "data_identifier"!) + +Folder structures are different and more user-friendly: +- nnUNet_preprocessed + - By default, preprocessed data is now saved as: `nnUNet_preprocessed/DATASET_NAME/PLANS_IDENTIFIER_CONFIGURATION` to clearly link them to their corresponding plans and configuration + - Name of the folder containing the preprocessed images can be adapted with the `data_identifier` key. +- nnUNet_results + - Results are now sorted as follows: DATASET_NAME/TRAINERCLASS__PLANSIDENTIFIER__CONFIGURATION/FOLD + +## What other changes are planned and not yet implemented? +- Integration into MONAI (together with our friends at Nvidia) +- New pretrained weights for a large number of datasets (coming very soon)) + + +[//]: # (- nnU-Net now also natively supports an **ignore label**. Pixels with this label will not contribute to the loss. ) + +[//]: # (Use this to learn from sparsely annotated data, or excluding irrelevant areas from training. Read more [here](ignore_label.md).) \ No newline at end of file diff --git a/nnUNet/documentation/convert_msd_dataset.md b/nnUNet/documentation/convert_msd_dataset.md new file mode 100644 index 0000000..4c4ee48 --- /dev/null +++ b/nnUNet/documentation/convert_msd_dataset.md @@ -0,0 +1,3 @@ +Use `nnUNetv2_convert_MSD_dataset`. + +Read `nnUNetv2_convert_MSD_dataset -h` for usage instructions. \ No newline at end of file diff --git a/nnUNet/documentation/dataset_format.md b/nnUNet/documentation/dataset_format.md new file mode 100644 index 0000000..e11d8b2 --- /dev/null +++ b/nnUNet/documentation/dataset_format.md @@ -0,0 +1,232 @@ +# nnU-Net dataset format +The only way to bring your data into nnU-Net is by storing it in a specific format. Due to nnU-Net's roots in the +[Medical Segmentation Decathlon](http://medicaldecathlon.com/) (MSD), its dataset is heavily inspired but has since +diverged (see also [here](#how-to-use-decathlon-datasets)) from the format used in the MSD. + +Datasets consist of three components: raw images, corresponding segmentation maps and a dataset.json file specifying +some metadata. + +If you are migrating from nnU-Net v1, read [this](#how-to-use-nnu-net-v1-tasks) to convert your existing Tasks. + + +## What do training cases look like? +Each training case is associated with an identifier = a unique name for that case. This identifier is used by nnU-Net to +connect images with the correct segmentation. + +A training case consists of images and their corresponding segmentation. + +**Images** is plural because nnU-Net supports arbitrarily many input channels. In order to be as flexible as possible, +nnU-net requires each input channel to be stored in a separate image (with the sole exception being RGB natural +images). So these images could for example be a T1 and a T2 MRI (or whatever else you want). The different input +channels MUST have the same geometry (same shape, spacing (if applicable) etc.) and +must be co-registered (if applicable). Input channels are identified by nnU-Net by their FILE_ENDING: a four-digit integer at the end +of the filename. Image files must therefore follow the following naming convention: {CASE_IDENTIFIER}_{XXXX}.{FILE_ENDING}. +Hereby, XXXX is the 4-digit modality/channel identifier (should be unique for each modality/chanel, e.g., “0000” for T1, “0001” for +T2 MRI, …) and FILE_ENDING is the file extension used by your image format (.png, .nii.gz, ...). See below for concrete examples. +The dataset.json file connects channel names with the channel identifiers in the 'channel_names' key (see below for details). + +Side note: Typically, each channel/modality needs to be stored in a separate file and is accessed with the XXXX channel identifier. +Exception are natural images (RGB; .png) where the three color channels can all be stored in one file (see the [road segmentation](../nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py) dataset as an example). + +**Segmentations** must share the same geometry with their corresponding images (same shape etc.). Segmentations are +integer maps with each value representing a semantic class. The background must be 0. If there is no background, then +do not use the label 0 for something else! Integer values of your semantic classes must be consecutive (0, 1, 2, 3, +...). Of course, not all labels have to be present in each training case. Segmentations are saved as {CASE_IDENTIFER}.{FILE_ENDING} . + +Within a training case, all image geometries (input channels, corresponding segmentation) must match. Between training +cases, they can of course differ. nnU-Net takes care of that. + +Important: The input channels must be consistent! Concretely, **all images need the same input channels in the same +order and all input channels have to be present every time**. This is also true for inference! + + +## Supported file formats +nnU-Net expects the same file format for images and segmentations! These will also be used for inference. For now, it +is thus not possible to train .png and then run inference on .jpg. + +One big change in nnU-Net V2 is the support of multiple input file types. Gone are the days of converting everything to .nii.gz! +This is implemented by abstracting the input and output of images + segmentations through `BaseReaderWriter`. nnU-Net +comes with a broad collection of Readers+Writers and you can even add your own to support your data format! +See [here](../nnunetv2/imageio/readme.md). + +As a nice bonus, nnU-Net now also natively supports 2D input images and you no longer have to mess around with +conversions to pseudo 3D niftis. Yuck. That was disgusting. + +Note that internally (for storing and accessing preprocessed images) nnU-Net will use its own file format, irrespective +of what the raw data was provided in! This is for performance reasons. + + +By default, the following file formats are supported: +- NaturalImage2DIO: .png, .bmp, .tif +- NibabelIO: .nii.gz, .nrrd, .mha +- NibabelIOWithReorient: .nii.gz, .nrrd, .mha. This reader will reorient images to RAS! +- SimpleITKIO: .nii.gz, .nrrd, .mha +- Tiff3DIO: .tif, .tiff. 3D tif images! Since TIF does not have a standardized way of storing spacing information, +nnU-Net expects each TIF file to be accompanied by an identically named .json file that contains three numbers +(no units, no comma. Just separated by whitespace), one for each dimension. + + +The file extension lists are not exhaustive and depend on what the backend supports. For example, nibabel and SimpleITK +support more than the three given here. The file endings given here are just the ones we tested! + +IMPORTANT: nnU-Net can only be used with file formats that use lossless (or no) compression! Because the file +format is defined for an entire dataset (and not separately for images and segmentations, this could be a todo for +the future), we must ensure that there are no compression artifacts that destroy the segmentation maps. So no .jpg and +the likes! + +## Dataset folder structure +Datasets must be located in the `nnUNet_raw` folder (which you either define when installing nnU-Net or export/set every +time you intend to run nnU-Net commands!). +Each segmentation dataset is stored as a separate 'Dataset'. Datasets are associated with a dataset ID, a three digit +integer, and a dataset name (which you can freely choose): For example, Dataset005_Prostate has 'Prostate' as dataset name and +the dataset id is 5. Datasets are stored in the `nnUNet_raw` folder like this: + + nnUNet_raw/ + ├── Dataset001_BrainTumour + ├── Dataset002_Heart + ├── Dataset003_Liver + ├── Dataset004_Hippocampus + ├── Dataset005_Prostate + ├── ... + +Within each dataset folder, the following structure is expected: + + Dataset001_BrainTumour/ + ├── dataset.json + ├── imagesTr + ├── imagesTs # optional + └── labelsTr + + +When adding your custom dataset, take a look at the [dataset_conversion](../nnunetv2/dataset_conversion) folder and +pick an id that is not already taken. IDs 001-010 are for the Medical Segmentation Decathlon. + +- **imagesTr** contains the images belonging to the training cases. nnU-Net will perform pipeline configuration, training with +cross-validation, as well as finding postprocessing and the best ensemble using this data. +- **imagesTs** (optional) contains the images that belong to the test cases. nnU-Net does not use them! This could just +be a convenient location for you to store these images. Remnant of the Medical Segmentation Decathlon folder structure. +- **labelsTr** contains the images with the ground truth segmentation maps for the training cases. +- **dataset.json** contains metadata of the dataset. + +The scheme introduced [above](#what-do-training-cases-look-like) results in the following folder structure. Given +is an example for the first Dataset of the MSD: BrainTumour. This dataset hat four input channels: FLAIR (0000), +T1w (0001), T1gd (0002) and T2w (0003). Note that the imagesTs folder is optional and does not have to be present. + + nnUNet_raw/Dataset001_BrainTumour/ + ├── dataset.json + ├── imagesTr + │   ├── BRATS_001_0000.nii.gz + │   ├── BRATS_001_0001.nii.gz + │   ├── BRATS_001_0002.nii.gz + │   ├── BRATS_001_0003.nii.gz + │   ├── BRATS_002_0000.nii.gz + │   ├── BRATS_002_0001.nii.gz + │   ├── BRATS_002_0002.nii.gz + │   ├── BRATS_002_0003.nii.gz + │   ├── ... + ├── imagesTs + │   ├── BRATS_485_0000.nii.gz + │   ├── BRATS_485_0001.nii.gz + │   ├── BRATS_485_0002.nii.gz + │   ├── BRATS_485_0003.nii.gz + │   ├── BRATS_486_0000.nii.gz + │   ├── BRATS_486_0001.nii.gz + │   ├── BRATS_486_0002.nii.gz + │   ├── BRATS_486_0003.nii.gz + │   ├── ... + └── labelsTr + ├── BRATS_001.nii.gz + ├── BRATS_002.nii.gz + ├── ... + +Here is another example of the second dataset of the MSD, which has only one input channel: + + nnUNet_raw/Dataset002_Heart/ + ├── dataset.json + ├── imagesTr + │   ├── la_003_0000.nii.gz + │   ├── la_004_0000.nii.gz + │   ├── ... + ├── imagesTs + │   ├── la_001_0000.nii.gz + │   ├── la_002_0000.nii.gz + │   ├── ... + └── labelsTr + ├── la_003.nii.gz + ├── la_004.nii.gz + ├── ... + +Remember: For each training case, all images must have the same geometry to ensure that their pixel arrays are aligned. Also +make sure that all your data is co-registered! + +See also [dataset format inference](dataset_format_inference.md)!! + +## dataset.json +The dataset.json contains metadata that nnU-Net needs for training. We have greatly reduced the number of required +fields since version 1! + +Here is what the dataset.json should look like at the example of the Dataset005_Prostate from the MSD: + + { + "channel_names": { # formerly modalities + "0": "T2", + "1": "ADC" + }, + "labels": { # THIS IS DIFFERENT NOW! + "background": 0, + "PZ": 1, + "TZ": 2 + }, + "numTraining": 32, + "file_ending": ".nii.gz" + "overwrite_image_reader_writer": "SimpleITKIO" # optional! If not provided nnU-Net will automatically determine the ReaderWriter + } + +The channel_names determine the normalization used by nnU-Net. If a channel is marked as 'CT', then a global +normalization based on the intensities in the foreground pixels will be used. If it is something else, per-channel +z-scoring will be used. Refer to the methods section in [our paper](https://www.nature.com/articles/s41592-020-01008-z) +for more details. nnU-Net v2 introduces a few more normalization schemes to +choose from and allows you to define your own, see [here](explanation_normalization.md) for more information. + +Important changes relative to nnU-Net v1: +- "modality" is now called "channel_names" to remove strong bias to medical images +- labels are structured differently (name -> int instead of int -> name). This was needed to support [region-based training](region_based_training.md) +- "file_ending" is added to support different input file types +- "overwrite_image_reader_writer" optional! Can be used to specify a certain (custom) ReaderWriter class that should +be used with this dataset. If not provided, nnU-Net will automatically determine the ReaderWriter +- "regions_class_order" only used in [region-based training](region_based_training.md) + +There is a utility with which you can generate the dataset.json automatically. You can find it +[here](../nnunetv2/dataset_conversion/generate_dataset_json.py). +See our examples in [dataset_conversion](../nnunetv2/dataset_conversion) for how to use it. And read its documentation! + +## How to use nnU-Net v1 Tasks +If you are migrating from the old nnU-Net, convert your existing datasets with `nnUNetv2_convert_old_nnUNet_dataset`! + +Example for migrating a nnU-Net v1 Task: +```bash +nnUNetv2_convert_old_nnUNet_dataset /media/isensee/raw_data/nnUNet_raw_data_base/nnUNet_raw_data/Task027_ACDC Dataset027_ACDC +``` +Use `nnUNetv2_convert_old_nnUNet_dataset -h` for detailed usage instructions. + + +## How to use decathlon datasets +See [convert_msd_dataset.md](convert_msd_dataset.md) + +## How to use 2D data with nnU-Net +2D is now natively supported (yay!). See [here](#supported-file-formats) as well as the example dataset in this +[script](../nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py). + + +## How to update an existing dataset +When updating a dataset it is best practice to remove the preprocessed data in `nnUNet_preprocessed/DatasetXXX_NAME` +to ensure a fresh start. Then replace the data in `nnUNet_raw` and rerun `nnUNetv2_plan_and_preprocess`. Optionally, +also remove the results from old trainings. + +# Example dataset conversion scripts +In the `dataset_conversion` folder (see [here](../nnunetv2/dataset_conversion)) are multiple example scripts for +converting datasets into nnU-Net format. These scripts cannot be run as they are (you need to open them and change +some paths) but they are excellent examples for you to learn how to convert your own datasets into nnU-Net format. +Just pick the dataset that is closest to yours as a starting point. +The list of dataset conversion scripts is continually updated. If you find that some publicly available dataset is +missing, feel free to open a PR to add it! diff --git a/nnUNet/documentation/dataset_format_inference.md b/nnUNet/documentation/dataset_format_inference.md new file mode 100644 index 0000000..503d431 --- /dev/null +++ b/nnUNet/documentation/dataset_format_inference.md @@ -0,0 +1,39 @@ +# Data format for Inference +Read the documentation on the overall [data format](dataset_format.md) first! + +The data format for inference must match the one used for the raw data (**specifically, the images must be in exactly +the same format as in the imagesTr folder**). As before, the filenames must start with a +unique identifier, followed by a 4-digit modality identifier. Here is an example for two different datasets: + +1) Task005_Prostate: + + This task has 2 modalities, so the files in the input folder must look like this: + + input_folder + ├── prostate_03_0000.nii.gz + ├── prostate_03_0001.nii.gz + ├── prostate_05_0000.nii.gz + ├── prostate_05_0001.nii.gz + ├── prostate_08_0000.nii.gz + ├── prostate_08_0001.nii.gz + ├── ... + + _0000 has to be the T2 image and _0001 has to be the ADC image (as specified by 'channel_names' in the +dataset.json), exactly the same as was used for training. + +2) Task002_Heart: + + imagesTs + ├── la_001_0000.nii.gz + ├── la_002_0000.nii.gz + ├── la_006_0000.nii.gz + ├── ... + + Task002 only has one modality, so each case only has one _0000.nii.gz file. + + +The segmentations in the output folder will be named {CASE_IDENTIFIER}.nii.gz (omitting the modality identifier). + +Remember that the file format used for inference (.nii.gz in this example) must be the same as was used for training +(and as was specified in 'file_ending' in the dataset.json)! + \ No newline at end of file diff --git a/nnUNet/documentation/explanation_normalization.md b/nnUNet/documentation/explanation_normalization.md new file mode 100644 index 0000000..ed2b897 --- /dev/null +++ b/nnUNet/documentation/explanation_normalization.md @@ -0,0 +1,45 @@ +# Intensity normalization in nnU-Net + +The type of intensity normalization applied in nnU-Net can be controlled via the `channel_names` (former `modalities`) +entry in the dataset.json. Just like the old nnU-Net, per-channel z-scoring as well as dataset-wide z-scoring based on +foreground intensities are supported. However, there have been a few additions as well. + +Reminder: The `channel_names` entry typically looks like this: + + "channel_names": { + "0": "T2", + "1": "ADC" + }, + +It has as many entries as there are input channels for the given dataset. + +To tell you a secret, nnU-Net does not really care what your channels are called. We just use this to determine what normalization +scheme will be used for the given dataset. nnU-Net requires you to specify a normalization strategy for each of your input channels! +If you enter a channel name that is not in the following list, the default (`zscore`) will be used. + +Here is a list of currently available normalization schemes: + +- `CT`: Perform CT normalization. Specifically, collect intensity values from the foreground classes (all but the +background and ignore) from all training cases, compute the mean, standard deviation as well as the 0.5 and +99.5 percentile of the values. Then clip to the percentiles, followed by subtraction of the mean and division with the +standard deviation. The normalization that is applied is the same for each training case (for this input channel). +The values used by nnU-Net for normalization are stored in the `foreground_intensity_properties_per_channel` entry in the +corresponding plans file. This normalization is suitable for modalities presenting physical quantities such as CT +images and ADC maps. +- `noNorm` : do not perform any normalization at all +- `rescale_to_0_1`: rescale the intensities to [0, 1] +- `rgb_to_0_1`: assumes uint8 inputs. Divides by 255 to rescale uint8 to [0, 1] +- `zscore`/anything else: perform z-scoring (subtract mean and standard deviation) separately for each train case + +**Important:** The nnU-Net default is to perform 'CT' normalization for CT images and 'zscore' for everything else! If +you deviate from that path, make sure to benchmark whether that actually improves results! + +# How to implement custom normalization strategies? +- Head over to nnunetv2/preprocessing/normalization +- implement a new image normalization class by deriving from ImageNormalization +- register it in nnunetv2/preprocessing/normalization/map_channel_name_to_normalization.py:channel_name_to_normalization_mapping. +This is where you specify a channel name that should be associated with it +- use it by specifying the correct channel_name + +Normalization can only be applied to one channel at a time. There is currently no way of implementing a normalization scheme +that gets multiple channels as input to be used jointly! \ No newline at end of file diff --git a/nnUNet/documentation/explanation_plans_files.md b/nnUNet/documentation/explanation_plans_files.md new file mode 100644 index 0000000..00f1216 --- /dev/null +++ b/nnUNet/documentation/explanation_plans_files.md @@ -0,0 +1,185 @@ +# Modifying the nnU-Net Configurations + +nnU-Net provides unprecedented out-of-the-box segmentation performance for essentially any dataset we have evaluated +it on. That said, there is always room for improvements. A fool-proof strategy for squeezing out the last bit of +performance is to start with the default nnU-Net, and then further tune it manually to a concrete dataset at hand. +**This guide is about changes to the nnU-Net configuration you can make via the plans files. It does not cover code +extensions of nnU-Net. For that, take a look [here](extending_nnunet.md)** + +In nnU-Net V2, plans files are SO MUCH MORE powerful than they were in v1. There are a lot more knobs that you can +turn without resorting to hacky solutions or even having to touch the nnU-Net code at all! And as an added bonus: +plans files are now also .json files and no longer require users to fiddle with pickle. Just open them in your text +editor of choice! + +If overwhelmed, look at our [Examples](#examples)! + +# plans.json structure + +Plans have global and local settings. Global settings are applied to all configurations in that plans file while +local settings are attached to a specific configuration. + +## Global settings + +- `foreground_intensity_properties_by_modality`: Intensity statistics of the foreground regions (all labels except +background and ignore label), computed over all training cases. Used by [CT normalization scheme](explanation_normalization.md). +- `image_reader_writer`: Name of the image reader/writer class that should be used with this dataset. You might want +to change this if, for example, you would like to run inference with files that have a different file format. The +class that is named here must be located in nnunetv2.imageio! +- `label_manager`: The name of the class that does label handling. Take a look at +nnunetv2.utilities.label_handling.LabelManager to see what it does. If you decide to change it, place your version +in nnunetv2.utilities.label_handling! +- `transpose_forward`: nnU-Net transposes the input data so that the axes with the highest resolution (lowest spacing) +come last. This is because the 2D U-Net operates on the trailing dimensions (more efficient slicing due to internal +memory layout of arrays). Future work might move this setting to affect only individual configurations. +- transpose_backward is what numpy.transpose gets as new axis ordering. +- `transpose_backward`: the axis ordering that inverts "transpose_forward" +- \[`original_median_shape_after_transp`\]: just here for your information +- \[`original_median_spacing_after_transp`\]: just here for your information +- \[`plans_name`\]: do not change. Used internally +- \[`experiment_planner_used`\]: just here as metadata so that we know what planner originally generated this file +- \[`dataset_name`\]: do not change. This is the dataset these plans are intended for + +## Local settings +Plans also have a `configurations` key in which the actual configurations are stored. `configurations` are again a +dictionary, where the keys are the configuration names and the values are the local settings for each configuration. + +To better understand the components describing the network topology in our plans files, please read section 6.2 +in the [supplementary information](https://static-content.springer.com/esm/art%3A10.1038%2Fs41592-020-01008-z/MediaObjects/41592_2020_1008_MOESM1_ESM.pdf) +(page 13) of our paper! + +Local settings: +- `spacing`: the target spacing used in this configuration +- `patch_size`: the patch size used for training this configuration +- `data_identifier`: the preprocessed data for this configuration will be saved in + nnUNet_preprocessed/DATASET_NAME/_data_identifier_. If you add a new configuration, remember to set a unique + data_identifier in order to not create conflicts with other configurations (unless you plan to reuse the data from + another configuration, for example as is done in the cascade) +- `batch_size`: batch size used for training +- `batch_dice`: whether to use batch dice (pretend all samples in the batch are one image, compute dice loss over that) +or not (each sample in the batch is a separate image, compute dice loss for each sample and average over samples) +- `preprocessor_name`: Name of the preprocessor class used for running preprocessing. Class must be located in +nnunetv2.preprocessing.preprocessors +- `use_mask_for_norm`: whether to use the nonzero mask for normalization or not (relevant for BraTS and the like, +probably False for all other datasets). Interacts with ImageNormalization class +- `normalization_schemes`: mapping of channel identifier to ImageNormalization class name. ImageNormalization +classes must be located in nnunetv2.preprocessing.normalization. Also see [here](explanation_normalization.md) +- `resampling_fn_data`: name of resampling function to be used for resizing image data. resampling function must be +callable(data, current_spacing, new_spacing, **kwargs). It must be located in nnunetv2.preprocessing.resampling +- `resampling_fn_data_kwargs`: kwargs for resampling_fn_data +- `resampling_fn_probabilities`: name of resampling function to be used for resizing predicted class probabilities/logits. +resampling function must be `callable(data: Union[np.ndarray, torch.Tensor], current_spacing, new_spacing, **kwargs)`. It must be located in +nnunetv2.preprocessing.resampling +- `resampling_fn_probabilities_kwargs`: kwargs for resampling_fn_probabilities +- `resampling_fn_seg`: name of resampling function to be used for resizing segmentation maps (integer: 0, 1, 2, 3, etc). +resampling function must be callable(data, current_spacing, new_spacing, **kwargs). It must be located in +nnunetv2.preprocessing.resampling +- `resampling_fn_seg_kwargs`: kwargs for resampling_fn_seg +- `UNet_class_name`: UNet class name, can be used to integrate custom dynamic architectures +- `UNet_base_num_features`: The number of starting features for the UNet architecture. Default is 32. Default: Features +are doubled with each downsampling +- `unet_max_num_features`: Maximum number of features (default: capped at 320 for 3D and 512 for 2d). The purpose is to +prevent parameters from exploding too much. +- `conv_kernel_sizes`: the convolutional kernel sizes used by nnU-Net in each stage of the encoder. The decoder + mirrors the encoder and is therefore not explicitly listed here! The list is as long as `n_conv_per_stage_encoder` has + entries +- `n_conv_per_stage_encoder`: number of convolutions used per stage (=at a feature map resolution in the encoder) in the encoder. + Default is 2. The list has as many entries as the encoder has stages +- `n_conv_per_stage_decoder`: number of convolutions used per stage in the decoder. Also see `n_conv_per_stage_encoder` +- `num_pool_per_axis`: number of times each of the spatial axes is pooled in the network. Needed to know how to pad + image sizes during inference (num_pool = 5 means input must be divisible by 2**5=32) +- `pool_op_kernel_sizes`: the pooling kernel sizes (and at the same time strides) for each stage of the encoder +- \[`median_image_size_in_voxels`\]: the median size of the images of the training set at the current target spacing. +Do not modify this as this is not used. It is just here for your information. + +Special local settings: +- `inherits_from`: configurations can inherit from each other. This makes it easy to add new configurations that only +differ in a few local settings from another. If using this, remember to set a new `data_identifier` (if needed)! +- `previous_stage`: if this configuration is part of a cascade, we need to know what the previous stage (for example +the low resolution configuration) was. This needs to be specified here. +- `next_stage`: if this configuration is part of a cascade, we need to know what possible subsequent stages are! This +is because we need to export predictions in the correct spacing when running the validation. `next_stage` can either +be a string or a list of strings + +# Examples + +## Increasing the batch size for large datasets +If your dataset is large the training can benefit from larger batch_sizes. To do this, simply create a new +configuration in the `configurations` dict + + "configurations": { + "3d_fullres_bs40": { + "inherits_from": "3d_fullres", + "batch_size": 40 + } + } + +No need to change the data_identifier. `3d_fullres_bs40` will just use the preprocessed data from `3d_fullres`. +No need to rerun `nnUNetv2_preprocess` because we can use already existing data (if available) from `3d_fullres`. + +## Using custom preprocessors +If you would like to use a different preprocessor class then this can be specified as follows: + + "configurations": { + "3d_fullres_my_preprocesor": { + "inherits_from": "3d_fullres", + "preprocessor_name": MY_PREPROCESSOR, + "data_identifier": "3d_fullres_my_preprocesor" + } + } + +You need to run preprocessing for this new configuration: +`nnUNetv2_preprocess -d DATASET_ID -c 3d_fullres_my_preprocesor` because it changes the preprocessing. Remember to +set a unique `data_identifier` whenever you make modifications to the preprocessed data! + +## Change target spacing + + "configurations": { + "3d_fullres_my_spacing": { + "inherits_from": "3d_fullres", + "spacing": [X, Y, Z], + "data_identifier": "3d_fullres_my_spacing" + } + } + +You need to run preprocessing for this new configuration: +`nnUNetv2_preprocess -d DATASET_ID -c 3d_fullres_my_spacing` because it changes the preprocessing. Remember to +set a unique `data_identifier` whenever you make modifications to the preprocessed data! + +## Adding a cascade to a dataset where it does not exist +Hippocampus is small. It doesn't have a cascade. It also doesn't really make sense to add a cascade here but hey for +the sake of demonstration we can do that. +We change the following things here: + +- `spacing`: The lowres stage should operate at a lower resolution +- we modify the `median_image_size_in_voxels` entry as a guide for what original image sizes we deal with +- we set some patch size that is inspired by `median_image_size_in_voxels` +- we need to remember that the patch size must be divisible by 2**num_pool in each axis! +- network parameters such as kernel sizes, pooling operations are changed accordingly +- we need to specify the name of the next stage +- we need to add the highres stage + +This is how this would look like (comparisons with 3d_fullres given as reference): + + "configurations": { + "3d_lowres": { + "inherits_from": "3d_fullres", + "data_identifier": "3d_lowres" + "spacing": [2.0, 2.0, 2.0], # from [1.0, 1.0, 1.0] in 3d_fullres + "median_image_size_in_voxels": [18, 25, 18], # from [36, 50, 35] + "patch_size": [20, 28, 20], # from [40, 56, 40] + "n_conv_per_stage_encoder": [2, 2, 2], # one less entry than 3d_fullres ([2, 2, 2, 2]) + "n_conv_per_stage_decoder": [2, 2], # one less entry than 3d_fullres + "num_pool_per_axis": [2, 2, 2], # one less pooling than 3d_fullres in each dimension (3d_fullres: [3, 3, 3]) + "pool_op_kernel_sizes": [[1, 1, 1], [2, 2, 2], [2, 2, 2]], # one less [2, 2, 2] + "conv_kernel_sizes": [[3, 3, 3], [3, 3, 3], [3, 3, 3]], # one less [3, 3, 3] + "next_stage": "3d_cascade_fullres" # name of the next stage in the cascade + }, + "3d_cascade_fullres": { # does not need a data_identifier because we can use the data of 3d_fullres + "inherits_from": "3d_fullres", + "previous_stage": "3d_lowres" # name of the previous stage + } + } + +To better understand the components describing the network topology in our plans files, please read section 6.2 +in the [supplementary information](https://static-content.springer.com/esm/art%3A10.1038%2Fs41592-020-01008-z/MediaObjects/41592_2020_1008_MOESM1_ESM.pdf) +(page 13) of our paper! \ No newline at end of file diff --git a/nnUNet/documentation/extending_nnunet.md b/nnUNet/documentation/extending_nnunet.md new file mode 100644 index 0000000..2924f4f --- /dev/null +++ b/nnUNet/documentation/extending_nnunet.md @@ -0,0 +1,37 @@ +# Extending nnU-Net +We hope that the new structure of nnU-Net v2 makes it much more intuitive on how to modify it! We cannot give an +extensive tutorial on how each and every bit of it can be modified. It is better for you to search for the position +in the repository where the thing you intend to change is implemented and start working your way through the code from +there. Setting breakpoints and debugging into nnU-Net really helps in understanding it and thus will help you make the +necessary modifications! + +Here are some things you might want to read before you start: +- Editing nnU-Net configurations through plans files is really powerful now and allows you to change a lot of things regarding +preprocessing, resampling, network topology etc. Read [this](explanation_plans_files.md)! +- [Image normalization](explanation_normalization.md) and [i/o formats](dataset_format.md#supported-file-formats) are easy to extend! +- Manual data splits can be defined as described [here](manual_data_splits.md) +- You can chain arbitrary configurations together into cascades, see [this again](explanation_plans_files.md) +- Read about our support for [region-based training](region_based_training.md) +- If you intend to modify the training procedure (loss, sampling, data augmentation, lr scheduler, etc) then you need +to implement your own trainer class. Best practice is to create a class that inherits from nnUNetTrainer and +implements the necessary changes. Head over to our [trainer classes folder](../nnunetv2/training/nnUNetTrainer) for +inspiration! There will be similar trainers for what you intend to change and you can take them as a guide. nnUNetTrainer +are structured similarly to PyTorch lightning trainers, this should also make things easier! +- Integrating new network architectures can be done in two ways: + - Quick and dirty: implement a new nnUNetTrainer class and overwrite its `build_network_architecture` function. + Make sure your architecture is compatible with deep supervision (if not, use `nnUNetTrainerNoDeepSupervision` + as basis!) and that it can handle the patch sizes that are thrown at it! Your architecture should NOT apply any + nonlinearities at the end (softmax, sigmoid etc). nnU-Net does that! + - The 'proper' (but difficult) way: Build a dynamically configurable architecture such as the `PlainConvUNet` class + used by default. It needs to have some sort of GPU memory estimation method that can be used to evaluate whether + certain patch sizes and + topologies fit into a specified GPU memory target. Build a new `ExperimentPlanner` that can configure your new + class and communicate with its memory budget estimation. Run `nnUNetv2_plan_and_preprocess` while specifying your + custom `ExperimentPlanner` and a custom `plans_name`. Implement a nnUNetTrainer that can use the plans generated by + your `ExperimentPlanner` to instantiate the network architecture. Specify your plans and trainer when running `nnUNetv2_train`. + It always pays off to first read and understand the corresponding nnU-Net code and use it as a template for your implementation! +- Remember that multi-GPU training, region-based training, ignore label and cascaded training are now simply integrated +into one unified nnUNetTrainer class. No separate classes needed (remember that when implementing your own trainer +classes and ensure support for all of these features! Or raise `NotImplementedError`) + +[//]: # (- Read about our support for [ignore label](ignore_label.md) and [region-based training](region_based_training.md)) diff --git a/nnUNet/documentation/how_to_use_nnunet.md b/nnUNet/documentation/how_to_use_nnunet.md new file mode 100644 index 0000000..d962768 --- /dev/null +++ b/nnUNet/documentation/how_to_use_nnunet.md @@ -0,0 +1,297 @@ +## How to run nnU-Net on a new dataset +Given some dataset, nnU-Net fully automatically configures an entire segmentation pipeline that matches its properties. +nnU-Net covers the entire pipeline, from preprocessing to model configuration, model training, postprocessing +all the way to ensembling. After running nnU-Net, the trained model(s) can be applied to the test cases for inference. + +### Dataset Format +nnU-Net expects datasets in a structured format. This format is inspired by the data structure of +the [Medical Segmentation Decthlon](http://medicaldecathlon.com/). Please read +[this](dataset_format.md) for information on how to set up datasets to be compatible with nnU-Net. + +**Since version 2 we support multiple image file formats (.nii.gz, .png, .tif, ...)! Read the dataset_format +documentation to learn more!** + +**Datasets from nnU-Net v1 can be converted to V2 by running `nnUNetv2_convert_old_nnUNet_dataset INPUT_FOLDER +OUTPUT_DATASET_NAME`.** Remember that v2 calls datasets DatasetXXX_Name (not Task) where XXX is a 3-digit number. +Please provide the **path** to the old task, not just the Task name. nnU-Net V2 doesn't know where v1 tasks were! + +### Experiment planning and preprocessing +Given a new dataset, nnU-Net will extract a dataset fingerprint (a set of dataset-specific properties such as +image sizes, voxel spacings, intensity information etc). This information is used to design three U-Net configurations. +Each of these pipelines operates on its own preprocessed version of the dataset. + +The easiest way to run fingerprint extraction, experiment planning and preprocessing is to use: + +```bash +nnUNetv2_plan_and_preprocess -d DATASET_ID --verify_dataset_integrity +``` + +Where `DATASET_ID` is the dataset id (duh). We recommend `--verify_dataset_integrity` whenever it's the first time +you run this command. This will check for some of the most common error sources! + +You can also process several datasets at once by giving `-d 1 2 3 [...]`. If you already know what U-Net configuration +you need you can also specify that with `-c 3d_fullres` (make sure to adapt -np in this case!). For more information +about all the options available to you please run `nnUNetv2_plan_and_preprocess -h`. + +nnUNetv2_plan_and_preprocess will create a new subfolder in your nnUNet_preprocessed folder named after the dataset. +Once the command is completed there will be a dataset_fingerprint.json file as well as a nnUNetPlans.json file for you to look at +(in case you are interested!). There will also be subfolders containing the preprocessed data for your UNet configurations. + +[Optional] +If you prefer to keep things separate, you can also use `nnUNetv2_extract_fingerprint`, `nnUNetv2_plan_experiment` +and `nnUNetv2_preprocess` (in that order). + +### Model training +#### Overview +You pick which configurations (2d, 3d_fullres, 3d_lowres, 3d_cascade_fullres) should be trained! If you have no idea +what performs best on your data, just run all of them and let nnU-Net identify the best one. It's up to you! + +nnU-Net trains all configurations in a 5-fold cross-validation over the training cases. This is 1) needed so that +nnU-Net can estimate the performance of each configuration and tell you which one should be used for your +segmentation problem and 2) a natural way of obtaining a good model ensemble (average the output of these 5 models +for prediction) to boost performance. + +You can influence the splits nnU-Net uses for 5-fold cross-validation (see [here](manual_data_splits.md)). If you +prefer to train a single model on all training cases, this is also possible (see below). + +**Note that not all U-Net configurations are created for all datasets. In datasets with small image sizes, the U-Net +cascade (and with it the 3d_lowres configuration) is omitted because the patch size of the full resolution U-Net +already covers a large part of the input images.** + +Training models is done with the `nnUNetv2_train` command. The general structure of the command is: +```bash +nnUNetv2_train DATASET_NAME_OR_ID UNET_CONFIGURATION FOLD [additional options, see -h] +``` + +UNET_CONFIGURATION is a string that identifies the requested U-Net configuration (defaults: 2d, 3d_fullres, 3d_lowres, +3d_cascade_lowres). DATASET_NAME_OR_ID specifies what dataset should be trained on and FOLD specifies which fold of +the 5-fold-cross-validation is trained. + +nnU-Net stores a checkpoint every 50 epochs. If you need to continue a previous training, just add a `--c` to the +training command. + +IMPORTANT: If you plan to use `nnUNetv2_find_best_configuration` (see below) add the `--npz` flag. This makes +nnU-Net save the softmax outputs during the final validation. They are needed for that. Exported softmax +predictions are very large and therefore can take up a lot of disk space, which is why this is not enabled by default. +If you ran initially without the `--npz` flag but now require the softmax predictions, simply rerun the validation with: +```bash +nnUNetv2_train DATASET_NAME_OR_ID UNET_CONFIGURATION FOLD --val --npz +``` + +You can specify the device nnU-net should use by using `-device DEVICE`. DEVICE can only be cpu, cuda or mps. If +you have multiple GPUs, please select the gpu id using `CUDA_VISIBLE_DEVICES=X nnUNetv2_train [...]` (requires device to be cuda). + +See `nnUNetv2_train -h` for additional options. + +### 2D U-Net +For FOLD in [0, 1, 2, 3, 4], run: +```bash +nnUNetv2_train DATASET_NAME_OR_ID 2d FOLD [--npz] +``` + +### 3D full resolution U-Net +For FOLD in [0, 1, 2, 3, 4], run: +```bash +nnUNetv2_train DATASET_NAME_OR_ID 3d_fullres FOLD [--npz] +``` + +### 3D U-Net cascade +#### 3D low resolution U-Net +For FOLD in [0, 1, 2, 3, 4], run: +```bash +nnUNetv2_train DATASET_NAME_OR_ID 3d_lowres FOLD [--npz] +``` + +#### 3D full resolution U-Net +For FOLD in [0, 1, 2, 3, 4], run: +```bash +nnUNetv2_train DATASET_NAME_OR_ID 3d_cascade_fullres FOLD [--npz] +``` +**Note that the 3D full resolution U-Net of the cascade requires the five folds of the low resolution U-Net to be +completed!** + +The trained models will be written to the nnUNet_results folder. Each training obtains an automatically generated +output folder name: + +nnUNet_results/DatasetXXX_MYNAME/TRAINER_CLASS_NAME__PLANS_NAME__CONFIGURATION/FOLD + +For Dataset002_Heart (from the MSD), for example, this looks like this: + + nnUNet_results/ + ├── Dataset002_Heart + │── nnUNetTrainer__nnUNetPlans__2d + │ ├── fold_0 + │ ├── fold_1 + │ ├── fold_2 + │ ├── fold_3 + │ ├── fold_4 + │ ├── dataset.json + │ ├── dataset_fingerprint.json + │ └── plans.json + └── nnUNetTrainer__nnUNetPlans__3d_fullres + ├── fold_0 + ├── fold_1 + ├── fold_2 + ├── fold_3 + ├── fold_4 + ├── dataset.json + ├── dataset_fingerprint.json + └── plans.json + +Note that 3d_lowres and 3d_cascade_fullres do not exist here because this dataset did not trigger the cascade. In each +model training output folder (each of the fold_x folder), the following files will be created: +- debug.json: Contains a summary of blueprint and inferred parameters used for training this model as well as a +bunch of additional stuff. Not easy to read, but very useful for debugging ;-) +- checkpoint_best.pth: checkpoint files of the best model identified during training. Not used right now unless you +explicitly tell nnU-Net to use it. +- checkpoint_final.pth: checkpoint file of the final model (after training has ended). This is what is used for both +validation and inference. +- network_architecture.pdf (only if hiddenlayer is installed!): a pdf document with a figure of the network architecture in it. +- progress.png: Shows losses, pseudo dice, learning rate and epoch times ofer the course of the training. At the top is +a plot of the training (blue) and validation (red) loss during training. Also shows an approximation of + the dice (green) as well as a moving average of it (dotted green line). This approximation is the average Dice score + of the foreground classes. **It needs to be taken with a big (!) + grain of salt** because it is computed on randomly drawn patches from the validation + data at the end of each epoch, and the aggregation of TP, FP and FN for the Dice computation treats the patches as if + they all originate from the same volume ('global Dice'; we do not compute a Dice for each validation case and then + average over all cases but pretend that there is only one validation case from which we sample patches). The reason for + this is that the 'global Dice' is easy to compute during training and is still quite useful to evaluate whether a model + is training at all or not. A proper validation takes way too long to be done each epoch. It is run at the end of the training. +- validation_raw: in this folder are the predicted validation cases after the training has finished. The summary.json file in here + contains the validation metrics (a mean over all cases is provided at the start of the file). If `--npz` was set then +the compressed softmax outputs (saved as .npz files) are in here as well. + +During training it is often useful to watch the progress. We therefore recommend that you have a look at the generated +progress.png when running the first training. It will be updated after each epoch. + +Training times largely depend on the GPU. The smallest GPU we recommend for training is the Nvidia RTX 2080ti. With +that all network trainings take less than 2 days. Refer to our [benchmarks](benchmarking.md) to see if your system is +performing as expected. + +### Using multiple GPUs for training + +If multiple GPUs are at your disposal, the best way of using them is to train multiple nnU-Net trainings at once, one +on each GPU. This is because data parallelism never scales perfectly linearly, especially not with small networks such +as the ones used by nnU-Net. + +Example: + +```bash +CUDA_VISIBLE_DEVICES=0 nnUNetv2_train DATASET_NAME_OR_ID 2d 0 [--npz] & # train on GPU 0 +CUDA_VISIBLE_DEVICES=1 nnUNetv2_train DATASET_NAME_OR_ID 2d 1 [--npz] & # train on GPU 1 +CUDA_VISIBLE_DEVICES=2 nnUNetv2_train DATASET_NAME_OR_ID 2d 2 [--npz] & # train on GPU 2 +CUDA_VISIBLE_DEVICES=3 nnUNetv2_train DATASET_NAME_OR_ID 2d 3 [--npz] & # train on GPU 3 +CUDA_VISIBLE_DEVICES=4 nnUNetv2_train DATASET_NAME_OR_ID 2d 4 [--npz] & # train on GPU 4 +... +wait +``` + +**Important: The first time a training is run nnU-Net will extract the preprocessed data into uncompressed numpy +arrays for speed reasons! This operation must be completed before starting more than one training of the same +configuration! Wait with starting subsequent folds until the first training is using the GPU! Depending on the +dataset size and your System this should oly take a couple of minutes at most.** + +If you insist on running DDP multi-GPU training, we got you covered: + +`nnUNetv2_train DATASET_NAME_OR_ID 2d 0 [--npz] -num_gpus X` + +Again, note that this will be slower than running separate training on separate GPUs. DDP only makes sense if you have +manually interfered with the nnU-Net configuration and are training larger models with larger patch and/or batch sizes! + +Important when using `-num_gpus`: +1) If you train using, say, 2 GPUs but have more GPUs in the system you need to specify which GPUs should be used via +CUDA_VISIBLE_DEVICES=0,1 (or whatever your ids are). +2) You cannot specify more GPUs than you have samples in your minibatches. If the batch size is 2, 2 GPUs is the maximum! +3) Make sure your batch size is divisible by the numbers of GPUs you use or you will not make good use of your hardware. + +In contrast to the old nnU-Net, DDP is now completely hassle free. Enjoy! + +### Automatically determine the best configuration +Once the desired configurations were trained (full cross-validation) you can tell nnU-Net to automatically identify +the best combination for you: + +```commandline +nnUNetv2_find_best_configuration DATASET_NAME_OR_ID -c CONFIGURATIONS +``` + +`CONFIGURATIONS` hereby is the list of configurations you would like to explore. Per default, ensembling is enabled +meaning that nnU-Net will generate all possible combinations of ensembles (2 configurations per ensemble). This requires +the .npz files containing the predicted probabilities of the validation set to be present (use `nnUNetv2_train` with +`--npz` flag, see above). You can disable ensembling by setting the `--disable_ensembling` flag. + +See `nnUNetv2_find_best_configuration -h` for more options. + +nnUNetv2_find_best_configuration will also automatically determine the postprocessing that should be used. +Postprocessing in nnU-Net only considers the removal of all but the largest component in the prediction (once for +foreground vs background and once for each label/region). + +Once completed, the command will print to your console exactly what commands you need to run to make predictions. It +will also create two files in the `nnUNet_results/DATASET_NAME` folder for you to inspect: +- `inference_instructions.txt` again contains the exact commands you need to use for predictions +- `inference_information.json` can be inspected to see the performance of all configurations and ensembles, as well +as the effect of the postprocessing plus some debug information. + +### Run inference +Remember that the data located in the input folder must have the file endings as the dataset you trained the model on +and must adhere to the nnU-Net naming scheme for image files (see [dataset format](dataset_format.md) and +[inference data format](dataset_format_inference.md)!) + +`nnUNetv2_find_best_configuration` (see above) will print a string to the terminal with the inference commands you need to use. +The easiest way to run inference is to simply use these commands. + +If you wish to manually specify the configuration(s) used for inference, use the following commands: + +#### Run prediction +For each of the desired configurations, run: +``` +nnUNetv2_predict -i INPUT_FOLDER -o OUTPUT_FOLDER -d DATASET_NAME_OR_ID -c CONFIGURATION --save_probabilities +``` + +Only specify `--save_probabilities` if you intend to use ensembling. `--save_probabilities` will make the command save the predicted +probabilities alongside of the predicted segmentation masks requiring a lot of disk space. + +Please select a separate `OUTPUT_FOLDER` for each configuration! + +Note that per default, inference will be done with all 5 folds from the cross-validation as an ensemble. We very +strongly recommend you use all 5 folds. Thus, all 5 folds must have been trained prior to running inference. + +If you wish to make predictions with a single model, train the `all` fold and specify it in `nnUNetv2_predict` +with `-f all` + +#### Ensembling multiple configurations +If you wish to ensemble multiple predictions (typically form different configurations), you can do so with the following command: +```bash +nnUNetv2_ensemble -i FOLDER1 FOLDER2 ... -o OUTPUT_FOLDER -np NUM_PROCESSES +``` + +You can specify an arbitrary number of folders, but remember that each folder needs to contain npz files that were +generated by `nnUNetv2_predict`. Again, `nnUNetv2_ensemble -h` will tell you more about additional options. + +#### Apply postprocessing +Finally, apply the previously determined postprocessing to the (ensembled) predictions: + +```commandline +nnUNetv2_apply_postprocessing -i FOLDER_WITH_PREDICTIONS -o OUTPUT_FOLDER --pp_pkl_file POSTPROCESSING_FILE -plans_json PLANS_FILE -dataset_json DATASET_JSON_FILE +``` + +`nnUNetv2_find_best_configuration` (or its generated `inference_instructions.txt` file) will tell you where to find +the postprocessing file. If not you can just look for it in your results folder (it's creatively named +`postprocessing.pkl`). If your source folder is from an ensemble, you also need to specify a `-plans_json` file and +a `-dataset_json` file that should be used (for single configuration predictions these are automatically copied +from the respective training). You can pick these files from any of the ensemble members. + + +## How to run inference with pretrained models +See [here](run_inference_with_pretrained_models.md) + +[//]: # (## Examples) + +[//]: # () +[//]: # (To get you started we compiled two simple to follow examples:) + +[//]: # (- run a training with the 3d full resolution U-Net on the Hippocampus dataset. See [here](documentation/training_example_Hippocampus.md).) + +[//]: # (- run inference with nnU-Net's pretrained models on the Prostate dataset. See [here](documentation/inference_example_Prostate.md).) + +[//]: # () +[//]: # (Usability not good enough? Let us know!) diff --git a/nnUNet/documentation/installation_instructions.md b/nnUNet/documentation/installation_instructions.md new file mode 100644 index 0000000..409e5eb --- /dev/null +++ b/nnUNet/documentation/installation_instructions.md @@ -0,0 +1,87 @@ +# System requirements + +## Operating System +nnU-Net has been tested on Linux (Ubuntu 18.04, 20.04, 22.04; centOS, RHEL), Windows and MacOS! It should work out of the box! + +## Hardware requirements +We support GPU (recommended), CPU and Apple M1/M2 as devices (currently Apple mps does not implement 3D +convolutions, so you might have to use the CPU on those devices). + +### Hardware requirements for Training +We recommend you use a GPU for training as this will take a really long time on CPU or MPS (Apple M1/M2). +For training a GPU with at least 10 GB (popular non-datacenter options are the RTX 2080ti, RTX 3080/3090 or RTX 4080/4090) is +required. We also recommend a strong CPU to go along with the GPU. 6 cores (12 threads) +are the bare minimum! CPU requirements are mostly related to data augmentation and scale with the number of +input channels and target structures. Plus, the faster the GPU, the better the CPU should be! + +### Hardware Requirements for inference +Again we recommend a GPU to make predictions as this will be substantially faster than the other options. However, +inference times are typically still manageable on CPU and MPS (Apple M1/M2). If using a GPU, it should have at least +4 GB of available (unused) VRAM. + +### Example hardware configurations +Example workstation configurations for training: +- CPU: Ryzen 5800X - 5900X or 7900X would be even better! We have not yet tested Intel Alder/Raptor lake but they will likely work as well. +- GPU: RTX 3090 or RTX 4090 +- RAM: 64GB +- Storage: SSD (M.2 PCIe Gen 3 or better!) + +Example Server configuration for training: +- CPU: 2x AMD EPYC7763 for a total of 128C/256T. 16C/GPU are highly recommended for fast GPUs such as the A100! +- GPU: 8xA100 PCIe (price/performance superior to SXM variant + they use less power) +- RAM: 1 TB +- Storage: local SSD storage (PCIe Gen 3 or better) or ultra fast network storage + +(nnU-net by default uses one GPU per training. The server configuration can run up to 8 model trainings simultaneously) + +### Setting the correct number of Workers for data augmentation (training only) +Note that you will need to manually set the number of processes nnU-Net uses for data augmentation according to your +CPU/GPU ratio. For the server above (256 threads for 8 GPUs), a good value would be 24-30. You can do this by +setting the `nnUNet_n_proc_DA` environment variable (`export nnUNet_n_proc_DA=XX`). +Recommended values (assuming a recent CPU with good IPC) are 10-12 for RTX 2080 ti, 12 for a RTX 3090, 16-18 for +RTX 4090, 28-32 for A100. Optimal values may vary depending on the number of input channels/modalities and number of classes. + +# Installation instructions +We strongly recommend that you install nnU-Net in a virtual environment! Pip or anaconda are both fine. If you choose to +compile PyTorch from source (see below), you will need to use conda instead of pip. + +Use a recent version of Python! 3.9 or newer is guaranteed to work! + +**nnU-Net v2 can coexist with nnU-Net v1! Both can be installed at the same time.** + +1) Install [PyTorch](https://pytorch.org/get-started/locally/) as described on their website (conda/pip). Please +install the latest version with support for your hardware (cuda, mps, cpu). +**DO NOT JUST `pip install nnunetv2` WITHOUT PROPERLY INSTALLING PYTORCH FIRST**. For maximum speed, consider +[compiling pytorch yourself](https://github.com/pytorch/pytorch#from-source) (experienced users only!). +2) Install nnU-Net depending on your use case: + 1) For use as **standardized baseline**, **out-of-the-box segmentation algorithm** or for running + **inference with pretrained models**: + + ```pip install nnunetv2``` + + 2) For use as integrative **framework** (this will create a copy of the nnU-Net code on your computer so that you + can modify it as needed): + ```bash + git clone https://github.com/MIC-DKFZ/nnUNet.git + cd nnUNet + pip install -e . + ``` +3) nnU-Net needs to know where you intend to save raw data, preprocessed data and trained models. For this you need to + set a few environment variables. Please follow the instructions [here](setting_up_paths.md). +4) (OPTIONAL) Install [hiddenlayer](https://github.com/waleedka/hiddenlayer). hiddenlayer enables nnU-net to generate + plots of the network topologies it generates (see [Model training](how_to_use_nnunet.md#model-training)). +To install hiddenlayer, + run the following command: + ```bash + pip install --upgrade git+https://github.com/FabianIsensee/hiddenlayer.git + ``` + +Installing nnU-Net will add several new commands to your terminal. These commands are used to run the entire nnU-Net +pipeline. You can execute them from any location on your system. All nnU-Net commands have the prefix `nnUNetv2_` for +easy identification. + +Note that these commands simply execute python scripts. If you installed nnU-Net in a virtual environment, this +environment must be activated when executing the commands. You can see what scripts/functions are executed by +checking the entry_points in the setup.py file. + +All nnU-Net commands have a `-h` option which gives information on how to use them. diff --git a/nnUNet/documentation/manual_data_splits.md b/nnUNet/documentation/manual_data_splits.md new file mode 100644 index 0000000..f299c02 --- /dev/null +++ b/nnUNet/documentation/manual_data_splits.md @@ -0,0 +1,46 @@ +# How to generate custom splits in nnU-Net + +Sometimes, the default 5-fold cross-validation split by nnU-Net does not fit a project. Maybe you want to run 3-fold +cross-validation instead? Or maybe your training cases cannot be split randomly and require careful stratification. +Fear not, for nnU-Net has got you covered (it really can do anything <3). + +The splits nnU-Net uses are generated in the `do_split` function of nnUNetTrainer. This function will first look for +existing splits, stored as a file, and if no split exists it will create one. So if you wish to influence the split, +manually creating a split file that will then be recognized and used is the way to go! + +The split file is located in the `nnUNet_preprocessed/DATASETXXX_NAME` folder. So it is best practice to first +populate this folder by running `nnUNetv2_plan_and_preproccess`. + +Splits are stored as a .json file. They are a simple python list. The length of that list is the number of splits it +contains (so it's 5 in the default nnU-Net). Each list entry is a dictionary with keys 'train' and 'val'. Values are +again simply lists with the train identifiers in each set. To illustrate this, I am just messing with the Dataset002 +file as an example: + +```commandline +In [1]: from batchgenerators.utilities.file_and_folder_operations import load_json + +In [2]: splits = load_json('splits_final.json') + +In [3]: len(splits) +Out[3]: 5 + +In [4]: splits[0].keys() +Out[4]: dict_keys(['train', 'val']) + +In [5]: len(splits[0]['train']) +Out[5]: 16 + +In [6]: len(splits[0]['val']) +Out[6]: 4 + +In [7]: print(splits[0]) +{'train': ['la_003', 'la_004', 'la_005', 'la_009', 'la_010', 'la_011', 'la_014', 'la_017', 'la_018', 'la_019', 'la_020', 'la_022', 'la_023', 'la_026', 'la_029', 'la_030'], +'val': ['la_007', 'la_016', 'la_021', 'la_024']} +``` + +If you are still not sure what splits are supposed to look like, simply download some reference dataset from the +[Medical Decathlon](http://medicaldecathlon.com/), start some training (to generate the splits) and manually inspect +the .json file with your text editor of choice! + +In order to generate your custom splits, all you need to do is reproduce the data structure explained above and save it as +`splits_final.json` in the `nnUNet_preprocessed/DATASETXXX_NAME` folder. Then use `nnUNetv2_train` etc. as usual. \ No newline at end of file diff --git a/nnUNet/documentation/pretraining_and_finetuning.md b/nnUNet/documentation/pretraining_and_finetuning.md new file mode 100644 index 0000000..44b46dc --- /dev/null +++ b/nnUNet/documentation/pretraining_and_finetuning.md @@ -0,0 +1,82 @@ +# Pretraining with nnU-Net + +## Intro + +So far nnU-Net only supports supervised pre-training, meaning that you train a regular nnU-Net on some source dataset +and then use the final network weights as initialization for your target dataset. + +As a reminder, many training hyperparameters such as patch size and network topology differ between datasets as a +result of the automated dataset analysis and experiment planning nnU-Net is known for. So, out of the box, it is not +possible to simply take the network weights from some dataset and then reuse them for another. + +Consequently, the plans need to be aligned between the two tasks. In this README we show how this can be achieved and +how the resulting weights can then be used for initialization. + +### Terminology + +Throughout this README we use the following terminology: + +- `source dataset` is the dataset you intend to run the pretraining on +- `target dataset` is the dataset you are interested in; the one you wish to fine tune on + + +## Pretraining on the source dataset + +In order to obtain matching network topologies we need to transfer the plans from one dataset to another. Since we are +only interested in the target dataset, we first need to run experiment planning (and preprocessing) for it: + +```bash +nnUNetv2_plan_and_preprocess -d TARGET_DATASET +``` + +Then we need to extract the dataset fingerprint of the source dataset, if not yet available: + +```bash +nnUNetv2_extract_fingerprint -d SOURCE_DATASET +``` + +Now we can take the plans from the target dataset and transfer it to the source: + +```bash +nnUNetv2_move_plans_between_datasets -s SOURCE_DATSET -t TARGET_DATASET -sp SOURCE_PLANS_IDENTIFIER -tp TARGET_PLANS_IDENTIFIER +``` + +`SOURCE_PLANS_IDENTIFIER` is hereby probably nnUNetPlans unless you changed the experiment planner in +nnUNetv2_plan_and_preprocess. For `TARGET_PLANS_IDENTIFIER` we recommend you set something custom in order to not +overwrite default plans. + +Note that EVERYTHING is transferred between the datasets. Not just the network topology, batch size and patch size but +also the normalization scheme! Therefore, a transfer between datasets that use different normalization schemes may not +work well (but it could, depending on the schemes!). + +Note on CT normalization: Yes, also the clip values, mean and std are transferred! + +Now you can run the preprocessing on the source task: + +```bash +nnUNetv2_preprocess -d SOURCE_DATSET -plans_name TARGET_PLANS_IDENTIFIER +``` + +And run the training as usual: + +```bash +nnUNetv2_train SOURCE_DATSET CONFIG all -p TARGET_PLANS_IDENTIFIER +``` + +Note how we use the 'all' fold to train on all available data. For pretraining it does not make sense to split the data. + +## Using pretrained weights + +Once pretraining is completed (or you obtain compatible weights by other means) you can use them to initialize your model: + +```bash +nnUNetv2_train TARGET_DATASET CONFIG FOLD -pretrained_weights PATH_TO_CHECKPOINT +``` + +Specify the checkpoint in PATH_TO_CHECKPOINT. + +When loading pretrained weights, all layers except the segmentation layers will be used! + +So far there are no specific nnUNet trainers for fine tuning, so the current recommendation is to just use +nnUNetTrainer. You can however easily write your own trainers with learning rate ramp up, fine-tuning of segmentation +heads or shorter training time. \ No newline at end of file diff --git a/nnUNet/documentation/region_based_training.md b/nnUNet/documentation/region_based_training.md new file mode 100644 index 0000000..265c890 --- /dev/null +++ b/nnUNet/documentation/region_based_training.md @@ -0,0 +1,75 @@ +# Region-based training + +## What is this about? +In some segmentation tasks, most prominently the +[Brain Tumor Segmentation Challenge](http://braintumorsegmentation.org/), the target areas (based on which the metric +will be computed) are different from the labels provided in the training data. This is the case because for some +clinical applications, it is more relevant to detect the whole tumor, tumor core and enhancing tumor instead of the +individual labels (edema, necrosis and non-enhancing tumor, enhancing tumor). + + + +The figure shows an example BraTS case along with label-based representation of the task (top) and region-based +representation (bottom). The challenge evaluation is done on the regions. As we have shown in our +[BraTS 2018 contribution](https://arxiv.org/abs/1809.10483), directly optimizing those +overlapping areas over the individual labels yields better scoring models! + +## What can nnU-Net do? +nnU-Net's region-based training allows you to learn areas that are constructed by merging individual labels. For +some segmentation tasks this provides a benefit, as this shifts the importance allocated to different labels during training. +Most prominently, this feature can be used to represent **hierarchical classes**, for example when organs + +substructures are to be segmented. Imagine a liver segmentation problem, where vessels and tumors are also to be +segmented. The first target region could thus be the entire liver (including the substructures), while the remaining +targets are the individual substructues. + +Important: nnU-Net still requires integer label maps as input and will produce integer label maps as output! +Region-based training can be used to learn overlapping labels, but there must be a way to model these overlaps +for nnU-Net to work (see below how this is done). + +## How do you use it? + +When declaring the labels in the `dataset.json` file, BraTS would typically look like this: + +```python +... +"labels": { + "background": 0, + "edema": 1, + "non_enhancing_and_necrosis": 2, + "enhancing_tumor": 3 +}, +... +``` +(we use different int values than the challenge because nnU-Net needs consecutive integers!) + +This representation corresponds to the upper row in the figure above. + +For region-based training, the labels need to be changed to the following: + +```python +... +"labels": { + "background": 0, + "whole_tumor": [1, 2, 3], + "tumor_core": [2, 3], + "enhancing_tumor": 3 # or [3] +}, +"regions_class_order": [1, 2, 3], +... +``` +This corresponds to the bottom row in the figure above. Note how an additional entry in the dataset.json is +required: `regions_class_order`. This tells nnU-Net how to convert the region representations back to an integer map. +It essentially just tells nnU-Net what labels to place for which region in what order. Concretely, for the example +given here, nnU-Net will place the label 1 for the 'whole_tumor' region, then place the label 2 where the "tumor_core" +is and finally place the label 3 in the 'enhancing_tumor' area. With each step, part of the previously set pixels +will be overwritten with the new label! So when setting your `regions_class_order`, place encompassing regions +(like whole tumor etc) first, followed by substructures. + +**IMPORTANT** Because the conversion back to a segmentation map is sensitive to the order in which the regions are +declared ("place label X in the first region") you need to make sure that this order is not perturbed! When +automatically generating the dataset.json, make sure the dictionary keys do not get sorted alphabetically! Set +`sort_keys=False` in `json.dump()`!!! + +nnU-Net will perform the evaluation + model selection also on the regions, not the individual labels! + +That's all. Easy, huh? \ No newline at end of file diff --git a/nnUNet/documentation/run_inference_with_pretrained_models.md b/nnUNet/documentation/run_inference_with_pretrained_models.md new file mode 100644 index 0000000..d2698a2 --- /dev/null +++ b/nnUNet/documentation/run_inference_with_pretrained_models.md @@ -0,0 +1,7 @@ +# How to run inference with pretrained models +**Important:** Pretrained weights from nnU-Net v1 are NOT compatible with V2. You will need to retrain with the new +version. But honestly, you already have a fully trained model with which you can run inference (in v1), so +just continue using that! + +Not yet available for V2 :-( +If you wish to run inference with pretrained models, check out the old nnU-Net for now. We are working on this full steam! diff --git a/nnUNet/documentation/set_environment_variables.md b/nnUNet/documentation/set_environment_variables.md new file mode 100644 index 0000000..2dc22a5 --- /dev/null +++ b/nnUNet/documentation/set_environment_variables.md @@ -0,0 +1,78 @@ +# How to set environment variables + +nnU-Net requires some environment variables so that it always knows where the raw data, preprocessed data and trained +models are. Depending on the operating system, these environment variables need to be set in different ways. + +Variables can either be set permanently (recommended!) or you can decide to set them everytime you call nnU-Net. + +# Linux & MacOS + +## Permanent +Locate the `.bashrc` file in your home folder and add the following lines to the bottom: + +```bash +export nnUNet_raw="/media/fabian/nnUNet_raw" +export nnUNet_preprocessed="/media/fabian/nnUNet_preprocessed" +export nnUNet_results="/media/fabian/nnUNet_results" +``` + +(of course you need to adapt the paths to the actual folders you intend to use). +If you are using a different shell, such as zsh, you will need to find the correct script for it. For zsh this is `.zshrc`. + +## Temporary +Just execute the following lines whenever you run nnU-Net: +```bash +export nnUNet_raw="/media/fabian/nnUNet_raw" +export nnUNet_preprocessed="/media/fabian/nnUNet_preprocessed" +export nnUNet_results="/media/fabian/nnUNet_results" +``` +(of course you need to adapt the paths to the actual folders you intend to use). + +Important: These variables will be deleted if you close your terminal! They will also only apply to the current +terminal window and DO NOT transfer to other terminals! + +Alternatively you can also just prefix them to your nnU-Net commands: + +`nnUNet_results="/media/fabian/nnUNet_results" nnUNet_preprocessed="/media/fabian/nnUNet_preprocessed" nnUNetv2_train[...]` + +## Verify that environment parameters are set +You can always execute `echo ${nnUNet_raw}` etc to print the environment variables. This will return an empty string if +they were not set. + +# Windows +Useful links: +- [https://www3.ntu.edu.sg](https://www3.ntu.edu.sg/home/ehchua/programming/howto/Environment_Variables.html#:~:text=To%20set%20(or%20change)%20a,it%20to%20an%20empty%20string.) +- [https://phoenixnap.com](https://phoenixnap.com/kb/windows-set-environment-variable) + +## Permanent +See `Set Environment Variable in Windows via GUI` [here](https://phoenixnap.com/kb/windows-set-environment-variable). +Or read about setx (command prompt). + +## Temporary +Just execute the following before you run nnU-Net: + +(powershell) +```powershell +$Env:nnUNet_raw = "/media/fabian/nnUNet_raw" +$Env:nnUNet_preprocessed = "/media/fabian/nnUNet_preprocessed" +$Env:nnUNet_results = "/media/fabian/nnUNet_results" +``` + +(command prompt) +```commandline +set nnUNet_raw="/media/fabian/nnUNet_raw" +set nnUNet_preprocessed="/media/fabian/nnUNet_preprocessed" +set nnUNet_results="/media/fabian/nnUNet_results" +``` + +(of course you need to adapt the paths to the actual folders you intend to use). + +Important: These variables will be deleted if you close your session! They will also only apply to the current +window and DO NOT transfer to other sessions! + +## Verify that environment parameters are set +Printing in Windows works differently depending on the environment you are in: + +powershell: `echo $Env:[variable_name]` + +command prompt: `echo %[variable_name]%` \ No newline at end of file diff --git a/nnUNet/documentation/setting_up_paths.md b/nnUNet/documentation/setting_up_paths.md new file mode 100644 index 0000000..87f9f8c --- /dev/null +++ b/nnUNet/documentation/setting_up_paths.md @@ -0,0 +1,38 @@ +# Setting up Paths + +nnU-Net relies on environment variables to know where raw data, preprocessed data and trained model weights are stored. +To use the full functionality of nnU-Net, the following three environment variables must be set: + +1) `nnUNet_raw`: This is where you place the raw datasets. This folder will have one subfolder for each dataset names +DatasetXXX_YYY where XXX is a 3-digit identifier (such as 001, 002, 043, 999, ...) and YYY is the (unique) +dataset name. The datasets must be in nnU-Net format, see [here](dataset_format.md). + + Example tree structure: + ``` + nnUNet_raw/Dataset001_NAME1 + ├── dataset.json + ├── imagesTr + │   ├── ... + ├── imagesTs + │   ├── ... + └── labelsTr + ├── ... + nnUNet_raw/Dataset002_NAME2 + ├── dataset.json + ├── imagesTr + │   ├── ... + ├── imagesTs + │   ├── ... + └── labelsTr + ├── ... + ``` + +2) `nnUNet_preprocessed`: This is the folder where the preprocessed data will be saved. The data will also be read from +this folder during training. It is important that this folder is located on a drive with low access latency and high +throughput (such as a nvme SSD (PCIe gen 3 is sufficient)). + +3) `nnUNet_results`: This specifies where nnU-Net will save the model weights. If pretrained models are downloaded, this +is where it will save them. + +### How to set environment variables +See [here](set_environment_variables.md). \ No newline at end of file diff --git a/nnUNet/documentation/tldr_migration_guide_from_v1.md b/nnUNet/documentation/tldr_migration_guide_from_v1.md new file mode 100644 index 0000000..f9ec951 --- /dev/null +++ b/nnUNet/documentation/tldr_migration_guide_from_v1.md @@ -0,0 +1,20 @@ +# TLDR Migration Guide from nnU-Net V1 + +- nnU-Net V2 can be installed simultaneously with V1. They won't get in each other's way +- The environment variables needed for V2 have slightly different names. Read [this](setting_up_paths.md). +- nnU-Net V2 datasets are called DatasetXXX_NAME. Not Task. +- Datasets have the same structure (imagesTr, labelsTr, dataset.json) but we now support more +[file types](dataset_format.md#supported-file-formats). The dataset.json is simplified. Use `generate_dataset_json` +from nnunetv2.dataset_conversion.generate_dataset_json.py. +- Careful: labels are now no longer declared as value:name but name:value. This has to do with [hierarchical labels](region_based_training.md). +- nnU-Net v2 commands start with `nnUNetv2...`. They work mostly (but not entirely) the same. Just use the `-h` option. +- You can transfer your V1 raw datasets to V2 with `nnUNetv2_convert_old_nnUNet_dataset`. You cannot transfer trained +models. Continue to use the old nnU-Net Version for making inference with those. +- These are the commands you are most likely to be using (in that order) + - `nnUNetv2_plan_and_preprocess`. Example: `nnUNetv2_plan_and_preprocess -d 2` + - `nnUNetv2_train`. Example: `nnUNetv2_train 2 3d_fullres 0` + - `nnUNetv2_find_best_configuration`. Example: `nnUNetv2_find_best_configuration 2 -c 2d 3d_fullres`. This command + will now create a `inference_instructions.txt` file in your `nnUNet_preprocessed/DatasetXXX_NAME/` folder which + tells you exactly how to do inference. + - `nnUNetv2_predict`. Example: `nnUNetv2_predict -i INPUT_FOLDER -o OUTPUT_FOLDER -c 3d_fullres -d 2` + - `nnUNetv2_apply_postprocessing` (see inference_instructions.txt) diff --git a/nnUNet/nnunetv2/__init__.py b/nnUNet/nnunetv2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/batch_running/__init__.py b/nnUNet/nnunetv2/batch_running/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/batch_running/benchmarking/__init__.py b/nnUNet/nnunetv2/batch_running/benchmarking/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/batch_running/benchmarking/generate_benchmarking_commands.py b/nnUNet/nnunetv2/batch_running/benchmarking/generate_benchmarking_commands.py new file mode 100644 index 0000000..ca37206 --- /dev/null +++ b/nnUNet/nnunetv2/batch_running/benchmarking/generate_benchmarking_commands.py @@ -0,0 +1,41 @@ +if __name__ == '__main__': + """ + This code probably only works within the DKFZ infrastructure (using LSF). You will need to adapt it to your scheduler! + """ + gpu_models = [#'NVIDIAA100_PCIE_40GB', 'NVIDIAGeForceRTX2080Ti', 'NVIDIATITANRTX', 'TeslaV100_SXM2_32GB', + 'NVIDIAA100_SXM4_40GB']#, 'TeslaV100_PCIE_32GB'] + datasets = [2, 3, 4, 5] + trainers = ['nnUNetTrainerBenchmark_5epochs', 'nnUNetTrainerBenchmark_5epochs_noDataLoading'] + plans = ['nnUNetPlans'] + configs = ['2d', '2d_bs3x', '2d_bs6x', '3d_fullres', '3d_fullres_bs3x', '3d_fullres_bs6x'] + num_gpus = 1 + + benchmark_configurations = {d: configs for d in datasets} + + exclude_hosts = "-R \"select[hname!='e230-dgxa100-1']'\"" + resources = "-R \"tensorcore\"" + queue = "-q gpu" + preamble = "-L /bin/bash \"source ~/load_env_torch210.sh && " + train_command = 'nnUNet_compile=False nnUNet_results=/dkfz/cluster/gpu/checkpoints/OE0441/isensee/nnUNet_results_remake_benchmark nnUNetv2_train' + + folds = (0, ) + + use_these_modules = { + tr: plans for tr in trainers + } + + additional_arguments = f' -num_gpus {num_gpus}' # '' + + output_file = "/home/isensee/deleteme.txt" + with open(output_file, 'w') as f: + for g in gpu_models: + gpu_requirements = f"-gpu num={num_gpus}:j_exclusive=yes:gmodel={g}" + for tr in use_these_modules.keys(): + for p in use_these_modules[tr]: + for dataset in benchmark_configurations.keys(): + for config in benchmark_configurations[dataset]: + for fl in folds: + command = f'bsub {exclude_hosts} {resources} {queue} {gpu_requirements} {preamble} {train_command} {dataset} {config} {fl} -tr {tr} -p {p}' + if additional_arguments is not None and len(additional_arguments) > 0: + command += f' {additional_arguments}' + f.write(f'{command}\"\n') \ No newline at end of file diff --git a/nnUNet/nnunetv2/batch_running/benchmarking/summarize_benchmark_results.py b/nnUNet/nnunetv2/batch_running/benchmarking/summarize_benchmark_results.py new file mode 100644 index 0000000..d966321 --- /dev/null +++ b/nnUNet/nnunetv2/batch_running/benchmarking/summarize_benchmark_results.py @@ -0,0 +1,70 @@ +from batchgenerators.utilities.file_and_folder_operations import join, load_json, isfile +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name +from nnunetv2.paths import nnUNet_results +from nnunetv2.utilities.file_path_utilities import get_output_folder + +if __name__ == '__main__': + trainers = ['nnUNetTrainerBenchmark_5epochs', 'nnUNetTrainerBenchmark_5epochs_noDataLoading'] + datasets = [2, 3, 4, 5] + plans = ['nnUNetPlans'] + configs = ['2d', '2d_bs3x', '2d_bs6x', '3d_fullres', '3d_fullres_bs3x', '3d_fullres_bs6x'] + output_file = join(nnUNet_results, 'benchmark_results.csv') + + torch_version = '2.1.0.dev20230330'#"2.0.0"#"2.1.0.dev20230328" #"1.11.0a0+gitbc2c6ed" # + cudnn_version = 8700 # 8302 # + num_gpus = 1 + + unique_gpus = set() + + # collect results in the most janky way possible. Amazing coding skills! + all_results = {} + for tr in trainers: + all_results[tr] = {} + for p in plans: + all_results[tr][p] = {} + for c in configs: + all_results[tr][p][c] = {} + for d in datasets: + dataset_name = maybe_convert_to_dataset_name(d) + output_folder = get_output_folder(dataset_name, tr, p, c, fold=0) + expected_benchmark_file = join(output_folder, 'benchmark_result.json') + all_results[tr][p][c][d] = {} + if isfile(expected_benchmark_file): + # filter results for what we want + results = [i for i in load_json(expected_benchmark_file).values() + if i['num_gpus'] == num_gpus and i['cudnn_version'] == cudnn_version and + i['torch_version'] == torch_version] + for r in results: + all_results[tr][p][c][d][r['gpu_name']] = r + unique_gpus.add(r['gpu_name']) + + # haha. Fuck this. Collect GPUs in the code above. + # unique_gpus = np.unique([i["gpu_name"] for tr in trainers for p in plans for c in configs for d in datasets for i in all_results[tr][p][c][d]]) + + unique_gpus = list(unique_gpus) + unique_gpus.sort() + + with open(output_file, 'w') as f: + f.write('Dataset,Trainer,Plans,Config') + for g in unique_gpus: + f.write(f",{g}") + f.write("\n") + for d in datasets: + for tr in trainers: + for p in plans: + for c in configs: + gpu_results = [] + for g in unique_gpus: + if g in all_results[tr][p][c][d].keys(): + gpu_results.append(round(all_results[tr][p][c][d][g]["fastest_epoch"], ndigits=2)) + else: + gpu_results.append("MISSING") + # skip if all are missing + if all([i == 'MISSING' for i in gpu_results]): + continue + f.write(f"{d},{tr},{p},{c}") + for g in gpu_results: + f.write(f",{g}") + f.write("\n") + f.write("\n") + diff --git a/nnUNet/nnunetv2/batch_running/collect_results_custom_Decathlon.py b/nnUNet/nnunetv2/batch_running/collect_results_custom_Decathlon.py new file mode 100644 index 0000000..e5079bd --- /dev/null +++ b/nnUNet/nnunetv2/batch_running/collect_results_custom_Decathlon.py @@ -0,0 +1,114 @@ +from typing import Tuple + +import numpy as np +from batchgenerators.utilities.file_and_folder_operations import * + +from nnunetv2.evaluation.evaluate_predictions import load_summary_json +from nnunetv2.paths import nnUNet_results +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name, convert_dataset_name_to_id +from nnunetv2.utilities.file_path_utilities import get_output_folder + + +def collect_results(trainers: dict, datasets: List, output_file: str, + configurations=("2d", "3d_fullres", "3d_lowres", "3d_cascade_fullres"), + folds=tuple(np.arange(5))): + results_dirs = (nnUNet_results,) + datasets_names = [maybe_convert_to_dataset_name(i) for i in datasets] + with open(output_file, 'w') as f: + for i, d in zip(datasets, datasets_names): + for c in configurations: + for module in trainers.keys(): + for plans in trainers[module]: + for r in results_dirs: + expected_output_folder = get_output_folder(d, module, plans, c) + if isdir(expected_output_folder): + results_folds = [] + f.write("%s,%s,%s,%s,%s" % (d, c, module, plans, r)) + for fl in folds: + expected_output_folder_fold = get_output_folder(d, module, plans, c, fl) + expected_summary_file = join(expected_output_folder_fold, "validation", + "summary.json") + if not isfile(expected_summary_file): + print('expected output file not found:', expected_summary_file) + f.write(",") + results_folds.append(np.nan) + else: + foreground_mean = load_summary_json(expected_summary_file)['foreground_mean'][ + 'Dice'] + results_folds.append(foreground_mean) + f.write(",%02.4f" % foreground_mean) + f.write(",%02.4f\n" % np.nanmean(results_folds)) + + +def summarize(input_file, output_file, folds: Tuple[int, ...], configs: Tuple[str, ...], datasets, trainers): + txt = np.loadtxt(input_file, dtype=str, delimiter=',') + num_folds = txt.shape[1] - 6 + valid_configs = {} + for d in datasets: + if isinstance(d, int): + d = maybe_convert_to_dataset_name(d) + configs_in_txt = np.unique(txt[:, 1][txt[:, 0] == d]) + valid_configs[d] = [i for i in configs_in_txt if i in configs] + assert max(folds) < num_folds + + with open(output_file, 'w') as f: + f.write("name") + for d in valid_configs.keys(): + for c in valid_configs[d]: + f.write(",%d_%s" % (convert_dataset_name_to_id(d), c[:4])) + f.write(',mean\n') + valid_entries = txt[:, 4] == nnUNet_results + for t in trainers.keys(): + trainer_locs = valid_entries & (txt[:, 2] == t) + for pl in trainers[t]: + f.write("%s__%s" % (t, pl)) + trainer_plan_locs = trainer_locs & (txt[:, 3] == pl) + r = [] + for d in valid_configs.keys(): + trainer_plan_d_locs = trainer_plan_locs & (txt[:, 0] == d) + for v in valid_configs[d]: + trainer_plan_d_config_locs = trainer_plan_d_locs & (txt[:, 1] == v) + if np.any(trainer_plan_d_config_locs): + # we cannot have more than one row + assert np.sum(trainer_plan_d_config_locs) == 1 + + # now check that we have all folds + selected_row = txt[np.argwhere(trainer_plan_d_config_locs)[0,0]] + + fold_results = selected_row[[i + 5 for i in folds]] + + if '' in fold_results: + print('missing fold in', t, pl, d, v) + f.write(",nan") + r.append(np.nan) + else: + mean_dice = np.mean([float(i) for i in fold_results]) + f.write(",%02.4f" % mean_dice) + r.append(mean_dice) + else: + print('missing:', t, pl, d, v) + f.write(",nan") + r.append(np.nan) + f.write(",%02.4f\n" % np.mean(r)) + + +if __name__ == '__main__': + use_these_trainers = { + 'nnUNetTrainer': ('nnUNetPlans',), + 'nnUNetTrainerDiceCELoss_noSmooth': ('nnUNetPlans',), + 'nnUNetTrainer_DASegOrd0': ('nnUNetPlans',), + } + all_results_file= join(nnUNet_results, 'customDecResults.csv') + datasets = [2, 3, 4, 17, 20, 24, 27, 38, 55, 64, 82] + collect_results(use_these_trainers, datasets, all_results_file) + + folds = (0, 1, 2, 3, 4) + configs = ("3d_fullres", "3d_lowres") + output_file = join(nnUNet_results, 'customDecResults_summary5fold.csv') + summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers) + + folds = (0, ) + configs = ("3d_fullres", "3d_lowres") + output_file = join(nnUNet_results, 'customDecResults_summaryfold0.csv') + summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers) + diff --git a/nnUNet/nnunetv2/batch_running/collect_results_custom_Decathlon_2d.py b/nnUNet/nnunetv2/batch_running/collect_results_custom_Decathlon_2d.py new file mode 100644 index 0000000..2795d3d --- /dev/null +++ b/nnUNet/nnunetv2/batch_running/collect_results_custom_Decathlon_2d.py @@ -0,0 +1,18 @@ +from batchgenerators.utilities.file_and_folder_operations import * + +from nnunetv2.batch_running.collect_results_custom_Decathlon import collect_results, summarize +from nnunetv2.paths import nnUNet_results + +if __name__ == '__main__': + use_these_trainers = { + 'nnUNetTrainer': ('nnUNetPlans', ), + } + all_results_file = join(nnUNet_results, 'hrnet_results.csv') + datasets = [2, 3, 4, 17, 20, 24, 27, 38, 55, 64, 82] + collect_results(use_these_trainers, datasets, all_results_file) + + folds = (0, ) + configs = ('2d', ) + output_file = join(nnUNet_results, 'hrnet_results_summary_fold0.csv') + summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers) + diff --git a/nnUNet/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py b/nnUNet/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py new file mode 100644 index 0000000..0a75fbd --- /dev/null +++ b/nnUNet/nnunetv2/batch_running/generate_lsf_runs_customDecathlon.py @@ -0,0 +1,86 @@ +from copy import deepcopy +import numpy as np + + +def merge(dict1, dict2): + keys = np.unique(list(dict1.keys()) + list(dict2.keys())) + keys = np.unique(keys) + res = {} + for k in keys: + all_configs = [] + if dict1.get(k) is not None: + all_configs += list(dict1[k]) + if dict2.get(k) is not None: + all_configs += list(dict2[k]) + if len(all_configs) > 0: + res[k] = tuple(np.unique(all_configs)) + return res + + +if __name__ == "__main__": + # after the Nature Methods paper we switch our evaluation to a different (more stable/high quality) set of + # datasets for evaluation and future development + configurations_all = { + 2: ("3d_fullres", "2d"), + 3: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + 4: ("2d", "3d_fullres"), + 17: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + 20: ("2d", "3d_fullres"), + 24: ("2d", "3d_fullres"), + 27: ("2d", "3d_fullres"), + 38: ("2d", "3d_fullres"), + 55: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + 64: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + 82: ("2d", "3d_fullres"), + # 83: ("2d", "3d_fullres"), + } + + configurations_3d_fr_only = { + i: ("3d_fullres", ) for i in configurations_all if "3d_fullres" in configurations_all[i] + } + + configurations_3d_c_only = { + i: ("3d_cascade_fullres", ) for i in configurations_all if "3d_cascade_fullres" in configurations_all[i] + } + + configurations_3d_lr_only = { + i: ("3d_lowres", ) for i in configurations_all if "3d_lowres" in configurations_all[i] + } + + configurations_2d_only = { + i: ("2d", ) for i in configurations_all if "2d" in configurations_all[i] + } + + num_gpus = 1 + exclude_hosts = "-R \"select[hname!='e230-dgx2-2']\" -R \"select[hname!='e230-dgx2-1']\" -R \"select[hname!='e230-dgx1-1']\" -R \"select[hname!='e230-dgxa100-1']\" -R \"select[hname!='e230-dgxa100-2']\" -R \"select[hname!='e230-dgxa100-3']\" -R \"select[hname!='e230-dgxa100-4']\"" + resources = "-R \"tensorcore\"" + gpu_requirements = f"-gpu num={num_gpus}:j_exclusive=yes:gmem=33G" + queue = "-q gpu-lowprio" + preamble = "-L /bin/bash \"source ~/load_env_cluster4.sh && " + train_command = 'nnUNet_results=/dkfz/cluster/gpu/checkpoints/OE0441/isensee/nnUNet_results_remake_release nnUNetv2_train' + + folds = (0, ) + # use_this = configurations_2d_only + use_this = merge(configurations_3d_fr_only, configurations_3d_lr_only) + # use_this = merge(use_this, configurations_3d_c_only) + + use_these_modules = { + 'nnUNetTrainer': ('nnUNetPlans',), + 'nnUNetTrainerDiceCELoss_noSmooth': ('nnUNetPlans',), + # 'nnUNetTrainer_DASegOrd0': ('nnUNetPlans',), + } + + additional_arguments = f'--disable_checkpointing -num_gpus {num_gpus}' # '' + + output_file = "/home/isensee/deleteme.txt" + with open(output_file, 'w') as f: + for tr in use_these_modules.keys(): + for p in use_these_modules[tr]: + for dataset in use_this.keys(): + for config in use_this[dataset]: + for fl in folds: + command = f'bsub {exclude_hosts} {resources} {queue} {gpu_requirements} {preamble} {train_command} {dataset} {config} {fl} -tr {tr} -p {p}' + if additional_arguments is not None and len(additional_arguments) > 0: + command += f' {additional_arguments}' + f.write(f'{command}\"\n') + diff --git a/nnUNet/nnunetv2/batch_running/release_trainings/__init__.py b/nnUNet/nnunetv2/batch_running/release_trainings/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/batch_running/release_trainings/nnunetv2_v1/__init__.py b/nnUNet/nnunetv2/batch_running/release_trainings/nnunetv2_v1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/batch_running/release_trainings/nnunetv2_v1/collect_results.py b/nnUNet/nnunetv2/batch_running/release_trainings/nnunetv2_v1/collect_results.py new file mode 100644 index 0000000..f934186 --- /dev/null +++ b/nnUNet/nnunetv2/batch_running/release_trainings/nnunetv2_v1/collect_results.py @@ -0,0 +1,113 @@ +from typing import Tuple + +import numpy as np +from batchgenerators.utilities.file_and_folder_operations import * + +from nnunetv2.evaluation.evaluate_predictions import load_summary_json +from nnunetv2.paths import nnUNet_results +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name, convert_dataset_name_to_id +from nnunetv2.utilities.file_path_utilities import get_output_folder + + +def collect_results(trainers: dict, datasets: List, output_file: str, + configurations=("2d", "3d_fullres", "3d_lowres", "3d_cascade_fullres"), + folds=tuple(np.arange(5))): + results_dirs = (nnUNet_results,) + datasets_names = [maybe_convert_to_dataset_name(i) for i in datasets] + with open(output_file, 'w') as f: + for i, d in zip(datasets, datasets_names): + for c in configurations: + for module in trainers.keys(): + for plans in trainers[module]: + for r in results_dirs: + expected_output_folder = get_output_folder(d, module, plans, c) + if isdir(expected_output_folder): + results_folds = [] + f.write("%s,%s,%s,%s,%s" % (d, c, module, plans, r)) + for fl in folds: + expected_output_folder_fold = get_output_folder(d, module, plans, c, fl) + expected_summary_file = join(expected_output_folder_fold, "validation", + "summary.json") + if not isfile(expected_summary_file): + print('expected output file not found:', expected_summary_file) + f.write(",") + results_folds.append(np.nan) + else: + foreground_mean = load_summary_json(expected_summary_file)['foreground_mean'][ + 'Dice'] + results_folds.append(foreground_mean) + f.write(",%02.4f" % foreground_mean) + f.write(",%02.4f\n" % np.nanmean(results_folds)) + + +def summarize(input_file, output_file, folds: Tuple[int, ...], configs: Tuple[str, ...], datasets, trainers): + txt = np.loadtxt(input_file, dtype=str, delimiter=',') + num_folds = txt.shape[1] - 6 + valid_configs = {} + for d in datasets: + if isinstance(d, int): + d = maybe_convert_to_dataset_name(d) + configs_in_txt = np.unique(txt[:, 1][txt[:, 0] == d]) + valid_configs[d] = [i for i in configs_in_txt if i in configs] + assert max(folds) < num_folds + + with open(output_file, 'w') as f: + f.write("name") + for d in valid_configs.keys(): + for c in valid_configs[d]: + f.write(",%d_%s" % (convert_dataset_name_to_id(d), c[:4])) + f.write(',mean\n') + valid_entries = txt[:, 4] == nnUNet_results + for t in trainers.keys(): + trainer_locs = valid_entries & (txt[:, 2] == t) + for pl in trainers[t]: + f.write("%s__%s" % (t, pl)) + trainer_plan_locs = trainer_locs & (txt[:, 3] == pl) + r = [] + for d in valid_configs.keys(): + trainer_plan_d_locs = trainer_plan_locs & (txt[:, 0] == d) + for v in valid_configs[d]: + trainer_plan_d_config_locs = trainer_plan_d_locs & (txt[:, 1] == v) + if np.any(trainer_plan_d_config_locs): + # we cannot have more than one row + assert np.sum(trainer_plan_d_config_locs) == 1 + + # now check that we have all folds + selected_row = txt[np.argwhere(trainer_plan_d_config_locs)[0,0]] + + fold_results = selected_row[[i + 5 for i in folds]] + + if '' in fold_results: + print('missing fold in', t, pl, d, v) + f.write(",nan") + r.append(np.nan) + else: + mean_dice = np.mean([float(i) for i in fold_results]) + f.write(",%02.4f" % mean_dice) + r.append(mean_dice) + else: + print('missing:', t, pl, d, v) + f.write(",nan") + r.append(np.nan) + f.write(",%02.4f\n" % np.mean(r)) + + +if __name__ == '__main__': + use_these_trainers = { + 'nnUNetTrainer': ('nnUNetPlans',), + 'nnUNetTrainer_v1loss': ('nnUNetPlans',), + } + all_results_file = join(nnUNet_results, 'customDecResults.csv') + datasets = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 17, 20, 24, 27, 35, 38, 48, 55, 64, 82] + collect_results(use_these_trainers, datasets, all_results_file) + + folds = (0, 1, 2, 3, 4) + configs = ("3d_fullres", "3d_lowres") + output_file = join(nnUNet_results, 'customDecResults_summary5fold.csv') + summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers) + + folds = (0, ) + configs = ("3d_fullres", "3d_lowres") + output_file = join(nnUNet_results, 'customDecResults_summaryfold0.csv') + summarize(all_results_file, output_file, folds, configs, datasets, use_these_trainers) + diff --git a/nnUNet/nnunetv2/batch_running/release_trainings/nnunetv2_v1/generate_lsf_commands.py b/nnUNet/nnunetv2/batch_running/release_trainings/nnunetv2_v1/generate_lsf_commands.py new file mode 100644 index 0000000..7c5934f --- /dev/null +++ b/nnUNet/nnunetv2/batch_running/release_trainings/nnunetv2_v1/generate_lsf_commands.py @@ -0,0 +1,93 @@ +from copy import deepcopy +import numpy as np + + +def merge(dict1, dict2): + keys = np.unique(list(dict1.keys()) + list(dict2.keys())) + keys = np.unique(keys) + res = {} + for k in keys: + all_configs = [] + if dict1.get(k) is not None: + all_configs += list(dict1[k]) + if dict2.get(k) is not None: + all_configs += list(dict2[k]) + if len(all_configs) > 0: + res[k] = tuple(np.unique(all_configs)) + return res + + +if __name__ == "__main__": + # after the Nature Methods paper we switch our evaluation to a different (more stable/high quality) set of + # datasets for evaluation and future development + configurations_all = { + # 1: ("3d_fullres", "2d"), + 2: ("3d_fullres", "2d"), + # 3: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + # 4: ("2d", "3d_fullres"), + 5: ("2d", "3d_fullres"), + # 6: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + # 7: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + # 8: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + # 9: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + # 10: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + # 17: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + 20: ("2d", "3d_fullres"), + 24: ("2d", "3d_fullres"), + 27: ("2d", "3d_fullres"), + 35: ("2d", "3d_fullres"), + 38: ("2d", "3d_fullres"), + # 55: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + # 64: ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + # 82: ("2d", "3d_fullres"), + # 83: ("2d", "3d_fullres"), + } + + configurations_3d_fr_only = { + i: ("3d_fullres", ) for i in configurations_all if "3d_fullres" in configurations_all[i] + } + + configurations_3d_c_only = { + i: ("3d_cascade_fullres", ) for i in configurations_all if "3d_cascade_fullres" in configurations_all[i] + } + + configurations_3d_lr_only = { + i: ("3d_lowres", ) for i in configurations_all if "3d_lowres" in configurations_all[i] + } + + configurations_2d_only = { + i: ("2d", ) for i in configurations_all if "2d" in configurations_all[i] + } + + num_gpus = 1 + exclude_hosts = "-R \"select[hname!='e230-dgx2-2']\" -R \"select[hname!='e230-dgx2-1']\"" + resources = "-R \"tensorcore\"" + gpu_requirements = f"-gpu num={num_gpus}:j_exclusive=yes:gmem=1G" + queue = "-q gpu-lowprio" + preamble = "-L /bin/bash \"source ~/load_env_cluster4.sh && " + train_command = 'nnUNet_keep_files_open=True nnUNet_results=/dkfz/cluster/gpu/data/OE0441/isensee/nnUNet_results_remake_release_normfix nnUNetv2_train' + + folds = (0, 1, 2, 3, 4) + # use_this = configurations_2d_only + # use_this = merge(configurations_3d_fr_only, configurations_3d_lr_only) + # use_this = merge(use_this, configurations_3d_c_only) + use_this = configurations_all + + use_these_modules = { + 'nnUNetTrainer': ('nnUNetPlans',), + } + + additional_arguments = f'--disable_checkpointing -num_gpus {num_gpus}' # '' + + output_file = "/home/isensee/deleteme.txt" + with open(output_file, 'w') as f: + for tr in use_these_modules.keys(): + for p in use_these_modules[tr]: + for dataset in use_this.keys(): + for config in use_this[dataset]: + for fl in folds: + command = f'bsub {exclude_hosts} {resources} {queue} {gpu_requirements} {preamble} {train_command} {dataset} {config} {fl} -tr {tr} -p {p}' + if additional_arguments is not None and len(additional_arguments) > 0: + command += f' {additional_arguments}' + f.write(f'{command}\"\n') + diff --git a/nnUNet/nnunetv2/configuration.py b/nnUNet/nnunetv2/configuration.py new file mode 100644 index 0000000..cdc8cb6 --- /dev/null +++ b/nnUNet/nnunetv2/configuration.py @@ -0,0 +1,10 @@ +import os + +from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA + +default_num_processes = 8 if 'nnUNet_def_n_proc' not in os.environ else int(os.environ['nnUNet_def_n_proc']) + +ANISO_THRESHOLD = 3 # determines when a sample is considered anisotropic (3 means that the spacing in the low +# resolution axis must be 3x as large as the next largest spacing) + +default_n_proc_DA = get_allowed_n_proc_DA() diff --git a/nnUNet/nnunetv2/dataset_conversion/Dataset027_ACDC.py b/nnUNet/nnunetv2/dataset_conversion/Dataset027_ACDC.py new file mode 100644 index 0000000..569ff6f --- /dev/null +++ b/nnUNet/nnunetv2/dataset_conversion/Dataset027_ACDC.py @@ -0,0 +1,87 @@ +import os +import shutil +from pathlib import Path + +from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json +from nnunetv2.paths import nnUNet_raw + + +def make_out_dirs(dataset_id: int, task_name="ACDC"): + dataset_name = f"Dataset{dataset_id:03d}_{task_name}" + + out_dir = Path(nnUNet_raw.replace('"', "")) / dataset_name + out_train_dir = out_dir / "imagesTr" + out_labels_dir = out_dir / "labelsTr" + out_test_dir = out_dir / "imagesTs" + + os.makedirs(out_dir, exist_ok=True) + os.makedirs(out_train_dir, exist_ok=True) + os.makedirs(out_labels_dir, exist_ok=True) + os.makedirs(out_test_dir, exist_ok=True) + + return out_dir, out_train_dir, out_labels_dir, out_test_dir + + +def copy_files(src_data_folder: Path, train_dir: Path, labels_dir: Path, test_dir: Path): + """Copy files from the ACDC dataset to the nnUNet dataset folder. Returns the number of training cases.""" + patients_train = sorted([f for f in (src_data_folder / "training").iterdir() if f.is_dir()]) + patients_test = sorted([f for f in (src_data_folder / "testing").iterdir() if f.is_dir()]) + + num_training_cases = 0 + # Copy training files and corresponding labels. + for patient_dir in patients_train: + for file in patient_dir.iterdir(): + if file.suffix == ".gz" and "_gt" not in file.name and "_4d" not in file.name: + # The stem is 'patient.nii', and the suffix is '.gz'. + # We split the stem and append _0000 to the patient part. + shutil.copy(file, train_dir / f"{file.stem.split('.')[0]}_0000.nii.gz") + num_training_cases += 1 + elif file.suffix == ".gz" and "_gt" in file.name: + shutil.copy(file, labels_dir / file.name.replace("_gt", "")) + + # Copy test files. + for patient_dir in patients_test: + for file in patient_dir.iterdir(): + if file.suffix == ".gz" and "_gt" not in file.name and "_4d" not in file.name: + shutil.copy(file, test_dir / f"{file.stem.split('.')[0]}_0000.nii.gz") + + return num_training_cases + + +def convert_acdc(src_data_folder: str, dataset_id=27): + out_dir, train_dir, labels_dir, test_dir = make_out_dirs(dataset_id=dataset_id) + num_training_cases = copy_files(Path(src_data_folder), train_dir, labels_dir, test_dir) + + generate_dataset_json( + str(out_dir), + channel_names={ + 0: "cineMRI", + }, + labels={ + "background": 0, + "RV": 1, + "MLV": 2, + "LVC": 3, + }, + file_ending=".nii.gz", + num_training_cases=num_training_cases, + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "-i", + "--input_folder", + type=str, + help="The downloaded ACDC dataset dir. Should contain extracted 'training' and 'testing' folders.", + ) + parser.add_argument( + "-d", "--dataset_id", required=False, type=int, default=27, help="nnU-Net Dataset ID, default: 27" + ) + args = parser.parse_args() + print("Converting...") + convert_acdc(args.input_folder, args.dataset_id) + print("Done!") diff --git a/nnUNet/nnunetv2/dataset_conversion/Dataset073_Fluo_C3DH_A549_SIM.py b/nnUNet/nnunetv2/dataset_conversion/Dataset073_Fluo_C3DH_A549_SIM.py new file mode 100644 index 0000000..eca22d0 --- /dev/null +++ b/nnUNet/nnunetv2/dataset_conversion/Dataset073_Fluo_C3DH_A549_SIM.py @@ -0,0 +1,85 @@ +from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json +from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed +import tifffile +from batchgenerators.utilities.file_and_folder_operations import * +import shutil + + +if __name__ == '__main__': + """ + This is going to be my test dataset for working with tif as input and output images + + All we do here is copy the files and rename them. Not file conversions take place + """ + dataset_name = 'Dataset073_Fluo_C3DH_A549_SIM' + + imagestr = join(nnUNet_raw, dataset_name, 'imagesTr') + imagests = join(nnUNet_raw, dataset_name, 'imagesTs') + labelstr = join(nnUNet_raw, dataset_name, 'labelsTr') + maybe_mkdir_p(imagestr) + maybe_mkdir_p(imagests) + maybe_mkdir_p(labelstr) + + # we extract the downloaded train and test datasets to two separate folders and name them Fluo-C3DH-A549-SIM_train + # and Fluo-C3DH-A549-SIM_test + train_source = '/home/fabian/Downloads/Fluo-C3DH-A549-SIM_train' + test_source = '/home/fabian/Downloads/Fluo-C3DH-A549-SIM_test' + + # with the old nnU-Net we had to convert all the files to nifti. This is no longer required. We can just copy the + # tif files + + # tif is broken when it comes to spacing. No standards. Grr. So when we use tif nnU-Net expects a separate file + # that specifies the spacing. This file needs to exist for EVERY training/test case to allow for different spacings + # between files. Important! The spacing must align with the axes. + # Here when we do print(tifffile.imread('IMAGE').shape) we get (29, 300, 350). The low resolution axis is the first. + # The spacing on the website is griven in the wrong axis order. Great. + spacing = (1, 0.126, 0.126) + + # train set + for seq in ['01', '02']: + images_dir = join(train_source, seq) + seg_dir = join(train_source, seq + '_GT', 'SEG') + # if we were to be super clean we would go by IDs but here we just trust the files are sorted the correct way. + # Simpler filenames in the cell tracking challenge would be soooo nice. + images = subfiles(images_dir, suffix='.tif', sort=True, join=False) + segs = subfiles(seg_dir, suffix='.tif', sort=True, join=False) + for i, (im, se) in enumerate(zip(images, segs)): + target_name = f'{seq}_image_{i:03d}' + # we still need the '_0000' suffix for images! Otherwise we would not be able to support multiple input + # channels distributed over separate files + shutil.copy(join(images_dir, im), join(imagestr, target_name + '_0000.tif')) + # spacing file! + save_json({'spacing': spacing}, join(imagestr, target_name + '.json')) + shutil.copy(join(seg_dir, se), join(labelstr, target_name + '.tif')) + # spacing file! + save_json({'spacing': spacing}, join(labelstr, target_name + '.json')) + + # test set, same a strain just without the segmentations + for seq in ['01', '02']: + images_dir = join(test_source, seq) + images = subfiles(images_dir, suffix='.tif', sort=True, join=False) + for i, im in enumerate(images): + target_name = f'{seq}_image_{i:03d}' + shutil.copy(join(images_dir, im), join(imagests, target_name + '_0000.tif')) + # spacing file! + save_json({'spacing': spacing}, join(imagests, target_name + '.json')) + + # now we generate the dataset json + generate_dataset_json( + join(nnUNet_raw, dataset_name), + {0: 'fluorescence_microscopy'}, + {'background': 0, 'cell': 1}, + 60, + '.tif' + ) + + # custom split to ensure we are stratifying properly. This dataset only has 2 folds + caseids = [i[:-4] for i in subfiles(labelstr, suffix='.tif', join=False)] + splits = [] + splits.append( + {'train': [i for i in caseids if i.startswith('01_')], 'val': [i for i in caseids if i.startswith('02_')]} + ) + splits.append( + {'train': [i for i in caseids if i.startswith('02_')], 'val': [i for i in caseids if i.startswith('01_')]} + ) + save_json(splits, join(nnUNet_preprocessed, dataset_name, 'splits_final.json')) \ No newline at end of file diff --git a/nnUNet/nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py b/nnUNet/nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py new file mode 100644 index 0000000..90dcc6c --- /dev/null +++ b/nnUNet/nnunetv2/dataset_conversion/Dataset120_RoadSegmentation.py @@ -0,0 +1,87 @@ +import multiprocessing +import shutil +from multiprocessing import Pool + +from batchgenerators.utilities.file_and_folder_operations import * + +from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json +from nnunetv2.paths import nnUNet_raw +from skimage import io +from acvl_utils.morphology.morphology_helper import generic_filter_components +from scipy.ndimage import binary_fill_holes + + +def load_and_covnert_case(input_image: str, input_seg: str, output_image: str, output_seg: str, + min_component_size: int = 50): + seg = io.imread(input_seg) + seg[seg == 255] = 1 + image = io.imread(input_image) + image = image.sum(2) + mask = image == (3 * 255) + # the dataset has large white areas in which road segmentations can exist but no image information is available. + # Remove the road label in these areas + mask = generic_filter_components(mask, filter_fn=lambda ids, sizes: [i for j, i in enumerate(ids) if + sizes[j] > min_component_size]) + mask = binary_fill_holes(mask) + seg[mask] = 0 + io.imsave(output_seg, seg, check_contrast=False) + shutil.copy(input_image, output_image) + + +if __name__ == "__main__": + # extracted archive from https://www.kaggle.com/datasets/insaff/massachusetts-roads-dataset?resource=download + source = '/media/fabian/data/raw_datasets/Massachussetts_road_seg/road_segmentation_ideal' + + dataset_name = 'Dataset120_RoadSegmentation' + + imagestr = join(nnUNet_raw, dataset_name, 'imagesTr') + imagests = join(nnUNet_raw, dataset_name, 'imagesTs') + labelstr = join(nnUNet_raw, dataset_name, 'labelsTr') + labelsts = join(nnUNet_raw, dataset_name, 'labelsTs') + maybe_mkdir_p(imagestr) + maybe_mkdir_p(imagests) + maybe_mkdir_p(labelstr) + maybe_mkdir_p(labelsts) + + train_source = join(source, 'training') + test_source = join(source, 'testing') + + with multiprocessing.get_context("spawn").Pool(8) as p: + + # not all training images have a segmentation + valid_ids = subfiles(join(train_source, 'output'), join=False, suffix='png') + num_train = len(valid_ids) + r = [] + for v in valid_ids: + r.append( + p.starmap_async( + load_and_covnert_case, + (( + join(train_source, 'input', v), + join(train_source, 'output', v), + join(imagestr, v[:-4] + '_0000.png'), + join(labelstr, v), + 50 + ),) + ) + ) + + # test set + valid_ids = subfiles(join(test_source, 'output'), join=False, suffix='png') + for v in valid_ids: + r.append( + p.starmap_async( + load_and_covnert_case, + (( + join(test_source, 'input', v), + join(test_source, 'output', v), + join(imagests, v[:-4] + '_0000.png'), + join(labelsts, v), + 50 + ),) + ) + ) + _ = [i.get() for i in r] + + generate_dataset_json(join(nnUNet_raw, dataset_name), {0: 'R', 1: 'G', 2: 'B'}, {'background': 0, 'road': 1}, + num_train, '.png', dataset_name=dataset_name) diff --git a/nnUNet/nnunetv2/dataset_conversion/Dataset137_BraTS21.py b/nnUNet/nnunetv2/dataset_conversion/Dataset137_BraTS21.py new file mode 100644 index 0000000..b4817d2 --- /dev/null +++ b/nnUNet/nnunetv2/dataset_conversion/Dataset137_BraTS21.py @@ -0,0 +1,98 @@ +import multiprocessing +import shutil +from multiprocessing import Pool + +import SimpleITK as sitk +import numpy as np +from batchgenerators.utilities.file_and_folder_operations import * +from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json +from nnunetv2.paths import nnUNet_raw + + +def copy_BraTS_segmentation_and_convert_labels_to_nnUNet(in_file: str, out_file: str) -> None: + # use this for segmentation only!!! + # nnUNet wants the labels to be continuous. BraTS is 0, 1, 2, 4 -> we make that into 0, 1, 2, 3 + img = sitk.ReadImage(in_file) + img_npy = sitk.GetArrayFromImage(img) + + uniques = np.unique(img_npy) + for u in uniques: + if u not in [0, 1, 2, 4]: + raise RuntimeError('unexpected label') + + seg_new = np.zeros_like(img_npy) + seg_new[img_npy == 4] = 3 + seg_new[img_npy == 2] = 1 + seg_new[img_npy == 1] = 2 + img_corr = sitk.GetImageFromArray(seg_new) + img_corr.CopyInformation(img) + sitk.WriteImage(img_corr, out_file) + + +def convert_labels_back_to_BraTS(seg: np.ndarray): + new_seg = np.zeros_like(seg) + new_seg[seg == 1] = 2 + new_seg[seg == 3] = 4 + new_seg[seg == 2] = 1 + return new_seg + + +def load_convert_labels_back_to_BraTS(filename, input_folder, output_folder): + a = sitk.ReadImage(join(input_folder, filename)) + b = sitk.GetArrayFromImage(a) + c = convert_labels_back_to_BraTS(b) + d = sitk.GetImageFromArray(c) + d.CopyInformation(a) + sitk.WriteImage(d, join(output_folder, filename)) + + +def convert_folder_with_preds_back_to_BraTS_labeling_convention(input_folder: str, output_folder: str, num_processes: int = 12): + """ + reads all prediction files (nifti) in the input folder, converts the labels back to BraTS convention and saves the + """ + maybe_mkdir_p(output_folder) + nii = subfiles(input_folder, suffix='.nii.gz', join=False) + with multiprocessing.get_context("spawn").Pool(num_processes) as p: + p.starmap(load_convert_labels_back_to_BraTS, zip(nii, [input_folder] * len(nii), [output_folder] * len(nii))) + + +if __name__ == '__main__': + brats_data_dir = '/home/isensee/drives/E132-Rohdaten/BraTS_2021/training' + + task_id = 137 + task_name = "BraTS2021" + + foldername = "Dataset%03.0d_%s" % (task_id, task_name) + + # setting up nnU-Net folders + out_base = join(nnUNet_raw, foldername) + imagestr = join(out_base, "imagesTr") + labelstr = join(out_base, "labelsTr") + maybe_mkdir_p(imagestr) + maybe_mkdir_p(labelstr) + + case_ids = subdirs(brats_data_dir, prefix='BraTS', join=False) + + for c in case_ids: + shutil.copy(join(brats_data_dir, c, c + "_t1.nii.gz"), join(imagestr, c + '_0000.nii.gz')) + shutil.copy(join(brats_data_dir, c, c + "_t1ce.nii.gz"), join(imagestr, c + '_0001.nii.gz')) + shutil.copy(join(brats_data_dir, c, c + "_t2.nii.gz"), join(imagestr, c + '_0002.nii.gz')) + shutil.copy(join(brats_data_dir, c, c + "_flair.nii.gz"), join(imagestr, c + '_0003.nii.gz')) + + copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, c, c + "_seg.nii.gz"), + join(labelstr, c + '.nii.gz')) + + generate_dataset_json(out_base, + channel_names={0: 'T1', 1: 'T1ce', 2: 'T2', 3: 'Flair'}, + labels={ + 'background': 0, + 'whole tumor': (1, 2, 3), + 'tumor core': (2, 3), + 'enhancing tumor': (3, ) + }, + num_training_cases=len(case_ids), + file_ending='.nii.gz', + regions_class_order=(1, 2, 3), + license='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863', + reference='see https://www.synapse.org/#!Synapse:syn25829067/wiki/610863', + dataset_release='1.0') diff --git a/nnUNet/nnunetv2/dataset_conversion/Dataset218_Amos2022_task1.py b/nnUNet/nnunetv2/dataset_conversion/Dataset218_Amos2022_task1.py new file mode 100644 index 0000000..1f33cd7 --- /dev/null +++ b/nnUNet/nnunetv2/dataset_conversion/Dataset218_Amos2022_task1.py @@ -0,0 +1,70 @@ +from batchgenerators.utilities.file_and_folder_operations import * +import shutil +from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json +from nnunetv2.paths import nnUNet_raw + + +def convert_amos_task1(amos_base_dir: str, nnunet_dataset_id: int = 218): + """ + AMOS doesn't say anything about how the validation set is supposed to be used. So we just incorporate that into + the train set. Having a 5-fold cross-validation is superior to a single train:val split + """ + task_name = "AMOS2022_postChallenge_task1" + + foldername = "Dataset%03.0d_%s" % (nnunet_dataset_id, task_name) + + # setting up nnU-Net folders + out_base = join(nnUNet_raw, foldername) + imagestr = join(out_base, "imagesTr") + imagests = join(out_base, "imagesTs") + labelstr = join(out_base, "labelsTr") + maybe_mkdir_p(imagestr) + maybe_mkdir_p(imagests) + maybe_mkdir_p(labelstr) + + dataset_json_source = load_json(join(amos_base_dir, 'dataset.json')) + + training_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['training']] + tr_ctr = 0 + for tr in training_identifiers: + if int(tr.split("_")[-1]) <= 410: # these are the CT images + tr_ctr += 1 + shutil.copy(join(amos_base_dir, 'imagesTr', tr + '.nii.gz'), join(imagestr, f'{tr}_0000.nii.gz')) + shutil.copy(join(amos_base_dir, 'labelsTr', tr + '.nii.gz'), join(labelstr, f'{tr}.nii.gz')) + + test_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['test']] + for ts in test_identifiers: + if int(ts.split("_")[-1]) <= 500: # these are the CT images + shutil.copy(join(amos_base_dir, 'imagesTs', ts + '.nii.gz'), join(imagests, f'{ts}_0000.nii.gz')) + + val_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['validation']] + for vl in val_identifiers: + if int(vl.split("_")[-1]) <= 409: # these are the CT images + tr_ctr += 1 + shutil.copy(join(amos_base_dir, 'imagesVa', vl + '.nii.gz'), join(imagestr, f'{vl}_0000.nii.gz')) + shutil.copy(join(amos_base_dir, 'labelsVa', vl + '.nii.gz'), join(labelstr, f'{vl}.nii.gz')) + + generate_dataset_json(out_base, {0: "CT"}, labels={v: int(k) for k,v in dataset_json_source['labels'].items()}, + num_training_cases=tr_ctr, file_ending='.nii.gz', + dataset_name=task_name, reference='https://amos22.grand-challenge.org/', + release='https://zenodo.org/record/7262581', + overwrite_image_reader_writer='NibabelIOWithReorient', + description="This is the dataset as released AFTER the challenge event. It has the " + "validation set gt in it! We just use the validation images as additional " + "training cases because AMOS doesn't specify how they should be used. nnU-Net's" + " 5-fold CV is better than some random train:val split.") + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('input_folder', type=str, + help="The downloaded and extracted AMOS2022 (https://amos22.grand-challenge.org/) data. " + "Use this link: https://zenodo.org/record/7262581." + "You need to specify the folder with the imagesTr, imagesVal, labelsTr etc subfolders here!") + parser.add_argument('-d', required=False, type=int, default=218, help='nnU-Net Dataset ID, default: 218') + args = parser.parse_args() + amos_base = args.input_folder + convert_amos_task1(amos_base, args.d) + + diff --git a/nnUNet/nnunetv2/dataset_conversion/Dataset219_Amos2022_task2.py b/nnUNet/nnunetv2/dataset_conversion/Dataset219_Amos2022_task2.py new file mode 100644 index 0000000..9a5e2c6 --- /dev/null +++ b/nnUNet/nnunetv2/dataset_conversion/Dataset219_Amos2022_task2.py @@ -0,0 +1,65 @@ +from batchgenerators.utilities.file_and_folder_operations import * +import shutil +from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json +from nnunetv2.paths import nnUNet_raw + + +def convert_amos_task2(amos_base_dir: str, nnunet_dataset_id: int = 219): + """ + AMOS doesn't say anything about how the validation set is supposed to be used. So we just incorporate that into + the train set. Having a 5-fold cross-validation is superior to a single train:val split + """ + task_name = "AMOS2022_postChallenge_task2" + + foldername = "Dataset%03.0d_%s" % (nnunet_dataset_id, task_name) + + # setting up nnU-Net folders + out_base = join(nnUNet_raw, foldername) + imagestr = join(out_base, "imagesTr") + imagests = join(out_base, "imagesTs") + labelstr = join(out_base, "labelsTr") + maybe_mkdir_p(imagestr) + maybe_mkdir_p(imagests) + maybe_mkdir_p(labelstr) + + dataset_json_source = load_json(join(amos_base_dir, 'dataset.json')) + + training_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['training']] + for tr in training_identifiers: + shutil.copy(join(amos_base_dir, 'imagesTr', tr + '.nii.gz'), join(imagestr, f'{tr}_0000.nii.gz')) + shutil.copy(join(amos_base_dir, 'labelsTr', tr + '.nii.gz'), join(labelstr, f'{tr}.nii.gz')) + + test_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['test']] + for ts in test_identifiers: + shutil.copy(join(amos_base_dir, 'imagesTs', ts + '.nii.gz'), join(imagests, f'{ts}_0000.nii.gz')) + + val_identifiers = [i['image'].split('/')[-1][:-7] for i in dataset_json_source['validation']] + for vl in val_identifiers: + shutil.copy(join(amos_base_dir, 'imagesVa', vl + '.nii.gz'), join(imagestr, f'{vl}_0000.nii.gz')) + shutil.copy(join(amos_base_dir, 'labelsVa', vl + '.nii.gz'), join(labelstr, f'{vl}.nii.gz')) + + generate_dataset_json(out_base, {0: "either_CT_or_MR"}, labels={v: int(k) for k,v in dataset_json_source['labels'].items()}, + num_training_cases=len(training_identifiers) + len(val_identifiers), file_ending='.nii.gz', + dataset_name=task_name, reference='https://amos22.grand-challenge.org/', + release='https://zenodo.org/record/7262581', + overwrite_image_reader_writer='NibabelIOWithReorient', + description="This is the dataset as released AFTER the challenge event. It has the " + "validation set gt in it! We just use the validation images as additional " + "training cases because AMOS doesn't specify how they should be used. nnU-Net's" + " 5-fold CV is better than some random train:val split.") + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('input_folder', type=str, + help="The downloaded and extracted AMOS2022 (https://amos22.grand-challenge.org/) data. " + "Use this link: https://zenodo.org/record/7262581." + "You need to specify the folder with the imagesTr, imagesVal, labelsTr etc subfolders here!") + parser.add_argument('-d', required=False, type=int, default=219, help='nnU-Net Dataset ID, default: 219') + args = parser.parse_args() + amos_base = args.input_folder + convert_amos_task2(amos_base, args.d) + + # /home/isensee/Downloads/amos22/amos22/ + diff --git a/nnUNet/nnunetv2/dataset_conversion/Dataset220_KiTS2023.py b/nnUNet/nnunetv2/dataset_conversion/Dataset220_KiTS2023.py new file mode 100644 index 0000000..20a794c --- /dev/null +++ b/nnUNet/nnunetv2/dataset_conversion/Dataset220_KiTS2023.py @@ -0,0 +1,50 @@ +from batchgenerators.utilities.file_and_folder_operations import * +import shutil +from nnunetv2.dataset_conversion.generate_dataset_json import generate_dataset_json +from nnunetv2.paths import nnUNet_raw + + +def convert_kits2023(kits_base_dir: str, nnunet_dataset_id: int = 220): + task_name = "KiTS2023" + + foldername = "Dataset%03.0d_%s" % (nnunet_dataset_id, task_name) + + # setting up nnU-Net folders + out_base = join(nnUNet_raw, foldername) + imagestr = join(out_base, "imagesTr") + labelstr = join(out_base, "labelsTr") + maybe_mkdir_p(imagestr) + maybe_mkdir_p(labelstr) + + cases = subdirs(kits_base_dir, prefix='case_', join=False) + for tr in cases: + shutil.copy(join(kits_base_dir, tr, 'imaging.nii.gz'), join(imagestr, f'{tr}_0000.nii.gz')) + shutil.copy(join(kits_base_dir, tr, 'segmentation.nii.gz'), join(labelstr, f'{tr}.nii.gz')) + + generate_dataset_json(out_base, {0: "CT"}, + labels={ + "background": 0, + "kidney": (1, 2, 3), + "masses": (2, 3), + "tumor": 2 + }, + regions_class_order=(1, 3, 2), + num_training_cases=len(cases), file_ending='.nii.gz', + dataset_name=task_name, reference='none', + release='prerelease', + overwrite_image_reader_writer='NibabelIOWithReorient', + description="KiTS2023") + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('input_folder', type=str, + help="The downloaded and extracted KiTS2023 dataset (must have case_XXXXX subfolders)") + parser.add_argument('-d', required=False, type=int, default=220, help='nnU-Net Dataset ID, default: 220') + args = parser.parse_args() + amos_base = args.input_folder + convert_kits2023(amos_base, args.d) + + # /media/isensee/raw_data/raw_datasets/kits23/dataset + diff --git a/nnUNet/nnunetv2/dataset_conversion/Dataset988_dummyDataset4.py b/nnUNet/nnunetv2/dataset_conversion/Dataset988_dummyDataset4.py new file mode 100644 index 0000000..80b295d --- /dev/null +++ b/nnUNet/nnunetv2/dataset_conversion/Dataset988_dummyDataset4.py @@ -0,0 +1,32 @@ +import os + +from batchgenerators.utilities.file_and_folder_operations import * + +from nnunetv2.paths import nnUNet_raw +from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets + +if __name__ == '__main__': + # creates a dummy dataset where there are no files in imagestr and labelstr + source_dataset = 'Dataset004_Hippocampus' + + target_dataset = 'Dataset987_dummyDataset4' + target_dataset_dir = join(nnUNet_raw, target_dataset) + maybe_mkdir_p(target_dataset_dir) + + dataset = get_filenames_of_train_images_and_targets(join(nnUNet_raw, source_dataset)) + + # the returned dataset will have absolute paths. We should use relative paths so that you can freely copy + # datasets around between systems. As long as the source dataset is there it will continue working even if + # nnUNet_raw is in different locations + + # paths must be relative to target_dataset_dir!!! + for k in dataset.keys(): + dataset[k]['label'] = os.path.relpath(dataset[k]['label'], target_dataset_dir) + dataset[k]['images'] = [os.path.relpath(i, target_dataset_dir) for i in dataset[k]['images']] + + # load old dataset.json + dataset_json = load_json(join(nnUNet_raw, source_dataset, 'dataset.json')) + dataset_json['dataset'] = dataset + + # save + save_json(dataset_json, join(target_dataset_dir, 'dataset.json'), sort_keys=False) diff --git a/nnUNet/nnunetv2/dataset_conversion/__init__.py b/nnUNet/nnunetv2/dataset_conversion/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/dataset_conversion/convert_MSD_dataset.py b/nnUNet/nnunetv2/dataset_conversion/convert_MSD_dataset.py new file mode 100644 index 0000000..97aac29 --- /dev/null +++ b/nnUNet/nnunetv2/dataset_conversion/convert_MSD_dataset.py @@ -0,0 +1,132 @@ +import argparse +import multiprocessing +import shutil +from multiprocessing import Pool +from typing import Optional +import SimpleITK as sitk +from batchgenerators.utilities.file_and_folder_operations import * +from nnunetv2.paths import nnUNet_raw +from nnunetv2.utilities.dataset_name_id_conversion import find_candidate_datasets +from nnunetv2.configuration import default_num_processes +import numpy as np + + +def split_4d_nifti(filename, output_folder): + img_itk = sitk.ReadImage(filename) + dim = img_itk.GetDimension() + file_base = os.path.basename(filename) + if dim == 3: + shutil.copy(filename, join(output_folder, file_base[:-7] + "_0000.nii.gz")) + return + elif dim != 4: + raise RuntimeError("Unexpected dimensionality: %d of file %s, cannot split" % (dim, filename)) + else: + img_npy = sitk.GetArrayFromImage(img_itk) + spacing = img_itk.GetSpacing() + origin = img_itk.GetOrigin() + direction = np.array(img_itk.GetDirection()).reshape(4,4) + # now modify these to remove the fourth dimension + spacing = tuple(list(spacing[:-1])) + origin = tuple(list(origin[:-1])) + direction = tuple(direction[:-1, :-1].reshape(-1)) + for i, t in enumerate(range(img_npy.shape[0])): + img = img_npy[t] + img_itk_new = sitk.GetImageFromArray(img) + img_itk_new.SetSpacing(spacing) + img_itk_new.SetOrigin(origin) + img_itk_new.SetDirection(direction) + sitk.WriteImage(img_itk_new, join(output_folder, file_base[:-7] + "_%04.0d.nii.gz" % i)) + + +def convert_msd_dataset(source_folder: str, overwrite_target_id: Optional[int] = None, + num_processes: int = default_num_processes) -> None: + if source_folder.endswith('/') or source_folder.endswith('\\'): + source_folder = source_folder[:-1] + + labelsTr = join(source_folder, 'labelsTr') + imagesTs = join(source_folder, 'imagesTs') + imagesTr = join(source_folder, 'imagesTr') + assert isdir(labelsTr), f"labelsTr subfolder missing in source folder" + assert isdir(imagesTs), f"imagesTs subfolder missing in source folder" + assert isdir(imagesTr), f"imagesTr subfolder missing in source folder" + dataset_json = join(source_folder, 'dataset.json') + assert isfile(dataset_json), f"dataset.json missing in source_folder" + + # infer source dataset id and name + task, dataset_name = os.path.basename(source_folder).split('_') + task_id = int(task[4:]) + + # check if target dataset id is taken + target_id = task_id if overwrite_target_id is None else overwrite_target_id + existing_datasets = find_candidate_datasets(target_id) + assert len(existing_datasets) == 0, f"Target dataset id {target_id} is already taken, please consider changing " \ + f"it using overwrite_target_id. Conflicting dataset: {existing_datasets} (check nnUNet_results, nnUNet_preprocessed and nnUNet_raw!)" + + target_dataset_name = f"Dataset{target_id:03d}_{dataset_name}" + target_folder = join(nnUNet_raw, target_dataset_name) + target_imagesTr = join(target_folder, 'imagesTr') + target_imagesTs = join(target_folder, 'imagesTs') + target_labelsTr = join(target_folder, 'labelsTr') + maybe_mkdir_p(target_imagesTr) + maybe_mkdir_p(target_imagesTs) + maybe_mkdir_p(target_labelsTr) + + with multiprocessing.get_context("spawn").Pool(num_processes) as p: + results = [] + + # convert 4d train images + source_images = [i for i in subfiles(imagesTr, suffix='.nii.gz', join=False) if + not i.startswith('.') and not i.startswith('_')] + source_images = [join(imagesTr, i) for i in source_images] + + results.append( + p.starmap_async( + split_4d_nifti, zip(source_images, [target_imagesTr] * len(source_images)) + ) + ) + + # convert 4d test images + source_images = [i for i in subfiles(imagesTs, suffix='.nii.gz', join=False) if + not i.startswith('.') and not i.startswith('_')] + source_images = [join(imagesTs, i) for i in source_images] + + results.append( + p.starmap_async( + split_4d_nifti, zip(source_images, [target_imagesTs] * len(source_images)) + ) + ) + + # copy segmentations + source_images = [i for i in subfiles(labelsTr, suffix='.nii.gz', join=False) if + not i.startswith('.') and not i.startswith('_')] + for s in source_images: + shutil.copy(join(labelsTr, s), join(target_labelsTr, s)) + + [i.get() for i in results] + + dataset_json = load_json(dataset_json) + dataset_json['labels'] = {j: int(i) for i, j in dataset_json['labels'].items()} + dataset_json['file_ending'] = ".nii.gz" + dataset_json["channel_names"] = dataset_json["modality"] + del dataset_json["modality"] + del dataset_json["training"] + del dataset_json["test"] + save_json(dataset_json, join(nnUNet_raw, target_dataset_name, 'dataset.json'), sort_keys=False) + + +def entry_point(): + parser = argparse.ArgumentParser() + parser.add_argument('-i', type=str, required=True, + help='Downloaded and extracted MSD dataset folder. CANNOT be nnUNetv1 dataset! Example: ' + '/home/fabian/Downloads/Task05_Prostate') + parser.add_argument('-overwrite_id', type=int, required=False, default=None, + help='Overwrite the dataset id. If not set we use the id of the MSD task (inferred from ' + 'folder name). Only use this if you already have an equivalently numbered dataset!') + parser.add_argument('-np', type=int, required=False, default=default_num_processes, + help=f'Number of processes used. Default: {default_num_processes}') + args = parser.parse_args() + convert_msd_dataset(args.i, args.overwrite_id, args.np) + + +if __name__ == '__main__': + convert_msd_dataset('/home/fabian/Downloads/Task05_Prostate', overwrite_target_id=201) diff --git a/nnUNet/nnunetv2/dataset_conversion/convert_raw_dataset_from_old_nnunet_format.py b/nnUNet/nnunetv2/dataset_conversion/convert_raw_dataset_from_old_nnunet_format.py new file mode 100644 index 0000000..fb77533 --- /dev/null +++ b/nnUNet/nnunetv2/dataset_conversion/convert_raw_dataset_from_old_nnunet_format.py @@ -0,0 +1,53 @@ +import shutil +from copy import deepcopy + +from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p, isdir, load_json, save_json +from nnunetv2.paths import nnUNet_raw + + +def convert(source_folder, target_dataset_name): + """ + remember that old tasks were called TaskXXX_YYY and new ones are called DatasetXXX_YYY + source_folder + """ + if isdir(join(nnUNet_raw, target_dataset_name)): + raise RuntimeError(f'Target dataset name {target_dataset_name} already exists. Aborting... ' + f'(we might break something). If you are sure you want to proceed, please manually ' + f'delete {join(nnUNet_raw, target_dataset_name)}') + maybe_mkdir_p(join(nnUNet_raw, target_dataset_name)) + shutil.copytree(join(source_folder, 'imagesTr'), join(nnUNet_raw, target_dataset_name, 'imagesTr')) + shutil.copytree(join(source_folder, 'labelsTr'), join(nnUNet_raw, target_dataset_name, 'labelsTr')) + if isdir(join(source_folder, 'imagesTs')): + shutil.copytree(join(source_folder, 'imagesTs'), join(nnUNet_raw, target_dataset_name, 'imagesTs')) + if isdir(join(source_folder, 'labelsTs')): + shutil.copytree(join(source_folder, 'labelsTs'), join(nnUNet_raw, target_dataset_name, 'labelsTs')) + if isdir(join(source_folder, 'imagesVal')): + shutil.copytree(join(source_folder, 'imagesVal'), join(nnUNet_raw, target_dataset_name, 'imagesVal')) + if isdir(join(source_folder, 'labelsVal')): + shutil.copytree(join(source_folder, 'labelsVal'), join(nnUNet_raw, target_dataset_name, 'labelsVal')) + shutil.copy(join(source_folder, 'dataset.json'), join(nnUNet_raw, target_dataset_name)) + + dataset_json = load_json(join(nnUNet_raw, target_dataset_name, 'dataset.json')) + del dataset_json['tensorImageSize'] + del dataset_json['numTest'] + del dataset_json['training'] + del dataset_json['test'] + dataset_json['channel_names'] = deepcopy(dataset_json['modality']) + del dataset_json['modality'] + + dataset_json['labels'] = {j: int(i) for i, j in dataset_json['labels'].items()} + dataset_json['file_ending'] = ".nii.gz" + save_json(dataset_json, join(nnUNet_raw, target_dataset_name, 'dataset.json'), sort_keys=False) + + +def convert_entry_point(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("input_folder", type=str, + help='Raw old nnUNet dataset. This must be the folder with imagesTr,labelsTr etc subfolders! ' + 'Please provide the PATH to the old Task, not just the task name. nnU-Net V2 does not ' + 'know where v1 tasks are.') + parser.add_argument("output_dataset_name", type=str, + help='New dataset NAME (not path!). Must follow the DatasetXXX_NAME convention!') + args = parser.parse_args() + convert(args.input_folder, args.output_dataset_name) diff --git a/nnUNet/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset996_IntegrationTest_Hippocampus_regions_ignore.py b/nnUNet/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset996_IntegrationTest_Hippocampus_regions_ignore.py new file mode 100644 index 0000000..d59fc8e --- /dev/null +++ b/nnUNet/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset996_IntegrationTest_Hippocampus_regions_ignore.py @@ -0,0 +1,75 @@ +import SimpleITK as sitk +import shutil + +import numpy as np +from batchgenerators.utilities.file_and_folder_operations import isdir, join, load_json, save_json, nifti_files + +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name +from nnunetv2.paths import nnUNet_raw +from nnunetv2.utilities.label_handling.label_handling import LabelManager +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager + + +def sparsify_segmentation(seg: np.ndarray, label_manager: LabelManager, percent_of_slices: float) -> np.ndarray: + assert label_manager.has_ignore_label, "This preprocessor only works with datasets that have an ignore label!" + seg_new = np.ones_like(seg) * label_manager.ignore_label + x, y, z = seg.shape + # x + num_slices = max(1, round(x * percent_of_slices)) + selected_slices = np.random.choice(x, num_slices, replace=False) + seg_new[selected_slices] = seg[selected_slices] + # y + num_slices = max(1, round(y * percent_of_slices)) + selected_slices = np.random.choice(y, num_slices, replace=False) + seg_new[:, selected_slices] = seg[:, selected_slices] + # z + num_slices = max(1, round(z * percent_of_slices)) + selected_slices = np.random.choice(z, num_slices, replace=False) + seg_new[:, :, selected_slices] = seg[:, :, selected_slices] + return seg_new + + +if __name__ == '__main__': + dataset_name = 'IntegrationTest_Hippocampus_regions_ignore' + dataset_id = 996 + dataset_name = f"Dataset{dataset_id:03d}_{dataset_name}" + + try: + existing_dataset_name = maybe_convert_to_dataset_name(dataset_id) + if existing_dataset_name != dataset_name: + raise FileExistsError(f"A different dataset with id {dataset_id} already exists :-(: {existing_dataset_name}. If " + f"you intent to delete it, remember to also remove it in nnUNet_preprocessed and " + f"nnUNet_results!") + except RuntimeError: + pass + + if isdir(join(nnUNet_raw, dataset_name)): + shutil.rmtree(join(nnUNet_raw, dataset_name)) + + source_dataset = maybe_convert_to_dataset_name(4) + shutil.copytree(join(nnUNet_raw, source_dataset), join(nnUNet_raw, dataset_name)) + + # additionally optimize entire hippocampus region, remove Posterior + dj = load_json(join(nnUNet_raw, dataset_name, 'dataset.json')) + dj['labels'] = { + 'background': 0, + 'hippocampus': (1, 2), + 'anterior': 1, + 'ignore': 3 + } + dj['regions_class_order'] = (2, 1) + save_json(dj, join(nnUNet_raw, dataset_name, 'dataset.json'), sort_keys=False) + + # now add ignore label to segmentation images + np.random.seed(1234) + lm = LabelManager(label_dict=dj['labels'], regions_class_order=dj.get('regions_class_order')) + + segs = nifti_files(join(nnUNet_raw, dataset_name, 'labelsTr')) + for s in segs: + seg_itk = sitk.ReadImage(s) + seg_npy = sitk.GetArrayFromImage(seg_itk) + seg_npy = sparsify_segmentation(seg_npy, lm, 0.1 / 3) + seg_itk_new = sitk.GetImageFromArray(seg_npy) + seg_itk_new.CopyInformation(seg_itk) + sitk.WriteImage(seg_itk_new, s) + diff --git a/nnUNet/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset997_IntegrationTest_Hippocampus_regions.py b/nnUNet/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset997_IntegrationTest_Hippocampus_regions.py new file mode 100644 index 0000000..b40c534 --- /dev/null +++ b/nnUNet/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset997_IntegrationTest_Hippocampus_regions.py @@ -0,0 +1,37 @@ +import shutil + +from batchgenerators.utilities.file_and_folder_operations import isdir, join, load_json, save_json + +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name +from nnunetv2.paths import nnUNet_raw + +if __name__ == '__main__': + dataset_name = 'IntegrationTest_Hippocampus_regions' + dataset_id = 997 + dataset_name = f"Dataset{dataset_id:03d}_{dataset_name}" + + try: + existing_dataset_name = maybe_convert_to_dataset_name(dataset_id) + if existing_dataset_name != dataset_name: + raise FileExistsError( + f"A different dataset with id {dataset_id} already exists :-(: {existing_dataset_name}. If " + f"you intent to delete it, remember to also remove it in nnUNet_preprocessed and " + f"nnUNet_results!") + except RuntimeError: + pass + + if isdir(join(nnUNet_raw, dataset_name)): + shutil.rmtree(join(nnUNet_raw, dataset_name)) + + source_dataset = maybe_convert_to_dataset_name(4) + shutil.copytree(join(nnUNet_raw, source_dataset), join(nnUNet_raw, dataset_name)) + + # additionally optimize entire hippocampus region, remove Posterior + dj = load_json(join(nnUNet_raw, dataset_name, 'dataset.json')) + dj['labels'] = { + 'background': 0, + 'hippocampus': (1, 2), + 'anterior': 1 + } + dj['regions_class_order'] = (2, 1) + save_json(dj, join(nnUNet_raw, dataset_name, 'dataset.json'), sort_keys=False) diff --git a/nnUNet/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset998_IntegrationTest_Hippocampus_ignore.py b/nnUNet/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset998_IntegrationTest_Hippocampus_ignore.py new file mode 100644 index 0000000..1781a27 --- /dev/null +++ b/nnUNet/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset998_IntegrationTest_Hippocampus_ignore.py @@ -0,0 +1,33 @@ +import shutil + +from batchgenerators.utilities.file_and_folder_operations import isdir, join, load_json, save_json + +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name +from nnunetv2.paths import nnUNet_raw + + +if __name__ == '__main__': + dataset_name = 'IntegrationTest_Hippocampus_ignore' + dataset_id = 998 + dataset_name = f"Dataset{dataset_id:03d}_{dataset_name}" + + try: + existing_dataset_name = maybe_convert_to_dataset_name(dataset_id) + if existing_dataset_name != dataset_name: + raise FileExistsError(f"A different dataset with id {dataset_id} already exists :-(: {existing_dataset_name}. If " + f"you intent to delete it, remember to also remove it in nnUNet_preprocessed and " + f"nnUNet_results!") + except RuntimeError: + pass + + if isdir(join(nnUNet_raw, dataset_name)): + shutil.rmtree(join(nnUNet_raw, dataset_name)) + + source_dataset = maybe_convert_to_dataset_name(4) + shutil.copytree(join(nnUNet_raw, source_dataset), join(nnUNet_raw, dataset_name)) + + # set class 2 to ignore label + dj = load_json(join(nnUNet_raw, dataset_name, 'dataset.json')) + dj['labels']['ignore'] = 2 + del dj['labels']['Posterior'] + save_json(dj, join(nnUNet_raw, dataset_name, 'dataset.json'), sort_keys=False) diff --git a/nnUNet/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset999_IntegrationTest_Hippocampus.py b/nnUNet/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset999_IntegrationTest_Hippocampus.py new file mode 100644 index 0000000..33075da --- /dev/null +++ b/nnUNet/nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset999_IntegrationTest_Hippocampus.py @@ -0,0 +1,27 @@ +import shutil + +from batchgenerators.utilities.file_and_folder_operations import isdir, join + +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name +from nnunetv2.paths import nnUNet_raw + + +if __name__ == '__main__': + dataset_name = 'IntegrationTest_Hippocampus' + dataset_id = 999 + dataset_name = f"Dataset{dataset_id:03d}_{dataset_name}" + + try: + existing_dataset_name = maybe_convert_to_dataset_name(dataset_id) + if existing_dataset_name != dataset_name: + raise FileExistsError(f"A different dataset with id {dataset_id} already exists :-(: {existing_dataset_name}. If " + f"you intent to delete it, remember to also remove it in nnUNet_preprocessed and " + f"nnUNet_results!") + except RuntimeError: + pass + + if isdir(join(nnUNet_raw, dataset_name)): + shutil.rmtree(join(nnUNet_raw, dataset_name)) + + source_dataset = maybe_convert_to_dataset_name(4) + shutil.copytree(join(nnUNet_raw, source_dataset), join(nnUNet_raw, dataset_name)) diff --git a/nnUNet/nnunetv2/dataset_conversion/datasets_for_integration_tests/__init__.py b/nnUNet/nnunetv2/dataset_conversion/datasets_for_integration_tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/dataset_conversion/generate_dataset_json.py b/nnUNet/nnunetv2/dataset_conversion/generate_dataset_json.py new file mode 100644 index 0000000..2f2e115 --- /dev/null +++ b/nnUNet/nnunetv2/dataset_conversion/generate_dataset_json.py @@ -0,0 +1,103 @@ +from typing import Tuple + +from batchgenerators.utilities.file_and_folder_operations import save_json, join + + +def generate_dataset_json(output_folder: str, + channel_names: dict, + labels: dict, + num_training_cases: int, + file_ending: str, + regions_class_order: Tuple[int, ...] = None, + dataset_name: str = None, reference: str = None, release: str = None, license: str = None, + description: str = None, + overwrite_image_reader_writer: str = None, **kwargs): + """ + Generates a dataset.json file in the output folder + + channel_names: + Channel names must map the index to the name of the channel, example: + { + 0: 'T1', + 1: 'CT' + } + Note that the channel names may influence the normalization scheme!! Learn more in the documentation. + + labels: + This will tell nnU-Net what labels to expect. Important: This will also determine whether you use region-based training or not. + Example regular labels: + { + 'background': 0, + 'left atrium': 1, + 'some other label': 2 + } + Example region-based training: + { + 'background': 0, + 'whole tumor': (1, 2, 3), + 'tumor core': (2, 3), + 'enhancing tumor': 3 + } + + Remember that nnU-Net expects consecutive values for labels! nnU-Net also expects 0 to be background! + + num_training_cases: is used to double check all cases are there! + + file_ending: needed for finding the files correctly. IMPORTANT! File endings must match between images and + segmentations! + + dataset_name, reference, release, license, description: self-explanatory and not used by nnU-Net. Just for + completeness and as a reminder that these would be great! + + overwrite_image_reader_writer: If you need a special IO class for your dataset you can derive it from + BaseReaderWriter, place it into nnunet.imageio and reference it here by name + + kwargs: whatever you put here will be placed in the dataset.json as well + + """ + has_regions: bool = any([isinstance(i, (tuple, list)) and len(i) > 1 for i in labels.values()]) + if has_regions: + assert regions_class_order is not None, f"You have defined regions but regions_class_order is not set. " \ + f"You need that." + # channel names need strings as keys + keys = list(channel_names.keys()) + for k in keys: + if not isinstance(k, str): + channel_names[str(k)] = channel_names[k] + del channel_names[k] + + # labels need ints as values + for l in labels.keys(): + value = labels[l] + if isinstance(value, (tuple, list)): + value = tuple([int(i) for i in value]) + labels[l] = value + else: + labels[l] = int(labels[l]) + + dataset_json = { + 'channel_names': channel_names, # previously this was called 'modality'. I didnt like this so this is + # channel_names now. Live with it. + 'labels': labels, + 'numTraining': num_training_cases, + 'file_ending': file_ending, + } + + if dataset_name is not None: + dataset_json['name'] = dataset_name + if reference is not None: + dataset_json['reference'] = reference + if release is not None: + dataset_json['release'] = release + if license is not None: + dataset_json['licence'] = license + if description is not None: + dataset_json['description'] = description + if overwrite_image_reader_writer is not None: + dataset_json['overwrite_image_reader_writer'] = overwrite_image_reader_writer + if regions_class_order is not None: + dataset_json['regions_class_order'] = regions_class_order + + dataset_json.update(kwargs) + + save_json(dataset_json, join(output_folder, 'dataset.json'), sort_keys=False) diff --git a/nnUNet/nnunetv2/ensembling/__init__.py b/nnUNet/nnunetv2/ensembling/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/ensembling/ensemble.py b/nnUNet/nnunetv2/ensembling/ensemble.py new file mode 100644 index 0000000..68b378b --- /dev/null +++ b/nnUNet/nnunetv2/ensembling/ensemble.py @@ -0,0 +1,206 @@ +import argparse +import multiprocessing +import shutil +from copy import deepcopy +from multiprocessing import Pool +from typing import List, Union, Tuple + +import numpy as np +from batchgenerators.utilities.file_and_folder_operations import load_json, join, subfiles, \ + maybe_mkdir_p, isdir, save_pickle, load_pickle, isfile +from nnunetv2.configuration import default_num_processes +from nnunetv2.imageio.base_reader_writer import BaseReaderWriter +from nnunetv2.utilities.label_handling.label_handling import LabelManager +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager + + +def average_probabilities(list_of_files: List[str]) -> np.ndarray: + assert len(list_of_files), 'At least one file must be given in list_of_files' + avg = None + for f in list_of_files: + if avg is None: + avg = np.load(f)['probabilities'] + # maybe increase precision to prevent rounding errors + if avg.dtype != np.float32: + avg = avg.astype(np.float32) + else: + avg += np.load(f)['probabilities'] + avg /= len(list_of_files) + return avg + + +def merge_files(list_of_files, + output_filename_truncated: str, + output_file_ending: str, + image_reader_writer: BaseReaderWriter, + label_manager: LabelManager, + save_probabilities: bool = False): + # load the pkl file associated with the first file in list_of_files + properties = load_pickle(list_of_files[0][:-4] + '.pkl') + # load and average predictions + probabilities = average_probabilities(list_of_files) + segmentation = label_manager.convert_logits_to_segmentation(probabilities) + image_reader_writer.write_seg(segmentation, output_filename_truncated + output_file_ending, properties) + if save_probabilities: + np.savez_compressed(output_filename_truncated + '.npz', probabilities=probabilities) + save_pickle(probabilities, output_filename_truncated + '.pkl') + + +def ensemble_folders(list_of_input_folders: List[str], + output_folder: str, + save_merged_probabilities: bool = False, + num_processes: int = default_num_processes, + dataset_json_file_or_dict: str = None, + plans_json_file_or_dict: str = None): + """we need too much shit for this function. Problem is that we now have to support region-based training plus + multiple input/output formats so there isn't really a way around this. + + If plans and dataset json are not specified, we assume each of the folders has a corresponding plans.json + and/or dataset.json in it. These are usually copied into those folders by nnU-Net during prediction. + We just pick the dataset.json and plans.json from the first of the folders and we DONT check whether the 5 + folders contain the same plans etc! This can be a feature if results from different datasets are to be merged (only + works if label dict in dataset.json is the same between these datasets!!!)""" + if dataset_json_file_or_dict is not None: + if isinstance(dataset_json_file_or_dict, str): + dataset_json = load_json(dataset_json_file_or_dict) + else: + dataset_json = dataset_json_file_or_dict + else: + dataset_json = load_json(join(list_of_input_folders[0], 'dataset.json')) + + if plans_json_file_or_dict is not None: + if isinstance(plans_json_file_or_dict, str): + plans = load_json(plans_json_file_or_dict) + else: + plans = plans_json_file_or_dict + else: + plans = load_json(join(list_of_input_folders[0], 'plans.json')) + + plans_manager = PlansManager(plans) + + # now collect the files in each of the folders and enforce that all files are present in all folders + files_per_folder = [set(subfiles(i, suffix='.npz', join=False)) for i in list_of_input_folders] + # first build a set with all files + s = deepcopy(files_per_folder[0]) + for f in files_per_folder[1:]: + s.update(f) + for f in files_per_folder: + assert len(s.difference(f)) == 0, "Not all folders contain the same files for ensembling. Please only " \ + "provide folders that contain the predictions" + lists_of_lists_of_files = [[join(fl, fi) for fl in list_of_input_folders] for fi in s] + output_files_truncated = [join(output_folder, fi[:-4]) for fi in s] + + image_reader_writer = plans_manager.image_reader_writer_class() + label_manager = plans_manager.get_label_manager(dataset_json) + + maybe_mkdir_p(output_folder) + shutil.copy(join(list_of_input_folders[0], 'dataset.json'), output_folder) + + with multiprocessing.get_context("spawn").Pool(num_processes) as pool: + num_preds = len(s) + _ = pool.starmap( + merge_files, + zip( + lists_of_lists_of_files, + output_files_truncated, + [dataset_json['file_ending']] * num_preds, + [image_reader_writer] * num_preds, + [label_manager] * num_preds, + [save_merged_probabilities] * num_preds + ) + ) + + +def entry_point_ensemble_folders(): + parser = argparse.ArgumentParser() + parser.add_argument('-i', nargs='+', type=str, required=True, + help='list of input folders') + parser.add_argument('-o', type=str, required=True, help='output folder') + parser.add_argument('-np', type=int, required=False, default=default_num_processes, + help=f"Numbers of processes used for ensembling. Default: {default_num_processes}") + parser.add_argument('--save_npz', action='store_true', required=False, help='Set this flag to store output ' + 'probabilities in separate .npz files') + + args = parser.parse_args() + ensemble_folders(args.i, args.o, args.save_npz, args.np) + + +def ensemble_crossvalidations(list_of_trained_model_folders: List[str], + output_folder: str, + folds: Union[Tuple[int, ...], List[int]] = (0, 1, 2, 3, 4), + num_processes: int = default_num_processes, + overwrite: bool = True) -> None: + """ + Feature: different configurations can now have different splits + """ + dataset_json = load_json(join(list_of_trained_model_folders[0], 'dataset.json')) + plans_manager = PlansManager(join(list_of_trained_model_folders[0], 'plans.json')) + + # first collect all unique filenames + files_per_folder = {} + unique_filenames = set() + for tr in list_of_trained_model_folders: + files_per_folder[tr] = {} + for f in folds: + if not isdir(join(tr, f'fold_{f}', 'validation')): + raise RuntimeError(f'Expected model output directory does not exist. You must train all requested ' + f'folds of the speficied model.\nModel: {tr}\nFold: {f}') + files_here = subfiles(join(tr, f'fold_{f}', 'validation'), suffix='.npz', join=False) + if len(files_here) == 0: + raise RuntimeError(f"No .npz files found in folder {join(tr, f'fold_{f}', 'validation')}. Rerun your " + f"validation with the --npz flag. Use nnUNetv2_train [...] --val --npz.") + files_per_folder[tr][f] = subfiles(join(tr, f'fold_{f}', 'validation'), suffix='.npz', join=False) + unique_filenames.update(files_per_folder[tr][f]) + + # verify that all trained_model_folders have all predictions + ok = True + for tr, fi in files_per_folder.items(): + all_files_here = set() + for f in folds: + all_files_here.update(fi[f]) + diff = unique_filenames.difference(all_files_here) + if len(diff) > 0: + ok = False + print(f'model {tr} does not seem to contain all predictions. Missing: {diff}') + if not ok: + raise RuntimeError('There were missing files, see print statements above this one') + + # now we need to collect where these files are + file_mapping = [] + for tr in list_of_trained_model_folders: + file_mapping.append({}) + for f in folds: + for fi in files_per_folder[tr][f]: + # check for duplicates + assert fi not in file_mapping[-1].keys(), f"Duplicate detected. Case {fi} is present in more than " \ + f"one fold of model {tr}." + file_mapping[-1][fi] = join(tr, f'fold_{f}', 'validation', fi) + + lists_of_lists_of_files = [[fm[i] for fm in file_mapping] for i in unique_filenames] + output_files_truncated = [join(output_folder, fi[:-4]) for fi in unique_filenames] + + image_reader_writer = plans_manager.image_reader_writer_class() + maybe_mkdir_p(output_folder) + label_manager = plans_manager.get_label_manager(dataset_json) + + if not overwrite: + tmp = [isfile(i + dataset_json['file_ending']) for i in output_files_truncated] + lists_of_lists_of_files = [lists_of_lists_of_files[i] for i in range(len(tmp)) if not tmp[i]] + output_files_truncated = [output_files_truncated[i] for i in range(len(tmp)) if not tmp[i]] + + with multiprocessing.get_context("spawn").Pool(num_processes) as pool: + num_preds = len(lists_of_lists_of_files) + _ = pool.starmap( + merge_files, + zip( + lists_of_lists_of_files, + output_files_truncated, + [dataset_json['file_ending']] * num_preds, + [image_reader_writer] * num_preds, + [label_manager] * num_preds, + [False] * num_preds + ) + ) + + shutil.copy(join(list_of_trained_model_folders[0], 'plans.json'), join(output_folder, 'plans.json')) + shutil.copy(join(list_of_trained_model_folders[0], 'dataset.json'), join(output_folder, 'dataset.json')) diff --git a/nnUNet/nnunetv2/evaluation/__init__.py b/nnUNet/nnunetv2/evaluation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/evaluation/accumulate_cv_results.py b/nnUNet/nnunetv2/evaluation/accumulate_cv_results.py new file mode 100644 index 0000000..6db2129 --- /dev/null +++ b/nnUNet/nnunetv2/evaluation/accumulate_cv_results.py @@ -0,0 +1,55 @@ +import shutil +from typing import Union, List, Tuple + +from batchgenerators.utilities.file_and_folder_operations import load_json, join, isdir, maybe_mkdir_p, subfiles, isfile + +from nnunetv2.configuration import default_num_processes +from nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder +from nnunetv2.paths import nnUNet_raw +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager + + +def accumulate_cv_results(trained_model_folder, + merged_output_folder: str, + folds: Union[List[int], Tuple[int, ...]], + num_processes: int = default_num_processes, + overwrite: bool = True): + """ + There are a lot of things that can get fucked up, so the simplest way to deal with potential problems is to + collect the cv results into a separate folder and then evaluate them again. No messing with summary_json files! + """ + + if overwrite and isdir(merged_output_folder): + shutil.rmtree(merged_output_folder) + maybe_mkdir_p(merged_output_folder) + + dataset_json = load_json(join(trained_model_folder, 'dataset.json')) + plans_manager = PlansManager(join(trained_model_folder, 'plans.json')) + rw = plans_manager.image_reader_writer_class() + shutil.copy(join(trained_model_folder, 'dataset.json'), join(merged_output_folder, 'dataset.json')) + shutil.copy(join(trained_model_folder, 'plans.json'), join(merged_output_folder, 'plans.json')) + + did_we_copy_something = False + for f in folds: + expected_validation_folder = join(trained_model_folder, f'fold_{f}', 'validation') + if not isdir(expected_validation_folder): + raise RuntimeError(f"fold {f} of model {trained_model_folder} is missing. Please train it!") + predicted_files = subfiles(expected_validation_folder, suffix=dataset_json['file_ending'], join=False) + for pf in predicted_files: + if overwrite and isfile(join(merged_output_folder, pf)): + raise RuntimeError(f'More than one of your folds has a prediction for case {pf}') + if overwrite or not isfile(join(merged_output_folder, pf)): + shutil.copy(join(expected_validation_folder, pf), join(merged_output_folder, pf)) + did_we_copy_something = True + + if did_we_copy_something or not isfile(join(merged_output_folder, 'summary.json')): + label_manager = plans_manager.get_label_manager(dataset_json) + compute_metrics_on_folder(join(nnUNet_raw, plans_manager.dataset_name, 'labelsTr'), + merged_output_folder, + join(merged_output_folder, 'summary.json'), + rw, + dataset_json['file_ending'], + label_manager.foreground_regions if label_manager.has_regions else + label_manager.foreground_labels, + label_manager.ignore_label, + num_processes) diff --git a/nnUNet/nnunetv2/evaluation/evaluate_predictions.py b/nnUNet/nnunetv2/evaluation/evaluate_predictions.py new file mode 100644 index 0000000..e692f78 --- /dev/null +++ b/nnUNet/nnunetv2/evaluation/evaluate_predictions.py @@ -0,0 +1,264 @@ +import multiprocessing +import os +from copy import deepcopy +from multiprocessing import Pool +from typing import Tuple, List, Union, Optional + +import numpy as np +from batchgenerators.utilities.file_and_folder_operations import subfiles, join, save_json, load_json, \ + isfile +from nnunetv2.configuration import default_num_processes +from nnunetv2.imageio.base_reader_writer import BaseReaderWriter +from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json, \ + determine_reader_writer_from_file_ending +from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO +# the Evaluator class of the previous nnU-Net was great and all but man was it overengineered. Keep it simple +from nnunetv2.utilities.json_export import recursive_fix_for_json_export +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager + + +def label_or_region_to_key(label_or_region: Union[int, Tuple[int]]): + return str(label_or_region) + + +def key_to_label_or_region(key: str): + try: + return int(key) + except ValueError: + key = key.replace('(', '') + key = key.replace(')', '') + splitted = key.split(',') + return tuple([int(i) for i in splitted]) + + +def save_summary_json(results: dict, output_file: str): + """ + stupid json does not support tuples as keys (why does it have to be so shitty) so we need to convert that shit + ourselves + """ + results_converted = deepcopy(results) + # convert keys in mean metrics + results_converted['mean'] = {label_or_region_to_key(k): results['mean'][k] for k in results['mean'].keys()} + # convert metric_per_case + for i in range(len(results_converted["metric_per_case"])): + results_converted["metric_per_case"][i]['metrics'] = \ + {label_or_region_to_key(k): results["metric_per_case"][i]['metrics'][k] + for k in results["metric_per_case"][i]['metrics'].keys()} + # sort_keys=True will make foreground_mean the first entry and thus easy to spot + save_json(results_converted, output_file, sort_keys=True) + + +def load_summary_json(filename: str): + results = load_json(filename) + # convert keys in mean metrics + results['mean'] = {key_to_label_or_region(k): results['mean'][k] for k in results['mean'].keys()} + # convert metric_per_case + for i in range(len(results["metric_per_case"])): + results["metric_per_case"][i]['metrics'] = \ + {key_to_label_or_region(k): results["metric_per_case"][i]['metrics'][k] + for k in results["metric_per_case"][i]['metrics'].keys()} + return results + + +def labels_to_list_of_regions(labels: List[int]): + return [(i,) for i in labels] + + +def region_or_label_to_mask(segmentation: np.ndarray, region_or_label: Union[int, Tuple[int, ...]]) -> np.ndarray: + if np.isscalar(region_or_label): + return segmentation == region_or_label + else: + mask = np.zeros_like(segmentation, dtype=bool) + for r in region_or_label: + mask[segmentation == r] = True + return mask + + +def compute_tp_fp_fn_tn(mask_ref: np.ndarray, mask_pred: np.ndarray, ignore_mask: np.ndarray = None): + if ignore_mask is None: + use_mask = np.ones_like(mask_ref, dtype=bool) + else: + use_mask = ~ignore_mask + tp = np.sum((mask_ref & mask_pred) & use_mask) + fp = np.sum(((~mask_ref) & mask_pred) & use_mask) + fn = np.sum((mask_ref & (~mask_pred)) & use_mask) + tn = np.sum(((~mask_ref) & (~mask_pred)) & use_mask) + return tp, fp, fn, tn + + +def compute_metrics(reference_file: str, prediction_file: str, image_reader_writer: BaseReaderWriter, + labels_or_regions: Union[List[int], List[Union[int, Tuple[int, ...]]]], + ignore_label: int = None) -> dict: + # load images + seg_ref, seg_ref_dict = image_reader_writer.read_seg(reference_file) + seg_pred, seg_pred_dict = image_reader_writer.read_seg(prediction_file) + # spacing = seg_ref_dict['spacing'] + + ignore_mask = seg_ref == ignore_label if ignore_label is not None else None + + results = {} + results['reference_file'] = reference_file + results['prediction_file'] = prediction_file + results['metrics'] = {} + for r in labels_or_regions: + results['metrics'][r] = {} + mask_ref = region_or_label_to_mask(seg_ref, r) + mask_pred = region_or_label_to_mask(seg_pred, r) + tp, fp, fn, tn = compute_tp_fp_fn_tn(mask_ref, mask_pred, ignore_mask) + if tp + fp + fn == 0: + results['metrics'][r]['Dice'] = np.nan + results['metrics'][r]['IoU'] = np.nan + else: + results['metrics'][r]['Dice'] = 2 * tp / (2 * tp + fp + fn) + results['metrics'][r]['IoU'] = tp / (tp + fp + fn) + results['metrics'][r]['FP'] = fp + results['metrics'][r]['TP'] = tp + results['metrics'][r]['FN'] = fn + results['metrics'][r]['TN'] = tn + results['metrics'][r]['n_pred'] = fp + tp + results['metrics'][r]['n_ref'] = fn + tp + return results + + +def compute_metrics_on_folder(folder_ref: str, folder_pred: str, output_file: str, + image_reader_writer: BaseReaderWriter, + file_ending: str, + regions_or_labels: Union[List[int], List[Union[int, Tuple[int, ...]]]], + ignore_label: int = None, + num_processes: int = default_num_processes, + chill: bool = True) -> dict: + """ + output_file must end with .json; can be None + """ + if output_file is not None: + assert output_file.endswith('.json'), 'output_file should end with .json' + files_pred = subfiles(folder_pred, suffix=file_ending, join=False) + files_ref = subfiles(folder_ref, suffix=file_ending, join=False) + if not chill: + present = [isfile(join(folder_pred, i)) for i in files_ref] + assert all(present), "Not all files in folder_pred exist in folder_ref" + files_ref = [join(folder_ref, i) for i in files_pred] + files_pred = [join(folder_pred, i) for i in files_pred] + with multiprocessing.get_context("spawn").Pool(num_processes) as pool: + # for i in list(zip(files_ref, files_pred, [image_reader_writer] * len(files_pred), [regions_or_labels] * len(files_pred), [ignore_label] * len(files_pred))): + # compute_metrics(*i) + results = pool.starmap( + compute_metrics, + list(zip(files_ref, files_pred, [image_reader_writer] * len(files_pred), [regions_or_labels] * len(files_pred), + [ignore_label] * len(files_pred))) + ) + + # mean metric per class + metric_list = list(results[0]['metrics'][regions_or_labels[0]].keys()) + means = {} + for r in regions_or_labels: + means[r] = {} + for m in metric_list: + means[r][m] = np.nanmean([i['metrics'][r][m] for i in results]) + + # foreground mean + foreground_mean = {} + for m in metric_list: + values = [] + for k in means.keys(): + if k == 0 or k == '0': + continue + values.append(means[k][m]) + foreground_mean[m] = np.mean(values) + + [recursive_fix_for_json_export(i) for i in results] + recursive_fix_for_json_export(means) + recursive_fix_for_json_export(foreground_mean) + result = {'metric_per_case': results, 'mean': means, 'foreground_mean': foreground_mean} + if output_file is not None: + save_summary_json(result, output_file) + return result + # print('DONE') + + +def compute_metrics_on_folder2(folder_ref: str, folder_pred: str, dataset_json_file: str, plans_file: str, + output_file: str = None, + num_processes: int = default_num_processes, + chill: bool = False): + dataset_json = load_json(dataset_json_file) + # get file ending + file_ending = dataset_json['file_ending'] + + # get reader writer class + example_file = subfiles(folder_ref, suffix=file_ending, join=True)[0] + rw = determine_reader_writer_from_dataset_json(dataset_json, example_file)() + + # maybe auto set output file + if output_file is None: + output_file = join(folder_pred, 'summary.json') + + lm = PlansManager(plans_file).get_label_manager(dataset_json) + compute_metrics_on_folder(folder_ref, folder_pred, output_file, rw, file_ending, + lm.foreground_regions if lm.has_regions else lm.foreground_labels, lm.ignore_label, + num_processes, chill=chill) + + +def compute_metrics_on_folder_simple(folder_ref: str, folder_pred: str, labels: Union[Tuple[int, ...], List[int]], + output_file: str = None, + num_processes: int = default_num_processes, + ignore_label: int = None, + chill: bool = False): + example_file = subfiles(folder_ref, join=True)[0] + file_ending = os.path.splitext(example_file)[-1] + rw = determine_reader_writer_from_file_ending(file_ending, example_file, allow_nonmatching_filename=True, + verbose=False)() + # maybe auto set output file + if output_file is None: + output_file = join(folder_pred, 'summary.json') + compute_metrics_on_folder(folder_ref, folder_pred, output_file, rw, file_ending, + labels, ignore_label=ignore_label, num_processes=num_processes, chill=chill) + + +def evaluate_folder_entry_point(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('gt_folder', type=str, help='folder with gt segmentations') + parser.add_argument('pred_folder', type=str, help='folder with predicted segmentations') + parser.add_argument('-djfile', type=str, required=True, + help='dataset.json file') + parser.add_argument('-pfile', type=str, required=True, + help='plans.json file') + parser.add_argument('-o', type=str, required=False, default=None, + help='Output file. Optional. Default: pred_folder/summary.json') + parser.add_argument('-np', type=int, required=False, default=default_num_processes, + help=f'number of processes used. Optional. Default: {default_num_processes}') + parser.add_argument('--chill', action='store_true', help='dont crash if folder_pred doesnt have all files that are present in folder_gt') + args = parser.parse_args() + compute_metrics_on_folder2(args.gt_folder, args.pred_folder, args.djfile, args.pfile, args.o, args.np, chill=args.chill) + + +def evaluate_simple_entry_point(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('gt_folder', type=str, help='folder with gt segmentations') + parser.add_argument('pred_folder', type=str, help='folder with predicted segmentations') + parser.add_argument('-l', type=int, nargs='+', required=True, + help='list of labels') + parser.add_argument('-il', type=int, required=False, default=None, + help='ignore label') + parser.add_argument('-o', type=str, required=False, default=None, + help='Output file. Optional. Default: pred_folder/summary.json') + parser.add_argument('-np', type=int, required=False, default=default_num_processes, + help=f'number of processes used. Optional. Default: {default_num_processes}') + parser.add_argument('--chill', action='store_true', help='dont crash if folder_pred doesnt have all files that are present in folder_gt') + + args = parser.parse_args() + compute_metrics_on_folder_simple(args.gt_folder, args.pred_folder, args.l, args.o, args.np, args.il, chill=args.chill) + + +if __name__ == '__main__': + folder_ref = '/media/fabian/data/nnUNet_raw/Dataset004_Hippocampus/labelsTr' + folder_pred = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetModule__nnUNetPlans__3d_fullres/fold_0/validation' + output_file = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetModule__nnUNetPlans__3d_fullres/fold_0/validation/summary.json' + image_reader_writer = SimpleITKIO() + file_ending = '.nii.gz' + regions = labels_to_list_of_regions([1, 2]) + ignore_label = None + num_processes = 12 + compute_metrics_on_folder(folder_ref, folder_pred, output_file, image_reader_writer, file_ending, regions, ignore_label, + num_processes) diff --git a/nnUNet/nnunetv2/evaluation/find_best_configuration.py b/nnUNet/nnunetv2/evaluation/find_best_configuration.py new file mode 100644 index 0000000..c36008b --- /dev/null +++ b/nnUNet/nnunetv2/evaluation/find_best_configuration.py @@ -0,0 +1,333 @@ +import argparse +import os.path +from copy import deepcopy +from typing import Union, List, Tuple + +from batchgenerators.utilities.file_and_folder_operations import load_json, join, isdir, save_json + +from nnunetv2.configuration import default_num_processes +from nnunetv2.ensembling.ensemble import ensemble_crossvalidations +from nnunetv2.evaluation.accumulate_cv_results import accumulate_cv_results +from nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder, load_summary_json +from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw, nnUNet_results +from nnunetv2.postprocessing.remove_connected_components import determine_postprocessing +from nnunetv2.utilities.file_path_utilities import maybe_convert_to_dataset_name, get_output_folder, \ + convert_identifier_to_trainer_plans_config, get_ensemble_name, folds_tuple_to_string +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager + +default_trained_models = tuple([ + {'plans': 'nnUNetPlans', 'configuration': '2d', 'trainer': 'nnUNetTrainer'}, + {'plans': 'nnUNetPlans', 'configuration': '3d_fullres', 'trainer': 'nnUNetTrainer'}, + {'plans': 'nnUNetPlans', 'configuration': '3d_lowres', 'trainer': 'nnUNetTrainer'}, + {'plans': 'nnUNetPlans', 'configuration': '3d_cascade_fullres', 'trainer': 'nnUNetTrainer'}, +]) + + +def filter_available_models(model_dict: Union[List[dict], Tuple[dict, ...]], dataset_name_or_id: Union[str, int]): + valid = [] + for trained_model in model_dict: + plans_manager = PlansManager(join(nnUNet_preprocessed, maybe_convert_to_dataset_name(dataset_name_or_id), + trained_model['plans'] + '.json')) + # check if configuration exists + # 3d_cascade_fullres and 3d_lowres do not exist for each dataset so we allow them to be absent IF they are not + # specified in the plans file + if trained_model['configuration'] not in plans_manager.available_configurations: + print(f"Configuration {trained_model['configuration']} not found in plans {trained_model['plans']}.\n" + f"Inferred plans file: {join(nnUNet_preprocessed, maybe_convert_to_dataset_name(dataset_name_or_id), trained_model['plans'] + '.json')}.") + continue + + # check if trained model output folder exists. This is a requirement. No mercy here. + expected_output_folder = get_output_folder(dataset_name_or_id, trained_model['trainer'], trained_model['plans'], + trained_model['configuration'], fold=None) + if not isdir(expected_output_folder): + raise RuntimeError(f"Trained model {trained_model} does not have an output folder. " + f"Expected: {expected_output_folder}. Please run the training for this model! (don't forget " + f"the --npz flag if you want to ensemble multiple configurations)") + + valid.append(trained_model) + return valid + + +def generate_inference_command(dataset_name_or_id: Union[int, str], configuration_name: str, + plans_identifier: str = 'nnUNetPlans', trainer_name: str = 'nnUNetTrainer', + folds: Union[List[int], Tuple[int, ...]] = (0, 1, 2, 3, 4), + folder_with_segs_from_prev_stage: str = None, + input_folder: str = 'INPUT_FOLDER', + output_folder: str = 'OUTPUT_FOLDER', + save_npz: bool = False): + fold_str = '' + for f in folds: + fold_str += f' {f}' + + predict_command = '' + trained_model_folder = get_output_folder(dataset_name_or_id, trainer_name, plans_identifier, configuration_name, fold=None) + plans_manager = PlansManager(join(trained_model_folder, 'plans.json')) + configuration_manager = plans_manager.get_configuration(configuration_name) + if 'previous_stage' in plans_manager.available_configurations: + prev_stage = configuration_manager.previous_stage_name + predict_command += generate_inference_command(dataset_name_or_id, prev_stage, plans_identifier, trainer_name, + folds, None, output_folder='OUTPUT_FOLDER_PREV_STAGE') + '\n' + folder_with_segs_from_prev_stage = 'OUTPUT_FOLDER_PREV_STAGE' + + predict_command = f'nnUNetv2_predict -d {dataset_name_or_id} -i {input_folder} -o {output_folder} -f {fold_str} ' \ + f'-tr {trainer_name} -c {configuration_name} -p {plans_identifier}' + if folder_with_segs_from_prev_stage is not None: + predict_command += f' -prev_stage_predictions {folder_with_segs_from_prev_stage}' + if save_npz: + predict_command += ' --save_probabilities' + return predict_command + + +def find_best_configuration(dataset_name_or_id, + allowed_trained_models: Union[List[dict], Tuple[dict, ...]] = default_trained_models, + allow_ensembling: bool = True, + num_processes: int = default_num_processes, + overwrite: bool = True, + folds: Union[List[int], Tuple[int, ...]] = (0, 1, 2, 3, 4), + strict: bool = False): + dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id) + all_results = {} + + allowed_trained_models = filter_available_models(deepcopy(allowed_trained_models), dataset_name_or_id) + + for m in allowed_trained_models: + output_folder = get_output_folder(dataset_name_or_id, m['trainer'], m['plans'], m['configuration'], fold=None) + if not isdir(output_folder) and strict: + raise RuntimeError(f'{dataset_name}: The output folder of plans {m["plans"]} configuration ' + f'{m["configuration"]} is missing. Please train the model (all requested folds!) first!') + identifier = os.path.basename(output_folder) + merged_output_folder = join(output_folder, f'crossval_results_folds_{folds_tuple_to_string(folds)}') + accumulate_cv_results(output_folder, merged_output_folder, folds, num_processes, overwrite) + all_results[identifier] = { + 'source': merged_output_folder, + 'result': load_summary_json(join(merged_output_folder, 'summary.json'))['foreground_mean']['Dice'] + } + + if allow_ensembling: + for i in range(len(allowed_trained_models)): + for j in range(i + 1, len(allowed_trained_models)): + m1, m2 = allowed_trained_models[i], allowed_trained_models[j] + + output_folder_1 = get_output_folder(dataset_name_or_id, m1['trainer'], m1['plans'], m1['configuration'], fold=None) + output_folder_2 = get_output_folder(dataset_name_or_id, m2['trainer'], m2['plans'], m2['configuration'], fold=None) + identifier = get_ensemble_name(output_folder_1, output_folder_2, folds) + + output_folder_ensemble = join(nnUNet_results, dataset_name, 'ensembles', identifier) + + ensemble_crossvalidations([output_folder_1, output_folder_2], output_folder_ensemble, folds, + num_processes, overwrite=overwrite) + + # evaluate ensembled predictions + plans_manager = PlansManager(join(output_folder_1, 'plans.json')) + dataset_json = load_json(join(output_folder_1, 'dataset.json')) + label_manager = plans_manager.get_label_manager(dataset_json) + rw = plans_manager.image_reader_writer_class() + + compute_metrics_on_folder(join(nnUNet_preprocessed, dataset_name, 'gt_segmentations'), + output_folder_ensemble, + join(output_folder_ensemble, 'summary.json'), + rw, + dataset_json['file_ending'], + label_manager.foreground_regions if label_manager.has_regions else + label_manager.foreground_labels, + label_manager.ignore_label, + num_processes) + all_results[identifier] = \ + { + 'source': output_folder_ensemble, + 'result': load_summary_json(join(output_folder_ensemble, 'summary.json'))['foreground_mean']['Dice'] + } + + # pick best and report inference command + best_score = max([i['result'] for i in all_results.values()]) + best_keys = [k for k in all_results.keys() if all_results[k]['result'] == best_score] # may never happen but theoretically + # there can be a tie. Let's pick the first model in this case because it's going to be the simpler one (ensembles + # come after single configs) + best_key = best_keys[0] + + print() + print('***All results:***') + for k, v in all_results.items(): + print(f'{k}: {v["result"]}') + print(f'\n*Best*: {best_key}: {all_results[best_key]["result"]}') + print() + + print('***Determining postprocessing for best model/ensemble***') + determine_postprocessing(all_results[best_key]['source'], join(nnUNet_preprocessed, dataset_name, 'gt_segmentations'), + plans_file_or_dict=join(all_results[best_key]['source'], 'plans.json'), + dataset_json_file_or_dict=join(all_results[best_key]['source'], 'dataset.json'), + num_processes=num_processes, keep_postprocessed_files=True) + + # in addition to just reading the console output (how it was previously) we should return the information + # needed to run the full inference via API + return_dict = { + 'folds': folds, + 'dataset_name_or_id': dataset_name_or_id, + 'considered_models': allowed_trained_models, + 'ensembling_allowed': allow_ensembling, + 'all_results': {i: j['result'] for i, j in all_results.items()}, + 'best_model_or_ensemble': { + 'result_on_crossval_pre_pp': all_results[best_key]["result"], + 'result_on_crossval_post_pp': load_json(join(all_results[best_key]['source'], 'postprocessed', 'summary.json'))['foreground_mean']['Dice'], + 'postprocessing_file': join(all_results[best_key]['source'], 'postprocessing.pkl'), + 'some_plans_file': join(all_results[best_key]['source'], 'plans.json'), + # just needed for label handling, can + # come from any of the ensemble members (if any) + 'selected_model_or_models': [] + } + } + # convert best key to inference command: + if best_key.startswith('ensemble___'): + prefix, m1, m2, folds_string = best_key.split('___') + tr1, pl1, c1 = convert_identifier_to_trainer_plans_config(m1) + tr2, pl2, c2 = convert_identifier_to_trainer_plans_config(m2) + return_dict['best_model_or_ensemble']['selected_model_or_models'].append( + { + 'configuration': c1, + 'trainer': tr1, + 'plans_identifier': pl1, + }) + return_dict['best_model_or_ensemble']['selected_model_or_models'].append( + { + 'configuration': c2, + 'trainer': tr2, + 'plans_identifier': pl2, + }) + else: + tr, pl, c = convert_identifier_to_trainer_plans_config(best_key) + return_dict['best_model_or_ensemble']['selected_model_or_models'].append( + { + 'configuration': c, + 'trainer': tr, + 'plans_identifier': pl, + }) + + save_json(return_dict, join(nnUNet_results, dataset_name, 'inference_information.json')) # save this so that we don't have to run this + # everything someone wants to be reminded of the inference commands. They can just load this and give it to + # print_inference_instructions + + # print it + print_inference_instructions(return_dict, instructions_file=join(nnUNet_results, dataset_name, 'inference_instructions.txt')) + return return_dict + + +def print_inference_instructions(inference_info_dict: dict, instructions_file: str = None): + def _print_and_maybe_write_to_file(string): + print(string) + if f_handle is not None: + f_handle.write(f'{string}\n') + + f_handle = open(instructions_file, 'w') if instructions_file is not None else None + print() + _print_and_maybe_write_to_file('***Run inference like this:***\n') + output_folders = [] + + dataset_name_or_id = inference_info_dict['dataset_name_or_id'] + if len(inference_info_dict['best_model_or_ensemble']['selected_model_or_models']) > 1: + is_ensemble = True + _print_and_maybe_write_to_file('An ensemble won! What a surprise! Run the following commands to run predictions with the ensemble members:\n') + else: + is_ensemble = False + + for j, i in enumerate(inference_info_dict['best_model_or_ensemble']['selected_model_or_models']): + tr, c, pl = i['trainer'], i['configuration'], i['plans_identifier'] + if is_ensemble: + output_folder_name = f"OUTPUT_FOLDER_MODEL_{j+1}" + else: + output_folder_name = f"OUTPUT_FOLDER" + output_folders.append(output_folder_name) + + _print_and_maybe_write_to_file(generate_inference_command(dataset_name_or_id, c, pl, tr, inference_info_dict['folds'], + save_npz=is_ensemble, output_folder=output_folder_name)) + + if is_ensemble: + output_folder_str = output_folders[0] + for o in output_folders[1:]: + output_folder_str += f' {o}' + output_ensemble = f"OUTPUT_FOLDER" + _print_and_maybe_write_to_file('\nThe run ensembling with:\n') + _print_and_maybe_write_to_file(f"nnUNetv2_ensemble -i {output_folder_str} -o {output_ensemble} -np {default_num_processes}") + + _print_and_maybe_write_to_file("\n***Once inference is completed, run postprocessing like this:***\n") + _print_and_maybe_write_to_file(f"nnUNetv2_apply_postprocessing -i OUTPUT_FOLDER -o OUTPUT_FOLDER_PP " + f"-pp_pkl_file {inference_info_dict['best_model_or_ensemble']['postprocessing_file']} -np {default_num_processes} " + f"-plans_json {inference_info_dict['best_model_or_ensemble']['some_plans_file']}") + + +def dumb_trainer_config_plans_to_trained_models_dict(trainers: List[str], configs: List[str], plans: List[str]): + """ + function is called dumb because it's dumb + """ + ret = [] + for t in trainers: + for c in configs: + for p in plans: + ret.append( + {'plans': p, 'configuration': c, 'trainer': t} + ) + return tuple(ret) + + +def find_best_configuration_entry_point(): + parser = argparse.ArgumentParser() + parser.add_argument('dataset_name_or_id', type=str, help='Dataset Name or id') + parser.add_argument('-p', nargs='+', required=False, default=['nnUNetPlans'], + help='List of plan identifiers. Default: nnUNetPlans') + parser.add_argument('-c', nargs='+', required=False, default=['2d', '3d_fullres', '3d_lowres', '3d_cascade_fullres'], + help="List of configurations. Default: ['2d', '3d_fullres', '3d_lowres', '3d_cascade_fullres']") + parser.add_argument('-tr', nargs='+', required=False, default=['nnUNetTrainer'], + help='List of trainers. Default: nnUNetTrainer') + parser.add_argument('-np', required=False, default=default_num_processes, type=int, + help='Number of processes to use for ensembling, postprocessing etc') + parser.add_argument('-f', nargs='+', type=int, default=(0, 1, 2, 3, 4), + help='Folds to use. Default: 0 1 2 3 4') + parser.add_argument('--disable_ensembling', action='store_true', required=False, + help='Set this flag to disable ensembling') + parser.add_argument('--no_overwrite', action='store_true', + help='If set we will not overwrite already ensembled files etc. May speed up concecutive ' + 'runs of this command (why would oyu want to do that?) at the risk of not updating ' + 'outdated results.') + args = parser.parse_args() + + model_dict = dumb_trainer_config_plans_to_trained_models_dict(args.tr, args.c, args.p) + dataset_name = maybe_convert_to_dataset_name(args.dataset_name_or_id) + + find_best_configuration(dataset_name, model_dict, allow_ensembling=not args.disable_ensembling, + num_processes=args.np, overwrite=not args.no_overwrite, folds=args.f, + strict=False) + + +def accumulate_crossval_results_entry_point(): + parser = argparse.ArgumentParser('Copies all predicted segmentations from the individual folds into one joint ' + 'folder and evaluates them') + parser.add_argument('dataset_name_or_id', type=str, help='Dataset Name or id') + parser.add_argument('-c', type=str, required=True, + default='3d_fullres', + help="Configuration") + parser.add_argument('-o', type=str, required=False, default=None, + help="Output folder. If not specified, the output folder will be located in the trained " \ + "model directory (named crossval_results_folds_XXX).") + parser.add_argument('-f', nargs='+', type=int, default=(0, 1, 2, 3, 4), + help='Folds to use. Default: 0 1 2 3 4') + parser.add_argument('-p', type=str, required=False, default='nnUNetPlans', + help='Plan identifier in which to search for the specified configuration. Default: nnUNetPlans') + parser.add_argument('-tr', type=str, required=False, default='nnUNetTrainer', + help='Trainer class. Default: nnUNetTrainer') + args = parser.parse_args() + trained_model_folder = get_output_folder(args.dataset_name_or_id, args.tr, args.p, args.c) + + if args.o is None: + merged_output_folder = join(trained_model_folder, f'crossval_results_folds_{folds_tuple_to_string(args.f)}') + else: + merged_output_folder = args.o + + accumulate_cv_results(trained_model_folder, merged_output_folder, args.f) + + +if __name__ == '__main__': + find_best_configuration(4, + default_trained_models, + True, + 8, + False, + (0, 1, 2, 3, 4)) diff --git a/nnUNet/nnunetv2/experiment_planning/__init__.py b/nnUNet/nnunetv2/experiment_planning/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/experiment_planning/dataset_fingerprint/__init__.py b/nnUNet/nnunetv2/experiment_planning/dataset_fingerprint/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/experiment_planning/dataset_fingerprint/fingerprint_extractor.py b/nnUNet/nnunetv2/experiment_planning/dataset_fingerprint/fingerprint_extractor.py new file mode 100644 index 0000000..8280518 --- /dev/null +++ b/nnUNet/nnunetv2/experiment_planning/dataset_fingerprint/fingerprint_extractor.py @@ -0,0 +1,200 @@ +import multiprocessing +import os +from time import sleep +from typing import List, Type, Union + +import nibabel as nib +import numpy as np +from batchgenerators.utilities.file_and_folder_operations import load_json, join, save_json, isfile, maybe_mkdir_p +from tqdm import tqdm + +from nnunetv2.imageio.base_reader_writer import BaseReaderWriter +from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json +from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed +from nnunetv2.preprocessing.cropping.cropping import crop_to_nonzero +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name +from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets + + +class DatasetFingerprintExtractor(object): + def __init__(self, dataset_name_or_id: Union[str, int], num_processes: int = 8, verbose: bool = False): + """ + extracts the dataset fingerprint used for experiment planning. The dataset fingerprint will be saved as a + json file in the input_folder + + Philosophy here is to do only what we really need. Don't store stuff that we can easily read from somewhere + else. Don't compute stuff we don't need (except for intensity_statistics_per_channel) + """ + dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id) + self.verbose = verbose + + self.dataset_name = dataset_name + self.input_folder = join(nnUNet_raw, dataset_name) + self.num_processes = num_processes + self.dataset_json = load_json(join(self.input_folder, 'dataset.json')) + self.dataset = get_filenames_of_train_images_and_targets(self.input_folder, self.dataset_json) + + # We don't want to use all foreground voxels because that can accumulate a lot of data (out of memory). It is + # also not critically important to get all pixels as long as there are enough. Let's use 10e7 voxels in total + # (for the entire dataset) + self.num_foreground_voxels_for_intensitystats = 10e7 + + @staticmethod + def collect_foreground_intensities(segmentation: np.ndarray, images: np.ndarray, seed: int = 1234, + num_samples: int = 10000): + """ + images=image with multiple channels = shape (c, x, y(, z)) + """ + assert len(images.shape) == 4 + assert len(segmentation.shape) == 4 + + assert not np.any(np.isnan(segmentation)), "Segmentation contains NaN values. grrrr.... :-(" + assert not np.any(np.isnan(images)), "Images contains NaN values. grrrr.... :-(" + + rs = np.random.RandomState(seed) + + intensities_per_channel = [] + # we don't use the intensity_statistics_per_channel at all, it's just something that might be nice to have + intensity_statistics_per_channel = [] + + # segmentation is 4d: 1,x,y,z. We need to remove the empty dimension for the following code to work + foreground_mask = segmentation[0] > 0 + + for i in range(len(images)): + foreground_pixels = images[i][foreground_mask] + num_fg = len(foreground_pixels) + # sample with replacement so that we don't get issues with cases that have less than num_samples + # foreground_pixels. We could also just sample less in those cases but that would than cause these + # training cases to be underrepresented + intensities_per_channel.append( + rs.choice(foreground_pixels, num_samples, replace=True) if num_fg > 0 else []) + intensity_statistics_per_channel.append({ + 'mean': np.mean(foreground_pixels) if num_fg > 0 else np.nan, + 'median': np.median(foreground_pixels) if num_fg > 0 else np.nan, + 'min': np.min(foreground_pixels) if num_fg > 0 else np.nan, + 'max': np.max(foreground_pixels) if num_fg > 0 else np.nan, + 'percentile_99_5': np.percentile(foreground_pixels, 99.5) if num_fg > 0 else np.nan, + 'percentile_00_5': np.percentile(foreground_pixels, 0.5) if num_fg > 0 else np.nan, + + }) + + return intensities_per_channel, intensity_statistics_per_channel + + @staticmethod + def analyze_case(image_files: List[str], segmentation_file: str, reader_writer_class: Type[BaseReaderWriter], + num_samples: int = 10000): + rw = reader_writer_class() + images, properties_images = rw.read_images(image_files) + segmentation, properties_seg = rw.read_seg(segmentation_file) + + # we no longer crop and save the cropped images before this is run. Instead we run the cropping on the fly. + # Downside is that we need to do this twice (once here and once during preprocessing). Upside is that we don't + # need to save the cropped data anymore. Given that cropping is not too expensive it makes sense to do it this + # way. This is only possible because we are now using our new input/output interface. + data_cropped, seg_cropped, bbox = crop_to_nonzero(images, segmentation) + + foreground_intensities_per_channel, foreground_intensity_stats_per_channel = \ + DatasetFingerprintExtractor.collect_foreground_intensities(seg_cropped, data_cropped, + num_samples=num_samples) + + spacing = properties_images['spacing'] + + shape_before_crop = images.shape[1:] + shape_after_crop = data_cropped.shape[1:] + relative_size_after_cropping = np.prod(shape_after_crop) / np.prod(shape_before_crop) + return shape_after_crop, spacing, foreground_intensities_per_channel, foreground_intensity_stats_per_channel, \ + relative_size_after_cropping + + def run(self, overwrite_existing: bool = False) -> dict: + # we do not save the properties file in self.input_folder because that folder might be read-only. We can only + # reliably write in nnUNet_preprocessed and nnUNet_results, so nnUNet_preprocessed it is + preprocessed_output_folder = join(nnUNet_preprocessed, self.dataset_name) + maybe_mkdir_p(preprocessed_output_folder) + properties_file = join(preprocessed_output_folder, 'dataset_fingerprint.json') + + if not isfile(properties_file) or overwrite_existing: + reader_writer_class = determine_reader_writer_from_dataset_json(self.dataset_json, + # yikes. Rip the following line + self.dataset[self.dataset.keys().__iter__().__next__()]['images'][0]) + + # determine how many foreground voxels we need to sample per training case + num_foreground_samples_per_case = int(self.num_foreground_voxels_for_intensitystats // + len(self.dataset)) + + r = [] + with multiprocessing.get_context("spawn").Pool(self.num_processes) as p: + for k in self.dataset.keys(): + r.append(p.starmap_async(DatasetFingerprintExtractor.analyze_case, + ((self.dataset[k]['images'], self.dataset[k]['label'], reader_writer_class, + num_foreground_samples_per_case),))) + remaining = list(range(len(self.dataset))) + # p is pretty nifti. If we kill workers they just respawn but don't do any work. + # So we need to store the original pool of workers. + workers = [j for j in p._pool] + with tqdm(desc=None, total=len(self.dataset), disable=self.verbose) as pbar: + while len(remaining) > 0: + all_alive = all([j.is_alive() for j in workers]) + if not all_alive: + raise RuntimeError('Some background worker is 6 feet under. Yuck. \n' + 'OK jokes aside.\n' + 'One of your background processes is missing. This could be because of ' + 'an error (look for an error message) or because it was killed ' + 'by your OS due to running out of RAM. If you don\'t see ' + 'an error message, out of RAM is likely the problem. In that case ' + 'reducing the number of workers might help') + done = [i for i in remaining if r[i].ready()] + for _ in done: + pbar.update() + remaining = [i for i in remaining if i not in done] + sleep(0.1) + + # results = ptqdm(DatasetFingerprintExtractor.analyze_case, + # (training_images_per_case, training_labels_per_case), + # processes=self.num_processes, zipped=True, reader_writer_class=reader_writer_class, + # num_samples=num_foreground_samples_per_case, disable=self.verbose) + results = [i.get()[0] for i in r] + + shapes_after_crop = [r[0] for r in results] + spacings = [r[1] for r in results] + foreground_intensities_per_channel = [np.concatenate([r[2][i] for r in results]) for i in + range(len(results[0][2]))] + # we drop this so that the json file is somewhat human readable + # foreground_intensity_stats_by_case_and_modality = [r[3] for r in results] + median_relative_size_after_cropping = np.median([r[4] for r in results], 0) + + num_channels = len(self.dataset_json['channel_names'].keys() + if 'channel_names' in self.dataset_json.keys() + else self.dataset_json['modality'].keys()) + intensity_statistics_per_channel = {} + for i in range(num_channels): + intensity_statistics_per_channel[i] = { + 'mean': float(np.mean(foreground_intensities_per_channel[i])), + 'median': float(np.median(foreground_intensities_per_channel[i])), + 'std': float(np.std(foreground_intensities_per_channel[i])), + 'min': float(np.min(foreground_intensities_per_channel[i])), + 'max': float(np.max(foreground_intensities_per_channel[i])), + 'percentile_99_5': float(np.percentile(foreground_intensities_per_channel[i], 99.5)), + 'percentile_00_5': float(np.percentile(foreground_intensities_per_channel[i], 0.5)), + } + + fingerprint = { + "spacings": spacings, + "shapes_after_crop": shapes_after_crop, + 'foreground_intensity_properties_per_channel': intensity_statistics_per_channel, + "median_relative_size_after_cropping": median_relative_size_after_cropping + } + + try: + save_json(fingerprint, properties_file) + except Exception as e: + if isfile(properties_file): + os.remove(properties_file) + raise e + else: + fingerprint = load_json(properties_file) + return fingerprint + + +if __name__ == '__main__': + dfe = DatasetFingerprintExtractor(2, 8) + dfe.run(overwrite_existing=False) diff --git a/nnUNet/nnunetv2/experiment_planning/experiment_planners/__init__.py b/nnUNet/nnunetv2/experiment_planning/experiment_planners/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py b/nnUNet/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py new file mode 100644 index 0000000..55d841e --- /dev/null +++ b/nnUNet/nnunetv2/experiment_planning/experiment_planners/default_experiment_planner.py @@ -0,0 +1,534 @@ +import os.path +import shutil +from copy import deepcopy +from functools import lru_cache +from typing import List, Union, Tuple, Type + +import numpy as np +from batchgenerators.utilities.file_and_folder_operations import load_json, join, save_json, isfile, maybe_mkdir_p +from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet +from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm + +from nnunetv2.configuration import ANISO_THRESHOLD +from nnunetv2.experiment_planning.experiment_planners.network_topology import get_pool_and_conv_props +from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json +from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed +from nnunetv2.preprocessing.normalization.map_channel_name_to_normalization import get_normalization_scheme +from nnunetv2.preprocessing.resampling.default_resampling import resample_data_or_seg_to_shape, compute_new_shape +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name +from nnunetv2.utilities.json_export import recursive_fix_for_json_export +from nnunetv2.utilities.utils import get_identifiers_from_splitted_dataset_folder, \ + get_filenames_of_train_images_and_targets + + +class ExperimentPlanner(object): + def __init__(self, dataset_name_or_id: Union[str, int], + gpu_memory_target_in_gb: float = 8, + preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetPlans', + overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, + suppress_transpose: bool = False): + """ + overwrite_target_spacing only affects 3d_fullres! (but by extension 3d_lowres which starts with fullres may + also be affected + """ + + self.dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id) + self.suppress_transpose = suppress_transpose + self.raw_dataset_folder = join(nnUNet_raw, self.dataset_name) + preprocessed_folder = join(nnUNet_preprocessed, self.dataset_name) + self.dataset_json = load_json(join(self.raw_dataset_folder, 'dataset.json')) + self.dataset = get_filenames_of_train_images_and_targets(self.raw_dataset_folder, self.dataset_json) + + # load dataset fingerprint + if not isfile(join(preprocessed_folder, 'dataset_fingerprint.json')): + raise RuntimeError('Fingerprint missing for this dataset. Please run nnUNet_extract_dataset_fingerprint') + + self.dataset_fingerprint = load_json(join(preprocessed_folder, 'dataset_fingerprint.json')) + + self.anisotropy_threshold = ANISO_THRESHOLD + + self.UNet_base_num_features = 32 + self.UNet_class = PlainConvUNet + # the following two numbers are really arbitrary and were set to reproduce nnU-Net v1's configurations as + # much as possible + self.UNet_reference_val_3d = 560000000 # 455600128 550000000 + self.UNet_reference_val_2d = 85000000 # 83252480 + self.UNet_reference_com_nfeatures = 32 + self.UNet_reference_val_corresp_GB = 8 + self.UNet_reference_val_corresp_bs_2d = 12 + self.UNet_reference_val_corresp_bs_3d = 2 + self.UNet_vram_target_GB = gpu_memory_target_in_gb + self.UNet_featuremap_min_edge_length = 4 + self.UNet_blocks_per_stage_encoder = (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2) + self.UNet_blocks_per_stage_decoder = (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2) + self.UNet_min_batch_size = 2 + self.UNet_max_features_2d = 512 + self.UNet_max_features_3d = 320 + + self.lowres_creation_threshold = 0.25 # if the patch size of fullres is less than 25% of the voxels in the + # median shape then we need a lowres config as well + + self.preprocessor_name = preprocessor_name + self.plans_identifier = plans_name + self.overwrite_target_spacing = overwrite_target_spacing + assert overwrite_target_spacing is None or len(overwrite_target_spacing), 'if overwrite_target_spacing is ' \ + 'used then three floats must be ' \ + 'given (as list or tuple)' + assert overwrite_target_spacing is None or all([isinstance(i, float) for i in overwrite_target_spacing]), \ + 'if overwrite_target_spacing is used then three floats must be given (as list or tuple)' + + self.plans = None + + def determine_reader_writer(self): + example_image = self.dataset[self.dataset.keys().__iter__().__next__()]['images'][0] + return determine_reader_writer_from_dataset_json(self.dataset_json, example_image) + + @staticmethod + @lru_cache(maxsize=None) + def static_estimate_VRAM_usage(patch_size: Tuple[int], + n_stages: int, + strides: Union[int, List[int], Tuple[int, ...]], + UNet_class: Union[Type[PlainConvUNet], Type[ResidualEncoderUNet]], + num_input_channels: int, + features_per_stage: Tuple[int], + blocks_per_stage_encoder: Union[int, Tuple[int]], + blocks_per_stage_decoder: Union[int, Tuple[int]], + num_labels: int): + """ + Works for PlainConvUNet, ResidualEncoderUNet + """ + dim = len(patch_size) + conv_op = convert_dim_to_conv_op(dim) + norm_op = get_matching_instancenorm(conv_op) + net = UNet_class(num_input_channels, n_stages, + features_per_stage, + conv_op, + 3, + strides, + blocks_per_stage_encoder, + num_labels, + blocks_per_stage_decoder, + norm_op=norm_op) + return net.compute_conv_feature_map_size(patch_size) + + def determine_resampling(self, *args, **kwargs): + """ + returns what functions to use for resampling data and seg, respectively. Also returns kwargs + resampling function must be callable(data, current_spacing, new_spacing, **kwargs) + + determine_resampling is called within get_plans_for_configuration to allow for different functions for each + configuration + """ + resampling_data = resample_data_or_seg_to_shape + resampling_data_kwargs = { + "is_seg": False, + "order": 3, + "order_z": 0, + "force_separate_z": None, + } + resampling_seg = resample_data_or_seg_to_shape + resampling_seg_kwargs = { + "is_seg": True, + "order": 1, + "order_z": 0, + "force_separate_z": None, + } + return resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs + + def determine_segmentation_softmax_export_fn(self, *args, **kwargs): + """ + function must be callable(data, new_shape, current_spacing, new_spacing, **kwargs). The new_shape should be + used as target. current_spacing and new_spacing are merely there in case we want to use it somehow + + determine_segmentation_softmax_export_fn is called within get_plans_for_configuration to allow for different + functions for each configuration + + """ + resampling_fn = resample_data_or_seg_to_shape + resampling_fn_kwargs = { + "is_seg": False, + "order": 1, + "order_z": 0, + "force_separate_z": None, + } + return resampling_fn, resampling_fn_kwargs + + def determine_fullres_target_spacing(self) -> np.ndarray: + """ + per default we use the 50th percentile=median for the target spacing. Higher spacing results in smaller data + and thus faster and easier training. Smaller spacing results in larger data and thus longer and harder training + + For some datasets the median is not a good choice. Those are the datasets where the spacing is very anisotropic + (for example ACDC with (10, 1.5, 1.5)). These datasets still have examples with a spacing of 5 or 6 mm in the low + resolution axis. Choosing the median here will result in bad interpolation artifacts that can substantially + impact performance (due to the low number of slices). + """ + if self.overwrite_target_spacing is not None: + return np.array(self.overwrite_target_spacing) + + spacings = self.dataset_fingerprint['spacings'] + sizes = self.dataset_fingerprint['shapes_after_crop'] + + target = np.percentile(np.vstack(spacings), 50, 0) + + # todo sizes_after_resampling = [compute_new_shape(j, i, target) for i, j in zip(spacings, sizes)] + + target_size = np.percentile(np.vstack(sizes), 50, 0) + # we need to identify datasets for which a different target spacing could be beneficial. These datasets have + # the following properties: + # - one axis which much lower resolution than the others + # - the lowres axis has much less voxels than the others + # - (the size in mm of the lowres axis is also reduced) + worst_spacing_axis = np.argmax(target) + other_axes = [i for i in range(len(target)) if i != worst_spacing_axis] + other_spacings = [target[i] for i in other_axes] + other_sizes = [target_size[i] for i in other_axes] + + has_aniso_spacing = target[worst_spacing_axis] > (self.anisotropy_threshold * max(other_spacings)) + has_aniso_voxels = target_size[worst_spacing_axis] * self.anisotropy_threshold < min(other_sizes) + + if has_aniso_spacing and has_aniso_voxels: + spacings_of_that_axis = np.vstack(spacings)[:, worst_spacing_axis] + target_spacing_of_that_axis = np.percentile(spacings_of_that_axis, 10) + # don't let the spacing of that axis get higher than the other axes + if target_spacing_of_that_axis < max(other_spacings): + target_spacing_of_that_axis = max(max(other_spacings), target_spacing_of_that_axis) + 1e-5 + target[worst_spacing_axis] = target_spacing_of_that_axis + return target + + def determine_normalization_scheme_and_whether_mask_is_used_for_norm(self) -> Tuple[List[str], List[bool]]: + if 'channel_names' not in self.dataset_json.keys(): + print('WARNING: "modalities" should be renamed to "channel_names" in dataset.json. This will be ' + 'enforced soon!') + modalities = self.dataset_json['channel_names'] if 'channel_names' in self.dataset_json.keys() else \ + self.dataset_json['modality'] + normalization_schemes = [get_normalization_scheme(m) for m in modalities.values()] + if self.dataset_fingerprint['median_relative_size_after_cropping'] < (3 / 4.): + use_nonzero_mask_for_norm = [i.leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true for i in + normalization_schemes] + else: + use_nonzero_mask_for_norm = [False] * len(normalization_schemes) + assert all([i in (True, False) for i in use_nonzero_mask_for_norm]), 'use_nonzero_mask_for_norm must be ' \ + 'True or False and cannot be None' + normalization_schemes = [i.__name__ for i in normalization_schemes] + return normalization_schemes, use_nonzero_mask_for_norm + + def determine_transpose(self): + if self.suppress_transpose: + return [0, 1, 2], [0, 1, 2] + + # todo we should use shapes for that as well. Not quite sure how yet + target_spacing = self.determine_fullres_target_spacing() + + max_spacing_axis = np.argmax(target_spacing) + remaining_axes = [i for i in list(range(3)) if i != max_spacing_axis] + transpose_forward = [max_spacing_axis] + remaining_axes + transpose_backward = [np.argwhere(np.array(transpose_forward) == i)[0][0] for i in range(3)] + return transpose_forward, transpose_backward + + def get_plans_for_configuration(self, + spacing: Union[np.ndarray, Tuple[float, ...], List[float]], + median_shape: Union[np.ndarray, Tuple[int, ...], List[int]], + data_identifier: str, + approximate_n_voxels_dataset: float) -> dict: + assert all([i > 0 for i in spacing]), f"Spacing must be > 0! Spacing: {spacing}" + # print(spacing, median_shape, approximate_n_voxels_dataset) + # find an initial patch size + # we first use the spacing to get an aspect ratio + tmp = 1 / np.array(spacing) + + # we then upscale it so that it initially is certainly larger than what we need (rescale to have the same + # volume as a patch of size 256 ** 3) + # this may need to be adapted when using absurdly large GPU memory targets. Increasing this now would not be + # ideal because large initial patch sizes increase computation time because more iterations in the while loop + # further down may be required. + if len(spacing) == 3: + initial_patch_size = [round(i) for i in tmp * (256 ** 3 / np.prod(tmp)) ** (1 / 3)] + elif len(spacing) == 2: + initial_patch_size = [round(i) for i in tmp * (2048 ** 2 / np.prod(tmp)) ** (1 / 2)] + else: + raise RuntimeError() + + # clip initial patch size to median_shape. It makes little sense to have it be larger than that. Note that + # this is different from how nnU-Net v1 does it! + # todo patch size can still get too large because we pad the patch size to a multiple of 2**n + initial_patch_size = np.array([min(i, j) for i, j in zip(initial_patch_size, median_shape[:len(spacing)])]) + + # use that to get the network topology. Note that this changes the patch_size depending on the number of + # pooling operations (must be divisible by 2**num_pool in each axis) + network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ + shape_must_be_divisible_by = get_pool_and_conv_props(spacing, initial_patch_size, + self.UNet_featuremap_min_edge_length, + 999999) + + # now estimate vram consumption + num_stages = len(pool_op_kernel_sizes) + estimate = self.static_estimate_VRAM_usage(tuple(patch_size), + num_stages, + tuple([tuple(i) for i in pool_op_kernel_sizes]), + self.UNet_class, + len(self.dataset_json['channel_names'].keys() + if 'channel_names' in self.dataset_json.keys() + else self.dataset_json['modality'].keys()), + tuple([min(self.UNet_max_features_2d if len(patch_size) == 2 else + self.UNet_max_features_3d, + self.UNet_reference_com_nfeatures * 2 ** i) for + i in range(len(pool_op_kernel_sizes))]), + self.UNet_blocks_per_stage_encoder[:num_stages], + self.UNet_blocks_per_stage_decoder[:num_stages - 1], + len(self.dataset_json['labels'].keys())) + + # how large is the reference for us here (batch size etc)? + # adapt for our vram target + reference = (self.UNet_reference_val_2d if len(spacing) == 2 else self.UNet_reference_val_3d) * \ + (self.UNet_vram_target_GB / self.UNet_reference_val_corresp_GB) + + while estimate > reference: + # print(patch_size) + # patch size seems to be too large, so we need to reduce it. Reduce the axis that currently violates the + # aspect ratio the most (that is the largest relative to median shape) + axis_to_be_reduced = np.argsort(patch_size / median_shape[:len(spacing)])[-1] + + # we cannot simply reduce that axis by shape_must_be_divisible_by[axis_to_be_reduced] because this + # may cause us to skip some valid sizes, for example shape_must_be_divisible_by is 64 for a shape of 256. + # If we subtracted that we would end up with 192, skipping 224 which is also a valid patch size + # (224 / 2**5 = 7; 7 < 2 * self.UNet_featuremap_min_edge_length(4) so it's valid). So we need to first + # subtract shape_must_be_divisible_by, then recompute it and then subtract the + # recomputed shape_must_be_divisible_by. Annoying. + tmp = deepcopy(patch_size) + tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] + _, _, _, _, shape_must_be_divisible_by = \ + get_pool_and_conv_props(spacing, tmp, + self.UNet_featuremap_min_edge_length, + 999999) + patch_size[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] + + # now recompute topology + network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, \ + shape_must_be_divisible_by = get_pool_and_conv_props(spacing, patch_size, + self.UNet_featuremap_min_edge_length, + 999999) + + num_stages = len(pool_op_kernel_sizes) + estimate = self.static_estimate_VRAM_usage(tuple(patch_size), + num_stages, + tuple([tuple(i) for i in pool_op_kernel_sizes]), + self.UNet_class, + len(self.dataset_json['channel_names'].keys() + if 'channel_names' in self.dataset_json.keys() + else self.dataset_json['modality'].keys()), + tuple([min(self.UNet_max_features_2d if len(patch_size) == 2 else + self.UNet_max_features_3d, + self.UNet_reference_com_nfeatures * 2 ** i) for + i in range(len(pool_op_kernel_sizes))]), + self.UNet_blocks_per_stage_encoder[:num_stages], + self.UNet_blocks_per_stage_decoder[:num_stages - 1], + len(self.dataset_json['labels'].keys())) + + # alright now let's determine the batch size. This will give self.UNet_min_batch_size if the while loop was + # executed. If not, additional vram headroom is used to increase batch size + ref_bs = self.UNet_reference_val_corresp_bs_2d if len(spacing) == 2 else self.UNet_reference_val_corresp_bs_3d + batch_size = round((reference / estimate) * ref_bs) + + # we need to cap the batch size to cover at most 5% of the entire dataset. Overfitting precaution. We cannot + # go smaller than self.UNet_min_batch_size though + bs_corresponding_to_5_percent = round( + approximate_n_voxels_dataset * 0.05 / np.prod(patch_size, dtype=np.float64)) + batch_size = max(min(batch_size, bs_corresponding_to_5_percent), self.UNet_min_batch_size) + + resampling_data, resampling_data_kwargs, resampling_seg, resampling_seg_kwargs = self.determine_resampling() + resampling_softmax, resampling_softmax_kwargs = self.determine_segmentation_softmax_export_fn() + + normalization_schemes, mask_is_used_for_norm = \ + self.determine_normalization_scheme_and_whether_mask_is_used_for_norm() + num_stages = len(pool_op_kernel_sizes) + plan = { + 'data_identifier': data_identifier, + 'preprocessor_name': self.preprocessor_name, + 'batch_size': batch_size, + 'patch_size': patch_size, + 'median_image_size_in_voxels': median_shape, + 'spacing': spacing, + 'normalization_schemes': normalization_schemes, + 'use_mask_for_norm': mask_is_used_for_norm, + 'UNet_class_name': self.UNet_class.__name__, + 'UNet_base_num_features': self.UNet_base_num_features, + 'n_conv_per_stage_encoder': self.UNet_blocks_per_stage_encoder[:num_stages], + 'n_conv_per_stage_decoder': self.UNet_blocks_per_stage_decoder[:num_stages - 1], + 'num_pool_per_axis': network_num_pool_per_axis, + 'pool_op_kernel_sizes': pool_op_kernel_sizes, + 'conv_kernel_sizes': conv_kernel_sizes, + 'unet_max_num_features': self.UNet_max_features_3d if len(spacing) == 3 else self.UNet_max_features_2d, + 'resampling_fn_data': resampling_data.__name__, + 'resampling_fn_seg': resampling_seg.__name__, + 'resampling_fn_data_kwargs': resampling_data_kwargs, + 'resampling_fn_seg_kwargs': resampling_seg_kwargs, + 'resampling_fn_probabilities': resampling_softmax.__name__, + 'resampling_fn_probabilities_kwargs': resampling_softmax_kwargs, + } + return plan + + def plan_experiment(self): + """ + MOVE EVERYTHING INTO THE PLANS. MAXIMUM FLEXIBILITY + + Ideally I would like to move transpose_forward/backward into the configurations so that this can also be done + differently for each configuration but this would cause problems with identifying the correct axes for 2d. There + surely is a way around that but eh. I'm feeling lazy and featuritis must also not be pushed to the extremes. + + So for now if you want a different transpose_forward/backward you need to create a new planner. Also not too + hard. + """ + + # first get transpose + transpose_forward, transpose_backward = self.determine_transpose() + + # get fullres spacing and transpose it + fullres_spacing = self.determine_fullres_target_spacing() + fullres_spacing_transposed = fullres_spacing[transpose_forward] + + # get transposed new median shape (what we would have after resampling) + new_shapes = [compute_new_shape(j, i, fullres_spacing) for i, j in + zip(self.dataset_fingerprint['spacings'], self.dataset_fingerprint['shapes_after_crop'])] + new_median_shape = np.median(new_shapes, 0) + new_median_shape_transposed = new_median_shape[transpose_forward] + + approximate_n_voxels_dataset = float(np.prod(new_median_shape_transposed, dtype=np.float64) * + self.dataset_json['numTraining']) + # only run 3d if this is a 3d dataset + if new_median_shape_transposed[0] != 1: + plan_3d_fullres = self.get_plans_for_configuration(fullres_spacing_transposed, + new_median_shape_transposed, + self.generate_data_identifier('3d_fullres'), + approximate_n_voxels_dataset) + # maybe add 3d_lowres as well + patch_size_fullres = plan_3d_fullres['patch_size'] + median_num_voxels = np.prod(new_median_shape_transposed, dtype=np.float64) + num_voxels_in_patch = np.prod(patch_size_fullres, dtype=np.float64) + + plan_3d_lowres = None + lowres_spacing = deepcopy(plan_3d_fullres['spacing']) + + spacing_increase_factor = 1.03 # used to be 1.01 but that is slow with new GPU memory estimation! + + while num_voxels_in_patch / median_num_voxels < self.lowres_creation_threshold: + # we incrementally increase the target spacing. We start with the anisotropic axis/axes until it/they + # is/are similar (factor 2) to the other ax(i/e)s. + max_spacing = max(lowres_spacing) + if np.any((max_spacing / lowres_spacing) > 2): + lowres_spacing[(max_spacing / lowres_spacing) > 2] *= spacing_increase_factor + else: + lowres_spacing *= spacing_increase_factor + median_num_voxels = np.prod(plan_3d_fullres['spacing'] / lowres_spacing * new_median_shape_transposed, + dtype=np.float64) + # print(lowres_spacing) + plan_3d_lowres = self.get_plans_for_configuration(lowres_spacing, + [round(i) for i in plan_3d_fullres['spacing'] / + lowres_spacing * new_median_shape_transposed], + self.generate_data_identifier('3d_lowres'), + float(np.prod(median_num_voxels) * + self.dataset_json['numTraining'])) + num_voxels_in_patch = np.prod(plan_3d_lowres['patch_size'], dtype=np.int64) + print(f'Attempting to find 3d_lowres config. ' + f'\nCurrent spacing: {lowres_spacing}. ' + f'\nCurrent patch size: {plan_3d_lowres["patch_size"]}. ' + f'\nCurrent median shape: {plan_3d_fullres["spacing"] / lowres_spacing * new_median_shape_transposed}') + if plan_3d_lowres is not None: + plan_3d_lowres['batch_dice'] = False + plan_3d_fullres['batch_dice'] = True + else: + plan_3d_fullres['batch_dice'] = False + else: + plan_3d_fullres = None + plan_3d_lowres = None + + # 2D configuration + plan_2d = self.get_plans_for_configuration(fullres_spacing_transposed[1:], + new_median_shape_transposed[1:], + self.generate_data_identifier('2d'), approximate_n_voxels_dataset) + plan_2d['batch_dice'] = True + + print('2D U-Net configuration:') + print(plan_2d) + print() + + # median spacing and shape, just for reference when printing the plans + median_spacing = np.median(self.dataset_fingerprint['spacings'], 0)[transpose_forward] + median_shape = np.median(self.dataset_fingerprint['shapes_after_crop'], 0)[transpose_forward] + + # instead of writing all that into the plans we just copy the original file. More files, but less crowded + # per file. + shutil.copy(join(self.raw_dataset_folder, 'dataset.json'), + join(nnUNet_preprocessed, self.dataset_name, 'dataset.json')) + + # json is stupid and I hate it... "Object of type int64 is not JSON serializable" -> my ass + plans = { + 'dataset_name': self.dataset_name, + 'plans_name': self.plans_identifier, + 'original_median_spacing_after_transp': [float(i) for i in median_spacing], + 'original_median_shape_after_transp': [int(round(i)) for i in median_shape], + 'image_reader_writer': self.determine_reader_writer().__name__, + 'transpose_forward': [int(i) for i in transpose_forward], + 'transpose_backward': [int(i) for i in transpose_backward], + 'configurations': {'2d': plan_2d}, + 'experiment_planner_used': self.__class__.__name__, + 'label_manager': 'LabelManager', + 'foreground_intensity_properties_per_channel': self.dataset_fingerprint[ + 'foreground_intensity_properties_per_channel'] + } + + if plan_3d_lowres is not None: + plans['configurations']['3d_lowres'] = plan_3d_lowres + if plan_3d_fullres is not None: + plans['configurations']['3d_lowres']['next_stage'] = '3d_cascade_fullres' + print('3D lowres U-Net configuration:') + print(plan_3d_lowres) + print() + if plan_3d_fullres is not None: + plans['configurations']['3d_fullres'] = plan_3d_fullres + print('3D fullres U-Net configuration:') + print(plan_3d_fullres) + print() + if plan_3d_lowres is not None: + plans['configurations']['3d_cascade_fullres'] = { + 'inherits_from': '3d_fullres', + 'previous_stage': '3d_lowres' + } + + self.plans = plans + self.save_plans(plans) + return plans + + def save_plans(self, plans): + recursive_fix_for_json_export(plans) + + plans_file = join(nnUNet_preprocessed, self.dataset_name, self.plans_identifier + '.json') + + # we don't want to overwrite potentially existing custom configurations every time this is executed. So let's + # read the plans file if it already exists and keep any non-default configurations + if isfile(plans_file): + old_plans = load_json(plans_file) + old_configurations = old_plans['configurations'] + for c in plans['configurations'].keys(): + if c in old_configurations.keys(): + del (old_configurations[c]) + plans['configurations'].update(old_configurations) + + maybe_mkdir_p(join(nnUNet_preprocessed, self.dataset_name)) + save_json(plans, plans_file, sort_keys=False) + print('Plans were saved to %s' % join(nnUNet_preprocessed, self.dataset_name, self.plans_identifier + '.json')) + + def generate_data_identifier(self, configuration_name: str) -> str: + """ + configurations are unique within each plans file but differnet plans file can have configurations with the + same name. In order to distinguish the assiciated data we need a data identifier that reflects not just the + config but also the plans it originates from + """ + return self.plans_identifier + '_' + configuration_name + + def load_plans(self, fname: str): + self.plans = load_json(fname) + + +if __name__ == '__main__': + ExperimentPlanner(2, 8).plan_experiment() diff --git a/nnUNet/nnunetv2/experiment_planning/experiment_planners/network_topology.py b/nnUNet/nnunetv2/experiment_planning/experiment_planners/network_topology.py new file mode 100644 index 0000000..8b75b46 --- /dev/null +++ b/nnUNet/nnunetv2/experiment_planning/experiment_planners/network_topology.py @@ -0,0 +1,105 @@ +from copy import deepcopy +import numpy as np + + +def get_shape_must_be_divisible_by(net_numpool_per_axis): + return 2 ** np.array(net_numpool_per_axis) + + +def pad_shape(shape, must_be_divisible_by): + """ + pads shape so that it is divisible by must_be_divisible_by + :param shape: + :param must_be_divisible_by: + :return: + """ + if not isinstance(must_be_divisible_by, (tuple, list, np.ndarray)): + must_be_divisible_by = [must_be_divisible_by] * len(shape) + else: + assert len(must_be_divisible_by) == len(shape) + + new_shp = [shape[i] + must_be_divisible_by[i] - shape[i] % must_be_divisible_by[i] for i in range(len(shape))] + + for i in range(len(shape)): + if shape[i] % must_be_divisible_by[i] == 0: + new_shp[i] -= must_be_divisible_by[i] + new_shp = np.array(new_shp).astype(int) + return new_shp + + +def get_pool_and_conv_props(spacing, patch_size, min_feature_map_size, max_numpool): + """ + this is the same as get_pool_and_conv_props_v2 from old nnunet + + :param spacing: + :param patch_size: + :param min_feature_map_size: min edge length of feature maps in bottleneck + :param max_numpool: + :return: + """ + # todo review this code + dim = len(spacing) + + current_spacing = deepcopy(list(spacing)) + current_size = deepcopy(list(patch_size)) + + pool_op_kernel_sizes = [[1] * len(spacing)] + conv_kernel_sizes = [] + + num_pool_per_axis = [0] * dim + kernel_size = [1] * dim + + while True: + # exclude axes that we cannot pool further because of min_feature_map_size constraint + valid_axes_for_pool = [i for i in range(dim) if current_size[i] >= 2*min_feature_map_size] + if len(valid_axes_for_pool) < 1: + break + + spacings_of_axes = [current_spacing[i] for i in valid_axes_for_pool] + + # find axis that are within factor of 2 within smallest spacing + min_spacing_of_valid = min(spacings_of_axes) + valid_axes_for_pool = [i for i in valid_axes_for_pool if current_spacing[i] / min_spacing_of_valid < 2] + + # max_numpool constraint + valid_axes_for_pool = [i for i in valid_axes_for_pool if num_pool_per_axis[i] < max_numpool] + + if len(valid_axes_for_pool) == 1: + if current_size[valid_axes_for_pool[0]] >= 3 * min_feature_map_size: + pass + else: + break + if len(valid_axes_for_pool) < 1: + break + + # now we need to find kernel sizes + # kernel sizes are initialized to 1. They are successively set to 3 when their associated axis becomes within + # factor 2 of min_spacing. Once they are 3 they remain 3 + for d in range(dim): + if kernel_size[d] == 3: + continue + else: + if spacings_of_axes[d] / min(current_spacing) < 2: + kernel_size[d] = 3 + + other_axes = [i for i in range(dim) if i not in valid_axes_for_pool] + + pool_kernel_sizes = [0] * dim + for v in valid_axes_for_pool: + pool_kernel_sizes[v] = 2 + num_pool_per_axis[v] += 1 + current_spacing[v] *= 2 + current_size[v] = np.ceil(current_size[v] / 2) + for nv in other_axes: + pool_kernel_sizes[nv] = 1 + + pool_op_kernel_sizes.append(pool_kernel_sizes) + conv_kernel_sizes.append(deepcopy(kernel_size)) + #print(conv_kernel_sizes) + + must_be_divisible_by = get_shape_must_be_divisible_by(num_pool_per_axis) + patch_size = pad_shape(patch_size, must_be_divisible_by) + + # we need to add one more conv_kernel_size for the bottleneck. We always use 3x3(x3) conv here + conv_kernel_sizes.append([3]*dim) + return num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, must_be_divisible_by diff --git a/nnUNet/nnunetv2/experiment_planning/experiment_planners/readme.md b/nnUNet/nnunetv2/experiment_planning/experiment_planners/readme.md new file mode 100644 index 0000000..e2e4e18 --- /dev/null +++ b/nnUNet/nnunetv2/experiment_planning/experiment_planners/readme.md @@ -0,0 +1,38 @@ +What do experiment planners need to do (these are notes for myself while rewriting nnU-Net, they are provided as is +without further explanations. These notes also include new features): +- (done) preprocessor name should be configurable via cli +- (done) gpu memory target should be configurable via cli +- (done) plans name should be configurable via cli +- (done) data name should be specified in plans (plans specify the data they want to use, this will allow us to manually + edit plans files without having to copy the data folders) +- plans must contain: + - (done) transpose forward/backward + - (done) preprocessor name (can differ for each config) + - (done) spacing + - (done) normalization scheme + - (done) target spacing + - (done) conv and pool op kernel sizes + - (done) base num features for architecture + - (done) data identifier + - num conv per stage? + - (done) use mask for norm + - [NO. Handled by LabelManager & dataset.json] num segmentation outputs + - [NO. Handled by LabelManager & dataset.json] ignore class + - [NO. Handled by LabelManager & dataset.json] list of regions or classes + - [NO. Handled by LabelManager & dataset.json] regions class order, if applicable + - (done) resampling function to be used + - (done) the image reader writer class that should be used + + +dataset.json +mandatory: +- numTraining +- labels (value 'ignore' has special meaning. Cannot have more than one ignore_label) +- modalities +- file_ending + +optional +- overwrite_image_reader_writer (if absent, auto) +- regions +- region_class_order +- \ No newline at end of file diff --git a/nnUNet/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py b/nnUNet/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py new file mode 100644 index 0000000..52ca938 --- /dev/null +++ b/nnUNet/nnunetv2/experiment_planning/experiment_planners/resencUNet_planner.py @@ -0,0 +1,54 @@ +from typing import Union, List, Tuple + +from torch import nn + +from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner +from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet + + +class ResEncUNetPlanner(ExperimentPlanner): + def __init__(self, dataset_name_or_id: Union[str, int], + gpu_memory_target_in_gb: float = 8, + preprocessor_name: str = 'DefaultPreprocessor', plans_name: str = 'nnUNetResEncUNetPlans', + overwrite_target_spacing: Union[List[float], Tuple[float, ...]] = None, + suppress_transpose: bool = False): + super().__init__(dataset_name_or_id, gpu_memory_target_in_gb, preprocessor_name, plans_name, + overwrite_target_spacing, suppress_transpose) + + self.UNet_base_num_features = 32 + self.UNet_class = ResidualEncoderUNet + # the following two numbers are really arbitrary and were set to reproduce default nnU-Net's configurations as + # much as possible + self.UNet_reference_val_3d = 680000000 + self.UNet_reference_val_2d = 135000000 + self.UNet_reference_com_nfeatures = 32 + self.UNet_reference_val_corresp_GB = 8 + self.UNet_reference_val_corresp_bs_2d = 12 + self.UNet_reference_val_corresp_bs_3d = 2 + self.UNet_featuremap_min_edge_length = 4 + self.UNet_blocks_per_stage_encoder = (1, 3, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6) + self.UNet_blocks_per_stage_decoder = (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1) + self.UNet_min_batch_size = 2 + self.UNet_max_features_2d = 512 + self.UNet_max_features_3d = 320 + + +if __name__ == '__main__': + # we know both of these networks run with batch size 2 and 12 on ~8-10GB, respectively + net = ResidualEncoderUNet(input_channels=1, n_stages=6, features_per_stage=(32, 64, 128, 256, 320, 320), + conv_op=nn.Conv3d, kernel_sizes=3, strides=(1, 2, 2, 2, 2, 2), + n_blocks_per_stage=(1, 3, 4, 6, 6, 6), num_classes=3, + n_conv_per_stage_decoder=(1, 1, 1, 1, 1), + conv_bias=True, norm_op=nn.InstanceNorm3d, norm_op_kwargs={}, dropout_op=None, + nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True) + print(net.compute_conv_feature_map_size((128, 128, 128))) # -> 558319104. The value you see above was finetuned + # from this one to match the regular nnunetplans more closely + + net = ResidualEncoderUNet(input_channels=1, n_stages=7, features_per_stage=(32, 64, 128, 256, 512, 512, 512), + conv_op=nn.Conv2d, kernel_sizes=3, strides=(1, 2, 2, 2, 2, 2, 2), + n_blocks_per_stage=(1, 3, 4, 6, 6, 6, 6), num_classes=3, + n_conv_per_stage_decoder=(1, 1, 1, 1, 1, 1), + conv_bias=True, norm_op=nn.InstanceNorm2d, norm_op_kwargs={}, dropout_op=None, + nonlin=nn.LeakyReLU, nonlin_kwargs={'inplace': True}, deep_supervision=True) + print(net.compute_conv_feature_map_size((512, 512))) # -> 129793792 + diff --git a/nnUNet/nnunetv2/experiment_planning/plan_and_preprocess_api.py b/nnUNet/nnunetv2/experiment_planning/plan_and_preprocess_api.py new file mode 100644 index 0000000..eb94840 --- /dev/null +++ b/nnUNet/nnunetv2/experiment_planning/plan_and_preprocess_api.py @@ -0,0 +1,138 @@ +import shutil +from typing import List, Type, Optional, Tuple, Union + +import nnunetv2 +from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p, subfiles, load_json + +from nnunetv2.experiment_planning.dataset_fingerprint.fingerprint_extractor import DatasetFingerprintExtractor +from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner +from nnunetv2.experiment_planning.verify_dataset_integrity import verify_dataset_integrity +from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed +from nnunetv2.utilities.dataset_name_id_conversion import convert_id_to_dataset_name, maybe_convert_to_dataset_name +from nnunetv2.utilities.find_class_by_name import recursive_find_python_class +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager +from nnunetv2.configuration import default_num_processes +from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets + + +def extract_fingerprint_dataset(dataset_id: int, + fingerprint_extractor_class: Type[ + DatasetFingerprintExtractor] = DatasetFingerprintExtractor, + num_processes: int = default_num_processes, check_dataset_integrity: bool = False, + clean: bool = True, verbose: bool = True): + """ + Returns the fingerprint as a dictionary (additionally to saving it) + """ + dataset_name = convert_id_to_dataset_name(dataset_id) + print(dataset_name) + + if check_dataset_integrity: + verify_dataset_integrity(join(nnUNet_raw, dataset_name), num_processes) + + fpe = fingerprint_extractor_class(dataset_id, num_processes, verbose=verbose) + return fpe.run(overwrite_existing=clean) + + +def extract_fingerprints(dataset_ids: List[int], fingerprint_extractor_class_name: str = 'DatasetFingerprintExtractor', + num_processes: int = default_num_processes, check_dataset_integrity: bool = False, + clean: bool = True, verbose: bool = True): + """ + clean = False will not actually run this. This is just a switch for use with nnUNetv2_plan_and_preprocess where + we don't want to rerun fingerprint extraction every time. + """ + fingerprint_extractor_class = recursive_find_python_class(join(nnunetv2.__path__[0], "experiment_planning"), + fingerprint_extractor_class_name, + current_module="nnunetv2.experiment_planning") + for d in dataset_ids: + extract_fingerprint_dataset(d, fingerprint_extractor_class, num_processes, check_dataset_integrity, clean, + verbose) + + +def plan_experiment_dataset(dataset_id: int, + experiment_planner_class: Type[ExperimentPlanner] = ExperimentPlanner, + gpu_memory_target_in_gb: float = 8, preprocess_class_name: str = 'DefaultPreprocessor', + overwrite_target_spacing: Optional[Tuple[float, ...]] = None, + overwrite_plans_name: Optional[str] = None) -> dict: + """ + overwrite_target_spacing ONLY applies to 3d_fullres and 3d_cascade fullres! + """ + kwargs = {} + if overwrite_plans_name is not None: + kwargs['plans_name'] = overwrite_plans_name + return experiment_planner_class(dataset_id, + gpu_memory_target_in_gb=gpu_memory_target_in_gb, + preprocessor_name=preprocess_class_name, + overwrite_target_spacing=[float(i) for i in overwrite_target_spacing] if + overwrite_target_spacing is not None else overwrite_target_spacing, + suppress_transpose=False, # might expose this later, + **kwargs + ).plan_experiment() + + +def plan_experiments(dataset_ids: List[int], experiment_planner_class_name: str = 'ExperimentPlanner', + gpu_memory_target_in_gb: float = 8, preprocess_class_name: str = 'DefaultPreprocessor', + overwrite_target_spacing: Optional[Tuple[float, ...]] = None, + overwrite_plans_name: Optional[str] = None): + """ + overwrite_target_spacing ONLY applies to 3d_fullres and 3d_cascade fullres! + """ + experiment_planner = recursive_find_python_class(join(nnunetv2.__path__[0], "experiment_planning"), + experiment_planner_class_name, + current_module="nnunetv2.experiment_planning") + for d in dataset_ids: + plan_experiment_dataset(d, experiment_planner, gpu_memory_target_in_gb, preprocess_class_name, + overwrite_target_spacing, overwrite_plans_name) + + +def preprocess_dataset(dataset_id: int, + plans_identifier: str = 'nnUNetPlans', + configurations: Union[Tuple[str], List[str]] = ('2d', '3d_fullres', '3d_lowres'), + num_processes: Union[int, Tuple[int, ...], List[int]] = (8, 4, 8), + verbose: bool = False) -> None: + if not isinstance(num_processes, list): + num_processes = list(num_processes) + if len(num_processes) == 1: + num_processes = num_processes * len(configurations) + if len(num_processes) != len(configurations): + raise RuntimeError( + f'The list provided with num_processes must either have len 1 or as many elements as there are ' + f'configurations (see --help). Number of configurations: {len(configurations)}, length ' + f'of num_processes: ' + f'{len(num_processes)}') + + dataset_name = convert_id_to_dataset_name(dataset_id) + print(f'Preprocessing dataset {dataset_name}') + plans_file = join(nnUNet_preprocessed, dataset_name, plans_identifier + '.json') + plans_manager = PlansManager(plans_file) + for n, c in zip(num_processes, configurations): + print(f'Configuration: {c}...') + if c not in plans_manager.available_configurations: + print( + f"INFO: Configuration {c} not found in plans file {plans_identifier + '.json'} of " + f"dataset {dataset_name}. Skipping.") + continue + configuration_manager = plans_manager.get_configuration(c) + preprocessor = configuration_manager.preprocessor_class(verbose=verbose) + preprocessor.run(dataset_id, c, plans_identifier, num_processes=n) + + # copy the gt to a folder in the nnUNet_preprocessed so that we can do validation even if the raw data is no + # longer there (useful for compute cluster where only the preprocessed data is available) + from distutils.file_util import copy_file + maybe_mkdir_p(join(nnUNet_preprocessed, dataset_name, 'gt_segmentations')) + dataset_json = load_json(join(nnUNet_raw, dataset_name, 'dataset.json')) + dataset = get_filenames_of_train_images_and_targets(join(nnUNet_raw, dataset_name), dataset_json) + # only copy files that are newer than the ones already present + for k in dataset: + copy_file(dataset[k]['label'], + join(nnUNet_preprocessed, dataset_name, 'gt_segmentations', k + dataset_json['file_ending']), + update=True) + + + +def preprocess(dataset_ids: List[int], + plans_identifier: str = 'nnUNetPlans', + configurations: Union[Tuple[str], List[str]] = ('2d', '3d_fullres', '3d_lowres'), + num_processes: Union[int, Tuple[int, ...], List[int]] = (8, 4, 8), + verbose: bool = False): + for d in dataset_ids: + preprocess_dataset(d, plans_identifier, configurations, num_processes, verbose) diff --git a/nnUNet/nnunetv2/experiment_planning/plan_and_preprocess_entrypoints.py b/nnUNet/nnunetv2/experiment_planning/plan_and_preprocess_entrypoints.py new file mode 100644 index 0000000..a600653 --- /dev/null +++ b/nnUNet/nnunetv2/experiment_planning/plan_and_preprocess_entrypoints.py @@ -0,0 +1,201 @@ +from nnunetv2.configuration import default_num_processes +from nnunetv2.experiment_planning.plan_and_preprocess_api import extract_fingerprints, plan_experiments, preprocess + + +def extract_fingerprint_entry(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('-d', nargs='+', type=int, + help="[REQUIRED] List of dataset IDs. Example: 2 4 5. This will run fingerprint extraction, experiment " + "planning and preprocessing for these datasets. Can of course also be just one dataset") + parser.add_argument('-fpe', type=str, required=False, default='DatasetFingerprintExtractor', + help='[OPTIONAL] Name of the Dataset Fingerprint Extractor class that should be used. Default is ' + '\'DatasetFingerprintExtractor\'.') + parser.add_argument('-np', type=int, default=default_num_processes, required=False, + help=f'[OPTIONAL] Number of processes used for fingerprint extraction. ' + f'Default: {default_num_processes}') + parser.add_argument("--verify_dataset_integrity", required=False, default=False, action="store_true", + help="[RECOMMENDED] set this flag to check the dataset integrity. This is useful and should be done once for " + "each dataset!") + parser.add_argument("--clean", required=False, default=False, action="store_true", + help='[OPTIONAL] Set this flag to overwrite existing fingerprints. If this flag is not set and a ' + 'fingerprint already exists, the fingerprint extractor will not run.') + parser.add_argument('--verbose', required=False, action='store_true', + help='Set this to print a lot of stuff. Useful for debugging. Will disable progrewss bar! ' + 'Recommended for cluster environments') + args, unrecognized_args = parser.parse_known_args() + extract_fingerprints(args.d, args.fpe, args.np, args.verify_dataset_integrity, args.clean, args.verbose) + + +def plan_experiment_entry(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('-d', nargs='+', type=int, + help="[REQUIRED] List of dataset IDs. Example: 2 4 5. This will run fingerprint extraction, experiment " + "planning and preprocessing for these datasets. Can of course also be just one dataset") + parser.add_argument('-pl', type=str, default='ExperimentPlanner', required=False, + help='[OPTIONAL] Name of the Experiment Planner class that should be used. Default is ' + '\'ExperimentPlanner\'. Note: There is no longer a distinction between 2d and 3d planner. ' + 'It\'s an all in one solution now. Wuch. Such amazing.') + parser.add_argument('-gpu_memory_target', default=8, type=float, required=False, + help='[OPTIONAL] DANGER ZONE! Sets a custom GPU memory target. Default: 8 [GB]. Changing this will ' + 'affect patch and batch size and will ' + 'definitely affect your models performance! Only use this if you really know what you ' + 'are doing and NEVER use this without running the default nnU-Net first (as a baseline).') + parser.add_argument('-preprocessor_name', default='DefaultPreprocessor', type=str, required=False, + help='[OPTIONAL] DANGER ZONE! Sets a custom preprocessor class. This class must be located in ' + 'nnunetv2.preprocessing. Default: \'DefaultPreprocessor\'. Changing this may affect your ' + 'models performance! Only use this if you really know what you ' + 'are doing and NEVER use this without running the default nnU-Net first (as a baseline).') + parser.add_argument('-overwrite_target_spacing', default=None, nargs='+', required=False, + help='[OPTIONAL] DANGER ZONE! Sets a custom target spacing for the 3d_fullres and 3d_cascade_fullres ' + 'configurations. Default: None [no changes]. Changing this will affect image size and ' + 'potentially patch and batch ' + 'size. This will definitely affect your models performance! Only use this if you really ' + 'know what you are doing and NEVER use this without running the default nnU-Net first ' + '(as a baseline). Changing the target spacing for the other configurations is currently ' + 'not implemented. New target spacing must be a list of three numbers!') + parser.add_argument('-overwrite_plans_name', default=None, required=False, + help='[OPTIONAL] DANGER ZONE! If you used -gpu_memory_target, -preprocessor_name or ' + '-overwrite_target_spacing it is best practice to use -overwrite_plans_name to generate a ' + 'differently named plans file such that the nnunet default plans are not ' + 'overwritten. You will then need to specify your custom plans file with -p whenever ' + 'running other nnunet commands (training, inference etc)') + args, unrecognized_args = parser.parse_known_args() + plan_experiments(args.d, args.pl, args.gpu_memory_target, args.preprocessor_name, args.overwrite_target_spacing, + args.overwrite_plans_name) + + +def preprocess_entry(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('-d', nargs='+', type=int, + help="[REQUIRED] List of dataset IDs. Example: 2 4 5. This will run fingerprint extraction, experiment " + "planning and preprocessing for these datasets. Can of course also be just one dataset") + parser.add_argument('-plans_name', default='nnUNetPlans', required=False, + help='[OPTIONAL] You can use this to specify a custom plans file that you may have generated') + parser.add_argument('-c', required=False, default=['2d', '3d_fullres', '3d_lowres'], nargs='+', + help='[OPTIONAL] Configurations for which the preprocessing should be run. Default: 2d 3f_fullres ' + '3d_lowres. 3d_cascade_fullres does not need to be specified because it uses the data ' + 'from 3f_fullres. Configurations that do not exist for some dataset will be skipped.') + parser.add_argument('-np', type=int, nargs='+', default=[8, 4, 8], required=False, + help="[OPTIONAL] Use this to define how many processes are to be used. If this is just one number then " + "this number of processes is used for all configurations specified with -c. If it's a " + "list of numbers this list must have as many elements as there are configurations. We " + "then iterate over zip(configs, num_processes) to determine then umber of processes " + "used for each configuration. More processes is always faster (up to the number of " + "threads your PC can support, so 8 for a 4 core CPU with hyperthreading. If you don't " + "know what that is then dont touch it, or at least don't increase it!). DANGER: More " + "often than not the number of processes that can be used is limited by the amount of " + "RAM available. Image resampling takes up a lot of RAM. MONITOR RAM USAGE AND " + "DECREASE -np IF YOUR RAM FILLS UP TOO MUCH!. Default: 8 processes for 2d, 4 " + "for 3d_fullres, 8 for 3d_lowres and 4 for everything else") + parser.add_argument('--verbose', required=False, action='store_true', + help='Set this to print a lot of stuff. Useful for debugging. Will disable progrewss bar! ' + 'Recommended for cluster environments') + args, unrecognized_args = parser.parse_known_args() + if args.np is None: + default_np = { + '2d': 4, + '3d_lowres': 8, + '3d_fullres': 4 + } + np = {default_np[c] if c in default_np.keys() else 4 for c in args.c} + else: + np = args.np + preprocess(args.d, args.plans_name, configurations=args.c, num_processes=np, verbose=args.verbose) + + +def plan_and_preprocess_entry(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('-d', nargs='+', type=int, + help="[REQUIRED] List of dataset IDs. Example: 2 4 5. This will run fingerprint extraction, experiment " + "planning and preprocessing for these datasets. Can of course also be just one dataset") + parser.add_argument('-fpe', type=str, required=False, default='DatasetFingerprintExtractor', + help='[OPTIONAL] Name of the Dataset Fingerprint Extractor class that should be used. Default is ' + '\'DatasetFingerprintExtractor\'.') + parser.add_argument('-npfp', type=int, default=8, required=False, + help='[OPTIONAL] Number of processes used for fingerprint extraction. Default: 8') + parser.add_argument("--verify_dataset_integrity", required=False, default=False, action="store_true", + help="[RECOMMENDED] set this flag to check the dataset integrity. This is useful and should be done once for " + "each dataset!") + parser.add_argument('--no_pp', default=False, action='store_true', required=False, + help='[OPTIONAL] Set this to only run fingerprint extraction and experiment planning (no ' + 'preprocesing). Useful for debugging.') + parser.add_argument("--clean", required=False, default=False, action="store_true", + help='[OPTIONAL] Set this flag to overwrite existing fingerprints. If this flag is not set and a ' + 'fingerprint already exists, the fingerprint extractor will not run. REQUIRED IF YOU ' + 'CHANGE THE DATASET FINGERPRINT EXTRACTOR OR MAKE CHANGES TO THE DATASET!') + parser.add_argument('-pl', type=str, default='ExperimentPlanner', required=False, + help='[OPTIONAL] Name of the Experiment Planner class that should be used. Default is ' + '\'ExperimentPlanner\'. Note: There is no longer a distinction between 2d and 3d planner. ' + 'It\'s an all in one solution now. Wuch. Such amazing.') + parser.add_argument('-gpu_memory_target', default=8, type=int, required=False, + help='[OPTIONAL] DANGER ZONE! Sets a custom GPU memory target. Default: 8 [GB]. Changing this will ' + 'affect patch and batch size and will ' + 'definitely affect your models performance! Only use this if you really know what you ' + 'are doing and NEVER use this without running the default nnU-Net first (as a baseline).') + parser.add_argument('-preprocessor_name', default='DefaultPreprocessor', type=str, required=False, + help='[OPTIONAL] DANGER ZONE! Sets a custom preprocessor class. This class must be located in ' + 'nnunetv2.preprocessing. Default: \'DefaultPreprocessor\'. Changing this may affect your ' + 'models performance! Only use this if you really know what you ' + 'are doing and NEVER use this without running the default nnU-Net first (as a baseline).') + parser.add_argument('-overwrite_target_spacing', default=None, nargs='+', required=False, + help='[OPTIONAL] DANGER ZONE! Sets a custom target spacing for the 3d_fullres and 3d_cascade_fullres ' + 'configurations. Default: None [no changes]. Changing this will affect image size and ' + 'potentially patch and batch ' + 'size. This will definitely affect your models performance! Only use this if you really ' + 'know what you are doing and NEVER use this without running the default nnU-Net first ' + '(as a baseline). Changing the target spacing for the other configurations is currently ' + 'not implemented. New target spacing must be a list of three numbers!') + parser.add_argument('-overwrite_plans_name', default='nnUNetPlans', required=False, + help='[OPTIONAL] uSE A CUSTOM PLANS IDENTIFIER. If you used -gpu_memory_target, ' + '-preprocessor_name or ' + '-overwrite_target_spacing it is best practice to use -overwrite_plans_name to generate a ' + 'differently named plans file such that the nnunet default plans are not ' + 'overwritten. You will then need to specify your custom plans file with -p whenever ' + 'running other nnunet commands (training, inference etc)') + parser.add_argument('-c', required=False, default=['2d', '3d_fullres', '3d_lowres'], nargs='+', + help='[OPTIONAL] Configurations for which the preprocessing should be run. Default: 2d 3f_fullres ' + '3d_lowres. 3d_cascade_fullres does not need to be specified because it uses the data ' + 'from 3f_fullres. Configurations that do not exist for some dataset will be skipped.') + parser.add_argument('-np', type=int, nargs='+', default=None, required=False, + help="[OPTIONAL] Use this to define how many processes are to be used. If this is just one number then " + "this number of processes is used for all configurations specified with -c. If it's a " + "list of numbers this list must have as many elements as there are configurations. We " + "then iterate over zip(configs, num_processes) to determine then umber of processes " + "used for each configuration. More processes is always faster (up to the number of " + "threads your PC can support, so 8 for a 4 core CPU with hyperthreading. If you don't " + "know what that is then dont touch it, or at least don't increase it!). DANGER: More " + "often than not the number of processes that can be used is limited by the amount of " + "RAM available. Image resampling takes up a lot of RAM. MONITOR RAM USAGE AND " + "DECREASE -np IF YOUR RAM FILLS UP TOO MUCH!. Default: 8 processes for 2d, 4 " + "for 3d_fullres, 8 for 3d_lowres and 4 for everything else") + parser.add_argument('--verbose', required=False, action='store_true', + help='Set this to print a lot of stuff. Useful for debugging. Will disable progrewss bar! ' + 'Recommended for cluster environments') + args = parser.parse_args() + + # fingerprint extraction + print("Fingerprint extraction...") + extract_fingerprints(args.d, args.fpe, args.npfp, args.verify_dataset_integrity, args.clean, args.verbose) + + # experiment planning + print('Experiment planning...') + plan_experiments(args.d, args.pl, args.gpu_memory_target, args.preprocessor_name, args.overwrite_target_spacing, args.overwrite_plans_name) + + # manage default np + if args.np is None: + default_np = {"2d": 8, "3d_fullres": 4, "3d_lowres": 8} + np = [default_np[c] if c in default_np.keys() else 4 for c in args.c] + else: + np = args.np + # preprocessing + if not args.no_pp: + print('Preprocessing...') + preprocess(args.d, args.overwrite_plans_name, args.c, np, args.verbose) + + +if __name__ == '__main__': + plan_and_preprocess_entry() diff --git a/nnUNet/nnunetv2/experiment_planning/plans_for_pretraining/__init__.py b/nnUNet/nnunetv2/experiment_planning/plans_for_pretraining/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/experiment_planning/plans_for_pretraining/move_plans_between_datasets.py b/nnUNet/nnunetv2/experiment_planning/plans_for_pretraining/move_plans_between_datasets.py new file mode 100644 index 0000000..d78b689 --- /dev/null +++ b/nnUNet/nnunetv2/experiment_planning/plans_for_pretraining/move_plans_between_datasets.py @@ -0,0 +1,79 @@ +import argparse +from typing import Union + +from batchgenerators.utilities.file_and_folder_operations import join, isdir, isfile, load_json, subfiles, save_json + +from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json +from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw +from nnunetv2.utilities.file_path_utilities import maybe_convert_to_dataset_name +from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets + + +def move_plans_between_datasets( + source_dataset_name_or_id: Union[int, str], + target_dataset_name_or_id: Union[int, str], + source_plans_identifier: str, + target_plans_identifier: str = None): + source_dataset_name = maybe_convert_to_dataset_name(source_dataset_name_or_id) + target_dataset_name = maybe_convert_to_dataset_name(target_dataset_name_or_id) + + if target_plans_identifier is None: + target_plans_identifier = source_plans_identifier + + source_folder = join(nnUNet_preprocessed, source_dataset_name) + assert isdir(source_folder), f"Cannot move plans because preprocessed directory of source dataset is missing. " \ + f"Run nnUNetv2_plan_and_preprocess for source dataset first!" + + source_plans_file = join(source_folder, source_plans_identifier + '.json') + assert isfile(source_plans_file), f"Source plans are missing. Run the corresponding experiment planning first! " \ + f"Expected file: {source_plans_file}" + + source_plans = load_json(source_plans_file) + source_plans['dataset_name'] = target_dataset_name + + # we need to change data_identifier to use target_plans_identifier + if target_plans_identifier != source_plans_identifier: + for c in source_plans['configurations'].keys(): + old_identifier = source_plans['configurations'][c]["data_identifier"] + if old_identifier.startswith(source_plans_identifier): + new_identifier = target_plans_identifier + old_identifier[len(source_plans_identifier):] + else: + new_identifier = target_plans_identifier + '_' + old_identifier + source_plans['configurations'][c]["data_identifier"] = new_identifier + + # we need to change the reader writer class! + target_raw_data_dir = join(nnUNet_raw, target_dataset_name) + target_dataset_json = load_json(join(target_raw_data_dir, 'dataset.json')) + + # we may need to change the reader/writer + # pick any file from the source dataset + dataset = get_filenames_of_train_images_and_targets(target_raw_data_dir, target_dataset_json) + example_image = dataset[dataset.keys().__iter__().__next__()]['images'][0] + rw = determine_reader_writer_from_dataset_json(target_dataset_json, example_image, allow_nonmatching_filename=True, + verbose=False) + + source_plans["image_reader_writer"] = rw.__name__ + + save_json(source_plans, join(nnUNet_preprocessed, target_dataset_name, target_plans_identifier + '.json'), + sort_keys=False) + + +def entry_point_move_plans_between_datasets(): + parser = argparse.ArgumentParser() + parser.add_argument('-s', type=str, required=True, + help='Source dataset name or id') + parser.add_argument('-t', type=str, required=True, + help='Target dataset name or id') + parser.add_argument('-sp', type=str, required=True, + help='Source plans identifier. If your plans are named "nnUNetPlans.json" then the ' + 'identifier would be nnUNetPlans') + parser.add_argument('-tp', type=str, required=False, default=None, + help='Target plans identifier. Default is None meaning the source plans identifier will ' + 'be kept. Not recommended if the source plans identifier is a default nnU-Net identifier ' + 'such as nnUNetPlans!!!') + args = parser.parse_args() + move_plans_between_datasets(args.s, args.t, args.sp, args.tp) + + +if __name__ == '__main__': + move_plans_between_datasets(2, 4, 'nnUNetPlans', 'nnUNetPlansFrom2') diff --git a/nnUNet/nnunetv2/experiment_planning/verify_dataset_integrity.py b/nnUNet/nnunetv2/experiment_planning/verify_dataset_integrity.py new file mode 100644 index 0000000..502611c --- /dev/null +++ b/nnUNet/nnunetv2/experiment_planning/verify_dataset_integrity.py @@ -0,0 +1,234 @@ +# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center +# (DKFZ), Heidelberg, Germany +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import multiprocessing +import re +from multiprocessing import Pool +from typing import Type + +import numpy as np +import pandas as pd +from batchgenerators.utilities.file_and_folder_operations import * + +from nnunetv2.imageio.base_reader_writer import BaseReaderWriter +from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json +from nnunetv2.paths import nnUNet_raw +from nnunetv2.utilities.label_handling.label_handling import LabelManager +from nnunetv2.utilities.utils import get_identifiers_from_splitted_dataset_folder, \ + get_filenames_of_train_images_and_targets + + +def verify_labels(label_file: str, readerclass: Type[BaseReaderWriter], expected_labels: List[int]) -> bool: + rw = readerclass() + seg, properties = rw.read_seg(label_file) + found_labels = np.sort(pd.unique(seg.ravel())) # np.unique(seg) + unexpected_labels = [i for i in found_labels if i not in expected_labels] + if len(found_labels) == 0 and found_labels[0] == 0: + print('WARNING: File %s only has label 0 (which should be background). This may be intentional or not, ' + 'up to you.' % label_file) + if len(unexpected_labels) > 0: + print("Error: Unexpected labels found in file %s.\nExpected: %s\nFound: %s" % (label_file, expected_labels, + found_labels)) + return False + return True + + +def check_cases(image_files: List[str], label_file: str, expected_num_channels: int, + readerclass: Type[BaseReaderWriter]) -> bool: + rw = readerclass() + ret = True + + images, properties_image = rw.read_images(image_files) + segmentation, properties_seg = rw.read_seg(label_file) + + # check for nans + if np.any(np.isnan(images)): + print(f'Images contain NaN pixel values. You need to fix that by ' + f'replacing NaN values with something that makes sense for your images!\nImages:\n{image_files}') + ret = False + if np.any(np.isnan(segmentation)): + print(f'Segmentation contains NaN pixel values. You need to fix that.\nSegmentation:\n{label_file}') + ret = False + + # check shapes + shape_image = images.shape[1:] + shape_seg = segmentation.shape[1:] + if not all([i == j for i, j in zip(shape_image, shape_seg)]): + print('Error: Shape mismatch between segmentation and corresponding images. \nShape images: %s. ' + '\nShape seg: %s. \nImage files: %s. \nSeg file: %s\n' % + (shape_image, shape_seg, image_files, label_file)) + ret = False + + # check spacings + spacing_images = properties_image['spacing'] + spacing_seg = properties_seg['spacing'] + if not np.allclose(spacing_seg, spacing_images): + print('Error: Spacing mismatch between segmentation and corresponding images. \nSpacing images: %s. ' + '\nSpacing seg: %s. \nImage files: %s. \nSeg file: %s\n' % + (shape_image, shape_seg, image_files, label_file)) + ret = False + + # check modalities + if not len(images) == expected_num_channels: + print('Error: Unexpected number of modalities. \nExpected: %d. \nGot: %d. \nImages: %s\n' + % (expected_num_channels, len(images), image_files)) + ret = False + + # nibabel checks + if 'nibabel_stuff' in properties_image.keys(): + # this image was read with NibabelIO + affine_image = properties_image['nibabel_stuff']['original_affine'] + affine_seg = properties_seg['nibabel_stuff']['original_affine'] + if not np.allclose(affine_image, affine_seg): + print('WARNING: Affine is not the same for image and seg! \nAffine image: %s \nAffine seg: %s\n' + 'Image files: %s. \nSeg file: %s.\nThis can be a problem but doesn\'t have to be. Please run ' + 'nnUNet_plot_dataset_pngs to verify if everything is OK!\n' + % (affine_image, affine_seg, image_files, label_file)) + + # sitk checks + if 'sitk_stuff' in properties_image.keys(): + # this image was read with SimpleITKIO + # spacing has already been checked, only check direction and origin + origin_image = properties_image['sitk_stuff']['origin'] + origin_seg = properties_seg['sitk_stuff']['origin'] + if not np.allclose(origin_image, origin_seg): + print('Warning: Origin mismatch between segmentation and corresponding images. \nOrigin images: %s. ' + '\nOrigin seg: %s. \nImage files: %s. \nSeg file: %s\n' % + (origin_image, origin_seg, image_files, label_file)) + direction_image = properties_image['sitk_stuff']['direction'] + direction_seg = properties_seg['sitk_stuff']['direction'] + if not np.allclose(direction_image, direction_seg): + print('Warning: Direction mismatch between segmentation and corresponding images. \nDirection images: %s. ' + '\nDirection seg: %s. \nImage files: %s. \nSeg file: %s\n' % + (direction_image, direction_seg, image_files, label_file)) + + return ret + + +def verify_dataset_integrity(folder: str, num_processes: int = 8) -> None: + """ + folder needs the imagesTr, imagesTs and labelsTr subfolders. There also needs to be a dataset.json + checks if the expected number of training cases and labels are present + for each case, if possible, checks whether the pixel grids are aligned + checks whether the labels really only contain values they should + :param folder: + :return: + """ + assert isfile(join(folder, "dataset.json")), "There needs to be a dataset.json file in folder, folder=%s" % folder + dataset_json = load_json(join(folder, "dataset.json")) + + if not 'dataset' in dataset_json.keys(): + assert isdir(join(folder, "imagesTr")), "There needs to be a imagesTr subfolder in folder, folder=%s" % folder + assert isdir(join(folder, "labelsTr")), "There needs to be a labelsTr subfolder in folder, folder=%s" % folder + + # make sure all required keys are there + dataset_keys = list(dataset_json.keys()) + required_keys = ['labels', "channel_names", "numTraining", "file_ending"] + assert all([i in dataset_keys for i in required_keys]), 'not all required keys are present in dataset.json.' \ + '\n\nRequired: \n%s\n\nPresent: \n%s\n\nMissing: ' \ + '\n%s\n\nUnused by nnU-Net:\n%s' % \ + (str(required_keys), + str(dataset_keys), + str([i for i in required_keys if i not in dataset_keys]), + str([i for i in dataset_keys if i not in required_keys])) + + expected_num_training = dataset_json['numTraining'] + num_modalities = len(dataset_json['channel_names'].keys() + if 'channel_names' in dataset_json.keys() + else dataset_json['modality'].keys()) + file_ending = dataset_json['file_ending'] + + dataset = get_filenames_of_train_images_and_targets(folder, dataset_json) + + # check if the right number of training cases is present + assert len(dataset) == expected_num_training, 'Did not find the expected number of training cases ' \ + '(%d). Found %d instead.\nExamples: %s' % \ + (expected_num_training, len(dataset), + list(dataset.keys())[:5]) + + # check if corresponding labels are present + if 'dataset' in dataset_json.keys(): + # just check if everything is there + ok = True + missing_images = [] + missing_labels = [] + for k in dataset: + for i in dataset[k]['images']: + if not isfile(i): + missing_images.append(i) + ok = False + if not isfile(dataset[k]['label']): + missing_labels.append(dataset[k]['label']) + ok = False + if not ok: + raise FileNotFoundError(f"Some expeted files were missing. Make sure you are properly referencing them " + f"in the dataset.json. Or use imagesTr & labelsTr folders!\nMissing images:" + f"\n{missing_images}\n\nMissing labels:\n{missing_labels}") + else: + # old code that uses imagestr and labelstr folders + labelfiles = subfiles(join(folder, 'labelsTr'), suffix=file_ending, join=False) + label_identifiers = [i[:-len(file_ending)] for i in labelfiles] + labels_present = [i in label_identifiers for i in dataset.keys()] + missing = [i for j, i in enumerate(dataset.keys()) if not labels_present[j]] + assert all(labels_present), 'not all training cases have a label file in labelsTr. Fix that. Missing: %s' % missing + + labelfiles = [v['label'] for v in dataset.values()] + image_files = [v['images'] for v in dataset.values()] + + # no plans exist yet, so we can't use PlansManager and gotta roll with the default. It's unlikely to cause + # problems anyway + label_manager = LabelManager(dataset_json['labels'], regions_class_order=dataset_json.get('regions_class_order')) + expected_labels = label_manager.all_labels + if label_manager.has_ignore_label: + expected_labels.append(label_manager.ignore_label) + labels_valid_consecutive = np.ediff1d(expected_labels) == 1 + assert all( + labels_valid_consecutive), f'Labels must be in consecutive order (0, 1, 2, ...). The labels {np.array(expected_labels)[1:][~labels_valid_consecutive]} do not satisfy this restriction' + + # determine reader/writer class + reader_writer_class = determine_reader_writer_from_dataset_json(dataset_json, dataset[dataset.keys().__iter__().__next__()]['images'][0]) + + # check whether only the desired labels are present + with multiprocessing.get_context("spawn").Pool(num_processes) as p: + result = p.starmap( + verify_labels, + zip([join(folder, 'labelsTr', i) for i in labelfiles], [reader_writer_class] * len(labelfiles), + [expected_labels] * len(labelfiles)) + ) + if not all(result): + raise RuntimeError( + 'Some segmentation images contained unexpected labels. Please check text output above to see which one(s).') + + # check whether shapes and spacings match between images and labels + result = p.starmap( + check_cases, + zip(image_files, labelfiles, [num_modalities] * expected_num_training, + [reader_writer_class] * expected_num_training) + ) + if not all(result): + raise RuntimeError( + 'Some images have errors. Please check text output above to see which one(s) and what\'s going on.') + + # check for nans + # check all same orientation nibabel + print('\n####################') + print('verify_dataset_integrity Done. \nIf you didn\'t see any error messages then your dataset is most likely OK!') + print('####################\n') + + +if __name__ == "__main__": + # investigate geometry issues + example_folder = join(nnUNet_raw, 'Dataset250_COMPUTING_it0') + num_processes = 6 + verify_dataset_integrity(example_folder, num_processes) diff --git a/nnUNet/nnunetv2/imageio/__init__.py b/nnUNet/nnunetv2/imageio/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/imageio/base_reader_writer.py b/nnUNet/nnunetv2/imageio/base_reader_writer.py new file mode 100644 index 0000000..d71226f --- /dev/null +++ b/nnUNet/nnunetv2/imageio/base_reader_writer.py @@ -0,0 +1,113 @@ +# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center +# (DKFZ), Heidelberg, Germany +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Tuple, Union, List +import numpy as np + + +class BaseReaderWriter(ABC): + @staticmethod + def _check_all_same(input_list): + # compare all entries to the first + for i in input_list[1:]: + if not len(i) == len(input_list[0]): + return False + all_same = all(i[j] == input_list[0][j] for j in range(len(i))) + if not all_same: + return False + return True + + @staticmethod + def _check_all_same_array(input_list): + # compare all entries to the first + for i in input_list[1:]: + if not all([a == b for a, b in zip(i.shape, input_list[0].shape)]): + return False + all_same = np.allclose(i, input_list[0]) + if not all_same: + return False + return True + + @abstractmethod + def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]: + """ + Reads a sequence of images and returns a 4d (!) np.ndarray along with a dictionary. The 4d array must have the + modalities (or color channels, or however you would like to call them) in its first axis, followed by the + spatial dimensions (so shape must be c,x,y,z where c is the number of modalities (can be 1)). + Use the dictionary to store necessary meta information that is lost when converting to numpy arrays, for + example the Spacing, Orientation and Direction of the image. This dictionary will be handed over to write_seg + for exporting the predicted segmentations, so make sure you have everything you need in there! + + IMPORTANT: dict MUST have a 'spacing' key with a tuple/list of length 3 with the voxel spacing of the np.ndarray. + Example: my_dict = {'spacing': (3, 0.5, 0.5), ...}. This is needed for planning and + preprocessing. The ordering of the numbers must correspond to the axis ordering in the returned numpy array. So + if the array has shape c,x,y,z and the spacing is (a,b,c) then a must be the spacing of x, b the spacing of y + and c the spacing of z. + + In the case of 2D images, the returned array should have shape (c, 1, x, y) and the spacing should be + (999, sp_x, sp_y). Make sure 999 is larger than sp_x and sp_y! Example: shape=(3, 1, 224, 224), + spacing=(999, 1, 1) + + For images that don't have a spacing, set the spacing to 1 (2d exception with 999 for the first axis still applies!) + + :param image_fnames: + :return: + 1) a np.ndarray of shape (c, x, y, z) where c is the number of image channels (can be 1) and x, y, z are + the spatial dimensions (set x=1 for 2D! Example: (3, 1, 224, 224) for RGB image). + 2) a dictionary with metadata. This can be anything. BUT it HAS to inclue a {'spacing': (a, b, c)} where a + is the spacing of x, b of y and c of z! If an image doesn't have spacing, just set this to 1. For 2D, set + a=999 (largest spacing value! Make it larger than b and c) + + """ + pass + + @abstractmethod + def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]: + """ + Same requirements as BaseReaderWriter.read_image. Returned segmentations must have shape 1,x,y,z. Multiple + segmentations are not (yet?) allowed + + If images and segmentations can be read the same way you can just `return self.read_image((image_fname,))` + :param seg_fname: + :return: + 1) a np.ndarray of shape (1, x, y, z) where x, y, z are + the spatial dimensions (set x=1 for 2D! Example: (1, 1, 224, 224) for 2D segmentation). + 2) a dictionary with metadata. This can be anything. BUT it HAS to inclue a {'spacing': (a, b, c)} where a + is the spacing of x, b of y and c of z! If an image doesn't have spacing, just set this to 1. For 2D, set + a=999 (largest spacing value! Make it larger than b and c) + """ + pass + + @abstractmethod + def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None: + """ + Export the predicted segmentation to the desired file format. The given seg array will have the same shape and + orientation as the corresponding image data, so you don't need to do any resampling or whatever. Just save :-) + + properties is the same dictionary you created during read_images/read_seg so you can use the information here + to restore metadata + + IMPORTANT: Segmentations are always 3D! If your input images were 2d then the segmentation will have shape + 1,x,y. You need to catch that and export accordingly (for 2d images you need to convert the 3d segmentation + to 2d via seg = seg[0])! + + :param seg: A segmentation (np.ndarray, integer) of shape (x, y, z). For 2D segmentations this will be (1, y, z)! + :param output_fname: + :param properties: the dictionary that you created in read_images (the ones this segmentation is based on). + Use this to restore metadata + :return: + """ + pass \ No newline at end of file diff --git a/nnUNet/nnunetv2/imageio/natural_image_reager_writer.py b/nnUNet/nnunetv2/imageio/natural_image_reager_writer.py new file mode 100644 index 0000000..6dd7718 --- /dev/null +++ b/nnUNet/nnunetv2/imageio/natural_image_reager_writer.py @@ -0,0 +1,73 @@ +# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center +# (DKFZ), Heidelberg, Germany +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Union, List +import numpy as np +from nnunetv2.imageio.base_reader_writer import BaseReaderWriter +from skimage import io + + +class NaturalImage2DIO(BaseReaderWriter): + """ + ONLY SUPPORTS 2D IMAGES!!! + """ + + # there are surely more we could add here. Everything that can be read by skimage.io should be supported + supported_file_endings = [ + '.png', + # '.jpg', + # '.jpeg', # jpg not supported because we cannot allow lossy compression! segmentation maps! + '.bmp', + '.tif' + ] + + def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]: + images = [] + for f in image_fnames: + npy_img = io.imread(f) + if len(npy_img.shape) == 3: + # rgb image, last dimension should be the color channel and the size of that channel should be 3 + # (or 4 if we have alpha) + assert npy_img.shape[-1] == 3 or npy_img.shape[-1] == 4, "If image has three dimensions then the last " \ + "dimension must have shape 3 or 4 " \ + f"(RGB or RGBA). Image shape here is {npy_img.shape}" + # move RGB(A) to front, add additional dim so that we have shape (1, c, X, Y), where c is either 3 or 4 + images.append(npy_img.transpose((2, 0, 1))[:, None]) + elif len(npy_img.shape) == 2: + # grayscale image + images.append(npy_img[None, None]) + + if not self._check_all_same([i.shape for i in images]): + print('ERROR! Not all input images have the same shape!') + print('Shapes:') + print([i.shape for i in images]) + print('Image files:') + print(image_fnames) + raise RuntimeError() + return np.vstack(images).astype(np.float32), {'spacing': (999, 1, 1)} + + def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]: + return self.read_images((seg_fname, )) + + def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None: + io.imsave(output_fname, seg[0].astype(np.uint8), check_contrast=False) + + +if __name__ == '__main__': + images = ('/media/fabian/data/nnUNet_raw/Dataset120_RoadSegmentation/imagesTr/img-11_0000.png',) + segmentation = '/media/fabian/data/nnUNet_raw/Dataset120_RoadSegmentation/labelsTr/img-11.png' + imgio = NaturalImage2DIO() + img, props = imgio.read_images(images) + seg, segprops = imgio.read_seg(segmentation) \ No newline at end of file diff --git a/nnUNet/nnunetv2/imageio/nibabel_reader_writer.py b/nnUNet/nnunetv2/imageio/nibabel_reader_writer.py new file mode 100644 index 0000000..e4fa3f5 --- /dev/null +++ b/nnUNet/nnunetv2/imageio/nibabel_reader_writer.py @@ -0,0 +1,204 @@ +# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center +# (DKFZ), Heidelberg, Germany +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Union, List +import numpy as np +from nibabel import io_orientation + +from nnunetv2.imageio.base_reader_writer import BaseReaderWriter +import nibabel + + +class NibabelIO(BaseReaderWriter): + """ + Nibabel loads the images in a different order than sitk. We convert the axes to the sitk order to be + consistent. This is of course considered properly in segmentation export as well. + + IMPORTANT: Run nnUNet_plot_dataset_pngs to verify that this did not destroy the alignment of data and seg! + """ + supported_file_endings = [ + '.nii.gz', + '.nrrd', + '.mha' + ] + + def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]: + images = [] + original_affines = [] + + spacings_for_nnunet = [] + for f in image_fnames: + nib_image = nibabel.load(f) + assert len(nib_image.shape) == 3, 'only 3d images are supported by NibabelIO' + original_affine = nib_image.affine + + original_affines.append(original_affine) + + # spacing is taken in reverse order to be consistent with SimpleITK axis ordering (confusing, I know...) + spacings_for_nnunet.append( + [float(i) for i in nib_image.header.get_zooms()[::-1]] + ) + + # transpose image to be consistent with the way SimpleITk reads images. Yeah. Annoying. + images.append(nib_image.get_fdata().transpose((2, 1, 0))[None]) + + if not self._check_all_same([i.shape for i in images]): + print('ERROR! Not all input images have the same shape!') + print('Shapes:') + print([i.shape for i in images]) + print('Image files:') + print(image_fnames) + raise RuntimeError() + if not self._check_all_same_array(original_affines): + print('WARNING! Not all input images have the same original_affines!') + print('Affines:') + print(original_affines) + print('Image files:') + print(image_fnames) + print('It is up to you to decide whether that\'s a problem. You should run nnUNet_plot_dataset_pngs to verify ' + 'that segmentations and data overlap.') + if not self._check_all_same(spacings_for_nnunet): + print('ERROR! Not all input images have the same spacing_for_nnunet! This might be caused by them not ' + 'having the same affine') + print('spacings_for_nnunet:') + print(spacings_for_nnunet) + print('Image files:') + print(image_fnames) + raise RuntimeError() + + stacked_images = np.vstack(images) + dict = { + 'nibabel_stuff': { + 'original_affine': original_affines[0], + }, + 'spacing': spacings_for_nnunet[0] + } + return stacked_images.astype(np.float32), dict + + def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]: + return self.read_images((seg_fname, )) + + def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None: + # revert transpose + seg = seg.transpose((2, 1, 0)).astype(np.uint8) + seg_nib = nibabel.Nifti1Image(seg, affine=properties['nibabel_stuff']['original_affine']) + nibabel.save(seg_nib, output_fname) + + +class NibabelIOWithReorient(BaseReaderWriter): + """ + Reorients images to RAS + + Nibabel loads the images in a different order than sitk. We convert the axes to the sitk order to be + consistent. This is of course considered properly in segmentation export as well. + + IMPORTANT: Run nnUNet_plot_dataset_pngs to verify that this did not destroy the alignment of data and seg! + """ + supported_file_endings = [ + '.nii.gz', + '.nrrd', + '.mha' + ] + + def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]: + images = [] + original_affines = [] + reoriented_affines = [] + + spacings_for_nnunet = [] + for f in image_fnames: + nib_image = nibabel.load(f) + assert len(nib_image.shape) == 3, 'only 3d images are supported by NibabelIO' + original_affine = nib_image.affine + reoriented_image = nib_image.as_reoriented(io_orientation(original_affine)) + reoriented_affine = reoriented_image.affine + + original_affines.append(original_affine) + reoriented_affines.append(reoriented_affine) + + # spacing is taken in reverse order to be consistent with SimpleITK axis ordering (confusing, I know...) + spacings_for_nnunet.append( + [float(i) for i in reoriented_image.header.get_zooms()[::-1]] + ) + + # transpose image to be consistent with the way SimpleITk reads images. Yeah. Annoying. + images.append(reoriented_image.get_fdata().transpose((2, 1, 0))[None]) + + if not self._check_all_same([i.shape for i in images]): + print('ERROR! Not all input images have the same shape!') + print('Shapes:') + print([i.shape for i in images]) + print('Image files:') + print(image_fnames) + raise RuntimeError() + if not self._check_all_same_array(reoriented_affines): + print('WARNING! Not all input images have the same reoriented_affines!') + print('Affines:') + print(reoriented_affines) + print('Image files:') + print(image_fnames) + print('It is up to you to decide whether that\'s a problem. You should run nnUNet_plot_dataset_pngs to verify ' + 'that segmentations and data overlap.') + if not self._check_all_same(spacings_for_nnunet): + print('ERROR! Not all input images have the same spacing_for_nnunet! This might be caused by them not ' + 'having the same affine') + print('spacings_for_nnunet:') + print(spacings_for_nnunet) + print('Image files:') + print(image_fnames) + raise RuntimeError() + + stacked_images = np.vstack(images) + dict = { + 'nibabel_stuff': { + 'original_affine': original_affines[0], + 'reoriented_affine': reoriented_affines[0], + }, + 'spacing': spacings_for_nnunet[0] + } + return stacked_images.astype(np.float32), dict + + def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]: + return self.read_images((seg_fname, )) + + def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None: + # revert transpose + seg = seg.transpose((2, 1, 0)).astype(np.uint8) + + seg_nib = nibabel.Nifti1Image(seg, affine=properties['nibabel_stuff']['reoriented_affine']) + seg_nib_reoriented = seg_nib.as_reoriented(io_orientation(properties['nibabel_stuff']['original_affine'])) + assert np.allclose(properties['nibabel_stuff']['original_affine'], seg_nib_reoriented.affine), \ + 'restored affine does not match original affine' + nibabel.save(seg_nib_reoriented, output_fname) + + +if __name__ == '__main__': + img_file = 'patient028_frame01_0000.nii.gz' + seg_file = 'patient028_frame01.nii.gz' + + nibio = NibabelIO() + images, dct = nibio.read_images([img_file]) + seg, dctseg = nibio.read_seg(seg_file) + + nibio_r = NibabelIOWithReorient() + images_r, dct_r = nibio_r.read_images([img_file]) + seg_r, dctseg_r = nibio_r.read_seg(seg_file) + + nibio.write_seg(seg[0], '/home/isensee/seg_nibio.nii.gz', dctseg) + nibio_r.write_seg(seg_r[0], '/home/isensee/seg_nibio_r.nii.gz', dctseg_r) + + s_orig = nibabel.load(seg_file).get_fdata() + s_nibio = nibabel.load('/home/isensee/seg_nibio.nii.gz').get_fdata() + s_nibio_r = nibabel.load('/home/isensee/seg_nibio_r.nii.gz').get_fdata() diff --git a/nnUNet/nnunetv2/imageio/reader_writer_registry.py b/nnUNet/nnunetv2/imageio/reader_writer_registry.py new file mode 100644 index 0000000..bdbee5d --- /dev/null +++ b/nnUNet/nnunetv2/imageio/reader_writer_registry.py @@ -0,0 +1,79 @@ +import traceback +from typing import Type + +from batchgenerators.utilities.file_and_folder_operations import join + +import nnunetv2 +from nnunetv2.imageio.natural_image_reager_writer import NaturalImage2DIO +from nnunetv2.imageio.nibabel_reader_writer import NibabelIO, NibabelIOWithReorient +from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO +from nnunetv2.imageio.tif_reader_writer import Tiff3DIO +from nnunetv2.imageio.base_reader_writer import BaseReaderWriter +from nnunetv2.utilities.find_class_by_name import recursive_find_python_class + +LIST_OF_IO_CLASSES = [ + NaturalImage2DIO, + SimpleITKIO, + Tiff3DIO, + NibabelIO, + NibabelIOWithReorient +] + + +def determine_reader_writer_from_dataset_json(dataset_json_content: dict, example_file: str = None, + allow_nonmatching_filename: bool = False, verbose: bool = True + ) -> Type[BaseReaderWriter]: + if 'overwrite_image_reader_writer' in dataset_json_content.keys() and \ + dataset_json_content['overwrite_image_reader_writer'] != 'None': + ioclass_name = dataset_json_content['overwrite_image_reader_writer'] + # trying to find that class in the nnunetv2.imageio module + try: + ret = recursive_find_reader_writer_by_name(ioclass_name) + if verbose: print('Using %s reader/writer' % ret) + return ret + except RuntimeError: + if verbose: print('Warning: Unable to find ioclass specified in dataset.json: %s' % ioclass_name) + if verbose: print('Trying to automatically determine desired class') + return determine_reader_writer_from_file_ending(dataset_json_content['file_ending'], example_file, + allow_nonmatching_filename, verbose) + + +def determine_reader_writer_from_file_ending(file_ending: str, example_file: str = None, allow_nonmatching_filename: bool = False, + verbose: bool = True): + for rw in LIST_OF_IO_CLASSES: + if file_ending.lower() in rw.supported_file_endings: + if example_file is not None: + # if an example file is provided, try if we can actually read it. If not move on to the next reader + try: + tmp = rw() + _ = tmp.read_images((example_file,)) + if verbose: print('Using %s as reader/writer' % rw) + return rw + except: + if verbose: print(f'Failed to open file {example_file} with reader {rw}:') + traceback.print_exc() + pass + else: + if verbose: print('Using %s as reader/writer' % rw) + return rw + else: + if allow_nonmatching_filename and example_file is not None: + try: + tmp = rw() + _ = tmp.read_images((example_file,)) + if verbose: print('Using %s as reader/writer' % rw) + return rw + except: + if verbose: print(f'Failed to open file {example_file} with reader {rw}:') + if verbose: traceback.print_exc() + pass + raise RuntimeError("Unable to determine a reader for file ending %s and file %s (file None means no file provided)." % (file_ending, example_file)) + + +def recursive_find_reader_writer_by_name(rw_class_name: str) -> Type[BaseReaderWriter]: + ret = recursive_find_python_class(join(nnunetv2.__path__[0], "imageio"), rw_class_name, 'nnunetv2.imageio') + if ret is None: + raise RuntimeError("Unable to find reader writer class '%s'. Please make sure this class is located in the " + "nnunetv2.imageio module." % rw_class_name) + else: + return ret diff --git a/nnUNet/nnunetv2/imageio/readme.md b/nnUNet/nnunetv2/imageio/readme.md new file mode 100644 index 0000000..7819425 --- /dev/null +++ b/nnUNet/nnunetv2/imageio/readme.md @@ -0,0 +1,7 @@ +- Derive your adapter from `BaseReaderWriter`. +- Reimplement all abstractmethods. +- make sure to support 2d and 3d input images (or raise some error). +- place it in this folder or nnU-Net won't find it! +- add it to LIST_OF_IO_CLASSES in `reader_writer_registry.py` + +Bam, you're done! \ No newline at end of file diff --git a/nnUNet/nnunetv2/imageio/simpleitk_reader_writer.py b/nnUNet/nnunetv2/imageio/simpleitk_reader_writer.py new file mode 100644 index 0000000..2b9b168 --- /dev/null +++ b/nnUNet/nnunetv2/imageio/simpleitk_reader_writer.py @@ -0,0 +1,129 @@ +# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center +# (DKFZ), Heidelberg, Germany +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Union, List +import numpy as np +from nnunetv2.imageio.base_reader_writer import BaseReaderWriter +import SimpleITK as sitk + + +class SimpleITKIO(BaseReaderWriter): + supported_file_endings = [ + '.nii.gz', + '.nrrd', + '.mha' + ] + + def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]: + images = [] + spacings = [] + origins = [] + directions = [] + + spacings_for_nnunet = [] + for f in image_fnames: + itk_image = sitk.ReadImage(f) + spacings.append(itk_image.GetSpacing()) + origins.append(itk_image.GetOrigin()) + directions.append(itk_image.GetDirection()) + npy_image = sitk.GetArrayFromImage(itk_image) + if len(npy_image.shape) == 2: + # 2d + npy_image = npy_image[None, None] + max_spacing = max(spacings[-1]) + spacings_for_nnunet.append((max_spacing * 999, *list(spacings[-1])[::-1])) + elif len(npy_image.shape) == 3: + # 3d, as in original nnunet + npy_image = npy_image[None] + spacings_for_nnunet.append(list(spacings[-1])[::-1]) + elif len(npy_image.shape) == 4: + # 4d, multiple modalities in one file + spacings_for_nnunet.append(list(spacings[-1])[::-1][1:]) + pass + else: + raise RuntimeError("Unexpected number of dimensions: %d in file %s" % (len(npy_image.shape), f)) + + images.append(npy_image) + spacings_for_nnunet[-1] = list(np.abs(spacings_for_nnunet[-1])) + + if not self._check_all_same([i.shape for i in images]): + print('ERROR! Not all input images have the same shape!') + print('Shapes:') + print([i.shape for i in images]) + print('Image files:') + print(image_fnames) + raise RuntimeError() + if not self._check_all_same(spacings): + print('ERROR! Not all input images have the same spacing!') + print('Spacings:') + print(spacings) + print('Image files:') + print(image_fnames) + raise RuntimeError() + if not self._check_all_same(origins): + print('WARNING! Not all input images have the same origin!') + print('Origins:') + print(origins) + print('Image files:') + print(image_fnames) + print('It is up to you to decide whether that\'s a problem. You should run nnUNet_plot_dataset_pngs to verify ' + 'that segmentations and data overlap.') + if not self._check_all_same(directions): + print('WARNING! Not all input images have the same direction!') + print('Directions:') + print(directions) + print('Image files:') + print(image_fnames) + print('It is up to you to decide whether that\'s a problem. You should run nnUNet_plot_dataset_pngs to verify ' + 'that segmentations and data overlap.') + if not self._check_all_same(spacings_for_nnunet): + print('ERROR! Not all input images have the same spacing_for_nnunet! (This should not happen and must be a ' + 'bug. Please report!') + print('spacings_for_nnunet:') + print(spacings_for_nnunet) + print('Image files:') + print(image_fnames) + raise RuntimeError() + + stacked_images = np.vstack(images) + dict = { + 'sitk_stuff': { + # this saves the sitk geometry information. This part is NOT used by nnU-Net! + 'spacing': spacings[0], + 'origin': origins[0], + 'direction': directions[0] + }, + # the spacing is inverted with [::-1] because sitk returns the spacing in the wrong order lol. Image arrays + # are returned x,y,z but spacing is returned z,y,x. Duh. + 'spacing': spacings_for_nnunet[0] + } + return stacked_images.astype(np.float32), dict + + def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]: + return self.read_images((seg_fname, )) + + def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None: + assert len(seg.shape) == 3, 'segmentation must be 3d. If you are exporting a 2d segmentation, please provide it as shape 1,x,y' + output_dimension = len(properties['sitk_stuff']['spacing']) + assert 1 < output_dimension < 4 + if output_dimension == 2: + seg = seg[0] + + itk_image = sitk.GetImageFromArray(seg.astype(np.uint8)) + itk_image.SetSpacing(properties['sitk_stuff']['spacing']) + itk_image.SetOrigin(properties['sitk_stuff']['origin']) + itk_image.SetDirection(properties['sitk_stuff']['direction']) + + sitk.WriteImage(itk_image, output_fname) diff --git a/nnUNet/nnunetv2/imageio/tif_reader_writer.py b/nnUNet/nnunetv2/imageio/tif_reader_writer.py new file mode 100644 index 0000000..0aa5ff3 --- /dev/null +++ b/nnUNet/nnunetv2/imageio/tif_reader_writer.py @@ -0,0 +1,100 @@ +# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center +# (DKFZ), Heidelberg, Germany +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os.path +from typing import Tuple, Union, List +import numpy as np +from nnunetv2.imageio.base_reader_writer import BaseReaderWriter +import tifffile +from batchgenerators.utilities.file_and_folder_operations import isfile, load_json, save_json, split_path, join + + +class Tiff3DIO(BaseReaderWriter): + """ + reads and writes 3D tif(f) images. Uses tifffile package. Ignores metadata (for now)! + + If you have 2D tiffs, use NaturalImage2DIO + + Supports the use of auxiliary files for spacing information. If used, the auxiliary files are expected to end + with .json and omit the channel identifier. So, for example, the corresponding of image image1_0000.tif is + expected to be image1.json)! + """ + supported_file_endings = [ + '.tif', + '.tiff', + ] + + def read_images(self, image_fnames: Union[List[str], Tuple[str, ...]]) -> Tuple[np.ndarray, dict]: + # figure out file ending used here + ending = '.' + image_fnames[0].split('.')[-1] + assert ending.lower() in self.supported_file_endings, f'Ending {ending} not supported by {self.__class__.__name__}' + ending_length = len(ending) + truncate_length = ending_length + 5 # 5 comes from len(_0000) + + images = [] + for f in image_fnames: + image = tifffile.imread(f) + if len(image.shape) != 3: + raise RuntimeError("Only 3D images are supported! File: %s" % f) + images.append(image[None]) + + # see if aux file can be found + expected_aux_file = image_fnames[0][:-truncate_length] + '.json' + if isfile(expected_aux_file): + spacing = load_json(expected_aux_file)['spacing'] + assert len(spacing) == 3, 'spacing must have 3 entries, one for each dimension of the image. File: %s' % expected_aux_file + else: + print(f'WARNING no spacing file found for images {image_fnames}\nAssuming spacing (1, 1, 1).') + spacing = (1, 1, 1) + + if not self._check_all_same([i.shape for i in images]): + print('ERROR! Not all input images have the same shape!') + print('Shapes:') + print([i.shape for i in images]) + print('Image files:') + print(image_fnames) + raise RuntimeError() + + return np.vstack(images).astype(np.float32), {'spacing': spacing} + + def write_seg(self, seg: np.ndarray, output_fname: str, properties: dict) -> None: + # not ideal but I really have no clue how to set spacing/resolution information properly in tif files haha + tifffile.imwrite(output_fname, data=seg.astype(np.uint8), compression='zlib') + file = os.path.basename(output_fname) + out_dir = os.path.dirname(output_fname) + ending = file.split('.')[-1] + save_json({'spacing': properties['spacing']}, join(out_dir, file[:-(len(ending) + 1)] + '.json')) + + def read_seg(self, seg_fname: str) -> Tuple[np.ndarray, dict]: + # figure out file ending used here + ending = '.' + seg_fname.split('.')[-1] + assert ending.lower() in self.supported_file_endings, f'Ending {ending} not supported by {self.__class__.__name__}' + ending_length = len(ending) + + seg = tifffile.imread(seg_fname) + if len(seg.shape) != 3: + raise RuntimeError(f"Only 3D images are supported! File: {seg_fname}") + seg = seg[None] + + # see if aux file can be found + expected_aux_file = seg_fname[:-ending_length] + '.json' + if isfile(expected_aux_file): + spacing = load_json(expected_aux_file)['spacing'] + assert len(spacing) == 3, 'spacing must have 3 entries, one for each dimension of the image. File: %s' % expected_aux_file + assert all([i > 0 for i in spacing]), f"Spacing must be > 0, spacing: {spacing}" + else: + print(f'WARNING no spacing file found for segmentation {seg_fname}\nAssuming spacing (1, 1, 1).') + spacing = (1, 1, 1) + + return seg.astype(np.float32), {'spacing': spacing} \ No newline at end of file diff --git a/nnUNet/nnunetv2/inference/__init__.py b/nnUNet/nnunetv2/inference/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/inference/data_iterators.py b/nnUNet/nnunetv2/inference/data_iterators.py new file mode 100644 index 0000000..37d7f07 --- /dev/null +++ b/nnUNet/nnunetv2/inference/data_iterators.py @@ -0,0 +1,318 @@ +import multiprocessing +import queue +from torch.multiprocessing import Event, Process, Queue, Manager + +from time import sleep +from typing import Union, List + +import numpy as np +import torch +from batchgenerators.dataloading.data_loader import DataLoader + +from nnunetv2.preprocessing.preprocessors.default_preprocessor import DefaultPreprocessor +from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager + + +def preprocess_fromfiles_save_to_queue(list_of_lists: List[List[str]], + list_of_segs_from_prev_stage_files: Union[None, List[str]], + output_filenames_truncated: Union[None, List[str]], + plans_manager: PlansManager, + dataset_json: dict, + configuration_manager: ConfigurationManager, + target_queue: Queue, + done_event: Event, + abort_event: Event, + verbose: bool = False): + try: + label_manager = plans_manager.get_label_manager(dataset_json) + preprocessor = configuration_manager.preprocessor_class(verbose=verbose) + for idx in range(len(list_of_lists)): + data, seg, data_properites = preprocessor.run_case(list_of_lists[idx], + list_of_segs_from_prev_stage_files[ + idx] if list_of_segs_from_prev_stage_files is not None else None, + plans_manager, + configuration_manager, + dataset_json) + if list_of_segs_from_prev_stage_files is not None and list_of_segs_from_prev_stage_files[idx] is not None: + seg_onehot = convert_labelmap_to_one_hot(seg[0], label_manager.foreground_labels, data.dtype) + data = np.vstack((data, seg_onehot)) + + data = torch.from_numpy(data).contiguous().float() + + item = {'data': data, 'data_properites': data_properites, + 'ofile': output_filenames_truncated[idx] if output_filenames_truncated is not None else None} + success = False + while not success: + try: + if abort_event.is_set(): + return + target_queue.put(item, timeout=0.01) + success = True + except queue.Full: + pass + done_event.set() + except Exception as e: + abort_event.set() + raise e + + +def preprocessing_iterator_fromfiles(list_of_lists: List[List[str]], + list_of_segs_from_prev_stage_files: Union[None, List[str]], + output_filenames_truncated: Union[None, List[str]], + plans_manager: PlansManager, + dataset_json: dict, + configuration_manager: ConfigurationManager, + num_processes: int, + pin_memory: bool = False, + verbose: bool = False): + context = multiprocessing.get_context('spawn') + manager = Manager() + num_processes = min(len(list_of_lists), num_processes) + assert num_processes >= 1 + processes = [] + done_events = [] + target_queues = [] + abort_event = manager.Event() + for i in range(num_processes): + event = manager.Event() + queue = Manager().Queue(maxsize=1) + pr = context.Process(target=preprocess_fromfiles_save_to_queue, + args=( + list_of_lists[i::num_processes], + list_of_segs_from_prev_stage_files[ + i::num_processes] if list_of_segs_from_prev_stage_files is not None else None, + output_filenames_truncated[ + i::num_processes] if output_filenames_truncated is not None else None, + plans_manager, + dataset_json, + configuration_manager, + queue, + event, + abort_event, + verbose + ), daemon=True) + pr.start() + target_queues.append(queue) + done_events.append(event) + processes.append(pr) + + worker_ctr = 0 + # print(f"Type: {type(target_queues[worker_ctr])}") + while (not done_events[worker_ctr].is_set()) or (not target_queues[worker_ctr].empty()): + # print(type(target_queues[worker_ctr])) + if not target_queues[worker_ctr].empty(): + item = target_queues[worker_ctr].get() + worker_ctr = (worker_ctr + 1) % num_processes + else: + all_ok = all( + [i.is_alive() or j.is_set() for i, j in zip(processes, done_events)]) and not abort_event.is_set() + if not all_ok: + raise RuntimeError('Background workers died. Look for the error message further up! If there is ' + 'none then your RAM was full and the worker was killed by the OS. Use fewer ' + 'workers or get more RAM in that case!') + sleep(0.01) + continue + if pin_memory: + [i.pin_memory() for i in item.values() if isinstance(i, torch.Tensor)] + yield item + [p.join() for p in processes] + +class PreprocessAdapter(DataLoader): + def __init__(self, list_of_lists: List[List[str]], + list_of_segs_from_prev_stage_files: Union[None, List[str]], + preprocessor: DefaultPreprocessor, + output_filenames_truncated: Union[None, List[str]], + plans_manager: PlansManager, + dataset_json: dict, + configuration_manager: ConfigurationManager, + num_threads_in_multithreaded: int = 1): + self.preprocessor, self.plans_manager, self.configuration_manager, self.dataset_json = \ + preprocessor, plans_manager, configuration_manager, dataset_json + + self.label_manager = plans_manager.get_label_manager(dataset_json) + + if list_of_segs_from_prev_stage_files is None: + list_of_segs_from_prev_stage_files = [None] * len(list_of_lists) + if output_filenames_truncated is None: + output_filenames_truncated = [None] * len(list_of_lists) + + super().__init__(list(zip(list_of_lists, list_of_segs_from_prev_stage_files, output_filenames_truncated)), + 1, num_threads_in_multithreaded, + seed_for_shuffle=1, return_incomplete=True, + shuffle=False, infinite=False, sampling_probabilities=None) + + self.indices = list(range(len(list_of_lists))) + + def generate_train_batch(self): + idx = self.get_indices()[0] + files = self._data[idx][0] + seg_prev_stage = self._data[idx][1] + ofile = self._data[idx][2] + # if we have a segmentation from the previous stage we have to process it together with the images so that we + # can crop it appropriately (if needed). Otherwise it would just be resized to the shape of the data after + # preprocessing and then there might be misalignments + data, seg, data_properites = self.preprocessor.run_case(files, seg_prev_stage, self.plans_manager, + self.configuration_manager, + self.dataset_json) + if seg_prev_stage is not None: + seg_onehot = convert_labelmap_to_one_hot(seg[0], self.label_manager.foreground_labels, data.dtype) + data = np.vstack((data, seg_onehot)) + + data = torch.from_numpy(data) + + return {'data': data, 'data_properites': data_properites, 'ofile': ofile} + + +class PreprocessAdapterFromNpy(DataLoader): + def __init__(self, list_of_images: List[np.ndarray], + list_of_segs_from_prev_stage: Union[List[np.ndarray], None], + list_of_image_properties: List[dict], + truncated_ofnames: Union[List[str], None], + plans_manager: PlansManager, dataset_json: dict, configuration_manager: ConfigurationManager, + num_threads_in_multithreaded: int = 1, verbose: bool = False): + preprocessor = configuration_manager.preprocessor_class(verbose=verbose) + self.preprocessor, self.plans_manager, self.configuration_manager, self.dataset_json, self.truncated_ofnames = \ + preprocessor, plans_manager, configuration_manager, dataset_json, truncated_ofnames + + self.label_manager = plans_manager.get_label_manager(dataset_json) + + if list_of_segs_from_prev_stage is None: + list_of_segs_from_prev_stage = [None] * len(list_of_images) + if truncated_ofnames is None: + truncated_ofnames = [None] * len(list_of_images) + + super().__init__( + list(zip(list_of_images, list_of_segs_from_prev_stage, list_of_image_properties, truncated_ofnames)), + 1, num_threads_in_multithreaded, + seed_for_shuffle=1, return_incomplete=True, + shuffle=False, infinite=False, sampling_probabilities=None) + + self.indices = list(range(len(list_of_images))) + + def generate_train_batch(self): + idx = self.get_indices()[0] + image = self._data[idx][0] + seg_prev_stage = self._data[idx][1] + props = self._data[idx][2] + ofname = self._data[idx][3] + # if we have a segmentation from the previous stage we have to process it together with the images so that we + # can crop it appropriately (if needed). Otherwise it would just be resized to the shape of the data after + # preprocessing and then there might be misalignments + data, seg = self.preprocessor.run_case_npy(image, seg_prev_stage, props, + self.plans_manager, + self.configuration_manager, + self.dataset_json) + if seg_prev_stage is not None: + seg_onehot = convert_labelmap_to_one_hot(seg[0], self.label_manager.foreground_labels, data.dtype) + data = np.vstack((data, seg_onehot)) + + data = torch.from_numpy(data) + + return {'data': data, 'data_properites': props, 'ofile': ofname} + + +def preprocess_fromnpy_save_to_queue(list_of_images: List[np.ndarray], + list_of_segs_from_prev_stage: Union[List[np.ndarray], None], + list_of_image_properties: List[dict], + truncated_ofnames: Union[List[str], None], + plans_manager: PlansManager, + dataset_json: dict, + configuration_manager: ConfigurationManager, + target_queue: Queue, + done_event: Event, + abort_event: Event, + verbose: bool = False): + try: + label_manager = plans_manager.get_label_manager(dataset_json) + preprocessor = configuration_manager.preprocessor_class(verbose=verbose) + for idx in range(len(list_of_images)): + data, seg = preprocessor.run_case_npy(list_of_images[idx], + list_of_segs_from_prev_stage[ + idx] if list_of_segs_from_prev_stage is not None else None, + list_of_image_properties[idx], + plans_manager, + configuration_manager, + dataset_json) + if list_of_segs_from_prev_stage is not None and list_of_segs_from_prev_stage[idx] is not None: + seg_onehot = convert_labelmap_to_one_hot(seg[0], label_manager.foreground_labels, data.dtype) + data = np.vstack((data, seg_onehot)) + + data = torch.from_numpy(data).contiguous().float() + + item = {'data': data, 'data_properites': list_of_image_properties[idx], + 'ofile': truncated_ofnames[idx] if truncated_ofnames is not None else None} + success = False + while not success: + try: + if abort_event.is_set(): + return + target_queue.put(item, timeout=0.01) + success = True + except queue.Full: + pass + done_event.set() + except Exception as e: + abort_event.set() + raise e + + +def preprocessing_iterator_fromnpy(list_of_images: List[np.ndarray], + list_of_segs_from_prev_stage: Union[List[np.ndarray], None], + list_of_image_properties: List[dict], + truncated_ofnames: Union[List[str], None], + plans_manager: PlansManager, + dataset_json: dict, + configuration_manager: ConfigurationManager, + num_processes: int, + pin_memory: bool = False, + verbose: bool = False): + context = multiprocessing.get_context('spawn') + manager = Manager() + num_processes = min(len(list_of_images), num_processes) + assert num_processes >= 1 + target_queues = [] + processes = [] + done_events = [] + abort_event = manager.Event() + for i in range(num_processes): + event = manager.Event() + queue = manager.Queue(maxsize=1) + pr = context.Process(target=preprocess_fromnpy_save_to_queue, + args=( + list_of_images[i::num_processes], + list_of_segs_from_prev_stage[ + i::num_processes] if list_of_segs_from_prev_stage is not None else None, + list_of_image_properties[i::num_processes], + truncated_ofnames[i::num_processes] if truncated_ofnames is not None else None, + plans_manager, + dataset_json, + configuration_manager, + queue, + event, + abort_event, + verbose + ), daemon=True) + pr.start() + done_events.append(event) + processes.append(pr) + target_queues.append(queue) + + worker_ctr = 0 + while (not done_events[worker_ctr].is_set()) or (not target_queues[worker_ctr].empty()): + if not target_queues[worker_ctr].empty(): + item = target_queues[worker_ctr].get() + worker_ctr = (worker_ctr + 1) % num_processes + else: + all_ok = all( + [i.is_alive() or j.is_set() for i, j in zip(processes, done_events)]) and not abort_event.is_set() + if not all_ok: + raise RuntimeError('Background workers died. Look for the error message further up! If there is ' + 'none then your RAM was full and the worker was killed by the OS. Use fewer ' + 'workers or get more RAM in that case!') + sleep(0.01) + continue + if pin_memory: + [i.pin_memory() for i in item.values() if isinstance(i, torch.Tensor)] + yield item + [p.join() for p in processes] diff --git a/nnUNet/nnunetv2/inference/examples.py b/nnUNet/nnunetv2/inference/examples.py new file mode 100644 index 0000000..8e8f264 --- /dev/null +++ b/nnUNet/nnunetv2/inference/examples.py @@ -0,0 +1,102 @@ +if __name__ == '__main__': + from nnunetv2.paths import nnUNet_results, nnUNet_raw + import torch + from batchgenerators.utilities.file_and_folder_operations import join + from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor + from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO + + # nnUNetv2_predict -d 3 -f 0 -c 3d_lowres -i imagesTs -o imagesTs_predlowres --continue_prediction + + # instantiate the nnUNetPredictor + predictor = nnUNetPredictor( + tile_step_size=0.5, + use_gaussian=True, + use_mirroring=True, + perform_everything_on_gpu=True, + device=torch.device('cuda', 0), + verbose=False, + verbose_preprocessing=False, + allow_tqdm=True + ) + # initializes the network architecture, loads the checkpoint + predictor.initialize_from_trained_model_folder( + join(nnUNet_results, 'Dataset003_Liver/nnUNetTrainer__nnUNetPlans__3d_lowres'), + use_folds=(0,), + checkpoint_name='checkpoint_final.pth', + ) + # variant 1: give input and output folders + predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'), + join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres'), + save_probabilities=False, overwrite=False, + num_processes_preprocessing=2, num_processes_segmentation_export=2, + folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0) + + # variant 2, use list of files as inputs. Note how we use nested lists!!! + indir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs') + outdir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres') + predictor.predict_from_files([[join(indir, 'liver_152_0000.nii.gz')], + [join(indir, 'liver_142_0000.nii.gz')]], + [join(outdir, 'liver_152.nii.gz'), + join(outdir, 'liver_142.nii.gz')], + save_probabilities=False, overwrite=True, + num_processes_preprocessing=2, num_processes_segmentation_export=2, + folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0) + + # variant 2.5, returns segmentations + indir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs') + predicted_segmentations = predictor.predict_from_files([[join(indir, 'liver_152_0000.nii.gz')], + [join(indir, 'liver_142_0000.nii.gz')]], + None, + save_probabilities=True, overwrite=True, + num_processes_preprocessing=2, + num_processes_segmentation_export=2, + folder_with_segs_from_prev_stage=None, num_parts=1, + part_id=0) + + # predict several npy images + from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO + + img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')]) + img2, props2 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_146_0000.nii.gz')]) + img3, props3 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_145_0000.nii.gz')]) + img4, props4 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_144_0000.nii.gz')]) + # we do not set output files so that the segmentations will be returned. You can of course also specify output + # files instead (no return value on that case) + ret = predictor.predict_from_list_of_npy_arrays([img, img2, img3, img4], + None, + [props, props2, props3, props4], + None, 2, save_probabilities=False, + num_processes_segmentation_export=2) + + # predict a single numpy array + img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')]) + ret = predictor.predict_single_npy_array(img, props, None, None, True) + + # custom iterator + + img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')]) + img2, props2 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_146_0000.nii.gz')]) + img3, props3 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_145_0000.nii.gz')]) + img4, props4 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_144_0000.nii.gz')]) + + + # each element returned by data_iterator must be a dict with 'data', 'ofile' and 'data_properites' keys! + # If 'ofile' is None, the result will be returned instead of written to a file + # the iterator is responsible for performing the correct preprocessing! + # note how the iterator here does not use multiprocessing -> preprocessing will be done in the main thread! + # take a look at the default iterators for predict_from_files and predict_from_list_of_npy_arrays + # (they both use predictor.predict_from_data_iterator) for inspiration! + def my_iterator(list_of_input_arrs, list_of_input_props): + preprocessor = predictor.configuration_manager.preprocessor_class(verbose=predictor.verbose) + for a, p in zip(list_of_input_arrs, list_of_input_props): + data, seg = preprocessor.run_case_npy(a, + None, + p, + predictor.plans_manager, + predictor.configuration_manager, + predictor.dataset_json) + yield {'data': torch.from_numpy(data).contiguous().pin_memory(), 'data_properites': p, 'ofile': None} + + + ret = predictor.predict_from_data_iterator(my_iterator([img, img2, img3, img4], [props, props2, props3, props4]), + save_probabilities=False, num_processes_segmentation_export=3) diff --git a/nnUNet/nnunetv2/inference/export_prediction.py b/nnUNet/nnunetv2/inference/export_prediction.py new file mode 100644 index 0000000..418a73d --- /dev/null +++ b/nnUNet/nnunetv2/inference/export_prediction.py @@ -0,0 +1,168 @@ +import os +from copy import deepcopy +from typing import Union, List + +from skimage.transform import resize +import numpy as np +import torch +from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice +from batchgenerators.utilities.file_and_folder_operations import load_json, isfile, save_pickle + +from nnunetv2.configuration import default_num_processes +from nnunetv2.utilities.label_handling.label_handling import LabelManager +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager + + +def convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits: Union[torch.Tensor, np.ndarray], + plans_manager: PlansManager, + configuration_manager: ConfigurationManager, + label_manager: LabelManager, + properties_dict: dict, + return_probabilities: bool = False, + num_threads_torch: int = default_num_processes): + # This makes the docker hang. Why, idk + #old_threads = torch.get_num_threads() + #torch.set_num_threads(num_threads_torch) + + # resample to original shape + current_spacing = configuration_manager.spacing if \ + len(configuration_manager.spacing) == \ + len(properties_dict['shape_after_cropping_and_before_resampling']) else \ + [properties_dict['spacing'][0], *configuration_manager.spacing] + predicted_logits = predicted_logits.cpu() + + # Ok so here the idea is that since output tensors will have 31 channels (30 classes + background), + # combined to the huge X and Y dimensions caused by the small median spacing, this makes for huge tensors, + # barely fitting in memory. + # So, instead of following the 'nonlin -> argmax -> resampling' scheme + # (which implies creation of np array along the way, like we have that kind of memory lying around?!); + # Instead we do 'argmax -> resampling' (who needs nonlin anyway? (unless you want pseudo-probabilities)) + # This divides the size in memory by 62 (!). Indeed, we go from 31 channels to 1 and go fromm float16 to uint8. + # Results should be "fairly" similar + # segmentation = predicted_logits.squeeze() # Old hack for memory reduction + segmentation = torch.argmax(predicted_logits, 0, keepdim=False) + # If you get to this point, memory should not be a problem going forward + segmentation = segmentation.numpy().astype(np.uint8) + + # Copied the resample_data_or_seg function + segmentation = resize(segmentation, properties_dict['shape_after_cropping_and_before_resampling'], 0, + mode="edge", clip=True, anti_aliasing=False) + + #segmentation = configuration_manager.resampling_fn_seg(predicted_logits, + # properties_dict['shape_after_cropping_and_before_resampling'], + # current_spacing, + # properties_dict['spacing']) + # return value of resampling_fn_probabilities can be ndarray or Tensor but that doesnt matter because + # apply_inference_nonlin will covnert to torch + #predicted_probabilities = label_manager.apply_inference_nonlin(predicted_logits_r) + #segmentation = predicted_logits_r # label_manager.convert_probabilities_to_segmentation(predicted_probabilities) + + # segmentation may be torch.Tensor but we continue with numpy + if isinstance(segmentation, torch.Tensor): + segmentation = segmentation.cpu().numpy() + + # put segmentation in bbox (revert cropping) + segmentation_reverted_cropping = np.zeros(properties_dict['shape_before_cropping'], + dtype=np.uint8 if len(label_manager.foreground_labels) < 255 else np.uint16) + slicer = bounding_box_to_slice(properties_dict['bbox_used_for_cropping']) + segmentation_reverted_cropping[slicer] = segmentation + del segmentation + + # revert transpose + segmentation_reverted_cropping = segmentation_reverted_cropping.transpose(plans_manager.transpose_backward) + if return_probabilities: # We don't. So u should never set it to True + # revert cropping + predicted_probabilities = label_manager.revert_cropping_on_probabilities(predicted_probabilities, + properties_dict[ + 'bbox_used_for_cropping'], + properties_dict[ + 'shape_before_cropping']) + predicted_probabilities = predicted_probabilities.cpu().numpy() + # revert transpose + predicted_probabilities = predicted_probabilities.transpose([0] + [i + 1 for i in + plans_manager.transpose_backward]) + # Same reason as above, it makes the docker hang + # torch.set_num_threads(old_threads) + return segmentation_reverted_cropping, predicted_probabilities + else: + # torch.set_num_threads(old_threads) + return segmentation_reverted_cropping + + +def export_prediction_from_logits(predicted_array_or_file: Union[np.ndarray, torch.Tensor], properties_dict: dict, + configuration_manager: ConfigurationManager, + plans_manager: PlansManager, + dataset_json_dict_or_file: Union[dict, str], output_file_truncated: str, + save_probabilities: bool = False): + # if isinstance(predicted_array_or_file, str): + # tmp = deepcopy(predicted_array_or_file) + # if predicted_array_or_file.endswith('.npy'): + # predicted_array_or_file = np.load(predicted_array_or_file) + # elif predicted_array_or_file.endswith('.npz'): + # predicted_array_or_file = np.load(predicted_array_or_file)['softmax'] + # os.remove(tmp) + + if isinstance(dataset_json_dict_or_file, str): + dataset_json_dict_or_file = load_json(dataset_json_dict_or_file) + + label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file) + print("starting conversion", flush=True) + ret = convert_predicted_logits_to_segmentation_with_correct_shape( + predicted_array_or_file, plans_manager, configuration_manager, label_manager, properties_dict, + return_probabilities=save_probabilities + ) + del predicted_array_or_file + + # save + if save_probabilities: + segmentation_final, probabilities_final = ret + np.savez_compressed(output_file_truncated + '.npz', probabilities=probabilities_final) + save_pickle(properties_dict, output_file_truncated + '.pkl') + del probabilities_final, ret + else: + segmentation_final = ret + del ret + + rw = plans_manager.image_reader_writer_class() + print("saving", flush=True) + rw.write_seg(segmentation_final, output_file_truncated + dataset_json_dict_or_file['file_ending'], + properties_dict) + + +def resample_and_save(predicted: Union[torch.Tensor, np.ndarray], target_shape: List[int], output_file: str, + plans_manager: PlansManager, configuration_manager: ConfigurationManager, properties_dict: dict, + dataset_json_dict_or_file: Union[dict, str], num_threads_torch: int = default_num_processes) \ + -> None: + # # needed for cascade + # if isinstance(predicted, str): + # assert isfile(predicted), "If isinstance(segmentation_softmax, str) then " \ + # "isfile(segmentation_softmax) must be True" + # del_file = deepcopy(predicted) + # predicted = np.load(predicted) + # os.remove(del_file) + old_threads = torch.get_num_threads() + torch.set_num_threads(num_threads_torch) + + if isinstance(dataset_json_dict_or_file, str): + dataset_json_dict_or_file = load_json(dataset_json_dict_or_file) + + # resample to original shape + current_spacing = configuration_manager.spacing if \ + len(configuration_manager.spacing) == len(properties_dict['shape_after_cropping_and_before_resampling']) else \ + [properties_dict['spacing'][0], *configuration_manager.spacing] + target_spacing = configuration_manager.spacing if len(configuration_manager.spacing) == \ + len(properties_dict['shape_after_cropping_and_before_resampling']) else \ + [properties_dict['spacing'][0], *configuration_manager.spacing] + predicted_array_or_file = configuration_manager.resampling_fn_probabilities(predicted, + target_shape, + current_spacing, + target_spacing) + + # create segmentation (argmax, regions, etc) + label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file) + segmentation = label_manager.convert_logits_to_segmentation(predicted_array_or_file) + # segmentation may be torch.Tensor but we continue with numpy + if isinstance(segmentation, torch.Tensor): + segmentation = segmentation.cpu().numpy() + np.savez_compressed(output_file, seg=segmentation.astype(np.uint8)) + torch.set_num_threads(old_threads) diff --git a/nnUNet/nnunetv2/inference/predict_from_raw_data.py b/nnUNet/nnunetv2/inference/predict_from_raw_data.py new file mode 100644 index 0000000..299e61a --- /dev/null +++ b/nnUNet/nnunetv2/inference/predict_from_raw_data.py @@ -0,0 +1,938 @@ +import inspect +import multiprocessing +import os +import traceback +from copy import deepcopy +from time import sleep +from typing import Tuple, Union, List, Optional + +import numpy as np +import torch +from acvl_utils.cropping_and_padding.padding import pad_nd_image +from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter +from batchgenerators.utilities.file_and_folder_operations import load_json, join, isfile, maybe_mkdir_p, isdir, subdirs, \ + save_json +from torch import nn +from torch._dynamo import OptimizedModule +from torch.nn.parallel import DistributedDataParallel +from tqdm import tqdm + +import nnunetv2 +from nnunetv2.configuration import default_num_processes +from nnunetv2.inference.data_iterators import PreprocessAdapterFromNpy, preprocessing_iterator_fromfiles, \ + preprocessing_iterator_fromnpy +from nnunetv2.inference.export_prediction import export_prediction_from_logits, \ + convert_predicted_logits_to_segmentation_with_correct_shape +from nnunetv2.inference.sliding_window_prediction import compute_gaussian, \ + compute_steps_for_sliding_window +from nnunetv2.utilities.file_path_utilities import get_output_folder, check_workers_alive_and_busy +from nnunetv2.utilities.find_class_by_name import recursive_find_python_class +from nnunetv2.utilities.helpers import empty_cache, dummy_context +from nnunetv2.utilities.json_export import recursive_fix_for_json_export +from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager +from nnunetv2.utilities.utils import create_lists_from_splitted_dataset_folder + + +class nnUNetPredictor(object): + def __init__(self, + tile_step_size: float = 0.5, + use_gaussian: bool = True, + use_mirroring: bool = True, + perform_everything_on_gpu: bool = True, + device: torch.device = torch.device('cuda'), + verbose: bool = False, + verbose_preprocessing: bool = False, + allow_tqdm: bool = True): + self.verbose = verbose + self.verbose_preprocessing = verbose_preprocessing + self.allow_tqdm = allow_tqdm + + self.plans_manager, self.configuration_manager, self.list_of_parameters, self.network, self.dataset_json, \ + self.trainer_name, self.allowed_mirroring_axes, self.label_manager = None, None, None, None, None, None, None, None + + self.tile_step_size = tile_step_size + self.use_gaussian = use_gaussian + self.use_mirroring = use_mirroring + if device.type == 'cuda': + # device = torch.device(type='cuda', index=0) # set the desired GPU with CUDA_VISIBLE_DEVICES! + # why would I ever want to do that. Stupid dobby. This kills DDP inference... + pass + if device.type != 'cuda': + print(f'perform_everything_on_gpu=True is only supported for cuda devices! Setting this to False') + perform_everything_on_gpu = False + self.device = device + self.perform_everything_on_gpu = perform_everything_on_gpu + + def initialize_from_trained_model_folder(self, model_training_output_dir: str, + use_folds: Union[Tuple[Union[int, str]], None], + checkpoint_name: str = 'checkpoint_final.pth'): + """ + This is used when making predictions with a trained model + """ + if use_folds is None: + use_folds = nnUNetPredictor.auto_detect_available_folds(model_training_output_dir, checkpoint_name) + + dataset_json = load_json(join(model_training_output_dir, 'dataset.json')) + plans = load_json(join(model_training_output_dir, 'plans.json')) + plans_manager = PlansManager(plans) + + if isinstance(use_folds, str): + use_folds = [use_folds] + + parameters = [] + for i, f in enumerate(use_folds): + f = int(f) if f != 'all' else f + checkpoint = torch.load(join(model_training_output_dir, f'fold_{f}', checkpoint_name), + map_location=torch.device('cpu')) + checkpoint_name = join(model_training_output_dir, f'fold_{f}', checkpoint_name) + + if i == 0: + trainer_name = checkpoint['trainer_name'] + configuration_name = checkpoint['init_args']['configuration'] + inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \ + 'inference_allowed_mirroring_axes' in checkpoint.keys() else None + del checkpoint + # parameters.append(checkpoint['network_weights']) + parameters.append(checkpoint_name) + + configuration_manager = plans_manager.get_configuration(configuration_name) + # restore network + num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json) + trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), + trainer_name, 'nnunetv2.training.nnUNetTrainer') + network = trainer_class.build_network_architecture(plans_manager, dataset_json, configuration_manager, + num_input_channels, enable_deep_supervision=False) + self.plans_manager = plans_manager + self.configuration_manager = configuration_manager + self.list_of_parameters = parameters + self.network = network + self.dataset_json = dataset_json + self.trainer_name = trainer_name + self.allowed_mirroring_axes = inference_allowed_mirroring_axes + self.label_manager = plans_manager.get_label_manager(dataset_json) + if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \ + and not isinstance(self.network, OptimizedModule): + print('compiling network') + self.network = torch.compile(self.network) + + def manual_initialization(self, network: nn.Module, plans_manager: PlansManager, + configuration_manager: ConfigurationManager, parameters: Optional[List[dict]], + dataset_json: dict, trainer_name: str, + inference_allowed_mirroring_axes: Optional[Tuple[int, ...]]): + """ + This is used by the nnUNetTrainer to initialize nnUNetPredictor for the final validation + """ + self.plans_manager = plans_manager + self.configuration_manager = configuration_manager + self.list_of_parameters = parameters + self.network = network + self.dataset_json = dataset_json + self.trainer_name = trainer_name + self.allowed_mirroring_axes = inference_allowed_mirroring_axes + self.label_manager = plans_manager.get_label_manager(dataset_json) + allow_compile = True + allow_compile = allow_compile and ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) + allow_compile = allow_compile and not isinstance(self.network, OptimizedModule) + if isinstance(self.network, DistributedDataParallel): + allow_compile = allow_compile and isinstance(self.network.module, OptimizedModule) + if allow_compile: + print('compiling network') + self.network = torch.compile(self.network) + + @staticmethod + def auto_detect_available_folds(model_training_output_dir, checkpoint_name): + print('use_folds is None, attempting to auto detect available folds') + fold_folders = subdirs(model_training_output_dir, prefix='fold_', join=False) + fold_folders = [i for i in fold_folders if i != 'fold_all'] + fold_folders = [i for i in fold_folders if isfile(join(model_training_output_dir, i, checkpoint_name))] + use_folds = [int(i.split('_')[-1]) for i in fold_folders] + print(f'found the following folds: {use_folds}') + return use_folds + + def _manage_input_and_output_lists(self, list_of_lists_or_source_folder: Union[str, List[List[str]]], + output_folder_or_list_of_truncated_output_files: Union[None, str, List[str]], + folder_with_segs_from_prev_stage: str = None, + overwrite: bool = True, + part_id: int = 0, + num_parts: int = 1, + save_probabilities: bool = False): + if isinstance(list_of_lists_or_source_folder, str): + list_of_lists_or_source_folder = create_lists_from_splitted_dataset_folder(list_of_lists_or_source_folder, + self.dataset_json['file_ending']) + print(f'There are {len(list_of_lists_or_source_folder)} cases in the source folder') + list_of_lists_or_source_folder = list_of_lists_or_source_folder[part_id::num_parts] + caseids = [os.path.basename(i[0])[:-(len(self.dataset_json['file_ending']) + 5)] for i in + list_of_lists_or_source_folder] + print( + f'I am process {part_id} out of {num_parts} (max process ID is {num_parts - 1}, we start counting with 0!)') + print(f'There are {len(caseids)} cases that I would like to predict') + + if isinstance(output_folder_or_list_of_truncated_output_files, str): + output_filename_truncated = [join(output_folder_or_list_of_truncated_output_files, i) for i in caseids] + else: + output_filename_truncated = output_folder_or_list_of_truncated_output_files + + seg_from_prev_stage_files = [join(folder_with_segs_from_prev_stage, i + self.dataset_json['file_ending']) if + folder_with_segs_from_prev_stage is not None else None for i in caseids] + # remove already predicted files form the lists + if not overwrite and output_filename_truncated is not None: + tmp = [isfile(i + self.dataset_json['file_ending']) for i in output_filename_truncated] + if save_probabilities: + tmp2 = [isfile(i + '.npz') for i in output_filename_truncated] + tmp = [i and j for i, j in zip(tmp, tmp2)] + not_existing_indices = [i for i, j in enumerate(tmp) if not j] + + output_filename_truncated = [output_filename_truncated[i] for i in not_existing_indices] + list_of_lists_or_source_folder = [list_of_lists_or_source_folder[i] for i in not_existing_indices] + seg_from_prev_stage_files = [seg_from_prev_stage_files[i] for i in not_existing_indices] + print(f'overwrite was set to {overwrite}, so I am only working on cases that haven\'t been predicted yet. ' + f'That\'s {len(not_existing_indices)} cases.') + return list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files + + def predict_from_files(self, + list_of_lists_or_source_folder: Union[str, List[List[str]]], + output_folder_or_list_of_truncated_output_files: Union[str, None, List[str]], + save_probabilities: bool = False, + overwrite: bool = True, + num_processes_preprocessing: int = default_num_processes, + num_processes_segmentation_export: int = default_num_processes, + folder_with_segs_from_prev_stage: str = None, + num_parts: int = 1, + part_id: int = 0): + """ + This is nnU-Net's default function for making predictions. It works best for batch predictions + (predicting many images at once). + """ + if isinstance(output_folder_or_list_of_truncated_output_files, str): + output_folder = output_folder_or_list_of_truncated_output_files + elif isinstance(output_folder_or_list_of_truncated_output_files, list): + output_folder = os.path.dirname(output_folder_or_list_of_truncated_output_files[0]) + else: + output_folder = None + + ######################## + # let's store the input arguments so that its clear what was used to generate the prediction + if output_folder is not None: + my_init_kwargs = {} + for k in inspect.signature(self.predict_from_files).parameters.keys(): + my_init_kwargs[k] = locals()[k] + my_init_kwargs = deepcopy( + my_init_kwargs) # let's not unintentionally change anything in-place. Take this as a + recursive_fix_for_json_export(my_init_kwargs) + maybe_mkdir_p(output_folder) + save_json(my_init_kwargs, join(output_folder, 'predict_from_raw_data_args.json')) + + # we need these two if we want to do things with the predictions like for example apply postprocessing + save_json(self.dataset_json, join(output_folder, 'dataset.json'), sort_keys=False) + save_json(self.plans_manager.plans, join(output_folder, 'plans.json'), sort_keys=False) + ####################### + + # check if we need a prediction from the previous stage + if self.configuration_manager.previous_stage_name is not None: + assert folder_with_segs_from_prev_stage is not None, \ + f'The requested configuration is a cascaded network. It requires the segmentations of the previous ' \ + f'stage ({self.configuration_manager.previous_stage_name}) as input. Please provide the folder where' \ + f' they are located via folder_with_segs_from_prev_stage' + + # sort out input and output filenames + list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files = \ + self._manage_input_and_output_lists(list_of_lists_or_source_folder, + output_folder_or_list_of_truncated_output_files, + folder_with_segs_from_prev_stage, overwrite, part_id, num_parts, + save_probabilities) + if len(list_of_lists_or_source_folder) == 0: + return + + data_iterator = self._internal_get_data_iterator_from_lists_of_filenames(list_of_lists_or_source_folder, + seg_from_prev_stage_files, + output_filename_truncated, + num_processes_preprocessing) + + return self.predict_from_data_iterator(data_iterator, save_probabilities, num_processes_segmentation_export) + + def _internal_get_data_iterator_from_lists_of_filenames(self, + input_list_of_lists: List[List[str]], + seg_from_prev_stage_files: Union[List[str], None], + output_filenames_truncated: Union[List[str], None], + num_processes: int): + return preprocessing_iterator_fromfiles(input_list_of_lists, seg_from_prev_stage_files, + output_filenames_truncated, self.plans_manager, self.dataset_json, + self.configuration_manager, num_processes, self.device.type == 'cuda', + self.verbose_preprocessing) + # preprocessor = self.configuration_manager.preprocessor_class(verbose=self.verbose_preprocessing) + # # hijack batchgenerators, yo + # # we use the multiprocessing of the batchgenerators dataloader to handle all the background worker stuff. This + # # way we don't have to reinvent the wheel here. + # num_processes = max(1, min(num_processes, len(input_list_of_lists))) + # ppa = PreprocessAdapter(input_list_of_lists, seg_from_prev_stage_files, preprocessor, + # output_filenames_truncated, self.plans_manager, self.dataset_json, + # self.configuration_manager, num_processes) + # if num_processes == 0: + # mta = SingleThreadedAugmenter(ppa, None) + # else: + # mta = MultiThreadedAugmenter(ppa, None, num_processes, 1, None, pin_memory=pin_memory) + # return mta + + def get_data_iterator_from_raw_npy_data(self, + image_or_list_of_images: Union[np.ndarray, List[np.ndarray]], + segs_from_prev_stage_or_list_of_segs_from_prev_stage: Union[None, + np.ndarray, + List[ + np.ndarray]], + properties_or_list_of_properties: Union[dict, List[dict]], + truncated_ofname: Union[str, List[str], None], + num_processes: int = 3): + + list_of_images = [image_or_list_of_images] if not isinstance(image_or_list_of_images, list) else \ + image_or_list_of_images + + if isinstance(segs_from_prev_stage_or_list_of_segs_from_prev_stage, np.ndarray): + segs_from_prev_stage_or_list_of_segs_from_prev_stage = [ + segs_from_prev_stage_or_list_of_segs_from_prev_stage] + + if isinstance(truncated_ofname, str): + truncated_ofname = [truncated_ofname] + + if isinstance(properties_or_list_of_properties, dict): + properties_or_list_of_properties = [properties_or_list_of_properties] + + num_processes = min(num_processes, len(list_of_images)) + pp = preprocessing_iterator_fromnpy( + list_of_images, + segs_from_prev_stage_or_list_of_segs_from_prev_stage, + properties_or_list_of_properties, + truncated_ofname, + self.plans_manager, + self.dataset_json, + self.configuration_manager, + num_processes, + self.device.type == 'cuda', + self.verbose_preprocessing + ) + + return pp + + def predict_from_list_of_npy_arrays(self, + image_or_list_of_images: Union[np.ndarray, List[np.ndarray]], + segs_from_prev_stage_or_list_of_segs_from_prev_stage: Union[None, + np.ndarray, + List[ + np.ndarray]], + properties_or_list_of_properties: Union[dict, List[dict]], + truncated_ofname: Union[str, List[str], None], + num_processes: int = 3, + save_probabilities: bool = False, + num_processes_segmentation_export: int = default_num_processes): + iterator = self.get_data_iterator_from_raw_npy_data(image_or_list_of_images, + segs_from_prev_stage_or_list_of_segs_from_prev_stage, + properties_or_list_of_properties, + truncated_ofname, + num_processes) + return self.predict_from_data_iterator(iterator, save_probabilities, num_processes_segmentation_export) + + def predict_from_data_iterator(self, + data_iterator, + save_probabilities: bool = False, + num_processes_segmentation_export: int = default_num_processes): + """ + each element returned by data_iterator must be a dict with 'data', 'ofile' and 'data_properites' keys! + If 'ofile' is None, the result will be returned instead of written to a file + """ + with multiprocessing.get_context("spawn").Pool(num_processes_segmentation_export) as export_pool: + worker_list = [i for i in export_pool._pool] + r = [] + for preprocessed in data_iterator: + data = preprocessed['data'] + if isinstance(data, str): + delfile = data + data = torch.from_numpy(np.load(data)) + os.remove(delfile) + + ofile = preprocessed['ofile'] + if ofile is not None: + print(f'\nPredicting {os.path.basename(ofile)}:') + else: + print(f'\nPredicting image of shape {data.shape}:') + + print(f'perform_everything_on_gpu: {self.perform_everything_on_gpu}') + + properties = preprocessed['data_properites'] + + # let's not get into a runaway situation where the GPU predicts so fast that the disk has to b swamped with + # npy files + proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2) + while not proceed: + # print('sleeping') + sleep(0.1) + proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2) + + prediction = self.predict_logits_from_preprocessed_data(data).cpu() + + if ofile is not None: + # this needs to go into background processes + # export_prediction_from_logits(prediction, properties, configuration_manager, plans_manager, + # dataset_json, ofile, save_probabilities) + print('sending off prediction to background worker for resampling and export') + r.append( + export_pool.starmap_async( + export_prediction_from_logits, + ((prediction, properties, self.configuration_manager, self.plans_manager, + self.dataset_json, ofile, save_probabilities),) + ) + ) + else: + # convert_predicted_logits_to_segmentation_with_correct_shape(prediction, plans_manager, + # configuration_manager, label_manager, + # properties, + # save_probabilities) + print('sending off prediction to background worker for resampling') + r.append( + export_pool.starmap_async( + convert_predicted_logits_to_segmentation_with_correct_shape, ( + (prediction, self.plans_manager, + self.configuration_manager, self.label_manager, + properties, + save_probabilities),) + ) + ) + if ofile is not None: + print(f'done with {os.path.basename(ofile)}') + else: + print(f'\nDone with image of shape {data.shape}:') + ret = [i.get()[0] for i in r] + + if isinstance(data_iterator, MultiThreadedAugmenter): + data_iterator._finish() + + # clear lru cache + compute_gaussian.cache_clear() + # clear device cache + empty_cache(self.device) + return ret + + def predict_single_npy_array(self, input_image: np.ndarray, image_properties: dict, + segmentation_previous_stage: np.ndarray = None, + output_file_truncated: str = None, + save_or_return_probabilities: bool = False): + """ + image_properties must only have a 'spacing' key! + """ + ppa = PreprocessAdapterFromNpy([input_image], [segmentation_previous_stage], [image_properties], + [output_file_truncated], + self.plans_manager, self.dataset_json, self.configuration_manager, + num_threads_in_multithreaded=1, verbose=self.verbose) + if self.verbose: + print('preprocessing') + dct = next(ppa) + + if self.verbose: + print('predicting') + predicted_logits = self.predict_logits_from_preprocessed_data(dct['data']) + print(predicted_logits.dtype) + if self.verbose: + print('resampling to original shape') + if output_file_truncated is not None: + print("Starting export", flush = True) + export_prediction_from_logits(predicted_logits, dct['data_properites'], self.configuration_manager, + self.plans_manager, self.dataset_json, output_file_truncated, + save_or_return_probabilities) + else: + del input_image + ret = convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits, self.plans_manager, + self.configuration_manager, + self.label_manager, + dct['data_properites'], + return_probabilities= + save_or_return_probabilities) + if save_or_return_probabilities: + return ret[0], ret[1] + else: + return ret + + def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Tensor: + """ + IMPORTANT! IF YOU ARE RUNNING THE CASCADE, THE SEGMENTATION FROM THE PREVIOUS STAGE MUST ALREADY BE STACKED ON + TOP OF THE IMAGE AS ONE-HOT REPRESENTATION! SEE PreprocessAdapter ON HOW THIS SHOULD BE DONE! + + RETURNED LOGITS HAVE THE SHAPE OF THE INPUT. THEY MUST BE CONVERTED BACK TO THE ORIGINAL IMAGE SIZE. + SEE convert_predicted_logits_to_segmentation_with_correct_shape + """ + # we have some code duplication here but this allows us to run with perform_everything_on_gpu=True as + # default and not have the entire program crash in case of GPU out of memory. Neat. That should make + # things a lot faster for some datasets. + original_perform_everything_on_gpu = self.perform_everything_on_gpu + with torch.no_grad(): + prediction = None + if self.perform_everything_on_gpu: + try: + for params in self.list_of_parameters: + + # messing with state dict names... + if not isinstance(self.network, OptimizedModule): + self.network.load_state_dict(torch.load(params, + map_location=torch.device('cpu'))["network_weights"]) + else: + self.network._orig_mod.load_state_dict(torch.load(params, + map_location=torch.device('cpu'))["network_weights"]) + + if prediction is None: + #prediction = torch.argmax(self.predict_sliding_window_return_logits(data), 0, + # keepdim=True).type(torch.uint8) + prediction = self.predict_sliding_window_return_logits(data) + else: + #prediction = torch.cat([prediction, + # torch.argmax(self.predict_sliding_window_return_logits(data), 0, + # keepdim=True).type(torch.uint8)]) + prediction += self.predict_sliding_window_return_logits(data) + if len(self.list_of_parameters) > 1: + # prediction /= len(self.list_of_parameters) + prediction = torch.mode(prediction, dim=0, keepdim=True).values + + except RuntimeError: + print('Prediction with perform_everything_on_gpu=True failed due to insufficient GPU memory. ' + 'Falling back to perform_everything_on_gpu=False. Not a big deal, just slower...') + print('Error:') + traceback.print_exc() + prediction = None + self.perform_everything_on_gpu = False + + if prediction is None: + for params in self.list_of_parameters: + # messing with state dict names... + if not isinstance(self.network, OptimizedModule): + self.network.load_state_dict(torch.load(params, + map_location=torch.device('cpu'))["network_weights"], + strict=False) # bc of old experiments.. + else: + self.network._orig_mod.load_state_dict(torch.load(params, + map_location=torch.device('cpu'))["network_weights"], + strict=False) # bc of old experiments.. + + if prediction is None: + prediction = self.predict_sliding_window_return_logits(data) + # prediction = torch.argmax(self.predict_sliding_window_return_logits(data),0,keepdim=True).type(torch.uint8) + else: + prediction += self.predict_sliding_window_return_logits(data) + #prediction = torch.cat([prediction, + # torch.argmax(self.predict_sliding_window_return_logits(data), 0, + # keepdim=True).type(torch.uint8)]) + if len(self.list_of_parameters) > 1: + prediction /= len(self.list_of_parameters) + # prediction = torch.mode(prediction, dim=0, keepdim=True).values + self.perform_everything_on_gpu = original_perform_everything_on_gpu + return prediction + + def _internal_get_sliding_window_slicers(self, image_size: Tuple[int, ...]): + slicers = [] + if len(self.configuration_manager.patch_size) < len(image_size): + assert len(self.configuration_manager.patch_size) == len( + image_size) - 1, 'if tile_size has less entries than image_size, ' \ + 'len(tile_size) ' \ + 'must be one shorter than len(image_size) ' \ + '(only dimension ' \ + 'discrepancy of 1 allowed).' + steps = compute_steps_for_sliding_window(image_size[1:], self.configuration_manager.patch_size, + self.tile_step_size) + if self.verbose: print(f'n_steps {image_size[0] * len(steps[0]) * len(steps[1])}, image size is' + f' {image_size}, tile_size {self.configuration_manager.patch_size}, ' + f'tile_step_size {self.tile_step_size}\nsteps:\n{steps}') + for d in range(image_size[0]): + for sx in steps[0]: + for sy in steps[1]: + slicers.append( + tuple([slice(None), d, *[slice(si, si + ti) for si, ti in + zip((sx, sy), self.configuration_manager.patch_size)]])) + else: + steps = compute_steps_for_sliding_window(image_size, self.configuration_manager.patch_size, + self.tile_step_size) + if self.verbose: print( + f'n_steps {np.prod([len(i) for i in steps])}, image size is {image_size}, tile_size {self.configuration_manager.patch_size}, ' + f'tile_step_size {self.tile_step_size}\nsteps:\n{steps}') + for sx in steps[0]: + for sy in steps[1]: + for sz in steps[2]: + slicers.append( + tuple([slice(None), *[slice(si, si + ti) for si, ti in + zip((sx, sy, sz), self.configuration_manager.patch_size)]])) + return slicers + + def _internal_maybe_mirror_and_predict(self, x: torch.Tensor) -> torch.Tensor: + mirror_axes = self.allowed_mirroring_axes if self.use_mirroring else None + prediction = self.network(x) + + if mirror_axes is not None: + # check for invalid numbers in mirror_axes + # x should be 5d for 3d images and 4d for 2d. so the max value of mirror_axes cannot exceed len(x.shape) - 3 + assert max(mirror_axes) <= len(x.shape) - 3, 'mirror_axes does not match the dimension of the input!' + + # Here I removed the combined flips, bc they ended up blurring the left/right dichotomy + # num_predictons = 2 ** len(mirror_axes) + num_predictons = len(mirror_axes) + if 0 in mirror_axes: + prediction += torch.flip(self.network(torch.flip(x, (2,))), (2,)) + if 1 in mirror_axes: + prediction += torch.flip(self.network(torch.flip(x, (3,))), (3,)) + if 2 in mirror_axes: + prediction += torch.flip(self.network(torch.flip(x, (4,))), (4,)) + # if 0 in mirror_axes and 1 in mirror_axes: + # prediction += torch.flip(self.network(torch.flip(x, (2, 3))), (2, 3)) + # if 0 in mirror_axes and 2 in mirror_axes: + # prediction += torch.flip(self.network(torch.flip(x, (2, 4))), (2, 4)) + # if 1 in mirror_axes and 2 in mirror_axes: + # prediction += torch.flip(self.network(torch.flip(x, (3, 4))), (3, 4)) + # if 0 in mirror_axes and 1 in mirror_axes and 2 in mirror_axes: + prediction += torch.flip(self.network(torch.flip(x, (2, 3, 4))), (2, 3, 4)) + prediction /= num_predictons + return prediction + + def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \ + -> Union[np.ndarray, torch.Tensor]: + assert isinstance(input_image, torch.Tensor) + self.network = self.network.to(self.device) + self.network.eval() + + empty_cache(self.device) + + # Autocast is a little bitch. + # If the device_type is 'cpu' then it's slow as heck on some CPUs (no auto bfloat16 support detection) + # and needs to be disabled. + # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False + # is set. Whyyyyyyy. (this is why we don't make use of enabled=False) + # So autocast will only be active if we have a cuda device. + with torch.no_grad(): + with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): + assert len(input_image.shape) == 4, 'input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)' + + if self.verbose: print(f'Input shape: {input_image.shape}') + if self.verbose: print("step_size:", self.tile_step_size) + if self.verbose: print("mirror_axes:", self.allowed_mirroring_axes if self.use_mirroring else None) + + # if input_image is smaller than tile_size we need to pad it to tile_size. + data, slicer_revert_padding = pad_nd_image(input_image, self.configuration_manager.patch_size, + 'constant', {'value': 0}, True, + None) + + slicers = self._internal_get_sliding_window_slicers(data.shape[1:]) + + # preallocate results and num_predictions + results_device = self.device if self.perform_everything_on_gpu else torch.device('cpu') + if self.verbose: print('preallocating arrays') + try: + data = data.to(self.device) + predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]), + dtype=torch.half, + device=results_device) + n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, + device=results_device) + if self.use_gaussian: + gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8, + value_scaling_factor=1000, + device=results_device) + except RuntimeError: + # sometimes the stuff is too large for GPUs. In that case fall back to CPU + results_device = torch.device('cpu') + data = data.to(results_device) + predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]), + dtype=torch.half, + device=results_device) + n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, + device=results_device) + if self.use_gaussian: + gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8, + value_scaling_factor=1000, + device=results_device) + finally: + empty_cache(self.device) + + if self.verbose: print('running prediction') + for sl in tqdm(slicers, disable=not self.allow_tqdm): + workon = data[sl][None] + workon = workon.to(self.device, non_blocking=False) + + prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device) + + predicted_logits[sl] += (prediction * gaussian if self.use_gaussian else prediction) + n_predictions[sl[1:]] += (gaussian if self.use_gaussian else 1) + + predicted_logits /= n_predictions + empty_cache(self.device) + return predicted_logits[tuple([slice(None), *slicer_revert_padding[1:]])] + + +def predict_entry_point_modelfolder(): + import argparse + parser = argparse.ArgumentParser(description='Use this to run inference with nnU-Net. This function is used when ' + 'you want to manually specify a folder containing a trained nnU-Net ' + 'model. This is useful when the nnunet environment variables ' + '(nnUNet_results) are not set.') + parser.add_argument('-i', type=str, required=True, + help='input folder. Remember to use the correct channel numberings for your files (_0000 etc). ' + 'File endings must be the same as the training dataset!') + parser.add_argument('-o', type=str, required=True, + help='Output folder. If it does not exist it will be created. Predicted segmentations will ' + 'have the same name as their source images.') + parser.add_argument('-m', type=str, required=True, + help='Folder in which the trained model is. Must have subfolders fold_X for the different ' + 'folds you trained') + parser.add_argument('-f', nargs='+', type=str, required=False, default=(0, 1, 2, 3, 4), + help='Specify the folds of the trained model that should be used for prediction. ' + 'Default: (0, 1, 2, 3, 4)') + parser.add_argument('-step_size', type=float, required=False, default=0.5, + help='Step size for sliding window prediction. The larger it is the faster but less accurate ' + 'the prediction. Default: 0.5. Cannot be larger than 1. We recommend the default.') + parser.add_argument('--disable_tta', action='store_true', required=False, default=False, + help='Set this flag to disable test time data augmentation in the form of mirroring. Faster, ' + 'but less accurate inference. Not recommended.') + parser.add_argument('--verbose', action='store_true', help="Set this if you like being talked to. You will have " + "to be a good listener/reader.") + parser.add_argument('--save_probabilities', action='store_true', + help='Set this to export predicted class "probabilities". Required if you want to ensemble ' + 'multiple configurations.') + parser.add_argument('--continue_prediction', '--c', action='store_true', + help='Continue an aborted previous prediction (will not overwrite existing files)') + parser.add_argument('-chk', type=str, required=False, default='checkpoint_final.pth', + help='Name of the checkpoint you want to use. Default: checkpoint_final.pth') + parser.add_argument('-npp', type=int, required=False, default=3, + help='Number of processes used for preprocessing. More is not always better. Beware of ' + 'out-of-RAM issues. Default: 3') + parser.add_argument('-nps', type=int, required=False, default=3, + help='Number of processes used for segmentation export. More is not always better. Beware of ' + 'out-of-RAM issues. Default: 3') + parser.add_argument('-prev_stage_predictions', type=str, required=False, default=None, + help='Folder containing the predictions of the previous stage. Required for cascaded models.') + parser.add_argument('-device', type=str, default='cuda', required=False, + help="Use this to set the device the inference should run with. Available options are 'cuda' " + "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! " + "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!") + + print( + "\n#######################################################################\nPlease cite the following paper " + "when using nnU-Net:\n" + "Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). " + "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. " + "Nature methods, 18(2), 203-211.\n#######################################################################\n") + + args = parser.parse_args() + args.f = [i if i == 'all' else int(i) for i in args.f] + + if not isdir(args.o): + maybe_mkdir_p(args.o) + + assert args.device in ['cpu', 'cuda', + 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.' + if args.device == 'cpu': + # let's allow torch to use hella threads + import multiprocessing + torch.set_num_threads(multiprocessing.cpu_count()) + device = torch.device('cpu') + elif args.device == 'cuda': + # multithreading in torch doesn't help nnU-Net if run on GPU + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + device = torch.device('cuda') + else: + device = torch.device('mps') + + predictor = nnUNetPredictor(tile_step_size=args.step_size, + use_gaussian=True, + use_mirroring=not args.disable_tta, + perform_everything_on_gpu=True, + device=device, + verbose=args.verbose) + predictor.initialize_from_trained_model_folder(args.m, args.f, args.chk) + predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities, + overwrite=not args.continue_prediction, + num_processes_preprocessing=args.npp, + num_processes_segmentation_export=args.nps, + folder_with_segs_from_prev_stage=args.prev_stage_predictions, + num_parts=1, part_id=0) + + +def predict_entry_point(): + import argparse + parser = argparse.ArgumentParser(description='Use this to run inference with nnU-Net. This function is used when ' + 'you want to manually specify a folder containing a trained nnU-Net ' + 'model. This is useful when the nnunet environment variables ' + '(nnUNet_results) are not set.') + parser.add_argument('-i', type=str, required=True, + help='input folder. Remember to use the correct channel numberings for your files (_0000 etc). ' + 'File endings must be the same as the training dataset!') + parser.add_argument('-o', type=str, required=True, + help='Output folder. If it does not exist it will be created. Predicted segmentations will ' + 'have the same name as their source images.') + parser.add_argument('-d', type=str, required=True, + help='Dataset with which you would like to predict. You can specify either dataset name or id') + parser.add_argument('-p', type=str, required=False, default='nnUNetPlans', + help='Plans identifier. Specify the plans in which the desired configuration is located. ' + 'Default: nnUNetPlans') + parser.add_argument('-tr', type=str, required=False, default='nnUNetTrainer', + help='What nnU-Net trainer class was used for training? Default: nnUNetTrainer') + parser.add_argument('-c', type=str, required=True, + help='nnU-Net configuration that should be used for prediction. Config must be located ' + 'in the plans specified with -p') + parser.add_argument('-f', nargs='+', type=str, required=False, default=(0, 1, 2, 3, 4), + help='Specify the folds of the trained model that should be used for prediction. ' + 'Default: (0, 1, 2, 3, 4)') + parser.add_argument('-step_size', type=float, required=False, default=0.5, + help='Step size for sliding window prediction. The larger it is the faster but less accurate ' + 'the prediction. Default: 0.5. Cannot be larger than 1. We recommend the default.') + parser.add_argument('--disable_tta', action='store_true', required=False, default=False, + help='Set this flag to disable test time data augmentation in the form of mirroring. Faster, ' + 'but less accurate inference. Not recommended.') + parser.add_argument('--verbose', action='store_true', help="Set this if you like being talked to. You will have " + "to be a good listener/reader.") + parser.add_argument('--save_probabilities', action='store_true', + help='Set this to export predicted class "probabilities". Required if you want to ensemble ' + 'multiple configurations.') + parser.add_argument('--continue_prediction', action='store_true', + help='Continue an aborted previous prediction (will not overwrite existing files)') + parser.add_argument('-chk', type=str, required=False, default='checkpoint_final.pth', + help='Name of the checkpoint you want to use. Default: checkpoint_final.pth') + parser.add_argument('-npp', type=int, required=False, default=3, + help='Number of processes used for preprocessing. More is not always better. Beware of ' + 'out-of-RAM issues. Default: 3') + parser.add_argument('-nps', type=int, required=False, default=3, + help='Number of processes used for segmentation export. More is not always better. Beware of ' + 'out-of-RAM issues. Default: 3') + parser.add_argument('-prev_stage_predictions', type=str, required=False, default=None, + help='Folder containing the predictions of the previous stage. Required for cascaded models.') + parser.add_argument('-num_parts', type=int, required=False, default=1, + help='Number of separate nnUNetv2_predict call that you will be making. Default: 1 (= this one ' + 'call predicts everything)') + parser.add_argument('-part_id', type=int, required=False, default=0, + help='If multiple nnUNetv2_predict exist, which one is this? IDs start with 0 can end with ' + 'num_parts - 1. So when you submit 5 nnUNetv2_predict calls you need to set -num_parts ' + '5 and use -part_id 0, 1, 2, 3 and 4. Simple, right? Note: You are yourself responsible ' + 'to make these run on separate GPUs! Use CUDA_VISIBLE_DEVICES (google, yo!)') + parser.add_argument('-device', type=str, default='cuda', required=False, + help="Use this to set the device the inference should run with. Available options are 'cuda' " + "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! " + "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!") + + print( + "\n#######################################################################\nPlease cite the following paper " + "when using nnU-Net:\n" + "Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). " + "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. " + "Nature methods, 18(2), 203-211.\n#######################################################################\n") + + args = parser.parse_args() + args.f = [i if i == 'all' else int(i) for i in args.f] + + model_folder = get_output_folder(args.d, args.tr, args.p, args.c) + + if not isdir(args.o): + maybe_mkdir_p(args.o) + + # slightly passive agressive haha + assert args.part_id < args.num_parts, 'Do you even read the documentation? See nnUNetv2_predict -h.' + + assert args.device in ['cpu', 'cuda', + 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.' + if args.device == 'cpu': + # let's allow torch to use hella threads + import multiprocessing + torch.set_num_threads(multiprocessing.cpu_count()) + device = torch.device('cpu') + elif args.device == 'cuda': + # multithreading in torch doesn't help nnU-Net if run on GPU + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + device = torch.device('cuda') + else: + device = torch.device('mps') + + predictor = nnUNetPredictor(tile_step_size=args.step_size, + use_gaussian=True, + use_mirroring=not args.disable_tta, + perform_everything_on_gpu=True, + device=device, + verbose=args.verbose, + verbose_preprocessing=False) + predictor.initialize_from_trained_model_folder( + model_folder, + args.f, + checkpoint_name=args.chk + ) + predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities, + overwrite=not args.continue_prediction, + num_processes_preprocessing=args.npp, + num_processes_segmentation_export=args.nps, + folder_with_segs_from_prev_stage=args.prev_stage_predictions, + num_parts=args.num_parts, + part_id=args.part_id) + # r = predict_from_raw_data(args.i, + # args.o, + # model_folder, + # args.f, + # args.step_size, + # use_gaussian=True, + # use_mirroring=not args.disable_tta, + # perform_everything_on_gpu=True, + # verbose=args.verbose, + # save_probabilities=args.save_probabilities, + # overwrite=not args.continue_prediction, + # checkpoint_name=args.chk, + # num_processes_preprocessing=args.npp, + # num_processes_segmentation_export=args.nps, + # folder_with_segs_from_prev_stage=args.prev_stage_predictions, + # num_parts=args.num_parts, + # part_id=args.part_id, + # device=device) + + +if __name__ == '__main__': + # predict a bunch of files + from nnunetv2.paths import nnUNet_results, nnUNet_raw + predictor = nnUNetPredictor( + tile_step_size=0.5, + use_gaussian=True, + use_mirroring=True, + perform_everything_on_gpu=True, + device=torch.device('cuda', 0), + verbose=False, + verbose_preprocessing=False, + allow_tqdm=True + ) + predictor.initialize_from_trained_model_folder( + join(nnUNet_results, 'Dataset003_Liver/nnUNetTrainer__nnUNetPlans__3d_lowres'), + use_folds=(0, ), + checkpoint_name='checkpoint_final.pth', + ) + predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'), + join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres'), + save_probabilities=False, overwrite=False, + num_processes_preprocessing=2, num_processes_segmentation_export=2, + folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0) + + # predict a numpy array + from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO + img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTr/liver_63_0000.nii.gz')]) + ret = predictor.predict_single_npy_array(img, props, None, None, False) + + iterator = predictor.get_data_iterator_from_raw_npy_data([img], None, [props], None, 1) + ret = predictor.predict_from_data_iterator(iterator, False, 1) + + + # predictor = nnUNetPredictor( + # tile_step_size=0.5, + # use_gaussian=True, + # use_mirroring=True, + # perform_everything_on_gpu=True, + # device=torch.device('cuda', 0), + # verbose=False, + # allow_tqdm=True + # ) + # predictor.initialize_from_trained_model_folder( + # join(nnUNet_results, 'Dataset003_Liver/nnUNetTrainer__nnUNetPlans__3d_cascade_fullres'), + # use_folds=(0,), + # checkpoint_name='checkpoint_final.pth', + # ) + # predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'), + # join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predCascade'), + # save_probabilities=False, overwrite=False, + # num_processes_preprocessing=2, num_processes_segmentation_export=2, + # folder_with_segs_from_prev_stage='/media/isensee/data/nnUNet_raw/Dataset003_Liver/imagesTs_predlowres', + # num_parts=1, part_id=0) + diff --git a/nnUNet/nnunetv2/inference/predict_from_raw_data.py.save b/nnUNet/nnunetv2/inference/predict_from_raw_data.py.save new file mode 100644 index 0000000..a2329a2 --- /dev/null +++ b/nnUNet/nnunetv2/inference/predict_from_raw_data.py.save @@ -0,0 +1,924 @@ +import inspect +import multiprocessing +import os +import traceback +from copy import deepcopy +from time import sleep +from typing import Tuple, Union, List, Optional + +import numpy as np +import torch +from acvl_utils.cropping_and_padding.padding import pad_nd_image +from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter +from batchgenerators.utilities.file_and_folder_operations import load_json, join, isfile, maybe_mkdir_p, isdir, subdirs, \ + save_json +from torch import nn +from torch._dynamo import OptimizedModule +from torch.nn.parallel import DistributedDataParallel +from tqdm import tqdm + +import nnunetv2 +from nnunetv2.configuration import default_num_processes +from nnunetv2.inference.data_iterators import PreprocessAdapterFromNpy, preprocessing_iterator_fromfiles, \ + preprocessing_iterator_fromnpy +from nnunetv2.inference.export_prediction import export_prediction_from_logits, \ + convert_predicted_logits_to_segmentation_with_correct_shape +from nnunetv2.inference.sliding_window_prediction import compute_gaussian, \ + compute_steps_for_sliding_window +from nnunetv2.utilities.file_path_utilities import get_output_folder, check_workers_alive_and_busy +from nnunetv2.utilities.find_class_by_name import recursive_find_python_class +from nnunetv2.utilities.helpers import empty_cache, dummy_context +from nnunetv2.utilities.json_export import recursive_fix_for_json_export +from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager +from nnunetv2.utilities.utils import create_lists_from_splitted_dataset_folder + + +class nnUNetPredictor(object): + def __init__(self, + tile_step_size: float = 0.5, + use_gaussian: bool = True, + use_mirroring: bool = True, + perform_everything_on_gpu: bool = True, + device: torch.device = torch.device('cuda'), + verbose: bool = False, + verbose_preprocessing: bool = False, + allow_tqdm: bool = True): + self.verbose = verbose + self.verbose_preprocessing = verbose_preprocessing + self.allow_tqdm = allow_tqdm + + self.plans_manager, self.configuration_manager, self.list_of_parameters, self.network, self.dataset_json, \ + self.trainer_name, self.allowed_mirroring_axes, self.label_manager = None, None, None, None, None, None, None, None + + self.tile_step_size = tile_step_size + self.use_gaussian = use_gaussian + self.use_mirroring = use_mirroring + if device.type == 'cuda': + # device = torch.device(type='cuda', index=0) # set the desired GPU with CUDA_VISIBLE_DEVICES! + # why would I ever want to do that. Stupid dobby. This kills DDP inference... + pass + if device.type != 'cuda': + print(f'perform_everything_on_gpu=True is only supported for cuda devices! Setting this to False') + perform_everything_on_gpu = False + self.device = device + self.perform_everything_on_gpu = perform_everything_on_gpu + + def initialize_from_trained_model_folder(self, model_training_output_dir: str, + use_folds: Union[Tuple[Union[int, str]], None], + checkpoint_name: str = 'checkpoint_final.pth'): + """ + This is used when making predictions with a trained model + """ + if use_folds is None: + use_folds = nnUNetPredictor.auto_detect_available_folds(model_training_output_dir, checkpoint_name) + + dataset_json = load_json(join(model_training_output_dir, 'dataset.json')) + plans = load_json(join(model_training_output_dir, 'plans.json')) + plans_manager = PlansManager(plans) + + if isinstance(use_folds, str): + use_folds = [use_folds] + + parameters = [] + for i, f in enumerate(use_folds): + f = int(f) if f != 'all' else f + checkpoint = torch.load(join(model_training_output_dir, f'fold_{f}', checkpoint_name), + map_location=torch.device('cpu')) + if i == 0: + trainer_name = checkpoint['trainer_name'] + configuration_name = checkpoint['init_args']['configuration'] + inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \ + 'inference_allowed_mirroring_axes' in checkpoint.keys() else None + + parameters.append(checkpoint['network_weights']) + + configuration_manager = plans_manager.get_configuration(configuration_name) + # restore network + num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json) + trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), + trainer_name, 'nnunetv2.training.nnUNetTrainer') + network = trainer_class.build_network_architecture(plans_manager, dataset_json, configuration_manager, + num_input_channels, enable_deep_supervision=False) + self.plans_manager = plans_manager + self.configuration_manager = configuration_manager + self.list_of_parameters = parameters + self.network = network + self.dataset_json = dataset_json + self.trainer_name = trainer_name + self.allowed_mirroring_axes = inference_allowed_mirroring_axes + self.label_manager = plans_manager.get_label_manager(dataset_json) + if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \ + and not isinstance(self.network, OptimizedModule): + print('compiling network') + self.network = torch.compile(self.network) + + def manual_initialization(self, network: nn.Module, plans_manager: PlansManager, + configuration_manager: ConfigurationManager, parameters: Optional[List[dict]], + dataset_json: dict, trainer_name: str, + inference_allowed_mirroring_axes: Optional[Tuple[int, ...]]): + """ + This is used by the nnUNetTrainer to initialize nnUNetPredictor for the final validation + """ + self.plans_manager = plans_manager + self.configuration_manager = configuration_manager + self.list_of_parameters = parameters + self.network = network + self.dataset_json = dataset_json + self.trainer_name = trainer_name + self.allowed_mirroring_axes = inference_allowed_mirroring_axes + self.label_manager = plans_manager.get_label_manager(dataset_json) + allow_compile = True + allow_compile = allow_compile and ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) + allow_compile = allow_compile and not isinstance(self.network, OptimizedModule) + if isinstance(self.network, DistributedDataParallel): + allow_compile = allow_compile and isinstance(self.network.module, OptimizedModule) + if allow_compile: + print('compiling network') + self.network = torch.compile(self.network) + + @staticmethod + def auto_detect_available_folds(model_training_output_dir, checkpoint_name): + print('use_folds is None, attempting to auto detect available folds') + fold_folders = subdirs(model_training_output_dir, prefix='fold_', join=False) + fold_folders = [i for i in fold_folders if i != 'fold_all'] + fold_folders = [i for i in fold_folders if isfile(join(model_training_output_dir, i, checkpoint_name))] + use_folds = [int(i.split('_')[-1]) for i in fold_folders] + print(f'found the following folds: {use_folds}') + return use_folds + + def _manage_input_and_output_lists(self, list_of_lists_or_source_folder: Union[str, List[List[str]]], + output_folder_or_list_of_truncated_output_files: Union[None, str, List[str]], + folder_with_segs_from_prev_stage: str = None, + overwrite: bool = True, + part_id: int = 0, + num_parts: int = 1, + save_probabilities: bool = False): + if isinstance(list_of_lists_or_source_folder, str): + list_of_lists_or_source_folder = create_lists_from_splitted_dataset_folder(list_of_lists_or_source_folder, + self.dataset_json['file_ending']) + print(f'There are {len(list_of_lists_or_source_folder)} cases in the source folder') + list_of_lists_or_source_folder = list_of_lists_or_source_folder[part_id::num_parts] + caseids = [os.path.basename(i[0])[:-(len(self.dataset_json['file_ending']) + 5)] for i in + list_of_lists_or_source_folder] + print( + f'I am process {part_id} out of {num_parts} (max process ID is {num_parts - 1}, we start counting with 0!)') + print(f'There are {len(caseids)} cases that I would like to predict') + + if isinstance(output_folder_or_list_of_truncated_output_files, str): + output_filename_truncated = [join(output_folder_or_list_of_truncated_output_files, i) for i in caseids] + else: + output_filename_truncated = output_folder_or_list_of_truncated_output_files + + seg_from_prev_stage_files = [join(folder_with_segs_from_prev_stage, i + self.dataset_json['file_ending']) if + folder_with_segs_from_prev_stage is not None else None for i in caseids] + # remove already predicted files form the lists + if not overwrite and output_filename_truncated is not None: + tmp = [isfile(i + self.dataset_json['file_ending']) for i in output_filename_truncated] + if save_probabilities: + tmp2 = [isfile(i + '.npz') for i in output_filename_truncated] + tmp = [i and j for i, j in zip(tmp, tmp2)] + not_existing_indices = [i for i, j in enumerate(tmp) if not j] + + output_filename_truncated = [output_filename_truncated[i] for i in not_existing_indices] + list_of_lists_or_source_folder = [list_of_lists_or_source_folder[i] for i in not_existing_indices] + seg_from_prev_stage_files = [seg_from_prev_stage_files[i] for i in not_existing_indices] + print(f'overwrite was set to {overwrite}, so I am only working on cases that haven\'t been predicted yet. ' + f'That\'s {len(not_existing_indices)} cases.') + return list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files + + def predict_from_files(self, + list_of_lists_or_source_folder: Union[str, List[List[str]]], + output_folder_or_list_of_truncated_output_files: Union[str, None, List[str]], + save_probabilities: bool = False, + overwrite: bool = True, + num_processes_preprocessing: int = default_num_processes, + num_processes_segmentation_export: int = default_num_processes, + folder_with_segs_from_prev_stage: str = None, + num_parts: int = 1, + part_id: int = 0): + """ + This is nnU-Net's default function for making predictions. It works best for batch predictions + (predicting many images at once). + """ + if isinstance(output_folder_or_list_of_truncated_output_files, str): + output_folder = output_folder_or_list_of_truncated_output_files + elif isinstance(output_folder_or_list_of_truncated_output_files, list): + output_folder = os.path.dirname(output_folder_or_list_of_truncated_output_files[0]) + else: + output_folder = None + + ######################## + # let's store the input arguments so that its clear what was used to generate the prediction + if output_folder is not None: + my_init_kwargs = {} + for k in inspect.signature(self.predict_from_files).parameters.keys(): + my_init_kwargs[k] = locals()[k] + my_init_kwargs = deepcopy( + my_init_kwargs) # let's not unintentionally change anything in-place. Take this as a + recursive_fix_for_json_export(my_init_kwargs) + maybe_mkdir_p(output_folder) + save_json(my_init_kwargs, join(output_folder, 'predict_from_raw_data_args.json')) + + # we need these two if we want to do things with the predictions like for example apply postprocessing + save_json(self.dataset_json, join(output_folder, 'dataset.json'), sort_keys=False) + save_json(self.plans_manager.plans, join(output_folder, 'plans.json'), sort_keys=False) + ####################### + + # check if we need a prediction from the previous stage + if self.configuration_manager.previous_stage_name is not None: + assert folder_with_segs_from_prev_stage is not None, \ + f'The requested configuration is a cascaded network. It requires the segmentations of the previous ' \ + f'stage ({self.configuration_manager.previous_stage_name}) as input. Please provide the folder where' \ + f' they are located via folder_with_segs_from_prev_stage' + + # sort out input and output filenames + list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files = \ + self._manage_input_and_output_lists(list_of_lists_or_source_folder, + output_folder_or_list_of_truncated_output_files, + folder_with_segs_from_prev_stage, overwrite, part_id, num_parts, + save_probabilities) + if len(list_of_lists_or_source_folder) == 0: + return + + data_iterator = self._internal_get_data_iterator_from_lists_of_filenames(list_of_lists_or_source_folder, + seg_from_prev_stage_files, + output_filename_truncated, + num_processes_preprocessing) + + return self.predict_from_data_iterator(data_iterator, save_probabilities, num_processes_segmentation_export) + + def _internal_get_data_iterator_from_lists_of_filenames(self, + input_list_of_lists: List[List[str]], + seg_from_prev_stage_files: Union[List[str], None], + output_filenames_truncated: Union[List[str], None], + num_processes: int): + return preprocessing_iterator_fromfiles(input_list_of_lists, seg_from_prev_stage_files, + output_filenames_truncated, self.plans_manager, self.dataset_json, + self.configuration_manager, num_processes, self.device.type == 'cuda', + self.verbose_preprocessing) + # preprocessor = self.configuration_manager.preprocessor_class(verbose=self.verbose_preprocessing) + # # hijack batchgenerators, yo + # # we use the multiprocessing of the batchgenerators dataloader to handle all the background worker stuff. This + # # way we don't have to reinvent the wheel here. + # num_processes = max(1, min(num_processes, len(input_list_of_lists))) + # ppa = PreprocessAdapter(input_list_of_lists, seg_from_prev_stage_files, preprocessor, + # output_filenames_truncated, self.plans_manager, self.dataset_json, + # self.configuration_manager, num_processes) + # if num_processes == 0: + # mta = SingleThreadedAugmenter(ppa, None) + # else: + # mta = MultiThreadedAugmenter(ppa, None, num_processes, 1, None, pin_memory=pin_memory) + # return mta + + def get_data_iterator_from_raw_npy_data(self, + image_or_list_of_images: Union[np.ndarray, List[np.ndarray]], + segs_from_prev_stage_or_list_of_segs_from_prev_stage: Union[None, + np.ndarray, + List[ + np.ndarray]], + properties_or_list_of_properties: Union[dict, List[dict]], + truncated_ofname: Union[str, List[str], None], + num_processes: int = 3): + + list_of_images = [image_or_list_of_images] if not isinstance(image_or_list_of_images, list) else \ + image_or_list_of_images + + if isinstance(segs_from_prev_stage_or_list_of_segs_from_prev_stage, np.ndarray): + segs_from_prev_stage_or_list_of_segs_from_prev_stage = [ + segs_from_prev_stage_or_list_of_segs_from_prev_stage] + + if isinstance(truncated_ofname, str): + truncated_ofname = [truncated_ofname] + + if isinstance(properties_or_list_of_properties, dict): + properties_or_list_of_properties = [properties_or_list_of_properties] + + num_processes = min(num_processes, len(list_of_images)) + pp = preprocessing_iterator_fromnpy( + list_of_images, + segs_from_prev_stage_or_list_of_segs_from_prev_stage, + properties_or_list_of_properties, + truncated_ofname, + self.plans_manager, + self.dataset_json, + self.configuration_manager, + num_processes, + self.device.type == 'cuda', + self.verbose_preprocessing + ) + + return pp + + def predict_from_list_of_npy_arrays(self, + image_or_list_of_images: Union[np.ndarray, List[np.ndarray]], + segs_from_prev_stage_or_list_of_segs_from_prev_stage: Union[None, + np.ndarray, + List[ + np.ndarray]], + properties_or_list_of_properties: Union[dict, List[dict]], + truncated_ofname: Union[str, List[str], None], + num_processes: int = 3, + save_probabilities: bool = False, + num_processes_segmentation_export: int = default_num_processes): + iterator = self.get_data_iterator_from_raw_npy_data(image_or_list_of_images, + segs_from_prev_stage_or_list_of_segs_from_prev_stage, + properties_or_list_of_properties, + truncated_ofname, + num_processes) + return self.predict_from_data_iterator(iterator, save_probabilities, num_processes_segmentation_export) + + def predict_from_data_iterator(self, + data_iterator, + save_probabilities: bool = False, + num_processes_segmentation_export: int = default_num_processes): + """ + each element returned by data_iterator must be a dict with 'data', 'ofile' and 'data_properites' keys! + If 'ofile' is None, the result will be returned instead of written to a file + """ + with multiprocessing.get_context("spawn").Pool(num_processes_segmentation_export) as export_pool: + worker_list = [i for i in export_pool._pool] + r = [] + for preprocessed in data_iterator: + data = preprocessed['data'] + if isinstance(data, str): + delfile = data + data = torch.from_numpy(np.load(data)) + os.remove(delfile) + + ofile = preprocessed['ofile'] + if ofile is not None: + print(f'\nPredicting {os.path.basename(ofile)}:') + else: + print(f'\nPredicting image of shape {data.shape}:') + + print(f'perform_everything_on_gpu: {self.perform_everything_on_gpu}') + + properties = preprocessed['data_properites'] + + # let's not get into a runaway situation where the GPU predicts so fast that the disk has to b swamped with + # npy files + proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2) + while not proceed: + # print('sleeping') + sleep(0.1) + proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2) + + prediction = self.predict_logits_from_preprocessed_data(data).cpu() + + if ofile is not None: + # this needs to go into background processes + # export_prediction_from_logits(prediction, properties, configuration_manager, plans_manager, + # dataset_json, ofile, save_probabilities) + print('sending off prediction to background worker for resampling and export') + r.append( + export_pool.starmap_async( + export_prediction_from_logits, + ((prediction, properties, self.configuration_manager, self.plans_manager, + self.dataset_json, ofile, save_probabilities),) + ) + ) + else: + # convert_predicted_logits_to_segmentation_with_correct_shape(prediction, plans_manager, + # configuration_manager, label_manager, + # properties, + # save_probabilities) + print('sending off prediction to background worker for resampling') + r.append( + export_pool.starmap_async( + convert_predicted_logits_to_segmentation_with_correct_shape, ( + (prediction, self.plans_manager, + self.configuration_manager, self.label_manager, + properties, + save_probabilities),) + ) + ) + if ofile is not None: + print(f'done with {os.path.basename(ofile)}') + else: + print(f'\nDone with image of shape {data.shape}:') + ret = [i.get()[0] for i in r] + + if isinstance(data_iterator, MultiThreadedAugmenter): + data_iterator._finish() + + # clear lru cache + compute_gaussian.cache_clear() + # clear device cache + empty_cache(self.device) + return ret + + def predict_single_npy_array(self, input_image: np.ndarray, image_properties: dict, + segmentation_previous_stage: np.ndarray = None, + output_file_truncated: str = None, + save_or_return_probabilities: bool = False): + """ + image_properties must only have a 'spacing' key! + """ + ppa = PreprocessAdapterFromNpy([input_image], [segmentation_previous_stage], [image_properties], + [output_file_truncated], + self.plans_manager, self.dataset_json, self.configuration_manager, + num_threads_in_multithreaded=1, verbose=self.verbose) + if self.verbose: + print('preprocessing') + dct = next(ppa) + + if self.verbose: + print('predicting') + predicted_logits = self.predict_logits_from_preprocessed_data(dct['data']) + print(predicted_logits.dtype) + #print("resampling", flush = True) + if self.verbose: + print('resampling to original shape') + if output_file_truncated is not None: + print("Starting export", flush = True) + export_prediction_from_logits(predicted_logits, dct['data_properites'], self.configuration_manager, + self.plans_manager, self.dataset_json, output_file_truncated, + save_or_return_probabilities) + else: + del input_image + ret = convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits, self.plans_manager, + self.configuration_manager, + self.label_manager, + dct['data_properites'], + return_probabilities= + save_or_return_probabilities) + #print("Done converting", flush=True) + if save_or_return_probabilities: + return ret[0], ret[1] + else: + return ret + + def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Tensor: + """ + IMPORTANT! IF YOU ARE RUNNING THE CASCADE, THE SEGMENTATION FROM THE PREVIOUS STAGE MUST ALREADY BE STACKED ON + TOP OF THE IMAGE AS ONE-HOT REPRESENTATION! SEE PreprocessAdapter ON HOW THIS SHOULD BE DONE! + + RETURNED LOGITS HAVE THE SHAPE OF THE INPUT. THEY MUST BE CONVERTED BACK TO THE ORIGINAL IMAGE SIZE. + SEE convert_predicted_logits_to_segmentation_with_correct_shape + """ + # we have some code duplication here but this allows us to run with perform_everything_on_gpu=True as + # default and not have the entire program crash in case of GPU out of memory. Neat. That should make + # things a lot faster for some datasets. + original_perform_everything_on_gpu = self.perform_everything_on_gpu + with torch.no_grad(): + prediction = None + if self.perform_everything_on_gpu: + try: + for params in self.list_of_parameters: + + # messing with state dict names... + if not isinstance(self.network, OptimizedModule): + self.network.load_state_dict(params) + else: + self.network._orig_mod.load_state_dict(params) + + if prediction is None: + prediction = self.predict_sliding_window_return_logits(data) + else: + prediction += self.predict_sliding_window_return_logits(data) + # print("done sli win") + if len(self.list_of_parameters) > 1: + # print("dividing hold on") + prediction /= len(self.list_of_parameters) + + except RuntimeError: + print('Prediction with perform_everything_on_gpu=True failed due to insufficient GPU memory. ' + 'Falling back to perform_everything_on_gpu=False. Not a big deal, just slower...') + print('Error:') + traceback.print_exc() + prediction = None + self.perform_everything_on_gpu = False + + if prediction is None: + for params in self.list_of_parameters: + # messing with state dict names... + if not isinstance(self.network, OptimizedModule): + self.network.load_state_dict(params) + else: + self.network._orig_mod.load_state_dict(params) + + if prediction is None: + prediction = self.predict_sliding_window_return_logits(data) + else: + prediction += self.predict_sliding_window_return_logits(data) + if len(self.list_of_parameters) > 1: + prediction /= len(self.list_of_parameters) + + # prediction = prediction.to('cpu') + self.perform_everything_on_gpu = original_perform_everything_on_gpu + return prediction + + def _internal_get_sliding_window_slicers(self, image_size: Tuple[int, ...]): + slicers = [] + if len(self.configuration_manager.patch_size) < len(image_size): + assert len(self.configuration_manager.patch_size) == len( + image_size) - 1, 'if tile_size has less entries than image_size, ' \ + 'len(tile_size) ' \ + 'must be one shorter than len(image_size) ' \ + '(only dimension ' \ + 'discrepancy of 1 allowed).' + steps = compute_steps_for_sliding_window(image_size[1:], self.configuration_manager.patch_size, + self.tile_step_size) + if self.verbose: print(f'n_steps {image_size[0] * len(steps[0]) * len(steps[1])}, image size is' + f' {image_size}, tile_size {self.configuration_manager.patch_size}, ' + f'tile_step_size {self.tile_step_size}\nsteps:\n{steps}') + for d in range(image_size[0]): + for sx in steps[0]: + for sy in steps[1]: + slicers.append( + tuple([slice(None), d, *[slice(si, si + ti) for si, ti in + zip((sx, sy), self.configuration_manager.patch_size)]])) + else: + steps = compute_steps_for_sliding_window(image_size, self.configuration_manager.patch_size, + self.tile_step_size) + if self.verbose: print( + f'n_steps {np.prod([len(i) for i in steps])}, image size is {image_size}, tile_size {self.configuration_manager.patch_size}, ' + f'tile_step_size {self.tile_step_size}\nsteps:\n{steps}') + for sx in steps[0]: + for sy in steps[1]: + for sz in steps[2]: + slicers.append( + tuple([slice(None), *[slice(si, si + ti) for si, ti in + zip((sx, sy, sz), self.configuration_manager.patch_size)]])) + return slicers + + def _internal_maybe_mirror_and_predict(self, x: torch.Tensor) -> torch.Tensor: + mirror_axes = self.allowed_mirroring_axes if self.use_mirroring else None + prediction = self.network(x) + + if mirror_axes is not None: + # check for invalid numbers in mirror_axes + # x should be 5d for 3d images and 4d for 2d. so the max value of mirror_axes cannot exceed len(x.shape) - 3 + assert max(mirror_axes) <= len(x.shape) - 3, 'mirror_axes does not match the dimension of the input!' + + num_predictons = 2 ** len(mirror_axes) + if 0 in mirror_axes: + prediction += torch.flip(self.network(torch.flip(x, (2,))), (2,)) + if 1 in mirror_axes: + prediction += torch.flip(self.network(torch.flip(x, (3,))), (3,)) + if 2 in mirror_axes: + prediction += torch.flip(self.network(torch.flip(x, (4,))), (4,)) + if 0 in mirror_axes and 1 in mirror_axes: + prediction += torch.flip(self.network(torch.flip(x, (2, 3))), (2, 3)) + if 0 in mirror_axes and 2 in mirror_axes: + prediction += torch.flip(self.network(torch.flip(x, (2, 4))), (2, 4)) + if 1 in mirror_axes and 2 in mirror_axes: + prediction += torch.flip(self.network(torch.flip(x, (3, 4))), (3, 4)) + if 0 in mirror_axes and 1 in mirror_axes and 2 in mirror_axes: + prediction += torch.flip(self.network(torch.flip(x, (2, 3, 4))), (2, 3, 4)) + prediction /= num_predictons + return prediction + + def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \ + -> Union[np.ndarray, torch.Tensor]: + assert isinstance(input_image, torch.Tensor) + self.network = self.network.to(self.device) + self.network.eval() + + empty_cache(self.device) + + # Autocast is a little bitch. + # If the device_type is 'cpu' then it's slow as heck on some CPUs (no auto bfloat16 support detection) + # and needs to be disabled. + # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False + # is set. Whyyyyyyy. (this is why we don't make use of enabled=False) + # So autocast will only be active if we have a cuda device. + with torch.no_grad(): + with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): + assert len(input_image.shape) == 4, 'input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)' + + if self.verbose: print(f'Input shape: {input_image.shape}') + if self.verbose: print("step_size:", self.tile_step_size) + if self.verbose: print("mirror_axes:", self.allowed_mirroring_axes if self.use_mirroring else None) + + # if input_image is smaller than tile_size we need to pad it to tile_size. + data, slicer_revert_padding = pad_nd_image(input_image, self.configuration_manager.patch_size, + 'constant', {'value': 0}, True, + None) + + slicers = self._internal_get_sliding_window_slicers(data.shape[1:]) + + # preallocate results and num_predictions + results_device = self.device if self.perform_everything_on_gpu else torch.device('cpu') + if self.verbose: print('preallocating arrays') + try: + data = data.to(self.device) + predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]), + dtype=torch.half, + device=results_device) + n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, + device=results_device) + if self.use_gaussian: + gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8, + value_scaling_factor=1000, + device=results_device) + except RuntimeError: + # sometimes the stuff is too large for GPUs. In that case fall back to CPU + print("Switching to CPU", flush=True) + results_device = torch.device('cpu') + data = data.to(results_device) + predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]), + dtype=torch.half, + device=results_device) + n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, + device=results_device) + if self.use_gaussian: + gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8, + value_scaling_factor=1000, + device=results_device) + finally: + empty_cache(self.device) + + if self.verbose: print('running prediction') + for sl in tqdm(slicers, disable=not self.allow_tqdm): + ^rint("go",flush=True) + workon = data[sl][None] + workon = workon.to(self.device, non_blocking=False) + + prediction = self._internal_maybe_mirror_and_predict(workon)[0].to(results_device) + + predicted_logits[sl] += (prediction * gaussian if self.use_gaussian else prediction) + n_predictions[sl[1:]] += (gaussian if self.use_gaussian else 1) + + predicted_logits /= n_predictions + empty_cache(self.device) + return predicted_logits[tuple([slice(None), *slicer_revert_padding[1:]])] + + +def predict_entry_point_modelfolder(): + import argparse + parser = argparse.ArgumentParser(description='Use this to run inference with nnU-Net. This function is used when ' + 'you want to manually specify a folder containing a trained nnU-Net ' + 'model. This is useful when the nnunet environment variables ' + '(nnUNet_results) are not set.') + parser.add_argument('-i', type=str, required=True, + help='input folder. Remember to use the correct channel numberings for your files (_0000 etc). ' + 'File endings must be the same as the training dataset!') + parser.add_argument('-o', type=str, required=True, + help='Output folder. If it does not exist it will be created. Predicted segmentations will ' + 'have the same name as their source images.') + parser.add_argument('-m', type=str, required=True, + help='Folder in which the trained model is. Must have subfolders fold_X for the different ' + 'folds you trained') + parser.add_argument('-f', nargs='+', type=str, required=False, default=(0, 1, 2, 3, 4), + help='Specify the folds of the trained model that should be used for prediction. ' + 'Default: (0, 1, 2, 3, 4)') + parser.add_argument('-step_size', type=float, required=False, default=0.5, + help='Step size for sliding window prediction. The larger it is the faster but less accurate ' + 'the prediction. Default: 0.5. Cannot be larger than 1. We recommend the default.') + parser.add_argument('--disable_tta', action='store_true', required=False, default=False, + help='Set this flag to disable test time data augmentation in the form of mirroring. Faster, ' + 'but less accurate inference. Not recommended.') + parser.add_argument('--verbose', action='store_true', help="Set this if you like being talked to. You will have " + "to be a good listener/reader.") + parser.add_argument('--save_probabilities', action='store_true', + help='Set this to export predicted class "probabilities". Required if you want to ensemble ' + 'multiple configurations.') + parser.add_argument('--continue_prediction', '--c', action='store_true', + help='Continue an aborted previous prediction (will not overwrite existing files)') + parser.add_argument('-chk', type=str, required=False, default='checkpoint_final.pth', + help='Name of the checkpoint you want to use. Default: checkpoint_final.pth') + parser.add_argument('-npp', type=int, required=False, default=3, + help='Number of processes used for preprocessing. More is not always better. Beware of ' + 'out-of-RAM issues. Default: 3') + parser.add_argument('-nps', type=int, required=False, default=3, + help='Number of processes used for segmentation export. More is not always better. Beware of ' + 'out-of-RAM issues. Default: 3') + parser.add_argument('-prev_stage_predictions', type=str, required=False, default=None, + help='Folder containing the predictions of the previous stage. Required for cascaded models.') + parser.add_argument('-device', type=str, default='cuda', required=False, + help="Use this to set the device the inference should run with. Available options are 'cuda' " + "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! " + "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!") + + print( + "\n#######################################################################\nPlease cite the following paper " + "when using nnU-Net:\n" + "Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). " + "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. " + "Nature methods, 18(2), 203-211.\n#######################################################################\n") + + args = parser.parse_args() + args.f = [i if i == 'all' else int(i) for i in args.f] + + if not isdir(args.o): + maybe_mkdir_p(args.o) + + assert args.device in ['cpu', 'cuda', + 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.' + if args.device == 'cpu': + # let's allow torch to use hella threads + import multiprocessing + torch.set_num_threads(multiprocessing.cpu_count()) + device = torch.device('cpu') + elif args.device == 'cuda': + # multithreading in torch doesn't help nnU-Net if run on GPU + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + device = torch.device('cuda') + else: + device = torch.device('mps') + + predictor = nnUNetPredictor(tile_step_size=args.step_size, + use_gaussian=True, + use_mirroring=not args.disable_tta, + perform_everything_on_gpu=True, + device=device, + verbose=args.verbose) + predictor.initialize_from_trained_model_folder(args.m, args.f, args.chk) + predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities, + overwrite=not args.continue_prediction, + num_processes_preprocessing=args.npp, + num_processes_segmentation_export=args.nps, + folder_with_segs_from_prev_stage=args.prev_stage_predictions, + num_parts=1, part_id=0) + + +def predict_entry_point(): + import argparse + parser = argparse.ArgumentParser(description='Use this to run inference with nnU-Net. This function is used when ' + 'you want to manually specify a folder containing a trained nnU-Net ' + 'model. This is useful when the nnunet environment variables ' + '(nnUNet_results) are not set.') + parser.add_argument('-i', type=str, required=True, + help='input folder. Remember to use the correct channel numberings for your files (_0000 etc). ' + 'File endings must be the same as the training dataset!') + parser.add_argument('-o', type=str, required=True, + help='Output folder. If it does not exist it will be created. Predicted segmentations will ' + 'have the same name as their source images.') + parser.add_argument('-d', type=str, required=True, + help='Dataset with which you would like to predict. You can specify either dataset name or id') + parser.add_argument('-p', type=str, required=False, default='nnUNetPlans', + help='Plans identifier. Specify the plans in which the desired configuration is located. ' + 'Default: nnUNetPlans') + parser.add_argument('-tr', type=str, required=False, default='nnUNetTrainer', + help='What nnU-Net trainer class was used for training? Default: nnUNetTrainer') + parser.add_argument('-c', type=str, required=True, + help='nnU-Net configuration that should be used for prediction. Config must be located ' + 'in the plans specified with -p') + parser.add_argument('-f', nargs='+', type=str, required=False, default=(0, 1, 2, 3, 4), + help='Specify the folds of the trained model that should be used for prediction. ' + 'Default: (0, 1, 2, 3, 4)') + parser.add_argument('-step_size', type=float, required=False, default=0.5, + help='Step size for sliding window prediction. The larger it is the faster but less accurate ' + 'the prediction. Default: 0.5. Cannot be larger than 1. We recommend the default.') + parser.add_argument('--disable_tta', action='store_true', required=False, default=False, + help='Set this flag to disable test time data augmentation in the form of mirroring. Faster, ' + 'but less accurate inference. Not recommended.') + parser.add_argument('--verbose', action='store_true', help="Set this if you like being talked to. You will have " + "to be a good listener/reader.") + parser.add_argument('--save_probabilities', action='store_true', + help='Set this to export predicted class "probabilities". Required if you want to ensemble ' + 'multiple configurations.') + parser.add_argument('--continue_prediction', action='store_true', + help='Continue an aborted previous prediction (will not overwrite existing files)') + parser.add_argument('-chk', type=str, required=False, default='checkpoint_final.pth', + help='Name of the checkpoint you want to use. Default: checkpoint_final.pth') + parser.add_argument('-npp', type=int, required=False, default=3, + help='Number of processes used for preprocessing. More is not always better. Beware of ' + 'out-of-RAM issues. Default: 3') + parser.add_argument('-nps', type=int, required=False, default=3, + help='Number of processes used for segmentation export. More is not always better. Beware of ' + 'out-of-RAM issues. Default: 3') + parser.add_argument('-prev_stage_predictions', type=str, required=False, default=None, + help='Folder containing the predictions of the previous stage. Required for cascaded models.') + parser.add_argument('-num_parts', type=int, required=False, default=1, + help='Number of separate nnUNetv2_predict call that you will be making. Default: 1 (= this one ' + 'call predicts everything)') + parser.add_argument('-part_id', type=int, required=False, default=0, + help='If multiple nnUNetv2_predict exist, which one is this? IDs start with 0 can end with ' + 'num_parts - 1. So when you submit 5 nnUNetv2_predict calls you need to set -num_parts ' + '5 and use -part_id 0, 1, 2, 3 and 4. Simple, right? Note: You are yourself responsible ' + 'to make these run on separate GPUs! Use CUDA_VISIBLE_DEVICES (google, yo!)') + parser.add_argument('-device', type=str, default='cuda', required=False, + help="Use this to set the device the inference should run with. Available options are 'cuda' " + "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! " + "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!") + + print( + "\n#######################################################################\nPlease cite the following paper " + "when using nnU-Net:\n" + "Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). " + "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. " + "Nature methods, 18(2), 203-211.\n#######################################################################\n") + + args = parser.parse_args() + args.f = [i if i == 'all' else int(i) for i in args.f] + + model_folder = get_output_folder(args.d, args.tr, args.p, args.c) + + if not isdir(args.o): + maybe_mkdir_p(args.o) + + # slightly passive agressive haha + assert args.part_id < args.num_parts, 'Do you even read the documentation? See nnUNetv2_predict -h.' + + assert args.device in ['cpu', 'cuda', + 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.' + if args.device == 'cpu': + # let's allow torch to use hella threads + import multiprocessing + torch.set_num_threads(multiprocessing.cpu_count()) + device = torch.device('cpu') + elif args.device == 'cuda': + # multithreading in torch doesn't help nnU-Net if run on GPU + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + device = torch.device('cuda') + else: + device = torch.device('mps') + + predictor = nnUNetPredictor(tile_step_size=args.step_size, + use_gaussian=True, + use_mirroring=not args.disable_tta, + perform_everything_on_gpu=True, + device=device, + verbose=args.verbose, + verbose_preprocessing=False) + predictor.initialize_from_trained_model_folder( + model_folder, + args.f, + checkpoint_name=args.chk + ) + predictor.predict_from_files(args.i, args.o, save_probabilities=args.save_probabilities, + overwrite=not args.continue_prediction, + num_processes_preprocessing=args.npp, + num_processes_segmentation_export=args.nps, + folder_with_segs_from_prev_stage=args.prev_stage_predictions, + num_parts=args.num_parts, + part_id=args.part_id) + # r = predict_from_raw_data(args.i, + # args.o, + # model_folder, + # args.f, + # args.step_size, + # use_gaussian=True, + # use_mirroring=not args.disable_tta, + # perform_everything_on_gpu=True, + # verbose=args.verbose, + # save_probabilities=args.save_probabilities, + # overwrite=not args.continue_prediction, + # checkpoint_name=args.chk, + # num_processes_preprocessing=args.npp, + # num_processes_segmentation_export=args.nps, + # folder_with_segs_from_prev_stage=args.prev_stage_predictions, + # num_parts=args.num_parts, + # part_id=args.part_id, + # device=device) + + +if __name__ == '__main__': + # predict a bunch of files + from nnunetv2.paths import nnUNet_results, nnUNet_raw + predictor = nnUNetPredictor( + tile_step_size=0.5, + use_gaussian=True, + use_mirroring=True, + perform_everything_on_gpu=True, + device=torch.device('cuda', 0), + verbose=False, + verbose_preprocessing=False, + allow_tqdm=True + ) + predictor.initialize_from_trained_model_folder( + join(nnUNet_results, 'Dataset003_Liver/nnUNetTrainer__nnUNetPlans__3d_lowres'), + use_folds=(0, ), + checkpoint_name='checkpoint_final.pth', + ) + predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'), + join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres'), + save_probabilities=False, overwrite=False, + num_processes_preprocessing=2, num_processes_segmentation_export=2, + folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0) + + # predict a numpy array + from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO + img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTr/liver_63_0000.nii.gz')]) + ret = predictor.predict_single_npy_array(img, props, None, None, False) + + iterator = predictor.get_data_iterator_from_raw_npy_data([img], None, [props], None, 1) + ret = predictor.predict_from_data_iterator(iterator, False, 1) + + + # predictor = nnUNetPredictor( + # tile_step_size=0.5, + # use_gaussian=True, + # use_mirroring=True, + # perform_everything_on_gpu=True, + # device=torch.device('cuda', 0), + # verbose=False, + # allow_tqdm=True + # ) + # predictor.initialize_from_trained_model_folder( + # join(nnUNet_results, 'Dataset003_Liver/nnUNetTrainer__nnUNetPlans__3d_cascade_fullres'), + # use_folds=(0,), + # checkpoint_name='checkpoint_final.pth', + # ) + # predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'), + # join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predCascade'), + # save_probabilities=False, overwrite=False, + # num_processes_preprocessing=2, num_processes_segmentation_export=2, + # folder_with_segs_from_prev_stage='/media/isensee/data/nnUNet_raw/Dataset003_Liver/imagesTs_predlowres', + # num_parts=1, part_id=0) + diff --git a/nnUNet/nnunetv2/inference/readme.md b/nnUNet/nnunetv2/inference/readme.md new file mode 100644 index 0000000..984b4a9 --- /dev/null +++ b/nnUNet/nnunetv2/inference/readme.md @@ -0,0 +1,205 @@ +The nnU-Net inference is now much more dynamic than before, allowing you to more seamlessly integrate nnU-Net into +your existing workflows. +This readme will give you a quick rundown of your options. This is not a complete guide. Look into the code to learn +all the details! + +# Preface +In terms of speed, the most efficient inference strategy is the one done by the nnU-Net defaults! Images are read on +the fly and preprocessed in background workers. The main process takes the preprocessed images, predicts them and +sends the prediction off to another set of background workers which will resize the resulting logits, convert +them to a segmentation and export the segmentation. + +The reason the default setup is the best option is because + +1) loading and preprocessing as well as segmentation export are interlaced with the prediction. The main process can +focus on communicating with the compute device (i.e. your GPU) and does not have to do any other processing. +This uses your resources as well as possible! +2) only the images and segmentation that are currently being needed are stored in RAM! Imaging predicting many images +and having to store all of them + the results in your system memory + +# nnUNetPredictor +The new nnUNetPredictor class encapsulates the inferencing code and makes it simple to switch between modes. Your +code can hold a nnUNetPredictor instance and perform prediction on the fly. Previously this was not possible and each +new prediction request resulted in reloading the parameters and reinstantiating the network architecture. Not ideal. + +The nnUNetPredictor must be ininitialized manually! You will want to use the +`predictor.initialize_from_trained_model_folder` function for 99% of use cases! + +New feature: If you do not specify an output folder / output files then the predicted segmentations will be +returned + + +## Recommended nnU-Net default: predict from source files + +tldr: +- loads images on the fly +- performs preprocessing in background workers +- main process focuses only on making predictions +- results are again given to background workers for resampling and (optional) export + +pros: +- best suited for predicting a large number of images +- nicer to your RAM + +cons: +- not ideal when single images are to be predicted +- requires images to be present as files + +Example: +```python + from nnunetv2.paths import nnUNet_results, nnUNet_raw + import torch + from batchgenerators.utilities.file_and_folder_operations import join + from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor + + # instantiate the nnUNetPredictor + predictor = nnUNetPredictor( + tile_step_size=0.5, + use_gaussian=True, + use_mirroring=True, + perform_everything_on_gpu=True, + device=torch.device('cuda', 0), + verbose=False, + verbose_preprocessing=False, + allow_tqdm=True + ) + # initializes the network architecture, loads the checkpoint + predictor.initialize_from_trained_model_folder( + join(nnUNet_results, 'Dataset003_Liver/nnUNetTrainer__nnUNetPlans__3d_lowres'), + use_folds=(0,), + checkpoint_name='checkpoint_final.pth', + ) + # variant 1: give input and output folders + predictor.predict_from_files(join(nnUNet_raw, 'Dataset003_Liver/imagesTs'), + join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres'), + save_probabilities=False, overwrite=False, + num_processes_preprocessing=2, num_processes_segmentation_export=2, + folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0) +``` + +Instead if giving input and output folders you can also give concrete files. If you give concrete files, there is no +need for the _0000 suffix anymore! This can be useful in situations where you have no control over the filenames! +Remember that the files must be given as 'list of lists' where each entry in the outer list is a case to be predicted +and the inner list contains all the files belonging to that case. There is just one file for datasets with just one +input modality (such as CT) but may be more files for others (such as MRI where there is sometimes T1, T2, Flair etc). +IMPORTANT: the order in wich the files for each case are given must match the order of the channels as defined in the +dataset.json! + +If you give files as input, you need to give individual output files as output! + +```python + # variant 2, use list of files as inputs. Note how we use nested lists!!! + indir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs') + outdir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs_predlowres') + predictor.predict_from_files([[join(indir, 'liver_152_0000.nii.gz')], + [join(indir, 'liver_142_0000.nii.gz')]], + [join(outdir, 'liver_152.nii.gz'), + join(outdir, 'liver_142.nii.gz')], + save_probabilities=False, overwrite=False, + num_processes_preprocessing=2, num_processes_segmentation_export=2, + folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0) +``` + +Did you know? If you do not specify output files, the predicted segmentations will be returned: +```python + # variant 2.5, returns segmentations + indir = join(nnUNet_raw, 'Dataset003_Liver/imagesTs') + predicted_segmentations = predictor.predict_from_files([[join(indir, 'liver_152_0000.nii.gz')], + [join(indir, 'liver_142_0000.nii.gz')]], + None, + save_probabilities=False, overwrite=True, + num_processes_preprocessing=2, num_processes_segmentation_export=2, + folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0) +``` + +## Prediction from npy arrays +tldr: +- you give images as a list of npy arrays +- performs preprocessing in background workers +- main process focuses only on making predictions +- results are again given to background workers for resampling and (optional) export + +pros: +- the correct variant for when you have images in RAM already +- well suited for predicting multiple images + +cons: +- uses more ram than the default +- unsuited for large number of images as all images must be held in RAM + +```python + from nnunetv2.imageio.simpleitk_reader_writer import SimpleITKIO + + img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')]) + img2, props2 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_146_0000.nii.gz')]) + img3, props3 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_145_0000.nii.gz')]) + img4, props4 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_144_0000.nii.gz')]) + # we do not set output files so that the segmentations will be returned. You can of course also specify output + # files instead (no return value on that case) + ret = predictor.predict_from_list_of_npy_arrays([img, img2, img3, img4], + None, + [props, props2, props3, props4], + None, 2, save_probabilities=False, + num_processes_segmentation_export=2) +``` + +## Predicting a single npy array + +tldr: +- you give one image as npy array +- everything is done in the main process: preprocessing, prediction, resampling, (export) +- no interlacing, slowest variant! +- ONLY USE THIS IF YOU CANNOT GIVE NNUNET MULTIPLE IMAGES AT ONCE FOR SOME REASON + +pros: +- no messing with multiprocessing +- no messing with data iterator blabla + +cons: +- slows as heck, yo +- never the right choice unless you can only give a single image at a time to nnU-Net + +```python + # predict a single numpy array + img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTr/liver_63_0000.nii.gz')]) + ret = predictor.predict_single_npy_array(img, props, None, None, False) +``` + +## Predicting with a custom data iterator +tldr: +- highly flexible +- not for newbies + +pros: +- you can do everything yourself +- you have all the freedom you want +- really fast if you remember to use multiprocessing in your iterator + +cons: +- you need to do everything yourself +- harder than you might think + +```python + img, props = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_147_0000.nii.gz')]) + img2, props2 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_146_0000.nii.gz')]) + img3, props3 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_145_0000.nii.gz')]) + img4, props4 = SimpleITKIO().read_images([join(nnUNet_raw, 'Dataset003_Liver/imagesTs/liver_144_0000.nii.gz')]) + # each element returned by data_iterator must be a dict with 'data', 'ofile' and 'data_properites' keys! + # If 'ofile' is None, the result will be returned instead of written to a file + # the iterator is responsible for performing the correct preprocessing! + # note how the iterator here does not use multiprocessing -> preprocessing will be done in the main thread! + # take a look at the default iterators for predict_from_files and predict_from_list_of_npy_arrays + # (they both use predictor.predict_from_data_iterator) for inspiration! + def my_iterator(list_of_input_arrs, list_of_input_props): + preprocessor = predictor.configuration_manager.preprocessor_class(verbose=predictor.verbose) + for a, p in zip(list_of_input_arrs, list_of_input_props): + data, seg = preprocessor.run_case_npy(a, + None, + p, + predictor.plans_manager, + predictor.configuration_manager, + predictor.dataset_json) + yield {'data': torch.from_numpy(data).contiguous().pin_memory(), 'data_properites': p, 'ofile': None} + ret = predictor.predict_from_data_iterator(my_iterator([img, img2, img3, img4], [props, props2, props3, props4]), + save_probabilities=False, num_processes_segmentation_export=3) +``` \ No newline at end of file diff --git a/nnUNet/nnunetv2/inference/sliding_window_prediction.py b/nnUNet/nnunetv2/inference/sliding_window_prediction.py new file mode 100644 index 0000000..07316cf --- /dev/null +++ b/nnUNet/nnunetv2/inference/sliding_window_prediction.py @@ -0,0 +1,67 @@ +from functools import lru_cache + +import numpy as np +import torch +from typing import Union, Tuple, List +from acvl_utils.cropping_and_padding.padding import pad_nd_image +from scipy.ndimage import gaussian_filter + + +@lru_cache(maxsize=2) +def compute_gaussian(tile_size: Union[Tuple[int, ...], List[int]], sigma_scale: float = 1. / 8, + value_scaling_factor: float = 1, dtype=torch.float16, device=torch.device('cuda', 0)) \ + -> torch.Tensor: + tmp = np.zeros(tile_size) + center_coords = [i // 2 for i in tile_size] + sigmas = [i * sigma_scale for i in tile_size] + tmp[tuple(center_coords)] = 1 + gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0) + + gaussian_importance_map = torch.from_numpy(gaussian_importance_map).type(dtype).to(device) + + gaussian_importance_map = gaussian_importance_map / torch.max(gaussian_importance_map) * value_scaling_factor + gaussian_importance_map = gaussian_importance_map.type(dtype) + + # gaussian_importance_map cannot be 0, otherwise we may end up with nans! + gaussian_importance_map[gaussian_importance_map == 0] = torch.min( + gaussian_importance_map[gaussian_importance_map != 0]) + + return gaussian_importance_map + + +def compute_steps_for_sliding_window(image_size: Tuple[int, ...], tile_size: Tuple[int, ...], tile_step_size: float) -> \ + List[List[int]]: + assert [i >= j for i, j in zip(image_size, tile_size)], "image size must be as large or larger than patch_size" + assert 0 < tile_step_size <= 1, 'step_size must be larger than 0 and smaller or equal to 1' + + # our step width is patch_size*step_size at most, but can be narrower. For example if we have image size of + # 110, patch size of 64 and step_size of 0.5, then we want to make 3 steps starting at coordinate 0, 23, 46 + target_step_sizes_in_voxels = [i * tile_step_size for i in tile_size] + + num_steps = [int(np.ceil((i - k) / j)) + 1 for i, j, k in zip(image_size, target_step_sizes_in_voxels, tile_size)] + + steps = [] + for dim in range(len(tile_size)): + # the highest step value for this dimension is + max_step_value = image_size[dim] - tile_size[dim] + if num_steps[dim] > 1: + actual_step_size = max_step_value / (num_steps[dim] - 1) + else: + actual_step_size = 99999999999 # does not matter because there is only one step at 0 + + steps_here = [int(np.round(actual_step_size * i)) for i in range(num_steps[dim])] + + steps.append(steps_here) + + return steps + + +if __name__ == '__main__': + a = torch.rand((4, 2, 32, 23)) + a_npy = a.numpy() + + a_padded = pad_nd_image(a, new_shape=(48, 27)) + a_npy_padded = pad_nd_image(a_npy, new_shape=(48, 27)) + assert all([i == j for i, j in zip(a_padded.shape, (4, 2, 48, 27))]) + assert all([i == j for i, j in zip(a_npy_padded.shape, (4, 2, 48, 27))]) + assert np.all(a_padded.numpy() == a_npy_padded) diff --git a/nnUNet/nnunetv2/model_sharing/__init__.py b/nnUNet/nnunetv2/model_sharing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/model_sharing/entry_points.py b/nnUNet/nnunetv2/model_sharing/entry_points.py new file mode 100644 index 0000000..1ab7c93 --- /dev/null +++ b/nnUNet/nnunetv2/model_sharing/entry_points.py @@ -0,0 +1,61 @@ +from nnunetv2.model_sharing.model_download import download_and_install_from_url +from nnunetv2.model_sharing.model_export import export_pretrained_model +from nnunetv2.model_sharing.model_import import install_model_from_zip_file + + +def print_license_warning(): + print('') + print('######################################################') + print('!!!!!!!!!!!!!!!!!!!!!!!!WARNING!!!!!!!!!!!!!!!!!!!!!!!') + print('######################################################') + print("Using the pretrained model weights is subject to the license of the dataset they were trained on. Some " + "allow commercial use, others don't. It is your responsibility to make sure you use them appropriately! Use " + "nnUNet_print_pretrained_model_info(task_name) to see a summary of the dataset and where to find its license!") + print('######################################################') + print('') + + +def download_by_url(): + import argparse + parser = argparse.ArgumentParser( + description="Use this to download pretrained models. This script is intended to download models via url only. " + "CAREFUL: This script will overwrite " + "existing models (if they share the same trainer class and plans as " + "the pretrained model.") + parser.add_argument("url", type=str, help='URL of the pretrained model') + args = parser.parse_args() + url = args.url + download_and_install_from_url(url) + + +def install_from_zip_entry_point(): + import argparse + parser = argparse.ArgumentParser( + description="Use this to install a zip file containing a pretrained model.") + parser.add_argument("zip", type=str, help='zip file') + args = parser.parse_args() + zip = args.zip + install_model_from_zip_file(zip) + + +def export_pretrained_model_entry(): + import argparse + parser = argparse.ArgumentParser( + description="Use this to export a trained model as a zip file.") + parser.add_argument('-d', type=str, required=True, help='Dataset name or id') + parser.add_argument('-o', type=str, required=True, help='Output file name') + parser.add_argument('-c', nargs='+', type=str, required=False, + default=('3d_lowres', '3d_fullres', '2d', '3d_cascade_fullres'), + help="List of configuration names") + parser.add_argument('-tr', required=False, type=str, default='nnUNetTrainer', help='Trainer class') + parser.add_argument('-p', required=False, type=str, default='nnUNetPlans', help='plans identifier') + parser.add_argument('-f', required=False, nargs='+', type=str, default=(0, 1, 2, 3, 4), help='list of fold ids') + parser.add_argument('-chk', required=False, nargs='+', type=str, default=('checkpoint_final.pth', ), + help='Lis tof checkpoint names to export. Default: checkpoint_final.pth') + parser.add_argument('--not_strict', action='store_false', default=False, required=False, help='Set this to allow missing folds and/or configurations') + parser.add_argument('--exp_cv_preds', action='store_true', required=False, help='Set this to export the cross-validation predictions as well') + args = parser.parse_args() + + export_pretrained_model(dataset_name_or_id=args.d, output_file=args.o, configurations=args.c, trainer=args.tr, + plans_identifier=args.p, folds=args.f, strict=not args.not_strict, save_checkpoints=args.chk, + export_crossval_predictions=args.exp_cv_preds) diff --git a/nnUNet/nnunetv2/model_sharing/model_download.py b/nnUNet/nnunetv2/model_sharing/model_download.py new file mode 100644 index 0000000..d845ab5 --- /dev/null +++ b/nnUNet/nnunetv2/model_sharing/model_download.py @@ -0,0 +1,47 @@ +from typing import Optional + +import requests +from batchgenerators.utilities.file_and_folder_operations import * +from time import time +from nnunetv2.model_sharing.model_import import install_model_from_zip_file +from nnunetv2.paths import nnUNet_results +from tqdm import tqdm + + +def download_and_install_from_url(url): + assert nnUNet_results is not None, "Cannot install model because network_training_output_dir is not " \ + "set (RESULTS_FOLDER missing as environment variable, see " \ + "Installation instructions)" + print('Downloading pretrained model from url:', url) + import http.client + http.client.HTTPConnection._http_vsn = 10 + http.client.HTTPConnection._http_vsn_str = 'HTTP/1.0' + + import os + home = os.path.expanduser('~') + random_number = int(time() * 1e7) + tempfile = join(home, '.nnunetdownload_%s' % str(random_number)) + + try: + download_file(url=url, local_filename=tempfile, chunk_size=8192 * 16) + print("Download finished. Extracting...") + install_model_from_zip_file(tempfile) + print("Done") + except Exception as e: + raise e + finally: + if isfile(tempfile): + os.remove(tempfile) + + +def download_file(url: str, local_filename: str, chunk_size: Optional[int] = 8192 * 16) -> str: + # borrowed from https://stackoverflow.com/questions/16694907/download-large-file-in-python-with-requests + # NOTE the stream=True parameter below + with requests.get(url, stream=True, timeout=100) as r: + r.raise_for_status() + with tqdm.wrapattr(open(local_filename, 'wb'), "write", total=int(r.headers.get("Content-Length"))) as f: + for chunk in r.iter_content(chunk_size=chunk_size): + f.write(chunk) + return local_filename + + diff --git a/nnUNet/nnunetv2/model_sharing/model_export.py b/nnUNet/nnunetv2/model_sharing/model_export.py new file mode 100644 index 0000000..2db8e24 --- /dev/null +++ b/nnUNet/nnunetv2/model_sharing/model_export.py @@ -0,0 +1,124 @@ +import zipfile + +from nnunetv2.utilities.file_path_utilities import * + + +def export_pretrained_model(dataset_name_or_id: Union[int, str], output_file: str, + configurations: Tuple[str] = ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), + trainer: str = 'nnUNetTrainer', + plans_identifier: str = 'nnUNetPlans', + folds: Tuple[int, ...] = (0, 1, 2, 3, 4), + strict: bool = True, + save_checkpoints: Tuple[str, ...] = ('checkpoint_final.pth',), + export_crossval_predictions: bool = False) -> None: + dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id) + with(zipfile.ZipFile(output_file, 'w', zipfile.ZIP_DEFLATED)) as zipf: + for c in configurations: + print(f"Configuration {c}") + trainer_output_dir = get_output_folder(dataset_name, trainer, plans_identifier, c) + + if not isdir(trainer_output_dir): + if strict: + raise RuntimeError(f"{dataset_name} is missing the trained model of configuration {c}") + else: + continue + + expected_fold_folder = ["fold_%s" % i if i != 'all' else 'fold_all' for i in folds] + assert all([isdir(join(trainer_output_dir, i)) for i in expected_fold_folder]), \ + f"not all requested folds are present; {dataset_name} {c}; requested folds: {folds}" + + assert isfile(join(trainer_output_dir, "plans.json")), f"plans.json missing, {dataset_name} {c}" + + for fold_folder in expected_fold_folder: + print(f"Exporting {fold_folder}") + # debug.json, does not exist yet + source_file = join(trainer_output_dir, fold_folder, "debug.json") + if isfile(source_file): + zipf.write(source_file, os.path.relpath(source_file, nnUNet_results)) + + # all requested checkpoints + for chk in save_checkpoints: + source_file = join(trainer_output_dir, fold_folder, chk) + zipf.write(source_file, os.path.relpath(source_file, nnUNet_results)) + + # progress.png + source_file = join(trainer_output_dir, fold_folder, "progress.png") + zipf.write(source_file, os.path.relpath(source_file, nnUNet_results)) + + # if it exists, network architecture.png + source_file = join(trainer_output_dir, fold_folder, "network_architecture.pdf") + if isfile(source_file): + zipf.write(source_file, os.path.relpath(source_file, nnUNet_results)) + + # validation folder with all predicted segmentations etc + if export_crossval_predictions: + source_folder = join(trainer_output_dir, fold_folder, "validation") + files = [i for i in subfiles(source_folder, join=False) if not i.endswith('.npz') and not i.endswith('.pkl')] + for f in files: + zipf.write(join(source_folder, f), os.path.relpath(join(source_folder, f), nnUNet_results)) + # just the summary.json file from the validation + else: + source_file = join(trainer_output_dir, fold_folder, "validation", "summary.json") + zipf.write(source_file, os.path.relpath(source_file, nnUNet_results)) + + source_folder = join(trainer_output_dir, f'crossval_results_folds_{folds_tuple_to_string(folds)}') + if isdir(source_folder): + if export_crossval_predictions: + source_files = subfiles(source_folder, join=True) + else: + source_files = [ + join(trainer_output_dir, f'crossval_results_folds_{folds_tuple_to_string(folds)}', i) for i in + ['summary.json', 'postprocessing.pkl', 'postprocessing.json'] + ] + for s in source_files: + if isfile(s): + zipf.write(s, os.path.relpath(s, nnUNet_results)) + # plans + source_file = join(trainer_output_dir, "plans.json") + zipf.write(source_file, os.path.relpath(source_file, nnUNet_results)) + # fingerprint + source_file = join(trainer_output_dir, "dataset_fingerprint.json") + zipf.write(source_file, os.path.relpath(source_file, nnUNet_results)) + # dataset + source_file = join(trainer_output_dir, "dataset.json") + zipf.write(source_file, os.path.relpath(source_file, nnUNet_results)) + + ensemble_dir = join(nnUNet_results, dataset_name, 'ensembles') + + if not isdir(ensemble_dir): + print("No ensemble directory found for task", dataset_name_or_id) + return + subd = subdirs(ensemble_dir, join=False) + # figure out whether the models in the ensemble are all within the exported models here + for ens in subd: + identifiers, folds = convert_ensemble_folder_to_model_identifiers_and_folds(ens) + ok = True + for i in identifiers: + tr, pl, c = convert_identifier_to_trainer_plans_config(i) + if tr == trainer and pl == plans_identifier and c in configurations: + pass + else: + ok = False + if ok: + print(f'found matching ensemble: {ens}') + source_folder = join(ensemble_dir, ens) + if export_crossval_predictions: + source_files = subfiles(source_folder, join=True) + else: + source_files = [ + join(source_folder, i) for i in + ['summary.json', 'postprocessing.pkl', 'postprocessing.json'] if isfile(join(source_folder, i)) + ] + for s in source_files: + zipf.write(s, os.path.relpath(s, nnUNet_results)) + inference_information_file = join(nnUNet_results, dataset_name, 'inference_information.json') + if isfile(inference_information_file): + zipf.write(inference_information_file, os.path.relpath(inference_information_file, nnUNet_results)) + inference_information_txt_file = join(nnUNet_results, dataset_name, 'inference_information.txt') + if isfile(inference_information_txt_file): + zipf.write(inference_information_txt_file, os.path.relpath(inference_information_txt_file, nnUNet_results)) + print('Done') + + +if __name__ == '__main__': + export_pretrained_model(2, '/home/fabian/temp/dataset2.zip', strict=False, export_crossval_predictions=True, folds=(0, )) diff --git a/nnUNet/nnunetv2/model_sharing/model_import.py b/nnUNet/nnunetv2/model_sharing/model_import.py new file mode 100644 index 0000000..0356e90 --- /dev/null +++ b/nnUNet/nnunetv2/model_sharing/model_import.py @@ -0,0 +1,8 @@ +import zipfile + +from nnunetv2.paths import nnUNet_results + + +def install_model_from_zip_file(zip_file: str): + with zipfile.ZipFile(zip_file, 'r') as zip_ref: + zip_ref.extractall(nnUNet_results) \ No newline at end of file diff --git a/nnUNet/nnunetv2/paths.py b/nnUNet/nnunetv2/paths.py new file mode 100644 index 0000000..f2220f9 --- /dev/null +++ b/nnUNet/nnunetv2/paths.py @@ -0,0 +1,39 @@ +# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +""" +PLEASE READ paths.md FOR INFORMATION TO HOW TO SET THIS UP +""" + +nnUNet_raw = os.environ.get('nnUNet_raw') +nnUNet_preprocessed = os.environ.get('nnUNet_preprocessed') +nnUNet_results = os.environ.get('nnUNet_results') + +if nnUNet_raw is None: + print("nnUNet_raw is not defined and nnU-Net can only be used on data for which preprocessed files " + "are already present on your system. nnU-Net cannot be used for experiment planning and preprocessing like " + "this. If this is not intended, please read documentation/setting_up_paths.md for information on how to set " + "this up properly.") + +if nnUNet_preprocessed is None: + print("nnUNet_preprocessed is not defined and nnU-Net can not be used for preprocessing " + "or training. If this is not intended, please read documentation/setting_up_paths.md for information on how " + "to set this up.") + +if nnUNet_results is None: + print("nnUNet_results is not defined and nnU-Net cannot be used for training or " + "inference. If this is not intended behavior, please read documentation/setting_up_paths.md for information " + "on how to set this up.") diff --git a/nnUNet/nnunetv2/postprocessing/__init__.py b/nnUNet/nnunetv2/postprocessing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/postprocessing/remove_connected_components.py b/nnUNet/nnunetv2/postprocessing/remove_connected_components.py new file mode 100644 index 0000000..94724fc --- /dev/null +++ b/nnUNet/nnunetv2/postprocessing/remove_connected_components.py @@ -0,0 +1,362 @@ +import argparse +import multiprocessing +import shutil +from multiprocessing import Pool +from typing import Union, Tuple, List, Callable + +import numpy as np +from acvl_utils.morphology.morphology_helper import remove_all_but_largest_component +from batchgenerators.utilities.file_and_folder_operations import load_json, subfiles, maybe_mkdir_p, join, isfile, \ + isdir, save_pickle, load_pickle, save_json +from nnunetv2.configuration import default_num_processes +from nnunetv2.evaluation.accumulate_cv_results import accumulate_cv_results +from nnunetv2.evaluation.evaluate_predictions import region_or_label_to_mask, compute_metrics_on_folder, \ + load_summary_json, label_or_region_to_key +from nnunetv2.imageio.base_reader_writer import BaseReaderWriter +from nnunetv2.paths import nnUNet_raw +from nnunetv2.utilities.file_path_utilities import folds_tuple_to_string +from nnunetv2.utilities.json_export import recursive_fix_for_json_export +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager + + +def remove_all_but_largest_component_from_segmentation(segmentation: np.ndarray, + labels_or_regions: Union[int, Tuple[int, ...], + List[Union[int, Tuple[int, ...]]]], + background_label: int = 0) -> np.ndarray: + mask = np.zeros_like(segmentation, dtype=bool) + if not isinstance(labels_or_regions, list): + labels_or_regions = [labels_or_regions] + for l_or_r in labels_or_regions: + mask |= region_or_label_to_mask(segmentation, l_or_r) + mask_keep = remove_all_but_largest_component(mask) + ret = np.copy(segmentation) # do not modify the input! + ret[mask & ~mask_keep] = background_label + return ret + + +def apply_postprocessing(segmentation: np.ndarray, pp_fns: List[Callable], pp_fn_kwargs: List[dict]): + for fn, kwargs in zip(pp_fns, pp_fn_kwargs): + segmentation = fn(segmentation, **kwargs) + return segmentation + + +def load_postprocess_save(segmentation_file: str, + output_fname: str, + image_reader_writer: BaseReaderWriter, + pp_fns: List[Callable], + pp_fn_kwargs: List[dict]): + seg, props = image_reader_writer.read_seg(segmentation_file) + seg = apply_postprocessing(seg[0], pp_fns, pp_fn_kwargs) + image_reader_writer.write_seg(seg, output_fname, props) + + +def determine_postprocessing(folder_predictions: str, + folder_ref: str, + plans_file_or_dict: Union[str, dict], + dataset_json_file_or_dict: Union[str, dict], + num_processes: int = default_num_processes, + keep_postprocessed_files: bool = True): + """ + Determines nnUNet postprocessing. Its output is a postprocessing.pkl file in folder_predictions which can be + used with apply_postprocessing_to_folder. + + Postprocessed files are saved in folder_predictions/postprocessed. Set + keep_postprocessed_files=False to delete these files after this function is done (temp files will eb created + and deleted regardless). + + If plans_file_or_dict or dataset_json_file_or_dict are None, we will look for them in input_folder + """ + output_folder = join(folder_predictions, 'postprocessed') + + if plans_file_or_dict is None: + expected_plans_file = join(folder_predictions, 'plans.json') + if not isfile(expected_plans_file): + raise RuntimeError(f"Expected plans file missing: {expected_plans_file}. The plans fils should have been " + f"created while running nnUNetv2_predict. Sadge.") + plans_file_or_dict = load_json(expected_plans_file) + plans_manager = PlansManager(plans_file_or_dict) + + if dataset_json_file_or_dict is None: + expected_dataset_json_file = join(folder_predictions, 'dataset.json') + if not isfile(expected_dataset_json_file): + raise RuntimeError( + f"Expected plans file missing: {expected_dataset_json_file}. The plans fils should have been " + f"created while running nnUNetv2_predict. Sadge.") + dataset_json_file_or_dict = load_json(expected_dataset_json_file) + + if not isinstance(dataset_json_file_or_dict, dict): + dataset_json = load_json(dataset_json_file_or_dict) + else: + dataset_json = dataset_json_file_or_dict + + rw = plans_manager.image_reader_writer_class() + label_manager = plans_manager.get_label_manager(dataset_json) + labels_or_regions = label_manager.foreground_regions if label_manager.has_regions else label_manager.foreground_labels + + predicted_files = subfiles(folder_predictions, suffix=dataset_json['file_ending'], join=False) + ref_files = subfiles(folder_ref, suffix=dataset_json['file_ending'], join=False) + # we should print a warning if not all files from folder_ref are present in folder_predictions + if not all([i in predicted_files for i in ref_files]): + print(f'WARNING: Not all files in folder_ref were found in folder_predictions. Determining postprocessing ' + f'should always be done on the entire dataset!') + + # before we start we should evaluate the imaegs in the source folder + if not isfile(join(folder_predictions, 'summary.json')): + compute_metrics_on_folder(folder_ref, + folder_predictions, + join(folder_predictions, 'summary.json'), + rw, + dataset_json['file_ending'], + labels_or_regions, + label_manager.ignore_label, + num_processes) + + # we save the postprocessing functions in here + pp_fns = [] + pp_fn_kwargs = [] + + # pool party! + with multiprocessing.get_context("spawn").Pool(num_processes) as pool: + # now let's see whether removing all but the largest foreground region improves the scores + output_here = join(output_folder, 'temp', 'keep_largest_fg') + maybe_mkdir_p(output_here) + pp_fn = remove_all_but_largest_component_from_segmentation + kwargs = { + 'labels_or_regions': label_manager.foreground_labels, + } + + pool.starmap( + load_postprocess_save, + zip( + [join(folder_predictions, i) for i in predicted_files], + [join(output_here, i) for i in predicted_files], + [rw] * len(predicted_files), + [[pp_fn]] * len(predicted_files), + [[kwargs]] * len(predicted_files) + ) + ) + compute_metrics_on_folder(folder_ref, + output_here, + join(output_here, 'summary.json'), + rw, + dataset_json['file_ending'], + labels_or_regions, + label_manager.ignore_label, + num_processes) + # now we need to figure out if doing this improved the dice scores. We will implement that defensively in so far + # that if a single class got worse as a result we won't do this. We can change this in the future but right now I + # prefer to do it this way + baseline_results = load_summary_json(join(folder_predictions, 'summary.json')) + pp_results = load_summary_json(join(output_here, 'summary.json')) + do_this = pp_results['foreground_mean']['Dice'] > baseline_results['foreground_mean']['Dice'] + if do_this: + for class_id in pp_results['mean'].keys(): + if pp_results['mean'][class_id]['Dice'] < baseline_results['mean'][class_id]['Dice']: + do_this = False + break + if do_this: + print(f'Results were improved by removing all but the largest foreground region. ' + f'Mean dice before: {round(baseline_results["foreground_mean"]["Dice"], 5)} ' + f'after: {round(pp_results["foreground_mean"]["Dice"], 5)}') + source = output_here + pp_fns.append(pp_fn) + pp_fn_kwargs.append(kwargs) + else: + print(f'Removing all but the largest foreground region did not improve results!') + source = folder_predictions + + # in the old nnU-Net we could just apply all-but-largest component removal to all classes at the same time and + # then evaluate for each class whether this improved results. This is no longer possible because we now support + # region-based predictions and regions can overlap, causing interactions + # in principle the order with which the postprocessing is applied to the regions matter as well and should be + # investigated, but due to some things that I am too lazy to explain right now it's going to be alright (I think) + # to stick to the order in which they are declared in dataset.json (if you want to think about it then think about + # region_class_order) + # 2023_02_06: I hate myself for the comment above. Thanks past me + if len(labels_or_regions) > 1: + for label_or_region in labels_or_regions: + pp_fn = remove_all_but_largest_component_from_segmentation + kwargs = { + 'labels_or_regions': label_or_region, + } + + output_here = join(output_folder, 'temp', 'keep_largest_perClassOrRegion') + maybe_mkdir_p(output_here) + + pool.starmap( + load_postprocess_save, + zip( + [join(source, i) for i in predicted_files], + [join(output_here, i) for i in predicted_files], + [rw] * len(predicted_files), + [[pp_fn]] * len(predicted_files), + [[kwargs]] * len(predicted_files) + ) + ) + compute_metrics_on_folder(folder_ref, + output_here, + join(output_here, 'summary.json'), + rw, + dataset_json['file_ending'], + labels_or_regions, + label_manager.ignore_label, + num_processes) + baseline_results = load_summary_json(join(source, 'summary.json')) + pp_results = load_summary_json(join(output_here, 'summary.json')) + do_this = pp_results['mean'][label_or_region]['Dice'] > baseline_results['mean'][label_or_region]['Dice'] + if do_this: + print(f'Results were improved by removing all but the largest component for {label_or_region}. ' + f'Dice before: {round(baseline_results["mean"][label_or_region]["Dice"], 5)} ' + f'after: {round(pp_results["mean"][label_or_region]["Dice"], 5)}') + if isdir(join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest')): + shutil.rmtree(join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest')) + shutil.move(output_here, join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest'), ) + source = join(output_folder, 'temp', 'keep_largest_perClassOrRegion_currentBest') + pp_fns.append(pp_fn) + pp_fn_kwargs.append(kwargs) + else: + print(f'Removing all but the largest component for {label_or_region} did not improve results! ' + f'Dice before: {round(baseline_results["mean"][label_or_region]["Dice"], 5)} ' + f'after: {round(pp_results["mean"][label_or_region]["Dice"], 5)}') + [shutil.copy(join(source, i), join(output_folder, i)) for i in subfiles(source, join=False)] + save_pickle((pp_fns, pp_fn_kwargs), join(folder_predictions, 'postprocessing.pkl')) + + baseline_results = load_summary_json(join(folder_predictions, 'summary.json')) + final_results = load_summary_json(join(output_folder, 'summary.json')) + tmp = { + 'input_folder': {i: baseline_results[i] for i in ['foreground_mean', 'mean']}, + 'postprocessed': {i: final_results[i] for i in ['foreground_mean', 'mean']}, + 'postprocessing_fns': [i.__name__ for i in pp_fns], + 'postprocessing_kwargs': pp_fn_kwargs, + } + # json is a very annoying little bi###. Can't handle tuples as dict keys. + tmp['input_folder']['mean'] = {label_or_region_to_key(k): tmp['input_folder']['mean'][k] for k in + tmp['input_folder']['mean'].keys()} + tmp['postprocessed']['mean'] = {label_or_region_to_key(k): tmp['postprocessed']['mean'][k] for k in + tmp['postprocessed']['mean'].keys()} + # did I already say that I hate json? "TypeError: Object of type int64 is not JSON serializable" You retarded bro? + recursive_fix_for_json_export(tmp) + save_json(tmp, join(folder_predictions, 'postprocessing.json')) + + shutil.rmtree(join(output_folder, 'temp')) + + if not keep_postprocessed_files: + shutil.rmtree(output_folder) + return pp_fns, pp_fn_kwargs + + +def apply_postprocessing_to_folder(input_folder: str, + output_folder: str, + pp_fns: List[Callable], + pp_fn_kwargs: List[dict], + plans_file_or_dict: Union[str, dict] = None, + dataset_json_file_or_dict: Union[str, dict] = None, + num_processes=8) -> None: + """ + If plans_file_or_dict or dataset_json_file_or_dict are None, we will look for them in input_folder + """ + if plans_file_or_dict is None: + expected_plans_file = join(input_folder, 'plans.json') + if not isfile(expected_plans_file): + raise RuntimeError(f"Expected plans file missing: {expected_plans_file}. The plans file should have been " + f"created while running nnUNetv2_predict. Sadge. If the folder you want to apply " + f"postprocessing to was create from an ensemble then just specify one of the " + f"plans files of the ensemble members in plans_file_or_dict") + plans_file_or_dict = load_json(expected_plans_file) + plans_manager = PlansManager(plans_file_or_dict) + + if dataset_json_file_or_dict is None: + expected_dataset_json_file = join(input_folder, 'dataset.json') + if not isfile(expected_dataset_json_file): + raise RuntimeError( + f"Expected plans file missing: {expected_dataset_json_file}. The dataset.json should have been " + f"copied while running nnUNetv2_predict/nnUNetv2_ensemble. Sadge.") + dataset_json_file_or_dict = load_json(expected_dataset_json_file) + + if not isinstance(dataset_json_file_or_dict, dict): + dataset_json = load_json(dataset_json_file_or_dict) + else: + dataset_json = dataset_json_file_or_dict + + rw = plans_manager.image_reader_writer_class() + + maybe_mkdir_p(output_folder) + with multiprocessing.get_context("spawn").Pool(num_processes) as p: + files = subfiles(input_folder, suffix=dataset_json['file_ending'], join=False) + + _ = p.starmap(load_postprocess_save, + zip( + [join(input_folder, i) for i in files], + [join(output_folder, i) for i in files], + [rw] * len(files), + [pp_fns] * len(files), + [pp_fn_kwargs] * len(files) + ) + ) + + +def entry_point_determine_postprocessing_folder(): + parser = argparse.ArgumentParser('Writes postprocessing.pkl and postprocessing.json in input_folder.') + parser.add_argument('-i', type=str, required=True, help='Input folder') + parser.add_argument('-ref', type=str, required=True, help='Folder with gt labels') + parser.add_argument('-plans_json', type=str, required=False, default=None, + help="plans file to use. If not specified we will look for the plans.json file in the " + "input folder (input_folder/plans.json)") + parser.add_argument('-dataset_json', type=str, required=False, default=None, + help="dataset.json file to use. If not specified we will look for the dataset.json file in the " + "input folder (input_folder/dataset.json)") + parser.add_argument('-np', type=int, required=False, default=default_num_processes, + help=f"number of processes to use. Default: {default_num_processes}") + parser.add_argument('--remove_postprocessed', action='store_true', required=False, + help='set this is you don\'t want to keep the postprocessed files') + + args = parser.parse_args() + determine_postprocessing(args.i, args.ref, args.plans_json, args.dataset_json, args.np, + not args.remove_postprocessed) + + +def entry_point_apply_postprocessing(): + parser = argparse.ArgumentParser('Apples postprocessing specified in pp_pkl_file to input folder.') + parser.add_argument('-i', type=str, required=True, help='Input folder') + parser.add_argument('-o', type=str, required=True, help='Output folder') + parser.add_argument('-pp_pkl_file', type=str, required=True, help='postprocessing.pkl file') + parser.add_argument('-np', type=int, required=False, default=default_num_processes, + help=f"number of processes to use. Default: {default_num_processes}") + parser.add_argument('-plans_json', type=str, required=False, default=None, + help="plans file to use. If not specified we will look for the plans.json file in the " + "input folder (input_folder/plans.json)") + parser.add_argument('-dataset_json', type=str, required=False, default=None, + help="dataset.json file to use. If not specified we will look for the dataset.json file in the " + "input folder (input_folder/dataset.json)") + args = parser.parse_args() + pp_fns, pp_fn_kwargs = load_pickle(args.pp_pkl_file) + apply_postprocessing_to_folder(args.i, args.o, pp_fns, pp_fn_kwargs, args.plans_json, args.dataset_json, args.np) + + +if __name__ == '__main__': + trained_model_folder = '/home/fabian/results/nnUNet_remake/Dataset004_Hippocampus/nnUNetTrainer__nnUNetPlans__3d_fullres' + labelstr = join(nnUNet_raw, 'Dataset004_Hippocampus', 'labelsTr') + plans_manager = PlansManager(join(trained_model_folder, 'plans.json')) + dataset_json = load_json(join(trained_model_folder, 'dataset.json')) + folds = (0, 1, 2, 3, 4) + label_manager = plans_manager.get_label_manager(dataset_json) + + merged_output_folder = join(trained_model_folder, f'crossval_results_folds_{folds_tuple_to_string(folds)}') + accumulate_cv_results(trained_model_folder, merged_output_folder, folds, 8, False) + + fns, kwargs = determine_postprocessing(merged_output_folder, labelstr, plans_manager.plans, + dataset_json, 8, keep_postprocessed_files=True) + save_pickle((fns, kwargs), join(trained_model_folder, 'postprocessing.pkl')) + fns, kwargs = load_pickle(join(trained_model_folder, 'postprocessing.pkl')) + + apply_postprocessing_to_folder(merged_output_folder, merged_output_folder + '_pp', fns, kwargs, + plans_manager.plans, dataset_json, + 8) + compute_metrics_on_folder(labelstr, + merged_output_folder + '_pp', + join(merged_output_folder + '_pp', 'summary.json'), + plans_manager.image_reader_writer_class(), + dataset_json['file_ending'], + label_manager.foreground_regions if label_manager.has_regions else label_manager.foreground_labels, + label_manager.ignore_label, + 8) diff --git a/nnUNet/nnunetv2/preprocessing/__init__.py b/nnUNet/nnunetv2/preprocessing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/preprocessing/cropping/__init__.py b/nnUNet/nnunetv2/preprocessing/cropping/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/preprocessing/cropping/cropping.py b/nnUNet/nnunetv2/preprocessing/cropping/cropping.py new file mode 100644 index 0000000..33421c3 --- /dev/null +++ b/nnUNet/nnunetv2/preprocessing/cropping/cropping.py @@ -0,0 +1,51 @@ +import numpy as np + + +# Hello! crop_to_nonzero is the function you are looking for. Ignore the rest. +from acvl_utils.cropping_and_padding.bounding_boxes import get_bbox_from_mask, crop_to_bbox, bounding_box_to_slice + + +def create_nonzero_mask(data): + """ + + :param data: + :return: the mask is True where the data is nonzero + """ + from scipy.ndimage import binary_fill_holes + assert len(data.shape) == 4 or len(data.shape) == 3, "data must have shape (C, X, Y, Z) or shape (C, X, Y)" + nonzero_mask = np.ones(data.shape[1:], dtype=bool) + for c in range(data.shape[0]): + this_mask = data[c] != 0 + nonzero_mask = nonzero_mask | this_mask + nonzero_mask = binary_fill_holes(nonzero_mask) + return nonzero_mask + + +def crop_to_nonzero(data, seg=None, nonzero_label=-1): + """ + + :param data: + :param seg: + :param nonzero_label: this will be written into the segmentation map + :return: + """ + nonzero_mask = create_nonzero_mask(data) + bbox = get_bbox_from_mask(nonzero_mask) + + slicer = bounding_box_to_slice(bbox) + data = data[tuple([slice(None), *slicer])] + + if seg is not None: + seg = seg[tuple([slice(None), *slicer])] + + nonzero_mask = nonzero_mask[slicer][None] + if seg is not None: + seg[(seg == 0) & (~nonzero_mask)] = nonzero_label + else: + nonzero_mask = nonzero_mask.astype(np.int8) + nonzero_mask[nonzero_mask == 0] = nonzero_label + nonzero_mask[nonzero_mask > 0] = 0 + seg = nonzero_mask + return data, seg, bbox + + diff --git a/nnUNet/nnunetv2/preprocessing/normalization/__init__.py b/nnUNet/nnunetv2/preprocessing/normalization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/preprocessing/normalization/default_normalization_schemes.py b/nnUNet/nnunetv2/preprocessing/normalization/default_normalization_schemes.py new file mode 100644 index 0000000..3c90a91 --- /dev/null +++ b/nnUNet/nnunetv2/preprocessing/normalization/default_normalization_schemes.py @@ -0,0 +1,95 @@ +from abc import ABC, abstractmethod +from typing import Type + +import numpy as np +from numpy import number + + +class ImageNormalization(ABC): + leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = None + + def __init__(self, use_mask_for_norm: bool = None, intensityproperties: dict = None, + target_dtype: Type[number] = np.float32): + assert use_mask_for_norm is None or isinstance(use_mask_for_norm, bool) + self.use_mask_for_norm = use_mask_for_norm + assert isinstance(intensityproperties, dict) + self.intensityproperties = intensityproperties + self.target_dtype = target_dtype + + @abstractmethod + def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: + """ + Image and seg must have the same shape. Seg is not always used + """ + pass + + +class ZScoreNormalization(ImageNormalization): + leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = True + + def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: + """ + here seg is used to store the zero valued region. The value for that region in the segmentation is -1 by + default. + """ + image = image.astype(self.target_dtype) + if self.use_mask_for_norm is not None and self.use_mask_for_norm: + # negative values in the segmentation encode the 'outside' region (think zero values around the brain as + # in BraTS). We want to run the normalization only in the brain region, so we need to mask the image. + # The default nnU-net sets use_mask_for_norm to True if cropping to the nonzero region substantially + # reduced the image size. + mask = seg >= 0 + mean = image[mask].mean() + std = image[mask].std() + image[mask] = (image[mask] - mean) / (max(std, 1e-8)) + else: + mean = image.mean() + std = image.std() + image = (image - mean) / (max(std, 1e-8)) + return image + + +class CTNormalization(ImageNormalization): + leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False + + def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: + assert self.intensityproperties is not None, "CTNormalization requires intensity properties" + image = image.astype(self.target_dtype) + mean_intensity = self.intensityproperties['mean'] + std_intensity = self.intensityproperties['std'] + lower_bound = self.intensityproperties['percentile_00_5'] + upper_bound = self.intensityproperties['percentile_99_5'] + image = np.clip(image, lower_bound, upper_bound) + image = (image - mean_intensity) / max(std_intensity, 1e-8) + return image + + +class NoNormalization(ImageNormalization): + leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False + + def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: + return image.astype(self.target_dtype) + + +class RescaleTo01Normalization(ImageNormalization): + leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False + + def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: + image = image.astype(self.target_dtype) + image = image - image.min() + image = image / np.clip(image.max(), a_min=1e-8, a_max=None) + return image + + +class RGBTo01Normalization(ImageNormalization): + leaves_pixels_outside_mask_at_zero_if_use_mask_for_norm_is_true = False + + def run(self, image: np.ndarray, seg: np.ndarray = None) -> np.ndarray: + assert image.min() >= 0, "RGB images are uint 8, for whatever reason I found pixel values smaller than 0. " \ + "Your images do not seem to be RGB images" + assert image.max() <= 255, "RGB images are uint 8, for whatever reason I found pixel values greater than 255" \ + ". Your images do not seem to be RGB images" + image = image.astype(self.target_dtype) + image = image / 255. + return image + diff --git a/nnUNet/nnunetv2/preprocessing/normalization/map_channel_name_to_normalization.py b/nnUNet/nnunetv2/preprocessing/normalization/map_channel_name_to_normalization.py new file mode 100644 index 0000000..e821650 --- /dev/null +++ b/nnUNet/nnunetv2/preprocessing/normalization/map_channel_name_to_normalization.py @@ -0,0 +1,24 @@ +from typing import Type + +from nnunetv2.preprocessing.normalization.default_normalization_schemes import CTNormalization, NoNormalization, \ + ZScoreNormalization, RescaleTo01Normalization, RGBTo01Normalization, ImageNormalization + +channel_name_to_normalization_mapping = { + 'CT': CTNormalization, + 'noNorm': NoNormalization, + 'zscore': ZScoreNormalization, + 'rescale_0_1': RescaleTo01Normalization, + 'rgb_to_0_1': RGBTo01Normalization +} + + +def get_normalization_scheme(channel_name: str) -> Type[ImageNormalization]: + """ + If we find the channel_name in channel_name_to_normalization_mapping return the corresponding normalization. If it is + not found, use the default (ZScoreNormalization) + """ + norm_scheme = channel_name_to_normalization_mapping.get(channel_name) + if norm_scheme is None: + norm_scheme = ZScoreNormalization + # print('Using %s for image normalization' % norm_scheme.__name__) + return norm_scheme diff --git a/nnUNet/nnunetv2/preprocessing/normalization/readme.md b/nnUNet/nnunetv2/preprocessing/normalization/readme.md new file mode 100644 index 0000000..7b54396 --- /dev/null +++ b/nnUNet/nnunetv2/preprocessing/normalization/readme.md @@ -0,0 +1,5 @@ +The channel_names entry in dataset.json only determines the normlaization scheme. So if you want to use something different +then you can just +- create a new subclass of ImageNormalization +- map your custom channel identifier to that subclass in channel_name_to_normalization_mapping +- run plan and preprocess again with your custom normlaization scheme \ No newline at end of file diff --git a/nnUNet/nnunetv2/preprocessing/preprocessors/__init__.py b/nnUNet/nnunetv2/preprocessing/preprocessors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/preprocessing/preprocessors/default_preprocessor.py b/nnUNet/nnunetv2/preprocessing/preprocessors/default_preprocessor.py new file mode 100644 index 0000000..685ffcd --- /dev/null +++ b/nnUNet/nnunetv2/preprocessing/preprocessors/default_preprocessor.py @@ -0,0 +1,302 @@ +# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import multiprocessing +import shutil +from time import sleep +from typing import Union, Tuple + +import nnunetv2 +import numpy as np +from batchgenerators.utilities.file_and_folder_operations import * +from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw +from nnunetv2.preprocessing.cropping.cropping import crop_to_nonzero +from nnunetv2.preprocessing.resampling.default_resampling import compute_new_shape +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name +from nnunetv2.utilities.find_class_by_name import recursive_find_python_class +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager +from nnunetv2.utilities.utils import get_identifiers_from_splitted_dataset_folder, \ + create_lists_from_splitted_dataset_folder, get_filenames_of_train_images_and_targets +from tqdm import tqdm + + +class DefaultPreprocessor(object): + def __init__(self, verbose: bool = True): + self.verbose = verbose + """ + Everything we need is in the plans. Those are given when run() is called + """ + + def run_case_npy(self, data: np.ndarray, seg: Union[np.ndarray, None], properties: dict, + plans_manager: PlansManager, configuration_manager: ConfigurationManager, + dataset_json: Union[dict, str]): + # let's not mess up the inputs! + data = np.copy(data) + if seg is not None: + seg = np.copy(seg) + + has_seg = seg is not None + + # apply transpose_forward, this also needs to be applied to the spacing! + data = data.transpose([0, *[i + 1 for i in plans_manager.transpose_forward]]) + if seg is not None: + seg = seg.transpose([0, *[i + 1 for i in plans_manager.transpose_forward]]) + original_spacing = [properties['spacing'][i] for i in plans_manager.transpose_forward] + + # crop, remember to store size before cropping! + shape_before_cropping = data.shape[1:] + properties['shape_before_cropping'] = shape_before_cropping + # this command will generate a segmentation. This is important because of the nonzero mask which we may need + data, seg, bbox = crop_to_nonzero(data, seg) + properties['bbox_used_for_cropping'] = bbox + # print(data.shape, seg.shape) + properties['shape_after_cropping_and_before_resampling'] = data.shape[1:] + + # resample + target_spacing = configuration_manager.spacing # this should already be transposed + + if len(target_spacing) < len(data.shape[1:]): + # target spacing for 2d has 2 entries but the data and original_spacing have three because everything is 3d + # in 3d we do not change the spacing between slices + target_spacing = [original_spacing[0]] + target_spacing + new_shape = compute_new_shape(data.shape[1:], original_spacing, target_spacing) + + # normalize + # normalization MUST happen before resampling or we get huge problems with resampled nonzero masks no + # longer fitting the images perfectly! + data = self._normalize(data, seg, configuration_manager, + plans_manager.foreground_intensity_properties_per_channel) + + # print('current shape', data.shape[1:], 'current_spacing', original_spacing, + # '\ntarget shape', new_shape, 'target_spacing', target_spacing) + old_shape = data.shape[1:] + data = configuration_manager.resampling_fn_data(data, new_shape, original_spacing, target_spacing) + print(f"Actual new_shape: {data.shape}", flush=True) + print(type(seg)) + if has_seg: + seg = configuration_manager.resampling_fn_seg(seg, new_shape, original_spacing, target_spacing) + else: + seg = None + if self.verbose: + print(f'old shape: {old_shape}, new_shape: {new_shape}, old_spacing: {original_spacing}, ' + f'new_spacing: {target_spacing}, fn_data: {configuration_manager.resampling_fn_data}') + + # if we have a segmentation, sample foreground locations for oversampling and add those to properties + if has_seg: + # reinstantiating LabelManager for each case is not ideal. We could replace the dataset_json argument + # with a LabelManager Instance in this function because that's all its used for. Dunno what's better. + # LabelManager is pretty light computation-wise. + label_manager = plans_manager.get_label_manager(dataset_json) + collect_for_this = label_manager.foreground_regions if label_manager.has_regions \ + else label_manager.foreground_labels + + # when using the ignore label we want to sample only from annotated regions. Therefore we also need to + # collect samples uniformly from all classes (incl background) + if label_manager.has_ignore_label: + collect_for_this.append(label_manager.all_labels) + + # no need to filter background in regions because it is already filtered in handle_labels + # print(all_labels, regions) + properties['class_locations'] = self._sample_foreground_locations(seg, collect_for_this, + verbose=self.verbose) + seg = self.modify_seg_fn(seg, plans_manager, dataset_json, configuration_manager) + + if np.max(seg) > 127: + seg = seg.astype(np.int16) + else: + seg = seg.astype(np.int8) + # data = data.astype(np.float16) + return data, seg + + def run_case(self, image_files: List[str], seg_file: Union[str, None], plans_manager: PlansManager, + configuration_manager: ConfigurationManager, + dataset_json: Union[dict, str]): + """ + seg file can be none (test cases) + + order of operations is: transpose -> crop -> resample + so when we export we need to run the following order: resample -> crop -> transpose (we could also run + transpose at a different place, but reverting the order of operations done during preprocessing seems cleaner) + """ + if isinstance(dataset_json, str): + dataset_json = load_json(dataset_json) + + rw = plans_manager.image_reader_writer_class() + + # load image(s) + data, data_properites = rw.read_images(image_files) + + # if possible, load seg + if seg_file is not None: + seg, _ = rw.read_seg(seg_file) + else: + seg = None + + data, seg = self.run_case_npy(data, seg, data_properites, plans_manager, configuration_manager, + dataset_json) + return data, seg, data_properites + + def run_case_save(self, output_filename_truncated: str, image_files: List[str], seg_file: str, + plans_manager: PlansManager, configuration_manager: ConfigurationManager, + dataset_json: Union[dict, str]): + data, seg, properties = self.run_case(image_files, seg_file, plans_manager, configuration_manager, dataset_json) + # print('dtypes', data.dtype, seg.dtype) + np.savez_compressed(output_filename_truncated + '.npz', data=data, seg=seg) + write_pickle(properties, output_filename_truncated + '.pkl') + + @staticmethod + def _sample_foreground_locations(seg: np.ndarray, classes_or_regions: Union[List[int], List[Tuple[int, ...]]], + seed: int = 1234, verbose: bool = False): + num_samples = 10000 + min_percent_coverage = 0.01 # at least 1% of the class voxels need to be selected, otherwise it may be too + # sparse + rndst = np.random.RandomState(seed) + class_locs = {} + for c in classes_or_regions: + k = c if not isinstance(c, list) else tuple(c) + if isinstance(c, (tuple, list)): + mask = seg == c[0] + for cc in c[1:]: + mask = mask | (seg == cc) + all_locs = np.argwhere(mask) + else: + all_locs = np.argwhere(seg == c) + if len(all_locs) == 0: + class_locs[k] = [] + continue + target_num_samples = min(num_samples, len(all_locs)) + target_num_samples = max(target_num_samples, int(np.ceil(len(all_locs) * min_percent_coverage))) + + selected = all_locs[rndst.choice(len(all_locs), target_num_samples, replace=False)] + class_locs[k] = selected + if verbose: + print(c, target_num_samples) + return class_locs + + def _normalize(self, data: np.ndarray, seg: np.ndarray, configuration_manager: ConfigurationManager, + foreground_intensity_properties_per_channel: dict) -> np.ndarray: + for c in range(data.shape[0]): + scheme = configuration_manager.normalization_schemes[c] + normalizer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "preprocessing", "normalization"), + scheme, + 'nnunetv2.preprocessing.normalization') + if normalizer_class is None: + raise RuntimeError('Unable to locate class \'%s\' for normalization' % scheme) + normalizer = normalizer_class(use_mask_for_norm=configuration_manager.use_mask_for_norm[c], + intensityproperties=foreground_intensity_properties_per_channel[str(c)]) + data[c] = normalizer.run(data[c], seg[0]) + return data + + def run(self, dataset_name_or_id: Union[int, str], configuration_name: str, plans_identifier: str, + num_processes: int): + """ + data identifier = configuration name in plans. EZ. + """ + dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id) + + assert isdir(join(nnUNet_raw, dataset_name)), "The requested dataset could not be found in nnUNet_raw" + + plans_file = join(nnUNet_preprocessed, dataset_name, plans_identifier + '.json') + assert isfile(plans_file), "Expected plans file (%s) not found. Run corresponding nnUNet_plan_experiment " \ + "first." % plans_file + plans = load_json(plans_file) + plans_manager = PlansManager(plans) + configuration_manager = plans_manager.get_configuration(configuration_name) + + if self.verbose: + print(f'Preprocessing the following configuration: {configuration_name}') + if self.verbose: + print(configuration_manager) + + dataset_json_file = join(nnUNet_preprocessed, dataset_name, 'dataset.json') + dataset_json = load_json(dataset_json_file) + + output_directory = join(nnUNet_preprocessed, dataset_name, configuration_manager.data_identifier) + + if isdir(output_directory): + shutil.rmtree(output_directory) + + maybe_mkdir_p(output_directory) + + dataset = get_filenames_of_train_images_and_targets(join(nnUNet_raw, dataset_name), dataset_json) + + # identifiers = [os.path.basename(i[:-len(dataset_json['file_ending'])]) for i in seg_fnames] + # output_filenames_truncated = [join(output_directory, i) for i in identifiers] + + # multiprocessing magic. + r = [] + with multiprocessing.get_context("spawn").Pool(num_processes) as p: + for k in dataset.keys(): + r.append(p.starmap_async(self.run_case_save, + ((join(output_directory, k), dataset[k]['images'], dataset[k]['label'], + plans_manager, configuration_manager, + dataset_json),))) + remaining = list(range(len(dataset))) + # p is pretty nifti. If we kill workers they just respawn but don't do any work. + # So we need to store the original pool of workers. + workers = [j for j in p._pool] + with tqdm(desc=None, total=len(dataset), disable=self.verbose) as pbar: + while len(remaining) > 0: + all_alive = all([j.is_alive() for j in workers]) + if not all_alive: + raise RuntimeError('Some background worker is 6 feet under. Yuck. \n' + 'OK jokes aside.\n' + 'One of your background processes is missing. This could be because of ' + 'an error (look for an error message) or because it was killed ' + 'by your OS due to running out of RAM. If you don\'t see ' + 'an error message, out of RAM is likely the problem. In that case ' + 'reducing the number of workers might help') + done = [i for i in remaining if r[i].ready()] + for _ in done: + pbar.update() + remaining = [i for i in remaining if i not in done] + sleep(0.1) + + def modify_seg_fn(self, seg: np.ndarray, plans_manager: PlansManager, dataset_json: dict, + configuration_manager: ConfigurationManager) -> np.ndarray: + # this function will be called at the end of self.run_case. Can be used to change the segmentation + # after resampling. Useful for experimenting with sparse annotations: I can introduce sparsity after resampling + # and don't have to create a new dataset each time I modify my experiments + return seg + + +def example_test_case_preprocessing(): + # (paths to files may need adaptations) + plans_file = '/home/isensee/drives/gpu_data/nnUNet_preprocessed/Dataset219_AMOS2022_postChallenge_task2/nnUNetPlans.json' + dataset_json_file = '/home/isensee/drives/gpu_data/nnUNet_preprocessed/Dataset219_AMOS2022_postChallenge_task2/dataset.json' + input_images = ['/home/isensee/drives/e132-rohdaten/nnUNetv2/Dataset219_AMOS2022_postChallenge_task2/imagesTr/amos_0600_0000.nii.gz', ] # if you only have one channel, you still need a list: ['case000_0000.nii.gz'] + + configuration = '3d_fullres' + pp = DefaultPreprocessor() + + # _ because this position would be the segmentation if seg_file was not None (training case) + # even if you have the segmentation, don't put the file there! You should always evaluate in the original + # resolution. What comes out of the preprocessor might have been resampled to some other image resolution (as + # specified by plans) + plans_manager = PlansManager(plans_file) + data, _, properties = pp.run_case(input_images, seg_file=None, plans_manager=plans_manager, + configuration_manager=plans_manager.get_configuration(configuration), + dataset_json=dataset_json_file) + + # voila. Now plug data into your prediction function of choice. We of course recommend nnU-Net's default (TODO) + return data + + +if __name__ == '__main__': + example_test_case_preprocessing() + # pp = DefaultPreprocessor() + # pp.run(2, '2d', 'nnUNetPlans', 8) + + ########################################################################################################### + # how to process a test cases? This is an example: + # example_test_case_preprocessing() diff --git a/nnUNet/nnunetv2/preprocessing/preprocessors/default_preprocessor.py.save b/nnUNet/nnunetv2/preprocessing/preprocessors/default_preprocessor.py.save new file mode 100644 index 0000000..ee6c9e6 --- /dev/null +++ b/nnUNet/nnunetv2/preprocessing/preprocessors/default_preprocessor.py.save @@ -0,0 +1,301 @@ +# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import multiprocessing +import shutil +from time import sleep +from typing import Union, Tuple + +import nnunetv2 +import numpy as np +from batchgenerators.utilities.file_and_folder_operations import * +from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw +from nnunetv2.preprocessing.cropping.cropping import crop_to_nonzero +from nnunetv2.preprocessing.resampling.default_resampling import compute_new_shape +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name +from nnunetv2.utilities.find_class_by_name import recursive_find_python_class +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager +from nnunetv2.utilities.utils import get_identifiers_from_splitted_dataset_folder, \ + create_lists_from_splitted_dataset_folder, get_filenames_of_train_images_and_targets +from tqdm import tqdm + + +class DefaultPreprocessor(object): + def __init__(self, verbose: bool = True): + self.verbose = verbose + """ + Everything we need is in the plans. Those are given when run() is called + """ + + def run_case_npy(self, data: np.ndarray, seg: Union[np.ndarray, None], properties: dict, + plans_manager: PlansManager, configuration_manager: ConfigurationManager, + dataset_json: Union[dict, str]): + # let's not mess up the inputs! + data = np.copy(data) + if seg is not None: + seg = np.copy(seg) + + has_seg = seg is not None + + # apply transpose_forward, this also needs to be applied to the spacing! + data = data.transpose([0, *[i + 1 for i in plans_manager.transpose_forward]]) + if seg is not None: + seg = seg.transpose([0, *[i + 1 for i in plans_manager.transpose_forward]]) + original_spacing = [properties['spacing'][i] for i in plans_manager.transpose_forward] + + # crop, remember to store size before cropping! + shape_before_cropping = data.shape[1:] + properties['shape_before_cropping'] = shape_before_cropping + # this command will generate a segmentation. This is important because of the nonzero mask which we may need + data, seg, bbox = crop_to_nonzero(data, seg) + properties['bbox_used_for_cropping'] = bbox + # print(data.shape, seg.shape) + properties['shape_after_cropping_and_before_resampling'] = data.shape[1:] + + # resample + target_spacing = configuration_manager.spacing # this should already be transposed + + if len(target_spacing) < len(data.shape[1:]): + # target spacing for 2d has 2 entries but the data and original_spacing have three because everything is 3d + # in 3d we do not change the spacing between slices + target_spacing = [original_spacing[0]] + target_spacing + new_shape = compute_new_shape(data.shape[1:], original_spacing, target_spacing) + + # normalize + # normalization MUST happen before resampling or we get huge problems with resampled nonzero masks no + # longer fitting the images perfectly! + data = self._normalize(data, seg, configuration_manager, + plans_manager.foreground_intensity_properties_per_channel) + + # print('current shape', data.shape[1:], 'current_spacing', original_spacing, + # '\ntarget shape', new_shape, 'target_spacing', target_spacing) + old_shape = data.shape[1:] + data = configuration_manager.resampling_fn_data(data, new_shape, original_spacing, target_spacing) + print(f"Actual new_shape: {data.shape}", flush=True) + print(type(seg)) + if has_seg: + seg = configuration_manager.resampling_fn_seg(seg, new_shape, original_spacing, target_spacing) + if self.verbose: + print(f'old shape: {old_shape}, new_shape: {new_shape}, old_spacing: {original_spacing}, ' + f'new_spacing: {target_spacing}, fn_data: {configuration_manager.resampling_fn_data}') + + # if we have a segmentation, sample foreground locations for oversampling and add those to properties + if has_seg: + # reinstantiating LabelManager for each case is not ideal. We could replace the dataset_json argument + # with a LabelManager Instance in this function because that's all its used for. Dunno what's better. + # LabelManager is pretty light computation-wise. + label_manager = plans_manager.get_label_manager(dataset_json) + collect_for_this = label_manager.foreground_regions if label_manager.has_regions \ + else label_manager.foreground_labels + + # when using the ignore label we want to sample only from annotated regions. Therefore we also need to + # collect samples uniformly from all classes (incl background) + if label_manager.has_ignore_label: + collect_for_this.append(label_manager.all_labels) + + # no need to filter background in regions because it is already filtered in handle_labels + # print(all_labels, regions) + properties['class_locations'] = self._sample_foreground_locations(seg, collect_for_this, + verbose=self.verbose) + seg = self.modify_seg_fn(seg, plans_manager, dataset_json, configuration_manager) + if has_seg: + if np.max(seg) > 127: + seg = seg.astype(np.int16) + else: + seg = seg.astype(np.int8) + else: + + return data, seg + + def run_case(self, image_files: List[str], seg_file: Union[str, None], plans_manager: PlansManager, + configuration_manager: ConfigurationManager, + dataset_json: Union[dict, str]): + """ + seg file can be none (test cases) + + order of operations is: transpose -> crop -> resample + so when we export we need to run the following order: resample -> crop -> transpose (we could also run + transpose at a different place, but reverting the order of operations done during preprocessing seems cleaner) + """ + if isinstance(dataset_json, str): + dataset_json = load_json(dataset_json) + + rw = plans_manager.image_reader_writer_class() + + # load image(s) + data, data_properites = rw.read_images(image_files) + + # if possible, load seg + if seg_file is not None: + seg, _ = rw.read_seg(seg_file) + else: + seg = None + + data, seg = self.run_case_npy(data, seg, data_properites, plans_manager, configuration_manager, + dataset_json) + return data, seg, data_properites + + def run_case_save(self, output_filename_truncated: str, image_files: List[str], seg_file: str, + plans_manager: PlansManager, configuration_manager: ConfigurationManager, + dataset_json: Union[dict, str]): + data, seg, properties = self.run_case(image_files, seg_file, plans_manager, configuration_manager, dataset_json) + # print('dtypes', data.dtype, seg.dtype) + np.savez_compressed(output_filename_truncated + '.npz', data=data, seg=seg) + write_pickle(properties, output_filename_truncated + '.pkl') + + @staticmethod + def _sample_foreground_locations(seg: np.ndarray, classes_or_regions: Union[List[int], List[Tuple[int, ...]]], + seed: int = 1234, verbose: bool = False): + num_samples = 10000 + min_percent_coverage = 0.01 # at least 1% of the class voxels need to be selected, otherwise it may be too + # sparse + rndst = np.random.RandomState(seed) + class_locs = {} + for c in classes_or_regions: + k = c if not isinstance(c, list) else tuple(c) + if isinstance(c, (tuple, list)): + mask = seg == c[0] + for cc in c[1:]: + mask = mask | (seg == cc) + all_locs = np.argwhere(mask) + else: + all_locs = np.argwhere(seg == c) + if len(all_locs) == 0: + class_locs[k] = [] + continue + target_num_samples = min(num_samples, len(all_locs)) + target_num_samples = max(target_num_samples, int(np.ceil(len(all_locs) * min_percent_coverage))) + + selected = all_locs[rndst.choice(len(all_locs), target_num_samples, replace=False)] + class_locs[k] = selected + if verbose: + print(c, target_num_samples) + return class_locs + + def _normalize(self, data: np.ndarray, seg: np.ndarray, configuration_manager: ConfigurationManager, + foreground_intensity_properties_per_channel: dict) -> np.ndarray: + for c in range(data.shape[0]): + scheme = configuration_manager.normalization_schemes[c] + normalizer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "preprocessing", "normalization"), + scheme, + 'nnunetv2.preprocessing.normalization') + if normalizer_class is None: + raise RuntimeError('Unable to locate class \'%s\' for normalization' % scheme) + normalizer = normalizer_class(use_mask_for_norm=configuration_manager.use_mask_for_norm[c], + intensityproperties=foreground_intensity_properties_per_channel[str(c)]) + data[c] = normalizer.run(data[c], seg[0]) + return data + + def run(self, dataset_name_or_id: Union[int, str], configuration_name: str, plans_identifier: str, + num_processes: int): + """ + data identifier = configuration name in plans. EZ. + """ + dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id) + + assert isdir(join(nnUNet_raw, dataset_name)), "The requested dataset could not be found in nnUNet_raw" + + plans_file = join(nnUNet_preprocessed, dataset_name, plans_identifier + '.json') + assert isfile(plans_file), "Expected plans file (%s) not found. Run corresponding nnUNet_plan_experiment " \ + "first." % plans_file + plans = load_json(plans_file) + plans_manager = PlansManager(plans) + configuration_manager = plans_manager.get_configuration(configuration_name) + + if self.verbose: + print(f'Preprocessing the following configuration: {configuration_name}') + if self.verbose: + print(configuration_manager) + + dataset_json_file = join(nnUNet_preprocessed, dataset_name, 'dataset.json') + dataset_json = load_json(dataset_json_file) + + output_directory = join(nnUNet_preprocessed, dataset_name, configuration_manager.data_identifier) + + if isdir(output_directory): + shutil.rmtree(output_directory) + + maybe_mkdir_p(output_directory) + + dataset = get_filenames_of_train_images_and_targets(join(nnUNet_raw, dataset_name), dataset_json) + + # identifiers = [os.path.basename(i[:-len(dataset_json['file_ending'])]) for i in seg_fnames] + # output_filenames_truncated = [join(output_directory, i) for i in identifiers] + + # multiprocessing magic. + r = [] + with multiprocessing.get_context("spawn").Pool(num_processes) as p: + for k in dataset.keys(): + r.append(p.starmap_async(self.run_case_save, + ((join(output_directory, k), dataset[k]['images'], dataset[k]['label'], + plans_manager, configuration_manager, + dataset_json),))) + remaining = list(range(len(dataset))) + # p is pretty nifti. If we kill workers they just respawn but don't do any work. + # So we need to store the original pool of workers. + workers = [j for j in p._pool] + with tqdm(desc=None, total=len(dataset), disable=self.verbose) as pbar: + while len(remaining) > 0: + all_alive = all([j.is_alive() for j in workers]) + if not all_alive: + raise RuntimeError('Some background worker is 6 feet under. Yuck. \n' + 'OK jokes aside.\n' + 'One of your background processes is missing. This could be because of ' + 'an error (look for an error message) or because it was killed ' + 'by your OS due to running out of RAM. If you don\'t see ' + 'an error message, out of RAM is likely the problem. In that case ' + 'reducing the number of workers might help') + done = [i for i in remaining if r[i].ready()] + for _ in done: + pbar.update() + remaining = [i for i in remaining if i not in done] + sleep(0.1) + + def modify_seg_fn(self, seg: np.ndarray, plans_manager: PlansManager, dataset_json: dict, + configuration_manager: ConfigurationManager) -> np.ndarray: + # this function will be called at the end of self.run_case. Can be used to change the segmentation + # after resampling. Useful for experimenting with sparse annotations: I can introduce sparsity after resampling + # and don't have to create a new dataset each time I modify my experiments + return seg + + +def example_test_case_preprocessing(): + # (paths to files may need adaptations) + plans_file = '/home/isensee/drives/gpu_data/nnUNet_preprocessed/Dataset219_AMOS2022_postChallenge_task2/nnUNetPlans.json' + dataset_json_file = '/home/isensee/drives/gpu_data/nnUNet_preprocessed/Dataset219_AMOS2022_postChallenge_task2/dataset.json' + input_images = ['/home/isensee/drives/e132-rohdaten/nnUNetv2/Dataset219_AMOS2022_postChallenge_task2/imagesTr/amos_0600_0000.nii.gz', ] # if you only have one channel, you still need a list: ['case000_0000.nii.gz'] + + configuration = '3d_fullres' + pp = DefaultPreprocessor() + + # _ because this position would be the segmentation if seg_file was not None (training case) + # even if you have the segmentation, don't put the file there! You should always evaluate in the original + # resolution. What comes out of the preprocessor might have been resampled to some other image resolution (as + # specified by plans) + plans_manager = PlansManager(plans_file) + data, _, properties = pp.run_case(input_images, seg_file=None, plans_manager=plans_manager, + configuration_manager=plans_manager.get_configuration(configuration), + dataset_json=dataset_json_file) + + # voila. Now plug data into your prediction function of choice. We of course recommend nnU-Net's default (TODO) + return data + + +if __name__ == '__main__': + example_test_case_preprocessing() + # pp = DefaultPreprocessor() + # pp.run(2, '2d', 'nnUNetPlans', 8) + + ########################################################################################################### + # how to process a test cases? This is an example: + # example_test_case_preprocessing() diff --git a/nnUNet/nnunetv2/preprocessing/resampling/__init__.py b/nnUNet/nnunetv2/preprocessing/resampling/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/preprocessing/resampling/default_resampling.py b/nnUNet/nnunetv2/preprocessing/resampling/default_resampling.py new file mode 100644 index 0000000..0079739 --- /dev/null +++ b/nnUNet/nnunetv2/preprocessing/resampling/default_resampling.py @@ -0,0 +1,222 @@ +from collections import OrderedDict +from typing import Union, Tuple, List + +import numpy as np +import pandas as pd +import torch +from batchgenerators.augmentations.utils import resize_segmentation +from scipy.ndimage.interpolation import map_coordinates +from skimage.transform import resize +from nnunetv2.configuration import ANISO_THRESHOLD + + +def get_do_separate_z(spacing: Union[Tuple[float, ...], List[float], np.ndarray], anisotropy_threshold=ANISO_THRESHOLD): + do_separate_z = (np.max(spacing) / np.min(spacing)) > anisotropy_threshold + print(f"{np.max(spacing)/np.min(spacing)} > {anisotropy_threshold}", flush = True) + return do_separate_z + + +def get_lowres_axis(new_spacing: Union[Tuple[float, ...], List[float], np.ndarray]): + axis = np.where(max(new_spacing) / np.array(new_spacing) == 1)[0] # find which axis is anisotropic + return axis + + +def compute_new_shape(old_shape: Union[Tuple[int, ...], List[int], np.ndarray], + old_spacing: Union[Tuple[float, ...], List[float], np.ndarray], + new_spacing: Union[Tuple[float, ...], List[float], np.ndarray]) -> np.ndarray: + assert len(old_spacing) == len(old_shape) + assert len(old_shape) == len(new_spacing) + new_shape = np.array([int(round(i / j * k)) for i, j, k in zip(old_spacing, new_spacing, old_shape)]) + return new_shape + + +def resample_data_or_seg_to_spacing(data: np.ndarray, + current_spacing: Union[Tuple[float, ...], List[float], np.ndarray], + new_spacing: Union[Tuple[float, ...], List[float], np.ndarray], + is_seg: bool = False, + order: int = 3, order_z: int = 0, + force_separate_z: Union[bool, None] = False, + separate_z_anisotropy_threshold: float = ANISO_THRESHOLD): + if force_separate_z is not None: + do_separate_z = force_separate_z + if force_separate_z: + axis = get_lowres_axis(current_spacing) + else: + axis = None + else: + if get_do_separate_z(current_spacing, separate_z_anisotropy_threshold): + do_separate_z = True + axis = get_lowres_axis(current_spacing) + elif get_do_separate_z(new_spacing, separate_z_anisotropy_threshold): + do_separate_z = True + axis = get_lowres_axis(new_spacing) + else: + do_separate_z = False + axis = None + + if axis is not None: + if len(axis) == 3: + # every axis has the same spacing, this should never happen, why is this code here? + do_separate_z = False + elif len(axis) == 2: + # this happens for spacings like (0.24, 1.25, 1.25) for example. In that case we do not want to resample + # separately in the out of plane axis + do_separate_z = False + else: + pass + + if data is not None: + assert len(data.shape) == 4, "data must be c x y z" + + shape = np.array(data[0].shape) + new_shape = compute_new_shape(shape[1:], current_spacing, new_spacing) + + data_reshaped = resample_data_or_seg(data, new_shape, is_seg, axis, order, do_separate_z, order_z=order_z) + return data_reshaped + + +def resample_data_or_seg_to_shape(data: Union[torch.Tensor, np.ndarray], + new_shape: Union[Tuple[int, ...], List[int], np.ndarray], + current_spacing: Union[Tuple[float, ...], List[float], np.ndarray], + new_spacing: Union[Tuple[float, ...], List[float], np.ndarray], + is_seg: bool = False, + order: int = 3, order_z: int = 0, + force_separate_z: Union[bool, None] = False, + separate_z_anisotropy_threshold: float = ANISO_THRESHOLD): + """ + needed for segmentation export. Stupid, I know. Maybe we can fix that with Leos new resampling functions + """ + if isinstance(data, torch.Tensor): + data = data.cpu().numpy() + if force_separate_z is not None: + do_separate_z = force_separate_z + if force_separate_z: + axis = get_lowres_axis(current_spacing) + else: + axis = None + else: + if get_do_separate_z(current_spacing, separate_z_anisotropy_threshold): + do_separate_z = True + axis = get_lowres_axis(current_spacing) + elif get_do_separate_z(new_spacing, separate_z_anisotropy_threshold): + do_separate_z = True + axis = get_lowres_axis(new_spacing) + else: + do_separate_z = False + axis = None + + if axis is not None: + if len(axis) == 3: + # every axis has the same spacing, this should never happen, why is this code here? + do_separate_z = False + elif len(axis) == 2: + # this happens for spacings like (0.24, 1.25, 1.25) for example. In that case we do not want to resample + # separately in the out of plane axis + do_separate_z = False + else: + pass + + if data is not None: + assert len(data.shape) == 4, "data must be c x y z" + + data_reshaped = resample_data_or_seg(data, new_shape, is_seg, axis, order, do_separate_z, order_z=order_z) + return data_reshaped + + +def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], List[float], np.ndarray], + is_seg: bool = False, axis: Union[None, int] = None, order: int = 3, + do_separate_z: bool = False, order_z: int = 0): + """ + separate_z=True will resample with order 0 along z + :param data: + :param new_shape: + :param is_seg: + :param axis: + :param order: + :param do_separate_z: + :param order_z: only applies if do_separate_z is True + :return: + """ + assert len(data.shape) == 4, "data must be (c, x, y, z)" + assert len(new_shape) == len(data.shape) - 1 + + if is_seg: + resize_fn = resize_segmentation + kwargs = OrderedDict() + else: + resize_fn = resize + kwargs = {'mode': 'edge', 'anti_aliasing': False} + dtype_data = data.dtype + shape = np.array(data[0].shape) + new_shape = np.array(new_shape) + if np.any(shape != new_shape): + data = data.astype(float) + if do_separate_z: + print("separate z, order in z is", order_z, "order inplane is", order) + assert len(axis) == 1, "only one anisotropic axis supported" + axis = axis[0] + if axis == 0: + new_shape_2d = new_shape[1:] + elif axis == 1: + new_shape_2d = new_shape[[0, 2]] + else: + new_shape_2d = new_shape[:-1] + + reshaped_final_data = [] + for c in range(data.shape[0]): + reshaped_data = [] + for slice_id in range(shape[axis]): + if axis == 0: + reshaped_data.append(resize_fn(data[c, slice_id], new_shape_2d, order, **kwargs)) + elif axis == 1: + reshaped_data.append(resize_fn(data[c, :, slice_id], new_shape_2d, order, **kwargs)) + else: + reshaped_data.append(resize_fn(data[c, :, :, slice_id], new_shape_2d, order, **kwargs)) + reshaped_data = np.stack(reshaped_data, axis) + if shape[axis] != new_shape[axis]: + + # The following few lines are blatantly copied and modified from sklearn's resize() + rows, cols, dim = new_shape[0], new_shape[1], new_shape[2] + orig_rows, orig_cols, orig_dim = reshaped_data.shape + + row_scale = float(orig_rows) / rows + col_scale = float(orig_cols) / cols + dim_scale = float(orig_dim) / dim + + map_rows, map_cols, map_dims = np.mgrid[:rows, :cols, :dim] + map_rows = row_scale * (map_rows + 0.5) - 0.5 + map_cols = col_scale * (map_cols + 0.5) - 0.5 + map_dims = dim_scale * (map_dims + 0.5) - 0.5 + map_rows = map_rows.astype(np.float32) + map_cols = map_cols.astype(np.float32) + map_dims = map_dims.astype(np.float32) + print(map_rows.dtype, flush=True) + print(map_cols.dtype, flush=True) + print(map_dims.dtype, flush=True) + coord_map = np.array([map_rows, map_cols, map_dims]) + if not is_seg or order_z == 0: + reshaped_final_data.append(map_coordinates(reshaped_data, coord_map, order=order_z, + mode='nearest')[None]) + else: + unique_labels = np.sort(pd.unique(reshaped_data.ravel())) # np.unique(reshaped_data) + reshaped = np.zeros(new_shape, dtype=dtype_data) + + for i, cl in enumerate(unique_labels): + reshaped_multihot = np.round( + map_coordinates((reshaped_data == cl).astype(float), coord_map, order=order_z, + mode='nearest')) + reshaped[reshaped_multihot > 0.5] = cl + reshaped_final_data.append(reshaped[None]) + else: + reshaped_final_data.append(reshaped_data[None]) + reshaped_final_data = np.vstack(reshaped_final_data) + else: + # print("no separate z, order", order) + reshaped = [] + for c in range(data.shape[0]): + reshaped.append(resize_fn(data[c], new_shape, order, **kwargs)[None]) + reshaped_final_data = np.vstack(reshaped) + return reshaped_final_data.astype(dtype_data) + else: + # print("no resampling necessary") + return data diff --git a/nnUNet/nnunetv2/preprocessing/resampling/utils.py b/nnUNet/nnunetv2/preprocessing/resampling/utils.py new file mode 100644 index 0000000..0bff719 --- /dev/null +++ b/nnUNet/nnunetv2/preprocessing/resampling/utils.py @@ -0,0 +1,15 @@ +from typing import Callable + +import nnunetv2 +from batchgenerators.utilities.file_and_folder_operations import join +from nnunetv2.utilities.find_class_by_name import recursive_find_python_class + + +def recursive_find_resampling_fn_by_name(resampling_fn: str) -> Callable: + ret = recursive_find_python_class(join(nnunetv2.__path__[0], "preprocessing", "resampling"), resampling_fn, + 'nnunetv2.preprocessing.resampling') + if ret is None: + raise RuntimeError("Unable to find resampling function named '%s'. Please make sure this fn is located in the " + "nnunetv2.preprocessing.resampling module." % resampling_fn) + else: + return ret diff --git a/nnUNet/nnunetv2/run/__init__.py b/nnUNet/nnunetv2/run/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/run/load_pretrained_weights.py b/nnUNet/nnunetv2/run/load_pretrained_weights.py new file mode 100644 index 0000000..b5c51bf --- /dev/null +++ b/nnUNet/nnunetv2/run/load_pretrained_weights.py @@ -0,0 +1,66 @@ +import torch +from torch._dynamo import OptimizedModule +from torch.nn.parallel import DistributedDataParallel as DDP + + +def load_pretrained_weights(network, fname, verbose=False): + """ + Transfers all weights between matching keys in state_dicts. matching is done by name and we only transfer if the + shape is also the same. Segmentation layers (the 1x1(x1) layers that produce the segmentation maps) + identified by keys ending with '.seg_layers') are not transferred! + + If the pretrained weights were optained with a training outside nnU-Net and DDP or torch.optimize was used, + you need to change the keys of the pretrained state_dict. DDP adds a 'module.' prefix and torch.optim adds + '_orig_mod'. You DO NOT need to worry about this if pretraining was done with nnU-Net as + nnUNetTrainer.save_checkpoint takes care of that! + + """ + saved_model = torch.load(fname) + pretrained_dict = saved_model['network_weights'] + + skip_strings_in_pretrained = [ + '.seg_layers.', + ] + + if isinstance(network, DDP): + mod = network.module + else: + mod = network + if isinstance(mod, OptimizedModule): + mod = mod._orig_mod + + model_dict = mod.state_dict() + # verify that all but the segmentation layers have the same shape + for key, _ in model_dict.items(): + if all([i not in key for i in skip_strings_in_pretrained]): + assert key in pretrained_dict, \ + f"Key {key} is missing in the pretrained model weights. The pretrained weights do not seem to be " \ + f"compatible with your network." + assert model_dict[key].shape == pretrained_dict[key].shape, \ + f"The shape of the parameters of key {key} is not the same. Pretrained model: " \ + f"{pretrained_dict[key].shape}; your network: {model_dict[key]}. The pretrained model " \ + f"does not seem to be compatible with your network." + + # fun fact: in principle this allows loading from parameters that do not cover the entire network. For example pretrained + # encoders. Not supported by this function though (see assertions above) + + # commenting out this abomination of a dict comprehension for preservation in the archives of 'what not to do' + # pretrained_dict = {'module.' + k if is_ddp else k: v + # for k, v in pretrained_dict.items() + # if (('module.' + k if is_ddp else k) in model_dict) and + # all([i not in k for i in skip_strings_in_pretrained])} + + pretrained_dict = {k: v for k, v in pretrained_dict.items() + if k in model_dict.keys() and all([i not in k for i in skip_strings_in_pretrained])} + + model_dict.update(pretrained_dict) + + print("################### Loading pretrained weights from file ", fname, '###################') + if verbose: + print("Below is the list of overlapping blocks in pretrained model and nnUNet architecture:") + for key, value in pretrained_dict.items(): + print(key, 'shape', value.shape) + print("################### Done ###################") + mod.load_state_dict(model_dict) + + diff --git a/nnUNet/nnunetv2/run/run_finetuning_stunet.py b/nnUNet/nnunetv2/run/run_finetuning_stunet.py new file mode 100644 index 0000000..f295194 --- /dev/null +++ b/nnUNet/nnunetv2/run/run_finetuning_stunet.py @@ -0,0 +1,66 @@ +from unittest.mock import patch +from torch._dynamo import OptimizedModule +from torch.nn.parallel import DistributedDataParallel as DDP +import torch + +def load_stunet_pretrained_weights(network, fname, verbose=False): + + saved_model = torch.load(fname) + print(saved_model.keys()) + if fname.endswith('pth'): + pretrained_dict = saved_model['network_weights'] + elif fname.endswith('model'): + pretrained_dict = saved_model['state_dict'] + + skip_strings_in_pretrained = [ + 'seg_outputs', + 'conv_blocks_context.0', + ] + + if isinstance(network, DDP): + mod = network.module + else: + mod = network + if isinstance(mod, OptimizedModule): + mod = mod._orig_mod + + model_dict = mod.state_dict() + # verify that all but the segmentation layers have the same shape + for key, _ in model_dict.items(): + if all([i not in key for i in skip_strings_in_pretrained]): + assert key in pretrained_dict, \ + f"Key {key} is missing in the pretrained model weights. The pretrained weights do not seem to be " \ + f"compatible with your network." + assert model_dict[key].shape == pretrained_dict[key].shape, \ + f"The shape of the parameters of key {key} is not the same. Pretrained model: " \ + f"{pretrained_dict[key].shape}; your network: {model_dict[key].shape}. The pretrained model " \ + f"does not seem to be compatible with your network." + + # fun fact: in principle this allows loading from parameters that do not cover the entire network. For example pretrained + # encoders. Not supported by this function though (see assertions above) + + # commenting out this abomination of a dict comprehension for preservation in the archives of 'what not to do' + # pretrained_dict = {'module.' + k if is_ddp else k: v + # for k, v in pretrained_dict.items() + # if (('module.' + k if is_ddp else k) in model_dict) and + # all([i not in k for i in skip_strings_in_pretrained])} + + pretrained_dict = {k: v for k, v in pretrained_dict.items() + if k in model_dict.keys() and all([i not in k for i in skip_strings_in_pretrained])} + + model_dict.update(pretrained_dict) + + print("################### Loading pretrained weights from file ", fname, '###################') + if verbose: + print("Below is the list of overlapping blocks in pretrained model and nnUNet architecture:") + for key, value in pretrained_dict.items(): + print(key, 'shape', value.shape) + print("################### Done ###################") + mod.load_state_dict(model_dict) + + + +if __name__ == '__main__': + from run_training import run_training_entry + # with patch("run_training.load_pretrained_weights", load_stunet_pretrained_weights): + run_training_entry() diff --git a/nnUNet/nnunetv2/run/run_training.py b/nnUNet/nnunetv2/run/run_training.py new file mode 100644 index 0000000..7903590 --- /dev/null +++ b/nnUNet/nnunetv2/run/run_training.py @@ -0,0 +1,260 @@ +import os +import socket +from typing import Union, Optional + +import nnunetv2 +import torch.cuda +import torch.distributed as dist +import torch.multiprocessing as mp +from batchgenerators.utilities.file_and_folder_operations import join, isfile, load_json +from nnunetv2.paths import nnUNet_preprocessed +# from nnunetv2.run.load_pretrained_weights import load_pretrained_weights +from nnunetv2.run.run_finetuning_stunet import load_stunet_pretrained_weights as load_pretrained_weights +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name +from nnunetv2.utilities.find_class_by_name import recursive_find_python_class +from torch.backends import cudnn + + +def find_free_network_port() -> int: + """Finds a free port on localhost. + + It is useful in single-node training when we don't want to connect to a real main node but have to set the + `MASTER_PORT` environment variable. + """ + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("", 0)) + port = s.getsockname()[1] + s.close() + return port + + +def get_trainer_from_args(dataset_name_or_id: Union[int, str], + configuration: str, + fold: int, + trainer_name: str = 'nnUNetTrainer', + plans_identifier: str = 'nnUNetPlans', + use_compressed: bool = False, + device: torch.device = torch.device('cuda')): + # load nnunet class and do sanity checks + nnunet_trainer = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), + trainer_name, 'nnunetv2.training.nnUNetTrainer') + if nnunet_trainer is None: + raise RuntimeError(f'Could not find requested nnunet trainer {trainer_name} in ' + f'nnunetv2.training.nnUNetTrainer (' + f'{join(nnunetv2.__path__[0], "training", "nnUNetTrainer")}). If it is located somewhere ' + f'else, please move it there.') + assert issubclass(nnunet_trainer, nnUNetTrainer), 'The requested nnunet trainer class must inherit from ' \ + 'nnUNetTrainer' + + # handle dataset input. If it's an ID we need to convert to int from string + if dataset_name_or_id.startswith('Dataset'): + pass + else: + try: + dataset_name_or_id = int(dataset_name_or_id) + except ValueError: + raise ValueError(f'dataset_name_or_id must either be an integer or a valid dataset name with the pattern ' + f'DatasetXXX_YYY where XXX are the three(!) task ID digits. Your ' + f'input: {dataset_name_or_id}') + + # initialize nnunet trainer + preprocessed_dataset_folder_base = join(nnUNet_preprocessed, maybe_convert_to_dataset_name(dataset_name_or_id)) + plans_file = join(preprocessed_dataset_folder_base, plans_identifier + '.json') + plans = load_json(plans_file) + dataset_json = load_json(join(preprocessed_dataset_folder_base, 'dataset.json')) + nnunet_trainer = nnunet_trainer(plans=plans, configuration=configuration, fold=fold, + dataset_json=dataset_json, unpack_dataset=not use_compressed, device=device) + return nnunet_trainer + + +def maybe_load_checkpoint(nnunet_trainer: nnUNetTrainer, continue_training: bool, validation_only: bool, + pretrained_weights_file: str = None): + if continue_training and pretrained_weights_file is not None: + raise RuntimeError('Cannot both continue a training AND load pretrained weights. Pretrained weights can only ' + 'be used at the beginning of the training.') + if continue_training: + expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_final.pth') + if not isfile(expected_checkpoint_file): + expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_latest.pth') + # special case where --c is used to run a previously aborted validation + if not isfile(expected_checkpoint_file): + expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_best.pth') + if not isfile(expected_checkpoint_file): + print(f"WARNING: Cannot continue training because there seems to be no checkpoint available to " + f"continue from. Starting a new training...") + expected_checkpoint_file = None + elif validation_only: + expected_checkpoint_file = join(nnunet_trainer.output_folder, 'checkpoint_final.pth') + if not isfile(expected_checkpoint_file): + raise RuntimeError(f"Cannot run validation because the training is not finished yet!") + else: + if pretrained_weights_file is not None: + if not nnunet_trainer.was_initialized: + nnunet_trainer.initialize() + load_pretrained_weights(nnunet_trainer.network, pretrained_weights_file, verbose=True) + expected_checkpoint_file = None + + if expected_checkpoint_file is not None: + nnunet_trainer.load_checkpoint(expected_checkpoint_file) + + +def setup_ddp(rank, world_size): + # initialize the process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + +def cleanup_ddp(): + dist.destroy_process_group() + + +def run_ddp(rank, dataset_name_or_id, configuration, fold, tr, p, use_compressed, disable_checkpointing, c, val, pretrained_weights, npz, world_size): + setup_ddp(rank, world_size) + torch.cuda.set_device(torch.device('cuda', dist.get_rank())) + + nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, tr, p, + use_compressed) + + if disable_checkpointing: + nnunet_trainer.disable_checkpointing = disable_checkpointing + + assert not (c and val), f'Cannot set --c and --val flag at the same time. Dummy.' + + maybe_load_checkpoint(nnunet_trainer, c, val, pretrained_weights) + + if torch.cuda.is_available(): + cudnn.deterministic = False + cudnn.benchmark = True + + if not val: + nnunet_trainer.run_training() + + nnunet_trainer.perform_actual_validation(npz) + cleanup_ddp() + + +def run_training(dataset_name_or_id: Union[str, int], + configuration: str, fold: Union[int, str], + trainer_class_name: str = 'nnUNetTrainer', + plans_identifier: str = 'nnUNetPlans', + pretrained_weights: Optional[str] = None, + num_gpus: int = 1, + use_compressed_data: bool = False, + export_validation_probabilities: bool = False, + continue_training: bool = False, + only_run_validation: bool = False, + disable_checkpointing: bool = False, + device: torch.device = torch.device('cuda')): + if isinstance(fold, str): + if fold != 'all': + try: + fold = int(fold) + except ValueError as e: + print(f'Unable to convert given value for fold to int: {fold}. fold must bei either "all" or an integer!') + raise e + + if num_gpus > 1: + assert device.type == 'cuda', f"DDP training (triggered by num_gpus > 1) is only implemented for cuda devices. Your device: {device}" + + os.environ['MASTER_ADDR'] = 'localhost' + if 'MASTER_PORT' not in os.environ.keys(): + port = str(find_free_network_port()) + print(f"using port {port}") + os.environ['MASTER_PORT'] = port # str(port) + + mp.spawn(run_ddp, + args=( + dataset_name_or_id, + configuration, + fold, + trainer_class_name, + plans_identifier, + use_compressed_data, + disable_checkpointing, + continue_training, + only_run_validation, + pretrained_weights, + export_validation_probabilities, + num_gpus), + nprocs=num_gpus, + join=True) + else: + nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, trainer_class_name, + plans_identifier, use_compressed_data, device=device) + + if disable_checkpointing: + nnunet_trainer.disable_checkpointing = disable_checkpointing + + assert not (continue_training and only_run_validation), f'Cannot set --c and --val flag at the same time. Dummy.' + + maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights) + + if torch.cuda.is_available(): + cudnn.deterministic = False + cudnn.benchmark = True + + if not only_run_validation: + nnunet_trainer.run_training() + + nnunet_trainer.perform_actual_validation(export_validation_probabilities) + + +def run_training_entry(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('dataset_name_or_id', type=str, + help="Dataset name or ID to train with") + parser.add_argument('configuration', type=str, + help="Configuration that should be trained") + parser.add_argument('fold', type=str, + help='Fold of the 5-fold cross-validation. Should be an int between 0 and 4.') + parser.add_argument('-tr', type=str, required=False, default='nnUNetTrainer', + help='[OPTIONAL] Use this flag to specify a custom trainer. Default: nnUNetTrainer') + parser.add_argument('-p', type=str, required=False, default='nnUNetPlans', + help='[OPTIONAL] Use this flag to specify a custom plans identifier. Default: nnUNetPlans') + parser.add_argument('-pretrained_weights', type=str, required=False, default=None, + help='[OPTIONAL] path to nnU-Net checkpoint file to be used as pretrained model. Will only ' + 'be used when actually training. Beta. Use with caution.') + parser.add_argument('-num_gpus', type=int, default=1, required=False, + help='Specify the number of GPUs to use for training') + parser.add_argument("--use_compressed", default=False, action="store_true", required=False, + help="[OPTIONAL] If you set this flag the training cases will not be decompressed. Reading compressed " + "data is much more CPU and (potentially) RAM intensive and should only be used if you " + "know what you are doing") + parser.add_argument('--npz', action='store_true', required=False, + help='[OPTIONAL] Save softmax predictions from final validation as npz files (in addition to predicted ' + 'segmentations). Needed for finding the best ensemble.') + parser.add_argument('--c', action='store_true', required=False, + help='[OPTIONAL] Continue training from latest checkpoint') + parser.add_argument('--val', action='store_true', required=False, + help='[OPTIONAL] Set this flag to only run the validation. Requires training to have finished.') + parser.add_argument('--disable_checkpointing', action='store_true', required=False, + help='[OPTIONAL] Set this flag to disable checkpointing. Ideal for testing things out and ' + 'you dont want to flood your hard drive with checkpoints.') + parser.add_argument('-device', type=str, default='cuda', required=False, + help="Use this to set the device the training should run with. Available options are 'cuda' " + "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! " + "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_train [...] instead!") + args = parser.parse_args() + + assert args.device in ['cpu', 'cuda', 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.' + if args.device == 'cpu': + # let's allow torch to use hella threads + import multiprocessing + torch.set_num_threads(multiprocessing.cpu_count()) + device = torch.device('cpu') + elif args.device == 'cuda': + # multithreading in torch doesn't help nnU-Net if run on GPU + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + device = torch.device('cuda') + else: + device = torch.device('mps') + + run_training(args.dataset_name_or_id, args.configuration, args.fold, args.tr, args.p, args.pretrained_weights, + args.num_gpus, args.use_compressed, args.npz, args.c, args.val, args.disable_checkpointing, + device=device) + + +if __name__ == '__main__': + run_training_entry() diff --git a/nnUNet/nnunetv2/tests/__init__.py b/nnUNet/nnunetv2/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/tests/integration_tests/__init__.py b/nnUNet/nnunetv2/tests/integration_tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/tests/integration_tests/add_lowres_and_cascade.py b/nnUNet/nnunetv2/tests/integration_tests/add_lowres_and_cascade.py new file mode 100644 index 0000000..a1b4df1 --- /dev/null +++ b/nnUNet/nnunetv2/tests/integration_tests/add_lowres_and_cascade.py @@ -0,0 +1,33 @@ +from batchgenerators.utilities.file_and_folder_operations import * + +from nnunetv2.paths import nnUNet_preprocessed +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('-d', nargs='+', type=int, help='List of dataset ids') + args = parser.parse_args() + + for d in args.d: + dataset_name = maybe_convert_to_dataset_name(d) + plans = load_json(join(nnUNet_preprocessed, dataset_name, 'nnUNetPlans.json')) + plans['configurations']['3d_lowres'] = { + "data_identifier": "nnUNetPlans_3d_lowres", # do not be a dumbo and forget this. I was a dumbo. And I paid dearly with ~10 min debugging time + 'inherits_from': '3d_fullres', + "patch_size": [20, 28, 20], + "median_image_size_in_voxels": [18.0, 25.0, 18.0], + "spacing": [2.0, 2.0, 2.0], + "n_conv_per_stage_encoder": [2, 2, 2], + "n_conv_per_stage_decoder": [2, 2], + "num_pool_per_axis": [2, 2, 2], + "pool_op_kernel_sizes": [[1, 1, 1], [2, 2, 2], [2, 2, 2]], + "conv_kernel_sizes": [[3, 3, 3], [3, 3, 3], [3, 3, 3]], + "next_stage": "3d_cascade_fullres" + } + plans['configurations']['3d_cascade_fullres'] = { + 'inherits_from': '3d_fullres', + "previous_stage": "3d_lowres" + } + save_json(plans, join(nnUNet_preprocessed, dataset_name, 'nnUNetPlans.json'), sort_keys=False) \ No newline at end of file diff --git a/nnUNet/nnunetv2/tests/integration_tests/cleanup_integration_test.py b/nnUNet/nnunetv2/tests/integration_tests/cleanup_integration_test.py new file mode 100644 index 0000000..c9fca95 --- /dev/null +++ b/nnUNet/nnunetv2/tests/integration_tests/cleanup_integration_test.py @@ -0,0 +1,19 @@ +import shutil + +from batchgenerators.utilities.file_and_folder_operations import isdir, join + +from nnunetv2.paths import nnUNet_raw, nnUNet_results, nnUNet_preprocessed + +if __name__ == '__main__': + # deletes everything! + dataset_names = [ + 'Dataset996_IntegrationTest_Hippocampus_regions_ignore', + 'Dataset997_IntegrationTest_Hippocampus_regions', + 'Dataset998_IntegrationTest_Hippocampus_ignore', + 'Dataset999_IntegrationTest_Hippocampus', + ] + for fld in [nnUNet_raw, nnUNet_preprocessed, nnUNet_results]: + for d in dataset_names: + if isdir(join(fld, d)): + shutil.rmtree(join(fld, d)) + diff --git a/nnUNet/nnunetv2/tests/integration_tests/lsf_commands.sh b/nnUNet/nnunetv2/tests/integration_tests/lsf_commands.sh new file mode 100644 index 0000000..3888c1a --- /dev/null +++ b/nnUNet/nnunetv2/tests/integration_tests/lsf_commands.sh @@ -0,0 +1,10 @@ +bsub -q gpu.legacy -gpu num=1:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test.sh 996" +bsub -q gpu.legacy -gpu num=1:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test.sh 997" +bsub -q gpu.legacy -gpu num=1:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test.sh 998" +bsub -q gpu.legacy -gpu num=1:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test.sh 999" + + +bsub -q gpu.legacy -gpu num=2:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh 996" +bsub -q gpu.legacy -gpu num=2:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh 997" +bsub -q gpu.legacy -gpu num=2:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh 998" +bsub -q gpu.legacy -gpu num=2:j_exclusive=yes:gmem=1G -L /bin/bash ". /home/isensee/load_env_cluster4.sh && cd /home/isensee/git_repos/nnunet_remake && export nnUNet_keep_files_open=True && . nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh 999" diff --git a/nnUNet/nnunetv2/tests/integration_tests/prepare_integration_tests.sh b/nnUNet/nnunetv2/tests/integration_tests/prepare_integration_tests.sh new file mode 100644 index 0000000..b5dda42 --- /dev/null +++ b/nnUNet/nnunetv2/tests/integration_tests/prepare_integration_tests.sh @@ -0,0 +1,18 @@ +# assumes you are in the nnunet repo! + +# prepare raw datasets +python nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset999_IntegrationTest_Hippocampus.py +python nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset998_IntegrationTest_Hippocampus_ignore.py +python nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset997_IntegrationTest_Hippocampus_regions.py +python nnunetv2/dataset_conversion/datasets_for_integration_tests/Dataset996_IntegrationTest_Hippocampus_regions_ignore.py + +# now run experiment planning without preprocessing +nnUNetv2_plan_and_preprocess -d 996 997 998 999 --no_pp + +# now add 3d lowres and cascade +python nnunetv2/tests/integration_tests/add_lowres_and_cascade.py -d 996 997 998 999 + +# now preprocess everything +nnUNetv2_preprocess -d 996 997 998 999 -c 2d 3d_lowres 3d_fullres -np 8 8 8 # no need to preprocess cascade as its the same data as 3d_fullres + +# done \ No newline at end of file diff --git a/nnUNet/nnunetv2/tests/integration_tests/readme.md b/nnUNet/nnunetv2/tests/integration_tests/readme.md new file mode 100644 index 0000000..2a44f13 --- /dev/null +++ b/nnUNet/nnunetv2/tests/integration_tests/readme.md @@ -0,0 +1,58 @@ +# Preface + +I am just a mortal with many tasks and limited time. Aint nobody got time for unittests. + +HOWEVER, at least some integration tests should be performed testing nnU-Net from start to finish. + +# Introduction - What the heck is happening? +This test covers all possible labeling scenarios (standard labels, regions, ignore labels and regions with +ignore labels). It runs the entire nnU-Net pipeline from start to finish: + +- fingerprint extraction +- experiment planning +- preprocessing +- train all 4 configurations (2d, 3d_lowres, 3d_fullres, 3d_cascade_fullres) as 5-fold CV +- automatically find the best model or ensemble +- determine the postprocessing used for this +- predict some test set +- apply postprocessing to the test set + +To speed things up, we do the following: +- pick Dataset004_Hippocampus because it is quadratisch praktisch gut. MNIST of medical image segmentation +- by default this dataset does not have 3d_lowres or cascade. We just manually add them (cool new feature, eh?). See `add_lowres_and_cascade.py` to learn more! +- we use nnUNetTrainer_5epochs for a short training + +# How to run it? + +Set your pwd to be the nnunet repo folder (the one where the `nnunetv2` folder and the `setup.py` are located!) + +Now generate the 4 dummy datasets (ids 996, 997, 998, 999) from dataset 4. This will crash if you don't have Dataset004! +```commandline +bash nnunetv2/tests/integration_tests/prepare_integration_tests.sh +``` + +Now you can run the integration test for each of the datasets: +```commandline +bash nnunetv2/tests/integration_tests/run_integration_test.sh DATSET_ID +``` +use DATSET_ID 996, 997, 998 and 999. You can run these independently on different GPUs/systems to speed things up. +This will take i dunno like 10-30 Minutes!? + +Also run +```commandline +bash nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh DATSET_ID +``` +to verify DDP is working (needs 2 GPUs!) + +# How to check if the test was successful? +If I was not as lazy as I am I would have programmed some automatism that checks if Dice scores etc are in an acceptable range. +So you need to do the following: +1) check that none of your runs crashed (duh) +2) for each run, navigate to `nnUNet_results/DATASET_NAME` and take a look at the `inference_information.json` file. +Does it make sense? If so: NICE! + +Once the integration test is completed you can delete all the temporary files associated with it by running: + +```commandline +python nnunetv2/tests/integration_tests/cleanup_integration_test.py +``` \ No newline at end of file diff --git a/nnUNet/nnunetv2/tests/integration_tests/run_integration_test.sh b/nnUNet/nnunetv2/tests/integration_tests/run_integration_test.sh new file mode 100644 index 0000000..ff0426c --- /dev/null +++ b/nnUNet/nnunetv2/tests/integration_tests/run_integration_test.sh @@ -0,0 +1,27 @@ + + +nnUNetv2_train $1 3d_fullres 0 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 3d_fullres 1 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 3d_fullres 2 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 3d_fullres 3 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 3d_fullres 4 -tr nnUNetTrainer_5epochs --npz + +nnUNetv2_train $1 2d 0 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 2d 1 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 2d 2 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 2d 3 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 2d 4 -tr nnUNetTrainer_5epochs --npz + +nnUNetv2_train $1 3d_lowres 0 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 3d_lowres 1 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 3d_lowres 2 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 3d_lowres 3 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 3d_lowres 4 -tr nnUNetTrainer_5epochs --npz + +nnUNetv2_train $1 3d_cascade_fullres 0 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 3d_cascade_fullres 1 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 3d_cascade_fullres 2 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 3d_cascade_fullres 3 -tr nnUNetTrainer_5epochs --npz +nnUNetv2_train $1 3d_cascade_fullres 4 -tr nnUNetTrainer_5epochs --npz + +python nnunetv2/tests/integration_tests/run_integration_test_bestconfig_inference.py -d $1 \ No newline at end of file diff --git a/nnUNet/nnunetv2/tests/integration_tests/run_integration_test_bestconfig_inference.py b/nnUNet/nnunetv2/tests/integration_tests/run_integration_test_bestconfig_inference.py new file mode 100644 index 0000000..89e783e --- /dev/null +++ b/nnUNet/nnunetv2/tests/integration_tests/run_integration_test_bestconfig_inference.py @@ -0,0 +1,75 @@ +import argparse + +import torch +from batchgenerators.utilities.file_and_folder_operations import join, load_pickle + +from nnunetv2.ensembling.ensemble import ensemble_folders +from nnunetv2.evaluation.find_best_configuration import find_best_configuration, \ + dumb_trainer_config_plans_to_trained_models_dict +from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor +from nnunetv2.paths import nnUNet_raw, nnUNet_results +from nnunetv2.postprocessing.remove_connected_components import apply_postprocessing_to_folder +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name +from nnunetv2.utilities.file_path_utilities import get_output_folder + + +if __name__ == '__main__': + """ + Predicts the imagesTs folder with the best configuration and applies postprocessing + """ + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + + parser = argparse.ArgumentParser() + parser.add_argument('-d', type=int, help='dataset id') + args = parser.parse_args() + d = args.d + + dataset_name = maybe_convert_to_dataset_name(d) + source_dir = join(nnUNet_raw, dataset_name, 'imagesTs') + target_dir_base = join(nnUNet_results, dataset_name) + + models = dumb_trainer_config_plans_to_trained_models_dict(['nnUNetTrainer_5epochs'], + ['2d', + '3d_lowres', + '3d_cascade_fullres', + '3d_fullres'], + ['nnUNetPlans']) + ret = find_best_configuration(d, models, allow_ensembling=True, num_processes=8, overwrite=True, + folds=(0, 1, 2, 3, 4), strict=True) + + has_ensemble = len(ret['best_model_or_ensemble']['selected_model_or_models']) > 1 + + # we don't use all folds to speed stuff up + used_folds = (0, 3) + output_folders = [] + for im in ret['best_model_or_ensemble']['selected_model_or_models']: + output_dir = join(target_dir_base, f"pred_{im['configuration']}") + model_folder = get_output_folder(d, im['trainer'], im['plans_identifier'], im['configuration']) + # note that if the best model is the enseble of 3d_lowres and 3d cascade then 3d_lowres will be predicted + # twice (once standalone and once to generate the predictions for the cascade) because we don't reuse the + # prediction here. Proper way would be to check for that and + # then give the output of 3d_lowres inference to the folder_with_segs_from_prev_stage kwarg in + # predict_from_raw_data. Since we allow for + # dynamically setting 'previous_stage' in the plans I am too lazy to implement this here. This is just an + # integration test after all. Take a closer look at how this in handled in predict_from_raw_data + predictor = nnUNetPredictor(verbose=False, allow_tqdm=False) + predictor.initialize_from_trained_model_folder(model_folder, used_folds) + predictor.predict_from_files(source_dir, output_dir, has_ensemble, overwrite=True) + # predict_from_raw_data(list_of_lists_or_source_folder=source_dir, output_folder=output_dir, + # model_training_output_dir=model_folder, use_folds=used_folds, + # save_probabilities=has_ensemble, verbose=False, overwrite=True) + output_folders.append(output_dir) + + # if we have an ensemble, we need to ensemble the results + if has_ensemble: + ensemble_folders(output_folders, join(target_dir_base, 'ensemble_predictions'), save_merged_probabilities=False) + folder_for_pp = join(target_dir_base, 'ensemble_predictions') + else: + folder_for_pp = output_folders[0] + + # apply postprocessing + pp_fns, pp_fn_kwargs = load_pickle(ret['best_model_or_ensemble']['postprocessing_file']) + apply_postprocessing_to_folder(folder_for_pp, join(target_dir_base, 'ensemble_predictions_postprocessed'), + pp_fns, + pp_fn_kwargs, plans_file_or_dict=ret['best_model_or_ensemble']['some_plans_file']) diff --git a/nnUNet/nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh b/nnUNet/nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh new file mode 100644 index 0000000..5199247 --- /dev/null +++ b/nnUNet/nnunetv2/tests/integration_tests/run_integration_test_trainingOnly_DDP.sh @@ -0,0 +1 @@ +nnUNetv2_train $1 3d_fullres 0 -tr nnUNetTrainer_10epochs -num_gpus 2 diff --git a/nnUNet/nnunetv2/training/__init__.py b/nnUNet/nnunetv2/training/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/training/data_augmentation/__init__.py b/nnUNet/nnunetv2/training/data_augmentation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/training/data_augmentation/compute_initial_patch_size.py b/nnUNet/nnunetv2/training/data_augmentation/compute_initial_patch_size.py new file mode 100644 index 0000000..a772bc2 --- /dev/null +++ b/nnUNet/nnunetv2/training/data_augmentation/compute_initial_patch_size.py @@ -0,0 +1,24 @@ +import numpy as np + + +def get_patch_size(final_patch_size, rot_x, rot_y, rot_z, scale_range): + if isinstance(rot_x, (tuple, list)): + rot_x = max(np.abs(rot_x)) + if isinstance(rot_y, (tuple, list)): + rot_y = max(np.abs(rot_y)) + if isinstance(rot_z, (tuple, list)): + rot_z = max(np.abs(rot_z)) + rot_x = min(90 / 360 * 2. * np.pi, rot_x) + rot_y = min(90 / 360 * 2. * np.pi, rot_y) + rot_z = min(90 / 360 * 2. * np.pi, rot_z) + from batchgenerators.augmentations.utils import rotate_coords_3d, rotate_coords_2d + coords = np.array(final_patch_size) + final_shape = np.copy(coords) + if len(coords) == 3: + final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, rot_x, 0, 0)), final_shape)), 0) + final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, rot_y, 0)), final_shape)), 0) + final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, 0, rot_z)), final_shape)), 0) + elif len(coords) == 2: + final_shape = np.max(np.vstack((np.abs(rotate_coords_2d(coords, rot_x)), final_shape)), 0) + final_shape /= min(scale_range) + return final_shape.astype(int) diff --git a/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/__init__.py b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/cascade_transforms.py b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/cascade_transforms.py new file mode 100644 index 0000000..378bab2 --- /dev/null +++ b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/cascade_transforms.py @@ -0,0 +1,136 @@ +from typing import Union, List, Tuple, Callable + +import numpy as np +from acvl_utils.morphology.morphology_helper import label_with_component_sizes +from batchgenerators.transforms.abstract_transforms import AbstractTransform +from skimage.morphology import ball +from skimage.morphology.binary import binary_erosion, binary_dilation, binary_closing, binary_opening + + +class MoveSegAsOneHotToData(AbstractTransform): + def __init__(self, index_in_origin: int, all_labels: Union[Tuple[int, ...], List[int]], + key_origin="seg", key_target="data", remove_from_origin=True): + """ + Takes data_dict[seg][:, index_in_origin], converts it to one hot encoding and appends it to + data_dict[key_target]. Optionally removes index_in_origin from data_dict[seg]. + """ + self.remove_from_origin = remove_from_origin + self.all_labels = all_labels + self.key_target = key_target + self.key_origin = key_origin + self.index_in_origin = index_in_origin + + def __call__(self, **data_dict): + seg = data_dict[self.key_origin][:, self.index_in_origin:self.index_in_origin+1] + + seg_onehot = np.zeros((seg.shape[0], len(self.all_labels), *seg.shape[2:]), + dtype=data_dict[self.key_target].dtype) + for i, l in enumerate(self.all_labels): + seg_onehot[:, i][seg[:, 0] == l] = 1 + + data_dict[self.key_target] = np.concatenate((data_dict[self.key_target], seg_onehot), 1) + + if self.remove_from_origin: + remaining_channels = [i for i in range(data_dict[self.key_origin].shape[1]) if i != self.index_in_origin] + data_dict[self.key_origin] = data_dict[self.key_origin][:, remaining_channels] + + return data_dict + + +class RemoveRandomConnectedComponentFromOneHotEncodingTransform(AbstractTransform): + def __init__(self, channel_idx: Union[int, List[int]], key: str = "data", p_per_sample: float = 0.2, + fill_with_other_class_p: float = 0.25, + dont_do_if_covers_more_than_x_percent: float = 0.25, p_per_label: float = 1): + """ + Randomly removes connected components in the specified channel_idx of data_dict[key]. Only considers components + smaller than dont_do_if_covers_more_than_X_percent of the sample. Also has the option of simulating + misclassification as another class (fill_with_other_class_p) + """ + self.p_per_label = p_per_label + self.dont_do_if_covers_more_than_x_percent = dont_do_if_covers_more_than_x_percent + self.fill_with_other_class_p = fill_with_other_class_p + self.p_per_sample = p_per_sample + self.key = key + if not isinstance(channel_idx, (list, tuple)): + channel_idx = [channel_idx] + self.channel_idx = channel_idx + + def __call__(self, **data_dict): + data = data_dict.get(self.key) + for b in range(data.shape[0]): + if np.random.uniform() < self.p_per_sample: + for c in self.channel_idx: + if np.random.uniform() < self.p_per_label: + # print(np.unique(data[b, c])) ## should be [0, 1] + workon = data[b, c].astype(bool) + if not np.any(workon): + continue + num_voxels = np.prod(workon.shape, dtype=np.uint64) + lab, component_sizes = label_with_component_sizes(workon.astype(bool)) + if len(component_sizes) > 0: + valid_component_ids = [i for i, j in component_sizes.items() if j < + num_voxels*self.dont_do_if_covers_more_than_x_percent] + # print('RemoveRandomConnectedComponentFromOneHotEncodingTransform', c, + # np.unique(data[b, c]), len(component_sizes), valid_component_ids, + # len(valid_component_ids)) + if len(valid_component_ids) > 0: + random_component = np.random.choice(valid_component_ids) + data[b, c][lab == random_component] = 0 + if np.random.uniform() < self.fill_with_other_class_p: + other_ch = [i for i in self.channel_idx if i != c] + if len(other_ch) > 0: + other_class = np.random.choice(other_ch) + data[b, other_class][lab == random_component] = 1 + data_dict[self.key] = data + return data_dict + + +class ApplyRandomBinaryOperatorTransform(AbstractTransform): + def __init__(self, + channel_idx: Union[int, List[int], Tuple[int, ...]], + p_per_sample: float = 0.3, + any_of_these: Tuple[Callable] = (binary_dilation, binary_erosion, binary_closing, binary_opening), + key: str = "data", + strel_size: Tuple[int, int] = (1, 10), + p_per_label: float = 1): + """ + Applies random binary operations (specified by any_of_these) with random ball size (radius is uniformly sampled + from interval strel_size) to specified channels. Expects the channel_idx to correspond to a hone hot encoded + segmentation (see for example MoveSegAsOneHotToData) + """ + self.p_per_label = p_per_label + self.strel_size = strel_size + self.key = key + self.any_of_these = any_of_these + self.p_per_sample = p_per_sample + + if not isinstance(channel_idx, (list, tuple)): + channel_idx = [channel_idx] + self.channel_idx = channel_idx + + def __call__(self, **data_dict): + for b in range(data_dict[self.key].shape[0]): + if np.random.uniform() < self.p_per_sample: + # this needs to be applied in random order to the channels + np.random.shuffle(self.channel_idx) + for c in self.channel_idx: + if np.random.uniform() < self.p_per_label: + operation = np.random.choice(self.any_of_these) + selem = ball(np.random.uniform(*self.strel_size)) + workon = data_dict[self.key][b, c].astype(bool) + if not np.any(workon): + continue + # print(np.unique(workon)) + res = operation(workon, selem).astype(data_dict[self.key].dtype) + # print('ApplyRandomBinaryOperatorTransform', c, operation, np.sum(workon), np.sum(res)) + data_dict[self.key][b, c] = res + + # if class was added, we need to remove it in ALL other channels to keep one hot encoding + # properties + other_ch = [i for i in self.channel_idx if i != c] + if len(other_ch) > 0: + was_added_mask = (res - workon) > 0 + for oc in other_ch: + data_dict[self.key][b, oc][was_added_mask] = 0 + # if class was removed, leave it at background + return data_dict diff --git a/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/deep_supervision_donwsampling.py b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/deep_supervision_donwsampling.py new file mode 100644 index 0000000..6469ee2 --- /dev/null +++ b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/deep_supervision_donwsampling.py @@ -0,0 +1,55 @@ +from typing import Tuple, Union, List + +from batchgenerators.augmentations.utils import resize_segmentation +from batchgenerators.transforms.abstract_transforms import AbstractTransform +import numpy as np + + +class DownsampleSegForDSTransform2(AbstractTransform): + ''' + data_dict['output_key'] will be a list of segmentations scaled according to ds_scales + ''' + def __init__(self, ds_scales: Union[List, Tuple], + order: int = 0, input_key: str = "seg", + output_key: str = "seg", axes: Tuple[int] = None): + """ + Downscales data_dict[input_key] according to ds_scales. Each entry in ds_scales specified one deep supervision + output and its resolution relative to the original data, for example 0.25 specifies 1/4 of the original shape. + ds_scales can also be a tuple of tuples, for example ((1, 1, 1), (0.5, 0.5, 0.5)) to specify the downsampling + for each axis independently + """ + self.axes = axes + self.output_key = output_key + self.input_key = input_key + self.order = order + self.ds_scales = ds_scales + + def __call__(self, **data_dict): + if self.axes is None: + axes = list(range(2, len(data_dict[self.input_key].shape))) + else: + axes = self.axes + + output = [] + for s in self.ds_scales: + if not isinstance(s, (tuple, list)): + s = [s] * len(axes) + else: + assert len(s) == len(axes), f'If ds_scales is a tuple for each resolution (one downsampling factor ' \ + f'for each axis) then the number of entried in that tuple (here ' \ + f'{len(s)}) must be the same as the number of axes (here {len(axes)}).' + + if all([i == 1 for i in s]): + output.append(data_dict[self.input_key]) + else: + new_shape = np.array(data_dict[self.input_key].shape).astype(float) + for i, a in enumerate(axes): + new_shape[a] *= s[i] + new_shape = np.round(new_shape).astype(int) + out_seg = np.zeros(new_shape, dtype=data_dict[self.input_key].dtype) + for b in range(data_dict[self.input_key].shape[0]): + for c in range(data_dict[self.input_key].shape[1]): + out_seg[b, c] = resize_segmentation(data_dict[self.input_key][b, c], new_shape[2:], self.order) + output.append(out_seg) + data_dict[self.output_key] = output + return data_dict diff --git a/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/limited_length_multithreaded_augmenter.py b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/limited_length_multithreaded_augmenter.py new file mode 100644 index 0000000..dd8368c --- /dev/null +++ b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/limited_length_multithreaded_augmenter.py @@ -0,0 +1,10 @@ +from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter + + +class LimitedLenWrapper(NonDetMultiThreadedAugmenter): + def __init__(self, my_imaginary_length, *args, **kwargs): + super().__init__(*args, **kwargs) + self.len = my_imaginary_length + + def __len__(self): + return self.len diff --git a/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/manipulating_data_dict.py b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/manipulating_data_dict.py new file mode 100644 index 0000000..587acd7 --- /dev/null +++ b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/manipulating_data_dict.py @@ -0,0 +1,10 @@ +from batchgenerators.transforms.abstract_transforms import AbstractTransform + + +class RemoveKeyTransform(AbstractTransform): + def __init__(self, key_to_remove: str): + self.key_to_remove = key_to_remove + + def __call__(self, **data_dict): + _ = data_dict.pop(self.key_to_remove, None) + return data_dict diff --git a/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/masking.py b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/masking.py new file mode 100644 index 0000000..b009993 --- /dev/null +++ b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/masking.py @@ -0,0 +1,22 @@ +from typing import List + +from batchgenerators.transforms.abstract_transforms import AbstractTransform + + +class MaskTransform(AbstractTransform): + def __init__(self, apply_to_channels: List[int], mask_idx_in_seg: int = 0, set_outside_to: int = 0, + data_key: str = "data", seg_key: str = "seg"): + """ + Sets everything outside the mask to 0. CAREFUL! outside is defined as < 0, not =0 (in the Mask)!!! + """ + self.apply_to_channels = apply_to_channels + self.seg_key = seg_key + self.data_key = data_key + self.set_outside_to = set_outside_to + self.mask_idx_in_seg = mask_idx_in_seg + + def __call__(self, **data_dict): + mask = data_dict[self.seg_key][:, self.mask_idx_in_seg] < 0 + for c in self.apply_to_channels: + data_dict[self.data_key][:, c][mask] = self.set_outside_to + return data_dict diff --git a/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py new file mode 100644 index 0000000..52d2fc0 --- /dev/null +++ b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py @@ -0,0 +1,38 @@ +from typing import List, Tuple, Union + +from batchgenerators.transforms.abstract_transforms import AbstractTransform +import numpy as np + + +class ConvertSegmentationToRegionsTransform(AbstractTransform): + def __init__(self, regions: Union[List, Tuple], + seg_key: str = "seg", output_key: str = "seg", seg_channel: int = 0): + """ + regions are tuple of tuples where each inner tuple holds the class indices that are merged into one region, + example: + regions= ((1, 2), (2, )) will result in 2 regions: one covering the region of labels 1&2 and the other just 2 + :param regions: + :param seg_key: + :param output_key: + """ + self.seg_channel = seg_channel + self.output_key = output_key + self.seg_key = seg_key + self.regions = regions + + def __call__(self, **data_dict): + seg = data_dict.get(self.seg_key) + num_regions = len(self.regions) + if seg is not None: + seg_shp = seg.shape + output_shape = list(seg_shp) + output_shape[1] = num_regions + region_output = np.zeros(output_shape, dtype=seg.dtype) + for b in range(seg_shp[0]): + for region_id, region_source_labels in enumerate(self.regions): + if not isinstance(region_source_labels, (list, tuple)): + region_source_labels = (region_source_labels, ) + for label_value in region_source_labels: + region_output[b, region_id][seg[b, self.seg_channel] == label_value] = 1 + data_dict[self.output_key] = region_output + return data_dict diff --git a/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/transforms_for_dummy_2d.py b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/transforms_for_dummy_2d.py new file mode 100644 index 0000000..340fce7 --- /dev/null +++ b/nnUNet/nnunetv2/training/data_augmentation/custom_transforms/transforms_for_dummy_2d.py @@ -0,0 +1,45 @@ +from typing import Tuple, Union, List + +from batchgenerators.transforms.abstract_transforms import AbstractTransform + + +class Convert3DTo2DTransform(AbstractTransform): + def __init__(self, apply_to_keys: Union[List[str], Tuple[str]] = ('data', 'seg')): + """ + Transforms a 5D array (b, c, x, y, z) to a 4D array (b, c * x, y, z) by overloading the color channel + """ + self.apply_to_keys = apply_to_keys + + def __call__(self, **data_dict): + for k in self.apply_to_keys: + shp = data_dict[k].shape + assert len(shp) == 5, 'This transform only works on 3D data, so expects 5D tensor (b, c, x, y, z) as input.' + data_dict[k] = data_dict[k].reshape((shp[0], shp[1] * shp[2], shp[3], shp[4])) + shape_key = f'orig_shape_{k}' + assert shape_key not in data_dict.keys(), f'Convert3DTo2DTransform needs to store the original shape. ' \ + f'It does that using the {shape_key} key. That key is ' \ + f'already taken. Bummer.' + data_dict[shape_key] = shp + return data_dict + + +class Convert2DTo3DTransform(AbstractTransform): + def __init__(self, apply_to_keys: Union[List[str], Tuple[str]] = ('data', 'seg')): + """ + Reverts Convert3DTo2DTransform by transforming a 4D array (b, c * x, y, z) back to 5D (b, c, x, y, z) + """ + self.apply_to_keys = apply_to_keys + + def __call__(self, **data_dict): + for k in self.apply_to_keys: + shape_key = f'orig_shape_{k}' + assert shape_key in data_dict.keys(), f'Did not find key {shape_key} in data_dict. Shitty. ' \ + f'Convert2DTo3DTransform only works in tandem with ' \ + f'Convert3DTo2DTransform and you probably forgot to add ' \ + f'Convert3DTo2DTransform to your pipeline. (Convert3DTo2DTransform ' \ + f'is where the missing key is generated)' + original_shape = data_dict[shape_key] + current_shape = data_dict[k].shape + data_dict[k] = data_dict[k].reshape((original_shape[0], original_shape[1], original_shape[2], + current_shape[-2], current_shape[-1])) + return data_dict diff --git a/nnUNet/nnunetv2/training/dataloading/__init__.py b/nnUNet/nnunetv2/training/dataloading/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/training/dataloading/base_data_loader.py b/nnUNet/nnunetv2/training/dataloading/base_data_loader.py new file mode 100644 index 0000000..6a6a49f --- /dev/null +++ b/nnUNet/nnunetv2/training/dataloading/base_data_loader.py @@ -0,0 +1,139 @@ +from typing import Union, Tuple + +from batchgenerators.dataloading.data_loader import DataLoader +import numpy as np +from batchgenerators.utilities.file_and_folder_operations import * +from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset +from nnunetv2.utilities.label_handling.label_handling import LabelManager + + +class nnUNetDataLoaderBase(DataLoader): + def __init__(self, + data: nnUNetDataset, + batch_size: int, + patch_size: Union[List[int], Tuple[int, ...], np.ndarray], + final_patch_size: Union[List[int], Tuple[int, ...], np.ndarray], + label_manager: LabelManager, + oversample_foreground_percent: float = 0.0, + sampling_probabilities: Union[List[int], Tuple[int, ...], np.ndarray] = None, + pad_sides: Union[List[int], Tuple[int, ...], np.ndarray] = None, + probabilistic_oversampling: bool = False): + super().__init__(data, batch_size, 1, None, True, False, True, sampling_probabilities) + assert isinstance(data, nnUNetDataset), 'nnUNetDataLoaderBase only supports dictionaries as data' + self.indices = list(data.keys()) + + self.oversample_foreground_percent = oversample_foreground_percent + self.final_patch_size = final_patch_size + self.patch_size = patch_size + self.list_of_keys = list(self._data.keys()) + # need_to_pad denotes by how much we need to pad the data so that if we sample a patch of size final_patch_size + # (which is what the network will get) these patches will also cover the border of the images + self.need_to_pad = (np.array(patch_size) - np.array(final_patch_size)).astype(int) + if pad_sides is not None: + if not isinstance(pad_sides, np.ndarray): + pad_sides = np.array(pad_sides) + self.need_to_pad += pad_sides + self.num_channels = None + self.pad_sides = pad_sides + self.data_shape, self.seg_shape = self.determine_shapes() + self.sampling_probabilities = sampling_probabilities + self.annotated_classes_key = tuple(label_manager.all_labels) + self.has_ignore = label_manager.has_ignore_label + self.get_do_oversample = self._oversample_last_XX_percent if not probabilistic_oversampling \ + else self._probabilistic_oversampling + + def _oversample_last_XX_percent(self, sample_idx: int) -> bool: + """ + determines whether sample sample_idx in a minibatch needs to be guaranteed foreground + """ + return not sample_idx < round(self.batch_size * (1 - self.oversample_foreground_percent)) + + def _probabilistic_oversampling(self, sample_idx: int) -> bool: + # print('YEAH BOIIIIII') + return np.random.uniform() < self.oversample_foreground_percent + + def determine_shapes(self): + # load one case + data, seg, properties = self._data.load_case(self.indices[0]) + num_color_channels = data.shape[0] + + data_shape = (self.batch_size, num_color_channels, *self.patch_size) + seg_shape = (self.batch_size, seg.shape[0], *self.patch_size) + return data_shape, seg_shape + + def get_bbox(self, data_shape: np.ndarray, force_fg: bool, class_locations: Union[dict, None], + overwrite_class: Union[int, Tuple[int, ...]] = None, verbose: bool = False): + # in dataloader 2d we need to select the slice prior to this and also modify the class_locations to only have + # locations for the given slice + need_to_pad = self.need_to_pad.copy() + dim = len(data_shape) + + for d in range(dim): + # if case_all_data.shape + need_to_pad is still < patch size we need to pad more! We pad on both sides + # always + if need_to_pad[d] + data_shape[d] < self.patch_size[d]: + need_to_pad[d] = self.patch_size[d] - data_shape[d] + + # we can now choose the bbox from -need_to_pad // 2 to shape - patch_size + need_to_pad // 2. Here we + # define what the upper and lower bound can be to then sample form them with np.random.randint + lbs = [- need_to_pad[i] // 2 for i in range(dim)] + ubs = [data_shape[i] + need_to_pad[i] // 2 + need_to_pad[i] % 2 - self.patch_size[i] for i in range(dim)] + + # if not force_fg then we can just sample the bbox randomly from lb and ub. Else we need to make sure we get + # at least one of the foreground classes in the patch + if not force_fg and not self.has_ignore: + bbox_lbs = [np.random.randint(lbs[i], ubs[i] + 1) for i in range(dim)] + # print('I want a random location') + else: + if not force_fg and self.has_ignore: + selected_class = self.annotated_classes_key + if len(class_locations[selected_class]) == 0: + # no annotated pixels in this case. Not good. But we can hardly skip it here + print('Warning! No annotated pixels in image!') + selected_class = None + # print(f'I have ignore labels and want to pick a labeled area. annotated_classes_key: {self.annotated_classes_key}') + elif force_fg: + assert class_locations is not None, 'if force_fg is set class_locations cannot be None' + if overwrite_class is not None: + assert overwrite_class in class_locations.keys(), 'desired class ("overwrite_class") does not ' \ + 'have class_locations (missing key)' + # this saves us a np.unique. Preprocessing already did that for all cases. Neat. + # class_locations keys can also be tuple + eligible_classes_or_regions = [i for i in class_locations.keys() if len(class_locations[i]) > 0] + + # if we have annotated_classes_key locations and other classes are present, remove the annotated_classes_key from the list + # strange formulation needed to circumvent + # ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() + tmp = [i == self.annotated_classes_key if isinstance(i, tuple) else False for i in eligible_classes_or_regions] + if any(tmp): + if len(eligible_classes_or_regions) > 1: + eligible_classes_or_regions.pop(np.where(tmp)[0][0]) + + if len(eligible_classes_or_regions) == 0: + # this only happens if some image does not contain foreground voxels at all + selected_class = None + if verbose: + print('case does not contain any foreground classes') + else: + # I hate myself. Future me aint gonna be happy to read this + # 2022_11_25: had to read it today. Wasn't too bad + selected_class = eligible_classes_or_regions[np.random.choice(len(eligible_classes_or_regions))] if \ + (overwrite_class is None or (overwrite_class not in eligible_classes_or_regions)) else overwrite_class + # print(f'I want to have foreground, selected class: {selected_class}') + else: + raise RuntimeError('lol what!?') + voxels_of_that_class = class_locations[selected_class] if selected_class is not None else None + + if voxels_of_that_class is not None and len(voxels_of_that_class) > 0: + selected_voxel = voxels_of_that_class[np.random.choice(len(voxels_of_that_class))] + # selected voxel is center voxel. Subtract half the patch size to get lower bbox voxel. + # Make sure it is within the bounds of lb and ub + # i + 1 because we have first dimension 0! + bbox_lbs = [max(lbs[i], selected_voxel[i + 1] - self.patch_size[i] // 2) for i in range(dim)] + else: + # If the image does not contain any foreground classes, we fall back to random cropping + bbox_lbs = [np.random.randint(lbs[i], ubs[i] + 1) for i in range(dim)] + + bbox_ubs = [bbox_lbs[i] + self.patch_size[i] for i in range(dim)] + + return bbox_lbs, bbox_ubs diff --git a/nnUNet/nnunetv2/training/dataloading/data_loader_2d.py b/nnUNet/nnunetv2/training/dataloading/data_loader_2d.py new file mode 100644 index 0000000..b44004f --- /dev/null +++ b/nnUNet/nnunetv2/training/dataloading/data_loader_2d.py @@ -0,0 +1,93 @@ +import numpy as np +from nnunetv2.training.dataloading.base_data_loader import nnUNetDataLoaderBase +from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset + + +class nnUNetDataLoader2D(nnUNetDataLoaderBase): + def generate_train_batch(self): + selected_keys = self.get_indices() + # preallocate memory for data and seg + data_all = np.zeros(self.data_shape, dtype=np.float32) + seg_all = np.zeros(self.seg_shape, dtype=np.int16) + case_properties = [] + + for j, current_key in enumerate(selected_keys): + # oversampling foreground will improve stability of model training, especially if many patches are empty + # (Lung for example) + force_fg = self.get_do_oversample(j) + data, seg, properties = self._data.load_case(current_key) + + # select a class/region first, then a slice where this class is present, then crop to that area + if not force_fg: + if self.has_ignore: + selected_class_or_region = self.annotated_classes_key + else: + selected_class_or_region = None + else: + # filter out all classes that are not present here + eligible_classes_or_regions = [i for i in properties['class_locations'].keys() if len(properties['class_locations'][i]) > 0] + + # if we have annotated_classes_key locations and other classes are present, remove the annotated_classes_key from the list + # strange formulation needed to circumvent + # ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() + tmp = [i == self.annotated_classes_key if isinstance(i, tuple) else False for i in eligible_classes_or_regions] + if any(tmp): + if len(eligible_classes_or_regions) > 1: + eligible_classes_or_regions.pop(np.where(tmp)[0][0]) + + selected_class_or_region = eligible_classes_or_regions[np.random.choice(len(eligible_classes_or_regions))] if \ + len(eligible_classes_or_regions) > 0 else None + if selected_class_or_region is not None: + selected_slice = np.random.choice(properties['class_locations'][selected_class_or_region][:, 1]) + else: + selected_slice = np.random.choice(len(data[0])) + + data = data[:, selected_slice] + seg = seg[:, selected_slice] + + # the line of death lol + # this needs to be a separate variable because we could otherwise permanently overwrite + # properties['class_locations'] + # selected_class_or_region is: + # - None if we do not have an ignore label and force_fg is False OR if force_fg is True but there is no foreground in the image + # - A tuple of all (non-ignore) labels if there is an ignore label and force_fg is False + # - a class or region if force_fg is True + class_locations = { + selected_class_or_region: properties['class_locations'][selected_class_or_region][properties['class_locations'][selected_class_or_region][:, 1] == selected_slice][:, (0, 2, 3)] + } if (selected_class_or_region is not None) else None + + # print(properties) + shape = data.shape[1:] + dim = len(shape) + bbox_lbs, bbox_ubs = self.get_bbox(shape, force_fg if selected_class_or_region is not None else None, + class_locations, overwrite_class=selected_class_or_region) + + # whoever wrote this knew what he was doing (hint: it was me). We first crop the data to the region of the + # bbox that actually lies within the data. This will result in a smaller array which is then faster to pad. + # valid_bbox is just the coord that lied within the data cube. It will be padded to match the patch size + # later + valid_bbox_lbs = [max(0, bbox_lbs[i]) for i in range(dim)] + valid_bbox_ubs = [min(shape[i], bbox_ubs[i]) for i in range(dim)] + + # At this point you might ask yourself why we would treat seg differently from seg_from_previous_stage. + # Why not just concatenate them here and forget about the if statements? Well that's because segneeds to + # be padded with -1 constant whereas seg_from_previous_stage needs to be padded with 0s (we could also + # remove label -1 in the data augmentation but this way it is less error prone) + this_slice = tuple([slice(0, data.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)]) + data = data[this_slice] + + this_slice = tuple([slice(0, seg.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)]) + seg = seg[this_slice] + + padding = [(-min(0, bbox_lbs[i]), max(bbox_ubs[i] - shape[i], 0)) for i in range(dim)] + data_all[j] = np.pad(data, ((0, 0), *padding), 'constant', constant_values=0) + seg_all[j] = np.pad(seg, ((0, 0), *padding), 'constant', constant_values=-1) + + return {'data': data_all, 'seg': seg_all, 'properties': case_properties, 'keys': selected_keys} + + +if __name__ == '__main__': + folder = '/media/fabian/data/nnUNet_preprocessed/Dataset004_Hippocampus/2d' + ds = nnUNetDataset(folder, None, 1000) # this should not load the properties! + dl = nnUNetDataLoader2D(ds, 366, (65, 65), (56, 40), 0.33, None, None) + a = next(dl) diff --git a/nnUNet/nnunetv2/training/dataloading/data_loader_3d.py b/nnUNet/nnunetv2/training/dataloading/data_loader_3d.py new file mode 100644 index 0000000..ab755e3 --- /dev/null +++ b/nnUNet/nnunetv2/training/dataloading/data_loader_3d.py @@ -0,0 +1,55 @@ +import numpy as np +from nnunetv2.training.dataloading.base_data_loader import nnUNetDataLoaderBase +from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset + + +class nnUNetDataLoader3D(nnUNetDataLoaderBase): + def generate_train_batch(self): + selected_keys = self.get_indices() + # preallocate memory for data and seg + data_all = np.zeros(self.data_shape, dtype=np.float32) + seg_all = np.zeros(self.seg_shape, dtype=np.int16) + case_properties = [] + + for j, i in enumerate(selected_keys): + # oversampling foreground will improve stability of model training, especially if many patches are empty + # (Lung for example) + force_fg = self.get_do_oversample(j) + + data, seg, properties = self._data.load_case(i) + + # If we are doing the cascade then the segmentation from the previous stage will already have been loaded by + # self._data.load_case(i) (see nnUNetDataset.load_case) + shape = data.shape[1:] + dim = len(shape) + bbox_lbs, bbox_ubs = self.get_bbox(shape, force_fg, properties['class_locations']) + + # whoever wrote this knew what he was doing (hint: it was me). We first crop the data to the region of the + # bbox that actually lies within the data. This will result in a smaller array which is then faster to pad. + # valid_bbox is just the coord that lied within the data cube. It will be padded to match the patch size + # later + valid_bbox_lbs = [max(0, bbox_lbs[i]) for i in range(dim)] + valid_bbox_ubs = [min(shape[i], bbox_ubs[i]) for i in range(dim)] + + # At this point you might ask yourself why we would treat seg differently from seg_from_previous_stage. + # Why not just concatenate them here and forget about the if statements? Well that's because segneeds to + # be padded with -1 constant whereas seg_from_previous_stage needs to be padded with 0s (we could also + # remove label -1 in the data augmentation but this way it is less error prone) + this_slice = tuple([slice(0, data.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)]) + data = data[this_slice] + + this_slice = tuple([slice(0, seg.shape[0])] + [slice(i, j) for i, j in zip(valid_bbox_lbs, valid_bbox_ubs)]) + seg = seg[this_slice] + + padding = [(-min(0, bbox_lbs[i]), max(bbox_ubs[i] - shape[i], 0)) for i in range(dim)] + data_all[j] = np.pad(data, ((0, 0), *padding), 'constant', constant_values=0) + seg_all[j] = np.pad(seg, ((0, 0), *padding), 'constant', constant_values=-1) + + return {'data': data_all, 'seg': seg_all, 'properties': case_properties, 'keys': selected_keys} + + +if __name__ == '__main__': + folder = '/media/fabian/data/nnUNet_preprocessed/Dataset002_Heart/3d_fullres' + ds = nnUNetDataset(folder, 0) # this should not load the properties! + dl = nnUNetDataLoader3D(ds, 5, (16, 16, 16), (16, 16, 16), 0.33, None, None) + a = next(dl) diff --git a/nnUNet/nnunetv2/training/dataloading/nnunet_dataset.py b/nnUNet/nnunetv2/training/dataloading/nnunet_dataset.py new file mode 100644 index 0000000..ae27fc3 --- /dev/null +++ b/nnUNet/nnunetv2/training/dataloading/nnunet_dataset.py @@ -0,0 +1,146 @@ +import os +from typing import List + +import numpy as np +import shutil + +from batchgenerators.utilities.file_and_folder_operations import join, load_pickle, isfile +from nnunetv2.training.dataloading.utils import get_case_identifiers + + +class nnUNetDataset(object): + def __init__(self, folder: str, case_identifiers: List[str] = None, + num_images_properties_loading_threshold: int = 0, + folder_with_segs_from_previous_stage: str = None): + """ + This does not actually load the dataset. It merely creates a dictionary where the keys are training case names and + the values are dictionaries containing the relevant information for that case. + dataset[training_case] -> info + Info has the following key:value pairs: + - dataset[case_identifier]['properties']['data_file'] -> the full path to the npz file associated with the training case + - dataset[case_identifier]['properties']['properties_file'] -> the pkl file containing the case properties + + In addition, if the total number of cases is < num_images_properties_loading_threshold we load all the pickle files + (containing auxiliary information). This is done for small datasets so that we don't spend too much CPU time on + reading pkl files on the fly during training. However, for large datasets storing all the aux info (which also + contains locations of foreground voxels in the images) can cause too much RAM utilization. In that + case is it better to load on the fly. + + If properties are loaded into the RAM, the info dicts each will have an additional entry: + - dataset[case_identifier]['properties'] -> pkl file content + + IMPORTANT! THIS CLASS ITSELF IS READ-ONLY. YOU CANNOT ADD KEY:VALUE PAIRS WITH nnUNetDataset[key] = value + USE THIS INSTEAD: + nnUNetDataset.dataset[key] = value + (not sure why you'd want to do that though. So don't do it) + """ + super().__init__() + # print('loading dataset') + if case_identifiers is None: + case_identifiers = get_case_identifiers(folder) + case_identifiers.sort() + + self.dataset = {} + for c in case_identifiers: + self.dataset[c] = {} + self.dataset[c]['data_file'] = join(folder, "%s.npz" % c) + self.dataset[c]['properties_file'] = join(folder, "%s.pkl" % c) + if folder_with_segs_from_previous_stage is not None: + self.dataset[c]['seg_from_prev_stage_file'] = join(folder_with_segs_from_previous_stage, "%s.npz" % c) + + if len(case_identifiers) <= num_images_properties_loading_threshold: + for i in self.dataset.keys(): + self.dataset[i]['properties'] = load_pickle(self.dataset[i]['properties_file']) + + self.keep_files_open = ('nnUNet_keep_files_open' in os.environ.keys()) and \ + (os.environ['nnUNet_keep_files_open'].lower() in ('true', '1', 't')) + # print(f'nnUNetDataset.keep_files_open: {self.keep_files_open}') + + def __getitem__(self, key): + ret = {**self.dataset[key]} + if 'properties' not in ret.keys(): + ret['properties'] = load_pickle(ret['properties_file']) + return ret + + def __setitem__(self, key, value): + return self.dataset.__setitem__(key, value) + + def keys(self): + return self.dataset.keys() + + def __len__(self): + return self.dataset.__len__() + + def items(self): + return self.dataset.items() + + def values(self): + return self.dataset.values() + + def load_case(self, key): + entry = self[key] + if 'open_data_file' in entry.keys(): + data = entry['open_data_file'] + # print('using open data file') + elif isfile(entry['data_file'][:-4] + ".npy"): + data = np.load(entry['data_file'][:-4] + ".npy", 'r') + if self.keep_files_open: + self.dataset[key]['open_data_file'] = data + # print('saving open data file') + else: + data = np.load(entry['data_file'])['data'] + + if 'open_seg_file' in entry.keys(): + seg = entry['open_seg_file'] + # print('using open data file') + elif isfile(entry['data_file'][:-4] + "_seg.npy"): + seg = np.load(entry['data_file'][:-4] + "_seg.npy", 'r') + if self.keep_files_open: + self.dataset[key]['open_seg_file'] = seg + # print('saving open seg file') + else: + seg = np.load(entry['data_file'])['seg'] + + if 'seg_from_prev_stage_file' in entry.keys(): + if isfile(entry['seg_from_prev_stage_file'][:-4] + ".npy"): + seg_prev = np.load(entry['seg_from_prev_stage_file'][:-4] + ".npy", 'r') + else: + seg_prev = np.load(entry['seg_from_prev_stage_file'])['seg'] + seg = np.vstack((seg, seg_prev[None])) + + return data, seg, entry['properties'] + + +if __name__ == '__main__': + # this is a mini test. Todo: We can move this to tests in the future (requires simulated dataset) + + folder = '/media/fabian/data/nnUNet_preprocessed/Dataset003_Liver/3d_lowres' + ds = nnUNetDataset(folder, num_images_properties_loading_threshold=0) # this should not load the properties! + # this SHOULD HAVE the properties + ks = ds['liver_0'].keys() + assert 'properties' in ks + # amazing. I am the best. + + # this should have the properties + ds = nnUNetDataset(folder, num_images_properties_loading_threshold=1000) + # now rename the properties file so that it doesnt exist anymore + shutil.move(join(folder, 'liver_0.pkl'), join(folder, 'liver_XXX.pkl')) + # now we should still be able to access the properties because they have already been loaded + ks = ds['liver_0'].keys() + assert 'properties' in ks + # move file back + shutil.move(join(folder, 'liver_XXX.pkl'), join(folder, 'liver_0.pkl')) + + # this should not have the properties + ds = nnUNetDataset(folder, num_images_properties_loading_threshold=0) + # now rename the properties file so that it doesnt exist anymore + shutil.move(join(folder, 'liver_0.pkl'), join(folder, 'liver_XXX.pkl')) + # now this should crash + try: + ks = ds['liver_0'].keys() + raise RuntimeError('we should not have come here') + except FileNotFoundError: + print('all good') + # move file back + shutil.move(join(folder, 'liver_XXX.pkl'), join(folder, 'liver_0.pkl')) + diff --git a/nnUNet/nnunetv2/training/dataloading/utils.py b/nnUNet/nnunetv2/training/dataloading/utils.py new file mode 100644 index 0000000..bd145b4 --- /dev/null +++ b/nnUNet/nnunetv2/training/dataloading/utils.py @@ -0,0 +1,48 @@ +import multiprocessing +import os +from multiprocessing import Pool +from typing import List + +import numpy as np +from batchgenerators.utilities.file_and_folder_operations import isfile, subfiles +from nnunetv2.configuration import default_num_processes + + +def _convert_to_npy(npz_file: str, unpack_segmentation: bool = True, overwrite_existing: bool = False) -> None: + try: + a = np.load(npz_file) # inexpensive, no compression is done here. This just reads metadata + if overwrite_existing or not isfile(npz_file[:-3] + "npy"): + np.save(npz_file[:-3] + "npy", a['data']) + if unpack_segmentation and (overwrite_existing or not isfile(npz_file[:-4] + "_seg.npy")): + np.save(npz_file[:-4] + "_seg.npy", a['seg']) + except KeyboardInterrupt: + if isfile(npz_file[:-3] + "npy"): + os.remove(npz_file[:-3] + "npy") + if isfile(npz_file[:-4] + "_seg.npy"): + os.remove(npz_file[:-4] + "_seg.npy") + raise KeyboardInterrupt + + +def unpack_dataset(folder: str, unpack_segmentation: bool = True, overwrite_existing: bool = False, + num_processes: int = default_num_processes): + """ + all npz files in this folder belong to the dataset, unpack them all + """ + with multiprocessing.get_context("spawn").Pool(num_processes) as p: + npz_files = subfiles(folder, True, None, ".npz", True) + p.starmap(_convert_to_npy, zip(npz_files, + [unpack_segmentation] * len(npz_files), + [overwrite_existing] * len(npz_files)) + ) + + +def get_case_identifiers(folder: str) -> List[str]: + """ + finds all npz files in the given folder and reconstructs the training case names from them + """ + case_identifiers = [i[:-4] for i in os.listdir(folder) if i.endswith("npz") and (i.find("segFromPrevStage") == -1)] + return case_identifiers + + +if __name__ == '__main__': + unpack_dataset('/media/fabian/data/nnUNet_preprocessed/Dataset002_Heart/2d') \ No newline at end of file diff --git a/nnUNet/nnunetv2/training/logging/__init__.py b/nnUNet/nnunetv2/training/logging/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/training/logging/nnunet_logger.py b/nnUNet/nnunetv2/training/logging/nnunet_logger.py new file mode 100644 index 0000000..8409738 --- /dev/null +++ b/nnUNet/nnunetv2/training/logging/nnunet_logger.py @@ -0,0 +1,103 @@ +import matplotlib +from batchgenerators.utilities.file_and_folder_operations import join + +matplotlib.use('agg') +import seaborn as sns +import matplotlib.pyplot as plt + + +class nnUNetLogger(object): + """ + This class is really trivial. Don't expect cool functionality here. This is my makeshift solution to problems + arising from out-of-sync epoch numbers and numbers of logged loss values. It also simplifies the trainer class a + little + + YOU MUST LOG EXACTLY ONE VALUE PER EPOCH FOR EACH OF THE LOGGING ITEMS! DONT FUCK IT UP + """ + def __init__(self, verbose: bool = False): + self.my_fantastic_logging = { + 'mean_fg_dice': list(), + 'ema_fg_dice': list(), + 'dice_per_class_or_region': list(), + 'train_losses': list(), + 'val_losses': list(), + 'lrs': list(), + 'epoch_start_timestamps': list(), + 'epoch_end_timestamps': list() + } + self.verbose = verbose + # shut up, this logging is great + + def log(self, key, value, epoch: int): + """ + sometimes shit gets messed up. We try to catch that here + """ + assert key in self.my_fantastic_logging.keys() and isinstance(self.my_fantastic_logging[key], list), \ + 'This function is only intended to log stuff to lists and to have one entry per epoch' + + if self.verbose: print(f'logging {key}: {value} for epoch {epoch}') + + if len(self.my_fantastic_logging[key]) < (epoch + 1): + self.my_fantastic_logging[key].append(value) + else: + assert len(self.my_fantastic_logging[key]) == (epoch + 1), 'something went horribly wrong. My logging ' \ + 'lists length is off by more than 1' + print(f'maybe some logging issue!? logging {key} and {value}') + self.my_fantastic_logging[key][epoch] = value + + # handle the ema_fg_dice special case! It is automatically logged when we add a new mean_fg_dice + if key == 'mean_fg_dice': + new_ema_pseudo_dice = self.my_fantastic_logging['ema_fg_dice'][epoch - 1] * 0.9 + 0.1 * value \ + if len(self.my_fantastic_logging['ema_fg_dice']) > 0 else value + self.log('ema_fg_dice', new_ema_pseudo_dice, epoch) + + def plot_progress_png(self, output_folder): + # we infer the epoch form our internal logging + epoch = min([len(i) for i in self.my_fantastic_logging.values()]) - 1 # lists of epoch 0 have len 1 + sns.set(font_scale=2.5) + fig, ax_all = plt.subplots(3, 1, figsize=(30, 54)) + # regular progress.png as we are used to from previous nnU-Net versions + ax = ax_all[0] + ax2 = ax.twinx() + x_values = list(range(epoch + 1)) + ax.plot(x_values, self.my_fantastic_logging['train_losses'][:epoch + 1], color='b', ls='-', label="loss_tr", linewidth=4) + ax.plot(x_values, self.my_fantastic_logging['val_losses'][:epoch + 1], color='r', ls='-', label="loss_val", linewidth=4) + ax2.plot(x_values, self.my_fantastic_logging['mean_fg_dice'][:epoch + 1], color='g', ls='dotted', label="pseudo dice", + linewidth=3) + ax2.plot(x_values, self.my_fantastic_logging['ema_fg_dice'][:epoch + 1], color='g', ls='-', label="pseudo dice (mov. avg.)", + linewidth=4) + ax.set_xlabel("epoch") + ax.set_ylabel("loss") + ax2.set_ylabel("pseudo dice") + ax.legend(loc=(0, 1)) + ax2.legend(loc=(0.2, 1)) + + # epoch times to see whether the training speed is consistent (inconsistent means there are other jobs + # clogging up the system) + ax = ax_all[1] + ax.plot(x_values, [i - j for i, j in zip(self.my_fantastic_logging['epoch_end_timestamps'][:epoch + 1], + self.my_fantastic_logging['epoch_start_timestamps'])][:epoch + 1], color='b', + ls='-', label="epoch duration", linewidth=4) + ylim = [0] + [ax.get_ylim()[1]] + ax.set(ylim=ylim) + ax.set_xlabel("epoch") + ax.set_ylabel("time [s]") + ax.legend(loc=(0, 1)) + + # learning rate + ax = ax_all[2] + ax.plot(x_values, self.my_fantastic_logging['lrs'][:epoch + 1], color='b', ls='-', label="learning rate", linewidth=4) + ax.set_xlabel("epoch") + ax.set_ylabel("learning rate") + ax.legend(loc=(0, 1)) + + plt.tight_layout() + + fig.savefig(join(output_folder, "progress.png")) + plt.close() + + def get_checkpoint(self): + return self.my_fantastic_logging + + def load_checkpoint(self, checkpoint: dict): + self.my_fantastic_logging = checkpoint diff --git a/nnUNet/nnunetv2/training/loss/__init__.py b/nnUNet/nnunetv2/training/loss/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/training/loss/compound_losses.py b/nnUNet/nnunetv2/training/loss/compound_losses.py new file mode 100644 index 0000000..9db0a42 --- /dev/null +++ b/nnUNet/nnunetv2/training/loss/compound_losses.py @@ -0,0 +1,151 @@ +import torch +from nnunetv2.training.loss.dice import SoftDiceLoss, MemoryEfficientSoftDiceLoss +from nnunetv2.training.loss.robust_ce_loss import RobustCrossEntropyLoss, TopKLoss +from nnunetv2.utilities.helpers import softmax_helper_dim1 +from torch import nn + + +class DC_and_CE_loss(nn.Module): + def __init__(self, soft_dice_kwargs, ce_kwargs, weight_ce=1, weight_dice=1, ignore_label=None, + dice_class=SoftDiceLoss): + """ + Weights for CE and Dice do not need to sum to one. You can set whatever you want. + :param soft_dice_kwargs: + :param ce_kwargs: + :param aggregate: + :param square_dice: + :param weight_ce: + :param weight_dice: + """ + super(DC_and_CE_loss, self).__init__() + if ignore_label is not None: + ce_kwargs['ignore_index'] = ignore_label + + self.weight_dice = weight_dice + self.weight_ce = weight_ce + self.ignore_label = ignore_label + + self.ce = RobustCrossEntropyLoss(**ce_kwargs) + self.dc = dice_class(apply_nonlin=softmax_helper_dim1, **soft_dice_kwargs) + + def forward(self, net_output: torch.Tensor, target: torch.Tensor): + """ + target must be b, c, x, y(, z) with c=1 + :param net_output: + :param target: + :return: + """ + if self.ignore_label is not None: + assert target.shape[1] == 1, 'ignore label is not implemented for one hot encoded target variables ' \ + '(DC_and_CE_loss)' + mask = (target != self.ignore_label).bool() + # remove ignore label from target, replace with one of the known labels. It doesn't matter because we + # ignore gradients in those areas anyway + target_dice = torch.clone(target) + target_dice[target == self.ignore_label] = 0 + num_fg = mask.sum() + else: + target_dice = target + mask = None + + dc_loss = self.dc(net_output, target_dice, loss_mask=mask) \ + if self.weight_dice != 0 else 0 + ce_loss = self.ce(net_output, target[:, 0].long()) \ + if self.weight_ce != 0 and (self.ignore_label is None or num_fg > 0) else 0 + + result = self.weight_ce * ce_loss + self.weight_dice * dc_loss + return result + + +class DC_and_BCE_loss(nn.Module): + def __init__(self, bce_kwargs, soft_dice_kwargs, weight_ce=1, weight_dice=1, use_ignore_label: bool = False, + dice_class=MemoryEfficientSoftDiceLoss): + """ + DO NOT APPLY NONLINEARITY IN YOUR NETWORK! + + target mut be one hot encoded + IMPORTANT: We assume use_ignore_label is located in target[:, -1]!!! + + :param soft_dice_kwargs: + :param bce_kwargs: + :param aggregate: + """ + super(DC_and_BCE_loss, self).__init__() + if use_ignore_label: + bce_kwargs['reduction'] = 'none' + + self.weight_dice = weight_dice + self.weight_ce = weight_ce + self.use_ignore_label = use_ignore_label + + self.ce = nn.BCEWithLogitsLoss(**bce_kwargs) + self.dc = dice_class(apply_nonlin=torch.sigmoid, **soft_dice_kwargs) + + def forward(self, net_output: torch.Tensor, target: torch.Tensor): + if self.use_ignore_label: + # target is one hot encoded here. invert it so that it is True wherever we can compute the loss + mask = (1 - target[:, -1:]).bool() + # remove ignore channel now that we have the mask + target_regions = torch.clone(target[:, :-1]) + else: + target_regions = target + mask = None + + dc_loss = self.dc(net_output, target_regions, loss_mask=mask) + if mask is not None: + ce_loss = (self.ce(net_output, target_regions) * mask).sum() / torch.clip(mask.sum(), min=1e-8) + else: + ce_loss = self.ce(net_output, target_regions) + result = self.weight_ce * ce_loss + self.weight_dice * dc_loss + return result + + +class DC_and_topk_loss(nn.Module): + def __init__(self, soft_dice_kwargs, ce_kwargs, weight_ce=1, weight_dice=1, ignore_label=None): + """ + Weights for CE and Dice do not need to sum to one. You can set whatever you want. + :param soft_dice_kwargs: + :param ce_kwargs: + :param aggregate: + :param square_dice: + :param weight_ce: + :param weight_dice: + """ + super().__init__() + if ignore_label is not None: + ce_kwargs['ignore_index'] = ignore_label + + self.weight_dice = weight_dice + self.weight_ce = weight_ce + self.ignore_label = ignore_label + + self.ce = TopKLoss(**ce_kwargs) + self.dc = SoftDiceLoss(apply_nonlin=softmax_helper_dim1, **soft_dice_kwargs) + + def forward(self, net_output: torch.Tensor, target: torch.Tensor): + """ + target must be b, c, x, y(, z) with c=1 + :param net_output: + :param target: + :return: + """ + if self.ignore_label is not None: + assert target.shape[1] == 1, 'ignore label is not implemented for one hot encoded target variables ' \ + '(DC_and_CE_loss)' + mask = (target != self.ignore_label).bool() + # remove ignore label from target, replace with one of the known labels. It doesn't matter because we + # ignore gradients in those areas anyway + target_dice = torch.clone(target) + target_dice[target == self.ignore_label] = 0 + num_fg = mask.sum() + else: + target_dice = target + mask = None + + dc_loss = self.dc(net_output, target_dice, loss_mask=mask) \ + if self.weight_dice != 0 else 0 + ce_loss = self.ce(net_output, target) \ + if self.weight_ce != 0 and (self.ignore_label is None or num_fg > 0) else 0 + + result = self.weight_ce * ce_loss + self.weight_dice * dc_loss + return result diff --git a/nnUNet/nnunetv2/training/loss/deep_supervision.py b/nnUNet/nnunetv2/training/loss/deep_supervision.py new file mode 100644 index 0000000..db71e80 --- /dev/null +++ b/nnUNet/nnunetv2/training/loss/deep_supervision.py @@ -0,0 +1,35 @@ +from torch import nn + + +class DeepSupervisionWrapper(nn.Module): + def __init__(self, loss, weight_factors=None): + """ + Wraps a loss function so that it can be applied to multiple outputs. Forward accepts an arbitrary number of + inputs. Each input is expected to be a tuple/list. Each tuple/list must have the same length. The loss is then + applied to each entry like this: + l = w0 * loss(input0[0], input1[0], ...) + w1 * loss(input0[1], input1[1], ...) + ... + If weights are None, all w will be 1. + """ + super(DeepSupervisionWrapper, self).__init__() + self.weight_factors = weight_factors + self.loss = loss + + def forward(self, *args): + for i in args: + assert isinstance(i, (tuple, list)), "all args must be either tuple or list, got %s" % type(i) + # we could check for equal lengths here as well but we really shouldn't overdo it with checks because + # this code is executed a lot of times! + + if self.weight_factors is None: + weights = [1] * len(args[0]) + else: + weights = self.weight_factors + + # we initialize the loss like this instead of 0 to ensure it sits on the correct device, not sure if that's + # really necessary + l = weights[0] * self.loss(*[j[0] for j in args]) + for i, inputs in enumerate(zip(*args)): + if i == 0: + continue + l += weights[i] * self.loss(*inputs) + return l \ No newline at end of file diff --git a/nnUNet/nnunetv2/training/loss/dice.py b/nnUNet/nnunetv2/training/loss/dice.py new file mode 100644 index 0000000..84c1e16 --- /dev/null +++ b/nnUNet/nnunetv2/training/loss/dice.py @@ -0,0 +1,194 @@ +from typing import Callable + +import torch +from nnunetv2.utilities.ddp_allgather import AllGatherGrad +from nnunetv2.utilities.tensor_utilities import sum_tensor +from torch import nn + + +class SoftDiceLoss(nn.Module): + def __init__(self, apply_nonlin: Callable = None, batch_dice: bool = False, do_bg: bool = True, smooth: float = 1., + ddp: bool = True, clip_tp: float = None): + """ + """ + super(SoftDiceLoss, self).__init__() + + self.do_bg = do_bg + self.batch_dice = batch_dice + self.apply_nonlin = apply_nonlin + self.smooth = smooth + self.clip_tp = clip_tp + self.ddp = ddp + + def forward(self, x, y, loss_mask=None): + shp_x = x.shape + + if self.batch_dice: + axes = [0] + list(range(2, len(shp_x))) + else: + axes = list(range(2, len(shp_x))) + + if self.apply_nonlin is not None: + x = self.apply_nonlin(x) + + tp, fp, fn, _ = get_tp_fp_fn_tn(x, y, axes, loss_mask, False) + + if self.ddp and self.batch_dice: + tp = AllGatherGrad.apply(tp).sum(0) + fp = AllGatherGrad.apply(fp).sum(0) + fn = AllGatherGrad.apply(fn).sum(0) + + if self.clip_tp is not None: + tp = torch.clip(tp, min=self.clip_tp , max=None) + + nominator = 2 * tp + denominator = 2 * tp + fp + fn + + dc = (nominator + self.smooth) / (torch.clip(denominator + self.smooth, 1e-8)) + + if not self.do_bg: + if self.batch_dice: + dc = dc[1:] + else: + dc = dc[:, 1:] + dc = dc.mean() + + return -dc + + +class MemoryEfficientSoftDiceLoss(nn.Module): + def __init__(self, apply_nonlin: Callable = None, batch_dice: bool = False, do_bg: bool = True, smooth: float = 1., + ddp: bool = True): + """ + saves 1.6 GB on Dataset017 3d_lowres + """ + super(MemoryEfficientSoftDiceLoss, self).__init__() + + self.do_bg = do_bg + self.batch_dice = batch_dice + self.apply_nonlin = apply_nonlin + self.smooth = smooth + self.ddp = ddp + + def forward(self, x, y, loss_mask=None): + shp_x, shp_y = x.shape, y.shape + + if self.apply_nonlin is not None: + x = self.apply_nonlin(x) + + if not self.do_bg: + x = x[:, 1:] + + # make everything shape (b, c) + axes = list(range(2, len(shp_x))) + + with torch.no_grad(): + if len(shp_x) != len(shp_y): + y = y.view((shp_y[0], 1, *shp_y[1:])) + + if all([i == j for i, j in zip(shp_x, shp_y)]): + # if this is the case then gt is probably already a one hot encoding + y_onehot = y + else: + gt = y.long() + y_onehot = torch.zeros(shp_x, device=x.device, dtype=torch.bool) + y_onehot.scatter_(1, gt, 1) + + if not self.do_bg: + y_onehot = y_onehot[:, 1:] + sum_gt = y_onehot.sum(axes) if loss_mask is None else (y_onehot * loss_mask).sum(axes) + + intersect = (x * y_onehot).sum(axes) if loss_mask is None else (x * y_onehot * loss_mask).sum(axes) + sum_pred = x.sum(axes) if loss_mask is None else (x * loss_mask).sum(axes) + + if self.ddp and self.batch_dice: + intersect = AllGatherGrad.apply(intersect).sum(0) + sum_pred = AllGatherGrad.apply(sum_pred).sum(0) + sum_gt = AllGatherGrad.apply(sum_gt).sum(0) + + if self.batch_dice: + intersect = intersect.sum(0) + sum_pred = sum_pred.sum(0) + sum_gt = sum_gt.sum(0) + + dc = (2 * intersect + self.smooth) / (torch.clip(sum_gt + sum_pred + self.smooth, 1e-8)) + + dc = dc.mean() + return -dc + + +def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False): + """ + net_output must be (b, c, x, y(, z))) + gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z)) + if mask is provided it must have shape (b, 1, x, y(, z))) + :param net_output: + :param gt: + :param axes: can be (, ) = no summation + :param mask: mask must be 1 for valid pixels and 0 for invalid pixels + :param square: if True then fp, tp and fn will be squared before summation + :return: + """ + if axes is None: + axes = tuple(range(2, len(net_output.size()))) + + shp_x = net_output.shape + shp_y = gt.shape + + with torch.no_grad(): + if len(shp_x) != len(shp_y): + gt = gt.view((shp_y[0], 1, *shp_y[1:])) + + if all([i == j for i, j in zip(net_output.shape, gt.shape)]): + # if this is the case then gt is probably already a one hot encoding + y_onehot = gt + else: + gt = gt.long() + y_onehot = torch.zeros(shp_x, device=net_output.device) + y_onehot.scatter_(1, gt, 1) + + tp = net_output * y_onehot + fp = net_output * (1 - y_onehot) + fn = (1 - net_output) * y_onehot + tn = (1 - net_output) * (1 - y_onehot) + + if mask is not None: + with torch.no_grad(): + mask_here = torch.tile(mask, (1, tp.shape[1], *[1 for i in range(2, len(tp.shape))])) + tp *= mask_here + fp *= mask_here + fn *= mask_here + tn *= mask_here + # benchmark whether tiling the mask would be faster (torch.tile). It probably is for large batch sizes + # OK it barely makes a difference but the implementation above is a tiny bit faster + uses less vram + # (using nnUNetv2_train 998 3d_fullres 0) + # tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1) + # fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1) + # fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1) + # tn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tn, dim=1)), dim=1) + + if square: + tp = tp ** 2 + fp = fp ** 2 + fn = fn ** 2 + tn = tn ** 2 + + if len(axes) > 0: + tp = sum_tensor(tp, axes, keepdim=False) + fp = sum_tensor(fp, axes, keepdim=False) + fn = sum_tensor(fn, axes, keepdim=False) + tn = sum_tensor(tn, axes, keepdim=False) + + return tp, fp, fn, tn + + +if __name__ == '__main__': + from nnunetv2.utilities.helpers import softmax_helper_dim1 + pred = torch.rand((2, 3, 32, 32, 32)) + ref = torch.randint(0, 3, (2, 32, 32, 32)) + + dl_old = SoftDiceLoss(apply_nonlin=softmax_helper_dim1, batch_dice=True, do_bg=False, smooth=0, ddp=False) + dl_new = MemoryEfficientSoftDiceLoss(apply_nonlin=softmax_helper_dim1, batch_dice=True, do_bg=False, smooth=0, ddp=False) + res_old = dl_old(pred, ref) + res_new = dl_new(pred, ref) + print(res_old, res_new) diff --git a/nnUNet/nnunetv2/training/loss/robust_ce_loss.py b/nnUNet/nnunetv2/training/loss/robust_ce_loss.py new file mode 100644 index 0000000..ad46659 --- /dev/null +++ b/nnUNet/nnunetv2/training/loss/robust_ce_loss.py @@ -0,0 +1,33 @@ +import torch +from torch import nn, Tensor +import numpy as np + + +class RobustCrossEntropyLoss(nn.CrossEntropyLoss): + """ + this is just a compatibility layer because my target tensor is float and has an extra dimension + + input must be logits, not probabilities! + """ + def forward(self, input: Tensor, target: Tensor) -> Tensor: + if len(target.shape) == len(input.shape): + assert target.shape[1] == 1 + target = target[:, 0] + return super().forward(input, target.long()) + + +class TopKLoss(RobustCrossEntropyLoss): + """ + input must be logits, not probabilities! + """ + def __init__(self, weight=None, ignore_index: int = -100, k: float = 10, label_smoothing: float = 0): + self.k = k + super(TopKLoss, self).__init__(weight, False, ignore_index, reduce=False, label_smoothing=label_smoothing) + + def forward(self, inp, target): + target = target[:, 0].long() + res = super(TopKLoss, self).forward(inp, target) + num_voxels = np.prod(res.shape, dtype=np.int64) + res, _ = torch.topk(res.view((-1, )), int(num_voxels * self.k / 100), sorted=False) + return res.mean() + diff --git a/nnUNet/nnunetv2/training/lr_scheduler/__init__.py b/nnUNet/nnunetv2/training/lr_scheduler/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/training/lr_scheduler/polylr.py b/nnUNet/nnunetv2/training/lr_scheduler/polylr.py new file mode 100644 index 0000000..44857b5 --- /dev/null +++ b/nnUNet/nnunetv2/training/lr_scheduler/polylr.py @@ -0,0 +1,20 @@ +from torch.optim.lr_scheduler import _LRScheduler + + +class PolyLRScheduler(_LRScheduler): + def __init__(self, optimizer, initial_lr: float, max_steps: int, exponent: float = 0.9, current_step: int = None): + self.optimizer = optimizer + self.initial_lr = initial_lr + self.max_steps = max_steps + self.exponent = exponent + self.ctr = 0 + super().__init__(optimizer, current_step if current_step is not None else -1, False) + + def step(self, current_step=None): + if current_step is None or current_step == -1: + current_step = self.ctr + self.ctr += 1 + + new_lr = self.initial_lr * (1 - current_step / self.max_steps) ** self.exponent + for param_group in self.optimizer.param_groups: + param_group['lr'] = new_lr diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/STUNetTrainer.py b/nnUNet/nnunetv2/training/nnUNetTrainer/STUNetTrainer.py new file mode 100644 index 0000000..85587ac --- /dev/null +++ b/nnUNet/nnunetv2/training/nnUNetTrainer/STUNetTrainer.py @@ -0,0 +1,254 @@ +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +import torch +from torch import nn + +class STUNetTrainer(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 1000 + self.initial_lr = 1e-2 + + @staticmethod + def build_network_architecture(plans_manager, + dataset_json, + configuration_manager, + num_input_channels, + enable_deep_supervision: bool = True) -> nn.Module: + label_manager = plans_manager.get_label_manager(dataset_json) + num_classes=label_manager.num_segmentation_heads + kernel_sizes = [[3,3,3]] * 6 + strides=configuration_manager.pool_op_kernel_sizes[1:] + if len(strides)>5: + strides = strides[:5] + while len(strides)<5: + strides.append([1,1,1]) + return STUNet(num_input_channels, num_classes, depth=[1]*6, dims= [32 * x for x in [1, 2, 4, 8, 16, 16]], + pool_op_kernel_sizes=strides, conv_kernel_sizes=kernel_sizes, enable_deep_supervision=enable_deep_supervision) + +class STUNetTrainer_small(STUNetTrainer): + @staticmethod + def build_network_architecture(plans_manager, + dataset_json, + configuration_manager, + num_input_channels, + enable_deep_supervision: bool = True) -> nn.Module: + label_manager = plans_manager.get_label_manager(dataset_json) + num_classes=label_manager.num_segmentation_heads + kernel_sizes = [[3,3,3]] * 6 + strides=configuration_manager.pool_op_kernel_sizes[1:] + if len(strides)>5: + strides = strides[:5] + while len(strides)<5: + strides.append([1,1,1]) + return STUNet(num_input_channels, num_classes, depth=[1]*6, dims= [16 * x for x in [1, 2, 4, 8, 16, 16]], + pool_op_kernel_sizes=strides, conv_kernel_sizes=kernel_sizes, enable_deep_supervision=enable_deep_supervision) + +class STUNetTrainer_small_ft(STUNetTrainer_small): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.initial_lr = 1e-3 + self.num_epochs = 1000 + + +class STUNetTrainer_base(STUNetTrainer): + @staticmethod + def build_network_architecture(plans_manager, + dataset_json, + configuration_manager, + num_input_channels, + enable_deep_supervision: bool = True) -> nn.Module: + label_manager = plans_manager.get_label_manager(dataset_json) + num_classes=label_manager.num_segmentation_heads + kernel_sizes = [[3,3,3]] * 6 + strides=configuration_manager.pool_op_kernel_sizes[1:] + if len(strides)>5: + strides = strides[:5] + while len(strides)<5: + strides.append([1,1,1]) + return STUNet(num_input_channels, num_classes, depth=[1]*6, dims= [32 * x for x in [1, 2, 4, 8, 16, 16]], + pool_op_kernel_sizes=strides, conv_kernel_sizes=kernel_sizes, enable_deep_supervision=enable_deep_supervision) + +class STUNetTrainer_base_ft(STUNetTrainer_base): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.initial_lr = 1e-3 + self.num_epochs = 1000 + + +class STUNetTrainer_large(STUNetTrainer): + @staticmethod + def build_network_architecture(plans_manager, + dataset_json, + configuration_manager, + num_input_channels, + enable_deep_supervision: bool = True) -> nn.Module: + label_manager = plans_manager.get_label_manager(dataset_json) + num_classes=label_manager.num_segmentation_heads + kernel_sizes = [[3,3,3]] * 6 + strides=configuration_manager.pool_op_kernel_sizes[1:] + if len(strides)>5: + strides = strides[:5] + while len(strides)<5: + strides.append([1,1,1]) + return STUNet(num_input_channels, num_classes, depth=[2]*6, dims= [64 * x for x in [1, 2, 4, 8, 16, 16]], + pool_op_kernel_sizes=strides, conv_kernel_sizes=kernel_sizes, enable_deep_supervision=enable_deep_supervision) + +class STUNetTrainer_large_ft(STUNetTrainer_large): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.initial_lr = 1e-3 + self.num_epochs = 1000 + + +class STUNetTrainer_huge(STUNetTrainer): + @staticmethod + def build_network_architecture(plans_manager, + dataset_json, + configuration_manager, + num_input_channels, + enable_deep_supervision: bool = True) -> nn.Module: + label_manager = plans_manager.get_label_manager(dataset_json) + num_classes=label_manager.num_segmentation_heads + kernel_sizes = [[3,3,3]] * 6 + strides=configuration_manager.pool_op_kernel_sizes[1:] + if len(strides)>5: + strides = strides[:5] + while len(strides)<5: + strides.append([1,1,1]) + return STUNet(num_input_channels, num_classes, depth=[3]*6, dims= [96 * x for x in [1, 2, 4, 8, 16, 16]], + pool_op_kernel_sizes=strides, conv_kernel_sizes=kernel_sizes, enable_deep_supervision=enable_deep_supervision) + +class STUNetTrainer_huge_ft(STUNetTrainer_huge): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.initial_lr = 1e-3 + self.num_epochs = 1000 + + + +class Decoder(nn.Module): + def __init__(self): + super().__init__() + self.deep_supervision = True + +class STUNet(nn.Module): + def __init__(self, input_channels, num_classes, depth=[1,1,1,1,1,1], dims=[32, 64, 128, 256, 512, 512], + pool_op_kernel_sizes=None, conv_kernel_sizes=None, enable_deep_supervision=True): + super().__init__() + self.conv_op = nn.Conv3d + self.input_channels = input_channels + self.num_classes = num_classes + + self.final_nonlin = lambda x:x + self.decoder = Decoder() + self.decoder.deep_supervision = enable_deep_supervision + self.upscale_logits = False + + self.pool_op_kernel_sizes = pool_op_kernel_sizes + self.conv_kernel_sizes = conv_kernel_sizes + self.conv_pad_sizes = [] + for krnl in self.conv_kernel_sizes: + self.conv_pad_sizes.append([i // 2 for i in krnl]) + + + num_pool = len(pool_op_kernel_sizes) + + assert num_pool == len(dims) - 1 + + # encoder + self.conv_blocks_context = nn.ModuleList() + stage = nn.Sequential(BasicResBlock(input_channels, dims[0], self.conv_kernel_sizes[0], self.conv_pad_sizes[0], use_1x1conv=True), + *[BasicResBlock(dims[0], dims[0], self.conv_kernel_sizes[0], self.conv_pad_sizes[0]) for _ in range(depth[0]-1)]) + self.conv_blocks_context.append(stage) + for d in range(1, num_pool+1): + stage = nn.Sequential(BasicResBlock(dims[d-1], dims[d], self.conv_kernel_sizes[d], self.conv_pad_sizes[d], stride=self.pool_op_kernel_sizes[d-1], use_1x1conv=True), + *[BasicResBlock(dims[d], dims[d], self.conv_kernel_sizes[d], self.conv_pad_sizes[d]) for _ in range(depth[d]-1)]) + self.conv_blocks_context.append(stage) + + # upsample_layers + self.upsample_layers = nn.ModuleList() + for u in range(num_pool): + upsample_layer = Upsample_Layer_nearest(dims[-1-u], dims[-2-u], pool_op_kernel_sizes[-1-u]) + self.upsample_layers.append(upsample_layer) + + # decoder + self.conv_blocks_localization = nn.ModuleList() + for u in range(num_pool): + stage = nn.Sequential(BasicResBlock(dims[-2-u] * 2, dims[-2-u], self.conv_kernel_sizes[-2-u], self.conv_pad_sizes[-2-u], use_1x1conv=True), + *[BasicResBlock(dims[-2-u], dims[-2-u], self.conv_kernel_sizes[-2-u], self.conv_pad_sizes[-2-u]) for _ in range(depth[-2-u]-1)]) + self.conv_blocks_localization.append(stage) + + # outputs + self.seg_outputs = nn.ModuleList() + for ds in range(len(self.conv_blocks_localization)): + self.seg_outputs.append(nn.Conv3d(dims[-2-ds], num_classes, kernel_size=1)) + + self.upscale_logits_ops = [] + for usl in range(num_pool - 1): + self.upscale_logits_ops.append(lambda x: x) + + + def forward(self, x): + skips = [] + seg_outputs = [] + + for d in range(len(self.conv_blocks_context) - 1): + x = self.conv_blocks_context[d](x) + skips.append(x) + + x = self.conv_blocks_context[-1](x) + + for u in range(len(self.conv_blocks_localization)): + x = self.upsample_layers[u](x) + x = torch.cat((x, skips[-(u + 1)]), dim=1) + x = self.conv_blocks_localization[u](x) + seg_outputs.append(self.final_nonlin(self.seg_outputs[u](x))) + + if self.decoder.deep_supervision: + return tuple([seg_outputs[-1]] + [i(j) for i, j in + zip(list(self.upscale_logits_ops)[::-1], seg_outputs[:-1][::-1])]) + else: + return seg_outputs[-1] + + +class BasicResBlock(nn.Module): + def __init__(self, input_channels, output_channels, kernel_size=3, padding=1, stride=1, use_1x1conv=False): + super().__init__() + self.conv1 = nn.Conv3d(input_channels, output_channels, kernel_size, stride=stride, padding=padding) + self.norm1 = nn.InstanceNorm3d(output_channels, affine=True) + self.act1 = nn.LeakyReLU(inplace=True) + + self.conv2 = nn.Conv3d(output_channels, output_channels, kernel_size, padding=padding) + self.norm2 = nn.InstanceNorm3d(output_channels, affine=True) + self.act2 = nn.LeakyReLU(inplace=True) + + if use_1x1conv: + self.conv3 = nn.Conv3d(input_channels, output_channels, kernel_size=1, stride=stride) + else: + self.conv3 = None + + def forward(self, x): + y = self.conv1(x) + y = self.act1(self.norm1(y)) + y = self.norm2(self.conv2(y)) + if self.conv3: + x = self.conv3(x) + y += x + return self.act2(y) + +class Upsample_Layer_nearest(nn.Module): + def __init__(self, input_channels, output_channels, pool_op_kernel_size, mode='nearest'): + super().__init__() + self.conv = nn.Conv3d(input_channels, output_channels, kernel_size=1) + self.pool_op_kernel_size = pool_op_kernel_size + self.mode = mode + + def forward(self, x): + x = nn.functional.interpolate(x, scale_factor=self.pool_op_kernel_size, mode=self.mode) + x = self.conv(x) + return x \ No newline at end of file diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/__init__.py b/nnUNet/nnunetv2/training/nnUNetTrainer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py b/nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py new file mode 100644 index 0000000..6dc1277 --- /dev/null +++ b/nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py @@ -0,0 +1,1241 @@ +import inspect +import multiprocessing +import os +import shutil +import sys +import warnings +from copy import deepcopy +from datetime import datetime +from time import time, sleep +from typing import Union, Tuple, List + +import numpy as np +import torch +from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter +from batchgenerators.transforms.abstract_transforms import AbstractTransform, Compose +from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \ + ContrastAugmentationTransform, GammaTransform +from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform +from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform +from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform +from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor +from batchgenerators.utilities.file_and_folder_operations import join, load_json, isfile, save_json, maybe_mkdir_p +from torch._dynamo import OptimizedModule + +from nnunetv2.configuration import ANISO_THRESHOLD, default_num_processes +from nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder +from nnunetv2.inference.export_prediction import export_prediction_from_logits, resample_and_save +from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor +from nnunetv2.inference.sliding_window_prediction import compute_gaussian +from nnunetv2.paths import nnUNet_preprocessed, nnUNet_results +from nnunetv2.training.data_augmentation.compute_initial_patch_size import get_patch_size +from nnunetv2.training.data_augmentation.custom_transforms.cascade_transforms import MoveSegAsOneHotToData, \ + ApplyRandomBinaryOperatorTransform, RemoveRandomConnectedComponentFromOneHotEncodingTransform +from nnunetv2.training.data_augmentation.custom_transforms.deep_supervision_donwsampling import \ + DownsampleSegForDSTransform2 +from nnunetv2.training.data_augmentation.custom_transforms.limited_length_multithreaded_augmenter import \ + LimitedLenWrapper +from nnunetv2.training.data_augmentation.custom_transforms.masking import MaskTransform +from nnunetv2.training.data_augmentation.custom_transforms.region_based_training import \ + ConvertSegmentationToRegionsTransform +from nnunetv2.training.data_augmentation.custom_transforms.transforms_for_dummy_2d import Convert2DTo3DTransform, \ + Convert3DTo2DTransform +from nnunetv2.training.dataloading.data_loader_2d import nnUNetDataLoader2D +from nnunetv2.training.dataloading.data_loader_3d import nnUNetDataLoader3D +from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset +from nnunetv2.training.dataloading.utils import get_case_identifiers, unpack_dataset +from nnunetv2.training.logging.nnunet_logger import nnUNetLogger +from nnunetv2.training.loss.compound_losses import DC_and_CE_loss, DC_and_BCE_loss +from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper +from nnunetv2.training.loss.dice import get_tp_fp_fn_tn, MemoryEfficientSoftDiceLoss +from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler +from nnunetv2.utilities.collate_outputs import collate_outputs +from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA +from nnunetv2.utilities.file_path_utilities import check_workers_alive_and_busy +from nnunetv2.utilities.get_network_from_plans import get_network_from_plans +from nnunetv2.utilities.helpers import empty_cache, dummy_context +from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot, determine_num_input_channels +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager +from sklearn.model_selection import KFold +from torch import autocast, nn +from torch import distributed as dist +from torch.cuda import device_count +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP + + +class nnUNetTrainer(object): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + # From https://grugbrain.dev/. Worth a read ya big brains ;-) + + # apex predator of grug is complexity + # complexity bad + # say again: + # complexity very bad + # you say now: + # complexity very, very bad + # given choice between complexity or one on one against t-rex, grug take t-rex: at least grug see t-rex + # complexity is spirit demon that enter codebase through well-meaning but ultimately very clubbable non grug-brain developers and project managers who not fear complexity spirit demon or even know about sometime + # one day code base understandable and grug can get work done, everything good! + # next day impossible: complexity demon spirit has entered code and very dangerous situation! + + # OK OK I am guilty. But I tried. http://tiny.cc/gzgwuz + + self.is_ddp = dist.is_available() and dist.is_initialized() + self.local_rank = 0 if not self.is_ddp else dist.get_rank() + + self.device = device + + # print what device we are using + if self.is_ddp: # implicitly it's clear that we use cuda in this case + print(f"I am local rank {self.local_rank}. {device_count()} GPUs are available. The world size is " + f"{dist.get_world_size()}." + f"Setting device to {self.device}") + self.device = torch.device(type='cuda', index=self.local_rank) + else: + if self.device.type == 'cuda': + # we might want to let the user pick this but for now please pick the correct GPU with CUDA_VISIBLE_DEVICES=X + self.device = torch.device(type='cuda', index=0) + print(f"Using device: {self.device}") + + # loading and saving this class for continuing from checkpoint should not happen based on pickling. This + # would also pickle the network etc. Bad, bad. Instead we just reinstantiate and then load the checkpoint we + # need. So let's save the init args + self.my_init_kwargs = {} + for k in inspect.signature(self.__init__).parameters.keys(): + self.my_init_kwargs[k] = locals()[k] + + ### Saving all the init args into class variables for later access + self.plans_manager = PlansManager(plans) + self.configuration_manager = self.plans_manager.get_configuration(configuration) + self.configuration_name = configuration + self.dataset_json = dataset_json + self.fold = fold + self.unpack_dataset = unpack_dataset + + ### Setting all the folder names. We need to make sure things don't crash in case we are just running + # inference and some of the folders may not be defined! + self.preprocessed_dataset_folder_base = join(nnUNet_preprocessed, self.plans_manager.dataset_name) \ + if nnUNet_preprocessed is not None else None + self.output_folder_base = join(nnUNet_results, self.plans_manager.dataset_name, + self.__class__.__name__ + '__' + self.plans_manager.plans_name + "__" + configuration) \ + if nnUNet_results is not None else None + self.output_folder = join(self.output_folder_base, f'fold_{fold}') + + self.preprocessed_dataset_folder = join(self.preprocessed_dataset_folder_base, + self.configuration_manager.data_identifier) + # unlike the previous nnunet folder_with_segs_from_previous_stage is now part of the plans. For now it has to + # be a different configuration in the same plans + # IMPORTANT! the mapping must be bijective, so lowres must point to fullres and vice versa (using + # "previous_stage" and "next_stage"). Otherwise it won't work! + self.is_cascaded = self.configuration_manager.previous_stage_name is not None + self.folder_with_segs_from_previous_stage = \ + join(nnUNet_results, self.plans_manager.dataset_name, + self.__class__.__name__ + '__' + self.plans_manager.plans_name + "__" + + self.configuration_manager.previous_stage_name, 'predicted_next_stage', self.configuration_name) \ + if self.is_cascaded else None + + ### Some hyperparameters for you to fiddle with + self.initial_lr = 1e-2 + self.weight_decay = 3e-5 + self.oversample_foreground_percent = 0.33 + self.num_iterations_per_epoch = 250 + self.num_val_iterations_per_epoch = 50 + self.num_epochs = 1000 + self.current_epoch = 0 + + ### Dealing with labels/regions + self.label_manager = self.plans_manager.get_label_manager(dataset_json) + # labels can either be a list of int (regular training) or a list of tuples of int (region-based training) + # needed for predictions. We do sigmoid in case of (overlapping) regions + + self.num_input_channels = None # -> self.initialize() + self.network = None # -> self._get_network() + self.optimizer = self.lr_scheduler = None # -> self.initialize + self.grad_scaler = GradScaler() if self.device.type == 'cuda' else None + self.loss = None # -> self.initialize + + ### Simple logging. Don't take that away from me! + # initialize log file. This is just our log for the print statements etc. Not to be confused with lightning + # logging + timestamp = datetime.now() + maybe_mkdir_p(self.output_folder) + self.log_file = join(self.output_folder, "training_log_%d_%d_%d_%02.0d_%02.0d_%02.0d.txt" % + (timestamp.year, timestamp.month, timestamp.day, timestamp.hour, timestamp.minute, + timestamp.second)) + self.logger = nnUNetLogger() + + ### placeholders + self.dataloader_train = self.dataloader_val = None # see on_train_start + + ### initializing stuff for remembering things and such + self._best_ema = None + + ### inference things + self.inference_allowed_mirroring_axes = None # this variable is set in + # self.configure_rotation_dummyDA_mirroring_and_inital_patch_size and will be saved in checkpoints + + ### checkpoint saving stuff + self.save_every = 50 + self.disable_checkpointing = False + + ## DDP batch size and oversampling can differ between workers and needs adaptation + # we need to change the batch size in DDP because we don't use any of those distributed samplers + self._set_batch_size_and_oversample() + + self.was_initialized = False + + self.print_to_log_file("\n#######################################################################\n" + "Please cite the following paper when using nnU-Net:\n" + "Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). " + "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. " + "Nature methods, 18(2), 203-211.\n" + "#######################################################################\n", + also_print_to_console=True, add_timestamp=False) + + def initialize(self): + if not self.was_initialized: + self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager, + self.dataset_json) + + self.network = self.build_network_architecture(self.plans_manager, self.dataset_json, + self.configuration_manager, + self.num_input_channels, + enable_deep_supervision=True).to(self.device) + # compile network for free speedup + if ('nnUNet_compile' in os.environ.keys()) and ( + os.environ['nnUNet_compile'].lower() in ('true', '1', 't')): + self.print_to_log_file('Compiling network...') + self.network = torch.compile(self.network) + + self.optimizer, self.lr_scheduler = self.configure_optimizers() + # if ddp, wrap in DDP wrapper + if self.is_ddp: + self.network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.network) + self.network = DDP(self.network, device_ids=[self.local_rank]) + + self.loss = self._build_loss() + self.was_initialized = True + else: + raise RuntimeError("You have called self.initialize even though the trainer was already initialized. " + "That should not happen.") + + def _save_debug_information(self): + # saving some debug information + if self.local_rank == 0: + dct = {} + for k in self.__dir__(): + if not k.startswith("__"): + if not callable(getattr(self, k)) or k in ['loss', ]: + dct[k] = str(getattr(self, k)) + elif k in ['network', ]: + dct[k] = str(getattr(self, k).__class__.__name__) + else: + # print(k) + pass + if k in ['dataloader_train', 'dataloader_val']: + if hasattr(getattr(self, k), 'generator'): + dct[k + '.generator'] = str(getattr(self, k).generator) + if hasattr(getattr(self, k), 'num_processes'): + dct[k + '.num_processes'] = str(getattr(self, k).num_processes) + if hasattr(getattr(self, k), 'transform'): + dct[k + '.transform'] = str(getattr(self, k).transform) + import subprocess + hostname = subprocess.getoutput(['hostname']) + dct['hostname'] = hostname + torch_version = torch.__version__ + if self.device.type == 'cuda': + gpu_name = torch.cuda.get_device_name() + dct['gpu_name'] = gpu_name + cudnn_version = torch.backends.cudnn.version() + else: + cudnn_version = 'None' + dct['device'] = str(self.device) + dct['torch_version'] = torch_version + dct['cudnn_version'] = cudnn_version + save_json(dct, join(self.output_folder, "debug.json")) + + @staticmethod + def build_network_architecture(plans_manager: PlansManager, + dataset_json, + configuration_manager: ConfigurationManager, + num_input_channels, + enable_deep_supervision: bool = True) -> nn.Module: + """ + his is where you build the architecture according to the plans. There is no obligation to use + get_network_from_plans, this is just a utility we use for the nnU-Net default architectures. You can do what + you want. Even ignore the plans and just return something static (as long as it can process the requested + patch size) + but don't bug us with your bugs arising from fiddling with this :-P + This is the function that is called in inference as well! This is needed so that all network architecture + variants can be loaded at inference time (inference will use the same nnUNetTrainer that was used for + training, so if you change the network architecture during training by deriving a new trainer class then + inference will know about it). + + If you need to know how many segmentation outputs your custom architecture needs to have, use the following snippet: + > label_manager = plans_manager.get_label_manager(dataset_json) + > label_manager.num_segmentation_heads + (why so complicated? -> We can have either classical training (classes) or regions. If we have regions, + the number of outputs is != the number of classes. Also there is the ignore label for which no output + should be generated. label_manager takes care of all that for you.) + + """ + return get_network_from_plans(plans_manager, dataset_json, configuration_manager, + num_input_channels, deep_supervision=enable_deep_supervision) + + def _get_deep_supervision_scales(self): + deep_supervision_scales = list(list(i) for i in 1 / np.cumprod(np.vstack( + self.configuration_manager.pool_op_kernel_sizes), axis=0))[:-1] + return deep_supervision_scales + + def _set_batch_size_and_oversample(self): + if not self.is_ddp: + # set batch size to what the plan says, leave oversample untouched + self.batch_size = self.configuration_manager.batch_size + else: + # batch size is distributed over DDP workers and we need to change oversample_percent for each worker + batch_sizes = [] + oversample_percents = [] + + world_size = dist.get_world_size() + my_rank = dist.get_rank() + + global_batch_size = self.configuration_manager.batch_size + assert global_batch_size >= world_size, 'Cannot run DDP if the batch size is smaller than the number of ' \ + 'GPUs... Duh.' + + batch_size_per_GPU = np.ceil(global_batch_size / world_size).astype(int) + + for rank in range(world_size): + if (rank + 1) * batch_size_per_GPU > global_batch_size: + batch_size = batch_size_per_GPU - ((rank + 1) * batch_size_per_GPU - global_batch_size) + else: + batch_size = batch_size_per_GPU + + batch_sizes.append(batch_size) + + sample_id_low = 0 if len(batch_sizes) == 0 else np.sum(batch_sizes[:-1]) + sample_id_high = np.sum(batch_sizes) + + if sample_id_high / global_batch_size < (1 - self.oversample_foreground_percent): + oversample_percents.append(0.0) + elif sample_id_low / global_batch_size > (1 - self.oversample_foreground_percent): + oversample_percents.append(1.0) + else: + percent_covered_by_this_rank = sample_id_high / global_batch_size - sample_id_low / global_batch_size + oversample_percent_here = 1 - (((1 - self.oversample_foreground_percent) - + sample_id_low / global_batch_size) / percent_covered_by_this_rank) + oversample_percents.append(oversample_percent_here) + + print("worker", my_rank, "oversample", oversample_percents[my_rank]) + print("worker", my_rank, "batch_size", batch_sizes[my_rank]) + # self.print_to_log_file("worker", my_rank, "oversample", oversample_percents[my_rank]) + # self.print_to_log_file("worker", my_rank, "batch_size", batch_sizes[my_rank]) + + self.batch_size = batch_sizes[my_rank] + self.oversample_foreground_percent = oversample_percents[my_rank] + + def _build_loss(self): + if self.label_manager.has_regions: + loss = DC_and_BCE_loss({}, + {'batch_dice': self.configuration_manager.batch_dice, + 'do_bg': True, 'smooth': 1e-5, 'ddp': self.is_ddp}, + use_ignore_label=self.label_manager.ignore_label is not None, + dice_class=MemoryEfficientSoftDiceLoss) + else: + loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice, + 'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1, + ignore_label=self.label_manager.ignore_label, dice_class=MemoryEfficientSoftDiceLoss) + + deep_supervision_scales = self._get_deep_supervision_scales() + + # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) + weights[-1] = 0 + + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) + return loss + + def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): + """ + This function is stupid and certainly one of the weakest spots of this implementation. Not entirely sure how we can fix it. + """ + patch_size = self.configuration_manager.patch_size + dim = len(patch_size) + # todo rotation should be defined dynamically based on patch size (more isotropic patch sizes = more rotation) + if dim == 2: + do_dummy_2d_data_aug = False + # todo revisit this parametrization + if max(patch_size) / min(patch_size) > 1.5: + rotation_for_DA = { + 'x': (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi), + 'y': (0, 0), + 'z': (0, 0) + } + else: + rotation_for_DA = { + 'x': (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi), + 'y': (0, 0), + 'z': (0, 0) + } + mirror_axes = (0, 1) + elif dim == 3: + # todo this is not ideal. We could also have patch_size (64, 16, 128) in which case a full 180deg 2d rot would be bad + # order of the axes is determined by spacing, not image size + do_dummy_2d_data_aug = (max(patch_size) / patch_size[0]) > ANISO_THRESHOLD + if do_dummy_2d_data_aug: + # why do we rotate 180 deg here all the time? We should also restrict it + rotation_for_DA = { + 'x': (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi), + 'y': (0, 0), + 'z': (0, 0) + } + else: + rotation_for_DA = { + 'x': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), + 'y': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), + 'z': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), + } + mirror_axes = (1, 2) + else: + raise RuntimeError() + + # todo this function is stupid. It doesn't even use the correct scale range (we keep things as they were in the + # old nnunet for now) + initial_patch_size = get_patch_size(patch_size[-dim:], + *rotation_for_DA.values(), + (0.85, 1.25)) + if do_dummy_2d_data_aug: + initial_patch_size[0] = patch_size[0] + + self.print_to_log_file(f'do_dummy_2d_data_aug: {do_dummy_2d_data_aug}') + self.inference_allowed_mirroring_axes = mirror_axes + + return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes + + def print_to_log_file(self, *args, also_print_to_console=True, add_timestamp=True): + if self.local_rank == 0: + timestamp = time() + dt_object = datetime.fromtimestamp(timestamp) + + if add_timestamp: + args = ("%s:" % dt_object, *args) + + successful = False + max_attempts = 5 + ctr = 0 + while not successful and ctr < max_attempts: + try: + with open(self.log_file, 'a+') as f: + for a in args: + f.write(str(a)) + f.write(" ") + f.write("\n") + successful = True + except IOError: + print("%s: failed to log: " % datetime.fromtimestamp(timestamp), sys.exc_info()) + sleep(0.5) + ctr += 1 + if also_print_to_console: + print(*args) + elif also_print_to_console: + print(*args) + + def print_plans(self): + if self.local_rank == 0: + dct = deepcopy(self.plans_manager.plans) + del dct['configurations'] + self.print_to_log_file(f"\nThis is the configuration used by this " + f"training:\nConfiguration name: {self.configuration_name}\n", + self.configuration_manager, '\n', add_timestamp=False) + self.print_to_log_file('These are the global plan.json settings:\n', dct, '\n', add_timestamp=False) + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, + momentum=0.99, nesterov=True) + lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs, exponent=3.) + return optimizer, lr_scheduler + + def plot_network_architecture(self): + if self.local_rank == 0: + try: + # raise NotImplementedError('hiddenlayer no longer works and we do not have a viable alternative :-(') + # pip install git+https://github.com/saugatkandel/hiddenlayer.git + + # from torchviz import make_dot + # # not viable. + # make_dot(tuple(self.network(torch.rand((1, self.num_input_channels, + # *self.configuration_manager.patch_size), + # device=self.device)))).render( + # join(self.output_folder, "network_architecture.pdf"), format='pdf') + # self.optimizer.zero_grad() + + # broken. + + import hiddenlayer as hl + g = hl.build_graph(self.network, + torch.rand((1, self.num_input_channels, + *self.configuration_manager.patch_size), + device=self.device), + transforms=None) + g.save(join(self.output_folder, "network_architecture.pdf")) + del g + except Exception as e: + self.print_to_log_file("Unable to plot network architecture:") + self.print_to_log_file(e) + + # self.print_to_log_file("\nprinting the network instead:\n") + # self.print_to_log_file(self.network) + # self.print_to_log_file("\n") + finally: + empty_cache(self.device) + + def do_split(self): + """ + The default split is a 5 fold CV on all available training cases. nnU-Net will create a split (it is seeded, + so always the same) and save it as splits_final.pkl file in the preprocessed data directory. + Sometimes you may want to create your own split for various reasons. For this you will need to create your own + splits_final.pkl file. If this file is present, nnU-Net is going to use it and whatever splits are defined in + it. You can create as many splits in this file as you want. Note that if you define only 4 splits (fold 0-3) + and then set fold=4 when training (that would be the fifth split), nnU-Net will print a warning and proceed to + use a random 80:20 data split. + :return: + """ + if self.fold == "all": + # if fold==all then we use all images for training and validation + case_identifiers = get_case_identifiers(self.preprocessed_dataset_folder) + tr_keys = case_identifiers + val_keys = tr_keys + else: + splits_file = join(self.preprocessed_dataset_folder_base, "splits_final.json") + dataset = nnUNetDataset(self.preprocessed_dataset_folder, case_identifiers=None, + num_images_properties_loading_threshold=0, + folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage) + # if the split file does not exist we need to create it + if not isfile(splits_file): + self.print_to_log_file("Creating new 5-fold cross-validation split...") + splits = [] + all_keys_sorted = np.sort(list(dataset.keys())) + kfold = KFold(n_splits=5, shuffle=True, random_state=12345) + for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)): + train_keys = np.array(all_keys_sorted)[train_idx] + test_keys = np.array(all_keys_sorted)[test_idx] + splits.append({}) + splits[-1]['train'] = list(train_keys) + splits[-1]['val'] = list(test_keys) + save_json(splits, splits_file) + + else: + self.print_to_log_file("Using splits from existing split file:", splits_file) + splits = load_json(splits_file) + self.print_to_log_file("The split file contains %d splits." % len(splits)) + + self.print_to_log_file("Desired fold for training: %d" % self.fold) + if self.fold < len(splits): + tr_keys = splits[self.fold]['train'] + val_keys = splits[self.fold]['val'] + self.print_to_log_file("This split has %d training and %d validation cases." + % (len(tr_keys), len(val_keys))) + else: + self.print_to_log_file("INFO: You requested fold %d for training but splits " + "contain only %d folds. I am now creating a " + "random (but seeded) 80:20 split!" % (self.fold, len(splits))) + # if we request a fold that is not in the split file, create a random 80:20 split + rnd = np.random.RandomState(seed=12345 + self.fold) + keys = np.sort(list(dataset.keys())) + idx_tr = rnd.choice(len(keys), int(len(keys) * 0.8), replace=False) + idx_val = [i for i in range(len(keys)) if i not in idx_tr] + tr_keys = [keys[i] for i in idx_tr] + val_keys = [keys[i] for i in idx_val] + self.print_to_log_file("This random 80:20 split has %d training and %d validation cases." + % (len(tr_keys), len(val_keys))) + if any([i in val_keys for i in tr_keys]): + self.print_to_log_file('WARNING: Some validation cases are also in the training set. Please check the ' + 'splits.json or ignore if this is intentional.') + return tr_keys, val_keys + + def get_tr_and_val_datasets(self): + # create dataset split + tr_keys, val_keys = self.do_split() + + # load the datasets for training and validation. Note that we always draw random samples so we really don't + # care about distributing training cases across GPUs. + dataset_tr = nnUNetDataset(self.preprocessed_dataset_folder, tr_keys, + folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage, + num_images_properties_loading_threshold=0) + dataset_val = nnUNetDataset(self.preprocessed_dataset_folder, val_keys, + folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage, + num_images_properties_loading_threshold=0) + return dataset_tr, dataset_val + + def get_dataloaders(self): + # we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether + # we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be + patch_size = self.configuration_manager.patch_size + dim = len(patch_size) + + # needed for deep supervision: how much do we need to downscale the segmentation targets for the different + # outputs? + deep_supervision_scales = self._get_deep_supervision_scales() + + rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ + self.configure_rotation_dummyDA_mirroring_and_inital_patch_size() + + # training pipeline + tr_transforms = self.get_training_transforms( + patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug, + order_resampling_data=3, order_resampling_seg=1, + use_mask_for_norm=self.configuration_manager.use_mask_for_norm, + is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.foreground_labels, + regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + # validation pipeline + val_transforms = self.get_validation_transforms(deep_supervision_scales, + is_cascaded=self.is_cascaded, + foreground_labels=self.label_manager.foreground_labels, + regions=self.label_manager.foreground_regions if + self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim) + + allowed_num_processes = get_allowed_n_proc_DA() + if allowed_num_processes == 0: + mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms) + mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms) + else: + mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, data_loader=dl_tr, transform=tr_transforms, + num_processes=allowed_num_processes, num_cached=6, seeds=None, + pin_memory=self.device.type == 'cuda', wait_time=0.02) + mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, data_loader=dl_val, + transform=val_transforms, num_processes=max(1, allowed_num_processes // 2), + num_cached=3, seeds=None, pin_memory=self.device.type == 'cuda', + wait_time=0.02) + return mt_gen_train, mt_gen_val + + def get_plain_dataloaders(self, initial_patch_size: Tuple[int, ...], dim: int): + dataset_tr, dataset_val = self.get_tr_and_val_datasets() + + if dim == 2: + dl_tr = nnUNetDataLoader2D(dataset_tr, self.batch_size, + initial_patch_size, + self.configuration_manager.patch_size, + self.label_manager, + oversample_foreground_percent=self.oversample_foreground_percent, + sampling_probabilities=None, pad_sides=None) + dl_val = nnUNetDataLoader2D(dataset_val, self.batch_size, + self.configuration_manager.patch_size, + self.configuration_manager.patch_size, + self.label_manager, + oversample_foreground_percent=self.oversample_foreground_percent, + sampling_probabilities=None, pad_sides=None) + else: + dl_tr = nnUNetDataLoader3D(dataset_tr, self.batch_size, + initial_patch_size, + self.configuration_manager.patch_size, + self.label_manager, + oversample_foreground_percent=self.oversample_foreground_percent, + sampling_probabilities=None, pad_sides=None) + dl_val = nnUNetDataLoader3D(dataset_val, self.batch_size, + self.configuration_manager.patch_size, + self.configuration_manager.patch_size, + self.label_manager, + oversample_foreground_percent=self.oversample_foreground_percent, + sampling_probabilities=None, pad_sides=None) + return dl_tr, dl_val + + @staticmethod + def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]], + rotation_for_DA: dict, + deep_supervision_scales: Union[List, Tuple], + mirror_axes: Tuple[int, ...], + do_dummy_2d_data_aug: bool, + order_resampling_data: int = 3, + order_resampling_seg: int = 1, + border_val_seg: int = -1, + use_mask_for_norm: List[bool] = None, + is_cascaded: bool = False, + foreground_labels: Union[Tuple[int, ...], List[int]] = None, + regions: List[Union[List[int], Tuple[int, ...], int]] = None, + ignore_label: int = None) -> AbstractTransform: + tr_transforms = [] + if do_dummy_2d_data_aug: + ignore_axes = (0,) + tr_transforms.append(Convert3DTo2DTransform()) + patch_size_spatial = patch_size[1:] + else: + patch_size_spatial = patch_size + ignore_axes = None + + tr_transforms.append(SpatialTransform( + patch_size_spatial, patch_center_dist_from_border=None, + do_elastic_deform=False, alpha=(0, 0), sigma=(0, 0), + do_rotation=True, angle_x=rotation_for_DA['x'], angle_y=rotation_for_DA['y'], angle_z=rotation_for_DA['z'], + p_rot_per_axis=1, # todo experiment with this + do_scale=True, scale=(0.7, 1.4), + border_mode_data="constant", border_cval_data=0, order_data=order_resampling_data, + border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=order_resampling_seg, + random_crop=False, # random cropping is part of our dataloaders + p_el_per_sample=0, p_scale_per_sample=0.2, p_rot_per_sample=0.2, + independent_scale_for_each_axis=False # todo experiment with this + )) + + if do_dummy_2d_data_aug: + tr_transforms.append(Convert2DTo3DTransform()) + + tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1)) + tr_transforms.append(GaussianBlurTransform((0.5, 1.), different_sigma_per_channel=True, p_per_sample=0.2, + p_per_channel=0.5)) + tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25), p_per_sample=0.15)) + tr_transforms.append(ContrastAugmentationTransform(p_per_sample=0.15)) + tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True, + p_per_channel=0.5, + order_downsample=0, order_upsample=3, p_per_sample=0.25, + ignore_axes=ignore_axes)) + tr_transforms.append(GammaTransform((0.7, 1.5), True, True, retain_stats=True, p_per_sample=0.1)) + tr_transforms.append(GammaTransform((0.7, 1.5), False, True, retain_stats=True, p_per_sample=0.3)) + + if mirror_axes is not None and len(mirror_axes) > 0: + tr_transforms.append(MirrorTransform(mirror_axes)) + + if use_mask_for_norm is not None and any(use_mask_for_norm): + tr_transforms.append(MaskTransform([i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]], + mask_idx_in_seg=0, set_outside_to=0)) + + tr_transforms.append(RemoveLabelTransform(-1, 0)) + + if is_cascaded: + assert foreground_labels is not None, 'We need foreground_labels for cascade augmentations' + tr_transforms.append(MoveSegAsOneHotToData(1, foreground_labels, 'seg', 'data')) + tr_transforms.append(ApplyRandomBinaryOperatorTransform( + channel_idx=list(range(-len(foreground_labels), 0)), + p_per_sample=0.4, + key="data", + strel_size=(1, 8), + p_per_label=1)) + tr_transforms.append( + RemoveRandomConnectedComponentFromOneHotEncodingTransform( + channel_idx=list(range(-len(foreground_labels), 0)), + key="data", + p_per_sample=0.2, + fill_with_other_class_p=0, + dont_do_if_covers_more_than_x_percent=0.15)) + + tr_transforms.append(RenameTransform('seg', 'target', True)) + + if regions is not None: + # the ignore label must also be converted + tr_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label] + if ignore_label is not None else regions, + 'target', 'target')) + + if deep_supervision_scales is not None: + tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target', + output_key='target')) + tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) + tr_transforms = Compose(tr_transforms) + return tr_transforms + + @staticmethod + def get_validation_transforms(deep_supervision_scales: Union[List, Tuple], + is_cascaded: bool = False, + foreground_labels: Union[Tuple[int, ...], List[int]] = None, + regions: List[Union[List[int], Tuple[int, ...], int]] = None, + ignore_label: int = None) -> AbstractTransform: + val_transforms = [] + val_transforms.append(RemoveLabelTransform(-1, 0)) + + if is_cascaded: + val_transforms.append(MoveSegAsOneHotToData(1, foreground_labels, 'seg', 'data')) + + val_transforms.append(RenameTransform('seg', 'target', True)) + + if regions is not None: + # the ignore label must also be converted + val_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label] + if ignore_label is not None else regions, + 'target', 'target')) + + if deep_supervision_scales is not None: + val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target', + output_key='target')) + + val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) + val_transforms = Compose(val_transforms) + return val_transforms + + def set_deep_supervision_enabled(self, enabled: bool): + """ + This function is specific for the default architecture in nnU-Net. If you change the architecture, there are + chances you need to change this as well! + """ + if self.is_ddp: + self.network.module.decoder.deep_supervision = enabled + else: + self.network.decoder.deep_supervision = enabled + + def on_train_start(self): + if not self.was_initialized: + self.initialize() + + maybe_mkdir_p(self.output_folder) + + # make sure deep supervision is on in the network + self.set_deep_supervision_enabled(True) + + self.print_plans() + empty_cache(self.device) + + # maybe unpack + if self.unpack_dataset and self.local_rank == 0: + self.print_to_log_file('unpacking dataset...') + unpack_dataset(self.preprocessed_dataset_folder, unpack_segmentation=True, overwrite_existing=False, + num_processes=max(1, round(get_allowed_n_proc_DA() // 2))) + self.print_to_log_file('unpacking done...') + + if self.is_ddp: + dist.barrier() + + # dataloaders must be instantiated here because they need access to the training data which may not be present + # when doing inference + self.dataloader_train, self.dataloader_val = self.get_dataloaders() + + # copy plans and dataset.json so that they can be used for restoring everything we need for inference + save_json(self.plans_manager.plans, join(self.output_folder_base, 'plans.json'), sort_keys=False) + save_json(self.dataset_json, join(self.output_folder_base, 'dataset.json'), sort_keys=False) + + # we don't really need the fingerprint but its still handy to have it with the others + shutil.copy(join(self.preprocessed_dataset_folder_base, 'dataset_fingerprint.json'), + join(self.output_folder_base, 'dataset_fingerprint.json')) + + # produces a pdf in output folder + self.plot_network_architecture() + + self._save_debug_information() + + # print(f"batch size: {self.batch_size}") + # print(f"oversample: {self.oversample_foreground_percent}") + + def on_train_end(self): + self.save_checkpoint(join(self.output_folder, "checkpoint_final.pth")) + # now we can delete latest + if self.local_rank == 0 and isfile(join(self.output_folder, "checkpoint_latest.pth")): + os.remove(join(self.output_folder, "checkpoint_latest.pth")) + + # shut down dataloaders + old_stdout = sys.stdout + with open(os.devnull, 'w') as f: + sys.stdout = f + if self.dataloader_train is not None: + self.dataloader_train._finish() + if self.dataloader_val is not None: + self.dataloader_val._finish() + sys.stdout = old_stdout + + empty_cache(self.device) + self.print_to_log_file("Training done.") + + def on_train_epoch_start(self): + self.network.train() + self.lr_scheduler.step(self.current_epoch) + self.print_to_log_file('') + self.print_to_log_file(f'Epoch {self.current_epoch}') + self.print_to_log_file( + f"Current learning rate: {np.round(self.optimizer.param_groups[0]['lr'], decimals=5)}") + # lrs are the same for all workers so we don't need to gather them in case of DDP training + self.logger.log('lrs', self.optimizer.param_groups[0]['lr'], self.current_epoch) + + def train_step(self, batch: dict) -> dict: + data = batch['data'] + target = batch['target'] + + data = data.to(self.device, non_blocking=True) + if isinstance(target, list): + target = [i.to(self.device, non_blocking=True) for i in target] + else: + target = target.to(self.device, non_blocking=True) + + self.optimizer.zero_grad() + # Autocast is a little bitch. + # If the device_type is 'cpu' then it's slow as heck and needs to be disabled. + # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False) + # So autocast will only be active if we have a cuda device. + with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): + output = self.network(data) + # del data + l = self.loss(output, target) + + if self.grad_scaler is not None: + self.grad_scaler.scale(l).backward() + self.grad_scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() + else: + l.backward() + torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) + self.optimizer.step() + return {'loss': l.detach().cpu().numpy()} + + def on_train_epoch_end(self, train_outputs: List[dict]): + outputs = collate_outputs(train_outputs) + + if self.is_ddp: + losses_tr = [None for _ in range(dist.get_world_size())] + dist.all_gather_object(losses_tr, outputs['loss']) + loss_here = np.vstack(losses_tr).mean() + else: + loss_here = np.mean(outputs['loss']) + + self.logger.log('train_losses', loss_here, self.current_epoch) + + def on_validation_epoch_start(self): + self.network.eval() + + def validation_step(self, batch: dict) -> dict: + data = batch['data'] + target = batch['target'] + + data = data.to(self.device, non_blocking=True) + if isinstance(target, list): + target = [i.to(self.device, non_blocking=True) for i in target] + else: + target = target.to(self.device, non_blocking=True) + + # Autocast is a little bitch. + # If the device_type is 'cpu' then it's slow as heck and needs to be disabled. + # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False) + # So autocast will only be active if we have a cuda device. + with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): + output = self.network(data) + del data + l = self.loss(output, target) + + # we only need the output with the highest output resolution + output = output[0] + target = target[0] + + # the following is needed for online evaluation. Fake dice (green line) + axes = [0] + list(range(2, len(output.shape))) + + if self.label_manager.has_regions: + predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long() + else: + # no need for softmax + output_seg = output.argmax(1)[:, None] + predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float32) + predicted_segmentation_onehot.scatter_(1, output_seg, 1) + del output_seg + + if self.label_manager.has_ignore_label: + if not self.label_manager.has_regions: + mask = (target != self.label_manager.ignore_label).float() + # CAREFUL that you don't rely on target after this line! + target[target == self.label_manager.ignore_label] = 0 + else: + mask = 1 - target[:, -1:] + # CAREFUL that you don't rely on target after this line! + target = target[:, :-1] + else: + mask = None + + tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask) + + tp_hard = tp.detach().cpu().numpy() + fp_hard = fp.detach().cpu().numpy() + fn_hard = fn.detach().cpu().numpy() + if not self.label_manager.has_regions: + # if we train with regions all segmentation heads predict some kind of foreground. In conventional + # (softmax training) there needs tobe one output for the background. We are not interested in the + # background Dice + # [1:] in order to remove background + tp_hard = tp_hard[1:] + fp_hard = fp_hard[1:] + fn_hard = fn_hard[1:] + + return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard} + + def on_validation_epoch_end(self, val_outputs: List[dict]): + outputs_collated = collate_outputs(val_outputs) + tp = np.sum(outputs_collated['tp_hard'], 0) + fp = np.sum(outputs_collated['fp_hard'], 0) + fn = np.sum(outputs_collated['fn_hard'], 0) + + if self.is_ddp: + world_size = dist.get_world_size() + + tps = [None for _ in range(world_size)] + dist.all_gather_object(tps, tp) + tp = np.vstack([i[None] for i in tps]).sum(0) + + fps = [None for _ in range(world_size)] + dist.all_gather_object(fps, fp) + fp = np.vstack([i[None] for i in fps]).sum(0) + + fns = [None for _ in range(world_size)] + dist.all_gather_object(fns, fn) + fn = np.vstack([i[None] for i in fns]).sum(0) + + losses_val = [None for _ in range(world_size)] + dist.all_gather_object(losses_val, outputs_collated['loss']) + loss_here = np.vstack(losses_val).mean() + else: + loss_here = np.mean(outputs_collated['loss']) + + global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in + zip(tp, fp, fn)]] + mean_fg_dice = np.nanmean(global_dc_per_class) + self.logger.log('mean_fg_dice', mean_fg_dice, self.current_epoch) + self.logger.log('dice_per_class_or_region', global_dc_per_class, self.current_epoch) + self.logger.log('val_losses', loss_here, self.current_epoch) + + def on_epoch_start(self): + self.logger.log('epoch_start_timestamps', time(), self.current_epoch) + + def on_epoch_end(self): + self.logger.log('epoch_end_timestamps', time(), self.current_epoch) + + # todo find a solution for this stupid shit + self.print_to_log_file('train_loss', np.round(self.logger.my_fantastic_logging['train_losses'][-1], decimals=4)) + self.print_to_log_file('val_loss', np.round(self.logger.my_fantastic_logging['val_losses'][-1], decimals=4)) + self.print_to_log_file('Pseudo dice', [np.round(i, decimals=4) for i in + self.logger.my_fantastic_logging['dice_per_class_or_region'][-1]]) + self.print_to_log_file( + f"Epoch time: {np.round(self.logger.my_fantastic_logging['epoch_end_timestamps'][-1] - self.logger.my_fantastic_logging['epoch_start_timestamps'][-1], decimals=2)} s") + + # handling periodic checkpointing + current_epoch = self.current_epoch + if (current_epoch + 1) % self.save_every == 0 and current_epoch != (self.num_epochs - 1): + self.save_checkpoint(join(self.output_folder, 'checkpoint_latest.pth')) + + # handle 'best' checkpointing. ema_fg_dice is computed by the logger and can be accessed like this + if self._best_ema is None or self.logger.my_fantastic_logging['ema_fg_dice'][-1] > self._best_ema: + self._best_ema = self.logger.my_fantastic_logging['ema_fg_dice'][-1] + self.print_to_log_file(f"Yayy! New best EMA pseudo Dice: {np.round(self._best_ema, decimals=4)}") + self.save_checkpoint(join(self.output_folder, 'checkpoint_best.pth')) + + if self.local_rank == 0: + self.logger.plot_progress_png(self.output_folder) + + self.current_epoch += 1 + + def save_checkpoint(self, filename: str) -> None: + if self.local_rank == 0: + if not self.disable_checkpointing: + if self.is_ddp: + mod = self.network.module + else: + mod = self.network + if isinstance(mod, OptimizedModule): + mod = mod._orig_mod + + checkpoint = { + 'network_weights': mod.state_dict(), + 'optimizer_state': self.optimizer.state_dict(), + 'grad_scaler_state': self.grad_scaler.state_dict() if self.grad_scaler is not None else None, + 'logging': self.logger.get_checkpoint(), + '_best_ema': self._best_ema, + 'current_epoch': self.current_epoch + 1, + 'init_args': self.my_init_kwargs, + 'trainer_name': self.__class__.__name__, + 'inference_allowed_mirroring_axes': self.inference_allowed_mirroring_axes, + } + torch.save(checkpoint, filename) + else: + self.print_to_log_file('No checkpoint written, checkpointing is disabled') + + def load_checkpoint(self, filename_or_checkpoint: Union[dict, str]) -> None: + if not self.was_initialized: + self.initialize() + + if isinstance(filename_or_checkpoint, str): + checkpoint = torch.load(filename_or_checkpoint, map_location=self.device) + # if state dict comes from nn.DataParallel but we use non-parallel model here then the state dict keys do not + # match. Use heuristic to make it match + new_state_dict = {} + for k, value in checkpoint['network_weights'].items(): + key = k + if key not in self.network.state_dict().keys() and key.startswith('module.'): + key = key[7:] + new_state_dict[key] = value + + self.my_init_kwargs = checkpoint['init_args'] + self.current_epoch = checkpoint['current_epoch'] + self.logger.load_checkpoint(checkpoint['logging']) + self._best_ema = checkpoint['_best_ema'] + self.inference_allowed_mirroring_axes = checkpoint[ + 'inference_allowed_mirroring_axes'] if 'inference_allowed_mirroring_axes' in checkpoint.keys() else self.inference_allowed_mirroring_axes + + # messing with state dict naming schemes. Facepalm. + if self.is_ddp: + if isinstance(self.network.module, OptimizedModule): + self.network.module._orig_mod.load_state_dict(new_state_dict) + else: + self.network.module.load_state_dict(new_state_dict) + else: + if isinstance(self.network, OptimizedModule): + self.network._orig_mod.load_state_dict(new_state_dict) + else: + self.network.load_state_dict(new_state_dict) + self.optimizer.load_state_dict(checkpoint['optimizer_state']) + if self.grad_scaler is not None: + if checkpoint['grad_scaler_state'] is not None: + self.grad_scaler.load_state_dict(checkpoint['grad_scaler_state']) + + def perform_actual_validation(self, save_probabilities: bool = False): + self.set_deep_supervision_enabled(False) + self.network.eval() + + predictor = nnUNetPredictor(tile_step_size=0.5, use_gaussian=True, use_mirroring=True, + perform_everything_on_gpu=True, device=self.device, verbose=False, + verbose_preprocessing=False, allow_tqdm=False) + predictor.manual_initialization(self.network, self.plans_manager, self.configuration_manager, None, + self.dataset_json, self.__class__.__name__, + self.inference_allowed_mirroring_axes) + + with multiprocessing.get_context("spawn").Pool(default_num_processes) as segmentation_export_pool: + worker_list = [i for i in segmentation_export_pool._pool] + validation_output_folder = join(self.output_folder, 'validation') + maybe_mkdir_p(validation_output_folder) + + # we cannot use self.get_tr_and_val_datasets() here because we might be DDP and then we have to distribute + # the validation keys across the workers. + _, val_keys = self.do_split() + if self.is_ddp: + val_keys = val_keys[self.local_rank:: dist.get_world_size()] + + dataset_val = nnUNetDataset(self.preprocessed_dataset_folder, val_keys, + folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage, + num_images_properties_loading_threshold=0) + + next_stages = self.configuration_manager.next_stage_names + + if next_stages is not None: + _ = [maybe_mkdir_p(join(self.output_folder_base, 'predicted_next_stage', n)) for n in next_stages] + + results = [] + + for k in dataset_val.keys(): + proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results, + allowed_num_queued=2) + while not proceed: + sleep(0.1) + proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results, + allowed_num_queued=2) + + self.print_to_log_file(f"predicting {k}") + data, seg, properties = dataset_val.load_case(k) + + if self.is_cascaded: + data = np.vstack((data, convert_labelmap_to_one_hot(seg[-1], self.label_manager.foreground_labels, + output_dtype=data.dtype))) + with warnings.catch_warnings(): + # ignore 'The given NumPy array is not writable' warning + warnings.simplefilter("ignore") + data = torch.from_numpy(data) + + output_filename_truncated = join(validation_output_folder, k) + + try: + prediction = predictor.predict_sliding_window_return_logits(data) + except RuntimeError: + predictor.perform_everything_on_gpu = False + prediction = predictor.predict_sliding_window_return_logits(data) + predictor.perform_everything_on_gpu = True + + prediction = prediction.cpu() + + # this needs to go into background processes + results.append( + segmentation_export_pool.starmap_async( + export_prediction_from_logits, ( + (prediction, properties, self.configuration_manager, self.plans_manager, + self.dataset_json, output_filename_truncated, save_probabilities), + ) + ) + ) + # for debug purposes + # export_prediction(prediction_for_export, properties, self.configuration, self.plans, self.dataset_json, + # output_filename_truncated, save_probabilities) + + # if needed, export the softmax prediction for the next stage + if next_stages is not None: + for n in next_stages: + next_stage_config_manager = self.plans_manager.get_configuration(n) + expected_preprocessed_folder = join(nnUNet_preprocessed, self.plans_manager.dataset_name, + next_stage_config_manager.data_identifier) + + try: + # we do this so that we can use load_case and do not have to hard code how loading training cases is implemented + tmp = nnUNetDataset(expected_preprocessed_folder, [k], + num_images_properties_loading_threshold=0) + d, s, p = tmp.load_case(k) + except FileNotFoundError: + self.print_to_log_file( + f"Predicting next stage {n} failed for case {k} because the preprocessed file is missing! " + f"Run the preprocessing for this configuration first!") + continue + + target_shape = d.shape[1:] + output_folder = join(self.output_folder_base, 'predicted_next_stage', n) + output_file = join(output_folder, k + '.npz') + + # resample_and_save(prediction, target_shape, output_file, self.plans_manager, self.configuration_manager, properties, + # self.dataset_json) + results.append(segmentation_export_pool.starmap_async( + resample_and_save, ( + (prediction, target_shape, output_file, self.plans_manager, + self.configuration_manager, + properties, + self.dataset_json), + ) + )) + + _ = [r.get() for r in results] + + if self.is_ddp: + dist.barrier() + + if self.local_rank == 0: + metrics = compute_metrics_on_folder(join(self.preprocessed_dataset_folder_base, 'gt_segmentations'), + validation_output_folder, + join(validation_output_folder, 'summary.json'), + self.plans_manager.image_reader_writer_class(), + self.dataset_json["file_ending"], + self.label_manager.foreground_regions if self.label_manager.has_regions else + self.label_manager.foreground_labels, + self.label_manager.ignore_label, chill=True) + self.print_to_log_file("Validation complete", also_print_to_console=True) + self.print_to_log_file("Mean Validation Dice: ", (metrics['foreground_mean']["Dice"]), also_print_to_console=True) + + self.set_deep_supervision_enabled(True) + compute_gaussian.cache_clear() + + def run_training(self): + self.on_train_start() + + for epoch in range(self.current_epoch, self.num_epochs): + self.on_epoch_start() + + self.on_train_epoch_start() + train_outputs = [] + for batch_id in range(self.num_iterations_per_epoch): + train_outputs.append(self.train_step(next(self.dataloader_train))) + self.on_train_epoch_end(train_outputs) + + with torch.no_grad(): + self.on_validation_epoch_start() + val_outputs = [] + for batch_id in range(self.num_val_iterations_per_epoch): + val_outputs.append(self.validation_step(next(self.dataloader_val))) + self.on_validation_epoch_end(val_outputs) + + self.on_epoch_end() + + self.on_train_end() diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer_aghiles.py b/nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer_aghiles.py new file mode 100644 index 0000000..8b376eb --- /dev/null +++ b/nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer_aghiles.py @@ -0,0 +1,1181 @@ +import inspect +import multiprocessing +import os +import shutil +import sys +import warnings +from copy import deepcopy +from datetime import datetime +from time import time, sleep +from typing import Union, Tuple, List + +import numpy as np +import torch +from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter +from batchgenerators.transforms.abstract_transforms import AbstractTransform, Compose +from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \ + ContrastAugmentationTransform, GammaTransform +from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform +from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform +from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform +from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor +from batchgenerators.utilities.file_and_folder_operations import join, load_json, isfile, save_json, maybe_mkdir_p +from torch._dynamo import OptimizedModule + +from nnunetv2.configuration import ANISO_THRESHOLD, default_num_processes +from nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder +from nnunetv2.inference.export_prediction import export_prediction_from_logits, resample_and_save +from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor +from nnunetv2.inference.sliding_window_prediction import compute_gaussian +from nnunetv2.paths import nnUNet_preprocessed, nnUNet_results +from nnunetv2.training.data_augmentation.compute_initial_patch_size import get_patch_size +from nnunetv2.training.data_augmentation.custom_transforms.cascade_transforms import MoveSegAsOneHotToData, \ + ApplyRandomBinaryOperatorTransform, RemoveRandomConnectedComponentFromOneHotEncodingTransform +from nnunetv2.training.data_augmentation.custom_transforms.deep_supervision_donwsampling import \ + DownsampleSegForDSTransform2 +from nnunetv2.training.data_augmentation.custom_transforms.limited_length_multithreaded_augmenter import \ + LimitedLenWrapper +from nnunetv2.training.data_augmentation.custom_transforms.masking import MaskTransform +from nnunetv2.training.data_augmentation.custom_transforms.region_based_training import \ + ConvertSegmentationToRegionsTransform +from nnunetv2.training.data_augmentation.custom_transforms.transforms_for_dummy_2d import Convert2DTo3DTransform, \ + Convert3DTo2DTransform +from nnunetv2.training.dataloading.data_loader_2d import nnUNetDataLoader2D +from nnunetv2.training.dataloading.data_loader_3d import nnUNetDataLoader3D +from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset +from nnunetv2.training.dataloading.utils import get_case_identifiers, unpack_dataset +from nnunetv2.training.logging.nnunet_logger import nnUNetLogger +from nnunetv2.training.loss.compound_losses import DC_and_CE_loss, DC_and_BCE_loss +from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper +from nnunetv2.training.loss.dice import get_tp_fp_fn_tn, MemoryEfficientSoftDiceLoss +from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler +from nnunetv2.utilities.collate_outputs import collate_outputs +from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA +from nnunetv2.utilities.file_path_utilities import check_workers_alive_and_busy +from nnunetv2.utilities.get_network_from_plans import get_network_from_plans +from nnunetv2.utilities.helpers import empty_cache, dummy_context +from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot, determine_num_input_channels +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager +from sklearn.model_selection import KFold +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +from torch import autocast, nn +from torch import distributed as dist +from torch.cuda import device_count +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP + + +class nnUNetTrainer_aghiles(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + # From https://grugbrain.dev/. Worth a read ya big brains ;-) + + # apex predator of grug is complexity + # complexity bad + # say again: + # complexity very bad + # you say now: + # complexity very, very bad + # given choice between complexity or one on one against t-rex, grug take t-rex: at least grug see t-rex + # complexity is spirit demon that enter codebase through well-meaning but ultimately very clubbable non grug-brain developers and project managers who not fear complexity spirit demon or even know about sometime + # one day code base understandable and grug can get work done, everything good! + # next day impossible: complexity demon spirit has entered code and very dangerous situation! + + # OK OK I am guilty. But I tried. http://tiny.cc/gzgwuz + + self.is_ddp = dist.is_available() and dist.is_initialized() + self.local_rank = 0 if not self.is_ddp else dist.get_rank() + + self.device = device + + # print what device we are using + if self.is_ddp: # implicitly it's clear that we use cuda in this case + print(f"I am local rank {self.local_rank}. {device_count()} GPUs are available. The world size is " + f"{dist.get_world_size()}." + f"Setting device to {self.device}") + self.device = torch.device(type='cuda', index=self.local_rank) + else: + if self.device.type == 'cuda': + # we might want to let the user pick this but for now please pick the correct GPU with CUDA_VISIBLE_DEVICES=X + self.device = torch.device(type='cuda', index=0) + print(f"Using device: {self.device}") + + # loading and saving this class for continuing from checkpoint should not happen based on pickling. This + # would also pickle the network etc. Bad, bad. Instead we just reinstantiate and then load the checkpoint we + # need. So let's save the init args + self.my_init_kwargs = {} + for k in inspect.signature(self.__init__).parameters.keys(): + self.my_init_kwargs[k] = locals()[k] + + ### Saving all the init args into class variables for later access + self.plans_manager = PlansManager(plans) + self.configuration_manager = self.plans_manager.get_configuration(configuration) + self.configuration_name = configuration + self.dataset_json = dataset_json + self.fold = fold + self.unpack_dataset = unpack_dataset + + ### Setting all the folder names. We need to make sure things don't crash in case we are just running + # inference and some of the folders may not be defined! + self.preprocessed_dataset_folder_base = join(nnUNet_preprocessed, self.plans_manager.dataset_name) \ + if nnUNet_preprocessed is not None else None + self.output_folder_base = join(nnUNet_results, self.plans_manager.dataset_name, + self.__class__.__name__ + '__' + self.plans_manager.plans_name + "__" + configuration) \ + if nnUNet_results is not None else None + self.output_folder = join(self.output_folder_base, f'fold_{fold}') + + self.preprocessed_dataset_folder = join(self.preprocessed_dataset_folder_base, + self.configuration_manager.data_identifier) + # unlike the previous nnunet folder_with_segs_from_previous_stage is now part of the plans. For now it has to + # be a different configuration in the same plans + # IMPORTANT! the mapping must be bijective, so lowres must point to fullres and vice versa (using + # "previous_stage" and "next_stage"). Otherwise it won't work! + self.is_cascaded = self.configuration_manager.previous_stage_name is not None + self.folder_with_segs_from_previous_stage = \ + join(nnUNet_results, self.plans_manager.dataset_name, + self.__class__.__name__ + '__' + self.plans_manager.plans_name + "__" + + self.configuration_manager.previous_stage_name, 'predicted_next_stage', self.configuration_name) \ + if self.is_cascaded else None + + ### Some hyperparameters for you to fiddle with + self.initial_lr = 1e-2 + self.weight_decay = 3e-5 + self.oversample_foreground_percent = 0.33 + self.num_iterations_per_epoch = 50 + self.num_val_iterations_per_epoch = 30 + self.num_epochs = 40 + self.current_epoch = 0 + + ### Dealing with labels/regions + self.label_manager = self.plans_manager.get_label_manager(dataset_json) + # labels can either be a list of int (regular training) or a list of tuples of int (region-based training) + # needed for predictions. We do sigmoid in case of (overlapping) regions + + self.num_input_channels = None # -> self.initialize() + self.network = None # -> self._get_network() + self.optimizer = self.lr_scheduler = None # -> self.initialize + self.grad_scaler = GradScaler() if self.device.type == 'cuda' else None + self.loss = None # -> self.initialize + + ### Simple logging. Don't take that away from me! + # initialize log file. This is just our log for the print statements etc. Not to be confused with lightning + # logging + timestamp = datetime.now() + maybe_mkdir_p(self.output_folder) + self.log_file = join(self.output_folder, "training_log_%d_%d_%d_%02.0d_%02.0d_%02.0d.txt" % + (timestamp.year, timestamp.month, timestamp.day, timestamp.hour, timestamp.minute, + timestamp.second)) + self.logger = nnUNetLogger() + + ### placeholders + self.dataloader_train = self.dataloader_val = None # see on_train_start + + ### initializing stuff for remembering things and such + self._best_ema = None + + ### inference things + self.inference_allowed_mirroring_axes = None # this variable is set in + # self.configure_rotation_dummyDA_mirroring_and_inital_patch_size and will be saved in checkpoints + + ### checkpoint saving stuff + self.save_every = 50 + self.disable_checkpointing = False + + ## DDP batch size and oversampling can differ between workers and needs adaptation + # we need to change the batch size in DDP because we don't use any of those distributed samplers + self._set_batch_size_and_oversample() + + self.was_initialized = False + + self.print_to_log_file("\n#######################################################################\n" + "Please cite the following paper when using nnU-Net:\n" + "Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). " + "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. " + "Nature methods, 18(2), 203-211.\n" + "#######################################################################\n", + also_print_to_console=True, add_timestamp=False) + + def initialize(self): + if not self.was_initialized: + self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager, + self.dataset_json) + + self.network = self.build_network_architecture(self.plans_manager, self.dataset_json, + self.configuration_manager, + self.num_input_channels, + enable_deep_supervision=False).to(self.device) + # compile network for free speedup + if ('nnUNet_compile' in os.environ.keys()) and ( + os.environ['nnUNet_compile'].lower() in ('true', '1', 't')): + self.print_to_log_file('Compiling network...') + self.network = torch.compile(self.network) + + self.optimizer, self.lr_scheduler = self.configure_optimizers() + # if ddp, wrap in DDP wrapper + if self.is_ddp: + self.network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.network) + self.network = DDP(self.network, device_ids=[self.local_rank]) + + self.loss = self._build_loss() + self.was_initialized = True + else: + raise RuntimeError("You have called self.initialize even though the trainer was already initialized. " + "That should not happen.") + + def _save_debug_information(self): + # saving some debug information + if self.local_rank == 0: + dct = {} + for k in self.__dir__(): + if not k.startswith("__"): + if not callable(getattr(self, k)) or k in ['loss', ]: + dct[k] = str(getattr(self, k)) + elif k in ['network', ]: + dct[k] = str(getattr(self, k).__class__.__name__) + else: + # print(k) + pass + if k in ['dataloader_train', 'dataloader_val']: + if hasattr(getattr(self, k), 'generator'): + dct[k + '.generator'] = str(getattr(self, k).generator) + if hasattr(getattr(self, k), 'num_processes'): + dct[k + '.num_processes'] = str(getattr(self, k).num_processes) + if hasattr(getattr(self, k), 'transform'): + dct[k + '.transform'] = str(getattr(self, k).transform) + import subprocess + hostname = subprocess.getoutput(['hostname']) + dct['hostname'] = hostname + torch_version = torch.__version__ + if self.device.type == 'cuda': + gpu_name = torch.cuda.get_device_name() + dct['gpu_name'] = gpu_name + cudnn_version = torch.backends.cudnn.version() + else: + cudnn_version = 'None' + dct['device'] = str(self.device) + dct['torch_version'] = torch_version + dct['cudnn_version'] = cudnn_version + save_json(dct, join(self.output_folder, "debug.json")) + + @staticmethod + def build_network_architecture(plans_manager: PlansManager, + dataset_json, + configuration_manager: ConfigurationManager, + num_input_channels, + enable_deep_supervision: bool = False) -> nn.Module: + """ + his is where you build the architecture according to the plans. There is no obligation to use + get_network_from_plans, this is just a utility we use for the nnU-Net default architectures. You can do what + you want. Even ignore the plans and just return something static (as long as it can process the requested + patch size) + but don't bug us with your bugs arising from fiddling with this :-P + This is the function that is called in inference as well! This is needed so that all network architecture + variants can be loaded at inference time (inference will use the same nnUNetTrainer that was used for + training, so if you change the network architecture during training by deriving a new trainer class then + inference will know about it). + + If you need to know how many segmentation outputs your custom architecture needs to have, use the following snippet: + > label_manager = plans_manager.get_label_manager(dataset_json) + > label_manager.num_segmentation_heads + (why so complicated? -> We can have either classical training (classes) or regions. If we have regions, + the number of outputs is != the number of classes. Also there is the ignore label for which no output + should be generated. label_manager takes care of all that for you.) + + """ + return get_network_from_plans(plans_manager, dataset_json, configuration_manager, + num_input_channels, deep_supervision=enable_deep_supervision) + + def _get_deep_supervision_scales(self): + deep_supervision_scales = list(list(i) for i in 1 / np.cumprod(np.vstack( + self.configuration_manager.pool_op_kernel_sizes), axis=0))[:-1] + return deep_supervision_scales + + def _set_batch_size_and_oversample(self): + if not self.is_ddp: + # set batch size to what the plan says, leave oversample untouched + self.batch_size = self.configuration_manager.batch_size + else: + # batch size is distributed over DDP workers and we need to change oversample_percent for each worker + batch_sizes = [] + oversample_percents = [] + + world_size = dist.get_world_size() + my_rank = dist.get_rank() + + global_batch_size = self.configuration_manager.batch_size + assert global_batch_size >= world_size, 'Cannot run DDP if the batch size is smaller than the number of ' \ + 'GPUs... Duh.' + + batch_size_per_GPU = np.ceil(global_batch_size / world_size).astype(int) + + for rank in range(world_size): + if (rank + 1) * batch_size_per_GPU > global_batch_size: + batch_size = batch_size_per_GPU - ((rank + 1) * batch_size_per_GPU - global_batch_size) + else: + batch_size = batch_size_per_GPU + + batch_sizes.append(batch_size) + + sample_id_low = 0 if len(batch_sizes) == 0 else np.sum(batch_sizes[:-1]) + sample_id_high = np.sum(batch_sizes) + + if sample_id_high / global_batch_size < (1 - self.oversample_foreground_percent): + oversample_percents.append(0.0) + elif sample_id_low / global_batch_size > (1 - self.oversample_foreground_percent): + oversample_percents.append(1.0) + else: + percent_covered_by_this_rank = sample_id_high / global_batch_size - sample_id_low / global_batch_size + oversample_percent_here = 1 - (((1 - self.oversample_foreground_percent) - + sample_id_low / global_batch_size) / percent_covered_by_this_rank) + oversample_percents.append(oversample_percent_here) + + print("worker", my_rank, "oversample", oversample_percents[my_rank]) + print("worker", my_rank, "batch_size", batch_sizes[my_rank]) + # self.print_to_log_file("worker", my_rank, "oversample", oversample_percents[my_rank]) + # self.print_to_log_file("worker", my_rank, "batch_size", batch_sizes[my_rank]) + + self.batch_size = batch_sizes[my_rank] + self.oversample_foreground_percent = oversample_percents[my_rank] + + def _build_loss(self): + if self.label_manager.has_regions: + loss = DC_and_BCE_loss({}, + {'batch_dice': self.configuration_manager.batch_dice, + 'do_bg': True, 'smooth': 1e-5, 'ddp': self.is_ddp}, + use_ignore_label=self.label_manager.ignore_label is not None, + dice_class=MemoryEfficientSoftDiceLoss) + else: + loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice, + 'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1, + ignore_label=self.label_manager.ignore_label, dice_class=MemoryEfficientSoftDiceLoss) + + deep_supervision_scales = self._get_deep_supervision_scales() + + # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) + weights[-1] = 0 + + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) + return loss + + def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): + """ + This function is stupid and certainly one of the weakest spots of this implementation. Not entirely sure how we can fix it. + """ + patch_size = self.configuration_manager.patch_size + dim = len(patch_size) + # todo rotation should be defined dynamically based on patch size (more isotropic patch sizes = more rotation) + if dim == 2: + do_dummy_2d_data_aug = False + # todo revisit this parametrization + if max(patch_size) / min(patch_size) > 1.5: + rotation_for_DA = { + 'x': (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi), + 'y': (0, 0), + 'z': (0, 0) + } + else: + rotation_for_DA = { + 'x': (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi), + 'y': (0, 0), + 'z': (0, 0) + } + mirror_axes = (0, 1) + elif dim == 3: + # todo this is not ideal. We could also have patch_size (64, 16, 128) in which case a full 180deg 2d rot would be bad + # order of the axes is determined by spacing, not image size + do_dummy_2d_data_aug = (max(patch_size) / patch_size[0]) > ANISO_THRESHOLD + if do_dummy_2d_data_aug: + # why do we rotate 180 deg here all the time? We should also restrict it + rotation_for_DA = { + 'x': (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi), + 'y': (0, 0), + 'z': (0, 0) + } + else: + rotation_for_DA = { + 'x': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), + 'y': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), + 'z': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), + } + mirror_axes = (0, 1, 2) + else: + raise RuntimeError() + + # todo this function is stupid. It doesn't even use the correct scale range (we keep things as they were in the + # old nnunet for now) + initial_patch_size = get_patch_size(patch_size[-dim:], + *rotation_for_DA.values(), + (0.85, 1.25)) + if do_dummy_2d_data_aug: + initial_patch_size[0] = patch_size[0] + + self.print_to_log_file(f'do_dummy_2d_data_aug: {do_dummy_2d_data_aug}') + self.inference_allowed_mirroring_axes = mirror_axes + + return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes + + def print_to_log_file(self, *args, also_print_to_console=True, add_timestamp=True): + if self.local_rank == 0: + timestamp = time() + dt_object = datetime.fromtimestamp(timestamp) + + if add_timestamp: + args = ("%s:" % dt_object, *args) + + successful = False + max_attempts = 5 + ctr = 0 + while not successful and ctr < max_attempts: + try: + with open(self.log_file, 'a+') as f: + for a in args: + f.write(str(a)) + f.write(" ") + f.write("\n") + successful = True + except IOError: + print("%s: failed to log: " % datetime.fromtimestamp(timestamp), sys.exc_info()) + sleep(0.5) + ctr += 1 + if also_print_to_console: + print(*args) + elif also_print_to_console: + print(*args) + + def print_plans(self): + if self.local_rank == 0: + dct = deepcopy(self.plans_manager.plans) + del dct['configurations'] + self.print_to_log_file(f"\nThis is the configuration used by this " + f"training:\nConfiguration name: {self.configuration_name}\n", + self.configuration_manager, '\n', add_timestamp=False) + self.print_to_log_file('These are the global plan.json settings:\n', dct, '\n', add_timestamp=False) + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, + momentum=0.99, nesterov=True) + lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs, exponent=3.) + return optimizer, lr_scheduler + + def plot_network_architecture(self): + if self.local_rank == 0: + try: + # raise NotImplementedError('hiddenlayer no longer works and we do not have a viable alternative :-(') + # pip install git+https://github.com/saugatkandel/hiddenlayer.git + + # from torchviz import make_dot + # # not viable. + # make_dot(tuple(self.network(torch.rand((1, self.num_input_channels, + # *self.configuration_manager.patch_size), + # device=self.device)))).render( + # join(self.output_folder, "network_architecture.pdf"), format='pdf') + # self.optimizer.zero_grad() + + # broken. + + import hiddenlayer as hl + g = hl.build_graph(self.network, + torch.rand((1, self.num_input_channels, + *self.configuration_manager.patch_size), + device=self.device), + transforms=None) + g.save(join(self.output_folder, "network_architecture.pdf")) + del g + except Exception as e: + self.print_to_log_file("Unable to plot network architecture:") + self.print_to_log_file(e) + + # self.print_to_log_file("\nprinting the network instead:\n") + # self.print_to_log_file(self.network) + # self.print_to_log_file("\n") + finally: + empty_cache(self.device) + + def do_split(self): + """ + The default split is a 5 fold CV on all available training cases. nnU-Net will create a split (it is seeded, + so always the same) and save it as splits_final.pkl file in the preprocessed data directory. + Sometimes you may want to create your own split for various reasons. For this you will need to create your own + splits_final.pkl file. If this file is present, nnU-Net is going to use it and whatever splits are defined in + it. You can create as many splits in this file as you want. Note that if you define only 4 splits (fold 0-3) + and then set fold=4 when training (that would be the fifth split), nnU-Net will print a warning and proceed to + use a random 80:20 data split. + :return: + """ + if self.fold == "all": + # if fold==all then we use all images for training and validation + case_identifiers = get_case_identifiers(self.preprocessed_dataset_folder) + tr_keys = case_identifiers + val_keys = tr_keys + else: + splits_file = join(self.preprocessed_dataset_folder_base, "splits_final.json") + dataset = nnUNetDataset(self.preprocessed_dataset_folder, case_identifiers=None, + num_images_properties_loading_threshold=0, + folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage) + # if the split file does not exist we need to create it + if not isfile(splits_file): + self.print_to_log_file("Creating new 5-fold cross-validation split...") + splits = [] + all_keys_sorted = np.sort(list(dataset.keys())) + kfold = KFold(n_splits=5, shuffle=True, random_state=12345) + for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)): + train_keys = np.array(all_keys_sorted)[train_idx] + test_keys = np.array(all_keys_sorted)[test_idx] + splits.append({}) + splits[-1]['train'] = list(train_keys) + splits[-1]['val'] = list(test_keys) + save_json(splits, splits_file) + + else: + self.print_to_log_file("Using splits from existing split file:", splits_file) + splits = load_json(splits_file) + self.print_to_log_file("The split file contains %d splits." % len(splits)) + + self.print_to_log_file("Desired fold for training: %d" % self.fold) + if self.fold < len(splits): + tr_keys = splits[self.fold]['train'] + val_keys = splits[self.fold]['val'] + self.print_to_log_file("This split has %d training and %d validation cases." + % (len(tr_keys), len(val_keys))) + else: + self.print_to_log_file("INFO: You requested fold %d for training but splits " + "contain only %d folds. I am now creating a " + "random (but seeded) 80:20 split!" % (self.fold, len(splits))) + # if we request a fold that is not in the split file, create a random 80:20 split + rnd = np.random.RandomState(seed=12345 + self.fold) + keys = np.sort(list(dataset.keys())) + idx_tr = rnd.choice(len(keys), int(len(keys) * 0.8), replace=False) + idx_val = [i for i in range(len(keys)) if i not in idx_tr] + tr_keys = [keys[i] for i in idx_tr] + val_keys = [keys[i] for i in idx_val] + self.print_to_log_file("This random 80:20 split has %d training and %d validation cases." + % (len(tr_keys), len(val_keys))) + if any([i in val_keys for i in tr_keys]): + self.print_to_log_file('WARNING: Some validation cases are also in the training set. Please check the ' + 'splits.json or ignore if this is intentional.') + return tr_keys, val_keys + + def get_tr_and_val_datasets(self): + # create dataset split + tr_keys, val_keys = self.do_split() + + # load the datasets for training and validation. Note that we always draw random samples so we really don't + # care about distributing training cases across GPUs. + dataset_tr = nnUNetDataset(self.preprocessed_dataset_folder, tr_keys, + folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage, + num_images_properties_loading_threshold=0) + dataset_val = nnUNetDataset(self.preprocessed_dataset_folder, val_keys, + folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage, + num_images_properties_loading_threshold=0) + return dataset_tr, dataset_val + + def get_dataloaders(self): + # we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether + # we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be + patch_size = self.configuration_manager.patch_size + dim = len(patch_size) + + # needed for deep supervision: how much do we need to downscale the segmentation targets for the different + # outputs? + deep_supervision_scales = self._get_deep_supervision_scales() + + rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ + self.configure_rotation_dummyDA_mirroring_and_inital_patch_size() + + # training pipeline + tr_transforms = self.get_training_transforms( + patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug, + order_resampling_data=3, order_resampling_seg=1, + use_mask_for_norm=self.configuration_manager.use_mask_for_norm, + is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.foreground_labels, + regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + # validation pipeline + val_transforms = self.get_validation_transforms(deep_supervision_scales, + is_cascaded=self.is_cascaded, + foreground_labels=self.label_manager.foreground_labels, + regions=self.label_manager.foreground_regions if + self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim) + + allowed_num_processes = get_allowed_n_proc_DA() + if allowed_num_processes == 0: + mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms) + mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms) + else: + mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, data_loader=dl_tr, transform=tr_transforms, + num_processes=allowed_num_processes, num_cached=6, seeds=None, + pin_memory=self.device.type == 'cuda', wait_time=0.02) + mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, data_loader=dl_val, + transform=val_transforms, num_processes=max(1, allowed_num_processes // 2), + num_cached=3, seeds=None, pin_memory=self.device.type == 'cuda', + wait_time=0.02) + return mt_gen_train, mt_gen_val + + def get_plain_dataloaders(self, initial_patch_size: Tuple[int, ...], dim: int): + dataset_tr, dataset_val = self.get_tr_and_val_datasets() + + if dim == 2: + dl_tr = nnUNetDataLoader2D(dataset_tr, self.batch_size, + initial_patch_size, + self.configuration_manager.patch_size, + self.label_manager, + oversample_foreground_percent=self.oversample_foreground_percent, + sampling_probabilities=None, pad_sides=None) + dl_val = nnUNetDataLoader2D(dataset_val, self.batch_size, + self.configuration_manager.patch_size, + self.configuration_manager.patch_size, + self.label_manager, + oversample_foreground_percent=self.oversample_foreground_percent, + sampling_probabilities=None, pad_sides=None) + else: + dl_tr = nnUNetDataLoader3D(dataset_tr, self.batch_size, + initial_patch_size, + self.configuration_manager.patch_size, + self.label_manager, + oversample_foreground_percent=self.oversample_foreground_percent, + sampling_probabilities=None, pad_sides=None) + dl_val = nnUNetDataLoader3D(dataset_val, self.batch_size, + self.configuration_manager.patch_size, + self.configuration_manager.patch_size, + self.label_manager, + oversample_foreground_percent=self.oversample_foreground_percent, + sampling_probabilities=None, pad_sides=None) + return dl_tr, dl_val + + @staticmethod + def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]], + rotation_for_DA: dict, + deep_supervision_scales: Union[List, Tuple], + mirror_axes: Tuple[int, ...], + do_dummy_2d_data_aug: bool, + order_resampling_data: int = 3, + order_resampling_seg: int = 1, + border_val_seg: int = -1, + use_mask_for_norm: List[bool] = None, + is_cascaded: bool = False, + foreground_labels: Union[Tuple[int, ...], List[int]] = None, + regions: List[Union[List[int], Tuple[int, ...], int]] = None, + ignore_label: int = None) -> AbstractTransform: + tr_transforms = [] + + tr_transforms.append(RenameTransform('seg', 'target', True)) + + if regions is not None: + # the ignore label must also be converted + tr_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label] + if ignore_label is not None else regions, + 'target', 'target')) + + if deep_supervision_scales is not None: + tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target', + output_key='target')) + tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) + tr_transforms = Compose(tr_transforms) + return tr_transforms + + @staticmethod + def get_validation_transforms(deep_supervision_scales: Union[List, Tuple], + is_cascaded: bool = False, + foreground_labels: Union[Tuple[int, ...], List[int]] = None, + regions: List[Union[List[int], Tuple[int, ...], int]] = None, + ignore_label: int = None) -> AbstractTransform: + val_transforms = [] + val_transforms.append(RemoveLabelTransform(-1, 0)) + + if is_cascaded: + val_transforms.append(MoveSegAsOneHotToData(1, foreground_labels, 'seg', 'data')) + + val_transforms.append(RenameTransform('seg', 'target', True)) + + if regions is not None: + # the ignore label must also be converted + val_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label] + if ignore_label is not None else regions, + 'target', 'target')) + + if deep_supervision_scales is not None: + val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target', + output_key='target')) + + val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) + val_transforms = Compose(val_transforms) + return val_transforms + + def set_deep_supervision_enabled(self, enabled: bool): + """ + This function is specific for the default architecture in nnU-Net. If you change the architecture, there are + chances you need to change this as well! + """ + if self.is_ddp: + self.network.module.decoder.deep_supervision = enabled + else: + self.network.decoder.deep_supervision = enabled + + def on_train_start(self): + if not self.was_initialized: + self.initialize() + + maybe_mkdir_p(self.output_folder) + + # make sure deep supervision is on in the network + self.set_deep_supervision_enabled(True) + + self.print_plans() + empty_cache(self.device) + + # maybe unpack + if self.unpack_dataset and self.local_rank == 0: + self.print_to_log_file('unpacking dataset...') + unpack_dataset(self.preprocessed_dataset_folder, unpack_segmentation=True, overwrite_existing=False, + num_processes=max(1, round(get_allowed_n_proc_DA() // 2))) + self.print_to_log_file('unpacking done...') + + if self.is_ddp: + dist.barrier() + + # dataloaders must be instantiated here because they need access to the training data which may not be present + # when doing inference + self.dataloader_train, self.dataloader_val = self.get_dataloaders() + + # copy plans and dataset.json so that they can be used for restoring everything we need for inference + save_json(self.plans_manager.plans, join(self.output_folder_base, 'plans.json'), sort_keys=False) + save_json(self.dataset_json, join(self.output_folder_base, 'dataset.json'), sort_keys=False) + + # we don't really need the fingerprint but its still handy to have it with the others + shutil.copy(join(self.preprocessed_dataset_folder_base, 'dataset_fingerprint.json'), + join(self.output_folder_base, 'dataset_fingerprint.json')) + + # produces a pdf in output folder + self.plot_network_architecture() + + self._save_debug_information() + + # print(f"batch size: {self.batch_size}") + # print(f"oversample: {self.oversample_foreground_percent}") + + def on_train_end(self): + self.save_checkpoint(join(self.output_folder, "checkpoint_final.pth")) + # now we can delete latest + if self.local_rank == 0 and isfile(join(self.output_folder, "checkpoint_latest.pth")): + os.remove(join(self.output_folder, "checkpoint_latest.pth")) + + # shut down dataloaders + old_stdout = sys.stdout + with open(os.devnull, 'w') as f: + sys.stdout = f + if self.dataloader_train is not None: + self.dataloader_train._finish() + if self.dataloader_val is not None: + self.dataloader_val._finish() + sys.stdout = old_stdout + + empty_cache(self.device) + self.print_to_log_file("Training done.") + + def on_train_epoch_start(self): + self.network.train() + self.lr_scheduler.step(self.current_epoch) + self.print_to_log_file('') + self.print_to_log_file(f'Epoch {self.current_epoch}') + self.print_to_log_file( + f"Current learning rate: {np.round(self.optimizer.param_groups[0]['lr'], decimals=5)}") + # lrs are the same for all workers so we don't need to gather them in case of DDP training + self.logger.log('lrs', self.optimizer.param_groups[0]['lr'], self.current_epoch) + + def train_step(self, batch: dict) -> dict: + data = batch['data'] + target = batch['target'] + + data = data.to(self.device, non_blocking=True) + if isinstance(target, list): + target = [i.to(self.device, non_blocking=True) for i in target] + else: + target = target.to(self.device, non_blocking=True) + + self.optimizer.zero_grad() + # Autocast is a little bitch. + # If the device_type is 'cpu' then it's slow as heck and needs to be disabled. + # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False) + # So autocast will only be active if we have a cuda device. + with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): + output = self.network(data) + # del data + l = self.loss(output, target) + + if self.grad_scaler is not None: + self.grad_scaler.scale(l).backward() + self.grad_scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() + else: + l.backward() + torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) + self.optimizer.step() + return {'loss': l.detach().cpu().numpy()} + + def on_train_epoch_end(self, train_outputs: List[dict]): + outputs = collate_outputs(train_outputs) + + if self.is_ddp: + losses_tr = [None for _ in range(dist.get_world_size())] + dist.all_gather_object(losses_tr, outputs['loss']) + loss_here = np.vstack(losses_tr).mean() + else: + loss_here = np.mean(outputs['loss']) + + self.logger.log('train_losses', loss_here, self.current_epoch) + + def on_validation_epoch_start(self): + self.network.eval() + + def validation_step(self, batch: dict) -> dict: + data = batch['data'] + target = batch['target'] + + data = data.to(self.device, non_blocking=True) + if isinstance(target, list): + target = [i.to(self.device, non_blocking=True) for i in target] + else: + target = target.to(self.device, non_blocking=True) + + # Autocast is a little bitch. + # If the device_type is 'cpu' then it's slow as heck and needs to be disabled. + # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False) + # So autocast will only be active if we have a cuda device. + with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): + output = self.network(data) + del data + l = self.loss(output, target) + + # we only need the output with the highest output resolution + output = output[0] + target = target[0] + + # the following is needed for online evaluation. Fake dice (green line) + axes = [0] + list(range(2, len(output.shape))) + + if self.label_manager.has_regions: + predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long() + else: + # no need for softmax + output_seg = output.argmax(1)[:, None] + predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float32) + predicted_segmentation_onehot.scatter_(1, output_seg, 1) + del output_seg + + if self.label_manager.has_ignore_label: + if not self.label_manager.has_regions: + mask = (target != self.label_manager.ignore_label).float() + # CAREFUL that you don't rely on target after this line! + target[target == self.label_manager.ignore_label] = 0 + else: + mask = 1 - target[:, -1:] + # CAREFUL that you don't rely on target after this line! + target = target[:, :-1] + else: + mask = None + + tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask) + + tp_hard = tp.detach().cpu().numpy() + fp_hard = fp.detach().cpu().numpy() + fn_hard = fn.detach().cpu().numpy() + if not self.label_manager.has_regions: + # if we train with regions all segmentation heads predict some kind of foreground. In conventional + # (softmax training) there needs tobe one output for the background. We are not interested in the + # background Dice + # [1:] in order to remove background + tp_hard = tp_hard[1:] + fp_hard = fp_hard[1:] + fn_hard = fn_hard[1:] + + return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard} + + def on_validation_epoch_end(self, val_outputs: List[dict]): + outputs_collated = collate_outputs(val_outputs) + tp = np.sum(outputs_collated['tp_hard'], 0) + fp = np.sum(outputs_collated['fp_hard'], 0) + fn = np.sum(outputs_collated['fn_hard'], 0) + + if self.is_ddp: + world_size = dist.get_world_size() + + tps = [None for _ in range(world_size)] + dist.all_gather_object(tps, tp) + tp = np.vstack([i[None] for i in tps]).sum(0) + + fps = [None for _ in range(world_size)] + dist.all_gather_object(fps, fp) + fp = np.vstack([i[None] for i in fps]).sum(0) + + fns = [None for _ in range(world_size)] + dist.all_gather_object(fns, fn) + fn = np.vstack([i[None] for i in fns]).sum(0) + + losses_val = [None for _ in range(world_size)] + dist.all_gather_object(losses_val, outputs_collated['loss']) + loss_here = np.vstack(losses_val).mean() + else: + loss_here = np.mean(outputs_collated['loss']) + + global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in + zip(tp, fp, fn)]] + mean_fg_dice = np.nanmean(global_dc_per_class) + self.logger.log('mean_fg_dice', mean_fg_dice, self.current_epoch) + self.logger.log('dice_per_class_or_region', global_dc_per_class, self.current_epoch) + self.logger.log('val_losses', loss_here, self.current_epoch) + + def on_epoch_start(self): + self.logger.log('epoch_start_timestamps', time(), self.current_epoch) + + def on_epoch_end(self): + self.logger.log('epoch_end_timestamps', time(), self.current_epoch) + + # todo find a solution for this stupid shit + self.print_to_log_file('train_loss', np.round(self.logger.my_fantastic_logging['train_losses'][-1], decimals=4)) + self.print_to_log_file('val_loss', np.round(self.logger.my_fantastic_logging['val_losses'][-1], decimals=4)) + self.print_to_log_file('Pseudo dice', [np.round(i, decimals=4) for i in + self.logger.my_fantastic_logging['dice_per_class_or_region'][-1]]) + self.print_to_log_file( + f"Epoch time: {np.round(self.logger.my_fantastic_logging['epoch_end_timestamps'][-1] - self.logger.my_fantastic_logging['epoch_start_timestamps'][-1], decimals=2)} s") + + # handling periodic checkpointing + current_epoch = self.current_epoch + if (current_epoch + 1) % self.save_every == 0 and current_epoch != (self.num_epochs - 1): + self.save_checkpoint(join(self.output_folder, 'checkpoint_latest.pth')) + + # handle 'best' checkpointing. ema_fg_dice is computed by the logger and can be accessed like this + if self._best_ema is None or self.logger.my_fantastic_logging['ema_fg_dice'][-1] > self._best_ema: + self._best_ema = self.logger.my_fantastic_logging['ema_fg_dice'][-1] + self.print_to_log_file(f"Yayy! New best EMA pseudo Dice: {np.round(self._best_ema, decimals=4)}") + self.save_checkpoint(join(self.output_folder, 'checkpoint_best.pth')) + + if self.local_rank == 0: + self.logger.plot_progress_png(self.output_folder) + + self.current_epoch += 1 + + def save_checkpoint(self, filename: str) -> None: + if self.local_rank == 0: + if not self.disable_checkpointing: + if self.is_ddp: + mod = self.network.module + else: + mod = self.network + if isinstance(mod, OptimizedModule): + mod = mod._orig_mod + + checkpoint = { + 'network_weights': mod.state_dict(), + 'optimizer_state': self.optimizer.state_dict(), + 'grad_scaler_state': self.grad_scaler.state_dict() if self.grad_scaler is not None else None, + 'logging': self.logger.get_checkpoint(), + '_best_ema': self._best_ema, + 'current_epoch': self.current_epoch + 1, + 'init_args': self.my_init_kwargs, + 'trainer_name': self.__class__.__name__, + 'inference_allowed_mirroring_axes': self.inference_allowed_mirroring_axes, + } + torch.save(checkpoint, filename) + else: + self.print_to_log_file('No checkpoint written, checkpointing is disabled') + + def load_checkpoint(self, filename_or_checkpoint: Union[dict, str]) -> None: + if not self.was_initialized: + self.initialize() + + if isinstance(filename_or_checkpoint, str): + checkpoint = torch.load(filename_or_checkpoint, map_location=self.device) + # if state dict comes from nn.DataParallel but we use non-parallel model here then the state dict keys do not + # match. Use heuristic to make it match + new_state_dict = {} + for k, value in checkpoint['network_weights'].items(): + key = k + if key not in self.network.state_dict().keys() and key.startswith('module.'): + key = key[7:] + new_state_dict[key] = value + + self.my_init_kwargs = checkpoint['init_args'] + self.current_epoch = checkpoint['current_epoch'] + self.logger.load_checkpoint(checkpoint['logging']) + self._best_ema = checkpoint['_best_ema'] + self.inference_allowed_mirroring_axes = checkpoint[ + 'inference_allowed_mirroring_axes'] if 'inference_allowed_mirroring_axes' in checkpoint.keys() else self.inference_allowed_mirroring_axes + + # messing with state dict naming schemes. Facepalm. + if self.is_ddp: + if isinstance(self.network.module, OptimizedModule): + self.network.module._orig_mod.load_state_dict(new_state_dict) + else: + self.network.module.load_state_dict(new_state_dict) + else: + if isinstance(self.network, OptimizedModule): + self.network._orig_mod.load_state_dict(new_state_dict) + else: + self.network.load_state_dict(new_state_dict) + self.optimizer.load_state_dict(checkpoint['optimizer_state']) + if self.grad_scaler is not None: + if checkpoint['grad_scaler_state'] is not None: + self.grad_scaler.load_state_dict(checkpoint['grad_scaler_state']) + + def perform_actual_validation(self, save_probabilities: bool = False): + self.set_deep_supervision_enabled(False) + self.network.eval() + + predictor = nnUNetPredictor(tile_step_size=0.5, use_gaussian=True, use_mirroring=True, + perform_everything_on_gpu=True, device=self.device, verbose=False, + verbose_preprocessing=False, allow_tqdm=False) + predictor.manual_initialization(self.network, self.plans_manager, self.configuration_manager, None, + self.dataset_json, self.__class__.__name__, + self.inference_allowed_mirroring_axes) + + with multiprocessing.get_context("spawn").Pool(default_num_processes) as segmentation_export_pool: + worker_list = [i for i in segmentation_export_pool._pool] + validation_output_folder = join(self.output_folder, 'validation') + maybe_mkdir_p(validation_output_folder) + + # we cannot use self.get_tr_and_val_datasets() here because we might be DDP and then we have to distribute + # the validation keys across the workers. + _, val_keys = self.do_split() + if self.is_ddp: + val_keys = val_keys[self.local_rank:: dist.get_world_size()] + + dataset_val = nnUNetDataset(self.preprocessed_dataset_folder, val_keys, + folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage, + num_images_properties_loading_threshold=0) + + next_stages = self.configuration_manager.next_stage_names + + if next_stages is not None: + _ = [maybe_mkdir_p(join(self.output_folder_base, 'predicted_next_stage', n)) for n in next_stages] + + results = [] + + for k in dataset_val.keys(): + proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results, + allowed_num_queued=2) + while not proceed: + sleep(0.1) + proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results, + allowed_num_queued=2) + + self.print_to_log_file(f"predicting {k}") + data, seg, properties = dataset_val.load_case(k) + + if self.is_cascaded: + data = np.vstack((data, convert_labelmap_to_one_hot(seg[-1], self.label_manager.foreground_labels, + output_dtype=data.dtype))) + with warnings.catch_warnings(): + # ignore 'The given NumPy array is not writable' warning + warnings.simplefilter("ignore") + data = torch.from_numpy(data) + + output_filename_truncated = join(validation_output_folder, k) + + try: + prediction = predictor.predict_sliding_window_return_logits(data) + except RuntimeError: + predictor.perform_everything_on_gpu = False + prediction = predictor.predict_sliding_window_return_logits(data) + predictor.perform_everything_on_gpu = True + + prediction = prediction.cpu() + + # this needs to go into background processes + results.append( + segmentation_export_pool.starmap_async( + export_prediction_from_logits, ( + (prediction, properties, self.configuration_manager, self.plans_manager, + self.dataset_json, output_filename_truncated, save_probabilities), + ) + ) + ) + # for debug purposes + # export_prediction(prediction_for_export, properties, self.configuration, self.plans, self.dataset_json, + # output_filename_truncated, save_probabilities) + + # if needed, export the softmax prediction for the next stage + if next_stages is not None: + for n in next_stages: + next_stage_config_manager = self.plans_manager.get_configuration(n) + expected_preprocessed_folder = join(nnUNet_preprocessed, self.plans_manager.dataset_name, + next_stage_config_manager.data_identifier) + + try: + # we do this so that we can use load_case and do not have to hard code how loading training cases is implemented + tmp = nnUNetDataset(expected_preprocessed_folder, [k], + num_images_properties_loading_threshold=0) + d, s, p = tmp.load_case(k) + except FileNotFoundError: + self.print_to_log_file( + f"Predicting next stage {n} failed for case {k} because the preprocessed file is missing! " + f"Run the preprocessing for this configuration first!") + continue + + target_shape = d.shape[1:] + output_folder = join(self.output_folder_base, 'predicted_next_stage', n) + output_file = join(output_folder, k + '.npz') + + # resample_and_save(prediction, target_shape, output_file, self.plans_manager, self.configuration_manager, properties, + # self.dataset_json) + results.append(segmentation_export_pool.starmap_async( + resample_and_save, ( + (prediction, target_shape, output_file, self.plans_manager, + self.configuration_manager, + properties, + self.dataset_json), + ) + )) + + _ = [r.get() for r in results] + + if self.is_ddp: + dist.barrier() + + if self.local_rank == 0: + metrics = compute_metrics_on_folder(join(self.preprocessed_dataset_folder_base, 'gt_segmentations'), + validation_output_folder, + join(validation_output_folder, 'summary.json'), + self.plans_manager.image_reader_writer_class(), + self.dataset_json["file_ending"], + self.label_manager.foreground_regions if self.label_manager.has_regions else + self.label_manager.foreground_labels, + self.label_manager.ignore_label, chill=True) + self.print_to_log_file("Validation complete", also_print_to_console=True) + self.print_to_log_file("Mean Validation Dice: ", (metrics['foreground_mean']["Dice"]), also_print_to_console=True) + + self.set_deep_supervision_enabled(True) + compute_gaussian.cache_clear() + + def run_training(self): + self.on_train_start() + + for epoch in range(self.current_epoch, self.num_epochs): + self.on_epoch_start() + + self.on_train_epoch_start() + train_outputs = [] + for batch_id in range(self.num_iterations_per_epoch): + train_outputs.append(self.train_step(next(self.dataloader_train))) + self.on_train_epoch_end(train_outputs) + + with torch.no_grad(): + self.on_validation_epoch_start() + val_outputs = [] + for batch_id in range(self.num_val_iterations_per_epoch): + val_outputs.append(self.validation_step(next(self.dataloader_val))) + self.on_validation_epoch_end(val_outputs) + + self.on_epoch_end() + + self.on_train_end() diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer_autopet.py b/nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer_autopet.py new file mode 100644 index 0000000..51b8f0d --- /dev/null +++ b/nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer_autopet.py @@ -0,0 +1,1339 @@ +import inspect +import multiprocessing +import os +import shutil +import sys +import warnings +from copy import deepcopy +from datetime import datetime +from time import time, sleep +from typing import Union, Tuple, List + +import numpy as np +import torch +from torch.nn import BCEWithLogitsLoss +from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter +from batchgenerators.transforms.abstract_transforms import AbstractTransform, Compose +from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \ + ContrastAugmentationTransform, GammaTransform +from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform +from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform +from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform +from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor +from batchgenerators.utilities.file_and_folder_operations import join, load_json, isfile, save_json, maybe_mkdir_p +from torch._dynamo import OptimizedModule + +from nnunetv2.configuration import ANISO_THRESHOLD, default_num_processes +from nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder +from nnunetv2.inference.export_prediction import export_prediction_from_logits, resample_and_save +from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor +from nnunetv2.inference.sliding_window_prediction import compute_gaussian +from nnunetv2.paths import nnUNet_preprocessed, nnUNet_results +from nnunetv2.training.data_augmentation.compute_initial_patch_size import get_patch_size +from nnunetv2.training.data_augmentation.custom_transforms.cascade_transforms import MoveSegAsOneHotToData, \ + ApplyRandomBinaryOperatorTransform, RemoveRandomConnectedComponentFromOneHotEncodingTransform +from nnunetv2.training.data_augmentation.custom_transforms.deep_supervision_donwsampling import \ + DownsampleSegForDSTransform2 +from nnunetv2.training.data_augmentation.custom_transforms.limited_length_multithreaded_augmenter import \ + LimitedLenWrapper +from nnunetv2.training.data_augmentation.custom_transforms.masking import MaskTransform +from nnunetv2.training.data_augmentation.custom_transforms.region_based_training import \ + ConvertSegmentationToRegionsTransform +from nnunetv2.training.data_augmentation.custom_transforms.transforms_for_dummy_2d import Convert2DTo3DTransform, \ + Convert3DTo2DTransform +from nnunetv2.training.dataloading.data_loader_2d import nnUNetDataLoader2D +from nnunetv2.training.dataloading.data_loader_3d import nnUNetDataLoader3D +from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset +from nnunetv2.training.dataloading.utils import get_case_identifiers, unpack_dataset +from nnunetv2.training.logging.nnunet_logger import nnUNetLogger +from nnunetv2.training.loss.compound_losses import DC_and_CE_loss, DC_and_BCE_loss +from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper +from nnunetv2.training.loss.dice import get_tp_fp_fn_tn, MemoryEfficientSoftDiceLoss +from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler +from nnunetv2.utilities.collate_outputs import collate_outputs +from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA +from nnunetv2.utilities.file_path_utilities import check_workers_alive_and_busy +from nnunetv2.utilities.get_network_from_plans import get_network_from_plans +from nnunetv2.utilities.helpers import empty_cache, dummy_context +from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot, determine_num_input_channels +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager +from sklearn.model_selection import KFold +from torch import autocast, nn +from torch import distributed as dist +from torch.cuda import device_count +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +from monai.networks.nets import ViT +from dynamic_network_architectures.building_blocks.helper import convert_conv_op_to_dim + + +class AutoPETNet(nn.Module): + def __init__(self, encoder, decoder, cl_a, cl_c, cl_s, classifier, fs_a, fs_c, fs_s): + super().__init__() + self.encoder = encoder + self.decoder = decoder + self.cl_a = cl_a + self.cl_c = cl_c + self.cl_s = cl_s + self.classifier = classifier + self.fs_a = fs_a + self.fs_c = fs_c + self.fs_s = fs_s + self.hidden_size = 768 + spatial_dims = 3 + self.proj_axes = (0, spatial_dims + 1) + tuple(d + 1 for d in range(spatial_dims)) + + def proj_feat(self, x, feat_size): + self.proj_view_shape = list(feat_size) + [self.hidden_size] + new_view = [x.size(0)] + self.proj_view_shape + x = x.view(new_view) + x = x.permute(self.proj_axes).contiguous() + return x + + def forward(self, x): + + mip_axial = torch.cat([torch.max(x[idx].unsqueeze(0), 4, keepdim=True)[0] for idx in range(np.shape(x)[0])], dim=0) + mip_coro = torch.cat([torch.max(x[idx].unsqueeze(0), 3, keepdim=True)[0] for idx in range(np.shape(x)[0])], dim=0) + mip_sagi = torch.cat([torch.max(x[idx].unsqueeze(0), 2, keepdim=True)[0] for idx in range(np.shape(x)[0])], dim=0) + skips = self.encoder(x) + output = self.decoder(skips) + feature_a, hs_a = self.cl_a(mip_axial) + feature_c, hs_c = self.cl_c(mip_coro) + feature_s, hs_s = self.cl_s(mip_sagi) + features = torch.nn.AvgPool3d((4, 4, 4))(skips[-1]) + feature_a = torch.nn.AvgPool3d((8, 8, 1))(self.proj_feat(feature_a, self.fs_a)) + feature_c = torch.nn.AvgPool3d((8, 1, 8))(self.proj_feat(feature_c, self.fs_c)) + feature_s = torch.nn.AvgPool3d((1, 8, 8))(self.proj_feat(feature_s, self.fs_s)) + all_features = torch.cat([features, feature_a, feature_c, feature_s], dim=1).squeeze(-1).squeeze(-1).squeeze(-1) + classif = torch.softmax(self.classifier(all_features), dim=1) + if self.training: + return output, classif + else: + if isinstance(output, list): + for i in range(output[0].shape[0]): + output[0][i] = output[0][i] * torch.argmax(classif, dim=1)[i] + return output + else: + for i in range(output.shape[0]): + output[i] = output[i] * torch.argmax(classif, dim=1)[i] + return output + + def compute_conv_feature_map_size(self, input_size): + # Ok this is not good. Later ? + assert len(input_size) == convert_conv_op_to_dim(self.encoder.conv_op), "just give the image size without color/feature channels or " \ + "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ + "Give input_size=(x, y(, z))!" + return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size) + + +class nnUNetTrainer_autopet(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + # From https://grugbrain.dev/. Worth a read ya big brains ;-) + + # apex predator of grug is complexity + # complexity bad + # say again: + # complexity very bad + # you say now: + # complexity very, very bad + # given choice between complexity or one on one against t-rex, grug take t-rex: at least grug see t-rex + # complexity is spirit demon that enter codebase through well-meaning but ultimately very clubbable non grug-brain developers and project managers who not fear complexity spirit demon or even know about sometime + # one day code base understandable and grug can get work done, everything good! + # next day impossible: complexity demon spirit has entered code and very dangerous situation! + + # OK OK I am guilty. But I tried. http://tiny.cc/gzgwuz + + self.is_ddp = dist.is_available() and dist.is_initialized() + self.local_rank = 0 if not self.is_ddp else dist.get_rank() + + self.device = device + + # print what device we are using + if self.is_ddp: # implicitly it's clear that we use cuda in this case + print(f"I am local rank {self.local_rank}. {device_count()} GPUs are available. The world size is " + f"{dist.get_world_size()}." + f"Setting device to {self.device}") + self.device = torch.device(type='cuda', index=self.local_rank) + else: + if self.device.type == 'cuda': + # we might want to let the user pick this but for now please pick the correct GPU with CUDA_VISIBLE_DEVICES=X + self.device = torch.device(type='cuda', index=0) + print(f"Using device: {self.device}") + + # loading and saving this class for continuing from checkpoint should not happen based on pickling. This + # would also pickle the network etc. Bad, bad. Instead we just reinstantiate and then load the checkpoint we + # need. So let's save the init args + self.my_init_kwargs = {} + for k in inspect.signature(self.__init__).parameters.keys(): + self.my_init_kwargs[k] = locals()[k] + + ### Saving all the init args into class variables for later access + self.plans_manager = PlansManager(plans) + self.configuration_manager = self.plans_manager.get_configuration(configuration) + self.configuration_name = configuration + self.dataset_json = dataset_json + self.fold = fold + self.unpack_dataset = unpack_dataset + + ### Setting all the folder names. We need to make sure things don't crash in case we are just running + # inference and some of the folders may not be defined! + self.preprocessed_dataset_folder_base = join(nnUNet_preprocessed, self.plans_manager.dataset_name) \ + if nnUNet_preprocessed is not None else None + self.output_folder_base = join(nnUNet_results, self.plans_manager.dataset_name, + self.__class__.__name__ + '__' + self.plans_manager.plans_name + "__" + configuration) \ + if nnUNet_results is not None else None + self.output_folder = join(self.output_folder_base, f'fold_{fold}') + + self.preprocessed_dataset_folder = join(self.preprocessed_dataset_folder_base, + self.configuration_manager.data_identifier) + # unlike the previous nnunet folder_with_segs_from_previous_stage is now part of the plans. For now it has to + # be a different configuration in the same plans + # IMPORTANT! the mapping must be bijective, so lowres must point to fullres and vice versa (using + # "previous_stage" and "next_stage"). Otherwise it won't work! + self.is_cascaded = self.configuration_manager.previous_stage_name is not None + self.folder_with_segs_from_previous_stage = \ + join(nnUNet_results, self.plans_manager.dataset_name, + self.__class__.__name__ + '__' + self.plans_manager.plans_name + "__" + + self.configuration_manager.previous_stage_name, 'predicted_next_stage', self.configuration_name) \ + if self.is_cascaded else None + + ### Some hyperparameters for you to fiddle with + self.initial_lr = 1e-4 + self.weight_decay = 3e-5 + self.oversample_foreground_percent = 0.33 + self.num_iterations_per_epoch = 825 + self.num_val_iterations_per_epoch = 200 + self.num_epochs = 250 + self.current_epoch = 0 + + ### Dealing with labels/regions + self.label_manager = self.plans_manager.get_label_manager(dataset_json) + # labels can either be a list of int (regular training) or a list of tuples of int (region-based training) + # needed for predictions. We do sigmoid in case of (overlapping) regions + + self.num_input_channels = None # -> self.initialize() + self.network = None # -> self._get_network() + self.optimizer = self.lr_scheduler = None # -> self.initialize + self.grad_scaler = GradScaler() if self.device.type == 'cuda' else None + self.loss = None # -> self.initialize + + ### Simple logging. Don't take that away from me! + # initialize log file. This is just our log for the print statements etc. Not to be confused with lightning + # logging + timestamp = datetime.now() + maybe_mkdir_p(self.output_folder) + self.log_file = join(self.output_folder, "training_log_%d_%d_%d_%02.0d_%02.0d_%02.0d.txt" % + (timestamp.year, timestamp.month, timestamp.day, timestamp.hour, timestamp.minute, + timestamp.second)) + self.logger = nnUNetLogger() + + ### placeholders + self.dataloader_train = self.dataloader_val = None # see on_train_start + + ### initializing stuff for remembering things and such + self._best_ema = None + + ### inference things + self.inference_allowed_mirroring_axes = None # this variable is set in + # self.configure_rotation_dummyDA_mirroring_and_inital_patch_size and will be saved in checkpoints + + ### checkpoint saving stuff + self.save_every = 50 + self.disable_checkpointing = False + + ## DDP batch size and oversampling can differ between workers and needs adaptation + # we need to change the batch size in DDP because we don't use any of those distributed samplers + self._set_batch_size_and_oversample() + self.hidden_size = 768 + spatial_dims = 3 + self.proj_axes = (0, spatial_dims + 1) + tuple(d + 1 for d in range(spatial_dims)) + + self.was_initialized = False + + self.print_to_log_file("\n#######################################################################\n" + "Please cite the following paper when using nnU-Net:\n" + "Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). " + "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. " + "Nature methods, 18(2), 203-211.\n" + "#######################################################################\n", + also_print_to_console=True, add_timestamp=False) + + def proj_feat(self, x, feat_size): + self.proj_view_shape = list(feat_size) + [self.hidden_size] + new_view = [x.size(0)] + self.proj_view_shape + x = x.view(new_view) + x = x.permute(self.proj_axes).contiguous() + return x + + def initialize(self): + if not self.was_initialized: + self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager, + self.dataset_json) + + self.network = self.build_network_architecture(self.plans_manager, self.dataset_json, + self.configuration_manager, + self.num_input_channels, + enable_deep_supervision=True).to(self.device) + # compile network for free speedup + if ('nnUNet_compile' in os.environ.keys()) and ( + os.environ['nnUNet_compile'].lower() in ('true', '1', 't')): + self.print_to_log_file('Compiling network...') + self.network = torch.compile(self.network) + + self.optimizer, self.lr_scheduler = self.configure_optimizers() + # if ddp, wrap in DDP wrapper + if self.is_ddp: + self.network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.network) + self.network = DDP(self.network, device_ids=[self.local_rank]) + + self.loss = self._build_loss() + self.was_initialized = True + else: + raise RuntimeError("You have called self.initialize even though the trainer was already initialized. " + "That should not happen.") + + def _save_debug_information(self): + # saving some debug information + if self.local_rank == 0: + dct = {} + for k in self.__dir__(): + if not k.startswith("__"): + if not callable(getattr(self, k)) or k in ['loss', ]: + dct[k] = str(getattr(self, k)) + elif k in ['network', ]: + dct[k] = str(getattr(self, k).__class__.__name__) + else: + # print(k) + pass + if k in ['dataloader_train', 'dataloader_val']: + if hasattr(getattr(self, k), 'generator'): + dct[k + '.generator'] = str(getattr(self, k).generator) + if hasattr(getattr(self, k), 'num_processes'): + dct[k + '.num_processes'] = str(getattr(self, k).num_processes) + if hasattr(getattr(self, k), 'transform'): + dct[k + '.transform'] = str(getattr(self, k).transform) + import subprocess + hostname = subprocess.getoutput(['hostname']) + dct['hostname'] = hostname + torch_version = torch.__version__ + if self.device.type == 'cuda': + gpu_name = torch.cuda.get_device_name() + dct['gpu_name'] = gpu_name + cudnn_version = torch.backends.cudnn.version() + else: + cudnn_version = 'None' + dct['device'] = str(self.device) + dct['torch_version'] = torch_version + dct['cudnn_version'] = cudnn_version + save_json(dct, join(self.output_folder, "debug.json")) + + @staticmethod + def build_network_architecture(plans_manager: PlansManager, + dataset_json, + configuration_manager: ConfigurationManager, + num_input_channels, + enable_deep_supervision: bool = True) -> nn.Module: + """ + his is where you build the architecture according to the plans. There is no obligation to use + get_network_from_plans, this is just a utility we use for the nnU-Net default architectures. You can do what + you want. Even ignore the plans and just return something static (as long as it can process the requested + patch size) + but don't bug us with your bugs arising from fiddling with this :-P + This is the function that is called in inference as well! This is needed so that all network architecture + variants can be loaded at inference time (inference will use the same nnUNetTrainer that was used for + training, so if you change the network architecture during training by deriving a new trainer class then + inference will know about it). + + If you need to know how many segmentation outputs your custom architecture needs to have, use the following snippet: + > label_manager = plans_manager.get_label_manager(dataset_json) + > label_manager.num_segmentation_heads + (why so complicated? -> We can have either classical training (classes) or regions. If we have regions, + the number of outputs is != the number of classes. Also there is the ignore label for which no output + should be generated. label_manager takes care of all that for you.) + + """ + + axial_ps = (configuration_manager.patch_size[0], configuration_manager.patch_size[1], 1) + coro_ps = (configuration_manager.patch_size[0], 1, configuration_manager.patch_size[2]) + sagi_ps = (1, configuration_manager.patch_size[1], configuration_manager.patch_size[2]) + feat_size_axial = tuple(img_d // p_d for img_d, p_d in zip(axial_ps, (16, 16, 1))) + feat_size_coro = tuple(img_d // p_d for img_d, p_d in zip(coro_ps, (16, 1, 16))) + feat_size_sagi = tuple(img_d // p_d for img_d, p_d in zip(sagi_ps, (1, 16, 16))) + model_classiff_axial = ViT(in_channels=num_input_channels, + img_size=axial_ps, + patch_size=(16, 16, 1), classification=False) + model_classiff_coro = ViT(in_channels=num_input_channels, + img_size=coro_ps, + patch_size=(16, 1, 16), classification=False) + model_classiff_sagi = ViT(in_channels=num_input_channels, + img_size=sagi_ps, + patch_size=(1, 16, 16), classification=False) + classifier = nn.Linear(3 * 768 + configuration_manager.unet_max_num_features, 2) + network = get_network_from_plans(plans_manager, dataset_json, configuration_manager, + num_input_channels, deep_supervision=enable_deep_supervision) + net = AutoPETNet(network.encoder, network.decoder, model_classiff_axial, model_classiff_coro, model_classiff_sagi, + classifier, feat_size_axial, feat_size_coro, feat_size_sagi) + return net + + def _get_deep_supervision_scales(self): + deep_supervision_scales = list(list(i) for i in 1 / np.cumprod(np.vstack( + self.configuration_manager.pool_op_kernel_sizes), axis=0))[:-1] + return deep_supervision_scales + + def _set_batch_size_and_oversample(self): + if not self.is_ddp: + # set batch size to what the plan says, leave oversample untouched + self.batch_size = self.configuration_manager.batch_size + else: + # batch size is distributed over DDP workers and we need to change oversample_percent for each worker + batch_sizes = [] + oversample_percents = [] + + world_size = dist.get_world_size() + my_rank = dist.get_rank() + + global_batch_size = self.configuration_manager.batch_size + assert global_batch_size >= world_size, 'Cannot run DDP if the batch size is smaller than the number of ' \ + 'GPUs... Duh.' + + batch_size_per_GPU = np.ceil(global_batch_size / world_size).astype(int) + + for rank in range(world_size): + if (rank + 1) * batch_size_per_GPU > global_batch_size: + batch_size = batch_size_per_GPU - ((rank + 1) * batch_size_per_GPU - global_batch_size) + else: + batch_size = batch_size_per_GPU + + batch_sizes.append(batch_size) + + sample_id_low = 0 if len(batch_sizes) == 0 else np.sum(batch_sizes[:-1]) + sample_id_high = np.sum(batch_sizes) + + if sample_id_high / global_batch_size < (1 - self.oversample_foreground_percent): + oversample_percents.append(0.0) + elif sample_id_low / global_batch_size > (1 - self.oversample_foreground_percent): + oversample_percents.append(1.0) + else: + percent_covered_by_this_rank = sample_id_high / global_batch_size - sample_id_low / global_batch_size + oversample_percent_here = 1 - (((1 - self.oversample_foreground_percent) - + sample_id_low / global_batch_size) / percent_covered_by_this_rank) + oversample_percents.append(oversample_percent_here) + + print("worker", my_rank, "oversample", oversample_percents[my_rank]) + print("worker", my_rank, "batch_size", batch_sizes[my_rank]) + # self.print_to_log_file("worker", my_rank, "oversample", oversample_percents[my_rank]) + # self.print_to_log_file("worker", my_rank, "batch_size", batch_sizes[my_rank]) + + self.batch_size = batch_sizes[my_rank] + self.oversample_foreground_percent = oversample_percents[my_rank] + + def _build_loss(self): + if self.label_manager.has_regions: + loss = DC_and_BCE_loss({}, + {'batch_dice': self.configuration_manager.batch_dice, + 'do_bg': True, 'smooth': 1e-5, 'ddp': self.is_ddp}, + use_ignore_label=self.label_manager.ignore_label is not None, + dice_class=MemoryEfficientSoftDiceLoss) + else: + loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice, + 'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1, + ignore_label=self.label_manager.ignore_label, dice_class=MemoryEfficientSoftDiceLoss) + self.classif_loss = BCEWithLogitsLoss() + deep_supervision_scales = self._get_deep_supervision_scales() + + # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) + weights[-1] = 0 + + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) + return loss + + def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): + """ + This function is stupid and certainly one of the weakest spots of this implementation. Not entirely sure how we can fix it. + """ + patch_size = self.configuration_manager.patch_size + dim = len(patch_size) + # todo rotation should be defined dynamically based on patch size (more isotropic patch sizes = more rotation) + if dim == 2: + do_dummy_2d_data_aug = False + # todo revisit this parametrization + if max(patch_size) / min(patch_size) > 1.5: + rotation_for_DA = { + 'x': (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi), + 'y': (0, 0), + 'z': (0, 0) + } + else: + rotation_for_DA = { + 'x': (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi), + 'y': (0, 0), + 'z': (0, 0) + } + mirror_axes = (0, 1) + elif dim == 3: + # todo this is not ideal. We could also have patch_size (64, 16, 128) in which case a full 180deg 2d rot would be bad + # order of the axes is determined by spacing, not image size + do_dummy_2d_data_aug = (max(patch_size) / patch_size[0]) > ANISO_THRESHOLD + if do_dummy_2d_data_aug: + # why do we rotate 180 deg here all the time? We should also restrict it + rotation_for_DA = { + 'x': (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi), + 'y': (0, 0), + 'z': (0, 0) + } + else: + rotation_for_DA = { + 'x': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), + 'y': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), + 'z': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), + } + mirror_axes = (0, 1, 2) + else: + raise RuntimeError() + + # todo this function is stupid. It doesn't even use the correct scale range (we keep things as they were in the + # old nnunet for now) + initial_patch_size = get_patch_size(patch_size[-dim:], + *rotation_for_DA.values(), + (0.85, 1.25)) + if do_dummy_2d_data_aug: + initial_patch_size[0] = patch_size[0] + + self.print_to_log_file(f'do_dummy_2d_data_aug: {do_dummy_2d_data_aug}') + self.inference_allowed_mirroring_axes = mirror_axes + + return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes + + def print_to_log_file(self, *args, also_print_to_console=True, add_timestamp=True): + if self.local_rank == 0: + timestamp = time() + dt_object = datetime.fromtimestamp(timestamp) + + if add_timestamp: + args = ("%s:" % dt_object, *args) + + successful = False + max_attempts = 5 + ctr = 0 + while not successful and ctr < max_attempts: + try: + with open(self.log_file, 'a+') as f: + for a in args: + f.write(str(a)) + f.write(" ") + f.write("\n") + successful = True + except IOError: + print("%s: failed to log: " % datetime.fromtimestamp(timestamp), sys.exc_info()) + sleep(0.5) + ctr += 1 + if also_print_to_console: + print(*args) + elif also_print_to_console: + print(*args) + + def print_plans(self): + if self.local_rank == 0: + dct = deepcopy(self.plans_manager.plans) + del dct['configurations'] + self.print_to_log_file(f"\nThis is the configuration used by this " + f"training:\nConfiguration name: {self.configuration_name}\n", + self.configuration_manager, '\n', add_timestamp=False) + self.print_to_log_file('These are the global plan.json settings:\n', dct, '\n', add_timestamp=False) + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, + momentum=0.99, nesterov=True) + lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs) + return optimizer, lr_scheduler + + def plot_network_architecture(self): + if self.local_rank == 0: + try: + # raise NotImplementedError('hiddenlayer no longer works and we do not have a viable alternative :-(') + # pip install git+https://github.com/saugatkandel/hiddenlayer.git + + # from torchviz import make_dot + # # not viable. + # make_dot(tuple(self.network(torch.rand((1, self.num_input_channels, + # *self.configuration_manager.patch_size), + # device=self.device)))).render( + # join(self.output_folder, "network_architecture.pdf"), format='pdf') + # self.optimizer.zero_grad() + + # broken. + + import hiddenlayer as hl + g = hl.build_graph(self.network, + torch.rand((1, self.num_input_channels, + *self.configuration_manager.patch_size), + device=self.device), + transforms=None) + g.save(join(self.output_folder, "network_architecture.pdf")) + del g + except Exception as e: + self.print_to_log_file("Unable to plot network architecture:") + self.print_to_log_file(e) + + # self.print_to_log_file("\nprinting the network instead:\n") + # self.print_to_log_file(self.network) + # self.print_to_log_file("\n") + finally: + empty_cache(self.device) + + def do_split(self): + """ + The default split is a 5 fold CV on all available training cases. nnU-Net will create a split (it is seeded, + so always the same) and save it as splits_final.pkl file in the preprocessed data directory. + Sometimes you may want to create your own split for various reasons. For this you will need to create your own + splits_final.pkl file. If this file is present, nnU-Net is going to use it and whatever splits are defined in + it. You can create as many splits in this file as you want. Note that if you define only 4 splits (fold 0-3) + and then set fold=4 when training (that would be the fifth split), nnU-Net will print a warning and proceed to + use a random 80:20 data split. + :return: + """ + if self.fold == "all": + # if fold==all then we use all images for training and validation + case_identifiers = get_case_identifiers(self.preprocessed_dataset_folder) + tr_keys = case_identifiers + val_keys = tr_keys + else: + splits_file = join(self.preprocessed_dataset_folder_base, "splits_final.json") + dataset = nnUNetDataset(self.preprocessed_dataset_folder, case_identifiers=None, + num_images_properties_loading_threshold=0, + folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage) + # if the split file does not exist we need to create it + if not isfile(splits_file): + self.print_to_log_file("Creating new 5-fold cross-validation split...") + splits = [] + all_keys_sorted = np.sort(list(dataset.keys())) + kfold = KFold(n_splits=5, shuffle=True, random_state=12345) + for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)): + train_keys = np.array(all_keys_sorted)[train_idx] + test_keys = np.array(all_keys_sorted)[test_idx] + splits.append({}) + splits[-1]['train'] = list(train_keys) + splits[-1]['val'] = list(test_keys) + save_json(splits, splits_file) + + else: + self.print_to_log_file("Using splits from existing split file:", splits_file) + splits = load_json(splits_file) + self.print_to_log_file("The split file contains %d splits." % len(splits)) + + self.print_to_log_file("Desired fold for training: %d" % self.fold) + if self.fold < len(splits): + tr_keys = splits[self.fold]['train'] + val_keys = splits[self.fold]['val'] + self.print_to_log_file("This split has %d training and %d validation cases." + % (len(tr_keys), len(val_keys))) + else: + self.print_to_log_file("INFO: You requested fold %d for training but splits " + "contain only %d folds. I am now creating a " + "random (but seeded) 80:20 split!" % (self.fold, len(splits))) + # if we request a fold that is not in the split file, create a random 80:20 split + rnd = np.random.RandomState(seed=12345 + self.fold) + keys = np.sort(list(dataset.keys())) + idx_tr = rnd.choice(len(keys), int(len(keys) * 0.8), replace=False) + idx_val = [i for i in range(len(keys)) if i not in idx_tr] + tr_keys = [keys[i] for i in idx_tr] + val_keys = [keys[i] for i in idx_val] + self.print_to_log_file("This random 80:20 split has %d training and %d validation cases." + % (len(tr_keys), len(val_keys))) + if any([i in val_keys for i in tr_keys]): + self.print_to_log_file('WARNING: Some validation cases are also in the training set. Please check the ' + 'splits.json or ignore if this is intentional.') + return tr_keys, val_keys + + def get_tr_and_val_datasets(self): + # create dataset split + tr_keys, val_keys = self.do_split() + + # load the datasets for training and validation. Note that we always draw random samples so we really don't + # care about distributing training cases across GPUs. + dataset_tr = nnUNetDataset(self.preprocessed_dataset_folder, tr_keys, + folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage, + num_images_properties_loading_threshold=0) + dataset_val = nnUNetDataset(self.preprocessed_dataset_folder, val_keys, + folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage, + num_images_properties_loading_threshold=0) + return dataset_tr, dataset_val + + def get_dataloaders(self): + # we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether + # we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be + patch_size = self.configuration_manager.patch_size + dim = len(patch_size) + + # needed for deep supervision: how much do we need to downscale the segmentation targets for the different + # outputs? + deep_supervision_scales = self._get_deep_supervision_scales() + + rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ + self.configure_rotation_dummyDA_mirroring_and_inital_patch_size() + + # training pipeline + tr_transforms = self.get_training_transforms( + patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug, + order_resampling_data=3, order_resampling_seg=1, + use_mask_for_norm=self.configuration_manager.use_mask_for_norm, + is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.foreground_labels, + regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + # validation pipeline + val_transforms = self.get_validation_transforms(deep_supervision_scales, + is_cascaded=self.is_cascaded, + foreground_labels=self.label_manager.foreground_labels, + regions=self.label_manager.foreground_regions if + self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim) + + allowed_num_processes = get_allowed_n_proc_DA() + if allowed_num_processes == 0: + mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms) + mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms) + else: + mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, data_loader=dl_tr, transform=tr_transforms, + num_processes=allowed_num_processes, num_cached=6, seeds=None, + pin_memory=self.device.type == 'cuda', wait_time=0.02) + mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, data_loader=dl_val, + transform=val_transforms, num_processes=max(1, allowed_num_processes // 2), + num_cached=3, seeds=None, pin_memory=self.device.type == 'cuda', + wait_time=0.02) + return mt_gen_train, mt_gen_val + + def get_plain_dataloaders(self, initial_patch_size: Tuple[int, ...], dim: int): + dataset_tr, dataset_val = self.get_tr_and_val_datasets() + + if dim == 2: + dl_tr = nnUNetDataLoader2D(dataset_tr, self.batch_size, + initial_patch_size, + self.configuration_manager.patch_size, + self.label_manager, + oversample_foreground_percent=self.oversample_foreground_percent, + sampling_probabilities=None, pad_sides=None) + dl_val = nnUNetDataLoader2D(dataset_val, self.batch_size, + self.configuration_manager.patch_size, + self.configuration_manager.patch_size, + self.label_manager, + oversample_foreground_percent=self.oversample_foreground_percent, + sampling_probabilities=None, pad_sides=None) + else: + dl_tr = nnUNetDataLoader3D(dataset_tr, self.batch_size, + initial_patch_size, + self.configuration_manager.patch_size, + self.label_manager, + oversample_foreground_percent=self.oversample_foreground_percent, + sampling_probabilities=None, pad_sides=None) + dl_val = nnUNetDataLoader3D(dataset_val, self.batch_size, + self.configuration_manager.patch_size, + self.configuration_manager.patch_size, + self.label_manager, + oversample_foreground_percent=self.oversample_foreground_percent, + sampling_probabilities=None, pad_sides=None) + return dl_tr, dl_val + + @staticmethod + def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]], + rotation_for_DA: dict, + deep_supervision_scales: Union[List, Tuple], + mirror_axes: Tuple[int, ...], + do_dummy_2d_data_aug: bool, + order_resampling_data: int = 3, + order_resampling_seg: int = 1, + border_val_seg: int = -1, + use_mask_for_norm: List[bool] = None, + is_cascaded: bool = False, + foreground_labels: Union[Tuple[int, ...], List[int]] = None, + regions: List[Union[List[int], Tuple[int, ...], int]] = None, + ignore_label: int = None) -> AbstractTransform: + tr_transforms = [] + if do_dummy_2d_data_aug: + ignore_axes = (0,) + tr_transforms.append(Convert3DTo2DTransform()) + patch_size_spatial = patch_size[1:] + else: + patch_size_spatial = patch_size + ignore_axes = None + + tr_transforms.append(SpatialTransform( + patch_size_spatial, patch_center_dist_from_border=None, + do_elastic_deform=False, alpha=(0, 0), sigma=(0, 0), + do_rotation=True, angle_x=rotation_for_DA['x'], angle_y=rotation_for_DA['y'], angle_z=rotation_for_DA['z'], + p_rot_per_axis=1, # todo experiment with this + do_scale=True, scale=(0.7, 1.4), + border_mode_data="constant", border_cval_data=0, order_data=order_resampling_data, + border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=order_resampling_seg, + random_crop=False, # random cropping is part of our dataloaders + p_el_per_sample=0, p_scale_per_sample=0.2, p_rot_per_sample=0.2, + independent_scale_for_each_axis=False # todo experiment with this + )) + + if do_dummy_2d_data_aug: + tr_transforms.append(Convert2DTo3DTransform()) + + tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1)) + tr_transforms.append(GaussianBlurTransform((0.5, 1.), different_sigma_per_channel=True, p_per_sample=0.2, + p_per_channel=0.5)) + tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25), p_per_sample=0.15)) + tr_transforms.append(ContrastAugmentationTransform(p_per_sample=0.15)) + tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True, + p_per_channel=0.5, + order_downsample=0, order_upsample=3, p_per_sample=0.25, + ignore_axes=ignore_axes)) + tr_transforms.append(GammaTransform((0.7, 1.5), True, True, retain_stats=True, p_per_sample=0.1)) + tr_transforms.append(GammaTransform((0.7, 1.5), False, True, retain_stats=True, p_per_sample=0.3)) + + if mirror_axes is not None and len(mirror_axes) > 0: + tr_transforms.append(MirrorTransform(mirror_axes)) + + if use_mask_for_norm is not None and any(use_mask_for_norm): + tr_transforms.append(MaskTransform([i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]], + mask_idx_in_seg=0, set_outside_to=0)) + + tr_transforms.append(RemoveLabelTransform(-1, 0)) + + if is_cascaded: + assert foreground_labels is not None, 'We need foreground_labels for cascade augmentations' + tr_transforms.append(MoveSegAsOneHotToData(1, foreground_labels, 'seg', 'data')) + tr_transforms.append(ApplyRandomBinaryOperatorTransform( + channel_idx=list(range(-len(foreground_labels), 0)), + p_per_sample=0.4, + key="data", + strel_size=(1, 8), + p_per_label=1)) + tr_transforms.append( + RemoveRandomConnectedComponentFromOneHotEncodingTransform( + channel_idx=list(range(-len(foreground_labels), 0)), + key="data", + p_per_sample=0.2, + fill_with_other_class_p=0, + dont_do_if_covers_more_than_x_percent=0.15)) + + tr_transforms.append(RenameTransform('seg', 'target', True)) + + if regions is not None: + # the ignore label must also be converted + tr_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label] + if ignore_label is not None else regions, + 'target', 'target')) + + if deep_supervision_scales is not None: + tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target', + output_key='target')) + tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) + tr_transforms = Compose(tr_transforms) + return tr_transforms + + @staticmethod + def get_validation_transforms(deep_supervision_scales: Union[List, Tuple], + is_cascaded: bool = False, + foreground_labels: Union[Tuple[int, ...], List[int]] = None, + regions: List[Union[List[int], Tuple[int, ...], int]] = None, + ignore_label: int = None) -> AbstractTransform: + val_transforms = [] + val_transforms.append(RemoveLabelTransform(-1, 0)) + + if is_cascaded: + val_transforms.append(MoveSegAsOneHotToData(1, foreground_labels, 'seg', 'data')) + + val_transforms.append(RenameTransform('seg', 'target', True)) + + if regions is not None: + # the ignore label must also be converted + val_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label] + if ignore_label is not None else regions, + 'target', 'target')) + + if deep_supervision_scales is not None: + val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target', + output_key='target')) + + val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) + val_transforms = Compose(val_transforms) + return val_transforms + + def set_deep_supervision_enabled(self, enabled: bool): + """ + This function is specific for the default architecture in nnU-Net. If you change the architecture, there are + chances you need to change this as well! + """ + if self.is_ddp: + self.network.module.decoder.deep_supervision = enabled + else: + self.network.decoder.deep_supervision = enabled + + def on_train_start(self): + if not self.was_initialized: + self.initialize() + + maybe_mkdir_p(self.output_folder) + + # make sure deep supervision is on in the network + self.set_deep_supervision_enabled(True) + + self.print_plans() + empty_cache(self.device) + + # maybe unpack + if self.unpack_dataset and self.local_rank == 0: + self.print_to_log_file('unpacking dataset...') + unpack_dataset(self.preprocessed_dataset_folder, unpack_segmentation=True, overwrite_existing=False, + num_processes=max(1, round(get_allowed_n_proc_DA() // 2))) + self.print_to_log_file('unpacking done...') + + if self.is_ddp: + dist.barrier() + + # dataloaders must be instantiated here because they need access to the training data which may not be present + # when doing inference + self.dataloader_train, self.dataloader_val = self.get_dataloaders() + + # copy plans and dataset.json so that they can be used for restoring everything we need for inference + save_json(self.plans_manager.plans, join(self.output_folder_base, 'plans.json'), sort_keys=False) + save_json(self.dataset_json, join(self.output_folder_base, 'dataset.json'), sort_keys=False) + + # we don't really need the fingerprint but its still handy to have it with the others + shutil.copy(join(self.preprocessed_dataset_folder_base, 'dataset_fingerprint.json'), + join(self.output_folder_base, 'dataset_fingerprint.json')) + + # produces a pdf in output folder + self.plot_network_architecture() + + self._save_debug_information() + + # print(f"batch size: {self.batch_size}") + # print(f"oversample: {self.oversample_foreground_percent}") + + def on_train_end(self): + self.save_checkpoint(join(self.output_folder, "checkpoint_final.pth")) + # now we can delete latest + if self.local_rank == 0 and isfile(join(self.output_folder, "checkpoint_latest.pth")): + os.remove(join(self.output_folder, "checkpoint_latest.pth")) + + # shut down dataloaders + old_stdout = sys.stdout + with open(os.devnull, 'w') as f: + sys.stdout = f + if self.dataloader_train is not None: + self.dataloader_train._finish() + if self.dataloader_val is not None: + self.dataloader_val._finish() + sys.stdout = old_stdout + + empty_cache(self.device) + self.print_to_log_file("Training done.") + + def on_train_epoch_start(self): + self.network.train() + self.lr_scheduler.step(self.current_epoch) + self.print_to_log_file('') + self.print_to_log_file(f'Epoch {self.current_epoch}') + self.print_to_log_file( + f"Current learning rate: {np.round(self.optimizer.param_groups[0]['lr'], decimals=5)}") + # lrs are the same for all workers so we don't need to gather them in case of DDP training + self.logger.log('lrs', self.optimizer.param_groups[0]['lr'], self.current_epoch) + + def train_step(self, batch: dict) -> dict: + data = batch['data'] + target = batch['target'] + + data = data.to(self.device, non_blocking=True) + if isinstance(target, list): + target = [i.to(self.device, non_blocking=True) for i in target] + else: + target = target.to(self.device, non_blocking=True) + target_class = torch.nn.functional.one_hot( + torch.cat([torch.max(target[0][idx]).long().unsqueeze(0) for idx in range(np.shape(target[0])[0])], dim=0), num_classes=2) + + self.optimizer.zero_grad() + # Autocast is a little bitch. + # If the device_type is 'cpu' then it's slow as heck and needs to be disabled. + # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False) + # So autocast will only be active if we have a cuda device. + with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): + outputs, classif = self.network(data) + # del data + l = self.loss(outputs, target) + self.classif_loss(classif, target_class.float()) + + # Seg backward loop + if self.grad_scaler is not None: + self.grad_scaler.scale(l).backward() + self.grad_scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() + else: + l.backward() + torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) + self.optimizer.step() + + return {'loss': l.detach().cpu().numpy()} + + def on_train_epoch_end(self, train_outputs: List[dict]): + outputs = collate_outputs(train_outputs) + + if self.is_ddp: + losses_tr = [None for _ in range(dist.get_world_size())] + dist.all_gather_object(losses_tr, outputs['loss']) + loss_here = np.vstack(losses_tr).mean() + else: + loss_here = np.mean(outputs['loss']) + + self.logger.log('train_losses', loss_here, self.current_epoch) + + def on_validation_epoch_start(self): + self.network.eval() + + def validation_step(self, batch: dict) -> dict: + data = batch['data'] + target = batch['target'] + + data = data.to(self.device, non_blocking=True) + if isinstance(target, list): + target = [i.to(self.device, non_blocking=True) for i in target] + else: + target = target.to(self.device, non_blocking=True) + + target_class = torch.nn.functional.one_hot( + torch.cat([torch.max(target[0][idx]).long().unsqueeze(0) for idx in range(np.shape(target[0])[0])], dim=0), num_classes=2) + # Autocast is a little bitch. + # If the device_type is 'cpu' then it's slow as heck and needs to be disabled. + # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False) + # So autocast will only be active if we have a cuda device. + with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): + output = self.network(data) + del data + l = self.loss(output, target) + + # we only need the output with the highest output resolution + output = output[0] + target = target[0] + + # the following is needed for online evaluation. Fake dice (green line) + axes = [0] + list(range(2, len(output.shape))) + + if self.label_manager.has_regions: + predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long() + else: + # no need for softmax + output_seg = output.argmax(1)[:, None] + predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float32) + predicted_segmentation_onehot.scatter_(1, output_seg, 1) + del output_seg + + if self.label_manager.has_ignore_label: + if not self.label_manager.has_regions: + mask = (target != self.label_manager.ignore_label).float() + # CAREFUL that you don't rely on target after this line! + target[target == self.label_manager.ignore_label] = 0 + else: + mask = 1 - target[:, -1:] + # CAREFUL that you don't rely on target after this line! + target = target[:, :-1] + else: + mask = None + + tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask) + + tp_hard = tp.detach().cpu().numpy() + fp_hard = fp.detach().cpu().numpy() + fn_hard = fn.detach().cpu().numpy() + if not self.label_manager.has_regions: + # if we train with regions all segmentation heads predict some kind of foreground. In conventional + # (softmax training) there needs tobe one output for the background. We are not interested in the + # background Dice + # [1:] in order to remove background + tp_hard = tp_hard[1:] + fp_hard = fp_hard[1:] + fn_hard = fn_hard[1:] + + return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard} + + def on_validation_epoch_end(self, val_outputs: List[dict]): + outputs_collated = collate_outputs(val_outputs) + tp = np.sum(outputs_collated['tp_hard'], 0) + fp = np.sum(outputs_collated['fp_hard'], 0) + fn = np.sum(outputs_collated['fn_hard'], 0) + + if self.is_ddp: + world_size = dist.get_world_size() + + tps = [None for _ in range(world_size)] + dist.all_gather_object(tps, tp) + tp = np.vstack([i[None] for i in tps]).sum(0) + + fps = [None for _ in range(world_size)] + dist.all_gather_object(fps, fp) + fp = np.vstack([i[None] for i in fps]).sum(0) + + fns = [None for _ in range(world_size)] + dist.all_gather_object(fns, fn) + fn = np.vstack([i[None] for i in fns]).sum(0) + + losses_val = [None for _ in range(world_size)] + dist.all_gather_object(losses_val, outputs_collated['loss']) + loss_here = np.vstack(losses_val).mean() + else: + loss_here = np.mean(outputs_collated['loss']) + + global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in + zip(tp, fp, fn)]] + mean_fg_dice = np.nanmean(global_dc_per_class) + self.logger.log('mean_fg_dice', mean_fg_dice, self.current_epoch) + self.logger.log('dice_per_class_or_region', global_dc_per_class, self.current_epoch) + self.logger.log('val_losses', loss_here, self.current_epoch) + + def on_epoch_start(self): + self.logger.log('epoch_start_timestamps', time(), self.current_epoch) + + def on_epoch_end(self): + self.logger.log('epoch_end_timestamps', time(), self.current_epoch) + + # todo find a solution for this stupid shit + self.print_to_log_file('train_loss', np.round(self.logger.my_fantastic_logging['train_losses'][-1], decimals=4)) + self.print_to_log_file('val_loss', np.round(self.logger.my_fantastic_logging['val_losses'][-1], decimals=4)) + self.print_to_log_file('Pseudo dice', [np.round(i, decimals=4) for i in + self.logger.my_fantastic_logging['dice_per_class_or_region'][-1]]) + self.print_to_log_file( + f"Epoch time: {np.round(self.logger.my_fantastic_logging['epoch_end_timestamps'][-1] - self.logger.my_fantastic_logging['epoch_start_timestamps'][-1], decimals=2)} s") + + # handling periodic checkpointing + current_epoch = self.current_epoch + if (current_epoch + 1) % self.save_every == 0 and current_epoch != (self.num_epochs - 1): + self.save_checkpoint(join(self.output_folder, 'checkpoint_latest.pth')) + + # handle 'best' checkpointing. ema_fg_dice is computed by the logger and can be accessed like this + if self._best_ema is None or self.logger.my_fantastic_logging['ema_fg_dice'][-1] > self._best_ema: + self._best_ema = self.logger.my_fantastic_logging['ema_fg_dice'][-1] + self.print_to_log_file(f"Yayy! New best EMA pseudo Dice: {np.round(self._best_ema, decimals=4)}") + self.save_checkpoint(join(self.output_folder, 'checkpoint_best.pth')) + + if self.local_rank == 0: + self.logger.plot_progress_png(self.output_folder) + + self.current_epoch += 1 + + def save_checkpoint(self, filename: str) -> None: + if self.local_rank == 0: + if not self.disable_checkpointing: + if self.is_ddp: + mod = self.network.module + else: + mod = self.network + if isinstance(mod, OptimizedModule): + mod = mod._orig_mod + + checkpoint = { + 'network_weights': mod.state_dict(), + 'optimizer_state': self.optimizer.state_dict(), + 'grad_scaler_state': self.grad_scaler.state_dict() if self.grad_scaler is not None else None, + 'logging': self.logger.get_checkpoint(), + '_best_ema': self._best_ema, + 'current_epoch': self.current_epoch + 1, + 'init_args': self.my_init_kwargs, + 'trainer_name': self.__class__.__name__, + 'inference_allowed_mirroring_axes': self.inference_allowed_mirroring_axes, + } + torch.save(checkpoint, filename) + else: + self.print_to_log_file('No checkpoint written, checkpointing is disabled') + + def load_checkpoint(self, filename_or_checkpoint: Union[dict, str]) -> None: + if not self.was_initialized: + self.initialize() + + if isinstance(filename_or_checkpoint, str): + checkpoint = torch.load(filename_or_checkpoint, map_location=self.device) + # if state dict comes from nn.DataParallel but we use non-parallel model here then the state dict keys do not + # match. Use heuristic to make it match + new_state_dict = {} + for k, value in checkpoint['network_weights'].items(): + key = k + if key not in self.network.state_dict().keys() and key.startswith('module.'): + key = key[7:] + new_state_dict[key] = value + + self.my_init_kwargs = checkpoint['init_args'] + self.current_epoch = checkpoint['current_epoch'] + self.logger.load_checkpoint(checkpoint['logging']) + self._best_ema = checkpoint['_best_ema'] + self.inference_allowed_mirroring_axes = checkpoint[ + 'inference_allowed_mirroring_axes'] if 'inference_allowed_mirroring_axes' in checkpoint.keys() else self.inference_allowed_mirroring_axes + + # messing with state dict naming schemes. Facepalm. + if self.is_ddp: + if isinstance(self.network.module, OptimizedModule): + self.network.module._orig_mod.load_state_dict(new_state_dict) + else: + self.network.module.load_state_dict(new_state_dict) + else: + if isinstance(self.network, OptimizedModule): + self.network._orig_mod.load_state_dict(new_state_dict) + else: + self.network.load_state_dict(new_state_dict) + self.optimizer.load_state_dict(checkpoint['optimizer_state']) + if self.grad_scaler is not None: + if checkpoint['grad_scaler_state'] is not None: + self.grad_scaler.load_state_dict(checkpoint['grad_scaler_state']) + + def perform_actual_validation(self, save_probabilities: bool = False): + self.set_deep_supervision_enabled(False) + self.network.eval() + + predictor = nnUNetPredictor(tile_step_size=0.5, use_gaussian=True, use_mirroring=True, + perform_everything_on_gpu=True, device=self.device, verbose=False, + verbose_preprocessing=False, allow_tqdm=False) + predictor.manual_initialization(self.network, self.plans_manager, self.configuration_manager, None, + self.dataset_json, self.__class__.__name__, + self.inference_allowed_mirroring_axes) + + with multiprocessing.get_context("spawn").Pool(default_num_processes) as segmentation_export_pool: + worker_list = [i for i in segmentation_export_pool._pool] + validation_output_folder = join(self.output_folder, 'validation') + maybe_mkdir_p(validation_output_folder) + + # we cannot use self.get_tr_and_val_datasets() here because we might be DDP and then we have to distribute + # the validation keys across the workers. + _, val_keys = self.do_split() + if self.is_ddp: + val_keys = val_keys[self.local_rank:: dist.get_world_size()] + + dataset_val = nnUNetDataset(self.preprocessed_dataset_folder, val_keys, + folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage, + num_images_properties_loading_threshold=0) + + next_stages = self.configuration_manager.next_stage_names + + if next_stages is not None: + _ = [maybe_mkdir_p(join(self.output_folder_base, 'predicted_next_stage', n)) for n in next_stages] + + results = [] + for k in dataset_val.keys(): + proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results, + allowed_num_queued=2) + while not proceed: + sleep(0.1) + proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results, + allowed_num_queued=2) + + self.print_to_log_file(f"predicting {k}") + data, seg, properties = dataset_val.load_case(k) + + if self.is_cascaded: + data = np.vstack((data, convert_labelmap_to_one_hot(seg[-1], self.label_manager.foreground_labels, + output_dtype=data.dtype))) + with warnings.catch_warnings(): + # ignore 'The given NumPy array is not writable' warning + warnings.simplefilter("ignore") + data = torch.from_numpy(data) + + output_filename_truncated = join(validation_output_folder, k) + + try: + prediction = predictor.predict_sliding_window_return_logits(data) + except RuntimeError: + predictor.perform_everything_on_gpu = False + prediction = predictor.predict_sliding_window_return_logits(data) + predictor.perform_everything_on_gpu = True + + prediction = prediction.cpu() + + # this needs to go into background processes + results.append( + segmentation_export_pool.starmap_async( + export_prediction_from_logits, ( + (prediction, properties, self.configuration_manager, self.plans_manager, + self.dataset_json, output_filename_truncated, save_probabilities), + ) + ) + ) + # for debug purposes + # export_prediction(prediction_for_export, properties, self.configuration, self.plans, self.dataset_json, + # output_filename_truncated, save_probabilities) + + # if needed, export the softmax prediction for the next stage + if next_stages is not None: + for n in next_stages: + next_stage_config_manager = self.plans_manager.get_configuration(n) + expected_preprocessed_folder = join(nnUNet_preprocessed, self.plans_manager.dataset_name, + next_stage_config_manager.data_identifier) + + try: + # we do this so that we can use load_case and do not have to hard code how loading training cases is implemented + tmp = nnUNetDataset(expected_preprocessed_folder, [k], + num_images_properties_loading_threshold=0) + d, s, p = tmp.load_case(k) + except FileNotFoundError: + self.print_to_log_file( + f"Predicting next stage {n} failed for case {k} because the preprocessed file is missing! " + f"Run the preprocessing for this configuration first!") + continue + + target_shape = d.shape[1:] + output_folder = join(self.output_folder_base, 'predicted_next_stage', n) + output_file = join(output_folder, k + '.npz') + + # resample_and_save(prediction, target_shape, output_file, self.plans_manager, self.configuration_manager, properties, + # self.dataset_json) + results.append(segmentation_export_pool.starmap_async( + resample_and_save, ( + (prediction, target_shape, output_file, self.plans_manager, + self.configuration_manager, + properties, + self.dataset_json), + ) + )) + + _ = [r.get() for r in results] + + if self.is_ddp: + dist.barrier() + + if self.local_rank == 0: + metrics = compute_metrics_on_folder(join(self.preprocessed_dataset_folder_base, 'gt_segmentations'), + validation_output_folder, + join(validation_output_folder, 'summary.json'), + self.plans_manager.image_reader_writer_class(), + self.dataset_json["file_ending"], + self.label_manager.foreground_regions if self.label_manager.has_regions else + self.label_manager.foreground_labels, + self.label_manager.ignore_label, chill=True) + self.print_to_log_file("Validation complete", also_print_to_console=True) + self.print_to_log_file("Mean Validation Dice: ", (metrics['foreground_mean']["Dice"]), also_print_to_console=True) + + self.set_deep_supervision_enabled(True) + compute_gaussian.cache_clear() + + def run_training(self): + self.on_train_start() + + for epoch in range(self.current_epoch, self.num_epochs): + self.on_epoch_start() + + self.on_train_epoch_start() + train_outputs = [] + for batch_id in range(self.num_iterations_per_epoch): + train_outputs.append(self.train_step(next(self.dataloader_train))) + self.on_train_epoch_end(train_outputs) + + with torch.no_grad(): + self.on_validation_epoch_start() + val_outputs = [] + for batch_id in range(self.num_val_iterations_per_epoch): + val_outputs.append(self.validation_step(next(self.dataloader_val))) + self.on_validation_epoch_end(val_outputs) + + self.on_epoch_end() + + self.on_train_end() diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer_segrap.py b/nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer_segrap.py new file mode 100644 index 0000000..6d95b26 --- /dev/null +++ b/nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer_segrap.py @@ -0,0 +1,1242 @@ +import inspect +import multiprocessing +import os +import shutil +import sys +import warnings +from copy import deepcopy +from datetime import datetime +from time import time, sleep +from typing import Union, Tuple, List + +import numpy as np +import torch +from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter +from batchgenerators.transforms.abstract_transforms import AbstractTransform, Compose +from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \ + ContrastAugmentationTransform, GammaTransform +from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform +from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform +from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform +from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor +from batchgenerators.utilities.file_and_folder_operations import join, load_json, isfile, save_json, maybe_mkdir_p +from torch._dynamo import OptimizedModule + +from nnunetv2.configuration import ANISO_THRESHOLD, default_num_processes +from nnunetv2.evaluation.evaluate_predictions import compute_metrics_on_folder +from nnunetv2.inference.export_prediction import export_prediction_from_logits, resample_and_save +from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor +from nnunetv2.inference.sliding_window_prediction import compute_gaussian +from nnunetv2.paths import nnUNet_preprocessed, nnUNet_results +from nnunetv2.training.data_augmentation.compute_initial_patch_size import get_patch_size +from nnunetv2.training.data_augmentation.custom_transforms.cascade_transforms import MoveSegAsOneHotToData, \ + ApplyRandomBinaryOperatorTransform, RemoveRandomConnectedComponentFromOneHotEncodingTransform +from nnunetv2.training.data_augmentation.custom_transforms.deep_supervision_donwsampling import \ + DownsampleSegForDSTransform2 +from nnunetv2.training.data_augmentation.custom_transforms.limited_length_multithreaded_augmenter import \ + LimitedLenWrapper +from nnunetv2.training.data_augmentation.custom_transforms.masking import MaskTransform +from nnunetv2.training.data_augmentation.custom_transforms.region_based_training import \ + ConvertSegmentationToRegionsTransform +from nnunetv2.training.data_augmentation.custom_transforms.transforms_for_dummy_2d import Convert2DTo3DTransform, \ + Convert3DTo2DTransform +from nnunetv2.training.dataloading.data_loader_2d import nnUNetDataLoader2D +from nnunetv2.training.dataloading.data_loader_3d import nnUNetDataLoader3D +from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset +from nnunetv2.training.dataloading.utils import get_case_identifiers, unpack_dataset +from nnunetv2.training.logging.nnunet_logger import nnUNetLogger +from nnunetv2.training.loss.compound_losses import DC_and_CE_loss, DC_and_BCE_loss +from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper +from nnunetv2.training.loss.dice import get_tp_fp_fn_tn, MemoryEfficientSoftDiceLoss +from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler +from nnunetv2.utilities.collate_outputs import collate_outputs +from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA +from nnunetv2.utilities.file_path_utilities import check_workers_alive_and_busy +from nnunetv2.utilities.get_network_from_plans import get_network_from_plans +from nnunetv2.utilities.helpers import empty_cache, dummy_context +from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot, determine_num_input_channels +from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager +from sklearn.model_selection import KFold +from torch import autocast, nn +from torch import distributed as dist +from torch.cuda import device_count +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP + + +class nnUNetTrainer_segrap(object): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + # From https://grugbrain.dev/. Worth a read ya big brains ;-) + + # apex predator of grug is complexity + # complexity bad + # say again: + # complexity very bad + # you say now: + # complexity very, very bad + # given choice between complexity or one on one against t-rex, grug take t-rex: at least grug see t-rex + # complexity is spirit demon that enter codebase through well-meaning but ultimately very clubbable non grug-brain developers and project managers who not fear complexity spirit demon or even know about sometime + # one day code base understandable and grug can get work done, everything good! + # next day impossible: complexity demon spirit has entered code and very dangerous situation! + + # OK OK I am guilty. But I tried. http://tiny.cc/gzgwuz + + self.is_ddp = dist.is_available() and dist.is_initialized() + self.local_rank = 0 if not self.is_ddp else dist.get_rank() + + self.device = device + + # print what device we are using + if self.is_ddp: # implicitly it's clear that we use cuda in this case + print(f"I am local rank {self.local_rank}. {device_count()} GPUs are available. The world size is " + f"{dist.get_world_size()}." + f"Setting device to {self.device}") + self.device = torch.device(type='cuda', index=self.local_rank) + else: + if self.device.type == 'cuda': + # we might want to let the user pick this but for now please pick the correct GPU with CUDA_VISIBLE_DEVICES=X + self.device = torch.device(type='cuda', index=0) + print(f"Using device: {self.device}") + + # loading and saving this class for continuing from checkpoint should not happen based on pickling. This + # would also pickle the network etc. Bad, bad. Instead we just reinstantiate and then load the checkpoint we + # need. So let's save the init args + self.my_init_kwargs = {} + for k in inspect.signature(self.__init__).parameters.keys(): + self.my_init_kwargs[k] = locals()[k] + + ### Saving all the init args into class variables for later access + self.plans_manager = PlansManager(plans) + self.configuration_manager = self.plans_manager.get_configuration(configuration) + self.configuration_name = configuration + self.dataset_json = dataset_json + self.fold = fold + self.unpack_dataset = unpack_dataset + + ### Setting all the folder names. We need to make sure things don't crash in case we are just running + # inference and some of the folders may not be defined! + self.preprocessed_dataset_folder_base = join(nnUNet_preprocessed, self.plans_manager.dataset_name) \ + if nnUNet_preprocessed is not None else None + self.output_folder_base = join(nnUNet_results, self.plans_manager.dataset_name, + self.__class__.__name__ + '__' + self.plans_manager.plans_name + "__" + configuration) \ + if nnUNet_results is not None else None + self.output_folder = join(self.output_folder_base, f'fold_{fold}') + + self.preprocessed_dataset_folder = join(self.preprocessed_dataset_folder_base, + self.configuration_manager.data_identifier) + # unlike the previous nnunet folder_with_segs_from_previous_stage is now part of the plans. For now it has to + # be a different configuration in the same plans + # IMPORTANT! the mapping must be bijective, so lowres must point to fullres and vice versa (using + # "previous_stage" and "next_stage"). Otherwise it won't work! + self.is_cascaded = self.configuration_manager.previous_stage_name is not None + self.folder_with_segs_from_previous_stage = \ + join(nnUNet_results, self.plans_manager.dataset_name, + self.__class__.__name__ + '__' + self.plans_manager.plans_name + "__" + + self.configuration_manager.previous_stage_name, 'predicted_next_stage', self.configuration_name) \ + if self.is_cascaded else None + + ### Some hyperparameters for you to fiddle with + self.initial_lr = 1e-2 + self.weight_decay = 3e-5 + self.oversample_foreground_percent = 0.33 + self.num_iterations_per_epoch = 250 + self.num_val_iterations_per_epoch = 50 + self.num_epochs = 1000 + self.current_epoch = 0 + + ### Dealing with labels/regions + self.label_manager = self.plans_manager.get_label_manager(dataset_json) + # labels can either be a list of int (regular training) or a list of tuples of int (region-based training) + # needed for predictions. We do sigmoid in case of (overlapping) regions + + self.num_input_channels = None # -> self.initialize() + self.network = None # -> self._get_network() + self.optimizer = self.lr_scheduler = None # -> self.initialize + self.grad_scaler = GradScaler() if self.device.type == 'cuda' else None + self.loss = None # -> self.initialize + + ### Simple logging. Don't take that away from me! + # initialize log file. This is just our log for the print statements etc. Not to be confused with lightning + # logging + timestamp = datetime.now() + maybe_mkdir_p(self.output_folder) + self.log_file = join(self.output_folder, "training_log_%d_%d_%d_%02.0d_%02.0d_%02.0d.txt" % + (timestamp.year, timestamp.month, timestamp.day, timestamp.hour, timestamp.minute, + timestamp.second)) + self.logger = nnUNetLogger() + + ### placeholders + self.dataloader_train = self.dataloader_val = None # see on_train_start + + ### initializing stuff for remembering things and such + self._best_ema = None + + ### inference things + self.inference_allowed_mirroring_axes = None # this variable is set in + # self.configure_rotation_dummyDA_mirroring_and_inital_patch_size and will be saved in checkpoints + + ### checkpoint saving stuff + self.save_every = 50 + self.disable_checkpointing = False + + ## DDP batch size and oversampling can differ between workers and needs adaptation + # we need to change the batch size in DDP because we don't use any of those distributed samplers + self._set_batch_size_and_oversample() + + self.was_initialized = False + + self.print_to_log_file("\n#######################################################################\n" + "Please cite the following paper when using nnU-Net:\n" + "Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). " + "nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. " + "Nature methods, 18(2), 203-211.\n" + "#######################################################################\n", + also_print_to_console=True, add_timestamp=False) + + def initialize(self): + if not self.was_initialized: + self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager, + self.dataset_json) + + self.network = self.build_network_architecture(self.plans_manager, self.dataset_json, + self.configuration_manager, + self.num_input_channels, + enable_deep_supervision=True).to(self.device) + self.reconstruction_decoder = copy.deepcopy(self.network.decoder) + # compile network for free speedup + if ('nnUNet_compile' in os.environ.keys()) and ( + os.environ['nnUNet_compile'].lower() in ('true', '1', 't')): + self.print_to_log_file('Compiling network...') + self.network = torch.compile(self.network) + + self.optimizer, self.lr_scheduler = self.configure_optimizers() + # if ddp, wrap in DDP wrapper + if self.is_ddp: + self.network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.network) + self.network = DDP(self.network, device_ids=[self.local_rank]) + + self.loss = self._build_loss() + self.was_initialized = True + else: + raise RuntimeError("You have called self.initialize even though the trainer was already initialized. " + "That should not happen.") + + def _save_debug_information(self): + # saving some debug information + if self.local_rank == 0: + dct = {} + for k in self.__dir__(): + if not k.startswith("__"): + if not callable(getattr(self, k)) or k in ['loss', ]: + dct[k] = str(getattr(self, k)) + elif k in ['network', ]: + dct[k] = str(getattr(self, k).__class__.__name__) + else: + # print(k) + pass + if k in ['dataloader_train', 'dataloader_val']: + if hasattr(getattr(self, k), 'generator'): + dct[k + '.generator'] = str(getattr(self, k).generator) + if hasattr(getattr(self, k), 'num_processes'): + dct[k + '.num_processes'] = str(getattr(self, k).num_processes) + if hasattr(getattr(self, k), 'transform'): + dct[k + '.transform'] = str(getattr(self, k).transform) + import subprocess + hostname = subprocess.getoutput(['hostname']) + dct['hostname'] = hostname + torch_version = torch.__version__ + if self.device.type == 'cuda': + gpu_name = torch.cuda.get_device_name() + dct['gpu_name'] = gpu_name + cudnn_version = torch.backends.cudnn.version() + else: + cudnn_version = 'None' + dct['device'] = str(self.device) + dct['torch_version'] = torch_version + dct['cudnn_version'] = cudnn_version + save_json(dct, join(self.output_folder, "debug.json")) + + @staticmethod + def build_network_architecture(plans_manager: PlansManager, + dataset_json, + configuration_manager: ConfigurationManager, + num_input_channels, + enable_deep_supervision: bool = True) -> nn.Module: + """ + his is where you build the architecture according to the plans. There is no obligation to use + get_network_from_plans, this is just a utility we use for the nnU-Net default architectures. You can do what + you want. Even ignore the plans and just return something static (as long as it can process the requested + patch size) + but don't bug us with your bugs arising from fiddling with this :-P + This is the function that is called in inference as well! This is needed so that all network architecture + variants can be loaded at inference time (inference will use the same nnUNetTrainer that was used for + training, so if you change the network architecture during training by deriving a new trainer class then + inference will know about it). + + If you need to know how many segmentation outputs your custom architecture needs to have, use the following snippet: + > label_manager = plans_manager.get_label_manager(dataset_json) + > label_manager.num_segmentation_heads + (why so complicated? -> We can have either classical training (classes) or regions. If we have regions, + the number of outputs is != the number of classes. Also there is the ignore label for which no output + should be generated. label_manager takes care of all that for you.) + + """ + return get_network_from_plans(plans_manager, dataset_json, configuration_manager, + num_input_channels, deep_supervision=enable_deep_supervision) + + def _get_deep_supervision_scales(self): + deep_supervision_scales = list(list(i) for i in 1 / np.cumprod(np.vstack( + self.configuration_manager.pool_op_kernel_sizes), axis=0))[:-1] + return deep_supervision_scales + + def _set_batch_size_and_oversample(self): + if not self.is_ddp: + # set batch size to what the plan says, leave oversample untouched + self.batch_size = self.configuration_manager.batch_size + else: + # batch size is distributed over DDP workers and we need to change oversample_percent for each worker + batch_sizes = [] + oversample_percents = [] + + world_size = dist.get_world_size() + my_rank = dist.get_rank() + + global_batch_size = self.configuration_manager.batch_size + assert global_batch_size >= world_size, 'Cannot run DDP if the batch size is smaller than the number of ' \ + 'GPUs... Duh.' + + batch_size_per_GPU = np.ceil(global_batch_size / world_size).astype(int) + + for rank in range(world_size): + if (rank + 1) * batch_size_per_GPU > global_batch_size: + batch_size = batch_size_per_GPU - ((rank + 1) * batch_size_per_GPU - global_batch_size) + else: + batch_size = batch_size_per_GPU + + batch_sizes.append(batch_size) + + sample_id_low = 0 if len(batch_sizes) == 0 else np.sum(batch_sizes[:-1]) + sample_id_high = np.sum(batch_sizes) + + if sample_id_high / global_batch_size < (1 - self.oversample_foreground_percent): + oversample_percents.append(0.0) + elif sample_id_low / global_batch_size > (1 - self.oversample_foreground_percent): + oversample_percents.append(1.0) + else: + percent_covered_by_this_rank = sample_id_high / global_batch_size - sample_id_low / global_batch_size + oversample_percent_here = 1 - (((1 - self.oversample_foreground_percent) - + sample_id_low / global_batch_size) / percent_covered_by_this_rank) + oversample_percents.append(oversample_percent_here) + + print("worker", my_rank, "oversample", oversample_percents[my_rank]) + print("worker", my_rank, "batch_size", batch_sizes[my_rank]) + # self.print_to_log_file("worker", my_rank, "oversample", oversample_percents[my_rank]) + # self.print_to_log_file("worker", my_rank, "batch_size", batch_sizes[my_rank]) + + self.batch_size = batch_sizes[my_rank] + self.oversample_foreground_percent = oversample_percents[my_rank] + + def _build_loss(self): + if self.label_manager.has_regions: + loss = DC_and_BCE_loss({}, + {'batch_dice': self.configuration_manager.batch_dice, + 'do_bg': True, 'smooth': 1e-5, 'ddp': self.is_ddp}, + use_ignore_label=self.label_manager.ignore_label is not None, + dice_class=MemoryEfficientSoftDiceLoss) + else: + loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice, + 'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1, + ignore_label=self.label_manager.ignore_label, dice_class=MemoryEfficientSoftDiceLoss) + + deep_supervision_scales = self._get_deep_supervision_scales() + + # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) + weights[-1] = 0 + + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) + return loss + + def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): + """ + This function is stupid and certainly one of the weakest spots of this implementation. Not entirely sure how we can fix it. + """ + patch_size = self.configuration_manager.patch_size + dim = len(patch_size) + # todo rotation should be defined dynamically based on patch size (more isotropic patch sizes = more rotation) + if dim == 2: + do_dummy_2d_data_aug = False + # todo revisit this parametrization + if max(patch_size) / min(patch_size) > 1.5: + rotation_for_DA = { + 'x': (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi), + 'y': (0, 0), + 'z': (0, 0) + } + else: + rotation_for_DA = { + 'x': (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi), + 'y': (0, 0), + 'z': (0, 0) + } + mirror_axes = (0, 1) + elif dim == 3: + # todo this is not ideal. We could also have patch_size (64, 16, 128) in which case a full 180deg 2d rot would be bad + # order of the axes is determined by spacing, not image size + do_dummy_2d_data_aug = (max(patch_size) / patch_size[0]) > ANISO_THRESHOLD + if do_dummy_2d_data_aug: + # why do we rotate 180 deg here all the time? We should also restrict it + rotation_for_DA = { + 'x': (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi), + 'y': (0, 0), + 'z': (0, 0) + } + else: + rotation_for_DA = { + 'x': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), + 'y': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), + 'z': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), + } + mirror_axes = (0, 1, 2) + else: + raise RuntimeError() + + # todo this function is stupid. It doesn't even use the correct scale range (we keep things as they were in the + # old nnunet for now) + initial_patch_size = get_patch_size(patch_size[-dim:], + *rotation_for_DA.values(), + (0.85, 1.25)) + if do_dummy_2d_data_aug: + initial_patch_size[0] = patch_size[0] + + self.print_to_log_file(f'do_dummy_2d_data_aug: {do_dummy_2d_data_aug}') + self.inference_allowed_mirroring_axes = mirror_axes + + return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes + + def print_to_log_file(self, *args, also_print_to_console=True, add_timestamp=True): + if self.local_rank == 0: + timestamp = time() + dt_object = datetime.fromtimestamp(timestamp) + + if add_timestamp: + args = ("%s:" % dt_object, *args) + + successful = False + max_attempts = 5 + ctr = 0 + while not successful and ctr < max_attempts: + try: + with open(self.log_file, 'a+') as f: + for a in args: + f.write(str(a)) + f.write(" ") + f.write("\n") + successful = True + except IOError: + print("%s: failed to log: " % datetime.fromtimestamp(timestamp), sys.exc_info()) + sleep(0.5) + ctr += 1 + if also_print_to_console: + print(*args) + elif also_print_to_console: + print(*args) + + def print_plans(self): + if self.local_rank == 0: + dct = deepcopy(self.plans_manager.plans) + del dct['configurations'] + self.print_to_log_file(f"\nThis is the configuration used by this " + f"training:\nConfiguration name: {self.configuration_name}\n", + self.configuration_manager, '\n', add_timestamp=False) + self.print_to_log_file('These are the global plan.json settings:\n', dct, '\n', add_timestamp=False) + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, + momentum=0.99, nesterov=True) + lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs, exponent=3.) + return optimizer, lr_scheduler + + def plot_network_architecture(self): + if self.local_rank == 0: + try: + # raise NotImplementedError('hiddenlayer no longer works and we do not have a viable alternative :-(') + # pip install git+https://github.com/saugatkandel/hiddenlayer.git + + # from torchviz import make_dot + # # not viable. + # make_dot(tuple(self.network(torch.rand((1, self.num_input_channels, + # *self.configuration_manager.patch_size), + # device=self.device)))).render( + # join(self.output_folder, "network_architecture.pdf"), format='pdf') + # self.optimizer.zero_grad() + + # broken. + + import hiddenlayer as hl + g = hl.build_graph(self.network, + torch.rand((1, self.num_input_channels, + *self.configuration_manager.patch_size), + device=self.device), + transforms=None) + g.save(join(self.output_folder, "network_architecture.pdf")) + del g + except Exception as e: + self.print_to_log_file("Unable to plot network architecture:") + self.print_to_log_file(e) + + # self.print_to_log_file("\nprinting the network instead:\n") + # self.print_to_log_file(self.network) + # self.print_to_log_file("\n") + finally: + empty_cache(self.device) + + def do_split(self): + """ + The default split is a 5 fold CV on all available training cases. nnU-Net will create a split (it is seeded, + so always the same) and save it as splits_final.pkl file in the preprocessed data directory. + Sometimes you may want to create your own split for various reasons. For this you will need to create your own + splits_final.pkl file. If this file is present, nnU-Net is going to use it and whatever splits are defined in + it. You can create as many splits in this file as you want. Note that if you define only 4 splits (fold 0-3) + and then set fold=4 when training (that would be the fifth split), nnU-Net will print a warning and proceed to + use a random 80:20 data split. + :return: + """ + if self.fold == "all": + # if fold==all then we use all images for training and validation + case_identifiers = get_case_identifiers(self.preprocessed_dataset_folder) + tr_keys = case_identifiers + val_keys = tr_keys + else: + splits_file = join(self.preprocessed_dataset_folder_base, "splits_final.json") + dataset = nnUNetDataset(self.preprocessed_dataset_folder, case_identifiers=None, + num_images_properties_loading_threshold=0, + folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage) + # if the split file does not exist we need to create it + if not isfile(splits_file): + self.print_to_log_file("Creating new 5-fold cross-validation split...") + splits = [] + all_keys_sorted = np.sort(list(dataset.keys())) + kfold = KFold(n_splits=5, shuffle=True, random_state=12345) + for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)): + train_keys = np.array(all_keys_sorted)[train_idx] + test_keys = np.array(all_keys_sorted)[test_idx] + splits.append({}) + splits[-1]['train'] = list(train_keys) + splits[-1]['val'] = list(test_keys) + save_json(splits, splits_file) + + else: + self.print_to_log_file("Using splits from existing split file:", splits_file) + splits = load_json(splits_file) + self.print_to_log_file("The split file contains %d splits." % len(splits)) + + self.print_to_log_file("Desired fold for training: %d" % self.fold) + if self.fold < len(splits): + tr_keys = splits[self.fold]['train'] + val_keys = splits[self.fold]['val'] + self.print_to_log_file("This split has %d training and %d validation cases." + % (len(tr_keys), len(val_keys))) + else: + self.print_to_log_file("INFO: You requested fold %d for training but splits " + "contain only %d folds. I am now creating a " + "random (but seeded) 80:20 split!" % (self.fold, len(splits))) + # if we request a fold that is not in the split file, create a random 80:20 split + rnd = np.random.RandomState(seed=12345 + self.fold) + keys = np.sort(list(dataset.keys())) + idx_tr = rnd.choice(len(keys), int(len(keys) * 0.8), replace=False) + idx_val = [i for i in range(len(keys)) if i not in idx_tr] + tr_keys = [keys[i] for i in idx_tr] + val_keys = [keys[i] for i in idx_val] + self.print_to_log_file("This random 80:20 split has %d training and %d validation cases." + % (len(tr_keys), len(val_keys))) + if any([i in val_keys for i in tr_keys]): + self.print_to_log_file('WARNING: Some validation cases are also in the training set. Please check the ' + 'splits.json or ignore if this is intentional.') + return tr_keys, val_keys + + def get_tr_and_val_datasets(self): + # create dataset split + tr_keys, val_keys = self.do_split() + + # load the datasets for training and validation. Note that we always draw random samples so we really don't + # care about distributing training cases across GPUs. + dataset_tr = nnUNetDataset(self.preprocessed_dataset_folder, tr_keys, + folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage, + num_images_properties_loading_threshold=0) + dataset_val = nnUNetDataset(self.preprocessed_dataset_folder, val_keys, + folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage, + num_images_properties_loading_threshold=0) + return dataset_tr, dataset_val + + def get_dataloaders(self): + # we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether + # we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be + patch_size = self.configuration_manager.patch_size + dim = len(patch_size) + + # needed for deep supervision: how much do we need to downscale the segmentation targets for the different + # outputs? + deep_supervision_scales = self._get_deep_supervision_scales() + + rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ + self.configure_rotation_dummyDA_mirroring_and_inital_patch_size() + + # training pipeline + tr_transforms = self.get_training_transforms( + patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug, + order_resampling_data=3, order_resampling_seg=1, + use_mask_for_norm=self.configuration_manager.use_mask_for_norm, + is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.foreground_labels, + regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + # validation pipeline + val_transforms = self.get_validation_transforms(deep_supervision_scales, + is_cascaded=self.is_cascaded, + foreground_labels=self.label_manager.foreground_labels, + regions=self.label_manager.foreground_regions if + self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim) + + allowed_num_processes = get_allowed_n_proc_DA() + if allowed_num_processes == 0: + mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms) + mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms) + else: + mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, data_loader=dl_tr, transform=tr_transforms, + num_processes=allowed_num_processes, num_cached=6, seeds=None, + pin_memory=self.device.type == 'cuda', wait_time=0.02) + mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, data_loader=dl_val, + transform=val_transforms, num_processes=max(1, allowed_num_processes // 2), + num_cached=3, seeds=None, pin_memory=self.device.type == 'cuda', + wait_time=0.02) + return mt_gen_train, mt_gen_val + + def get_plain_dataloaders(self, initial_patch_size: Tuple[int, ...], dim: int): + dataset_tr, dataset_val = self.get_tr_and_val_datasets() + + if dim == 2: + dl_tr = nnUNetDataLoader2D(dataset_tr, self.batch_size, + initial_patch_size, + self.configuration_manager.patch_size, + self.label_manager, + oversample_foreground_percent=self.oversample_foreground_percent, + sampling_probabilities=None, pad_sides=None) + dl_val = nnUNetDataLoader2D(dataset_val, self.batch_size, + self.configuration_manager.patch_size, + self.configuration_manager.patch_size, + self.label_manager, + oversample_foreground_percent=self.oversample_foreground_percent, + sampling_probabilities=None, pad_sides=None) + else: + dl_tr = nnUNetDataLoader3D(dataset_tr, self.batch_size, + initial_patch_size, + self.configuration_manager.patch_size, + self.label_manager, + oversample_foreground_percent=self.oversample_foreground_percent, + sampling_probabilities=None, pad_sides=None) + dl_val = nnUNetDataLoader3D(dataset_val, self.batch_size, + self.configuration_manager.patch_size, + self.configuration_manager.patch_size, + self.label_manager, + oversample_foreground_percent=self.oversample_foreground_percent, + sampling_probabilities=None, pad_sides=None) + return dl_tr, dl_val + + @staticmethod + def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]], + rotation_for_DA: dict, + deep_supervision_scales: Union[List, Tuple], + mirror_axes: Tuple[int, ...], + do_dummy_2d_data_aug: bool, + order_resampling_data: int = 3, + order_resampling_seg: int = 1, + border_val_seg: int = -1, + use_mask_for_norm: List[bool] = None, + is_cascaded: bool = False, + foreground_labels: Union[Tuple[int, ...], List[int]] = None, + regions: List[Union[List[int], Tuple[int, ...], int]] = None, + ignore_label: int = None) -> AbstractTransform: + tr_transforms = [] + if do_dummy_2d_data_aug: + ignore_axes = (0,) + tr_transforms.append(Convert3DTo2DTransform()) + patch_size_spatial = patch_size[1:] + else: + patch_size_spatial = patch_size + ignore_axes = None + + tr_transforms.append(SpatialTransform( + patch_size_spatial, patch_center_dist_from_border=None, + do_elastic_deform=False, alpha=(0, 0), sigma=(0, 0), + do_rotation=True, angle_x=rotation_for_DA['x'], angle_y=rotation_for_DA['y'], angle_z=rotation_for_DA['z'], + p_rot_per_axis=1, # todo experiment with this + do_scale=True, scale=(0.7, 1.4), + border_mode_data="constant", border_cval_data=0, order_data=order_resampling_data, + border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=order_resampling_seg, + random_crop=False, # random cropping is part of our dataloaders + p_el_per_sample=0, p_scale_per_sample=0.2, p_rot_per_sample=0.2, + independent_scale_for_each_axis=False # todo experiment with this + )) + + if do_dummy_2d_data_aug: + tr_transforms.append(Convert2DTo3DTransform()) + + tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1)) + tr_transforms.append(GaussianBlurTransform((0.5, 1.), different_sigma_per_channel=True, p_per_sample=0.2, + p_per_channel=0.5)) + tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25), p_per_sample=0.15)) + tr_transforms.append(ContrastAugmentationTransform(p_per_sample=0.15)) + tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True, + p_per_channel=0.5, + order_downsample=0, order_upsample=3, p_per_sample=0.25, + ignore_axes=ignore_axes)) + tr_transforms.append(GammaTransform((0.7, 1.5), True, True, retain_stats=True, p_per_sample=0.1)) + tr_transforms.append(GammaTransform((0.7, 1.5), False, True, retain_stats=True, p_per_sample=0.3)) + + if mirror_axes is not None and len(mirror_axes) > 0: + tr_transforms.append(MirrorTransform(mirror_axes)) + + if use_mask_for_norm is not None and any(use_mask_for_norm): + tr_transforms.append(MaskTransform([i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]], + mask_idx_in_seg=0, set_outside_to=0)) + + tr_transforms.append(RemoveLabelTransform(-1, 0)) + + if is_cascaded: + assert foreground_labels is not None, 'We need foreground_labels for cascade augmentations' + tr_transforms.append(MoveSegAsOneHotToData(1, foreground_labels, 'seg', 'data')) + tr_transforms.append(ApplyRandomBinaryOperatorTransform( + channel_idx=list(range(-len(foreground_labels), 0)), + p_per_sample=0.4, + key="data", + strel_size=(1, 8), + p_per_label=1)) + tr_transforms.append( + RemoveRandomConnectedComponentFromOneHotEncodingTransform( + channel_idx=list(range(-len(foreground_labels), 0)), + key="data", + p_per_sample=0.2, + fill_with_other_class_p=0, + dont_do_if_covers_more_than_x_percent=0.15)) + + tr_transforms.append(RenameTransform('seg', 'target', True)) + + if regions is not None: + # the ignore label must also be converted + tr_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label] + if ignore_label is not None else regions, + 'target', 'target')) + + if deep_supervision_scales is not None: + tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target', + output_key='target')) + tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) + tr_transforms = Compose(tr_transforms) + return tr_transforms + + @staticmethod + def get_validation_transforms(deep_supervision_scales: Union[List, Tuple], + is_cascaded: bool = False, + foreground_labels: Union[Tuple[int, ...], List[int]] = None, + regions: List[Union[List[int], Tuple[int, ...], int]] = None, + ignore_label: int = None) -> AbstractTransform: + val_transforms = [] + val_transforms.append(RemoveLabelTransform(-1, 0)) + + if is_cascaded: + val_transforms.append(MoveSegAsOneHotToData(1, foreground_labels, 'seg', 'data')) + + val_transforms.append(RenameTransform('seg', 'target', True)) + + if regions is not None: + # the ignore label must also be converted + val_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label] + if ignore_label is not None else regions, + 'target', 'target')) + + if deep_supervision_scales is not None: + val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target', + output_key='target')) + + val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) + val_transforms = Compose(val_transforms) + return val_transforms + + def set_deep_supervision_enabled(self, enabled: bool): + """ + This function is specific for the default architecture in nnU-Net. If you change the architecture, there are + chances you need to change this as well! + """ + if self.is_ddp: + self.network.module.decoder.deep_supervision = enabled + else: + self.network.decoder.deep_supervision = enabled + + def on_train_start(self): + if not self.was_initialized: + self.initialize() + + maybe_mkdir_p(self.output_folder) + + # make sure deep supervision is on in the network + self.set_deep_supervision_enabled(True) + + self.print_plans() + empty_cache(self.device) + + # maybe unpack + if self.unpack_dataset and self.local_rank == 0: + self.print_to_log_file('unpacking dataset...') + unpack_dataset(self.preprocessed_dataset_folder, unpack_segmentation=True, overwrite_existing=False, + num_processes=max(1, round(get_allowed_n_proc_DA() // 2))) + self.print_to_log_file('unpacking done...') + + if self.is_ddp: + dist.barrier() + + # dataloaders must be instantiated here because they need access to the training data which may not be present + # when doing inference + self.dataloader_train, self.dataloader_val = self.get_dataloaders() + + # copy plans and dataset.json so that they can be used for restoring everything we need for inference + save_json(self.plans_manager.plans, join(self.output_folder_base, 'plans.json'), sort_keys=False) + save_json(self.dataset_json, join(self.output_folder_base, 'dataset.json'), sort_keys=False) + + # we don't really need the fingerprint but its still handy to have it with the others + shutil.copy(join(self.preprocessed_dataset_folder_base, 'dataset_fingerprint.json'), + join(self.output_folder_base, 'dataset_fingerprint.json')) + + # produces a pdf in output folder + self.plot_network_architecture() + + self._save_debug_information() + + # print(f"batch size: {self.batch_size}") + # print(f"oversample: {self.oversample_foreground_percent}") + + def on_train_end(self): + self.save_checkpoint(join(self.output_folder, "checkpoint_final.pth")) + # now we can delete latest + if self.local_rank == 0 and isfile(join(self.output_folder, "checkpoint_latest.pth")): + os.remove(join(self.output_folder, "checkpoint_latest.pth")) + + # shut down dataloaders + old_stdout = sys.stdout + with open(os.devnull, 'w') as f: + sys.stdout = f + if self.dataloader_train is not None: + self.dataloader_train._finish() + if self.dataloader_val is not None: + self.dataloader_val._finish() + sys.stdout = old_stdout + + empty_cache(self.device) + self.print_to_log_file("Training done.") + + def on_train_epoch_start(self): + self.network.train() + self.lr_scheduler.step(self.current_epoch) + self.print_to_log_file('') + self.print_to_log_file(f'Epoch {self.current_epoch}') + self.print_to_log_file( + f"Current learning rate: {np.round(self.optimizer.param_groups[0]['lr'], decimals=5)}") + # lrs are the same for all workers so we don't need to gather them in case of DDP training + self.logger.log('lrs', self.optimizer.param_groups[0]['lr'], self.current_epoch) + + def train_step(self, batch: dict) -> dict: + data = batch['data'] + target = batch['target'] + + data = data.to(self.device, non_blocking=True) + if isinstance(target, list): + target = [i.to(self.device, non_blocking=True) for i in target] + else: + target = target.to(self.device, non_blocking=True) + + self.optimizer.zero_grad() + # Autocast is a little bitch. + # If the device_type is 'cpu' then it's slow as heck and needs to be disabled. + # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False) + # So autocast will only be active if we have a cuda device. + with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): + output = self.network(data) + # del data + l = self.loss(output, target) + + if self.grad_scaler is not None: + self.grad_scaler.scale(l).backward() + self.grad_scaler.unscale_(self.optimizer) + torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() + else: + l.backward() + torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12) + self.optimizer.step() + return {'loss': l.detach().cpu().numpy()} + + def on_train_epoch_end(self, train_outputs: List[dict]): + outputs = collate_outputs(train_outputs) + + if self.is_ddp: + losses_tr = [None for _ in range(dist.get_world_size())] + dist.all_gather_object(losses_tr, outputs['loss']) + loss_here = np.vstack(losses_tr).mean() + else: + loss_here = np.mean(outputs['loss']) + + self.logger.log('train_losses', loss_here, self.current_epoch) + + def on_validation_epoch_start(self): + self.network.eval() + + def validation_step(self, batch: dict) -> dict: + data = batch['data'] + target = batch['target'] + + data = data.to(self.device, non_blocking=True) + if isinstance(target, list): + target = [i.to(self.device, non_blocking=True) for i in target] + else: + target = target.to(self.device, non_blocking=True) + + # Autocast is a little bitch. + # If the device_type is 'cpu' then it's slow as heck and needs to be disabled. + # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False) + # So autocast will only be active if we have a cuda device. + with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): + output = self.network(data) + del data + l = self.loss(output, target) + + # we only need the output with the highest output resolution + output = output[0] + target = target[0] + + # the following is needed for online evaluation. Fake dice (green line) + axes = [0] + list(range(2, len(output.shape))) + + if self.label_manager.has_regions: + predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long() + else: + # no need for softmax + output_seg = output.argmax(1)[:, None] + predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float32) + predicted_segmentation_onehot.scatter_(1, output_seg, 1) + del output_seg + + if self.label_manager.has_ignore_label: + if not self.label_manager.has_regions: + mask = (target != self.label_manager.ignore_label).float() + # CAREFUL that you don't rely on target after this line! + target[target == self.label_manager.ignore_label] = 0 + else: + mask = 1 - target[:, -1:] + # CAREFUL that you don't rely on target after this line! + target = target[:, :-1] + else: + mask = None + + tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask) + + tp_hard = tp.detach().cpu().numpy() + fp_hard = fp.detach().cpu().numpy() + fn_hard = fn.detach().cpu().numpy() + if not self.label_manager.has_regions: + # if we train with regions all segmentation heads predict some kind of foreground. In conventional + # (softmax training) there needs tobe one output for the background. We are not interested in the + # background Dice + # [1:] in order to remove background + tp_hard = tp_hard[1:] + fp_hard = fp_hard[1:] + fn_hard = fn_hard[1:] + + return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard} + + def on_validation_epoch_end(self, val_outputs: List[dict]): + outputs_collated = collate_outputs(val_outputs) + tp = np.sum(outputs_collated['tp_hard'], 0) + fp = np.sum(outputs_collated['fp_hard'], 0) + fn = np.sum(outputs_collated['fn_hard'], 0) + + if self.is_ddp: + world_size = dist.get_world_size() + + tps = [None for _ in range(world_size)] + dist.all_gather_object(tps, tp) + tp = np.vstack([i[None] for i in tps]).sum(0) + + fps = [None for _ in range(world_size)] + dist.all_gather_object(fps, fp) + fp = np.vstack([i[None] for i in fps]).sum(0) + + fns = [None for _ in range(world_size)] + dist.all_gather_object(fns, fn) + fn = np.vstack([i[None] for i in fns]).sum(0) + + losses_val = [None for _ in range(world_size)] + dist.all_gather_object(losses_val, outputs_collated['loss']) + loss_here = np.vstack(losses_val).mean() + else: + loss_here = np.mean(outputs_collated['loss']) + + global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in + zip(tp, fp, fn)]] + mean_fg_dice = np.nanmean(global_dc_per_class) + self.logger.log('mean_fg_dice', mean_fg_dice, self.current_epoch) + self.logger.log('dice_per_class_or_region', global_dc_per_class, self.current_epoch) + self.logger.log('val_losses', loss_here, self.current_epoch) + + def on_epoch_start(self): + self.logger.log('epoch_start_timestamps', time(), self.current_epoch) + + def on_epoch_end(self): + self.logger.log('epoch_end_timestamps', time(), self.current_epoch) + + # todo find a solution for this stupid shit + self.print_to_log_file('train_loss', np.round(self.logger.my_fantastic_logging['train_losses'][-1], decimals=4)) + self.print_to_log_file('val_loss', np.round(self.logger.my_fantastic_logging['val_losses'][-1], decimals=4)) + self.print_to_log_file('Pseudo dice', [np.round(i, decimals=4) for i in + self.logger.my_fantastic_logging['dice_per_class_or_region'][-1]]) + self.print_to_log_file( + f"Epoch time: {np.round(self.logger.my_fantastic_logging['epoch_end_timestamps'][-1] - self.logger.my_fantastic_logging['epoch_start_timestamps'][-1], decimals=2)} s") + + # handling periodic checkpointing + current_epoch = self.current_epoch + if (current_epoch + 1) % self.save_every == 0 and current_epoch != (self.num_epochs - 1): + self.save_checkpoint(join(self.output_folder, 'checkpoint_latest.pth')) + + # handle 'best' checkpointing. ema_fg_dice is computed by the logger and can be accessed like this + if self._best_ema is None or self.logger.my_fantastic_logging['ema_fg_dice'][-1] > self._best_ema: + self._best_ema = self.logger.my_fantastic_logging['ema_fg_dice'][-1] + self.print_to_log_file(f"Yayy! New best EMA pseudo Dice: {np.round(self._best_ema, decimals=4)}") + self.save_checkpoint(join(self.output_folder, 'checkpoint_best.pth')) + + if self.local_rank == 0: + self.logger.plot_progress_png(self.output_folder) + + self.current_epoch += 1 + + def save_checkpoint(self, filename: str) -> None: + if self.local_rank == 0: + if not self.disable_checkpointing: + if self.is_ddp: + mod = self.network.module + else: + mod = self.network + if isinstance(mod, OptimizedModule): + mod = mod._orig_mod + + checkpoint = { + 'network_weights': mod.state_dict(), + 'optimizer_state': self.optimizer.state_dict(), + 'grad_scaler_state': self.grad_scaler.state_dict() if self.grad_scaler is not None else None, + 'logging': self.logger.get_checkpoint(), + '_best_ema': self._best_ema, + 'current_epoch': self.current_epoch + 1, + 'init_args': self.my_init_kwargs, + 'trainer_name': self.__class__.__name__, + 'inference_allowed_mirroring_axes': self.inference_allowed_mirroring_axes, + } + torch.save(checkpoint, filename) + else: + self.print_to_log_file('No checkpoint written, checkpointing is disabled') + + def load_checkpoint(self, filename_or_checkpoint: Union[dict, str]) -> None: + if not self.was_initialized: + self.initialize() + + if isinstance(filename_or_checkpoint, str): + checkpoint = torch.load(filename_or_checkpoint, map_location=self.device) + # if state dict comes from nn.DataParallel but we use non-parallel model here then the state dict keys do not + # match. Use heuristic to make it match + new_state_dict = {} + for k, value in checkpoint['network_weights'].items(): + key = k + if key not in self.network.state_dict().keys() and key.startswith('module.'): + key = key[7:] + new_state_dict[key] = value + + self.my_init_kwargs = checkpoint['init_args'] + self.current_epoch = checkpoint['current_epoch'] + self.logger.load_checkpoint(checkpoint['logging']) + self._best_ema = checkpoint['_best_ema'] + self.inference_allowed_mirroring_axes = checkpoint[ + 'inference_allowed_mirroring_axes'] if 'inference_allowed_mirroring_axes' in checkpoint.keys() else self.inference_allowed_mirroring_axes + + # messing with state dict naming schemes. Facepalm. + if self.is_ddp: + if isinstance(self.network.module, OptimizedModule): + self.network.module._orig_mod.load_state_dict(new_state_dict) + else: + self.network.module.load_state_dict(new_state_dict) + else: + if isinstance(self.network, OptimizedModule): + self.network._orig_mod.load_state_dict(new_state_dict) + else: + self.network.load_state_dict(new_state_dict) + self.optimizer.load_state_dict(checkpoint['optimizer_state']) + if self.grad_scaler is not None: + if checkpoint['grad_scaler_state'] is not None: + self.grad_scaler.load_state_dict(checkpoint['grad_scaler_state']) + + def perform_actual_validation(self, save_probabilities: bool = False): + self.set_deep_supervision_enabled(False) + self.network.eval() + + predictor = nnUNetPredictor(tile_step_size=0.5, use_gaussian=True, use_mirroring=True, + perform_everything_on_gpu=True, device=self.device, verbose=False, + verbose_preprocessing=False, allow_tqdm=False) + predictor.manual_initialization(self.network, self.plans_manager, self.configuration_manager, None, + self.dataset_json, self.__class__.__name__, + self.inference_allowed_mirroring_axes) + + with multiprocessing.get_context("spawn").Pool(default_num_processes) as segmentation_export_pool: + worker_list = [i for i in segmentation_export_pool._pool] + validation_output_folder = join(self.output_folder, 'validation') + maybe_mkdir_p(validation_output_folder) + + # we cannot use self.get_tr_and_val_datasets() here because we might be DDP and then we have to distribute + # the validation keys across the workers. + _, val_keys = self.do_split() + if self.is_ddp: + val_keys = val_keys[self.local_rank:: dist.get_world_size()] + + dataset_val = nnUNetDataset(self.preprocessed_dataset_folder, val_keys, + folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage, + num_images_properties_loading_threshold=0) + + next_stages = self.configuration_manager.next_stage_names + + if next_stages is not None: + _ = [maybe_mkdir_p(join(self.output_folder_base, 'predicted_next_stage', n)) for n in next_stages] + + results = [] + + for k in dataset_val.keys(): + proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results, + allowed_num_queued=2) + while not proceed: + sleep(0.1) + proceed = not check_workers_alive_and_busy(segmentation_export_pool, worker_list, results, + allowed_num_queued=2) + + self.print_to_log_file(f"predicting {k}") + data, seg, properties = dataset_val.load_case(k) + + if self.is_cascaded: + data = np.vstack((data, convert_labelmap_to_one_hot(seg[-1], self.label_manager.foreground_labels, + output_dtype=data.dtype))) + with warnings.catch_warnings(): + # ignore 'The given NumPy array is not writable' warning + warnings.simplefilter("ignore") + data = torch.from_numpy(data) + + output_filename_truncated = join(validation_output_folder, k) + + try: + prediction = predictor.predict_sliding_window_return_logits(data) + except RuntimeError: + predictor.perform_everything_on_gpu = False + prediction = predictor.predict_sliding_window_return_logits(data) + predictor.perform_everything_on_gpu = True + + prediction = prediction.cpu() + + # this needs to go into background processes + results.append( + segmentation_export_pool.starmap_async( + export_prediction_from_logits, ( + (prediction, properties, self.configuration_manager, self.plans_manager, + self.dataset_json, output_filename_truncated, save_probabilities), + ) + ) + ) + # for debug purposes + # export_prediction(prediction_for_export, properties, self.configuration, self.plans, self.dataset_json, + # output_filename_truncated, save_probabilities) + + # if needed, export the softmax prediction for the next stage + if next_stages is not None: + for n in next_stages: + next_stage_config_manager = self.plans_manager.get_configuration(n) + expected_preprocessed_folder = join(nnUNet_preprocessed, self.plans_manager.dataset_name, + next_stage_config_manager.data_identifier) + + try: + # we do this so that we can use load_case and do not have to hard code how loading training cases is implemented + tmp = nnUNetDataset(expected_preprocessed_folder, [k], + num_images_properties_loading_threshold=0) + d, s, p = tmp.load_case(k) + except FileNotFoundError: + self.print_to_log_file( + f"Predicting next stage {n} failed for case {k} because the preprocessed file is missing! " + f"Run the preprocessing for this configuration first!") + continue + + target_shape = d.shape[1:] + output_folder = join(self.output_folder_base, 'predicted_next_stage', n) + output_file = join(output_folder, k + '.npz') + + # resample_and_save(prediction, target_shape, output_file, self.plans_manager, self.configuration_manager, properties, + # self.dataset_json) + results.append(segmentation_export_pool.starmap_async( + resample_and_save, ( + (prediction, target_shape, output_file, self.plans_manager, + self.configuration_manager, + properties, + self.dataset_json), + ) + )) + + _ = [r.get() for r in results] + + if self.is_ddp: + dist.barrier() + + if self.local_rank == 0: + metrics = compute_metrics_on_folder(join(self.preprocessed_dataset_folder_base, 'gt_segmentations'), + validation_output_folder, + join(validation_output_folder, 'summary.json'), + self.plans_manager.image_reader_writer_class(), + self.dataset_json["file_ending"], + self.label_manager.foreground_regions if self.label_manager.has_regions else + self.label_manager.foreground_labels, + self.label_manager.ignore_label, chill=True) + self.print_to_log_file("Validation complete", also_print_to_console=True) + self.print_to_log_file("Mean Validation Dice: ", (metrics['foreground_mean']["Dice"]), also_print_to_console=True) + + self.set_deep_supervision_enabled(True) + compute_gaussian.cache_clear() + + def run_training(self): + self.on_train_start() + + for epoch in range(self.current_epoch, self.num_epochs): + self.on_epoch_start() + + self.on_train_epoch_start() + train_outputs = [] + for batch_id in range(self.num_iterations_per_epoch): + train_outputs.append(self.train_step(next(self.dataloader_train))) + self.on_train_epoch_end(train_outputs) + + with torch.no_grad(): + self.on_validation_epoch_start() + val_outputs = [] + for batch_id in range(self.num_val_iterations_per_epoch): + val_outputs.append(self.validation_step(next(self.dataloader_val))) + self.on_validation_epoch_end(val_outputs) + + self.on_epoch_end() + + self.on_train_end() diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/__init__.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/benchmarking/__init__.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/benchmarking/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs.py new file mode 100644 index 0000000..fad1fff --- /dev/null +++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs.py @@ -0,0 +1,65 @@ +import torch +from batchgenerators.utilities.file_and_folder_operations import save_json, join, isfile, load_json + +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +from torch import distributed as dist + + +class nnUNetTrainerBenchmark_5epochs(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + assert self.fold == 0, "It makes absolutely no sense to specify a certain fold. Stick with 0 so that we can parse the results." + self.disable_checkpointing = True + self.num_epochs = 5 + assert torch.cuda.is_available(), "This only works on GPU" + self.crashed_with_runtime_error = False + + def perform_actual_validation(self, save_probabilities: bool = False): + pass + + def save_checkpoint(self, filename: str) -> None: + # do not trust people to remember that self.disable_checkpointing must be True for this trainer + pass + + def run_training(self): + try: + super().run_training() + except RuntimeError: + self.crashed_with_runtime_error = True + + def on_train_end(self): + super().on_train_end() + + if not self.is_ddp or self.local_rank == 0: + torch_version = torch.__version__ + cudnn_version = torch.backends.cudnn.version() + gpu_name = torch.cuda.get_device_name() + if self.crashed_with_runtime_error: + fastest_epoch = 'Not enough VRAM!' + else: + epoch_times = [i - j for i, j in zip(self.logger.my_fantastic_logging['epoch_end_timestamps'], + self.logger.my_fantastic_logging['epoch_start_timestamps'])] + fastest_epoch = min(epoch_times) + + if self.is_ddp: + num_gpus = dist.get_world_size() + else: + num_gpus = 1 + + benchmark_result_file = join(self.output_folder, 'benchmark_result.json') + if isfile(benchmark_result_file): + old_results = load_json(benchmark_result_file) + else: + old_results = {} + # generate some unique key + my_key = f"{cudnn_version}__{torch_version.replace(' ', '')}__{gpu_name.replace(' ', '')}__gpus_{num_gpus}" + old_results[my_key] = { + 'torch_version': torch_version, + 'cudnn_version': cudnn_version, + 'gpu_name': gpu_name, + 'fastest_epoch': fastest_epoch, + 'num_gpus': num_gpus, + } + save_json(old_results, + join(self.output_folder, 'benchmark_result.json')) diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs_noDataLoading.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs_noDataLoading.py new file mode 100644 index 0000000..6c12ecc --- /dev/null +++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs_noDataLoading.py @@ -0,0 +1,51 @@ +import torch + +from nnunetv2.training.nnUNetTrainer.variants.benchmarking.nnUNetTrainerBenchmark_5epochs import \ + nnUNetTrainerBenchmark_5epochs +from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels + + +class nnUNetTrainerBenchmark_5epochs_noDataLoading(nnUNetTrainerBenchmark_5epochs): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self._set_batch_size_and_oversample() + num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager, + self.dataset_json) + patch_size = self.configuration_manager.patch_size + dummy_data = torch.rand((self.batch_size, num_input_channels, *patch_size), device=self.device) + dummy_target = [ + torch.round( + torch.rand((self.batch_size, 1, *[int(i * j) for i, j in zip(patch_size, k)]), device=self.device) * + max(self.label_manager.all_labels) + ) for k in self._get_deep_supervision_scales()] + self.dummy_batch = {'data': dummy_data, 'target': dummy_target} + + def get_dataloaders(self): + return None, None + + def run_training(self): + try: + self.on_train_start() + + for epoch in range(self.current_epoch, self.num_epochs): + self.on_epoch_start() + + self.on_train_epoch_start() + train_outputs = [] + for batch_id in range(self.num_iterations_per_epoch): + train_outputs.append(self.train_step(self.dummy_batch)) + self.on_train_epoch_end(train_outputs) + + with torch.no_grad(): + self.on_validation_epoch_start() + val_outputs = [] + for batch_id in range(self.num_val_iterations_per_epoch): + val_outputs.append(self.validation_step(self.dummy_batch)) + self.on_validation_epoch_end(val_outputs) + + self.on_epoch_end() + + self.on_train_end() + except RuntimeError: + self.crashed_with_runtime_error = True diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/__init__.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py new file mode 100644 index 0000000..bd9c31c --- /dev/null +++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDA5.py @@ -0,0 +1,410 @@ +from typing import List, Union, Tuple + +import numpy as np +import torch +from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter +from batchgenerators.transforms.abstract_transforms import AbstractTransform, Compose +from batchgenerators.transforms.color_transforms import BrightnessTransform, ContrastAugmentationTransform, \ + GammaTransform +from batchgenerators.transforms.local_transforms import BrightnessGradientAdditiveTransform, LocalGammaTransform +from batchgenerators.transforms.noise_transforms import MedianFilterTransform, GaussianBlurTransform, \ + GaussianNoiseTransform, BlankRectangleTransform, SharpeningTransform +from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform +from batchgenerators.transforms.spatial_transforms import SpatialTransform, Rot90Transform, TransposeAxesTransform, \ + MirrorTransform +from batchgenerators.transforms.utility_transforms import OneOfTransform, RemoveLabelTransform, RenameTransform, \ + NumpyToTensor + +from nnunetv2.configuration import ANISO_THRESHOLD +from nnunetv2.training.data_augmentation.compute_initial_patch_size import get_patch_size +from nnunetv2.training.data_augmentation.custom_transforms.cascade_transforms import MoveSegAsOneHotToData, \ + ApplyRandomBinaryOperatorTransform, RemoveRandomConnectedComponentFromOneHotEncodingTransform +from nnunetv2.training.data_augmentation.custom_transforms.deep_supervision_donwsampling import \ + DownsampleSegForDSTransform2 +from nnunetv2.training.data_augmentation.custom_transforms.limited_length_multithreaded_augmenter import \ + LimitedLenWrapper +from nnunetv2.training.data_augmentation.custom_transforms.masking import MaskTransform +from nnunetv2.training.data_augmentation.custom_transforms.region_based_training import \ + ConvertSegmentationToRegionsTransform +from nnunetv2.training.data_augmentation.custom_transforms.transforms_for_dummy_2d import Convert3DTo2DTransform, \ + Convert2DTo3DTransform +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA + + +class nnUNetTrainerDA5(nnUNetTrainer): + def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): + """ + This function is stupid and certainly one of the weakest spots of this implementation. Not entirely sure how we can fix it. + """ + patch_size = self.configuration_manager.patch_size + dim = len(patch_size) + # todo rotation should be defined dynamically based on patch size (more isotropic patch sizes = more rotation) + if dim == 2: + do_dummy_2d_data_aug = False + # todo revisit this parametrization + if max(patch_size) / min(patch_size) > 1.5: + rotation_for_DA = { + 'x': (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi), + 'y': (0, 0), + 'z': (0, 0) + } + else: + rotation_for_DA = { + 'x': (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi), + 'y': (0, 0), + 'z': (0, 0) + } + mirror_axes = (0, 1) + elif dim == 3: + # todo this is not ideal. We could also have patch_size (64, 16, 128) in which case a full 180deg 2d rot would be bad + # order of the axes is determined by spacing, not image size + do_dummy_2d_data_aug = (max(patch_size) / patch_size[0]) > ANISO_THRESHOLD + if do_dummy_2d_data_aug: + # why do we rotate 180 deg here all the time? We should also restrict it + rotation_for_DA = { + 'x': (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi), + 'y': (0, 0), + 'z': (0, 0) + } + else: + rotation_for_DA = { + 'x': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), + 'y': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), + 'z': (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi), + } + mirror_axes = (0, 1, 2) + else: + raise RuntimeError() + + # todo this function is stupid. It doesn't even use the correct scale range (we keep things as they were in the + # old nnunet for now) + initial_patch_size = get_patch_size(patch_size[-dim:], + *rotation_for_DA.values(), + (0.7, 1.43)) + if do_dummy_2d_data_aug: + initial_patch_size[0] = patch_size[0] + + self.print_to_log_file(f'do_dummy_2d_data_aug: {do_dummy_2d_data_aug}') + self.inference_allowed_mirroring_axes = mirror_axes + + return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes + + @staticmethod + def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]], + rotation_for_DA: dict, + deep_supervision_scales: Union[List, Tuple], + mirror_axes: Tuple[int, ...], + do_dummy_2d_data_aug: bool, + order_resampling_data: int = 3, + order_resampling_seg: int = 1, + border_val_seg: int = -1, + use_mask_for_norm: List[bool] = None, + is_cascaded: bool = False, + foreground_labels: Union[Tuple[int, ...], List[int]] = None, + regions: List[Union[List[int], Tuple[int, ...], int]] = None, + ignore_label: int = None) -> AbstractTransform: + matching_axes = np.array([sum([i == j for j in patch_size]) for i in patch_size]) + valid_axes = list(np.where(matching_axes == np.max(matching_axes))[0]) + + tr_transforms = [] + + if do_dummy_2d_data_aug: + ignore_axes = (0,) + tr_transforms.append(Convert3DTo2DTransform()) + patch_size_spatial = patch_size[1:] + else: + patch_size_spatial = patch_size + ignore_axes = None + + tr_transforms.append( + SpatialTransform( + patch_size_spatial, + patch_center_dist_from_border=None, + do_elastic_deform=False, + do_rotation=True, + angle_x=rotation_for_DA['x'], + angle_y=rotation_for_DA['y'], + angle_z=rotation_for_DA['z'], + p_rot_per_axis=0.5, + do_scale=True, + scale=(0.7, 1.43), + border_mode_data="constant", + border_cval_data=0, + order_data=order_resampling_data, + border_mode_seg="constant", + border_cval_seg=-1, + order_seg=order_resampling_seg, + random_crop=False, + p_el_per_sample=0.2, + p_scale_per_sample=0.2, + p_rot_per_sample=0.4, + independent_scale_for_each_axis=True, + ) + ) + + if do_dummy_2d_data_aug: + tr_transforms.append(Convert2DTo3DTransform()) + + if np.any(matching_axes > 1): + tr_transforms.append( + Rot90Transform( + (0, 1, 2, 3), axes=valid_axes, data_key='data', label_key='seg', p_per_sample=0.5 + ), + ) + + if np.any(matching_axes > 1): + tr_transforms.append( + TransposeAxesTransform(valid_axes, data_key='data', label_key='seg', p_per_sample=0.5) + ) + + tr_transforms.append(OneOfTransform([ + MedianFilterTransform( + (2, 8), + same_for_each_channel=False, + p_per_sample=0.2, + p_per_channel=0.5 + ), + GaussianBlurTransform((0.3, 1.5), + different_sigma_per_channel=True, + p_per_sample=0.2, + p_per_channel=0.5) + ])) + + tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1)) + + tr_transforms.append(BrightnessTransform(0, + 0.5, + per_channel=True, + p_per_sample=0.1, + p_per_channel=0.5 + ) + ) + + tr_transforms.append(OneOfTransform( + [ + ContrastAugmentationTransform( + contrast_range=(0.5, 2), + preserve_range=True, + per_channel=True, + data_key='data', + p_per_sample=0.2, + p_per_channel=0.5 + ), + ContrastAugmentationTransform( + contrast_range=(0.5, 2), + preserve_range=False, + per_channel=True, + data_key='data', + p_per_sample=0.2, + p_per_channel=0.5 + ), + ] + )) + + tr_transforms.append( + SimulateLowResolutionTransform(zoom_range=(0.25, 1), + per_channel=True, + p_per_channel=0.5, + order_downsample=0, + order_upsample=3, + p_per_sample=0.15, + ignore_axes=ignore_axes + ) + ) + + tr_transforms.append( + GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1)) + tr_transforms.append( + GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1)) + + if mirror_axes is not None and len(mirror_axes) > 0: + tr_transforms.append(MirrorTransform(mirror_axes)) + + tr_transforms.append( + BlankRectangleTransform([[max(1, p // 10), p // 3] for p in patch_size], + rectangle_value=np.mean, + num_rectangles=(1, 5), + force_square=False, + p_per_sample=0.4, + p_per_channel=0.5 + ) + ) + + tr_transforms.append( + BrightnessGradientAdditiveTransform( + lambda x, y: np.exp(np.random.uniform(np.log(x[y] // 6), np.log(x[y]))), + (-0.5, 1.5), + max_strength=lambda x, y: np.random.uniform(-5, -1) if np.random.uniform() < 0.5 else np.random.uniform(1, 5), + mean_centered=False, + same_for_all_channels=False, + p_per_sample=0.3, + p_per_channel=0.5 + ) + ) + + tr_transforms.append( + LocalGammaTransform( + lambda x, y: np.exp(np.random.uniform(np.log(x[y] // 6), np.log(x[y]))), + (-0.5, 1.5), + lambda: np.random.uniform(0.01, 0.8) if np.random.uniform() < 0.5 else np.random.uniform(1.5, 4), + same_for_all_channels=False, + p_per_sample=0.3, + p_per_channel=0.5 + ) + ) + + tr_transforms.append( + SharpeningTransform( + strength=(0.1, 1), + same_for_each_channel=False, + p_per_sample=0.2, + p_per_channel=0.5 + ) + ) + + if use_mask_for_norm is not None and any(use_mask_for_norm): + tr_transforms.append(MaskTransform([i for i in range(len(use_mask_for_norm)) if use_mask_for_norm[i]], + mask_idx_in_seg=0, set_outside_to=0)) + + tr_transforms.append(RemoveLabelTransform(-1, 0)) + + if is_cascaded: + if ignore_label is not None: + raise NotImplementedError('ignore label not yet supported in cascade') + assert foreground_labels is not None, 'We need all_labels for cascade augmentations' + use_labels = [i for i in foreground_labels if i != 0] + tr_transforms.append(MoveSegAsOneHotToData(1, use_labels, 'seg', 'data')) + tr_transforms.append(ApplyRandomBinaryOperatorTransform( + channel_idx=list(range(-len(use_labels), 0)), + p_per_sample=0.4, + key="data", + strel_size=(1, 8), + p_per_label=1)) + tr_transforms.append( + RemoveRandomConnectedComponentFromOneHotEncodingTransform( + channel_idx=list(range(-len(use_labels), 0)), + key="data", + p_per_sample=0.2, + fill_with_other_class_p=0, + dont_do_if_covers_more_than_x_percent=0.15)) + + tr_transforms.append(RenameTransform('seg', 'target', True)) + + if regions is not None: + # the ignore label must also be converted + tr_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label] + if ignore_label is not None else regions, + 'target', 'target')) + + if deep_supervision_scales is not None: + tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target', + output_key='target')) + tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) + tr_transforms = Compose(tr_transforms) + return tr_transforms + + +class nnUNetTrainerDA5ord0(nnUNetTrainerDA5): + def get_dataloaders(self): + """ + changed order_resampling_data, order_resampling_seg + """ + # we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether + # we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be + patch_size = self.configuration_manager.patch_size + dim = len(patch_size) + + # needed for deep supervision: how much do we need to downscale the segmentation targets for the different + # outputs? + deep_supervision_scales = self._get_deep_supervision_scales() + + rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ + self.configure_rotation_dummyDA_mirroring_and_inital_patch_size() + + # training pipeline + tr_transforms = self.get_training_transforms( + patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug, + order_resampling_data=0, order_resampling_seg=0, + use_mask_for_norm=self.configuration_manager.use_mask_for_norm, + is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.all_labels, + regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + # validation pipeline + val_transforms = self.get_validation_transforms(deep_supervision_scales, + is_cascaded=self.is_cascaded, + foreground_labels=self.label_manager.all_labels, + regions=self.label_manager.foreground_regions if + self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim) + + allowed_num_processes = get_allowed_n_proc_DA() + if allowed_num_processes == 0: + mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms) + mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms) + else: + mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms, + allowed_num_processes, 6, None, True, 0.02) + mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms, + max(1, allowed_num_processes // 2), 3, None, True, 0.02) + + return mt_gen_train, mt_gen_val + + +class nnUNetTrainerDA5Segord0(nnUNetTrainerDA5): + def get_dataloaders(self): + """ + changed order_resampling_data, order_resampling_seg + """ + # we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether + # we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be + patch_size = self.configuration_manager.patch_size + dim = len(patch_size) + + # needed for deep supervision: how much do we need to downscale the segmentation targets for the different + # outputs? + deep_supervision_scales = self._get_deep_supervision_scales() + + rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ + self.configure_rotation_dummyDA_mirroring_and_inital_patch_size() + + # training pipeline + tr_transforms = self.get_training_transforms( + patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug, + order_resampling_data=3, order_resampling_seg=0, + use_mask_for_norm=self.configuration_manager.use_mask_for_norm, + is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.all_labels, + regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + # validation pipeline + val_transforms = self.get_validation_transforms(deep_supervision_scales, + is_cascaded=self.is_cascaded, + foreground_labels=self.label_manager.all_labels, + regions=self.label_manager.foreground_regions if + self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim) + + allowed_num_processes = get_allowed_n_proc_DA() + if allowed_num_processes == 0: + mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms) + mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms) + else: + mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms, + allowed_num_processes, 6, None, True, 0.02) + mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms, + max(1, allowed_num_processes // 2), 3, None, True, 0.02) + + return mt_gen_train, mt_gen_val + + +class nnUNetTrainerDA5_10epochs(nnUNetTrainerDA5): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 10 diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDAOrd0.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDAOrd0.py new file mode 100644 index 0000000..e87ff8f --- /dev/null +++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDAOrd0.py @@ -0,0 +1,104 @@ +from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter + +from nnunetv2.training.data_augmentation.custom_transforms.limited_length_multithreaded_augmenter import \ + LimitedLenWrapper +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA + + +class nnUNetTrainerDAOrd0(nnUNetTrainer): + def get_dataloaders(self): + """ + changed order_resampling_data, order_resampling_seg + """ + # we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether + # we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be + patch_size = self.configuration_manager.patch_size + dim = len(patch_size) + + # needed for deep supervision: how much do we need to downscale the segmentation targets for the different + # outputs? + deep_supervision_scales = self._get_deep_supervision_scales() + + rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ + self.configure_rotation_dummyDA_mirroring_and_inital_patch_size() + + # training pipeline + tr_transforms = self.get_training_transforms( + patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug, + order_resampling_data=0, order_resampling_seg=0, + use_mask_for_norm=self.configuration_manager.use_mask_for_norm, + is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.all_labels, + regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + # validation pipeline + val_transforms = self.get_validation_transforms(deep_supervision_scales, + is_cascaded=self.is_cascaded, + foreground_labels=self.label_manager.all_labels, + regions=self.label_manager.foreground_regions if + self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim) + + allowed_num_processes = get_allowed_n_proc_DA() + if allowed_num_processes == 0: + mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms) + mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms) + else: + mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms, + allowed_num_processes, 6, None, True, 0.02) + mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms, + max(1, allowed_num_processes // 2), 3, None, True, 0.02) + + return mt_gen_train, mt_gen_val + + +class nnUNetTrainer_DASegOrd0(nnUNetTrainer): + def get_dataloaders(self): + """ + changed order_resampling_data, order_resampling_seg + """ + # we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether + # we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be + patch_size = self.configuration_manager.patch_size + dim = len(patch_size) + + # needed for deep supervision: how much do we need to downscale the segmentation targets for the different + # outputs? + deep_supervision_scales = self._get_deep_supervision_scales() + + rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ + self.configure_rotation_dummyDA_mirroring_and_inital_patch_size() + + # training pipeline + tr_transforms = self.get_training_transforms( + patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug, + order_resampling_data=3, order_resampling_seg=0, + use_mask_for_norm=self.configuration_manager.use_mask_for_norm, + is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.all_labels, + regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + # validation pipeline + val_transforms = self.get_validation_transforms(deep_supervision_scales, + is_cascaded=self.is_cascaded, + foreground_labels=self.label_manager.all_labels, + regions=self.label_manager.foreground_regions if + self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim) + + allowed_num_processes = get_allowed_n_proc_DA() + if allowed_num_processes == 0: + mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms) + mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms) + else: + mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms, + allowed_num_processes, 6, None, True, 0.02) + mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms, + max(1, allowed_num_processes // 2), 3, None, True, 0.02) + + return mt_gen_train, mt_gen_val diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoDA.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoDA.py new file mode 100644 index 0000000..527e262 --- /dev/null +++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoDA.py @@ -0,0 +1,40 @@ +from typing import Union, Tuple, List + +from batchgenerators.transforms.abstract_transforms import AbstractTransform + +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +import numpy as np + + +class nnUNetTrainerNoDA(nnUNetTrainer): + @staticmethod + def get_training_transforms(patch_size: Union[np.ndarray, Tuple[int]], + rotation_for_DA: dict, + deep_supervision_scales: Union[List, Tuple], + mirror_axes: Tuple[int, ...], + do_dummy_2d_data_aug: bool, + order_resampling_data: int = 1, + order_resampling_seg: int = 0, + border_val_seg: int = -1, + use_mask_for_norm: List[bool] = None, + is_cascaded: bool = False, + foreground_labels: Union[Tuple[int, ...], List[int]] = None, + regions: List[Union[List[int], Tuple[int, ...], int]] = None, + ignore_label: int = None) -> AbstractTransform: + return nnUNetTrainer.get_validation_transforms(deep_supervision_scales, is_cascaded, foreground_labels, + regions, ignore_label) + + def get_plain_dataloaders(self, initial_patch_size: Tuple[int, ...], dim: int): + return super().get_plain_dataloaders( + initial_patch_size=self.configuration_manager.patch_size, + dim=dim + ) + + def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): + # we need to disable mirroring here so that no mirroring will be applied in inferene! + rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ + super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() + mirror_axes = None + self.inference_allowed_mirroring_axes = None + return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes + diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoMirroring.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoMirroring.py new file mode 100644 index 0000000..18ea1ea --- /dev/null +++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerNoMirroring.py @@ -0,0 +1,28 @@ +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer + + +class nnUNetTrainerNoMirroring(nnUNetTrainer): + def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): + rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ + super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() + mirror_axes = None + self.inference_allowed_mirroring_axes = None + return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes + + +class nnUNetTrainer_onlyMirror01(nnUNetTrainer): + """ + Only mirrors along spatial axes 0 and 1 for 3D and 0 for 2D + """ + def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): + rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ + super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() + patch_size = self.configuration_manager.patch_size + dim = len(patch_size) + if dim == 2: + mirror_axes = (0, ) + else: + mirror_axes = (0, 1) + self.inference_allowed_mirroring_axes = mirror_axes + return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes + diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/loss/__init__.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/loss/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerCELoss.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerCELoss.py new file mode 100644 index 0000000..1a363cc --- /dev/null +++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerCELoss.py @@ -0,0 +1,23 @@ +from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +from nnunetv2.training.loss.robust_ce_loss import RobustCrossEntropyLoss +import numpy as np + + +class nnUNetTrainerCELoss(nnUNetTrainer): + def _build_loss(self): + assert not self.label_manager.has_regions, 'regions not supported by this trainer' + loss = RobustCrossEntropyLoss(weight=None, + ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100) + + deep_supervision_scales = self._get_deep_supervision_scales() + + # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) + + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) + return loss diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerDiceLoss.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerDiceLoss.py new file mode 100644 index 0000000..82114c0 --- /dev/null +++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerDiceLoss.py @@ -0,0 +1,56 @@ +import numpy as np +import torch + +from nnunetv2.training.loss.compound_losses import DC_and_BCE_loss, DC_and_CE_loss +from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper +from nnunetv2.training.loss.dice import MemoryEfficientSoftDiceLoss +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +from nnunetv2.utilities.helpers import softmax_helper_dim1 + + +class nnUNetTrainerDiceLoss(nnUNetTrainer): + def _build_loss(self): + loss = MemoryEfficientSoftDiceLoss(**{'batch_dice': self.configuration_manager.batch_dice, + 'do_bg': self.label_manager.has_regions, 'smooth': 1e-5, 'ddp': self.is_ddp}, + apply_nonlin=torch.sigmoid if self.label_manager.has_regions else softmax_helper_dim1) + + deep_supervision_scales = self._get_deep_supervision_scales() + + # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) + + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) + return loss + + +class nnUNetTrainerDiceCELoss_noSmooth(nnUNetTrainer): + def _build_loss(self): + # set smooth to 0 + if self.label_manager.has_regions: + loss = DC_and_BCE_loss({}, + {'batch_dice': self.configuration_manager.batch_dice, + 'do_bg': True, 'smooth': 0, 'ddp': self.is_ddp}, + use_ignore_label=self.label_manager.ignore_label is not None, + dice_class=MemoryEfficientSoftDiceLoss) + else: + loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice, + 'smooth': 0, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1, + ignore_label=self.label_manager.ignore_label, + dice_class=MemoryEfficientSoftDiceLoss) + + deep_supervision_scales = self._get_deep_supervision_scales() + + # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) + + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) + return loss + diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerTopkLoss.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerTopkLoss.py new file mode 100644 index 0000000..4ebfd41 --- /dev/null +++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/loss/nnUNetTrainerTopkLoss.py @@ -0,0 +1,66 @@ +from nnunetv2.training.loss.compound_losses import DC_and_topk_loss +from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +import numpy as np +from nnunetv2.training.loss.robust_ce_loss import TopKLoss + + +class nnUNetTrainerTopk10Loss(nnUNetTrainer): + def _build_loss(self): + assert not self.label_manager.has_regions, 'regions not supported by this trainer' + loss = TopKLoss(ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100, + k=10) + + deep_supervision_scales = self._get_deep_supervision_scales() + + # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) + + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) + return loss + + +class nnUNetTrainerTopk10LossLS01(nnUNetTrainer): + def _build_loss(self): + assert not self.label_manager.has_regions, 'regions not supported by this trainer' + loss = TopKLoss(ignore_index=self.label_manager.ignore_label if self.label_manager.has_ignore_label else -100, + k=10, label_smoothing=0.1) + + deep_supervision_scales = self._get_deep_supervision_scales() + + # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) + + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) + return loss + + +class nnUNetTrainerDiceTopK10Loss(nnUNetTrainer): + def _build_loss(self): + assert not self.label_manager.has_regions, 'regions not supported by this trainer' + loss = DC_and_topk_loss({'batch_dice': self.configuration_manager.batch_dice, + 'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp}, + {'k': 10, + 'label_smoothing': 0.0}, + weight_ce=1, weight_dice=1, + ignore_label=self.label_manager.ignore_label) + + deep_supervision_scales = self._get_deep_supervision_scales() + + # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases + # this gives higher resolution outputs more weight in the loss + weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))]) + + # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 + weights = weights / weights.sum() + # now wrap the loss + loss = DeepSupervisionWrapper(loss, weights) + return loss diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/lr_schedule/__init__.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/lr_schedule/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/lr_schedule/nnUNetTrainerCosAnneal.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/lr_schedule/nnUNetTrainerCosAnneal.py new file mode 100644 index 0000000..60455f2 --- /dev/null +++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/lr_schedule/nnUNetTrainerCosAnneal.py @@ -0,0 +1,13 @@ +import torch +from torch.optim.lr_scheduler import CosineAnnealingLR + +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer + + +class nnUNetTrainerCosAnneal(nnUNetTrainer): + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, + momentum=0.99, nesterov=True) + lr_scheduler = CosineAnnealingLR(optimizer, T_max=self.num_epochs) + return optimizer, lr_scheduler + diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/network_architecture/__init__.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/network_architecture/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerBN.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerBN.py new file mode 100644 index 0000000..b2f26e2 --- /dev/null +++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerBN.py @@ -0,0 +1,73 @@ +from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet, PlainConvUNet +from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_batchnorm +from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0, InitWeights_He +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager +from torch import nn + + +class nnUNetTrainerBN(nnUNetTrainer): + @staticmethod + def build_network_architecture(plans_manager: PlansManager, + dataset_json, + configuration_manager: ConfigurationManager, + num_input_channels, + enable_deep_supervision: bool = True) -> nn.Module: + num_stages = len(configuration_manager.conv_kernel_sizes) + + dim = len(configuration_manager.conv_kernel_sizes[0]) + conv_op = convert_dim_to_conv_op(dim) + + label_manager = plans_manager.get_label_manager(dataset_json) + + segmentation_network_class_name = configuration_manager.UNet_class_name + mapping = { + 'PlainConvUNet': PlainConvUNet, + 'ResidualEncoderUNet': ResidualEncoderUNet + } + kwargs = { + 'PlainConvUNet': { + 'conv_bias': True, + 'norm_op': get_matching_batchnorm(conv_op), + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, 'dropout_op_kwargs': None, + 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, + }, + 'ResidualEncoderUNet': { + 'conv_bias': True, + 'norm_op': get_matching_batchnorm(conv_op), + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, 'dropout_op_kwargs': None, + 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, + } + } + assert segmentation_network_class_name in mapping.keys(), 'The network architecture specified by the plans file ' \ + 'is non-standard (maybe your own?). Yo\'ll have to dive ' \ + 'into either this ' \ + 'function (get_network_from_plans) or ' \ + 'the init of your nnUNetModule to accomodate that.' + network_class = mapping[segmentation_network_class_name] + + conv_or_blocks_per_stage = { + 'n_conv_per_stage' + if network_class != ResidualEncoderUNet else 'n_blocks_per_stage': configuration_manager.n_conv_per_stage_encoder, + 'n_conv_per_stage_decoder': configuration_manager.n_conv_per_stage_decoder + } + # network class name!! + model = network_class( + input_channels=num_input_channels, + n_stages=num_stages, + features_per_stage=[min(configuration_manager.UNet_base_num_features * 2 ** i, + configuration_manager.unet_max_num_features) for i in range(num_stages)], + conv_op=conv_op, + kernel_sizes=configuration_manager.conv_kernel_sizes, + strides=configuration_manager.pool_op_kernel_sizes, + num_classes=label_manager.num_segmentation_heads, + deep_supervision=enable_deep_supervision, + **conv_or_blocks_per_stage, + **kwargs[segmentation_network_class_name] + ) + model.apply(InitWeights_He(1e-2)) + if network_class == ResidualEncoderUNet: + model.apply(init_last_bn_before_add_to_0) + return model diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py new file mode 100644 index 0000000..a07ff8a --- /dev/null +++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py @@ -0,0 +1,114 @@ +import torch +from torch import autocast + +from nnunetv2.training.loss.compound_losses import DC_and_BCE_loss, DC_and_CE_loss +from nnunetv2.training.loss.dice import get_tp_fp_fn_tn, MemoryEfficientSoftDiceLoss +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +from nnunetv2.utilities.helpers import dummy_context +from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels +from torch.nn.parallel import DistributedDataParallel as DDP + + +class nnUNetTrainerNoDeepSupervision(nnUNetTrainer): + def _build_loss(self): + if self.label_manager.has_regions: + loss = DC_and_BCE_loss({}, + {'batch_dice': self.configuration_manager.batch_dice, + 'do_bg': True, 'smooth': 1e-5, 'ddp': self.is_ddp}, + use_ignore_label=self.label_manager.ignore_label is not None, + dice_class=MemoryEfficientSoftDiceLoss) + else: + loss = DC_and_CE_loss({'batch_dice': self.configuration_manager.batch_dice, + 'smooth': 1e-5, 'do_bg': False, 'ddp': self.is_ddp}, {}, weight_ce=1, weight_dice=1, + ignore_label=self.label_manager.ignore_label, + dice_class=MemoryEfficientSoftDiceLoss) + return loss + + def _get_deep_supervision_scales(self): + return None + + def initialize(self): + if not self.was_initialized: + self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager, + self.dataset_json) + + self.network = self.build_network_architecture(self.plans_manager, self.dataset_json, + self.configuration_manager, + self.num_input_channels, + enable_deep_supervision=False).to(self.device) + + self.optimizer, self.lr_scheduler = self.configure_optimizers() + # if ddp, wrap in DDP wrapper + if self.is_ddp: + self.network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.network) + self.network = DDP(self.network, device_ids=[self.local_rank]) + + self.loss = self._build_loss() + self.was_initialized = True + else: + raise RuntimeError("You have called self.initialize even though the trainer was already initialized. " + "That should not happen.") + + def set_deep_supervision_enabled(self, enabled: bool): + pass + + def validation_step(self, batch: dict) -> dict: + data = batch['data'] + target = batch['target'] + + data = data.to(self.device, non_blocking=True) + if isinstance(target, list): + target = [i.to(self.device, non_blocking=True) for i in target] + else: + target = target.to(self.device, non_blocking=True) + + self.optimizer.zero_grad() + + # Autocast is a little bitch. + # If the device_type is 'cpu' then it's slow as heck and needs to be disabled. + # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False is set. Whyyyyyyy. (this is why we don't make use of enabled=False) + # So autocast will only be active if we have a cuda device. + with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context(): + output = self.network(data) + del data + l = self.loss(output, target) + + # the following is needed for online evaluation. Fake dice (green line) + axes = [0] + list(range(2, len(output.shape))) + + if self.label_manager.has_regions: + predicted_segmentation_onehot = (torch.sigmoid(output) > 0.5).long() + else: + # no need for softmax + output_seg = output.argmax(1)[:, None] + predicted_segmentation_onehot = torch.zeros(output.shape, device=output.device, dtype=torch.float32) + predicted_segmentation_onehot.scatter_(1, output_seg, 1) + del output_seg + + if self.label_manager.has_ignore_label: + if not self.label_manager.has_regions: + mask = (target != self.label_manager.ignore_label).float() + # CAREFUL that you don't rely on target after this line! + target[target == self.label_manager.ignore_label] = 0 + else: + mask = 1 - target[:, -1:] + # CAREFUL that you don't rely on target after this line! + target = target[:, :-1] + else: + mask = None + + tp, fp, fn, _ = get_tp_fp_fn_tn(predicted_segmentation_onehot, target, axes=axes, mask=mask) + + tp_hard = tp.detach().cpu().numpy() + fp_hard = fp.detach().cpu().numpy() + fn_hard = fn.detach().cpu().numpy() + if not self.label_manager.has_regions: + # if we train with regions all segmentation heads predict some kind of foreground. In conventional + # (softmax training) there needs tobe one output for the background. We are not interested in the + # background Dice + # [1:] in order to remove background + tp_hard = tp_hard[1:] + fp_hard = fp_hard[1:] + fn_hard = fn_hard[1:] + + return {'loss': l.detach().cpu().numpy(), 'tp_hard': tp_hard, 'fp_hard': fp_hard, 'fn_hard': fn_hard} \ No newline at end of file diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/optimizer/__init__.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/optimizer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/optimizer/nnUNetTrainerAdam.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/optimizer/nnUNetTrainerAdam.py new file mode 100644 index 0000000..be5a7f4 --- /dev/null +++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/optimizer/nnUNetTrainerAdam.py @@ -0,0 +1,58 @@ +import torch +from torch.optim import Adam, AdamW + +from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer + + +class nnUNetTrainerAdam(nnUNetTrainer): + def configure_optimizers(self): + optimizer = AdamW(self.network.parameters(), + lr=self.initial_lr, + weight_decay=self.weight_decay, + amsgrad=True) + # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, + # momentum=0.99, nesterov=True) + lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs) + return optimizer, lr_scheduler + + +class nnUNetTrainerVanillaAdam(nnUNetTrainer): + def configure_optimizers(self): + optimizer = Adam(self.network.parameters(), + lr=self.initial_lr, + weight_decay=self.weight_decay) + # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, + # momentum=0.99, nesterov=True) + lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs) + return optimizer, lr_scheduler + + +class nnUNetTrainerVanillaAdam1en3(nnUNetTrainerVanillaAdam): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.initial_lr = 1e-3 + + +class nnUNetTrainerVanillaAdam3en4(nnUNetTrainerVanillaAdam): + # https://twitter.com/karpathy/status/801621764144971776?lang=en + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.initial_lr = 3e-4 + + +class nnUNetTrainerAdam1en3(nnUNetTrainerAdam): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.initial_lr = 1e-3 + + +class nnUNetTrainerAdam3en4(nnUNetTrainerAdam): + # https://twitter.com/karpathy/status/801621764144971776?lang=en + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.initial_lr = 3e-4 diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/optimizer/nnUNetTrainerAdan.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/optimizer/nnUNetTrainerAdan.py new file mode 100644 index 0000000..8747f47 --- /dev/null +++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/optimizer/nnUNetTrainerAdan.py @@ -0,0 +1,66 @@ +import torch + +from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +from torch.optim.lr_scheduler import CosineAnnealingLR +try: + from adan_pytorch import Adan +except ImportError: + Adan = None + + +class nnUNetTrainerAdan(nnUNetTrainer): + def configure_optimizers(self): + if Adan is None: + raise RuntimeError('This trainer requires adan_pytorch to be installed, install with "pip install adan-pytorch"') + optimizer = Adan(self.network.parameters(), + lr=self.initial_lr, + # betas=(0.02, 0.08, 0.01), defaults + weight_decay=self.weight_decay) + # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, + # momentum=0.99, nesterov=True) + lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs) + return optimizer, lr_scheduler + + +class nnUNetTrainerAdan1en3(nnUNetTrainerAdan): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.initial_lr = 1e-3 + + +class nnUNetTrainerAdan3en4(nnUNetTrainerAdan): + # https://twitter.com/karpathy/status/801621764144971776?lang=en + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.initial_lr = 3e-4 + + +class nnUNetTrainerAdan1en1(nnUNetTrainerAdan): + # this trainer makes no sense -> nan! + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.initial_lr = 1e-1 + + +class nnUNetTrainerAdanCosAnneal(nnUNetTrainerAdan): + # def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + # device: torch.device = torch.device('cuda')): + # super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + # self.num_epochs = 15 + + def configure_optimizers(self): + if Adan is None: + raise RuntimeError('This trainer requires adan_pytorch to be installed, install with "pip install adan-pytorch"') + optimizer = Adan(self.network.parameters(), + lr=self.initial_lr, + # betas=(0.02, 0.08, 0.01), defaults + weight_decay=self.weight_decay) + # optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, + # momentum=0.99, nesterov=True) + lr_scheduler = CosineAnnealingLR(optimizer, T_max=self.num_epochs) + return optimizer, lr_scheduler + diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/sampling/__init__.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/sampling/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/sampling/nnUNetTrainer_probabilisticOversampling.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/sampling/nnUNetTrainer_probabilisticOversampling.py new file mode 100644 index 0000000..89fef48 --- /dev/null +++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/sampling/nnUNetTrainer_probabilisticOversampling.py @@ -0,0 +1,76 @@ +from typing import Tuple + +import torch + +from nnunetv2.training.dataloading.data_loader_2d import nnUNetDataLoader2D +from nnunetv2.training.dataloading.data_loader_3d import nnUNetDataLoader3D +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer +import numpy as np + + +class nnUNetTrainer_probabilisticOversampling(nnUNetTrainer): + """ + sampling of foreground happens randomly and not for the last 33% of samples in a batch + since most trainings happen with batch size 2 and nnunet guarantees at least one fg sample, effectively this can + be 50% + Here we compute the actual oversampling percentage used by nnUNetTrainer in order to be as consistent as possible. + If we switch to this oversampling then we can keep it at a constant 0.33 or whatever. + """ + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.oversample_foreground_percent = float(np.mean( + [not sample_idx < round(self.configuration_manager.batch_size * (1 - self.oversample_foreground_percent)) + for sample_idx in range(self.configuration_manager.batch_size)])) + self.print_to_log_file(f"self.oversample_foreground_percent {self.oversample_foreground_percent}") + + def get_plain_dataloaders(self, initial_patch_size: Tuple[int, ...], dim: int): + dataset_tr, dataset_val = self.get_tr_and_val_datasets() + + if dim == 2: + dl_tr = nnUNetDataLoader2D(dataset_tr, + self.batch_size, + initial_patch_size, + self.configuration_manager.patch_size, + self.label_manager, + oversample_foreground_percent=self.oversample_foreground_percent, + sampling_probabilities=None, pad_sides=None, probabilistic_oversampling=True) + dl_val = nnUNetDataLoader2D(dataset_val, + self.batch_size, + self.configuration_manager.patch_size, + self.configuration_manager.patch_size, + self.label_manager, + oversample_foreground_percent=self.oversample_foreground_percent, + sampling_probabilities=None, pad_sides=None, probabilistic_oversampling=True) + else: + dl_tr = nnUNetDataLoader3D(dataset_tr, + self.batch_size, + initial_patch_size, + self.configuration_manager.patch_size, + self.label_manager, + oversample_foreground_percent=self.oversample_foreground_percent, + sampling_probabilities=None, pad_sides=None, probabilistic_oversampling=True) + dl_val = nnUNetDataLoader3D(dataset_val, + self.batch_size, + self.configuration_manager.patch_size, + self.configuration_manager.patch_size, + self.label_manager, + oversample_foreground_percent=self.oversample_foreground_percent, + sampling_probabilities=None, pad_sides=None, probabilistic_oversampling=True) + return dl_tr, dl_val + + +class nnUNetTrainer_probabilisticOversampling_033(nnUNetTrainer_probabilisticOversampling): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.oversample_foreground_percent = 0.33 + + +class nnUNetTrainer_probabilisticOversampling_010(nnUNetTrainer_probabilisticOversampling): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.oversample_foreground_percent = 0.1 + + diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/training_length/__init__.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/training_length/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py new file mode 100644 index 0000000..e3a71a0 --- /dev/null +++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py @@ -0,0 +1,76 @@ +import torch + +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer + + +class nnUNetTrainer_5epochs(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + """used for debugging plans etc""" + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 5 + + +class nnUNetTrainer_1epoch(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + """used for debugging plans etc""" + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 1 + + +class nnUNetTrainer_10epochs(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + """used for debugging plans etc""" + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 10 + + +class nnUNetTrainer_20epochs(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 20 + + +class nnUNetTrainer_50epochs(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 50 + + +class nnUNetTrainer_100epochs(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 100 + + +class nnUNetTrainer_250epochs(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 250 + + +class nnUNetTrainer_2000epochs(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 2000 + + +class nnUNetTrainer_4000epochs(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 4000 + + +class nnUNetTrainer_8000epochs(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 8000 diff --git a/nnUNet/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs_NoMirroring.py b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs_NoMirroring.py new file mode 100644 index 0000000..c16b885 --- /dev/null +++ b/nnUNet/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs_NoMirroring.py @@ -0,0 +1,60 @@ +import torch + +from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer + + +class nnUNetTrainer_250epochs_NoMirroring(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 250 + + def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): + rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ + super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() + mirror_axes = None + self.inference_allowed_mirroring_axes = None + return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes + + +class nnUNetTrainer_2000epochs_NoMirroring(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 2000 + + def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): + rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ + super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() + mirror_axes = None + self.inference_allowed_mirroring_axes = None + return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes + + +class nnUNetTrainer_4000epochs_NoMirroring(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 4000 + + def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): + rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ + super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() + mirror_axes = None + self.inference_allowed_mirroring_axes = None + return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes + + +class nnUNetTrainer_8000epochs_NoMirroring(nnUNetTrainer): + def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, + device: torch.device = torch.device('cuda')): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + self.num_epochs = 8000 + + def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): + rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ + super().configure_rotation_dummyDA_mirroring_and_inital_patch_size() + mirror_axes = None + self.inference_allowed_mirroring_axes = None + return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes + diff --git a/nnUNet/nnunetv2/unet.py b/nnUNet/nnunetv2/unet.py new file mode 100644 index 0000000..df14b63 --- /dev/null +++ b/nnUNet/nnunetv2/unet.py @@ -0,0 +1,344 @@ +from typing import T +import torch +from torch.nn import Conv3d +from einops import rearrange, repeat +from torch import nn, einsum +from inspect import isfunction + + +class UNetEncoderS(nn.Module): + + def __init__(self, channels): + super(UNetEncoderS, self).__init__() + self.inc = (DoubleConv(channels, 16)) + self.down1 = (Down(16, 32, pooling=(1, 2, 2))) + self.down2 = (Down(32, 64)) + self.down3 = (Down(64, 128)) + self.down4 = (Down(128, 256)) + + def forward(self, x): + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + skips = [x1, x2, x3, x4] + return x5, skips + + +class UNetEncoderL(nn.Module): + def __init__(self, channels): + super(UNetEncoderL, self).__init__() + self.inc = (DoubleConv(channels, 32)) + self.down1 = (Down(32, 64, pooling=(1, 2, 2))) + self.down2 = (Down(64, 128)) + self.down3 = (Down(128, 256)) + self.down4 = (Down(256, 512)) + + def forward(self, x): + + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + skips = [x1, x2, x3, x4] + return x5, skips + + +class OutConv(nn.Module): + def __init__(self, in_channels, out_channels): + super(OutConv, self).__init__() + self.conv1 = DoubleConv(in_channels, out_channels) + self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=1) + + def forward(self, x): + x = self.conv1(x) + return self.conv2(x) + + +class SegmentationHeadS(nn.Module): + def __init__(self, in_features, segmentation_classes, do_ds): + super(SegmentationHeadS, self).__init__() + self.up_segmentation1 = (Up(in_features, 128, + bilinear=False) + ) + self.up_segmentation2 = (Up(128, 64, bilinear=False)) + self.up_segmentation3 = (Up(64, 32, bilinear=False)) + self.up_segmentation4 = (Up(32, 16, bilinear=False, + pooling=(1, 2, 2))) + self.outc_segmentation = (OutConv(16, segmentation_classes)) + + self.do_ds = do_ds + self.proj1 = Conv3d(256, segmentation_classes, 1) + self.proj2 = Conv3d(128, segmentation_classes, 1) + self.proj3 = Conv3d(64, segmentation_classes, 1) + self.proj4 = Conv3d(32, segmentation_classes, 1) + self.non_lin = lambda x: x # torch.softmax(x, 1) + + def forward(self, x, skips): + x1, x2, x3, x4 = skips + if self.do_ds: + outputs = [] + outputs.append(self.non_lin(self.proj1(x))) + x = self.up_segmentation1(x, x4) + if self.do_ds: + outputs.append(self.non_lin(self.proj2(x))) + x = self.up_segmentation2(x, x3) + if self.do_ds: + outputs.append(self.non_lin(self.proj3(x))) + x = self.up_segmentation3(x, x2) + if self.do_ds: + outputs.append(self.non_lin(self.proj4(x))) + x = self.up_segmentation4(x, x1) + x = self.outc_segmentation(x) + if self.do_ds: + outputs.append(self.non_lin(x)) + if self.do_ds: + return outputs[::-1] + return self.non_lin(x) + + def eval(self: T) -> T: + a = super(SegmentationHeadS, self).eval() + self.do_ds = False + return a + + def train(self: T, mode: bool = True) -> T: + a = super(SegmentationHeadS, self).train() + self.do_ds = True + return a + + +class DoubleConv(nn.Module): + + """(convolution => [BN] => ReLU) * 2""" + def __init__(self, in_channels, out_channels, + mid_channels=None, dropout_rate: float = 0.): + super().__init__() + if not mid_channels: + mid_channels = out_channels + self.double_conv = nn.Sequential( + nn.Conv3d(in_channels, mid_channels, kernel_size=3, + padding=1, bias=True), + nn.Dropout3d(dropout_rate), + nn.BatchNorm3d(mid_channels), + nn.LeakyReLU(inplace=True), + nn.Conv3d(mid_channels, out_channels, kernel_size=3, + padding=1, bias=True), + nn.Dropout3d(dropout_rate), + nn.BatchNorm3d(out_channels), + nn.LeakyReLU(inplace=True) + ) + + def forward(self, x): + return self.double_conv(x) + + +class Down(nn.Module): + """Downscaling with maxpool then double conv""" + + def __init__(self, in_channels, out_channels, pooling=(2,2,2)): + super().__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool3d(pooling), + DoubleConv(in_channels, out_channels, dropout_rate=0.5) + ) + + def forward(self, x): + return self.maxpool_conv(x) + + +class Up(nn.Module): + """Upscaling then double conv""" + + def __init__(self, in_channels, out_channels, pooling=(2, 2, 2), bilinear=True): + super().__init__() + + # if bilinear, use the normal convolutions to reduce the number of channels + if bilinear: + self.up = nn.Upsample(scale_factor=pooling, mode='bilinear', align_corners=True) + self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) + else: + self.up = nn.ConvTranspose3d(in_channels, in_channels // 2, kernel_size=pooling, stride=pooling, bias=False) + self.conv = DoubleConv(in_channels, out_channels) + + def forward(self, x1, x2): + x1 = self.up(x1) + # input is CHW + # diffY = x2.size()[2] - x1.size()[2] + # diffX = x2.size()[3] - x1.size()[3] + + # x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, + # diffY // 2, diffY - diffY // 2]) + # if you have padding issues, see + # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a + # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + + +class SegmentationHeadL(nn.Module): + def __init__(self, in_features, segmentation_classes, do_ds): + super(SegmentationHeadL, self).__init__() + self.up_segmentation1 = (Up(in_features, 256, bilinear=False)) + self.up_segmentation2 = (Up(256, 128, bilinear=False)) + self.up_segmentation3 = (Up(128, 64, bilinear=False)) + self.up_segmentation4 = (Up(64, 32, bilinear=False, + pooling=(1, 2, 2))) + self.outc_segmentation = (OutConv(32, segmentation_classes)) + + self.do_ds = do_ds + self.proj1 = Conv3d(512, segmentation_classes, 1) + self.proj2 = Conv3d(256, segmentation_classes, 1) + self.proj3 = Conv3d(128, segmentation_classes, 1) + self.proj4 = Conv3d(64, segmentation_classes, 1) + self.non_lin = lambda x: x # torch.softmax(x, 1) + + def forward(self, x, skips): + x1, x2, x3, x4 = skips + if self.do_ds: + outputs = [] + outputs.append(self.non_lin(self.proj1(x))) + x = self.up_segmentation1(x, x4) + if self.do_ds: + outputs.append(self.non_lin(self.proj2(x))) + x = self.up_segmentation2(x, x3) + if self.do_ds: + outputs.append(self.non_lin(self.proj3(x))) + x = self.up_segmentation3(x, x2) + if self.do_ds: + outputs.append(self.non_lin(self.proj4(x))) + x = self.up_segmentation4(x, x1) + x = self.outc_segmentation(x) + if self.do_ds: + outputs.append(self.non_lin(x)) + if self.do_ds: + return outputs[::-1] + return self.non_lin(x) + + def eval(self: T) -> T: + a = super(SegmentationHeadL, self).eval() + self.do_ds = False + return a + + def train(self: T, mode: bool = True) -> T: + a = super(SegmentationHeadL, self).train() + self.do_ds = True + return a + + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +class CrossAttention(nn.Module): + + def __init__(self, query_dim, context_dim=None, + heads=8, dim_head=64, dropout=0.): + + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Conv3d(query_dim, inner_dim, + kernel_size=1, bias=False) + self.to_k = nn.Conv3d(context_dim, inner_dim, + kernel_size=1, bias=False) + self.to_v = nn.Conv3d(context_dim, inner_dim, + kernel_size=1, bias=False) + + self.to_out = nn.Sequential( + nn.Conv3d(inner_dim, query_dim, kernel_size=1), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + n = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, c, h, w, d = x.shape + q, k, v = map(lambda t: + rearrange(t, 'b (n c) h w l -> (b n) c (h w l)', + n=n), (q, k, v)) + + # force cast to fp32 to avoid overflowing + with torch.autocast(enabled=False, device_type='cuda'): + q, k = q.float(), k.float() + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + del q, k + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', sim, v) + out = rearrange(out, '(b n) c (h w l) -> b (n c) h w l', n=n, h=h, w=w, l=d) + out = self.to_out(out) + return out + + +class UNetDeepSupervisionDoubleEncoder(nn.Module): + def __init__(self, n_channels_1, n_channels_2, + n_classes_segmentation, deep_supervision=True, + encoder=UNetEncoderL, + segmentation_head=SegmentationHeadL): + super(UNetDeepSupervisionDoubleEncoder, self).__init__() + self.n_channels_1 = n_channels_1 + self.n_channels_2 = n_channels_2 + self.n_classes_segmentation = n_classes_segmentation + self.nb_decoders = 4 + self.do_ds = deep_supervision + self.deep_supervision = deep_supervision + + self.encoder1 = encoder(self.n_channels_1) + self.encoder2 = encoder(self.n_channels_2) + feature_size = 512 if segmentation_head == SegmentationHeadL else 256 + print(feature_size) + self.segmentation_head = segmentation_head(feature_size, + self.n_classes_segmentation, + self.do_ds) + # self.CA = CrossAttention(query_dim=feature_size) + + def forward(self, x_in): + x, y = x_in[:, 0:1, :, :, :], x_in[:, 1:2, :, :, :] + features1, skips_1 = self.encoder1(x) + features2, skips_2 = self.encoder2(y) + x = torch.cat([features1, features2], dim=1) + skip = [torch.cat([s1, s2], dim=1) for s1, s2 in zip(skips_1, skips_2)] + + # self.CA(features1, context=features2) + return self.segmentation_head(x, skip) + + def eval(self: T) -> T: + super(UNetDeepSupervisionDoubleEncoder, self).eval() + self.do_ds = False + self.segmentation_head.eval() + self.encoder1.eval() + self.encoder2.eval() + + def train(self: T, mode: bool = True) -> T: + super(UNetDeepSupervisionDoubleEncoder, self).train() + self.do_ds = True + self.encoder1.train() + self.encoder2.train() + self.segmentation_head.train() + diff --git a/nnUNet/nnunetv2/utilities/__init__.py b/nnUNet/nnunetv2/utilities/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/utilities/collate_outputs.py b/nnUNet/nnunetv2/utilities/collate_outputs.py new file mode 100644 index 0000000..c9d6798 --- /dev/null +++ b/nnUNet/nnunetv2/utilities/collate_outputs.py @@ -0,0 +1,24 @@ +from typing import List + +import numpy as np + + +def collate_outputs(outputs: List[dict]): + """ + used to collate default train_step and validation_step outputs. If you want something different then you gotta + extend this + + we expect outputs to be a list of dictionaries where each of the dict has the same set of keys + """ + collated = {} + for k in outputs[0].keys(): + if np.isscalar(outputs[0][k]): + collated[k] = [o[k] for o in outputs] + elif isinstance(outputs[0][k], np.ndarray): + collated[k] = np.vstack([o[k][None] for o in outputs]) + elif isinstance(outputs[0][k], list): + collated[k] = [item for o in outputs for item in o[k]] + else: + raise ValueError(f'Cannot collate input of type {type(outputs[0][k])}. ' + f'Modify collate_outputs to add this functionality') + return collated \ No newline at end of file diff --git a/nnUNet/nnunetv2/utilities/dataset_name_id_conversion.py b/nnUNet/nnunetv2/utilities/dataset_name_id_conversion.py new file mode 100644 index 0000000..1f2c350 --- /dev/null +++ b/nnUNet/nnunetv2/utilities/dataset_name_id_conversion.py @@ -0,0 +1,74 @@ +# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Union + +from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw, nnUNet_results +from batchgenerators.utilities.file_and_folder_operations import * +import numpy as np + + +def find_candidate_datasets(dataset_id: int): + startswith = "Dataset%03.0d" % dataset_id + if nnUNet_preprocessed is not None and isdir(nnUNet_preprocessed): + candidates_preprocessed = subdirs(nnUNet_preprocessed, prefix=startswith, join=False) + else: + candidates_preprocessed = [] + + if nnUNet_raw is not None and isdir(nnUNet_raw): + candidates_raw = subdirs(nnUNet_raw, prefix=startswith, join=False) + else: + candidates_raw = [] + + candidates_trained_models = [] + if nnUNet_results is not None and isdir(nnUNet_results): + candidates_trained_models += subdirs(nnUNet_results, prefix=startswith, join=False) + + all_candidates = candidates_preprocessed + candidates_raw + candidates_trained_models + unique_candidates = np.unique(all_candidates) + return unique_candidates + + +def convert_id_to_dataset_name(dataset_id: int): + unique_candidates = find_candidate_datasets(dataset_id) + if len(unique_candidates) > 1: + raise RuntimeError("More than one dataset name found for dataset id %d. Please correct that. (I looked in the " + "following folders:\n%s\n%s\n%s" % (dataset_id, nnUNet_raw, nnUNet_preprocessed, nnUNet_results)) + if len(unique_candidates) == 0: + raise RuntimeError(f"Could not find a dataset with the ID {dataset_id}. Make sure the requested dataset ID " + f"exists and that nnU-Net knows where raw and preprocessed data are located " + f"(see Documentation - Installation). Here are your currently defined folders:\n" + f"nnUNet_preprocessed={os.environ.get('nnUNet_preprocessed') if os.environ.get('nnUNet_preprocessed') is not None else 'None'}\n" + f"nnUNet_results={os.environ.get('nnUNet_results') if os.environ.get('nnUNet_results') is not None else 'None'}\n" + f"nnUNet_raw={os.environ.get('nnUNet_raw') if os.environ.get('nnUNet_raw') is not None else 'None'}\n" + f"If something is not right, adapt your environment variables.") + return unique_candidates[0] + + +def convert_dataset_name_to_id(dataset_name: str): + assert dataset_name.startswith("Dataset") + dataset_id = int(dataset_name[7:10]) + return dataset_id + + +def maybe_convert_to_dataset_name(dataset_name_or_id: Union[int, str]) -> str: + if isinstance(dataset_name_or_id, str) and dataset_name_or_id.startswith("Dataset"): + return dataset_name_or_id + if isinstance(dataset_name_or_id, str): + try: + dataset_name_or_id = int(dataset_name_or_id) + except ValueError: + raise ValueError("dataset_name_or_id was a string and did not start with 'Dataset' so we tried to " + "convert it to a dataset ID (int). That failed, however. Please give an integer number " + "('1', '2', etc) or a correct tast name. Your input: %s" % dataset_name_or_id) + return convert_id_to_dataset_name(dataset_name_or_id) \ No newline at end of file diff --git a/nnUNet/nnunetv2/utilities/ddp_allgather.py b/nnUNet/nnunetv2/utilities/ddp_allgather.py new file mode 100644 index 0000000..c42b3ef --- /dev/null +++ b/nnUNet/nnunetv2/utilities/ddp_allgather.py @@ -0,0 +1,49 @@ +# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Tuple + +import torch +from torch import distributed + + +def print_if_rank0(*args): + if distributed.get_rank() == 0: + print(*args) + + +class AllGatherGrad(torch.autograd.Function): + # stolen from pytorch lightning + @staticmethod + def forward( + ctx: Any, + tensor: torch.Tensor, + group: Optional["torch.distributed.ProcessGroup"] = None, + ) -> torch.Tensor: + ctx.group = group + + gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] + + torch.distributed.all_gather(gathered_tensor, tensor, group=group) + gathered_tensor = torch.stack(gathered_tensor, dim=0) + + return gathered_tensor + + @staticmethod + def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]: + grad_output = torch.cat(grad_output) + + torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group) + + return grad_output[torch.distributed.get_rank()], None + diff --git a/nnUNet/nnunetv2/utilities/default_n_proc_DA.py b/nnUNet/nnunetv2/utilities/default_n_proc_DA.py new file mode 100644 index 0000000..3ecc922 --- /dev/null +++ b/nnUNet/nnunetv2/utilities/default_n_proc_DA.py @@ -0,0 +1,44 @@ +import subprocess +import os + + +def get_allowed_n_proc_DA(): + """ + This function is used to set the number of processes used on different Systems. It is specific to our cluster + infrastructure at DKFZ. You can modify it to suit your needs. Everything is allowed. + + IMPORTANT: if the environment variable nnUNet_n_proc_DA is set it will overwrite anything in this script + (see first line). + + Interpret the output as the number of processes used for data augmentation PER GPU. + + The way it is implemented here is simply a look up table. We know the hostnames, CPU and GPU configurations of our + systems and set the numbers accordingly. For example, a system with 4 GPUs and 48 threads can use 12 threads per + GPU without overloading the CPU (technically 11 because we have a main process as well), so that's what we use. + """ + + if 'nnUNet_n_proc_DA' in os.environ.keys(): + use_this = int(os.environ['nnUNet_n_proc_DA']) + else: + hostname = subprocess.getoutput(['hostname']) + if hostname in ['Fabian', ]: + use_this = 12 + elif hostname in ['hdf19-gpu16', 'hdf19-gpu17', 'hdf19-gpu18', 'hdf19-gpu19', 'e230-AMDworkstation']: + use_this = 16 + elif hostname.startswith('e230-dgx1'): + use_this = 10 + elif hostname.startswith('hdf18-gpu') or hostname.startswith('e132-comp'): + use_this = 16 + elif hostname.startswith('e230-dgx2'): + use_this = 6 + elif hostname.startswith('e230-dgxa100-'): + use_this = 28 + elif hostname.startswith('lsf22-gpu'): + use_this = 28 + elif hostname.startswith('hdf19-gpu') or hostname.startswith('e071-gpu'): + use_this = 12 + else: + use_this = 12 # default value + + use_this = min(use_this, os.cpu_count()) + return use_this diff --git a/nnUNet/nnunetv2/utilities/file_path_utilities.py b/nnUNet/nnunetv2/utilities/file_path_utilities.py new file mode 100644 index 0000000..611f6e2 --- /dev/null +++ b/nnUNet/nnunetv2/utilities/file_path_utilities.py @@ -0,0 +1,123 @@ +from multiprocessing import Pool +from typing import Union, Tuple +import numpy as np +from batchgenerators.utilities.file_and_folder_operations import * + +from nnunetv2.configuration import default_num_processes +from nnunetv2.paths import nnUNet_results +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name + + +def convert_trainer_plans_config_to_identifier(trainer_name, plans_identifier, configuration): + return f'{trainer_name}__{plans_identifier}__{configuration}' + + +def convert_identifier_to_trainer_plans_config(identifier: str): + return os.path.basename(identifier).split('__') + + +def get_output_folder(dataset_name_or_id: Union[str, int], trainer_name: str = 'nnUNetTrainer', + plans_identifier: str = 'nnUNetPlans', configuration: str = '3d_fullres', + fold: Union[str, int] = None) -> str: + tmp = join(nnUNet_results, maybe_convert_to_dataset_name(dataset_name_or_id), + convert_trainer_plans_config_to_identifier(trainer_name, plans_identifier, configuration)) + if fold is not None: + tmp = join(tmp, f'fold_{fold}') + return tmp + + +def parse_dataset_trainer_plans_configuration_from_path(path: str): + folders = split_path(path) + # this here can be a little tricky because we are making assumptions. Let's hope this never fails lol + + # safer to make this depend on two conditions, the fold_x and the DatasetXXX + # first let's see if some fold_X is present + fold_x_present = [i.startswith('fold_') for i in folders] + if any(fold_x_present): + idx = fold_x_present.index(True) + # OK now two entries before that there should be DatasetXXX + assert len(folders[:idx]) >= 2, 'Bad path, cannot extract what I need. Your path needs to be at least ' \ + 'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work' + if folders[idx - 2].startswith('Dataset'): + splitted = folders[idx - 1].split('__') + assert len(splitted) == 3, 'Bad path, cannot extract what I need. Your path needs to be at least ' \ + 'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work' + return folders[idx - 2], *splitted + else: + # we can only check for dataset followed by a string that is separable into three strings by splitting with '__' + # look for DatasetXXX + dataset_folder = [i.startswith('Dataset') for i in folders] + if any(dataset_folder): + idx = dataset_folder.index(True) + assert len(folders) >= (idx + 1), 'Bad path, cannot extract what I need. Your path needs to be at least ' \ + 'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work' + splitted = folders[idx + 1].split('__') + assert len(splitted) == 3, 'Bad path, cannot extract what I need. Your path needs to be at least ' \ + 'DatasetXXX/MODULE__PLANS__CONFIGURATION for this to work' + return folders[idx], *splitted + + +def get_ensemble_name(model1_folder, model2_folder, folds: Tuple[int, ...]): + identifier = 'ensemble___' + os.path.basename(model1_folder) + '___' + \ + os.path.basename(model2_folder) + '___' + folds_tuple_to_string(folds) + return identifier + + +def get_ensemble_name_from_d_tr_c(dataset, tr1, p1, c1, tr2, p2, c2, folds: Tuple[int, ...]): + model1_folder = get_output_folder(dataset, tr1, p1, c1) + model2_folder = get_output_folder(dataset, tr2, p2, c2) + + get_ensemble_name(model1_folder, model2_folder, folds) + + +def convert_ensemble_folder_to_model_identifiers_and_folds(ensemble_folder: str): + prefix, *models, folds = os.path.basename(ensemble_folder).split('___') + return models, folds + + +def folds_tuple_to_string(folds: Union[List[int], Tuple[int, ...]]): + s = str(folds[0]) + for f in folds[1:]: + s += f"_{f}" + return s + + +def folds_string_to_tuple(folds_string: str): + folds = folds_string.split('_') + res = [] + for f in folds: + try: + res.append(int(f)) + except ValueError: + res.append(f) + return res + + +def check_workers_alive_and_busy(export_pool: Pool, worker_list: List, results_list: List, allowed_num_queued: int = 0): + """ + + returns True if the number of results that are not ready is greater than the number of available workers + allowed_num_queued + """ + alive = [i.is_alive() for i in worker_list] + if not all(alive): + raise RuntimeError('Some background workers are no longer alive') + + not_ready = [not i.ready() for i in results_list] + if sum(not_ready) >= (len(export_pool._pool) + allowed_num_queued): + return True + return False + + +if __name__ == '__main__': + ### well at this point I could just write tests... + path = '/home/fabian/results/nnUNet_remake/Dataset002_Heart/nnUNetModule__nnUNetPlans__3d_fullres' + print(parse_dataset_trainer_plans_configuration_from_path(path)) + path = 'Dataset002_Heart/nnUNetModule__nnUNetPlans__3d_fullres' + print(parse_dataset_trainer_plans_configuration_from_path(path)) + path = '/home/fabian/results/nnUNet_remake/Dataset002_Heart/nnUNetModule__nnUNetPlans__3d_fullres/fold_all' + print(parse_dataset_trainer_plans_configuration_from_path(path)) + try: + path = '/home/fabian/results/nnUNet_remake/Dataset002_Heart/' + print(parse_dataset_trainer_plans_configuration_from_path(path)) + except AssertionError: + print('yayy, assertion works') diff --git a/nnUNet/nnunetv2/utilities/find_class_by_name.py b/nnUNet/nnunetv2/utilities/find_class_by_name.py new file mode 100644 index 0000000..a345d99 --- /dev/null +++ b/nnUNet/nnunetv2/utilities/find_class_by_name.py @@ -0,0 +1,24 @@ +import importlib +import pkgutil + +from batchgenerators.utilities.file_and_folder_operations import * + + +def recursive_find_python_class(folder: str, class_name: str, current_module: str): + tr = None + for importer, modname, ispkg in pkgutil.iter_modules([folder]): + # print(modname, ispkg) + if not ispkg: + m = importlib.import_module(current_module + "." + modname) + if hasattr(m, class_name): + tr = getattr(m, class_name) + break + + if tr is None: + for importer, modname, ispkg in pkgutil.iter_modules([folder]): + if ispkg: + next_current_module = current_module + "." + modname + tr = recursive_find_python_class(join(folder, modname), class_name, current_module=next_current_module) + if tr is not None: + break + return tr \ No newline at end of file diff --git a/nnUNet/nnunetv2/utilities/get_network_from_plans.py b/nnUNet/nnunetv2/utilities/get_network_from_plans.py new file mode 100644 index 0000000..447d1d5 --- /dev/null +++ b/nnUNet/nnunetv2/utilities/get_network_from_plans.py @@ -0,0 +1,77 @@ +from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet +from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op +from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0 +from nnunetv2.utilities.network_initialization import InitWeights_He +from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager +from torch import nn + + +def get_network_from_plans(plans_manager: PlansManager, + dataset_json: dict, + configuration_manager: ConfigurationManager, + num_input_channels: int, + deep_supervision: bool = True): + """ + we may have to change this in the future to accommodate other plans -> network mappings + + num_input_channels can differ depending on whether we do cascade. Its best to make this info available in the + trainer rather than inferring it again from the plans here. + """ + num_stages = len(configuration_manager.conv_kernel_sizes) + + dim = len(configuration_manager.conv_kernel_sizes[0]) + conv_op = convert_dim_to_conv_op(dim) + + label_manager = plans_manager.get_label_manager(dataset_json) + + segmentation_network_class_name = configuration_manager.UNet_class_name + mapping = { + 'PlainConvUNet': PlainConvUNet, + 'ResidualEncoderUNet': ResidualEncoderUNet + } + kwargs = { + 'PlainConvUNet': { + 'conv_bias': True, + 'norm_op': get_matching_instancenorm(conv_op), + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, 'dropout_op_kwargs': None, + 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, + }, + 'ResidualEncoderUNet': { + 'conv_bias': True, + 'norm_op': get_matching_instancenorm(conv_op), + 'norm_op_kwargs': {'eps': 1e-5, 'affine': True}, + 'dropout_op': None, 'dropout_op_kwargs': None, + 'nonlin': nn.LeakyReLU, 'nonlin_kwargs': {'inplace': True}, + } + } + assert segmentation_network_class_name in mapping.keys(), 'The network architecture specified by the plans file ' \ + 'is non-standard (maybe your own?). Yo\'ll have to dive ' \ + 'into either this ' \ + 'function (get_network_from_plans) or ' \ + 'the init of your nnUNetModule to accomodate that.' + network_class = mapping[segmentation_network_class_name] + + conv_or_blocks_per_stage = { + 'n_conv_per_stage' + if network_class != ResidualEncoderUNet else 'n_blocks_per_stage': configuration_manager.n_conv_per_stage_encoder, + 'n_conv_per_stage_decoder': configuration_manager.n_conv_per_stage_decoder + } + # network class name!! + model = network_class( + input_channels=num_input_channels, + n_stages=num_stages, + features_per_stage=[min(configuration_manager.UNet_base_num_features * 2 ** i, + configuration_manager.unet_max_num_features) for i in range(num_stages)], + conv_op=conv_op, + kernel_sizes=configuration_manager.conv_kernel_sizes, + strides=configuration_manager.pool_op_kernel_sizes, + num_classes=label_manager.num_segmentation_heads, + deep_supervision=deep_supervision, + **conv_or_blocks_per_stage, + **kwargs[segmentation_network_class_name] + ) + model.apply(InitWeights_He(1e-2)) + if network_class == ResidualEncoderUNet: + model.apply(init_last_bn_before_add_to_0) + return model diff --git a/nnUNet/nnunetv2/utilities/helpers.py b/nnUNet/nnunetv2/utilities/helpers.py new file mode 100644 index 0000000..42448e3 --- /dev/null +++ b/nnUNet/nnunetv2/utilities/helpers.py @@ -0,0 +1,27 @@ +import torch + + +def softmax_helper_dim0(x: torch.Tensor) -> torch.Tensor: + return torch.softmax(x, 0) + + +def softmax_helper_dim1(x: torch.Tensor) -> torch.Tensor: + return torch.softmax(x, 1) + + +def empty_cache(device: torch.device): + if device.type == 'cuda': + torch.cuda.empty_cache() + elif device.type == 'mps': + from torch import mps + mps.empty_cache() + else: + pass + + +class dummy_context(object): + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + pass diff --git a/nnUNet/nnunetv2/utilities/json_export.py b/nnUNet/nnunetv2/utilities/json_export.py new file mode 100644 index 0000000..faed954 --- /dev/null +++ b/nnUNet/nnunetv2/utilities/json_export.py @@ -0,0 +1,59 @@ +from collections.abc import Iterable + +import numpy as np +import torch + + +def recursive_fix_for_json_export(my_dict: dict): + # json is stupid. 'cannot serialize object of type bool_/int64/float64'. Come on bro. + keys = list(my_dict.keys()) # cannot iterate over keys() if we change keys.... + for k in keys: + if isinstance(k, (np.int64, np.int32, np.int8, np.uint8)): + tmp = my_dict[k] + del my_dict[k] + my_dict[int(k)] = tmp + del tmp + k = int(k) + + if isinstance(my_dict[k], dict): + recursive_fix_for_json_export(my_dict[k]) + elif isinstance(my_dict[k], np.ndarray): + assert len(my_dict[k].shape) == 1, 'only 1d arrays are supported' + my_dict[k] = fix_types_iterable(my_dict[k], output_type=list) + elif isinstance(my_dict[k], (np.bool_,)): + my_dict[k] = bool(my_dict[k]) + elif isinstance(my_dict[k], (np.int64, np.int32, np.int8, np.uint8)): + my_dict[k] = int(my_dict[k]) + elif isinstance(my_dict[k], (np.float32, np.float64, np.float16)): + my_dict[k] = float(my_dict[k]) + elif isinstance(my_dict[k], list): + my_dict[k] = fix_types_iterable(my_dict[k], output_type=type(my_dict[k])) + elif isinstance(my_dict[k], tuple): + my_dict[k] = fix_types_iterable(my_dict[k], output_type=tuple) + elif isinstance(my_dict[k], torch.device): + my_dict[k] = str(my_dict[k]) + else: + pass # pray it can be serialized + + +def fix_types_iterable(iterable, output_type): + # this sh!t is hacky as hell and will break if you use it for anything outside nnunet. Keep you hands off of this. + out = [] + for i in iterable: + if type(i) in (np.int64, np.int32, np.int8, np.uint8): + out.append(int(i)) + elif isinstance(i, dict): + recursive_fix_for_json_export(i) + out.append(i) + elif type(i) in (np.float32, np.float64, np.float16): + out.append(float(i)) + elif type(i) in (np.bool_,): + out.append(bool(i)) + elif isinstance(i, str): + out.append(i) + elif isinstance(i, Iterable): + # print('recursive call on', i, type(i)) + out.append(fix_types_iterable(i, type(i))) + else: + out.append(i) + return output_type(out) diff --git a/nnUNet/nnunetv2/utilities/label_handling/__init__.py b/nnUNet/nnunetv2/utilities/label_handling/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/utilities/label_handling/label_handling.py b/nnUNet/nnunetv2/utilities/label_handling/label_handling.py new file mode 100644 index 0000000..333296d --- /dev/null +++ b/nnUNet/nnunetv2/utilities/label_handling/label_handling.py @@ -0,0 +1,322 @@ +from __future__ import annotations +from time import time +from typing import Union, List, Tuple, Type + +import numpy as np +import torch +from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice +from batchgenerators.utilities.file_and_folder_operations import join + +import nnunetv2 +from nnunetv2.utilities.find_class_by_name import recursive_find_python_class +from nnunetv2.utilities.helpers import softmax_helper_dim0 + +from typing import TYPE_CHECKING + +# see https://adamj.eu/tech/2021/05/13/python-type-hints-how-to-fix-circular-imports/ +if TYPE_CHECKING: + from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager + + +class LabelManager(object): + def __init__(self, label_dict: dict, regions_class_order: Union[List[int], None], force_use_labels: bool = False, + inference_nonlin=None): + self._sanity_check(label_dict) + self.label_dict = label_dict + self.regions_class_order = regions_class_order + self._force_use_labels = force_use_labels + + if force_use_labels: + self._has_regions = False + else: + self._has_regions: bool = any( + [isinstance(i, (tuple, list)) and len(i) > 1 for i in self.label_dict.values()]) + + self._ignore_label: Union[None, int] = self._determine_ignore_label() + self._all_labels: List[int] = self._get_all_labels() + + self._regions: Union[None, List[Union[int, Tuple[int, ...]]]] = self._get_regions() + + if self.has_ignore_label: + assert self.ignore_label == max( + self.all_labels) + 1, 'If you use the ignore label it must have the highest ' \ + 'label value! It cannot be 0 or in between other labels. ' \ + 'Sorry bro.' + + if inference_nonlin is None: + self.inference_nonlin = torch.sigmoid if self.has_regions else softmax_helper_dim0 + else: + self.inference_nonlin = inference_nonlin + + def _sanity_check(self, label_dict: dict): + if not 'background' in label_dict.keys(): + raise RuntimeError('Background label not declared (remeber that this should be label 0!)') + bg_label = label_dict['background'] + if isinstance(bg_label, (tuple, list)): + raise RuntimeError(f"Background label must be 0. Not a list. Not a tuple. Your background label: {bg_label}") + assert int(bg_label) == 0, f"Background label must be 0. Your background label: {bg_label}" + # not sure if we want to allow regions that contain background. I don't immediately see how this could cause + # problems so we allow it for now. That doesn't mean that this is explicitly supported. It could be that this + # just crashes. + + def _get_all_labels(self) -> List[int]: + all_labels = [] + for k, r in self.label_dict.items(): + # ignore label is not going to be used, hence the name. Duh. + if k == 'ignore': + continue + if isinstance(r, (tuple, list)): + for ri in r: + all_labels.append(int(ri)) + else: + all_labels.append(int(r)) + all_labels = list(np.unique(all_labels)) + all_labels.sort() + return all_labels + + def _get_regions(self) -> Union[None, List[Union[int, Tuple[int, ...]]]]: + if not self._has_regions or self._force_use_labels: + return None + else: + assert self.regions_class_order is not None, 'if region-based training is requested then you need to ' \ + 'define regions_class_order!' + regions = [] + for k, r in self.label_dict.items(): + # ignore ignore label + if k == 'ignore': + continue + # ignore regions that are background + if (np.isscalar(r) and r == 0) \ + or \ + (isinstance(r, (tuple, list)) and len(np.unique(r)) == 1 and np.unique(r)[0] == 0): + continue + if isinstance(r, list): + r = tuple(r) + regions.append(r) + assert len(self.regions_class_order) == len(regions), 'regions_class_order must have as ' \ + 'many entries as there are ' \ + 'regions' + return regions + + def _determine_ignore_label(self) -> Union[None, int]: + ignore_label = self.label_dict.get('ignore') + if ignore_label is not None: + assert isinstance(ignore_label, int), f'Ignore label has to be an integer. It cannot be a region ' \ + f'(list/tuple). Got {type(ignore_label)}.' + return ignore_label + + @property + def has_regions(self) -> bool: + return self._has_regions + + @property + def has_ignore_label(self) -> bool: + return self.ignore_label is not None + + @property + def all_regions(self) -> Union[None, List[Union[int, Tuple[int, ...]]]]: + return self._regions + + @property + def all_labels(self) -> List[int]: + return self._all_labels + + @property + def ignore_label(self) -> Union[None, int]: + return self._ignore_label + + def apply_inference_nonlin(self, logits: Union[np.ndarray, torch.Tensor]) -> \ + Union[np.ndarray, torch.Tensor]: + """ + logits has to have shape (c, x, y(, z)) where c is the number of classes/regions + """ + if isinstance(logits, np.ndarray): + logits = torch.from_numpy(logits) + + with torch.no_grad(): + # softmax etc is not implemented for half + logits = logits.float() + probabilities = self.inference_nonlin(logits) + + return probabilities + + def convert_probabilities_to_segmentation(self, predicted_probabilities: Union[np.ndarray, torch.Tensor]) -> \ + Union[np.ndarray, torch.Tensor]: + """ + assumes that inference_nonlinearity was already applied! + + predicted_probabilities has to have shape (c, x, y(, z)) where c is the number of classes/regions + """ + if not isinstance(predicted_probabilities, (np.ndarray, torch.Tensor)): + raise RuntimeError(f"Unexpected input type. Expected np.ndarray or torch.Tensor," + f" got {type(predicted_probabilities)}") + + if self.has_regions: + assert self.regions_class_order is not None, 'if region-based training is requested then you need to ' \ + 'define regions_class_order!' + # check correct number of outputs + assert predicted_probabilities.shape[0] == self.num_segmentation_heads, \ + f'unexpected number of channels in predicted_probabilities. Expected {self.num_segmentation_heads}, ' \ + f'got {predicted_probabilities.shape[0]}. Remeber that predicted_probabilities should have shape ' \ + f'(c, x, y(, z)).' + + if self.has_regions: + if isinstance(predicted_probabilities, np.ndarray): + segmentation = np.zeros(predicted_probabilities.shape[1:], dtype=np.uint16) + else: + # no uint16 in torch + segmentation = torch.zeros(predicted_probabilities.shape[1:], dtype=torch.int16, + device=predicted_probabilities.device) + for i, c in enumerate(self.regions_class_order): + segmentation[predicted_probabilities[i] > 0.5] = c + else: + segmentation = predicted_probabilities.argmax(0) + + return segmentation + + def convert_logits_to_segmentation(self, predicted_logits: Union[np.ndarray, torch.Tensor]) -> \ + Union[np.ndarray, torch.Tensor]: + input_is_numpy = isinstance(predicted_logits, np.ndarray) + probabilities = self.apply_inference_nonlin(predicted_logits) + if input_is_numpy and isinstance(probabilities, torch.Tensor): + probabilities = probabilities.cpu().numpy() + return self.convert_probabilities_to_segmentation(probabilities) + + def revert_cropping_on_probabilities(self, predicted_probabilities: Union[torch.Tensor, np.ndarray], + bbox: List[List[int]], + original_shape: Union[List[int], Tuple[int, ...]]): + """ + ONLY USE THIS WITH PROBABILITIES, DO NOT USE LOGITS AND DO NOT USE FOR SEGMENTATION MAPS!!! + + predicted_probabilities must be (c, x, y(, z)) + + Why do we do this here? Well if we pad probabilities we need to make sure that convert_logits_to_segmentation + correctly returns background in the padded areas. Also we want to ba able to look at the padded probabilities + and not have strange artifacts. + Only LabelManager knows how this needs to be done. So let's let him/her do it, ok? + """ + # revert cropping + probs_reverted_cropping = np.zeros((predicted_probabilities.shape[0], *original_shape), + dtype=predicted_probabilities.dtype) \ + if isinstance(predicted_probabilities, np.ndarray) else \ + torch.zeros((predicted_probabilities.shape[0], *original_shape), dtype=predicted_probabilities.dtype) + + if not self.has_regions: + probs_reverted_cropping[0] = 1 + + slicer = bounding_box_to_slice(bbox) + probs_reverted_cropping[tuple([slice(None)] + list(slicer))] = predicted_probabilities + return probs_reverted_cropping + + @staticmethod + def filter_background(classes_or_regions: Union[List[int], List[Union[int, Tuple[int, ...]]]]): + # heck yeah + # This is definitely taking list comprehension too far. Enjoy. + return [i for i in classes_or_regions if + ((not isinstance(i, (tuple, list))) and i != 0) + or + (isinstance(i, (tuple, list)) and not ( + len(np.unique(i)) == 1 and np.unique(i)[0] == 0))] + + @property + def foreground_regions(self): + return self.filter_background(self.all_regions) + + @property + def foreground_labels(self): + return self.filter_background(self.all_labels) + + @property + def num_segmentation_heads(self): + if self.has_regions: + return len(self.foreground_regions) + else: + return len(self.all_labels) + + +def get_labelmanager_class_from_plans(plans: dict) -> Type[LabelManager]: + if 'label_manager' not in plans.keys(): + print('No label manager specified in plans. Using default: LabelManager') + return LabelManager + else: + labelmanager_class = recursive_find_python_class(join(nnunetv2.__path__[0], "utilities", "label_handling"), + plans['label_manager'], + current_module="nnunetv2.utilities.label_handling") + return labelmanager_class + + +def convert_labelmap_to_one_hot(segmentation: Union[np.ndarray, torch.Tensor], + all_labels: Union[List, torch.Tensor, np.ndarray, tuple], + output_dtype=None) -> Union[np.ndarray, torch.Tensor]: + """ + if output_dtype is None then we use np.uint8/torch.uint8 + if input is torch.Tensor then output will be on the same device + + np.ndarray is faster than torch.Tensor + + if segmentation is torch.Tensor, this function will be faster if it is LongTensor. If it is somethine else we have + to cast which takes time. + + IMPORTANT: This function only works properly if your labels are consecutive integers, so something like 0, 1, 2, 3, ... + DO NOT use it with 0, 32, 123, 255, ... or whatever (fix your labels, yo) + """ + if isinstance(segmentation, torch.Tensor): + result = torch.zeros((len(all_labels), *segmentation.shape), + dtype=output_dtype if output_dtype is not None else torch.uint8, + device=segmentation.device) + # variant 1, 2x faster than 2 + result.scatter_(0, segmentation[None].long(), 1) # why does this have to be long!? + # variant 2, slower than 1 + # for i, l in enumerate(all_labels): + # result[i] = segmentation == l + else: + result = np.zeros((len(all_labels), *segmentation.shape), + dtype=output_dtype if output_dtype is not None else np.uint8) + # variant 1, fastest in my testing + for i, l in enumerate(all_labels): + result[i] = segmentation == l + # variant 2. Takes about twice as long so nah + # result = np.eye(len(all_labels))[segmentation].transpose((3, 0, 1, 2)) + return result + + +def determine_num_input_channels(plans_manager: PlansManager, + configuration_or_config_manager: Union[str, ConfigurationManager], + dataset_json: dict) -> int: + if isinstance(configuration_or_config_manager, str): + config_manager = plans_manager.get_configuration(configuration_or_config_manager) + else: + config_manager = configuration_or_config_manager + + label_manager = plans_manager.get_label_manager(dataset_json) + num_modalities = len(dataset_json['modality']) if 'modality' in dataset_json.keys() else len(dataset_json['channel_names']) + + # cascade has different number of input channels + if config_manager.previous_stage_name is not None: + num_label_inputs = len(label_manager.foreground_labels) + num_input_channels = num_modalities + num_label_inputs + else: + num_input_channels = num_modalities + return num_input_channels + + +if __name__ == '__main__': + # this code used to be able to differentiate variant 1 and 2 to measure time. + num_labels = 7 + seg = np.random.randint(0, num_labels, size=(256, 256, 256), dtype=np.uint8) + seg_torch = torch.from_numpy(seg) + st = time() + onehot_npy = convert_labelmap_to_one_hot(seg, np.arange(num_labels)) + time_1 = time() + onehot_npy2 = convert_labelmap_to_one_hot(seg, np.arange(num_labels)) + time_2 = time() + onehot_torch = convert_labelmap_to_one_hot(seg_torch, np.arange(num_labels)) + time_torch = time() + onehot_torch2 = convert_labelmap_to_one_hot(seg_torch, np.arange(num_labels)) + time_torch2 = time() + print( + f'np: {time_1 - st}, np2: {time_2 - time_1}, torch: {time_torch - time_2}, torch2: {time_torch2 - time_torch}') + onehot_torch = onehot_torch.numpy() + onehot_torch2 = onehot_torch2.numpy() + print(np.all(onehot_torch == onehot_npy)) + print(np.all(onehot_torch2 == onehot_npy)) diff --git a/nnUNet/nnunetv2/utilities/network_initialization.py b/nnUNet/nnunetv2/utilities/network_initialization.py new file mode 100644 index 0000000..1ead271 --- /dev/null +++ b/nnUNet/nnunetv2/utilities/network_initialization.py @@ -0,0 +1,12 @@ +from torch import nn + + +class InitWeights_He(object): + def __init__(self, neg_slope=1e-2): + self.neg_slope = neg_slope + + def __call__(self, module): + if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): + module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) + if module.bias is not None: + module.bias = nn.init.constant_(module.bias, 0) diff --git a/nnUNet/nnunetv2/utilities/overlay_plots.py b/nnUNet/nnunetv2/utilities/overlay_plots.py new file mode 100644 index 0000000..8b0b9d1 --- /dev/null +++ b/nnUNet/nnunetv2/utilities/overlay_plots.py @@ -0,0 +1,275 @@ +# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import multiprocessing +from multiprocessing.pool import Pool +from typing import Tuple, Union + +import numpy as np +import pandas as pd +from batchgenerators.utilities.file_and_folder_operations import * +from nnunetv2.configuration import default_num_processes +from nnunetv2.imageio.base_reader_writer import BaseReaderWriter +from nnunetv2.imageio.reader_writer_registry import determine_reader_writer_from_dataset_json +from nnunetv2.paths import nnUNet_raw, nnUNet_preprocessed +from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name +from nnunetv2.utilities.utils import get_identifiers_from_splitted_dataset_folder, \ + get_filenames_of_train_images_and_targets + +color_cycle = ( + "000000", + "4363d8", + "f58231", + "3cb44b", + "e6194B", + "911eb4", + "ffe119", + "bfef45", + "42d4f4", + "f032e6", + "000075", + "9A6324", + "808000", + "800000", + "469990", +) + + +def hex_to_rgb(hex: str): + assert len(hex) == 6 + return tuple(int(hex[i:i + 2], 16) for i in (0, 2, 4)) + + +def generate_overlay(input_image: np.ndarray, segmentation: np.ndarray, mapping: dict = None, + color_cycle: Tuple[str, ...] = color_cycle, + overlay_intensity: float = 0.6): + """ + image can be 2d greyscale or 2d RGB (color channel in last dimension!) + + Segmentation must be label map of same shape as image (w/o color channels) + + mapping can be label_id -> idx_in_cycle or None + + returned image is scaled to [0, 255] (uint8)!!! + """ + # create a copy of image + image = np.copy(input_image) + + if len(image.shape) == 2: + image = np.tile(image[:, :, None], (1, 1, 3)) + elif len(image.shape) == 3: + if image.shape[2] == 1: + image = np.tile(image, (1, 1, 3)) + else: + raise RuntimeError(f'if 3d image is given the last dimension must be the color channels (3 channels). ' + f'Only 2D images are supported. Your image shape: {image.shape}') + else: + raise RuntimeError("unexpected image shape. only 2D images and 2D images with color channels (color in " + "last dimension) are supported") + + # rescale image to [0, 255] + image = image - image.min() + image = image / image.max() * 255 + + # create output + if mapping is None: + uniques = np.sort(pd.unique(segmentation.ravel())) # np.unique(segmentation) + mapping = {i: c for c, i in enumerate(uniques)} + + for l in mapping.keys(): + image[segmentation == l] += overlay_intensity * np.array(hex_to_rgb(color_cycle[mapping[l]])) + + # rescale result to [0, 255] + image = image / image.max() * 255 + return image.astype(np.uint8) + + +def select_slice_to_plot(image: np.ndarray, segmentation: np.ndarray) -> int: + """ + image and segmentation are expected to be 3D + + selects the slice with the largest amount of fg (regardless of label) + + we give image so that we can easily replace this function if needed + """ + fg_mask = segmentation != 0 + fg_per_slice = fg_mask.sum((1, 2)) + selected_slice = int(np.argmax(fg_per_slice)) + return selected_slice + + +def select_slice_to_plot2(image: np.ndarray, segmentation: np.ndarray) -> int: + """ + image and segmentation are expected to be 3D (or 1, x, y) + + selects the slice with the largest amount of fg (how much percent of each class are in each slice? pick slice + with highest avg percent) + + we give image so that we can easily replace this function if needed + """ + classes = [i for i in np.sort(pd.unique(segmentation.ravel())) if i != 0] + fg_per_slice = np.zeros((image.shape[0], len(classes))) + for i, c in enumerate(classes): + fg_mask = segmentation == c + fg_per_slice[:, i] = fg_mask.sum((1, 2)) + fg_per_slice[:, i] /= fg_per_slice.sum() + fg_per_slice = fg_per_slice.mean(1) + return int(np.argmax(fg_per_slice)) + + +def plot_overlay(image_file: str, segmentation_file: str, image_reader_writer: BaseReaderWriter, output_file: str, + overlay_intensity: float = 0.6): + import matplotlib.pyplot as plt + + image, props = image_reader_writer.read_images((image_file, )) + image = image[0] + seg, props_seg = image_reader_writer.read_seg(segmentation_file) + seg = seg[0] + + assert all([i == j for i, j in zip(image.shape, seg.shape)]), "image and seg do not have the same shape: %s, %s" % ( + image_file, segmentation_file) + + assert len(image.shape) == 3, 'only 3D images/segs are supported' + + selected_slice = select_slice_to_plot2(image, seg) + # print(image.shape, selected_slice) + + overlay = generate_overlay(image[selected_slice], seg[selected_slice], overlay_intensity=overlay_intensity) + + plt.imsave(output_file, overlay) + + +def plot_overlay_preprocessed(case_file: str, output_file: str, overlay_intensity: float = 0.6, channel_idx=0): + import matplotlib.pyplot as plt + data = np.load(case_file)['data'] + seg = np.load(case_file)['seg'][0] + + assert channel_idx < (data.shape[0]), 'This dataset only supports channel index up to %d' % (data.shape[0] - 1) + + image = data[channel_idx] + seg[seg < 0] = 0 + + selected_slice = select_slice_to_plot2(image, seg) + + overlay = generate_overlay(image[selected_slice], seg[selected_slice], overlay_intensity=overlay_intensity) + + plt.imsave(output_file, overlay) + + +def multiprocessing_plot_overlay(list_of_image_files, list_of_seg_files, image_reader_writer, + list_of_output_files, overlay_intensity, + num_processes=8): + with multiprocessing.get_context("spawn").Pool(num_processes) as p: + r = p.starmap_async(plot_overlay, zip( + list_of_image_files, list_of_seg_files, [image_reader_writer] * len(list_of_output_files), + list_of_output_files, [overlay_intensity] * len(list_of_output_files) + )) + r.get() + + +def multiprocessing_plot_overlay_preprocessed(list_of_case_files, list_of_output_files, overlay_intensity, + num_processes=8, channel_idx=0): + with multiprocessing.get_context("spawn").Pool(num_processes) as p: + r = p.starmap_async(plot_overlay_preprocessed, zip( + list_of_case_files, list_of_output_files, [overlay_intensity] * len(list_of_output_files), + [channel_idx] * len(list_of_output_files) + )) + r.get() + + +def generate_overlays_from_raw(dataset_name_or_id: Union[int, str], output_folder: str, + num_processes: int = 8, channel_idx: int = 0, overlay_intensity: float = 0.6): + dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id) + folder = join(nnUNet_raw, dataset_name) + dataset_json = load_json(join(folder, 'dataset.json')) + dataset = get_filenames_of_train_images_and_targets(folder, dataset_json) + + image_files = [v['images'][channel_idx] for v in dataset.values()] + seg_files = [v['label'] for v in dataset.values()] + + assert all([isfile(i) for i in image_files]) + assert all([isfile(i) for i in seg_files]) + + maybe_mkdir_p(output_folder) + output_files = [join(output_folder, i + '.png') for i in dataset.keys()] + + image_reader_writer = determine_reader_writer_from_dataset_json(dataset_json, image_files[0])() + multiprocessing_plot_overlay(image_files, seg_files, image_reader_writer, output_files, overlay_intensity, num_processes) + + +def generate_overlays_from_preprocessed(dataset_name_or_id: Union[int, str], output_folder: str, + num_processes: int = 8, channel_idx: int = 0, + configuration: str = None, + plans_identifier: str = 'nnUNetPlans', + overlay_intensity: float = 0.6): + dataset_name = maybe_convert_to_dataset_name(dataset_name_or_id) + folder = join(nnUNet_preprocessed, dataset_name) + if not isdir(folder): raise RuntimeError("run preprocessing for that task first") + + plans = load_json(join(folder, plans_identifier + '.json')) + if configuration is None: + if '3d_fullres' in plans['configurations'].keys(): + configuration = '3d_fullres' + else: + configuration = '2d' + data_identifier = plans['configurations'][configuration]["data_identifier"] + preprocessed_folder = join(folder, data_identifier) + + if not isdir(preprocessed_folder): + raise RuntimeError(f"Preprocessed data folder for configuration {configuration} of plans identifier " + f"{plans_identifier} ({dataset_name}) does not exist. Run preprocessing for this " + f"configuration first!") + + identifiers = [i[:-4] for i in subfiles(preprocessed_folder, suffix='.npz', join=False)] + + output_files = [join(output_folder, i + '.png') for i in identifiers] + image_files = [join(preprocessed_folder, i + ".npz") for i in identifiers] + + maybe_mkdir_p(output_folder) + multiprocessing_plot_overlay_preprocessed(image_files, output_files, overlay_intensity=overlay_intensity, + num_processes=num_processes, channel_idx=channel_idx) + + +def entry_point_generate_overlay(): + import argparse + parser = argparse.ArgumentParser("Plots png overlays of the slice with the most foreground. Note that this " + "disregards spacing information!") + parser.add_argument('-d', type=str, help="Dataset name or id", required=True) + parser.add_argument('-o', type=str, help="output folder", required=True) + parser.add_argument('-np', type=int, default=default_num_processes, required=False, + help=f"number of processes used. Default: {default_num_processes}") + parser.add_argument('-channel_idx', type=int, default=0, required=False, + help="channel index used (0 = _0000). Default: 0") + parser.add_argument('--use_raw', action='store_true', required=False, help="if set then we use raw data. else " + "we use preprocessed") + parser.add_argument('-p', type=str, required=False, default='nnUNetPlans', + help='plans identifier. Only used if --use_raw is not set! Default: nnUNetPlans') + parser.add_argument('-c', type=str, required=False, default=None, + help='configuration name. Only used if --use_raw is not set! Default: None = ' + '3d_fullres if available, else 2d') + parser.add_argument('-overlay_intensity', type=float, required=False, default=0.6, + help='overlay intensity. Higher = brighter/less transparent') + + + args = parser.parse_args() + + if args.use_raw: + generate_overlays_from_raw(args.d, args.o, args.np, args.channel_idx, + overlay_intensity=args.overlay_intensity) + else: + generate_overlays_from_preprocessed(args.d, args.o, args.np, args.channel_idx, args.c, args.p, + overlay_intensity=args.overlay_intensity) + + +if __name__ == '__main__': + entry_point_generate_overlay() \ No newline at end of file diff --git a/nnUNet/nnunetv2/utilities/plans_handling/__init__.py b/nnUNet/nnunetv2/utilities/plans_handling/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nnUNet/nnunetv2/utilities/plans_handling/plans_handler.py b/nnUNet/nnunetv2/utilities/plans_handling/plans_handler.py new file mode 100644 index 0000000..6c39fd1 --- /dev/null +++ b/nnUNet/nnunetv2/utilities/plans_handling/plans_handler.py @@ -0,0 +1,307 @@ +from __future__ import annotations + +import dynamic_network_architectures +from copy import deepcopy +from functools import lru_cache, partial +from typing import Union, Tuple, List, Type, Callable + +import numpy as np +import torch + +from nnunetv2.preprocessing.resampling.utils import recursive_find_resampling_fn_by_name +from torch import nn + +import nnunetv2 +from batchgenerators.utilities.file_and_folder_operations import load_json, join + +from nnunetv2.imageio.reader_writer_registry import recursive_find_reader_writer_by_name +from nnunetv2.utilities.find_class_by_name import recursive_find_python_class +from nnunetv2.utilities.label_handling.label_handling import get_labelmanager_class_from_plans + + +# see https://adamj.eu/tech/2021/05/13/python-type-hints-how-to-fix-circular-imports/ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from nnunetv2.utilities.label_handling.label_handling import LabelManager + from nnunetv2.imageio.base_reader_writer import BaseReaderWriter + from nnunetv2.preprocessing.preprocessors.default_preprocessor import DefaultPreprocessor + from nnunetv2.experiment_planning.experiment_planners.default_experiment_planner import ExperimentPlanner + + +class ConfigurationManager(object): + def __init__(self, configuration_dict: dict): + self.configuration = configuration_dict + + def __repr__(self): + return self.configuration.__repr__() + + @property + def data_identifier(self) -> str: + return self.configuration['data_identifier'] + + @property + def preprocessor_name(self) -> str: + return self.configuration['preprocessor_name'] + + @property + @lru_cache(maxsize=1) + def preprocessor_class(self) -> Type[DefaultPreprocessor]: + preprocessor_class = recursive_find_python_class(join(nnunetv2.__path__[0], "preprocessing"), + self.preprocessor_name, + current_module="nnunetv2.preprocessing") + return preprocessor_class + + @property + def batch_size(self) -> int: + return self.configuration['batch_size'] + + @property + def patch_size(self) -> List[int]: + return self.configuration['patch_size'] + + @property + def median_image_size_in_voxels(self) -> List[int]: + return self.configuration['median_image_size_in_voxels'] + + @property + def spacing(self) -> List[float]: + return self.configuration['spacing'] + + @property + def normalization_schemes(self) -> List[str]: + return self.configuration['normalization_schemes'] + + @property + def use_mask_for_norm(self) -> List[bool]: + return self.configuration['use_mask_for_norm'] + + @property + def UNet_class_name(self) -> str: + return self.configuration['UNet_class_name'] + + @property + @lru_cache(maxsize=1) + def UNet_class(self) -> Type[nn.Module]: + unet_class = recursive_find_python_class(join(dynamic_network_architectures.__path__[0], "architectures"), + self.UNet_class_name, + current_module="dynamic_network_architectures.architectures") + if unet_class is None: + raise RuntimeError('The network architecture specified by the plans file ' + 'is non-standard (maybe your own?). Fix this by not using ' + 'ConfigurationManager.UNet_class to instantiate ' + 'it (probably just overwrite build_network_architecture of your trainer.') + return unet_class + + @property + def UNet_base_num_features(self) -> int: + return self.configuration['UNet_base_num_features'] + + @property + def n_conv_per_stage_encoder(self) -> List[int]: + return self.configuration['n_conv_per_stage_encoder'] + + @property + def n_conv_per_stage_decoder(self) -> List[int]: + return self.configuration['n_conv_per_stage_decoder'] + + @property + def num_pool_per_axis(self) -> List[int]: + return self.configuration['num_pool_per_axis'] + + @property + def pool_op_kernel_sizes(self) -> List[List[int]]: + return self.configuration['pool_op_kernel_sizes'] + + @property + def conv_kernel_sizes(self) -> List[List[int]]: + return self.configuration['conv_kernel_sizes'] + + @property + def unet_max_num_features(self) -> int: + return self.configuration['unet_max_num_features'] + + @property + @lru_cache(maxsize=1) + def resampling_fn_data(self) -> Callable[ + [Union[torch.Tensor, np.ndarray], + Union[Tuple[int, ...], List[int], np.ndarray], + Union[Tuple[float, ...], List[float], np.ndarray], + Union[Tuple[float, ...], List[float], np.ndarray] + ], + Union[torch.Tensor, np.ndarray]]: + fn = recursive_find_resampling_fn_by_name(self.configuration['resampling_fn_data']) + fn = partial(fn, **self.configuration['resampling_fn_data_kwargs']) + return fn + + @property + @lru_cache(maxsize=1) + def resampling_fn_probabilities(self) -> Callable[ + [Union[torch.Tensor, np.ndarray], + Union[Tuple[int, ...], List[int], np.ndarray], + Union[Tuple[float, ...], List[float], np.ndarray], + Union[Tuple[float, ...], List[float], np.ndarray] + ], + Union[torch.Tensor, np.ndarray]]: + fn = recursive_find_resampling_fn_by_name(self.configuration['resampling_fn_probabilities']) + fn = partial(fn, **self.configuration['resampling_fn_probabilities_kwargs']) + return fn + + @property + @lru_cache(maxsize=1) + def resampling_fn_seg(self) -> Callable[ + [Union[torch.Tensor, np.ndarray], + Union[Tuple[int, ...], List[int], np.ndarray], + Union[Tuple[float, ...], List[float], np.ndarray], + Union[Tuple[float, ...], List[float], np.ndarray] + ], + Union[torch.Tensor, np.ndarray]]: + fn = recursive_find_resampling_fn_by_name(self.configuration['resampling_fn_seg']) + fn = partial(fn, **self.configuration['resampling_fn_seg_kwargs']) + return fn + + @property + def batch_dice(self) -> bool: + return self.configuration['batch_dice'] + + @property + def next_stage_names(self) -> Union[List[str], None]: + ret = self.configuration.get('next_stage') + if ret is not None: + if isinstance(ret, str): + ret = [ret] + return ret + + @property + def previous_stage_name(self) -> Union[str, None]: + return self.configuration.get('previous_stage') + + +class PlansManager(object): + def __init__(self, plans_file_or_dict: Union[str, dict]): + """ + Why do we need this? + 1) resolve inheritance in configurations + 2) expose otherwise annoying stuff like getting the label manager or IO class from a string + 3) clearly expose the things that are in the plans instead of hiding them in a dict + 4) cache shit + + This class does not prevent you from going wild. You can still use the plans directly if you prefer + (PlansHandler.plans['key']) + """ + self.plans = plans_file_or_dict if isinstance(plans_file_or_dict, dict) else load_json(plans_file_or_dict) + + def __repr__(self): + return self.plans.__repr__() + + def _internal_resolve_configuration_inheritance(self, configuration_name: str, + visited: Tuple[str, ...] = None) -> dict: + if configuration_name not in self.plans['configurations'].keys(): + raise ValueError(f'The configuration {configuration_name} does not exist in the plans I have. Valid ' + f'configuration names are {list(self.plans["configurations"].keys())}.') + configuration = deepcopy(self.plans['configurations'][configuration_name]) + if 'inherits_from' in configuration: + parent_config_name = configuration['inherits_from'] + + if visited is None: + visited = (configuration_name,) + else: + if parent_config_name in visited: + raise RuntimeError(f"Circular dependency detected. The following configurations were visited " + f"while solving inheritance (in that order!): {visited}. " + f"Current configuration: {configuration_name}. Its parent configuration " + f"is {parent_config_name}.") + visited = (*visited, configuration_name) + + base_config = self._internal_resolve_configuration_inheritance(parent_config_name, visited) + base_config.update(configuration) + configuration = base_config + return configuration + + @lru_cache(maxsize=10) + def get_configuration(self, configuration_name: str): + if configuration_name not in self.plans['configurations'].keys(): + raise RuntimeError(f"Requested configuration {configuration_name} not found in plans. " + f"Available configurations: {list(self.plans['configurations'].keys())}") + + configuration_dict = self._internal_resolve_configuration_inheritance(configuration_name) + return ConfigurationManager(configuration_dict) + + @property + def dataset_name(self) -> str: + return self.plans['dataset_name'] + + @property + def plans_name(self) -> str: + return self.plans['plans_name'] + + @property + def original_median_spacing_after_transp(self) -> List[float]: + return self.plans['original_median_spacing_after_transp'] + + @property + def original_median_shape_after_transp(self) -> List[float]: + return self.plans['original_median_shape_after_transp'] + + @property + @lru_cache(maxsize=1) + def image_reader_writer_class(self) -> Type[BaseReaderWriter]: + return recursive_find_reader_writer_by_name(self.plans['image_reader_writer']) + + @property + def transpose_forward(self) -> List[int]: + return self.plans['transpose_forward'] + + @property + def transpose_backward(self) -> List[int]: + return self.plans['transpose_backward'] + + @property + def available_configurations(self) -> List[str]: + return list(self.plans['configurations'].keys()) + + @property + @lru_cache(maxsize=1) + def experiment_planner_class(self) -> Type[ExperimentPlanner]: + planner_name = self.experiment_planner_name + experiment_planner = recursive_find_python_class(join(nnunetv2.__path__[0], "experiment_planning"), + planner_name, + current_module="nnunetv2.experiment_planning") + return experiment_planner + + @property + def experiment_planner_name(self) -> str: + return self.plans['experiment_planner_used'] + + @property + @lru_cache(maxsize=1) + def label_manager_class(self) -> Type[LabelManager]: + return get_labelmanager_class_from_plans(self.plans) + + def get_label_manager(self, dataset_json: dict, **kwargs) -> LabelManager: + return self.label_manager_class(label_dict=dataset_json['labels'], + regions_class_order=dataset_json.get('regions_class_order'), + **kwargs) + + @property + def foreground_intensity_properties_per_channel(self) -> dict: + if 'foreground_intensity_properties_per_channel' not in self.plans.keys(): + if 'foreground_intensity_properties_by_modality' in self.plans.keys(): + return self.plans['foreground_intensity_properties_by_modality'] + return self.plans['foreground_intensity_properties_per_channel'] + + +if __name__ == '__main__': + from nnunetv2.paths import nnUNet_preprocessed + from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name + + plans = load_json(join(nnUNet_preprocessed, maybe_convert_to_dataset_name(3), 'nnUNetPlans.json')) + # build new configuration that inherits from 3d_fullres + plans['configurations']['3d_fullres_bs4'] = { + 'batch_size': 4, + 'inherits_from': '3d_fullres' + } + # now get plans and configuration managers + plans_manager = PlansManager(plans) + configuration_manager = plans_manager.get_configuration('3d_fullres_bs4') + print(configuration_manager) # look for batch size 4 diff --git a/nnUNet/nnunetv2/utilities/tensor_utilities.py b/nnUNet/nnunetv2/utilities/tensor_utilities.py new file mode 100644 index 0000000..b16ffca --- /dev/null +++ b/nnUNet/nnunetv2/utilities/tensor_utilities.py @@ -0,0 +1,15 @@ +from typing import Union, List, Tuple + +import numpy as np +import torch + + +def sum_tensor(inp: torch.Tensor, axes: Union[np.ndarray, Tuple, List], keepdim: bool = False) -> torch.Tensor: + axes = np.unique(axes).astype(int) + if keepdim: + for ax in axes: + inp = inp.sum(int(ax), keepdim=True) + else: + for ax in sorted(axes, reverse=True): + inp = inp.sum(int(ax)) + return inp diff --git a/nnUNet/nnunetv2/utilities/utils.py b/nnUNet/nnunetv2/utilities/utils.py new file mode 100644 index 0000000..b0c16a2 --- /dev/null +++ b/nnUNet/nnunetv2/utilities/utils.py @@ -0,0 +1,69 @@ +# Copyright 2021 HIP Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center +# (DKFZ), Heidelberg, Germany +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os.path +from functools import lru_cache +from typing import Union + +from batchgenerators.utilities.file_and_folder_operations import * +import numpy as np +import re + +from nnunetv2.paths import nnUNet_raw + + +def get_identifiers_from_splitted_dataset_folder(folder: str, file_ending: str): + files = subfiles(folder, suffix=file_ending, join=False) + # all files have a 4 digit channel index (_XXXX) + crop = len(file_ending) + 5 + files = [i[:-crop] for i in files] + # only unique image ids + files = np.unique(files) + return files + + +def create_lists_from_splitted_dataset_folder(folder: str, file_ending: str, identifiers: List[str] = None) -> List[ + List[str]]: + """ + does not rely on dataset.json + """ + if identifiers is None: + identifiers = get_identifiers_from_splitted_dataset_folder(folder, file_ending) + files = subfiles(folder, suffix=file_ending, join=False, sort=True) + list_of_lists = [] + for f in identifiers: + p = re.compile(re.escape(f) + r"_\d\d\d\d" + re.escape(file_ending)) + list_of_lists.append([join(folder, i) for i in files if p.fullmatch(i)]) + return list_of_lists + + +def get_filenames_of_train_images_and_targets(raw_dataset_folder: str, dataset_json: dict = None): + if dataset_json is None: + dataset_json = load_json(join(raw_dataset_folder, 'dataset.json')) + + if 'dataset' in dataset_json.keys(): + dataset = dataset_json['dataset'] + for k in dataset.keys(): + dataset[k]['label'] = os.path.abspath(join(raw_dataset_folder, dataset[k]['label'])) if not os.path.isabs(dataset[k]['label']) else dataset[k]['label'] + dataset[k]['images'] = [os.path.abspath(join(raw_dataset_folder, i)) if not os.path.isabs(i) else i for i in dataset[k]['images']] + else: + identifiers = get_identifiers_from_splitted_dataset_folder(join(raw_dataset_folder, 'imagesTr'), dataset_json['file_ending']) + images = create_lists_from_splitted_dataset_folder(join(raw_dataset_folder, 'imagesTr'), dataset_json['file_ending'], identifiers) + segs = [join(raw_dataset_folder, 'labelsTr', i + dataset_json['file_ending']) for i in identifiers] + dataset = {i: {'images': im, 'label': se} for i, im, se in zip(identifiers, images, segs)} + return dataset + + +if __name__ == '__main__': + print(get_filenames_of_train_images_and_targets(join(nnUNet_raw, 'Dataset002_Heart'))) diff --git a/nnUNet/readme.md b/nnUNet/readme.md new file mode 100644 index 0000000..1b6ba5f --- /dev/null +++ b/nnUNet/readme.md @@ -0,0 +1,133 @@ +# Welcome to the new nnU-Net! + +Click [here](https://github.com/MIC-DKFZ/nnUNet/tree/nnunetv1) if you were looking for the old one instead. + +Coming from V1? Check out the [TLDR Migration Guide](documentation/tldr_migration_guide_from_v1.md). Reading the rest of the documentation is still strongly recommended ;-) + +# What is nnU-Net? +Image datasets are enormously diverse: image dimensionality (2D, 3D), modalities/input channels (RGB image, CT, MRI, microscopy, ...), +image sizes, voxel sizes, class ratio, target structure properties and more change substantially between datasets. +Traditionally, given a new problem, a tailored solution needs to be manually designed and optimized - a process that +is prone to errors, not scalable and where success is overwhelmingly determined by the skill of the experimenter. Even +for experts, this process is anything but simple: there are not only many design choices and data properties that need to +be considered, but they are also tightly interconnected, rendering reliable manual pipeline optimization all but impossible! + +![nnU-Net overview](documentation/assets/nnU-Net_overview.png) + +**nnU-Net is a semantic segmentation method that automatically adapts to a given dataset. It will analyze the provided +training cases and automatically configure a matching U-Net-based segmentation pipeline. No expertise required on your +end! You can simply train the models and use them for your application**. + +Upon release, nnU-Net was evaluated on 23 datasets belonging to competitions from the biomedical domain. Despite competing +with handcrafted solutions for each respective dataset, nnU-Net's fully automated pipeline scored several first places on +open leaderboards! Since then nnU-Net has stood the test of time: it continues to be used as a baseline and method +development framework ([9 out of 10 challenge winners at MICCAI 2020](https://arxiv.org/abs/2101.00232) and 5 out of 7 +in MICCAI 2021 built their methods on top of nnU-Net, + [we won AMOS2022 with nnU-Net](https://amos22.grand-challenge.org/final-ranking/))! + +Please cite the [following paper](https://www.google.com/url?q=https://www.nature.com/articles/s41592-020-01008-z&sa=D&source=docs&ust=1677235958581755&usg=AOvVaw3dWL0SrITLhCJUBiNIHCQO) when using nnU-Net: + + Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring + method for deep learning-based biomedical image segmentation. Nature methods, 18(2), 203-211. + + +## What can nnU-Net do for you? +If you are a **domain scientist** (biologist, radiologist, ...) looking to analyze your own images, nnU-Net provides +an out-of-the-box solution that is all but guaranteed to provide excellent results on your individual dataset. Simply +convert your dataset into the nnU-Net format and enjoy the power of AI - no expertise required! + +If you are an **AI researcher** developing segmentation methods, nnU-Net: +- offers a fantastic out-of-the-box applicable baseline algorithm to compete against +- can act as a method development framework to test your contribution on a large number of datasets without having to +tune individual pipelines (for example evaluating a new loss function) +- provides a strong starting point for further dataset-specific optimizations. This is particularly used when competing +in segmentation challenges +- provides a new perspective on the design of segmentation methods: maybe you can find better connections between +dataset properties and best-fitting segmentation pipelines? + +## What is the scope of nnU-Net? +nnU-Net is built for semantic segmentation. It can handle 2D and 3D images with arbitrary +input modalities/channels. It can understand voxel spacings, anisotropies and is robust even when classes are highly +imbalanced. + +nnU-Net relies on supervised learning, which means that you need to provide training cases for your application. The number of +required training cases varies heavily depending on the complexity of the segmentation problem. No +one-fits-all number can be provided here! nnU-Net does not require more training cases than other solutions - maybe +even less due to our extensive use of data augmentation. + +nnU-Net expects to be able to process entire images at once during preprocessing and postprocessing, so it cannot +handle enormous images. As a reference: we tested images from 40x40x40 pixels all the way up to 1500x1500x1500 in 3D +and 40x40 up to ~30000x30000 in 2D! If your RAM allows it, larger is always possible. + +## How does nnU-Net work? +Given a new dataset, nnU-Net will systematically analyze the provided training cases and create a 'dataset fingerprint'. +nnU-Net then creates several U-Net configurations for each dataset: +- `2d`: a 2D U-Net (for 2D and 3D datasets) +- `3d_fullres`: a 3D U-Net that operates on a high image resolution (for 3D datasets only) +- `3d_lowres` → `3d_cascade_fullres`: a 3D U-Net cascade where first a 3D U-Net operates on low resolution images and +then a second high-resolution 3D U-Net refined the predictions of the former (for 3D datasets with large image sizes only) + +**Note that not all U-Net configurations are created for all datasets. In datasets with small image sizes, the +U-Net cascade (and with it the 3d_lowres configuration) is omitted because the patch size of the full +resolution U-Net already covers a large part of the input images.** + +nnU-Net configures its segmentation pipelines based on a three-step recipe: +- **Fixed parameters** are not adapted. During development of nnU-Net we identified a robust configuration (that is, certain architecture and training properties) that can +simply be used all the time. This includes, for example, nnU-Net's loss function, (most of the) data augmentation strategy and learning rate. +- **Rule-based parameters** use the dataset fingerprint to adapt certain segmentation pipeline properties by following +hard-coded heuristic rules. For example, the network topology (pooling behavior and depth of the network architecture) +are adapted to the patch size; the patch size, network topology and batch size are optimized jointly given some GPU +memory constraint. +- **Empirical parameters** are essentially trial-and-error. For example the selection of the best U-net configuration +for the given dataset (2D, 3D full resolution, 3D low resolution, 3D cascade) and the optimization of the postprocessing strategy. + +## How to get started? +Read these: +- [Installation instructions](documentation/installation_instructions.md) +- [Dataset conversion](documentation/dataset_format.md) +- [Usage instructions](documentation/how_to_use_nnunet.md) + +Additional information: +- [Region-based training](documentation/region_based_training.md) +- [Manual data splits](documentation/manual_data_splits.md) +- [Pretraining and finetuning](documentation/pretraining_and_finetuning.md) +- [Intensity Normalization in nnU-Net](documentation/explanation_normalization.md) +- [Manually editing nnU-Net configurations](documentation/explanation_plans_files.md) +- [Extending nnU-Net](documentation/extending_nnunet.md) +- [What is different in V2?](documentation/changelog.md) + +[//]: # (- [Ignore label](documentation/ignore_label.md)) + +## Where does nnU-net perform well and where does it not perform? +nnU-Net excels in segmentation problems that need to be solved by training from scratch, +for example: research applications that feature non-standard image modalities and input channels, +challenge datasets from the biomedical domain, majority of 3D segmentation problems, etc . We have yet to find a +dataset for which nnU-Net's working principle fails! + +Note: On standard segmentation +problems, such as 2D RGB images in ADE20k and Cityscapes, fine-tuning a foundation model (that was pretrained on a large corpus of +similar images, e.g. Imagenet 22k, JFT-300M) will provide better performance than nnU-Net! That is simply because these +models allow much better initialization. Foundation models are not supported by nnU-Net as +they 1) are not useful for segmentation problems that deviate from the standard setting (see above mentioned +datasets), 2) would typically only support 2D architectures and 3) conflict with our core design principle of carefully adapting +the network topology for each dataset (if the topology is changed one can no longer transfer pretrained weights!) + +## What happened to the old nnU-Net? +The core of the old nnU-Net was hacked together in a short time period while participating in the Medical Segmentation +Decathlon challenge in 2018. Consequently, code structure and quality were not the best. Many features +were added later on and didn't quite fit into the nnU-Net design principles. Overall quite messy, really. And annoying to work with. + +nnU-Net V2 is a complete overhaul. The "delete everything and start again" kind. So everything is better +(in the author's opinion haha). While the segmentation performance [remains the same](https://docs.google.com/spreadsheets/d/13gqjIKEMPFPyMMMwA1EML57IyoBjfC3-QCTn4zRN_Mg/edit?usp=sharing), a lot of cool stuff has been added. +It is now also much easier to use it as a development framework and to manually fine-tune its configuration to new +datasets. A big driver for the reimplementation was also the emergence of [Helmholtz Imaging](http://helmholtz-imaging.de), +prompting us to extend nnU-Net to more image formats and domains. Take a look [here](documentation/changelog.md) for some highlights. + +# Acknowledgements + + + + +nnU-Net is developed and maintained by the Applied Computer Vision Lab (ACVL) of [Helmholtz Imaging](http://helmholtz-imaging.de) +and the [Division of Medical Image Computing](https://www.dkfz.de/en/mic/index.php) at the +[German Cancer Research Center (DKFZ)](https://www.dkfz.de/en/index.html). diff --git a/nnUNet/setup.cfg b/nnUNet/setup.cfg new file mode 100644 index 0000000..002a15d --- /dev/null +++ b/nnUNet/setup.cfg @@ -0,0 +1,2 @@ +[metadata] +description-file = readme.md \ No newline at end of file diff --git a/nnUNet/setup.py b/nnUNet/setup.py new file mode 100755 index 0000000..5b9271f --- /dev/null +++ b/nnUNet/setup.py @@ -0,0 +1,62 @@ +from setuptools import setup, find_namespace_packages + +setup(name='nnunetv2', + packages=find_namespace_packages(include=["nnunetv2", "nnunetv2.*"]), + version='2.1.1', + description='nnU-Net. Framework for out-of-the box biomedical image segmentation.', + url='https://github.com/MIC-DKFZ/nnUNet', + author='Helmholtz Imaging Applied Computer Vision Lab, Division of Medical Image Computing, German Cancer Research Center', + author_email='f.isensee@dkfz-heidelberg.de', + license='Apache License Version 2.0, January 2004', + python_requires=">=3.9", + install_requires=[ + "torch>=2.0.0", + "acvl-utils>=0.2", + "dynamic-network-architectures>=0.2", + "tqdm", + "dicom2nifti", + "scikit-image>=0.14", + "scipy", + "batchgenerators>=0.25", + "numpy", + "scikit-learn", + "scikit-image>=0.19.3", + "SimpleITK>=2.2.1", + "pandas", + "graphviz", + 'tifffile', + 'requests', + "nibabel", + "matplotlib", + "seaborn", + "imagecodecs", + "yacs" + ], + entry_points={ + 'console_scripts': [ + 'nnUNetv2_plan_and_preprocess = nnunetv2.experiment_planning.plan_and_preprocess_entrypoints:plan_and_preprocess_entry', # api available + 'nnUNetv2_extract_fingerprint = nnunetv2.experiment_planning.plan_and_preprocess_entrypoints:extract_fingerprint_entry', # api available + 'nnUNetv2_plan_experiment = nnunetv2.experiment_planning.plan_and_preprocess_entrypoints:plan_experiment_entry', # api available + 'nnUNetv2_preprocess = nnunetv2.experiment_planning.plan_and_preprocess_entrypoints:preprocess_entry', # api available + 'nnUNetv2_train = nnunetv2.run.run_training:run_training_entry', # api available + 'nnUNetv2_predict_from_modelfolder = nnunetv2.inference.predict_from_raw_data:predict_entry_point_modelfolder', # api available + 'nnUNetv2_predict = nnunetv2.inference.predict_from_raw_data:predict_entry_point', # api available + 'nnUNetv2_convert_old_nnUNet_dataset = nnunetv2.dataset_conversion.convert_raw_dataset_from_old_nnunet_format:convert_entry_point', # api available + 'nnUNetv2_find_best_configuration = nnunetv2.evaluation.find_best_configuration:find_best_configuration_entry_point', # api available + 'nnUNetv2_determine_postprocessing = nnunetv2.postprocessing.remove_connected_components:entry_point_determine_postprocessing_folder', # api available + 'nnUNetv2_apply_postprocessing = nnunetv2.postprocessing.remove_connected_components:entry_point_apply_postprocessing', # api available + 'nnUNetv2_ensemble = nnunetv2.ensembling.ensemble:entry_point_ensemble_folders', # api available + 'nnUNetv2_accumulate_crossval_results = nnunetv2.evaluation.find_best_configuration:accumulate_crossval_results_entry_point', # api available + 'nnUNetv2_plot_overlay_pngs = nnunetv2.utilities.overlay_plots:entry_point_generate_overlay', # api available + 'nnUNetv2_download_pretrained_model_by_url = nnunetv2.model_sharing.entry_points:download_by_url', # api available + 'nnUNetv2_install_pretrained_model_from_zip = nnunetv2.model_sharing.entry_points:install_from_zip_entry_point', # api available + 'nnUNetv2_export_model_to_zip = nnunetv2.model_sharing.entry_points:export_pretrained_model_entry', # api available + 'nnUNetv2_move_plans_between_datasets = nnunetv2.experiment_planning.plans_for_pretraining.move_plans_between_datasets:entry_point_move_plans_between_datasets', # api available + 'nnUNetv2_evaluate_folder = nnunetv2.evaluation.evaluate_predictions:evaluate_folder_entry_point', # api available + 'nnUNetv2_evaluate_simple = nnunetv2.evaluation.evaluate_predictions:evaluate_simple_entry_point', # api available + 'nnUNetv2_convert_MSD_dataset = nnunetv2.dataset_conversion.convert_MSD_dataset:entry_point' # api available + ], + }, + keywords=['deep learning', 'image segmentation', 'medical image analysis', + 'medical image segmentation', 'nnU-Net', 'nnunet'] + ) diff --git a/process.py b/process.py new file mode 100644 index 0000000..bbabfdd --- /dev/null +++ b/process.py @@ -0,0 +1,158 @@ +import time +import SimpleITK as sitk +import numpy as np +np.lib.index_tricks.int = np.uint16 +import ants +from os.path import join +from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor +import json +from custom_algorithm import Hanseg2023Algorithm + +LABEL_dict = { + "background": 0, + "A_Carotid_L": 1, + "A_Carotid_R": 2, + "Arytenoid": 3, + "Bone_Mandible": 4, + "Brainstem": 5, + "BuccalMucosa": 6, + "Cavity_Oral": 7, + "Cochlea_L": 8, + "Cochlea_R": 9, + "Cricopharyngeus": 10, + "Esophagus_S": 11, + "Eye_AL": 12, + "Eye_AR": 13, + "Eye_PL": 14, + "Eye_PR": 15, + "Glnd_Lacrimal_L": 16, + "Glnd_Lacrimal_R": 17, + "Glnd_Submand_L": 18, + "Glnd_Submand_R": 19, + "Glnd_Thyroid": 20, + "Glottis": 21, + "Larynx_SG": 22, + "Lips": 23, + "OpticChiasm": 24, + "OpticNrv_L": 25, + "OpticNrv_R": 26, + "Parotid_L": 27, + "Parotid_R": 28, + "Pituitary": 29, + "SpinalCord": 30, +} + + +def ants_2_itk(image): + imageITK = sitk.GetImageFromArray(image.numpy().T) + imageITK.SetOrigin(image.origin) + imageITK.SetSpacing(image.spacing) + imageITK.SetDirection(image.direction.reshape(9)) + return imageITK + +def itk_2_ants(image): + image_ants = ants.from_numpy(sitk.GetArrayFromImage(image).T, + origin=image.GetOrigin(), + spacing=image.GetSpacing(), + direction=np.array(image.GetDirection()).reshape(3, 3)) + return image_ants + + +class MyHanseg2023Algorithm(Hanseg2023Algorithm): + def __init__(self): + super().__init__() + + def predict(self, *, image_ct: ants.ANTsImage, image_mrt1: ants.ANTsImage) -> sitk.Image: + print("Computing registration", flush=True) + time0reg= time.time_ns() + mytx = ants.registration(fixed=image_ct, moving=image_mrt1, type_of_transform='Affine') #, aff_iterations=(150, 150, 150, 150)) + print(f"Time reg: {(time.time_ns()-time0reg)/1000000000}") + warped_MR = ants.apply_transforms(fixed=image_ct, moving=image_mrt1, + transformlist=mytx['fwdtransforms'], defaultvalue=image_mrt1.min()) + trained_model_path = join("/opt", "algorithm", "checkpoint", "nnUNet", "Dataset777_HaNSeg2023", "nnUNetTrainer__nnUNetPlans__3d_fullres") + # trained_model_path = join("/usr/DATA/backup_home_dir/jhhan/01_research/01_MICCAI/01_grandchellenge/han_seg/src/HanSeg_2023/nnUNet/dataset/nnUNet_results", + # "Dataset777_HaNSeg2023", "nnUNetTrainer__nnUNetPlans__3d_fullres") + + spacing = tuple(map(float,json.load(open(join(trained_model_path, "plans.json"), "r"))["configurations"]["3d_fullres"]["spacing"])) + ct_image = ants_2_itk(image_ct) + mr_image = ants_2_itk(warped_MR) + del image_mrt1 + del warped_MR + + + properties = { + 'sitk_stuff': + {'spacing': ct_image.GetSpacing(), + 'origin': ct_image.GetOrigin(), + 'direction': ct_image.GetDirection() + }, + # the spacing is inverted with [::-1] because sitk returns the spacing in the wrong order lol. Image arrays + # are returned x,y,z but spacing is returned z,y,x. Duh. + 'spacing': ct_image.GetSpacing()[::-1] + } + images = np.vstack([sitk.GetArrayFromImage(ct_image)[None], sitk.GetArrayFromImage(mr_image)[None]]).astype(np.float32) + fin_origin = ct_image.GetOrigin() + fin_spacing = ct_image.GetSpacing() + fin_direction = ct_image.GetDirection() + fin_size = ct_image.GetSize() + print(fin_spacing) + print(spacing) + print(fin_size) + + old_shape = np.shape(sitk.GetArrayFromImage(ct_image)) + del mr_image + del ct_image + # Shamelessly copied from nnUNet/nnunetv2/preprocessing/resampling/default_resampling.py + new_shape = np.array([int(round(i / j * k)) for i, j, k in zip(fin_spacing, spacing[::-1], fin_size)]) + if new_shape.prod()< 1e8: + print(f"Image is not too large ({new_shape.prod()}), using the folds (0,1,2,3,4) with mirror") + predictor = nnUNetPredictor(tile_step_size=0.4, use_mirroring=True, perform_everything_on_gpu=True, + verbose=True, verbose_preprocessing=True, + allow_tqdm=True) + predictor.initialize_from_trained_model_folder(trained_model_path, use_folds=(0,1,2,3), + checkpoint_name="checkpoint_best.pth") + # predictor.allowed_mirroring_axes = (0, 2) + elif new_shape.prod()< 1.3e8: + print(f"Image is not too large ({new_shape.prod()}), using the folds (0,1,2,3,4)") + + predictor = nnUNetPredictor(tile_step_size=0.6, use_mirroring=True, perform_everything_on_gpu=False, + verbose=True, verbose_preprocessing=True, + allow_tqdm=True) + predictor.initialize_from_trained_model_folder(trained_model_path, use_folds=(0,1,2,3), #(0,1,2,3,4) + checkpoint_name="checkpoint_best.pth") + elif new_shape.prod()< 1.7e8: + print(f"Image is not too large ({new_shape.prod()}), using the 'all' fold with mirror") + + predictor = nnUNetPredictor(tile_step_size=0.4, use_mirroring=True, perform_everything_on_gpu=False, + verbose=True, verbose_preprocessing=True, + allow_tqdm=True) + predictor.initialize_from_trained_model_folder(trained_model_path, use_folds="0", + checkpoint_name="checkpoint_best.pth") + # predictor.allowed_mirroring_axes = (0, 2) + + else: + predictor = nnUNetPredictor(tile_step_size=0.6, use_mirroring=True, perform_everything_on_gpu=False, + verbose=True, verbose_preprocessing=True, + allow_tqdm=True) + print(f"Image is too large ({new_shape.prod()}), using the 'all' fold") + predictor.initialize_from_trained_model_folder(trained_model_path, use_folds="0", + checkpoint_name="checkpoint_best.pth") + + img_temp = predictor.predict_single_npy_array(images, properties, None, None, False).astype(np.uint8) + del images + print("Prediction Done", flush=True) + output_seg = sitk.GetImageFromArray(img_temp) + print(f"Seg: {output_seg.GetSize()}, CT: {fin_size}") + # output_seg.CopyInformation(ct_image) + output_seg.SetOrigin(fin_origin) + output_seg.SetSpacing(fin_spacing) + output_seg.SetDirection(fin_direction) + print("Got Image", flush=True) + # save the simpleITK image + # sitk.WriteImage(output_seg, str("output_seg.seg.nrrd"), True) + return output_seg + +if __name__ == "__main__": + time0 = time.time_ns() + MyHanseg2023Algorithm().process() + print((time.time_ns()-time0)/1000000000) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..1e174c7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,104 @@ +# +# This file is autogenerated by pip-compile with Python 3.8 +# by the following command: +# +# pip-compile --resolver=backtracking +# +antspyx==0.4.2 +arrow==1.2.3 + # via jinja2-time +binaryornot==0.4.4 + # via cookiecutter +build==0.10.0 + # via pip-tools +certifi==2022.12.7 + # via requests +chardet==5.1.0 + # via binaryornot +charset-normalizer==3.1.0 + # via requests +click==8.1.3 + # via + # cookiecutter + # evalutils + # pip-tools +cookiecutter==2.1.1 + # via evalutils +evalutils==0.4.0 + # via -r requirements.in +idna==3.4 + # via requests +imageio[tifffile]==2.26.0 + # via evalutils +jinja2==3.1.2 + # via + # cookiecutter + # jinja2-time +jinja2-time==0.2.0 + # via cookiecutter +joblib==1.2.0 + # via scikit-learn +markupsafe==2.1.2 + # via jinja2 +# nnunetv2==2.2 +multiprocess==0.70.14 +numba==0.58.1 +numpy==1.24.2 + # via + # evalutils + # imageio + # pandas + # scikit-learn + # scipy + # tifffile +packaging==23.0 + # via build +pandas==1.5.3 + # via evalutils +pillow==9.4.0 + # via imageio +pip-tools==6.12.3 + # via evalutils +pyproject-hooks==1.0.0 + # via build +python-dateutil==2.8.2 + # via + # arrow + # pandas +python-slugify==8.0.1 + # via cookiecutter +pytz==2022.7.1 + # via pandas +pyyaml==6.0 + # via cookiecutter +requests==2.28.2 + # via cookiecutter +scikit-learn==1.2.2 + # via evalutils +scipy==1.10.1 + # via + # evalutils + # scikit-learn +simpleitk==2.2.1 + # via evalutils +six==1.16.0 + # via python-dateutil +text-unidecode==1.3 + # via python-slugify +threadpoolctl==3.1.0 + # via scikit-learn +tifffile==2023.3.15 + # via imageio +tomli==2.0.1 + # via + # build + # pyproject-hooks +torchio==0.18.86 +urllib3==1.26.15 + # via requests +wheel==0.40.0 + # via pip-tools + +# The following packages are considered to be unsafe in a requirements file: +# pip +# setuptools diff --git a/test.sh b/test.sh new file mode 100755 index 0000000..aa2ece8 --- /dev/null +++ b/test.sh @@ -0,0 +1,37 @@ +#!/usr/bin/env bash + +SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" +echo $SCRIPTPATH +OUTDIR=$SCRIPTPATH/output + +./build.sh + +# Maximum is currently 30g, configurable in your algorithm image settings on grand challenge +MEM_LIMIT="12g" + +# create output dir if it does not exist +if [ ! -d $OUTDIR ]; then + mkdir $OUTDIR; +fi +echo "starting docker" +# Do not change any of the parameters to docker run, these are fixed +docker run --rm \ + --memory="${MEM_LIMIT}" \ + --memory-swap="${MEM_LIMIT}" \ + --network="none" \ + --cap-drop="ALL" \ + --security-opt="no-new-privileges" \ + --shm-size="128m" \ + --pids-limit="256" \ + --gpus="0" \ + -v $SCRIPTPATH/test/:/input/ \ + -v $SCRIPTPATH/output/:/output \ + hanseg2023algorithm_dmx +echo "docker done" + +echo +echo +echo "Compare files in $OUTDIR with the expected results to see if test is successful" +docker run --rm \ + -v $OUTDIR:/output/ \ + python:3.8-slim ls -al /output/images/head_neck_oar