From 3660821d05fae20b0ea251768d86a18110674e3c Mon Sep 17 00:00:00 2001 From: arrrrrmin Date: Tue, 5 Mar 2024 11:48:15 +0100 Subject: [PATCH] Update to test_lengths --- tests/test_generator/test_lengths.py | 51 ++++++++++++---------------- 1 file changed, 22 insertions(+), 29 deletions(-) diff --git a/tests/test_generator/test_lengths.py b/tests/test_generator/test_lengths.py index fe4ea02..8cf0aad 100644 --- a/tests/test_generator/test_lengths.py +++ b/tests/test_generator/test_lengths.py @@ -1,10 +1,12 @@ import unittest +from typing import Dict, Any from gutenTAG import GutenTAG class TestLengths(unittest.TestCase): def setUp(self) -> None: + self.seed = 42 self.config = { "timeseries": [ { @@ -15,7 +17,6 @@ def setUp(self) -> None: { "position": "end", "length": 47, - # "creeping-length": 12, # remove to test the default 20% option "channel": 0, "kinds": [{"kind": "amplitude", "amplitude_factor": 2.0}], } @@ -25,35 +26,27 @@ def setUp(self) -> None: } self.creeping_factor = 0.3 - def update_config(self, length: int, include_creeping: bool = False): - self.config["timeseries"][0]["anomalies"][0]["length"] = length + def _update_config( + self, length: int, include_creeping: bool = False + ) -> Dict[str, Any]: + config = self.config.copy() + config["timeseries"][0]["anomalies"][0]["length"] = length cl = int(round(length * self.creeping_factor)) if include_creeping else 0 - self.config["timeseries"][0]["anomalies"][0]["creeping-length"] = cl + config["timeseries"][0]["anomalies"][0]["creeping-length"] = cl + return config - def __run(self): - res = False - gutentag = GutenTAG(seed=42) - gutentag.load_config_dict(self.config) - try: - gutentag.generate(return_timeseries=True) - res = True - except ValueError: - # For details see: https://github.com/TimeEval/GutenTAG/issues/49 - # Catches ValueError: - # 'operands could not be broadcast together with shapes (n,) (n-1,)' - ... - return res + def _run_and_validate(self, config: Dict[str, Any]) -> None: + gutentag = GutenTAG(seed=self.seed) + gutentag.load_config_dict(config) + ts = gutentag.generate(return_timeseries=True) + self.assertIsNotNone(ts) + self.assertEqual(len(ts), 1) - def test_lengths(self): - results = [] + def test_lengths(self) -> None: for i in range(25, 50): - self.update_config(i, include_creeping=False) - results.append(self.__run()) - assert all(results) - - def test_lengths_with_creeping(self): - results = [] - for i in range(25, 50): - self.update_config(i, include_creeping=True) - results.append(self.__run()) - assert all(results) + # Check case anomaly_protocol.creeping_length == 0 + config = self._update_config(i, include_creeping=False) + self._run_and_validate(config) + # Check case anomaly_protocol.creeping_length > 0 + config = self._update_config(i, include_creeping=True) + self._run_and_validate(config)