From d45e07a159dbd9edc4d2ca42038c782c02ce52ba Mon Sep 17 00:00:00 2001 From: christopherbunn Date: Tue, 23 Jan 2024 13:45:05 -0500 Subject: [PATCH] Added coverage for column labels --- .../decomposer_tests/test_stl_decomposer.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/evalml/tests/component_tests/decomposer_tests/test_stl_decomposer.py b/evalml/tests/component_tests/decomposer_tests/test_stl_decomposer.py index 6e17067d59..b102dceb66 100644 --- a/evalml/tests/component_tests/decomposer_tests/test_stl_decomposer.py +++ b/evalml/tests/component_tests/decomposer_tests/test_stl_decomposer.py @@ -148,8 +148,15 @@ def test_stl_fit_transform_in_sample( stl = STLDecomposer(period=period) + series_id_columns = ["series_1", "series_2"] + if variateness == "multivariate": + y.columns = series_id_columns + X_t, y_t = stl.fit_transform(X, y) + if variateness == "multivariate": + assert all(y_t.columns == series_id_columns) + # If y_t is a pd.Series, give it columns if isinstance(y_t, pd.Series): y_t = y_t.to_frame() @@ -179,7 +186,11 @@ def test_stl_fit_transform_in_sample( # Check the trend to make sure STL worked properly pd.testing.assert_series_equal( pd.Series(expected_trend), - pd.Series(stl.trends[0]), + pd.Series( + stl.trends["series_1"] + if variateness == "multivariate" + else stl.trends[0], + ), check_exact=False, check_index=False, check_names=False,