Skip to content

Commit e90622b

Browse files
authored
Adding a normalized MSE loss. (#7)
* Adding runs to .gitignore, so that no models/results are pushed to Git. * Adding the loss function. * Adding linting and testing to the NormalizedMSELoss. * Fixed the tests. * Linting tests and fixing them so that they pass. * More linting fixes. * Fixing linting on the losses. * Update README.md --------- Signed-off-by: Gal Egozi <[email protected]>
1 parent f386535 commit e90622b

File tree

6 files changed

+214
-4
lines changed

6 files changed

+214
-4
lines changed

.github/workflows/yoke_install_test_lint.yml

+20
Original file line numberDiff line numberDiff line change
@@ -16,64 +16,84 @@ jobs:
1616

1717
steps:
1818
- uses: actions/checkout@v4
19+
1920
- name: Set up Python 3.9
2021
uses: actions/setup-python@v3
2122
with:
2223
python-version: '3.9'
24+
2325
- name: Install flit
2426
run: pip install flit
27+
2528
- name: Build yoke
2629
run: flit install --deps=all
30+
2731
- name: Test with pytest
2832
run: |
2933
pytest -v --cov-report=lcov:./coverage/lcov.info --cov=yoke -Werror
34+
3035
- name: Upload coverage to Coveralls
3136
uses: coverallsapp/github-action@v1
37+
3238
- name: Lint Yoke
3339
run: |
3440
ruff check
3541
ruff check --preview
3642
ruff format --check --diff
3743
continue-on-error: true
44+
3845
- name: Lint applications/evaluation
3946
run: |
4047
ruff check applications/evaluation
4148
ruff check applications/evaluation --preview
4249
ruff format applications/evaluation --check --diff
4350
continue-on-error: false
51+
4452
- name: Lint applications/filelists
4553
run: |
4654
ruff check applications/filelists
4755
ruff check applications/filelists --preview
4856
ruff format applications/filelists --check --diff
4957
continue-on-error: false
58+
5059
- name: Lint applications/normalization
5160
run: |
5261
ruff check applications/normalization
5362
ruff check applications/normalization --preview
5463
ruff format applications/normalization --check --diff
5564
continue-on-error: false
65+
5666
- name: Lint applications/viewers
5767
run: |
5868
ruff check applications/viewers
5969
ruff check applications/viewers --preview
6070
ruff format applications/viewers --check --diff
6171
continue-on-error: false
72+
6273
- name: Lint tests
6374
run: |
6475
ruff check tests
6576
ruff check tests --preview
6677
ruff format tests --check --diff
6778
continue-on-error: false
79+
6880
- name: Lint datasets
6981
run: |
7082
ruff check src/yoke/datasets
7183
ruff check src/yoke/datasets --preview
7284
ruff format src/yoke/datasets --check --diff
7385
continue-on-error: false
86+
7487
- name: Lint models
7588
run: |
7689
ruff check src/yoke/models
7790
ruff check src/yoke/models --preview
7891
ruff format src/yoke/models --check --diff
7992
continue-on-error: false
93+
94+
- name: lint losses
95+
run: |
96+
ruff check src/yoke/losses
97+
ruff check src/yoke/losses --preview
98+
ruff format src/yoke/losses --check --diff
99+
continue-on-error: false

README.md

+6-4
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
YOKE: Yielding Optimal Knowledge Enhancement
22
============================================
33

4+
[![Coverage Status](https://coveralls.io/repos/github/lanl/Yoke/badge.svg?branch=main)](https://coveralls.io/github/lanl/Yoke?branch=main)
45
[![pipeline status](https://github.com/lanl/Yoke/actions/workflows/yoke_install_test_lint.yml/badge.svg)](https://github.com/lanl/Yoke/actions)
56
[![Latest Release](https://img.shields.io/github/v/release/lanl/Yoke)](https://github.com/lanl/Yoke/releases)
6-
[![Coverage Status](https://coveralls.io/repos/github/lanl/Yoke/badge.svg?branch=main)](https://coveralls.io/github/lanl/Yoke?branch=main)
77

88
![Get YOKEd!](./YOKE_DALLE_512x512.png)
99

10-
1110
About:
1211
------
1312

@@ -19,11 +18,14 @@ projects.
1918
The YOKE module is divided into submodules, installed in a python environment:
2019

2120
- datasets/
21+
- helpers/
2222
- models/
2323
- metrics/
24-
- torch_training_utils.py
24+
- losses/
25+
- utils/
2526
- lr_schedulers.py
26-
- parallel_utils.py
27+
- parellel_utils.py
28+
- torch_training_utils.py
2729

2830
Helper utilities and examples are under `applications`:
2931

src/yoke/losses/NormMSE.py

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""This loss function is a per-channel normalized version of mean squared error."""
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
7+
class NormalizedMSELoss(nn.Module):
8+
"""Per-channel normalized mean squared error loss.
9+
10+
This loss function normalizes the input and target tensors per channel
11+
before computing the mean squared error. The normalization is done by
12+
subtracting the mean and dividing by the standard deviation of the target
13+
tensor, with a small epsilon added to the standard deviation to avoid
14+
division by zero.
15+
16+
Args:
17+
eps (float): A small value to avoid division by zero.
18+
reduction (str): Specifies the reduction to apply to the output:
19+
'none' | 'mean' | 'sum'. Default: 'none'.
20+
"""
21+
22+
def __init__(self, eps: float = 1e-8, reduction: str = "none") -> None:
23+
"""Initialize the NormalizedMSELoss.
24+
25+
Args:
26+
eps (float): A small value to avoid division by zero.
27+
reduction (str): Specifies the reduction to apply to the output:
28+
'none' | 'mean' | 'sum'. Default: 'none'.
29+
"""
30+
super().__init__()
31+
self.eps = eps
32+
self.reduction = reduction
33+
34+
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
35+
"""Compute the normalized mean squared error loss.
36+
37+
Args:
38+
pred (torch.Tensor): The predicted tensor.
39+
target (torch.Tensor): The target tensor.
40+
41+
Returns:
42+
torch.Tensor: The computed loss.
43+
"""
44+
target_mean = target.mean(dim=(0, 2, 3), keepdim=True)
45+
target_std = target.std(dim=(0, 2, 3), keepdim=True) + self.eps
46+
47+
pred_norm = (pred - target_mean) / target_std
48+
target_norm = (target - target_mean) / target_std
49+
50+
loss = (pred_norm - target_norm) ** 2
51+
if self.reduction == "mean":
52+
return loss.mean()
53+
elif self.reduction == "sum":
54+
return loss.sum()
55+
else:
56+
return loss

src/yoke/losses/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""All custom loss functions in YOKE live here."""

tests/losses/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""The tests for the src/yoke/losses directory."""

tests/losses/test_normMSEloss.py

+130
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
"""Tests for the NormalizedMSELoss class."""
2+
3+
import pytest
4+
import torch
5+
from yoke.losses.NormMSE import NormalizedMSELoss
6+
7+
8+
@pytest.fixture
9+
def norm_mse() -> NormalizedMSELoss:
10+
"""Fixture for NormalizedMSELoss."""
11+
return NormalizedMSELoss()
12+
13+
14+
def test_norm_mse_loss_zero(norm_mse: NormalizedMSELoss) -> None:
15+
"""Test the NormalizedMSELoss with zero input and target."""
16+
inp = torch.tensor([[[[0.0, 0.0], [0.0, 0.0]]]])
17+
target = torch.tensor([[[[0.0, 0.0], [0.0, 0.0]]]])
18+
loss = norm_mse(inp, target)
19+
assert torch.all(loss == 0.0)
20+
21+
22+
def test_norm_mse_loss_positive(norm_mse: NormalizedMSELoss) -> None:
23+
"""Test the NormalizedMSELoss with positive input and target."""
24+
inp = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]], dtype=torch.float32)
25+
target = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]], dtype=torch.float32)
26+
loss = norm_mse(inp, target)
27+
assert torch.all(loss == 0.0)
28+
29+
30+
def test_norm_mse_loss_non_zero(norm_mse: NormalizedMSELoss) -> None:
31+
"""Test the NormalizedMSELoss with non-zero input and target."""
32+
inp = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]])
33+
target = torch.tensor([[[[4.0, 5.0], [6.0, 7.0]]]])
34+
loss = norm_mse(inp, target)
35+
expected_loss = torch.mean(
36+
(
37+
(inp - target.mean(dim=(0, 2, 3), keepdim=True))
38+
/ (target.std(dim=(0, 2, 3), keepdim=True) + norm_mse.eps)
39+
- (target - target.mean(dim=(0, 2, 3), keepdim=True))
40+
/ (target.std(dim=(0, 2, 3), keepdim=True) + norm_mse.eps)
41+
)
42+
** 2,
43+
dim=(0, 2, 3),
44+
)
45+
assert torch.all(loss == expected_loss)
46+
47+
48+
def test_norm_mse_loss_negative(norm_mse: NormalizedMSELoss) -> None:
49+
"""Test the NormalizedMSELoss with negative input and target."""
50+
inp = torch.tensor([[[[-1.0, -2.0], [-3.0, -4.0]]]])
51+
target = torch.tensor([[[[-4.0, -5.0], [-6.0, -7.0]]]])
52+
loss = norm_mse(inp, target)
53+
expected_loss = torch.mean(
54+
(
55+
(inp - target.mean(dim=(0, 2, 3), keepdim=True))
56+
/ (target.std(dim=(0, 2, 3), keepdim=True) + norm_mse.eps)
57+
- (target - target.mean(dim=(0, 2, 3), keepdim=True))
58+
/ (target.std(dim=(0, 2, 3), keepdim=True) + norm_mse.eps)
59+
)
60+
** 2,
61+
dim=(0, 2, 3),
62+
)
63+
assert torch.all(loss == expected_loss)
64+
65+
66+
def test_norm_mse_loss_mean_reduction() -> None:
67+
"""Test the mean reduction of the NormalizedMSELoss."""
68+
norm_mse = NormalizedMSELoss(reduction="mean")
69+
inp = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]])
70+
target = torch.tensor([[[[4.0, 5.0], [6.0, 7.0]]]])
71+
loss = norm_mse(inp, target)
72+
expected_loss = torch.mean(
73+
(
74+
(inp - target.mean(dim=(0, 2, 3), keepdim=True))
75+
/ (target.std(dim=(0, 2, 3), keepdim=True) + norm_mse.eps)
76+
- (target - target.mean(dim=(0, 2, 3), keepdim=True))
77+
/ (target.std(dim=(0, 2, 3), keepdim=True) + norm_mse.eps)
78+
)
79+
** 2
80+
).mean()
81+
assert torch.all(loss == expected_loss)
82+
83+
84+
def test_norm_mse_loss_sum_reduction() -> None:
85+
"""Test the sum reduction of the NormalizedMSELoss."""
86+
norm_mse = NormalizedMSELoss(reduction="sum")
87+
inp = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]])
88+
target = torch.tensor([[[[4.0, 5.0], [6.0, 7.0]]]])
89+
loss = norm_mse(inp, target)
90+
expected_loss = torch.sum(
91+
(
92+
(inp - target.mean(dim=(0, 2, 3), keepdim=True))
93+
/ (target.std(dim=(0, 2, 3), keepdim=True) + norm_mse.eps)
94+
- (target - target.mean(dim=(0, 2, 3), keepdim=True))
95+
/ (target.std(dim=(0, 2, 3), keepdim=True) + norm_mse.eps)
96+
)
97+
** 2
98+
).sum()
99+
assert loss == expected_loss
100+
101+
102+
def test_norm_mse_loss_different_shapes(norm_mse: NormalizedMSELoss) -> None:
103+
"""Test the NormalizedMSELoss with different input and target shapes."""
104+
inp = torch.tensor([[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]])
105+
target = torch.tensor([[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]])
106+
loss = norm_mse(inp, target)
107+
assert torch.all(loss == 0.0)
108+
109+
110+
def test_norm_mse_loss_batch_size(norm_mse: NormalizedMSELoss) -> None:
111+
"""Test the NormalizedMSELoss with different batch sizes."""
112+
inp = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]], [[[1.0, 2.0], [3.0, 4.0]]]])
113+
target = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]], [[[1.0, 2.0], [3.0, 4.0]]]])
114+
loss = norm_mse(inp, target)
115+
assert torch.all(loss == 0.0)
116+
117+
118+
def test_norm_mse_loss_different_eps() -> None:
119+
"""Test the NormalizedMSELoss with different eps values."""
120+
norm_mse = NormalizedMSELoss(eps=1e-5)
121+
inp = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]])
122+
target = torch.tensor([[[[4.0, 5.0], [6.0, 7.0]]]])
123+
loss = norm_mse(inp, target)
124+
expected_loss = (
125+
(inp - target.mean(dim=(0, 2, 3), keepdim=True))
126+
/ (target.std(dim=(0, 2, 3), keepdim=True) + 1e-5)
127+
- (target - target.mean(dim=(0, 2, 3), keepdim=True))
128+
/ (target.std(dim=(0, 2, 3), keepdim=True) + 1e-5)
129+
) ** 2
130+
assert torch.all(loss == expected_loss)

0 commit comments

Comments
 (0)