Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved plot_splits for time series splits #1113

Merged
merged 12 commits into from
Nov 8, 2024
Merged

Conversation

d-kleine
Copy link
Contributor

@d-kleine d-kleine commented Nov 6, 2024

Description

  • added title to data splits legend
  • adjusted xticklabels for Sample Index not to intersect with each other anymore
  • added legend for groups (setting the groups as labels on the xaxis would also be possible, but the samples index info would be gone then)

grafik

Related issues or pull requests

fixes #1094

Pull Request Checklist

  • Added a note about the modification or contribution to the ./docs/sources/CHANGELOG.md file (if applicable)
  • Added appropriate unit test functions in the ./mlxtend/*/tests directories (if applicable)
  • Modify documentation in the corresponding Jupyter Notebook under mlxtend/docs/sources/ (if applicable)
  • Ran PYTHONPATH='.' pytest ./mlxtend -sv and make sure that all unit tests pass (for small modifications, it might be sufficient to only run the specific test file, e.g., PYTHONPATH='.' pytest ./mlxtend/classifier/tests/test_stacking_cv_classifier.py -sv)
  • Checked for style issues by running flake8 ./mlxtend

@d-kleine d-kleine changed the title Improved plot_splits for time series splits Improved plot_splits for time series splits Nov 6, 2024
@d-kleine
Copy link
Contributor Author

d-kleine commented Nov 6, 2024

Test code:

import pandas as pd
import numpy as np

from sklearn.datasets import make_regression
from sklearn.dummy import DummyRegressor
from sklearn.metrics import root_mean_squared_error, make_scorer
from sklearn.model_selection import cross_val_score

from mlxtend.evaluate import GroupTimeSeriesSplit
from mlxtend.evaluate.time_series import plot_splits

X_test, y_test = [], []

start_year = 2010
end_year = 2020

for year in np.arange(start_year, end_year):
    X_year, y_year = make_regression(n_samples=5, n_features=2, bias=0, noise=1, random_state=year)
    X_year = pd.DataFrame(X_year).rename(columns={0:'X1', 1:'X2'})
    X_year['year'] = year
    y_year = pd.Series(y_year)
    X_test.append(X_year)
    y_test.append(y_year)

X, y = pd.concat(X_test), pd.concat(y_test)

# modelisation
model = DummyRegressor(strategy="mean")
metric = root_mean_squared_error
cv_args = {"test_size": 1, 'n_splits': len(np.unique(X['year'])) - 1, 'window_type': 'rolling'}
cv = GroupTimeSeriesSplit(**cv_args)

scores = cross_val_score(model, X, y, cv=cv, groups=X['year'], scoring=make_scorer(metric))

plot_splits(X, y, X['year'], **cv_args)

@d-kleine d-kleine marked this pull request as ready for review November 6, 2024 17:59
@d-kleine d-kleine marked this pull request as draft November 6, 2024 22:24
@d-kleine d-kleine marked this pull request as ready for review November 8, 2024 17:05
@d-kleine
Copy link
Contributor Author

d-kleine commented Nov 8, 2024

@rasbt What do you think about these improvements?

Copy link
Owner

@rasbt rasbt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really really good, thanks for the PR

@rasbt rasbt merged commit ec40b75 into rasbt:master Nov 8, 2024
2 checks passed
@d-kleine d-kleine deleted the splits branch November 8, 2024 18:54
@d-kleine d-kleine restored the splits branch November 8, 2024 18:54
@d-kleine d-kleine deleted the splits branch November 8, 2024 18:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Enhance plot
2 participants