Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional custom input tests #33

Merged
merged 2 commits into from
Jul 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ jobs:
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
needs:
- test
- typecheck
- check
runs-on: ubuntu-latest

steps:
Expand All @@ -117,7 +117,7 @@ jobs:
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
needs:
- test
- typecheck
- check
runs-on: ubuntu-latest

steps:
Expand Down
6 changes: 6 additions & 0 deletions tests/custom_input_ts/data_types_timeseries.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
ints,strings,bools,floats,nulls
1,a,true,0.187,
2,b,false,4.123,
3,c,false,-7896.1,
4,doc,false,325.1,
-1,ester,true,897.123,
51 changes: 37 additions & 14 deletions tests/test_base_oscillations/test_custom_input.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import unittest
from pathlib import Path
import warnings

import pandas as pd
import numpy as np
import os
from numpy.random import SeedSequence
from numpy.testing import assert_array_equal

Expand All @@ -16,6 +16,9 @@ def setUp(self) -> None:
self.ctx = GenerationContext(SeedSequence(42)).to_bo()
self.input_path1 = Path("tests/custom_input_ts/dummy_timeseries.csv")
self.input_path2 = Path("tests/custom_input_ts/dummy_timeseries_2.csv")
self.input_path_datatypes = Path(
"tests/custom_input_ts/data_types_timeseries.csv"
)
self.column_idx = 1
self.length = 100
self.expected_test = (
Expand Down Expand Up @@ -140,19 +143,39 @@ def test_input_too_short(self):
)
self.assertRegex(str(e.exception).lower(), "less than the desired length")

def test_integer_conversion(self):
df = pd.DataFrame({"data": [1, 2, 3, 4, 5]})
df.to_csv("test_data.csv", index=False)
custom_input = CustomInput("test_data.csv")
# test if warning is raised
with self.assertWarns(UserWarning):
# Generate the time series data
timeseries = custom_input.generate_only_base(
def test_read_floats(self):
for tpe in ["floats", "nulls"]:
# a warning will raise an error!
with warnings.catch_warnings():
timeseries = CustomInput().generate_only_base(
ctx=self.ctx,
length=5,
input_timeseries_path_test=self.input_path_datatypes,
use_column_test=tpe,
)
# data is properly converted
self.assertEqual(timeseries.dtype, np.float_)

def test_convert_to_float_with_warning(self):
for tpe in ["ints", "bools"]:
with self.assertWarns(UserWarning) as w:
timeseries = CustomInput().generate_only_base(
ctx=self.ctx,
length=5,
input_timeseries_path_test=self.input_path_datatypes,
use_column_test=tpe,
)
# warning is raised
self.assertRegex(str(w.warning), "automatically converted to float")
# data is properly converted
self.assertEqual(timeseries.dtype, np.float_)

def test_error_on_string_type(self):
with self.assertRaises(ValueError) as e:
CustomInput().generate_only_base(
ctx=self.ctx,
length=5,
input_timeseries_path_test="test_data.csv",
use_column_test="data",
input_timeseries_path_test=self.input_path_datatypes,
use_column_test="strings",
)
# test if data is properly converted
self.assertEqual(timeseries.dtype, np.float64)
os.remove("test_data.csv")
self.assertRegex(str(e.exception), "could not convert string to float")
Loading