Skip to content

Commit

Permalink
Refactor unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
menouar committed Jul 25, 2023
1 parent c66395f commit fef0b89
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 5 deletions.
File renamed without changes.
2 changes: 1 addition & 1 deletion eventdetector_ts/models/models_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from eventdetector_ts import LSTM, GRU, CNN, RNN_BIDIRECTIONAL, RNN_ENCODER_DECODER, CNN_RNN, FFN, CONV_LSTM1D, \
SELF_ATTENTION, MODELS_DIR
from eventdetector_ts.models import logger_models
from eventdetector_ts.models.helpers import SelfAttention
from eventdetector_ts.models.helpers_models import SelfAttention


class DtwLoss(tf.keras.losses.Loss):
Expand Down
2 changes: 1 addition & 1 deletion eventdetector_ts/models/models_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
META_MODEL_SCALER
from eventdetector_ts.metamodel.utils import DataSplitter
from eventdetector_ts.models import logger_models
from eventdetector_ts.models.helpers import CustomEarlyStopping, custom_cross_val_score
from eventdetector_ts.models.helpers_models import CustomEarlyStopping, custom_cross_val_score
from eventdetector_ts.models.models_builder import ModelBuilder


Expand Down
37 changes: 35 additions & 2 deletions tests/data/test_helpers_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from eventdetector_ts import TimeUnit
from eventdetector_ts.data.helpers_data import overlapping_partitions, compute_middle_event, \
num_columns, convert_dataframe_to_overlapping_partitions, get_timedelta
num_columns, convert_dataframe_to_overlapping_partitions, get_timedelta, get_total_units


def test_overlapping_partitions():
Expand Down Expand Up @@ -155,7 +155,40 @@ def test_year(self):

def test_invalid_unit(self):
with self.assertRaises(ValueError):
get_timedelta(10, "invalid_unit")
get_timedelta(10, None)

def test_microsecond_(self):
td = timedelta(microseconds=123456789)
self.assertEqual(get_total_units(td, TimeUnit.MICROSECOND), 123456789)

def test_millisecond_(self):
td = timedelta(milliseconds=123456)
self.assertEqual(get_total_units(td, TimeUnit.MILLISECOND), 123456)

def test_second_(self):
td = timedelta(seconds=123)
self.assertEqual(get_total_units(td, TimeUnit.SECOND), 123)

def test_minute_(self):
td = timedelta(minutes=2)
self.assertEqual(get_total_units(td, TimeUnit.MINUTE), 2)

def test_hour_(self):
td = timedelta(hours=1)
self.assertEqual(get_total_units(td, TimeUnit.HOUR), 1)

def test_day_(self):
td = timedelta(days=3)
self.assertEqual(get_total_units(td, TimeUnit.DAY), 3)

def test_year_(self):
td = timedelta(days=365.25)
self.assertAlmostEqual(get_total_units(td, TimeUnit.YEAR), 1.0, places=2)

def test_invalid_unit_(self):
td = timedelta(seconds=123)
with self.assertRaises(ValueError):
get_total_units(td, "invalid_unit")


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import tensorflow as tf

from eventdetector_ts.models.helpers import CustomEarlyStopping
from eventdetector_ts.models.helpers_models import CustomEarlyStopping


class TestHelpers(unittest.TestCase):
Expand Down

0 comments on commit fef0b89

Please sign in to comment.