Skip to content

Commit

Permalink
Additional custom input tests (#33)
Browse files Browse the repository at this point in the history
* fix: CI configuration 'build'

* feat: improve custom input tests
  • Loading branch information
CodeLionX authored Jul 16, 2023
1 parent 9ebc74e commit 958ba10
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 16 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ jobs:
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
needs:
- test
- typecheck
- check
runs-on: ubuntu-latest

steps:
Expand All @@ -117,7 +117,7 @@ jobs:
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
needs:
- test
- typecheck
- check
runs-on: ubuntu-latest

steps:
Expand Down
6 changes: 6 additions & 0 deletions tests/custom_input_ts/data_types_timeseries.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
ints,strings,bools,floats,nulls
1,a,true,0.187,
2,b,false,4.123,
3,c,false,-7896.1,
4,doc,false,325.1,
-1,ester,true,897.123,
51 changes: 37 additions & 14 deletions tests/test_base_oscillations/test_custom_input.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import unittest
from pathlib import Path
import warnings

import pandas as pd
import numpy as np
import os
from numpy.random import SeedSequence
from numpy.testing import assert_array_equal

Expand All @@ -16,6 +16,9 @@ def setUp(self) -> None:
self.ctx = GenerationContext(SeedSequence(42)).to_bo()
self.input_path1 = Path("tests/custom_input_ts/dummy_timeseries.csv")
self.input_path2 = Path("tests/custom_input_ts/dummy_timeseries_2.csv")
self.input_path_datatypes = Path(
"tests/custom_input_ts/data_types_timeseries.csv"
)
self.column_idx = 1
self.length = 100
self.expected_test = (
Expand Down Expand Up @@ -140,19 +143,39 @@ def test_input_too_short(self):
)
self.assertRegex(str(e.exception).lower(), "less than the desired length")

def test_integer_conversion(self):
df = pd.DataFrame({"data": [1, 2, 3, 4, 5]})
df.to_csv("test_data.csv", index=False)
custom_input = CustomInput("test_data.csv")
# test if warning is raised
with self.assertWarns(UserWarning):
# Generate the time series data
timeseries = custom_input.generate_only_base(
def test_read_floats(self):
for tpe in ["floats", "nulls"]:
# a warning will raise an error!
with warnings.catch_warnings():
timeseries = CustomInput().generate_only_base(
ctx=self.ctx,
length=5,
input_timeseries_path_test=self.input_path_datatypes,
use_column_test=tpe,
)
# data is properly converted
self.assertEqual(timeseries.dtype, np.float_)

def test_convert_to_float_with_warning(self):
for tpe in ["ints", "bools"]:
with self.assertWarns(UserWarning) as w:
timeseries = CustomInput().generate_only_base(
ctx=self.ctx,
length=5,
input_timeseries_path_test=self.input_path_datatypes,
use_column_test=tpe,
)
# warning is raised
self.assertRegex(str(w.warning), "automatically converted to float")
# data is properly converted
self.assertEqual(timeseries.dtype, np.float_)

def test_error_on_string_type(self):
with self.assertRaises(ValueError) as e:
CustomInput().generate_only_base(
ctx=self.ctx,
length=5,
input_timeseries_path_test="test_data.csv",
use_column_test="data",
input_timeseries_path_test=self.input_path_datatypes,
use_column_test="strings",
)
# test if data is properly converted
self.assertEqual(timeseries.dtype, np.float64)
os.remove("test_data.csv")
self.assertRegex(str(e.exception), "could not convert string to float")

0 comments on commit 958ba10

Please sign in to comment.