diff --git a/cosas/tests/test_transforms.py b/cosas/tests/test_transforms.py new file mode 100644 index 0000000..d403060 --- /dev/null +++ b/cosas/tests/test_transforms.py @@ -0,0 +1,27 @@ +import numpy as np +from cosas.transforms import get_image_stats + + +def test_get_image_stats(): + + images = np.random.randint(0, 256, size=(10, 224, 224, 3), dtype=np.uint8) + + # get image stats + means, stds = get_image_stats(images) + + assert means.shape == (3,) + assert stds.shape == (3,) + + +def test_get_image_stats_list(): + + images = [ + np.random.randint(0, 256, size=(224, 224, 3), dtype=np.uint8) + for i in range(0, 10) + ] + + # get image stats + means, stds = get_image_stats(images) + + assert means.shape == (3,) + assert stds.shape == (3,) diff --git a/cosas/transforms.py b/cosas/transforms.py index faeadec..2d0a2b0 100644 --- a/cosas/transforms.py +++ b/cosas/transforms.py @@ -1,10 +1,44 @@ import math +from typing import Tuple, List + import numpy as np import torch import albumentations as A from torchvision.transforms.functional import pad +def get_image_stats( + images: np.ndarray | List[np.ndarray], +) -> Tuple[np.ndarray, np.ndarray]: + """ + 이미지 배열의 평균 및 표준 편차를 계산합니다. + + 이 함수는 입력된 이미지 배열이 4차원 배열인지 확인하고, + 각 차원에 대해 평균 및 표준 편차를 계산하여 반환합니다. + + Params: + images (np.ndarray): 4차원 이미지 배열 (B, W, H, C) + + Returns: + Tuple[np.ndarray, np.ndarray]: 이미지 배열의 평균과 표준 편차. + + Exception: + ValueError: 입력된 이미지 배열이 4차원 배열이 아닐 경우 발생. + """ + if isinstance(images, list): + images = np.stack(images, axis=0) + + if images.ndim != 4: + raise ValueError( + f"Input images array must be 4-dimensional, passed images.ndim({images.ndim})" + ) + + mean = np.mean(images, axis=(0, 1, 2)) + std = np.std(images, axis=(0, 1, 2)) + + return mean, std + + def pad_image_tensor( image_tensor: torch.Tensor, size: tuple = (224, 224) ) -> torch.Tensor: diff --git a/requirements.txt b/requirements.txt index b0a2def..55cd91a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,7 @@ torchvision torchaudio segmentation-models-pytorch -albumentations \ No newline at end of file +albumentations + +# for CI +pytest==8.2.2 \ No newline at end of file