Skip to content

Commit

Permalink
add spectral angle mapper
Browse files Browse the repository at this point in the history
  • Loading branch information
anas-rz committed Feb 29, 2024
1 parent 2a609ec commit 6b367a8
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 5 deletions.
59 changes: 59 additions & 0 deletions k3_addons/metrics/image/sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from keras import backend, ops

from k3_addons.utils.checks import _check_same_shape
from k3_addons.utils.distributed import reduce
from k3_addons.api_export import k3_export

get_channel_axis = lambda data_format: 1 if data_format == "channels_first" else -1


def _sam_update(preds, target, data_format=None):
if preds.dtype != target.dtype:
raise TypeError(
"Expected `preds` and `target` to have the same data type."
f" Got preds: {preds.dtype} and target: {target.dtype}."
)
_check_same_shape(preds, target)
if len(ops.shape(preds)) != 4:
raise ValueError(
"Expected `preds` and `target` to have BxCxHxW or BxHxWxC shape."
f" Got preds: {ops.shape(preds)} and target: {ops.shape(target)}."
)
channel_axis = get_channel_axis(data_format)
if (preds.shape[channel_axis] <= 1) or (target.shape[channel_axis] <= 1):
raise ValueError(
"Expected channel dimension of `preds` and `target` to be larger than 1."
f" Got preds: {preds.shape[channel_axis]} and target: {target.shape[channel_axis]}."
)
return preds, target


def _sam_compute(
preds,
target,
reduction="elementwise_mean",
data_format=None,
):
if data_format is None:
data_format = backend.image_data_format()
channel_axis = get_channel_axis(data_format)
print(channel_axis)
dot_product = ops.sum((preds * target), axis=channel_axis)
preds_norm = ops.norm(preds, axis=channel_axis)
target_norm = ops.norm(target, axis=channel_axis)
denom = preds_norm * target_norm
sam_score = ops.clip(dot_product / denom, -1, 1)
sam_score = ops.arccos(sam_score)
return reduce(sam_score, reduction)


@k3_export(
[
"k3_addons.metrics.spectral_angle_mapper",
"k3_addons.metrics.functional.spectral_angle_mapper",
"k3_addons.metrics.image.spectral_angle_mapper",
]
)
def spectral_angle_mapper(preds, target, reduction, data_format=None):
preds, target = _sam_update(preds, target, data_format=data_format)
return _sam_compute(preds, target, reduction=reduction, data_format=data_format)
40 changes: 40 additions & 0 deletions k3_addons/metrics/image/sam_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest
import keras
from keras import ops
import numpy as np
import torch

from k3_addons.metrics.image.sam import (
spectral_angle_mapper as spectral_angle_mapper_keras,
)
from torchmetrics.functional.image.sam import (
spectral_angle_mapper as spectral_angle_mapper_torch,
)


# parametrize the test
@pytest.mark.parametrize(
"input_shape, reduction, data_format",
[
((4, 3, 32, 32), "sum", "channels_first"),
((4, 3, 32, 32), "elementwise_mean", "channels_first"),
((4, 32, 32, 3), "none", "channels_first"),
((4, 32, 32, 3), "sum", "channels_last"),
((4, 32, 32, 3), "elementwise_mean", "channels_last"),
((4, 32, 32, 3), "none", "channels_last"),
],
)
def test_total_variation(input_shape, reduction, data_format):
inputs = keras.random.uniform(input_shape)
labels = keras.random.uniform(input_shape)
tv_keras = spectral_angle_mapper_keras(
inputs, labels, data_format=data_format, reduction=reduction
)
if data_format == "channels_last":
inputs = ops.transpose(inputs, (0, 3, 1, 2))
labels = ops.transpose(labels, (0, 3, 1, 2))
inputs = torch.tensor(ops.convert_to_numpy(inputs))
labels = torch.tensor(ops.convert_to_numpy(labels))
tv_torch = spectral_angle_mapper_torch(inputs, labels, reduction=reduction).numpy()

assert np.allclose(tv_keras, tv_torch, atol=1e-4)
6 changes: 2 additions & 4 deletions k3_addons/metrics/image/tv.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,9 @@ def _total_variation_compute(score, num_elements, reduction):
[
"k3_addons.metrics.total_variation",
"k3_addons.metrics.functional.total_variation",
"k3_addons.metrics.image.total_variation",
]
)
def total_variation(img, reduction="sum", data_format=None):
score, num_elements = _total_variation_update(img, data_format=data_format)
return _total_variation_compute(
score, num_elements, reduction
)

return _total_variation_compute(score, num_elements, reduction)
5 changes: 4 additions & 1 deletion k3_addons/metrics/image/tv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torchmetrics.functional.image.tv import total_variation as total_variation_torch
from torchmetrics.image.tv import TotalVariation as TotalVariationTorch


# parametrize the test
@pytest.mark.parametrize(
"input_shape, reduction, data_format",
Expand All @@ -18,7 +19,9 @@
)
def test_total_variation(input_shape, reduction, data_format):
inputs = keras.random.uniform(input_shape)
tv_keras = total_variation_keras(inputs, data_format=data_format, reduction=reduction)
tv_keras = total_variation_keras(
inputs, data_format=data_format, reduction=reduction
)
if data_format == "channels_last":
inputs = ops.transpose(inputs, (0, 3, 1, 2))
inputs = torch.tensor(ops.convert_to_numpy(inputs))
Expand Down
9 changes: 9 additions & 0 deletions k3_addons/utils/checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from keras import ops


def _check_same_shape(preds, target):
"""Check that predictions and target have the same shape, else raise error."""
if ops.shape(preds) != ops.shape(target):
raise RuntimeError(
f"Predictions and targets are expected to have the same shape, but got {preds.shape} and {target.shape}."
)
11 changes: 11 additions & 0 deletions k3_addons/utils/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from keras import ops


def reduce(x, reduction):
if reduction == "elementwise_mean":
return ops.mean(x)
if reduction == "none" or reduction is None:
return x
if reduction == "sum":
return ops.sum(x)
raise ValueError("Reduction parameter unknown.")

0 comments on commit 6b367a8

Please sign in to comment.