Skip to content

Commit

Permalink
image distribution 계산 함수 추가 (issue #2)
Browse files Browse the repository at this point in the history
  • Loading branch information
4pygmalion committed Jul 17, 2024
1 parent 2ada0f1 commit f395d35
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 1 deletion.
27 changes: 27 additions & 0 deletions cosas/tests/test_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import numpy as np
from cosas.transforms import get_image_stats


def test_get_image_stats():

images = np.random.randint(0, 256, size=(10, 224, 224, 3), dtype=np.uint8)

# get image stats
means, stds = get_image_stats(images)

assert means.shape == (3,)
assert stds.shape == (3,)


def test_get_image_stats_list():

images = [
np.random.randint(0, 256, size=(224, 224, 3), dtype=np.uint8)
for i in range(0, 10)
]

# get image stats
means, stds = get_image_stats(images)

assert means.shape == (3,)
assert stds.shape == (3,)
34 changes: 34 additions & 0 deletions cosas/transforms.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,44 @@
import math
from typing import Tuple, List

import numpy as np
import torch
import albumentations as A
from torchvision.transforms.functional import pad


def get_image_stats(
images: np.ndarray | List[np.ndarray],
) -> Tuple[np.ndarray, np.ndarray]:
"""
이미지 배열의 평균 및 표준 편차를 계산합니다.
이 함수는 입력된 이미지 배열이 4차원 배열인지 확인하고,
각 차원에 대해 평균 및 표준 편차를 계산하여 반환합니다.
Params:
images (np.ndarray): 4차원 이미지 배열 (B, W, H, C)
Returns:
Tuple[np.ndarray, np.ndarray]: 이미지 배열의 평균과 표준 편차.
Exception:
ValueError: 입력된 이미지 배열이 4차원 배열이 아닐 경우 발생.
"""
if isinstance(images, list):
images = np.stack(images, axis=0)

if images.ndim != 4:
raise ValueError(
f"Input images array must be 4-dimensional, passed images.ndim({images.ndim})"
)

mean = np.mean(images, axis=(0, 1, 2))
std = np.std(images, axis=(0, 1, 2))

return mean, std


def pad_image_tensor(
image_tensor: torch.Tensor, size: tuple = (224, 224)
) -> torch.Tensor:
Expand Down
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,7 @@ torchvision
torchaudio

segmentation-models-pytorch
albumentations
albumentations

# for CI
pytest==8.2.2

0 comments on commit f395d35

Please sign in to comment.