Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metrics implementation #24

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
29 changes: 29 additions & 0 deletions bayesian_cnn_prometheus/analysis/analyse_lesions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import sys
from pathlib import Path

from bayesian_cnn_prometheus.analysis.lesions_analyzer import LesionsAnalyzer
from bayesian_cnn_prometheus.evaluation.evaluate_model import PredictionOptions
from bayesian_cnn_prometheus.evaluation.utils import load_config


def analyse_lesions():
model_path = sys.argv[1]
config_path = sys.argv[2]
input_path = sys.argv[3]
output_path = sys.argv[4]

# patients_to_analysis = [3, 4, 5, 6, 8, 10, 11, 14, 17, 21, 26, 27, 29, 30, 32, 34, 35, 36, 37, 39, 41]
patients_to_analysis = [3]

model_config = load_config(Path(config_path))
prediction_options = PredictionOptions(
chunk_size=model_config['evaluation']['chunk_size'],
stride=model_config['evaluation']['stride'],
mc_sample=model_config['mc_samples'])

lesions_analyzer = LesionsAnalyzer(model_path, input_path, prediction_options, patients_to_analysis, output_path)
lesions_analyzer.run_analysis(should_postprocess_variance=True)


if __name__ == '__main__':
analyse_lesions()
14 changes: 14 additions & 0 deletions bayesian_cnn_prometheus/analysis/apply_variance_postprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import cv2
import numpy as np


def apply_variance_postprocess(variance: np.ndarray):
lower_bound = np.percentile(variance, 99)
variance[lower_bound > variance] = 0

variance[np.nonzero(variance)] = 255
variance = variance.astype(np.uint8)
for i in range(variance.shape[0]):
im = cv2.erode(variance[i], np.ones((3, 3)))
variance[i, :, :] = cv2.dilate(im, np.ones((3, 3)))
return variance
157 changes: 157 additions & 0 deletions bayesian_cnn_prometheus/analysis/lesions_analyzer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import glob
import json
import os
from pathlib import Path
from typing import Tuple, List

import nibabel as nib
import numpy as np
from tqdm import tqdm

from bayesian_cnn_prometheus.analysis.apply_variance_postprocess import apply_variance_postprocess
from bayesian_cnn_prometheus.constants import Paths
from bayesian_cnn_prometheus.evaluation.bayesian_model_evaluator import BayesianModelEvaluator
from bayesian_cnn_prometheus.evaluation.evaluate_model import PredictionOptions, crop_image_to_bounding_box_with_lungs
from bayesian_cnn_prometheus.evaluation.utils import get_patient_index, load_nifti_file, load_lungs_mask, save_as_nifti


class LesionsAnalyzer:
def __init__(self, model_path: str, input_path: str, prediction_options: PredictionOptions,
patients_to_analysis: List[int] = None, output_path: str = None):
"""
Creates LesionsAnalyzer instance.
:param model_path: path to the model in h5 form
:param input_path: path to the directory with images, lesions and segmentations labels
:param prediction_options: parameters for prediction
:param patients_to_analysis: indices of patients to perform analysis on them, if None - every patient
in the input dir will be analyzed
:param output_path: path to the dir where output should be saved
"""
self.model_path = model_path
self.input_path = input_path
self.output_path = output_path or os.path.join(input_path, Paths.RESULTS_DIR)
self.prediction_options = prediction_options
self.patients_to_analysis = patients_to_analysis
self.model_evaluator = BayesianModelEvaluator(self.model_path, prediction_options.chunk_size)
self.results = {'chunks_analyse': {'overall': {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0}},
'voxels_analyse': {'overall': {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0}}}

def run_analysis(self, should_postprocess_variance: bool = False):
for image_path in tqdm(glob.glob(os.path.join(self.input_path, Paths.IMAGES_DIR, '*.nii.gz'))):
patient_idx = int(get_patient_index(image_path))

if self.patients_to_analysis:
if patient_idx not in self.patients_to_analysis:
continue
print(f'Patient id in analysis: {patient_idx}')

