diff --git a/tests/test_precision.py b/tests/test_precision.py index c87985092..cdc454b24 100644 --- a/tests/test_precision.py +++ b/tests/test_precision.py @@ -1,4 +1,5 @@ import functools +import importlib from unittest.mock import patch import naive @@ -8,7 +9,7 @@ from numba import cuda import stumpy -from stumpy import config, core +from stumpy import cache, config, core try: from numba.errors import NumbaPerformanceWarning @@ -146,20 +147,67 @@ def test_snippets(): cmp_regimes, ) = stumpy.snippets(T, m, k, s=s, mpdist_T_subseq_isconstant=isconstant_custom_func) - npt.assert_almost_equal( - ref_snippets, cmp_snippets, decimal=config.STUMPY_TEST_PRECISION - ) - npt.assert_almost_equal( - ref_indices, cmp_indices, decimal=config.STUMPY_TEST_PRECISION - ) - npt.assert_almost_equal( - ref_profiles, cmp_profiles, decimal=config.STUMPY_TEST_PRECISION - ) - npt.assert_almost_equal( - ref_fractions, cmp_fractions, decimal=config.STUMPY_TEST_PRECISION - ) - npt.assert_almost_equal(ref_areas, cmp_areas, decimal=config.STUMPY_TEST_PRECISION) - npt.assert_almost_equal(ref_regimes, cmp_regimes) + # Revise fastmath flag, recompile, and re-calculate snippets, + # and then revert the changes + config.STUMPY_FASTMATH_FLAGS = {"nsz", "arcp", "contract", "afn"} + core._calculate_squared_distance.targetoptions["fastmath"] = config.STUMPY_FASTMATH_FLAGS + njit_funcs = cache.get_njit_funcs() + for module_name, func_name in njit_funcs: + code = f"from stumpy.{module_name} import {func_name}; {func_name}.recompile()" + exec(code) + # module = importlib.import_module(f".{module_name}", package="stumpy") + # func = getattr(module, func_name) + # func.recompile() + + ( + cmp_snippets_NOreassoc, + cmp_indices_NOreassoc, + cmp_profiles_NOreassoc, + cmp_fractions_NOreassoc, + cmp_areas_NOreassoc, + cmp_regimes_NOreassoc, + ) = stumpy.snippets(T, m, k, s=s, mpdist_T_subseq_isconstant=isconstant_custom_func) + + config._reset("STUMPY_FASTMATH_FLAGS") + for module_name, func_name in njit_funcs: + module = importlib.import_module(f".{module_name}", package="stumpy") + func = getattr(module, func_name) + func.recompile() + + if np.allclose(ref_snippets, cmp_snippets): + npt.assert_almost_equal( + ref_snippets, cmp_snippets, decimal=config.STUMPY_TEST_PRECISION + ) + npt.assert_almost_equal( + ref_indices, cmp_indices, decimal=config.STUMPY_TEST_PRECISION + ) + npt.assert_almost_equal( + ref_profiles, cmp_profiles, decimal=config.STUMPY_TEST_PRECISION + ) + npt.assert_almost_equal( + ref_fractions, cmp_fractions, decimal=config.STUMPY_TEST_PRECISION + ) + npt.assert_almost_equal( + ref_areas, cmp_areas, decimal=config.STUMPY_TEST_PRECISION + ) + npt.assert_almost_equal(ref_regimes, cmp_regimes) + else: + npt.assert_almost_equal( + ref_snippets, cmp_snippets_NOreassoc, decimal=config.STUMPY_TEST_PRECISION + ) + npt.assert_almost_equal( + ref_indices, cmp_indices_NOreassoc, decimal=config.STUMPY_TEST_PRECISION + ) + npt.assert_almost_equal( + ref_profiles, cmp_profiles_NOreassoc, decimal=config.STUMPY_TEST_PRECISION + ) + npt.assert_almost_equal( + ref_fractions, cmp_fractions_NOreassoc, decimal=config.STUMPY_TEST_PRECISION + ) + npt.assert_almost_equal( + ref_areas, cmp_areas_NOreassoc, decimal=config.STUMPY_TEST_PRECISION + ) + npt.assert_almost_equal(ref_regimes, cmp_regimes_NOreassoc) @pytest.mark.filterwarnings("ignore", category=NumbaPerformanceWarning)