-
Notifications
You must be signed in to change notification settings - Fork 107
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Delete old tests of truncated normal
- Loading branch information
Showing
3 changed files
with
79 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from ert.config import TransferFunction | ||
|
||
import numpy as np | ||
from hypothesis import given, strategies as st | ||
|
||
|
||
def valid_params(): | ||
_std = st.floats(min_value=0.01, allow_nan=False, allow_infinity=False) | ||
|
||
mean_min_max_strategy = st.floats(allow_nan=False, allow_infinity=False).flatmap( | ||
lambda m: st.tuples( | ||
st.just(m), | ||
st.floats(m - 2, m - 1), | ||
st.floats(m + 1, m + 2).filter( | ||
lambda x: x > m | ||
), # _max, ensuring it's strictly greater than _min | ||
) | ||
) | ||
|
||
return mean_min_max_strategy.flatmap( | ||
lambda triplet: st.tuples( | ||
st.just(triplet[0]), # _mean | ||
_std, # _std | ||
st.just(triplet[1]), # _min | ||
st.just(triplet[2]), # _max | ||
) | ||
) | ||
|
||
|
||
@given(st.floats(allow_nan=False, allow_infinity=False), valid_params()) | ||
def test_that_truncated_normal_stays_within_bounds(x, arg): | ||
result = TransferFunction.trans_truncated_normal(x, arg) | ||
assert arg[2] <= result <= arg[3] | ||
|
||
|
||
@given( | ||
st.floats(allow_nan=False, allow_infinity=False), | ||
st.floats(allow_nan=False, allow_infinity=False), | ||
valid_params(), | ||
) | ||
def test_that_truncated_normal_is_monotonic(x1, x2, arg): | ||
result1 = TransferFunction.trans_truncated_normal(x1, arg) | ||
result2 = TransferFunction.trans_truncated_normal(x2, arg) | ||
|
||
if x1 < x2: | ||
# Results should be different unless clamped | ||
assert ( | ||
result1 < result2 | ||
or (result1 == arg[2] and result2 == arg[2]) | ||
or (result1 == arg[3] and result2 == arg[3]) | ||
) | ||
elif x1 == x2: | ||
assert result1 == result2 | ||
|
||
|
||
@given(valid_params()) | ||
def test_that_truncated_normal_is_standardized(arg): | ||
"""If `x` is 0 (i.e., the mean of the standard normal distribution), | ||
the output should be close to `_mean`. | ||
""" | ||
result = TransferFunction.trans_truncated_normal(0, arg) | ||
assert np.isclose(result, arg[0]) | ||
|
||
|
||
@given(st.floats(allow_nan=False, allow_infinity=False), valid_params()) | ||
def test_that_truncated_normal_stretches(x, arg): | ||
"""If `x` is 1 standard deviation away from 0, the output should be | ||
`_mean + _std` or `_mean - _std`. | ||
""" | ||
if x == 1: | ||
expected = arg[0] + arg[1] | ||
elif x == -1: | ||
expected = arg[0] - arg[1] | ||
else: | ||
return | ||
result = TransferFunction.trans_truncated_normal(x, arg) | ||
assert np.isclose(result, expected) |