segmentation_path = os.path.join(self.input_path, Paths.REFERENCE_SEGMENTATIONS_DIR,
str(Paths.REFERENCE_SEGMENTATION_FILE_PATTERN)
.format(f'{patient_idx:0>4}', 'nii.gz'))
mask_path = os.path.join(self.input_path, Paths.MASKS_DIR,
str(Paths.MASK_FILE_PATTERN).format(f'{patient_idx:0>4}', 'nii.gz'))
variance_path = os.path.join(self.output_path,
str(Paths.VARIANCE_FILE_PATTERN).format(f'{patient_idx:0>4}', 'nii.gz'))

nifti = nib.load(image_path)
image = nifti.get_fdata()
segmentation = load_lungs_mask(segmentation_path)
mask = load_lungs_mask(mask_path)

cropped_image = crop_image_to_bounding_box_with_lungs(image, segmentation)
cropped_segmentation = crop_image_to_bounding_box_with_lungs(segmentation, segmentation)
cropped_mask = crop_image_to_bounding_box_with_lungs(mask, segmentation)

if os.path.exists(variance_path):
variance = load_nifti_file(variance_path)
else:
variance = self.get_variance(cropped_image, cropped_segmentation, variance_path, nifti.affine,
nifti.header)

if should_postprocess_variance:
variance = apply_variance_postprocess(variance)

self._analyze_chunks(patient_idx, cropped_mask, variance)
self._analyze_voxels(patient_idx, cropped_mask, variance)

self._summarize_results()
self._save_results()

def _analyze_chunks(self, patient_idx, mask: np.ndarray, variance: np.ndarray):
self.results['chunks_analyse'][patient_idx] = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0}

origin_x, origin_y, origin_z = mask.shape
stride_x, stride_y, stride_z = self.prediction_options.stride

for x in range(0, origin_x, stride_x):
for y in range(0, origin_y, stride_y):
for z in range(0, origin_z, stride_z):
mask_chunk = mask[self._get_window((x, y, z))]
mask_mean = np.mean(mask_chunk)

variance_chunk = variance[self._get_window((x, y, z))]
variance_mean = np.mean(variance_chunk)

self._update_results(patient_idx, mask_mean, variance_mean, 'chunks_analyse')

def _get_window(self, coord: Tuple[int, int, int]) -> Tuple[slice, ...]:
return tuple([slice(dim_start, dim_start + chunk_dim) for (dim_start, chunk_dim) in
zip(coord, self.prediction_options.chunk_size[:3])])

def _analyze_voxels(self, patient_idx, mask: np.ndarray, variance: np.ndarray):
self.results['voxels_analyse'][patient_idx] = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0}

origin_x, origin_y, origin_z = mask.shape

for x in range(0, origin_x)[:-1]:
for y in range(0, origin_y)[:-1]:
for z in range(0, origin_z)[:-1]:
mask_value = mask[x, y, z]
variance_value = variance[x, y, z]

self._update_results(patient_idx, mask_value, variance_value, 'voxels_analyse')

def get_variance(self, image: np.ndarray, segmentation: np.ndarray, variance_path: str, nifti_affine, nifti_header,
should_save: bool = True):
predictions = self.model_evaluator.evaluate(
image,
segmentation,
self.prediction_options.mc_sample,
self.prediction_options.stride)

variance = self.model_evaluator.get_segmentation_variance(predictions, segmentation)

if should_save:
save_as_nifti(variance, Path(variance_path), nifti_affine, nifti_header)

return variance

def _update_results(self, patient_idx, ground_truth, prediction, analyse_type: str):
if ground_truth and prediction:
self.results[analyse_type][patient_idx]['tp'] += 1
self.results[analyse_type]['overall']['tp'] += 1

if not ground_truth and not prediction:
self.results[analyse_type][patient_idx]['tn'] += 1
self.results[analyse_type]['overall']['tn'] += 1

if not ground_truth and prediction:
self.results[analyse_type][patient_idx]['fp'] += 1
self.results[analyse_type]['overall']['fp'] += 1

if ground_truth and not prediction:
self.results[analyse_type][patient_idx]['fn'] += 1
self.results[analyse_type]['overall']['fn'] += 1

def _summarize_results(self):
for analyse_type in self.results.keys():
overall = self.results[analyse_type]['overall']

self.results[analyse_type]['overall']['tpr'] = overall['tp'] / max(1, (overall['tp'] + overall['fn']))
self.results[analyse_type]['overall']['tnr'] = overall['tn'] / max(1, (overall['tn'] + overall['fp']))

