diff --git a/cosas/misc.py b/cosas/misc.py index a842c06..7496917 100644 --- a/cosas/misc.py +++ b/cosas/misc.py @@ -1,5 +1,6 @@ import argparse -from typing import Tuple +from typing import Tuple, Dict, List +from collections import defaultdict import random import numpy as np @@ -7,6 +8,7 @@ from torchvision.transforms import ToPILImage from matplotlib import pyplot as plt from sklearn.model_selection import train_test_split +from torchvision.transforms.functional import rotate from .data_model import COSASData, Scanncers from .transforms import remove_pad, reverse_tesellation @@ -196,3 +198,33 @@ def plot_xypred( plt.tight_layout() return fig, axes + + +def tta_softvoting(n_trials: List[Dict[str, torch.Tensor]]): + """한 데이터포인트(=instance)에 대한 TTA 진행후 tensor""" + + res = defaultdict(list) + for trial in n_trials: + for k, v in trial.items(): + res[k].append(v) + return {k: torch.stack(v, dim=0).mean(dim=0) for k, v in res.items()} + + +@torch.no_grad() +def rotational_tta(xs, model, angles=[0, 90, 180, 270]): + """배치(xs)에 대해서 TTA을 진행""" + + y_hats = [] + for x in xs: + outputs = [] + for angle in angles: + x_new = rotate(x, angle=angle) + output = model(x_new.unsqueeze(0)) + outputs.append( + {k: rotate(tensor.squeeze(0), -angle) for k, tensor in output.items()} + ) + y_hats.append(tta_softvoting(outputs)) + + return { + k: torch.stack([y_hat[k] for y_hat in y_hats], dim=0) for k in y_hats[0].keys() + } diff --git a/experiments/tta.py b/experiments/tta.py index f059520..a7169d1 100644 --- a/experiments/tta.py +++ b/experiments/tta.py @@ -1,13 +1,11 @@ import argparse -from typing import Dict, List, Tuple -from collections import defaultdict +from typing import Tuple import numpy as np import mlflow import albumentations as A import torch from progress.bar import Bar -from torchvision.transforms.functional import rotate from torchvision.transforms import ToPILImage from torch.utils.data import DataLoader from sklearn.model_selection import KFold @@ -26,6 +24,7 @@ ) from cosas.datasets import ImageMaskDataset from cosas.metrics import summarize_metrics +from cosas.misc import rotational_tta MODEL_URI = "file:///vast/AI_team/mlflow_artifact/13/{run_id}/artifacts/model" @@ -157,36 +156,6 @@ def get_args(): return parser.parse_args() -def tta_softvoting(n_trials: List[Dict[str, torch.Tensor]]): - """한 데이터포인트(=instance)에 대한 TTA 진행후 tensor""" - - res = defaultdict(list) - for trial in n_trials: - for k, v in trial.items(): - res[k].append(v) - return {k: torch.stack(v, dim=0).mean(dim=0) for k, v in res.items()} - - -@torch.no_grad() -def rotational_tta(xs, model, angles=[0, 90, 180, 270]): - """배치(xs)에 대해서 TTA을 진행""" - - y_hats = [] - for x in xs: - outputs = [] - for angle in angles: - x_new = rotate(x, angle=angle) - output = model(x_new.unsqueeze(0)) - outputs.append( - {k: rotate(tensor.squeeze(0), -angle) for k, tensor in output.items()} - ) - y_hats.append(tta_softvoting(outputs)) - - return { - k: torch.stack([y_hat[k] for y_hat in y_hats], dim=0) for k in y_hats[0].keys() - } - - def load_data(task: int = 2): cosas_data = COSASData(DATA_DIR, task=task) cosas_data.load() diff --git a/inference.py b/inference.py index 96d9568..7a44c80 100644 --- a/inference.py +++ b/inference.py @@ -1,5 +1,5 @@ import os -from typing import List +import argparse import numpy as np import torch @@ -7,9 +7,18 @@ import albumentations as A from albumentations.pytorch.transforms import ToTensorV2 +from cosas.misc import rotational_tta + CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--dry_run", action="store_true") + + return parser.parse_args() + + def read_image(path) -> np.ndarray: image = SimpleITK.ReadImage(path) return SimpleITK.GetArrayFromImage(image) @@ -51,10 +60,15 @@ def write_image(path, result): def main(): + args = get_args() device = "cuda" if torch.cuda.is_available() else "cpu" - input_dir = "/input/images/adenocarcinoma-image" - output_dir = "/output/images/adenocarcinoma-mask" + if args.dry_run: + input_dir = "task2/input/domain1/images/adenocarcinoma-image" + output_dir = "task2/output/images/adenocarcinoma-mask" + else: + input_dir = "/input/images/adenocarcinoma-image" + output_dir = "/output/images/adenocarcinoma-mask" os.makedirs(output_dir, exist_ok=True) @@ -62,22 +76,22 @@ def main(): model_path = os.path.join(CURRENT_DIR, "model.pth") model = torch.load(model_path).eval().to(device) - with torch.no_grad(): - for filename in os.listdir(input_dir): - if filename.endswith(".mha"): - output_path = os.path.join(output_dir, filename) - try: - input_path = os.path.join(input_dir, filename) - raw_image = read_image(input_path) - original_size = raw_image.shape[:2] - - x: torch.Tensor = preprocess_image(raw_image, device) - confidences: torch.Tensor = model(x)["mask"] - result = postprocess_image(confidences, original_size=original_size) - write_image(output_path, result) - - except Exception as error: - print(error) + for filename in os.listdir(input_dir): + if filename.endswith(".mha"): + output_path = os.path.join(output_dir, filename) + input_path = os.path.join(input_dir, filename) + try: + raw_image = read_image(input_path) + except Exception as e: + print(e) + + original_size = raw_image.shape[:2] + + x: torch.Tensor = preprocess_image(raw_image, device) + confidences: torch.Tensor = rotational_tta(x, model)["mask"] # with no_grad + # confidences: torch.Tensor = model(x)["mask"] + result = postprocess_image(confidences, original_size=original_size) + write_image(output_path, result) if __name__ == "__main__":