Skip to content

Commit

Permalink
Add second attempt for assertion
Browse files Browse the repository at this point in the history
  • Loading branch information
NimaSarajpoor committed Dec 22, 2024
1 parent 2774302 commit 620dfdf
Showing 1 changed file with 63 additions and 15 deletions.
78 changes: 63 additions & 15 deletions tests/test_precision.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import importlib
from unittest.mock import patch

import naive
Expand All @@ -8,7 +9,7 @@
from numba import cuda

import stumpy
from stumpy import config, core
from stumpy import cache, config, core

try:
from numba.errors import NumbaPerformanceWarning
Expand Down Expand Up @@ -146,20 +147,67 @@ def test_snippets():
cmp_regimes,
) = stumpy.snippets(T, m, k, s=s, mpdist_T_subseq_isconstant=isconstant_custom_func)

npt.assert_almost_equal(
ref_snippets, cmp_snippets, decimal=config.STUMPY_TEST_PRECISION
)
npt.assert_almost_equal(
ref_indices, cmp_indices, decimal=config.STUMPY_TEST_PRECISION
)
npt.assert_almost_equal(
ref_profiles, cmp_profiles, decimal=config.STUMPY_TEST_PRECISION
)
npt.assert_almost_equal(
ref_fractions, cmp_fractions, decimal=config.STUMPY_TEST_PRECISION
)
npt.assert_almost_equal(ref_areas, cmp_areas, decimal=config.STUMPY_TEST_PRECISION)
npt.assert_almost_equal(ref_regimes, cmp_regimes)
# Revise fastmath flag, recompile, and re-calculate snippets,
# and then revert the changes
config.STUMPY_FASTMATH_FLAGS = {"nsz", "arcp", "contract", "afn"}
core._calculate_squared_distance.targetoptions["fastmath"] = config.STUMPY_FASTMATH_FLAGS
njit_funcs = cache.get_njit_funcs()
for module_name, func_name in njit_funcs:
code = f"from stumpy.{module_name} import {func_name}; {func_name}.recompile()"
exec(code)
# module = importlib.import_module(f".{module_name}", package="stumpy")
# func = getattr(module, func_name)
# func.recompile()

(
cmp_snippets_NOreassoc,
cmp_indices_NOreassoc,
cmp_profiles_NOreassoc,
cmp_fractions_NOreassoc,
cmp_areas_NOreassoc,
cmp_regimes_NOreassoc,
) = stumpy.snippets(T, m, k, s=s, mpdist_T_subseq_isconstant=isconstant_custom_func)

config._reset("STUMPY_FASTMATH_FLAGS")
for module_name, func_name in njit_funcs:
module = importlib.import_module(f".{module_name}", package="stumpy")
func = getattr(module, func_name)
func.recompile()

if np.allclose(ref_snippets, cmp_snippets):
npt.assert_almost_equal(
ref_snippets, cmp_snippets, decimal=config.STUMPY_TEST_PRECISION
)
npt.assert_almost_equal(
ref_indices, cmp_indices, decimal=config.STUMPY_TEST_PRECISION
)
npt.assert_almost_equal(
ref_profiles, cmp_profiles, decimal=config.STUMPY_TEST_PRECISION
)
npt.assert_almost_equal(
ref_fractions, cmp_fractions, decimal=config.STUMPY_TEST_PRECISION
)
npt.assert_almost_equal(
ref_areas, cmp_areas, decimal=config.STUMPY_TEST_PRECISION
)
npt.assert_almost_equal(ref_regimes, cmp_regimes)
else:
npt.assert_almost_equal(
ref_snippets, cmp_snippets_NOreassoc, decimal=config.STUMPY_TEST_PRECISION
)
npt.assert_almost_equal(
ref_indices, cmp_indices_NOreassoc, decimal=config.STUMPY_TEST_PRECISION
)
npt.assert_almost_equal(
ref_profiles, cmp_profiles_NOreassoc, decimal=config.STUMPY_TEST_PRECISION
)
npt.assert_almost_equal(
ref_fractions, cmp_fractions_NOreassoc, decimal=config.STUMPY_TEST_PRECISION
)
npt.assert_almost_equal(
ref_areas, cmp_areas_NOreassoc, decimal=config.STUMPY_TEST_PRECISION
)
npt.assert_almost_equal(ref_regimes, cmp_regimes_NOreassoc)


@pytest.mark.filterwarnings("ignore", category=NumbaPerformanceWarning)
Expand Down

0 comments on commit 620dfdf

Please sign in to comment.