def _save_results(self):
model_name = Path(self.model_path).stem
results_path = os.path.join(self.output_path, f'results_{model_name}.json')
with open(results_path, 'w') as r:
json.dump(self.results, r)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Tuple, List
from typing import Tuple, List, Optional

import cv2 as cv
import numpy as np
Expand Down Expand Up @@ -94,22 +94,22 @@ def _get_window(self, coord: Window) -> Tuple[slice, ...]:
@classmethod
def save_predictions(
cls, dir_path: Path, patient_idx: int, predictions, lungs_mask, affine: np.ndarray,
nifti_header, should_binarize_prediction) -> None:
nifti_header) -> None:
"""
Saves predictions list in nifti file.
:param patient_id: four-digit patient id
:param predictions: list of arrays with predictions
"""
predictions = np.array(predictions)
segmentation = cls.get_segmentation_from_mean(predictions, should_binarize_prediction)
segmentation = cls.get_segmentation_from_mean(predictions)
variance = cls.get_segmentation_variance(predictions, lungs_mask)
predictions_path = dir_path / Paths.PREDICTIONS_FILE_PATTERN.format(patient_idx, 'nii.gz')
save_as_nifti(segmentation, predictions_path, affine, nifti_header)
variance_path = dir_path / Paths.VARIANCE_FILE_PATTERN.format(patient_idx, 'nii.gz')
save_as_nifti(variance, variance_path, affine, nifti_header)

@staticmethod
def get_segmentation_from_mean(predictions, should_binarize_prediction):
def get_segmentation_from_mean(predictions):
segmentation = np.mean(predictions, axis=0)
return segmentation

Expand Down Expand Up @@ -138,7 +138,6 @@ def binarize_prediction(self, mean_prediction: np.ndarray) -> np.ndarray:
mean_prediction = mean_prediction.copy()
mean_prediction = mean_prediction * 255
mean_prediction = mean_prediction.astype(np.uint8)
max_value = np.max(mean_prediction)

for i in range(mean_prediction.shape[0]):
im_slice = mean_prediction[i]
Expand Down
4 changes: 1 addition & 3 deletions bayesian_cnn_prometheus/evaluation/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ def make_prediction(weights_path: Path, patient_idx, prediction_options: Predict
predictions,
cropped_segmentation,
nifti.affine,
nifti.header,
True)

nifti.header)

if __name__ == '__main__':
main()
21 changes: 21 additions & 0 deletions bayesian_cnn_prometheus/run_analysis_for_experiments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import os
import sys
from pathlib import Path

from bayesian_cnn_prometheus.constants import Paths
from bayesian_cnn_prometheus.run_training_for_experiments import create_sbatch_script


def run_analysis_for_experiment():
model_path = sys.argv[1]
config_path = sys.argv[2]
input_path = sys.argv[3]
output_path = sys.argv[4]

script_path = create_sbatch_script(Path(output_path))
os.system(
f'sbatch {script_path} {os.path.join(Paths.PROJECT_DIR, "analysis", "analyse_lesions.py")} {model_path} {config_path} {input_path} {output_path}')


if __name__ == '__main__':
run_analysis_for_experiment()
10 changes: 5 additions & 5 deletions bayesian_cnn_prometheus/run_training_for_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from bayesian_cnn_prometheus.constants import Paths

EXPERIMENTS_DIR = Path('experiments') / 'control_group'
EXPERIMENTS_DIR = Path('experiments_hp50')
EXPERIMENTS_DIR = str(EXPERIMENTS_DIR)


Expand Down Expand Up @@ -144,14 +144,14 @@ def parse():
'name': 'chunk_change',
'overrides': [
{
'alias': 's',
'alias': 'cs',
'key': 'preprocessing.create_chunks.chunk_size',
'values': [[128, 16, 16]]
'values': [[64, 32, 32]]
},
{
'alias': 'cs',
'alias': 's',
'key': 'preprocessing.create_chunks.stride',
'values': [[64, 16, 16]]
'values': [[64, 32, 32]]
}
],
}
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
'scikit-image==0.17.2',
'scipy==1.5.4',
'glob2==0.7',
'batchgenerators'
'batchgenerators',
'opencv-python~=4.5.5.62'
],
)