Skip to content

Commit

Permalink
Test properties of truncated normal
Browse files Browse the repository at this point in the history
Delete old tests of truncated normal
  • Loading branch information
dafeda committed Sep 11, 2023
1 parent bbf0db5 commit 087883a
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 13 deletions.
3 changes: 2 additions & 1 deletion src/ert/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .external_ert_script import ExternalErtScript
from .field import Field, field_transform
from .gen_data_config import GenDataConfig
from .gen_kw_config import GenKwConfig, PriorDict
from .gen_kw_config import GenKwConfig, PriorDict, TransferFunction
from .hook_runtime import HookRuntime
from .lint_file import lint_file
from .model_config import ModelConfig
Expand Down Expand Up @@ -45,6 +45,7 @@
"Field",
"GenDataConfig",
"GenKwConfig",
"TransferFunction",
"HookRuntime",
"lint_file",
"ModelConfig",
Expand Down
12 changes: 0 additions & 12 deletions tests/unit_tests/config/test_gen_kw_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,6 @@

@pytest.mark.usefixtures("use_tmpdir")
def test_gen_kw_config():
GenKwConfig(
name="KEY",
forward_init=False,
template_file="",
transfer_function_definitions=["KEY UNIFORM 0 1"],
output_file="kw.txt",
)
conf = GenKwConfig(
name="KEY",
forward_init=False,
Expand Down Expand Up @@ -340,11 +333,6 @@ def test_gen_kw_params_parsing(tmpdir, params, error):
("MYNAME LOGNORMAL 0 1", 0.3, 1.34985880757600318347),
("MYNAME LOGNORMAL 0 1", 0.7, 2.01375270747047663278),
("MYNAME LOGNORMAL 0 1", 1.0, 2.71828182845904509080),
("MYNAME TRUNCATED_NORMAL 1 0.25 0 10", -1.0, 0.75000000000000000000),
("MYNAME TRUNCATED_NORMAL 1 0.25 0 10", 0.0, 1.00000000000000000000),
("MYNAME TRUNCATED_NORMAL 1 0.25 0 10", 0.3, 1.07499999999999995559),
("MYNAME TRUNCATED_NORMAL 1 0.25 0 10", 0.7, 1.17500000000000004441),
("MYNAME TRUNCATED_NORMAL 1 0.25 0 10", 1.0, 1.25000000000000000000),
("MYNAME ERRF 1 2 0.1 0.1", -1.0, 1.00000000000000000000),
("MYNAME ERRF 1 2 0.1 0.1", 0.0, 1.84134474606854281475),
("MYNAME ERRF 1 2 0.1 0.1", 0.3, 1.99996832875816688002),
Expand Down
77 changes: 77 additions & 0 deletions tests/unit_tests/config/test_transfer_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from ert.config import TransferFunction

import numpy as np
from hypothesis import given, strategies as st


def valid_params():
_std = st.floats(min_value=0.01, allow_nan=False, allow_infinity=False)

mean_min_max_strategy = st.floats(allow_nan=False, allow_infinity=False).flatmap(
lambda m: st.tuples(
st.just(m),
st.floats(m - 2, m - 1),
st.floats(m + 1, m + 2).filter(
lambda x: x > m
), # _max, ensuring it's strictly greater than _min
)
)

return mean_min_max_strategy.flatmap(
lambda triplet: st.tuples(
st.just(triplet[0]), # _mean
_std, # _std
st.just(triplet[1]), # _min
st.just(triplet[2]), # _max
)
)


@given(st.floats(allow_nan=False, allow_infinity=False), valid_params())
def test_that_truncated_normal_stays_within_bounds(x, arg):
result = TransferFunction.trans_truncated_normal(x, arg)
assert arg[2] <= result <= arg[3]


@given(
st.floats(allow_nan=False, allow_infinity=False),
st.floats(allow_nan=False, allow_infinity=False),
valid_params(),
)
def test_that_truncated_normal_is_monotonic(x1, x2, arg):
result1 = TransferFunction.trans_truncated_normal(x1, arg)
result2 = TransferFunction.trans_truncated_normal(x2, arg)

if x1 < x2:
# Results should be different unless clamped
assert (
result1 < result2
or (result1 == arg[2] and result2 == arg[2])
or (result1 == arg[3] and result2 == arg[3])
)
elif x1 == x2:
assert result1 == result2


@given(valid_params())
def test_that_truncated_normal_is_standardized(arg):
"""If `x` is 0 (i.e., the mean of the standard normal distribution),
the output should be close to `_mean`.
"""
result = TransferFunction.trans_truncated_normal(0, arg)
assert np.isclose(result, arg[0])


@given(st.floats(allow_nan=False, allow_infinity=False), valid_params())
def test_that_truncated_normal_stretches(x, arg):
"""If `x` is 1 standard deviation away from 0, the output should be
`_mean + _std` or `_mean - _std`.
"""
if x == 1:
expected = arg[0] + arg[1]
elif x == -1:
expected = arg[0] - arg[1]
else:
return
result = TransferFunction.trans_truncated_normal(x, arg)
assert np.isclose(result, expected)

0 comments on commit 087883a

Please sign in to comment.