Skip to content

Commit

Permalink
Merge branch 'avc' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
opcode81 committed Jul 20, 2022
2 parents 083ccad + 28fac86 commit 511d9ff
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/sensai/sklearn/sklearn_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ def _predictSkLearn(self, inputs: pd.DataFrame) -> pd.DataFrame:
results[varName] = self.models[varName].predict(inputs)
return pd.DataFrame(results)

def getSkLearnModel(self, predictedVarName=None):
if predictedVarName is None:
if len(self.models) > 1:
raise ValueError(f"Must provide predicted variable name (one of {self.models.keys()})")
return next(iter(self.models.values()))
return self.models[predictedVarName]


class AbstractSkLearnMultiDimVectorRegressionModel(AbstractSkLearnVectorRegressionModel, ABC):
"""
Expand Down
25 changes: 25 additions & 0 deletions src/sensai/sklearn/sklearn_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sklearn.neighbors
import sklearn.neural_network
import sklearn.svm
from matplotlib import pyplot as plt

from .sklearn_base import AbstractSkLearnMultipleOneDimVectorRegressionModel, AbstractSkLearnMultiDimVectorRegressionModel, \
FeatureImportanceProviderSkLearnRegressionMultipleOneDim, FeatureImportanceProviderSkLearnRegressionMultiDim
Expand Down Expand Up @@ -109,3 +110,27 @@ class SkLearnDummyVectorRegressionModel(AbstractSkLearnMultipleOneDimVectorRegre
def __init__(self, strategy='mean', constant=None, quantile=None):
super().__init__(sklearn.dummy.DummyRegressor,
strategy=strategy, constant=constant, quantile=quantile)


class SkLearnDecisionTreeVectorRegressionModel(AbstractSkLearnMultipleOneDimVectorRegressionModel):
def __init__(self, random_state=42, **modelArgs):
super().__init__(sklearn.tree.DecisionTreeRegressor, random_state=random_state, **modelArgs)

def plot(self, predictedVarName=None, figsize=None) -> plt.Figure:
model = self.getSkLearnModel(predictedVarName)
fig = plt.figure(figsize=figsize)
sklearn.tree.plot_tree(model, feature_names=self.getModelInputVariableNames())
return fig

def plotGraphvizPDF(self, dotPath, predictedVarName=None):
"""
:param path: the path to a .dot file that will be created, alongside which a rendered PDF file (with added suffix ".pdf")
will be placed
:param predictedVarName: the predicted variable name for which to plot the model (if multiple; None is admissible if
there is only one predicted variable)
"""
import graphviz
dot = sklearn.tree.export_graphviz(self.getSkLearnModel(predictedVarName), out_file=None,
feature_names=self.getModelInputVariableNames(), filled=True)
graphviz.Source(dot).render(dotPath)

0 comments on commit 511d9ff

Please sign in to comment.