A python library for decision tree visualization and model interpretation.
By Terence Parr and Prince Grover
See How to visualize decision trees for deeper discussion of our decision tree visualization library and the visual design decisions we made.
Decision trees are the fundamental building block of gradient boosting machines and Random Forests(tm), probably the two most popular machine learning models for structured data. Visualizing decision trees is a tremendous aid when learning how these models work and when interpreting models. Unfortunately, current visualization packages are rudimentary and not immediately helpful to the novice. For example, we couldn't find a library that visualizes how decision nodes split up the feature space. It is also uncommon for libraries to support visualizing a specific feature vector as it weaves down through a tree's decision nodes; we could only find one image showing this.
So, we've created a general package for scikit-learn decision tree visualization and model interpretation, which we'll be using heavily in an upcoming machine learning book (written with Jeremy Howard).
The visualizations are inspired by an educational animation by R2D3; A visual introduction to machine learning. With dtreeviz
, you can visualize how the feature space is split up at decision nodes, how the training samples get distributed in leaf nodes and how the tree makes predictions for a specific observation. These operations are critical to for understanding how classification or regression decision trees work. If you're not familiar with decision trees, check out fast.ai's Introduction to Machine Learning for Coders MOOC.
Install anaconda3 on your system.
To install (Python >=3.6 only), do this (from Anaconda Prompt on Windows!):
pip install dtreeviz
This should also pull in the graphviz
Python library (>=0.9), which we are using for platform specific stuff.
Please email Terence with any helpful notes on making dtreeviz work (better) on other platforms. Thanks!
For your specific platform, please see the following subsections.
You need the graphviz binary for dot
installed with librsvg and pango. Make sure you reinstall or install like this:
brew install graphviz --with-librsvg --with-app --with-pango
(The --with-librsvg
is absolutely required because we generate output using dot
's -Tsvg:cairo
option.)
The OS X version is able to generate/save images in any format dot is allowed to use with -T{format}:cairo
option. So .svg, .pdf are totally safe bets.
Limitations. Jupyter notebook as a bug where they do not show .svg files correctly, but Juypter Lab has no problem.
To get the dot
binary do:
sudo apt install graphviz
Limitations. The view()
method works to pop up a new window and images appear inline for jupyter notebook but not jupyter lab (It gets an error parsing the SVG XML.) The notebook images also have a font substitution from the Arial we use and so some text overlaps. Only .svg files can be generated on this platform.
Download graphviz-2.38.msi and update your Path
environment variable. It's windows so you might need a reboot after updating that environment variable. You should see this from the Anaconda Prompt:
(base) C:\Users\Terence Parr>where dot
C:\Program Files (x86)\Graphviz2.38\bin\dot.exe
(Do not use conda install -c conda-forge python-graphviz
as you get an old version of graphviz
python library.)
Verify from the Anaconda Prompt that this works:
dot -V
If it doesn't work, you have a Path
problem. I found the following test programs useful. The first one sees if Python can find dot
:
import os
import subprocess
proc = subprocess.Popen(['dot','-V'])
print( os.getenv('Path') )
The following version does the same thing except uses graphviz
Python libraries backend support utilities, which is what we use in dtreeviz:
import graphviz.backend as be
cmd = ["dot", "-V"]
stdout, stderr = be.run(cmd, capture_output=True, check=True, quiet=False)
print( stderr )
Jupyter Lab and Jupyter notebook both show the inline .svg images well.
Limitations. Finally, don't use IE to view .svg files. Use Edge as they look much better. I suspect that IE is displaying them as a rasterized not vector images. Only .svg files can be generated on this platform.
dtree
: Main function to create decision tree visualization. Given a decision tree regressor or classifier, creates and returns a tree visualization using the graphviz (DOT) language.
- Required libraries:
Basic libraries and imports that will (might) be needed to generate the sample visualizations shown in examples below.
from sklearn.datasets import *
from sklearn import tree
from dtreeviz.trees import *
- Regression decision tree:
The default orientation of tree is top down but you can change it to left to right usingorientation="LR"
.view()
gives a pop up window with rendered graphviz object.
regr = tree.DecisionTreeRegressor(max_depth=2)
boston = load_boston()
regr.fit(boston.data, boston.target)
viz = dtreeviz(regr,
boston.data,
boston.target,
target_name='price',
feature_names=boston.feature_names)
viz.view()
- Classification decision tree:
An additional argument ofclass_names
giving a mapping of class value with class name is required for classification trees.
classifier = tree.DecisionTreeClassifier(max_depth=2) # limit depth of tree
iris = load_iris()
classifier.fit(iris.data, iris.target)
viz = dtreeviz(classifier,
iris.data,
iris.target,
target_name='variety',
feature_names=iris.feature_names,
class_names=["setosa", "versicolor", "virginica"] # need class_names for classifier
)
viz.view()
- Prediction path:
Highlights the decision nodes in which the feature value of single observation passed in argumentX
falls. Gives feature values of the observation and highlights features which are used by tree to traverse path.
regr = tree.DecisionTreeRegressor(max_depth=2) # limit depth of tree
diabetes = load_diabetes()
regr.fit(diabetes.data, diabetes.target)
X = diabetes.data[np.random.randint(0, len(diabetes.data)),:] # random sample from training
viz = dtreeviz(regr,
diabetes.data,
diabetes.target,
target_name='value',
orientation ='LR', # left-right orientation
feature_names=diabetes.feature_names,
X=X) # need to give single observation for prediction
viz.view()
- Decision tree without scatterplot or histograms for decision nodes:
Simple tree without histograms or scatterplots for decision nodes. Use argumentfancy=False
classifier = tree.DecisionTreeClassifier(max_depth=4) # limit depth of tree
cancer = load_breast_cancer()
classifier.fit(cancer.data, cancer.target)
viz = dtreeviz(classifier,
cancer.data,
cancer.target,
target_name='cancer',
feature_names=cancer.feature_names,
class_names=["malignant", "benign"],
fancy=False ) # fance=False to remove histograms/scatterplots from decision nodes
viz.view()
For more examples and different implementations, please see the jupyter notebook full of examples.
Make sure to follow the install guidelines above.
To push the dtreeviz
library to your local egg cache (force updates) during development, do this (from anaconda prompt on Windows):
python setup.py install -f
E.g., on Terence's box, it add /Users/parrt/anaconda3/lib/python3.6/site-packages/dtreeviz-0.2-py3.6.egg
.