diff --git a/src/sensai/sklearn/sklearn_base.py b/src/sensai/sklearn/sklearn_base.py index c7ef911d..2d773daf 100644 --- a/src/sensai/sklearn/sklearn_base.py +++ b/src/sensai/sklearn/sklearn_base.py @@ -161,6 +161,13 @@ def _predictSkLearn(self, inputs: pd.DataFrame) -> pd.DataFrame: results[varName] = self.models[varName].predict(inputs) return pd.DataFrame(results) + def getSkLearnModel(self, predictedVarName=None): + if predictedVarName is None: + if len(self.models) > 1: + raise ValueError(f"Must provide predicted variable name (one of {self.models.keys()})") + return next(iter(self.models.values())) + return self.models[predictedVarName] + class AbstractSkLearnMultiDimVectorRegressionModel(AbstractSkLearnVectorRegressionModel, ABC): """ diff --git a/src/sensai/sklearn/sklearn_regression.py b/src/sensai/sklearn/sklearn_regression.py index 23c4b5e4..9d57efb6 100644 --- a/src/sensai/sklearn/sklearn_regression.py +++ b/src/sensai/sklearn/sklearn_regression.py @@ -6,6 +6,7 @@ import sklearn.neighbors import sklearn.neural_network import sklearn.svm +from matplotlib import pyplot as plt from .sklearn_base import AbstractSkLearnMultipleOneDimVectorRegressionModel, AbstractSkLearnMultiDimVectorRegressionModel, \ FeatureImportanceProviderSkLearnRegressionMultipleOneDim, FeatureImportanceProviderSkLearnRegressionMultiDim @@ -109,3 +110,27 @@ class SkLearnDummyVectorRegressionModel(AbstractSkLearnMultipleOneDimVectorRegre def __init__(self, strategy='mean', constant=None, quantile=None): super().__init__(sklearn.dummy.DummyRegressor, strategy=strategy, constant=constant, quantile=quantile) + + +class SkLearnDecisionTreeVectorRegressionModel(AbstractSkLearnMultipleOneDimVectorRegressionModel): + def __init__(self, random_state=42, **modelArgs): + super().__init__(sklearn.tree.DecisionTreeRegressor, random_state=random_state, **modelArgs) + + def plot(self, predictedVarName=None, figsize=None) -> plt.Figure: + model = self.getSkLearnModel(predictedVarName) + fig = plt.figure(figsize=figsize) + sklearn.tree.plot_tree(model, feature_names=self.getModelInputVariableNames()) + return fig + + def plotGraphvizPDF(self, dotPath, predictedVarName=None): + """ + :param path: the path to a .dot file that will be created, alongside which a rendered PDF file (with added suffix ".pdf") + will be placed + :param predictedVarName: the predicted variable name for which to plot the model (if multiple; None is admissible if + there is only one predicted variable) + """ + import graphviz + dot = sklearn.tree.export_graphviz(self.getSkLearnModel(predictedVarName), out_file=None, + feature_names=self.getModelInputVariableNames(), filled=True) + graphviz.Source(dot).render(dotPath) +