Skip to content

Commit

Permalink
Update to test_lengths
Browse files Browse the repository at this point in the history
  • Loading branch information
arrrrrmin committed Mar 5, 2024
1 parent a4e6612 commit 3660821
Showing 1 changed file with 22 additions and 29 deletions.
51 changes: 22 additions & 29 deletions tests/test_generator/test_lengths.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import unittest
from typing import Dict, Any

from gutenTAG import GutenTAG


class TestLengths(unittest.TestCase):
def setUp(self) -> None:
self.seed = 42
self.config = {
"timeseries": [
{
Expand All @@ -15,7 +17,6 @@ def setUp(self) -> None:
{
"position": "end",
"length": 47,
# "creeping-length": 12, # remove to test the default 20% option
"channel": 0,
"kinds": [{"kind": "amplitude", "amplitude_factor": 2.0}],
}
Expand All @@ -25,35 +26,27 @@ def setUp(self) -> None:
}
self.creeping_factor = 0.3

def update_config(self, length: int, include_creeping: bool = False):
self.config["timeseries"][0]["anomalies"][0]["length"] = length
def _update_config(
self, length: int, include_creeping: bool = False
) -> Dict[str, Any]:
config = self.config.copy()
config["timeseries"][0]["anomalies"][0]["length"] = length
cl = int(round(length * self.creeping_factor)) if include_creeping else 0
self.config["timeseries"][0]["anomalies"][0]["creeping-length"] = cl
config["timeseries"][0]["anomalies"][0]["creeping-length"] = cl
return config

def __run(self):
res = False
gutentag = GutenTAG(seed=42)
gutentag.load_config_dict(self.config)
try:
gutentag.generate(return_timeseries=True)
res = True
except ValueError:
# For details see: https://github.com/TimeEval/GutenTAG/issues/49
# Catches ValueError:
# 'operands could not be broadcast together with shapes (n,) (n-1,)'
...
return res
def _run_and_validate(self, config: Dict[str, Any]) -> None:
gutentag = GutenTAG(seed=self.seed)
gutentag.load_config_dict(config)
ts = gutentag.generate(return_timeseries=True)
self.assertIsNotNone(ts)
self.assertEqual(len(ts), 1)

def test_lengths(self):
results = []
def test_lengths(self) -> None:
for i in range(25, 50):
self.update_config(i, include_creeping=False)
results.append(self.__run())
assert all(results)

def test_lengths_with_creeping(self):
results = []
for i in range(25, 50):
self.update_config(i, include_creeping=True)
results.append(self.__run())
assert all(results)
# Check case anomaly_protocol.creeping_length == 0
config = self._update_config(i, include_creeping=False)
self._run_and_validate(config)
# Check case anomaly_protocol.creeping_length > 0
config = self._update_config(i, include_creeping=True)
self._run_and_validate(config)

0 comments on commit 3660821

Please sign in to comment.