-
Notifications
You must be signed in to change notification settings - Fork 115
design
The R GBM package allows users to train gradient boosting machines primarily for the purposes of prediction and feature selection. The package is built using RCpp and is composed of two parts: a R front-end that takes user inputs and performs pre-processing before passing a series of arguments to the second part; the C++ back-end. The main algorithmic work is done in the C++ layer, which on completion passes R objects back to the R front-end for postprocessing and reporting. This design document focusses on the C++ back-end, providing descriptions of its scope, functionality and structure; for documentation on the R API please see the R GBM package vignettes.
The C++ back-end, from here on referred to as the "system", provides algorithms to: train a GBM model, make predictions on new data using such a fitted model and calculate the marginal effects of variable data within such a fitted model via "integrating out" the other variables. The main bulk of the C++ codebase focusses on the training of a GBM model, tasks involved in this process include: set-up of datasets, calculation of model residuals, fitting a tree to said residuals and determining the optimal node predictions for the fitted tree. The other pieces of functionality, that is prediction and calculating the marginal effects of variables, are singular functions with little or no developer designed classes associated with them.
The structure imposed on the C++ back-end was selected to:
-
Replace a legacy setup with one more amenable to change.
-
Make the structure and functionality of the system more transparent via appropriate OO design and encapsulation.
-
Implement design patterns where appropriate - in particular to simplify memory management through the application of the RAII idiom.
Within the GBM package the system performs the main algorithmic calculations. It undertakes all of the algorithmic work in training a gradient boosted model, making prediction based on new covariate datasets and calculating the marginal effects of covariates given a fitted model. It does not perform any formatting or postprocessing beyond converting the fitted trees, errors etc. to a R list object and passing it back to the front-end.
system, the description of the system will focus on the behaviour of
the system and its structure is captured by the Functional and
Development views respectively. Within the package the system is
located in the /src
directory. To conclude this introduction, it
should be noted that the system has been developed in line with Google
C++ style.
The main use cases of the system are:
- Training a gradient boosted model using data and an appropriate statistical model.
- Predicting the response from new covariates using a trained model.
- Calculating Marginal effects of data using a trained model.
The system is composed of numerous classes used almost exclusively for
training along with an interface to the R front-end. This entry point
is appropriately named gbmentry.cpp
and contains 3 functions: gbm,
gbm_pred and gbm_plot. These functions perform the training,
prediction and marginal effects calculations described in Section 2.1
respectively. The functionality and dynamic behaviour of these
methods will now be described.
Figure 1. Component diagram showing the interface between the R layer and the system.
The procedure to train a gbm model follows a sequence of simple steps as shown in Figure 2. Upon entry to the gbm method the user specified parameters are converted to an appropriate GBM configuration objects, see datadistparams.h
and treeparams.h
. These configuration objects are then used to initialize the GBM engine, this component stores both the tree parameters, which define the tree to be fitted, and the data/distribution container. Upon initialization the GBM engine will also initialize an instance of an appropriate dataset class, a bagged data class and a distribution object from the GBM configuration object. The dataset object contains both the training and validation data; these sets live together in vector containers and to swap between them the pointer to the first element in the set of interest is shifted by the appropriate amount. The bagged data class defines which elements of the dataset are used in growing an individual tree; how these elements are chosen is distribution dependent. After this, the GbmFit object is initialized. This object contains the current fitted learner, the errors associated with the fit and the function estimate (a Rcpp numeric vector). The function estimate is set to the previous estimates if provided or an appropriate value for the distribution selected otherwise.
With the GBM engine and the fit object initialized, the algorithm loops over the number of trees to fit and performs the following actions:
- Create a tree object using .
- Bag the data to use for fitting the tree.
- Calculate the residuals using the current estimate, bagged data and distribution object.
- Fit the tree to the calculated residuals - this is done by generating all possible splits for a subset of variables and choosing the split to minimize the variance of the residuals within nodes. The best splits for each node are cached so if a node is not split on this round, no future splitting of that node is evaluated.
- Set the terminal node predictions which minimize the error.
- Adjust the tree, so non-terminal nodes also have predictions.
- Update the estimate, for both training and validation sets, and calculate the: training errors, validation errors and the improvement in the out of bag error from the update.
- Wrap up the fitted tree and errors, along with a reference to the dataset, in a FittedLearner object (see
FittedLearner.h
).
The returned fitted learner object defines the "current fit" in the GbmFit object, which is used to update the errors. At the end of a single iteration, the tree and the associated errors are converted to a R List representation which is then output to the R layer once all the trees have been fitted.
Figure 2. Activity diagram for the process of training a GBM model.
The bagging of the data from the training set and calculation of residuals is distribution dependent and so is incorporated in the system as appropriate methods in the distribution class and its container. The tree object is comprised of several components, the principle ones are: a root node, a vector of the terminal nodes and a vector mapping the bagged data to terminal nodes. A node splitter which generates and assigns the best potential split to each terminal node is used in the tree growing method. This structure is shown below in Figure 3.
Figure 3. Component diagram showing what the GBM Engine component is comprised of.
Using a previously fitted GBM model, the system can predict the response of further covariate data. The R frontend passes the following variables to the gbm_pred
function in gbmentry.cpp
: the covariates for prediction, the number of trees to use to predict, the initial prediction estimate, the fitted trees, the categories of the splits, the variable types and a bool indicating whether to return the results using a single fitted tree or not.
The prediction for the new covariates is then initialized, either to the previous tree prediction or the initial estimate value from the model training if it is the first tree prediction to be calculating. The number of trees to be fitted is then looped over, within each iteration the observations are in the covariate dataset are also looped over. Each observation is initially in the root node, a counter which tracks the current node the observation is in is thus set to 0. This observation is then moved through the tree, updating its current node, until it reaches the appropriate terminal node. Once the algorithm reaches a terminal node, the prediction for that observation and tree is set to the final split value of the terminal node's parent in the tree. This process is repeated over all observations and for all trees, once completed the prediction vector is wrapped up and passed back to the R front-end, see Figure 4.
Figure 4. Activity diagram showing how the algorithm performs predictions on new covariate data.
The final piece of functionality offered by the system is to calculate the marginal effects of a variable by "integrating" out the others. As well as a fitted gbm model and covariate data the user also specifies which variables they're interested in at the R layer. This method utilises a NodeStack
class which is defined at the beginning of gbmentry.cpp
, this is a simple class that defines a stack of pairs. These pairs contain the node index, that is what node in the tree we are looking at, and its weight.
With this in mind the method works as follows:
- the initial prediction is set to user specified initial values.
- the algorithm loops over the number of trees to fit.
- the algorithm then loops over the observations.
- in the observation loop it creates a root node, puts it on the stack and sets the observation's current node to be 0 with weight 1.0.
- it gets the top node off the stack and checks if it is split.
- (a) if it isn't split the predicted function for this variable and tree is set to the split value times the weight of the node.
- (b) if the node is non-terminal it checks if the split variable is in the user specified variables of interested. If so the observation is moved to the appropriate child node which is then added to the stack. If this is not the case, two nodes (an average of both left/right splits - this "integrates" out the other variables) are added to the stack and it returns to step 5.
- When the observation reaches a terminal node it gets another observation and repeats steps 5 & 6.
- With all trees fitted the predicted function is wrapped up and output.
Figure 5. Activity diagram for calculating the marginal effects of specific variables.
The near entirety of the system is devoted to the task of training a gbm model and so this Section will focus on describing the classes and design patterns implemented to meet this end. Starting at a high level the primary objects are the "gbm engine" found in gbm_engine.cpp
and the fit defined in gbm_fit.h
. The component has the data and distribution container, see gbm_datacontainer.cpp
, and a reference to the tree parameters (treeparams.h
) as private members; this is shown in Figure 3. The "gbm engine" generates the initial function estimate and through the FitLearner
method which run the methods of the gbm_datacontainer.cpp
and tree.cpp
to perform tasks such as growing trees and calculating errors/bagging data. The system is built on the RAII idiom and these container classes are initialized on contruction of this gbm engine object and released on its destruction. This idiom is applied across the system to simplify the task of memory management and further elicit appropriate object design. Beyond this, encapsulation and const. correctness are implemented where possible within the system.
The R layer passes a number of parameters/R objects to the system. These are converted to appropriate parameters and stored in simple objects for use in initializing the data and distribution container and tree objects. After construction of the gbm engine the DataDistParams
class is not used again during training while the TreeParams
object are stored within the engine and used to initialize trees for fitting on each iteration.
The dataset class, dataset.h
, is part of the data/distribution container and is initialized using the DataDistParams
described in the previous Section. This class stores both appropriate R objects, e. g. the response is stored as RCpp::NumericMatrix
, and standard C++ datatypes describing the data, e. g. unsigned long num_traindata_
. To iterate over these R vectors/arrays the dataset object has appropriate pointers to these arrays which can be accessed via the corresponding getter methods. The R objects stored here contain both training and validation data, to access a particular dataset the appropriate pointers are shifted by the length of the dataset so they point to the beginning of the validation/training set. Most of this class is defined within the header as the number of calls to the dataset's getter functions means that inlining these methods has a significant impact on the system's overall performance. It also has a RandomOrder()
method which takes the predictor variables and shuffles them, this is used during the tree growing phase where all splits are generated for a random subset of features in the data. Finally, an associated class is the "Bag"; this class contains a vector of integers (0/1's) called the "bag" which defines the training data to be used in growing the tree.
The distribution class, distribution.h
, is an abstract class which defines the implementation any other distributions to be included in the system must follow. The distribution object performs numerous features using an instance of thedataset class defined in the previous Section. Importantly it constructs the bag for the data, this method appropriately titled BagData
is only different for the pairwise distribution where data groupings affect the bagging procedure. Before constructing the bags, the distribution object needs to be initialized with the data. This initialization will take the data rows and construct a multi-map which maps those rows to their corresponding observation ids, often each row is an unique observation but this is not always the case. This multi-map ensures bagging occurs on a by-observation basis as opposed to a by-row basis, as the latter would result in overly optimistic errors and potentially a failure of the training method where data from a single observation could appear both in the bag and out of it. The distribution object is also responsible for calculating the residuals for tree growing, the errors in the training and validation sets, initializing the function estimate and calculating the out of bag error improvement on fitting a new tree. These methods are all accessed via its container class stored within the gbm engine object.
Figure 6. Diagram showing the distributions currently implemented within the system.
The construction of the a distribution object is done using a dynamic factory pattern, see Figure 7. In this design, the constructors of the concrete distribution classes are private and they contain a static Create
method. This method is registered to a map in the factory, distribution_factory.cpp
, where the name of the distribution, a string, is its key. On creation of a gbm engine object, a distribution factory is created and the appropriate distribution is generated from the factory using the user selected distribution, a string which acts as key to the same map, and DataDistParams
struct.
Figure 7. Dynamic Factory pattern used to control the creation of distribution objects. The distribution factory itself is a singleton.
To conclude this subsection, there are some concrete distribution specific details to describe. The pairwise distribution, CPairwise.h
, uses several objects which specify the information ratio measure that this distribution uses in its implementation. These classes all inherit from an abstract IRMeasure
class and are defined in the CPairwise.h
and CPairwise.cpp
files.
Figure 8. Diagram displaying the information ratio classes used by the pairwise distribution.
The other distribution which has a unique design is the Cox partial hazards model. The Cox partial hazards model is a survival model whose implementation depends on whether the response matrix provided to the system is that of "censored" data or "start-stop" data. In essence, if the response matrix has 2 columns it is censored and if it has more than 2 it is start-stop. To incorporate this added complexity, the CCoxPH.h
class follows a state pattern, whereby the Cox PH object contains a pointer to an object implementing the correct methods for a specific scenario and those objects contain a pointer to the Cox PH object. This design is shown in Figure 9.
Figure 9. The Cox Partial Hazards distribution employs a state design pattern to accommodate the dependency of its implementation on the response matrix.
The final component of the gbm engine to consider is the tree class defined in tree.h
. As described in Section 2.1 the tree class contains: a rootnode, of class CNode
, a vector of pointers to terminal nodes, a vector assigning data to those terminal nodes, a node search object (see node_searcher.h
) and various data describing the tree (such as its maximum depth etc.). The tree class has methods that grow the tree, reset it for another growth, prediction on validation data, adjusting the tree once predictions are assigned to the terminal nodes and outputting it in a suitable format for the R layer to use.
The most important part of this class is the method for the growing of the tree. A key component in growing a tree is identifying appropriate splits, this role is the responsibility of the node searcher object. The node searcher object stores the best split identified for each terminal node. It uses vectors of variable splitters (see vec_varsplitters.h
and varsplitter.h
) to determine the best splits generated for each node and for each variable. Within the GenerateAllSplits()
method, one vector of NodeParams
, see below, copies the current best splits stored in the node searcher object. The variables to split on are then looped over and a vector of variable splits, one for each terminal node, is initialized to calculate the best split for the current variable in each node.
The best split for the current variable is extracted from this vector and used to update the copy of the best splits. At the end of this loop, the updated copy of best splits are then assigned to the best splits stored in the node search object. This design follows a MapReduce structure and allows for parallelisation of this search process; the details of this parallelisation (see parallel_details.h
) are also stored as a private member in the node search class.
The node class itself contains pointers to its children and its methods, such as calculating predictions, are dependent on the type of split present. The node can either be terminal, so it has no split, or have a continuous/categorical split. To account for this dependency on the split, the node class is implemented with a state design pattern , similar to the Cox PH distribution, but in this instance it posseses a SetStrategy()
method so the implementation of a node can change as a tree grows. The variable splitter objects contain "node parameters", see node_parameters.h
, which encapsulate the properties of the proposed split nodes. These NodeParams
objects use NodeDef
structs, located in node_parameters.h
, to define the proposed node splits and provides methods to evaluate the quality of the proposed splits. Finally, the varsplitter
class has a state design pattern where on construction the methods for how the variable of interest should be split are set, see Figure 11.
Figure 10. State pattern for the node object, it possesses a "SetStrategy()" method so nodes can change type as the tree grows.
Figure 11. State pattern for the variable splitter object. How it splits on a variable is set on construction.
To conclude this Secion formal system class and interaction diagrams are presented in Figures 12 & 13 repsectively. These provide a complete and more detailed picture of the classes and their interactions which define the system, in particular the process of training a gbm.
Figure 12. System class diagram showing the classes and how they interface in the training method.
Figure 13. Sequence diagram for training a single tree when fitting a gradient boosted model.
To use the GBM package it is necessary to install a version of R on your system, preferably 3.2 or higher. With R installed, the GBM package may be installed by opening up a R terminal and using the install.packages("gbm")
command to install the package.
The system is subject to two types of testing: black-box system and acceptance tests and unit testing. These tests are designed to run with the testthat
package and ship with the GBM package in /tests/
folder. The package can be checked using R CMD check <built package>
which will automatically run all of the tests within it and report any warnings or errors.
Currently the system has a series of black box system tests which check that the package is operating as expected using what can be considered 'sanity test cases' as well testing for errors. These tests focus primarily on the Gaussian, Bernoulli and CoxPH distributions at this time. Some tests, such as the checking the effects of the offset, touch all distributions implemented in the system. Using the covr
package, the travis CI server automatically collects code coverage data and sends it to codecov.io
. As of 03/06/2016 this coverage is at 66.07%, with a coverage of the R layer at 42.15% and coverage of the system at 79.28%.
This coverage tool is not perfect and when measuring exercisation of inlined code it is likely to say that the inlined functionality has been missed when in fact it has been exercised. Beyond this the tests do not exercise the counting Cox PH functionality, misses node functionality such as PrintSubTree
and GetVarRelativeInfluence
, does not touch many unhappy paths/error throwing paths and completely misses the gbm_plot
functionality. That being said, even with only the higher level black-box testing in place most of the important paths within the system are exercised. By contrast the R layer has a few well tested pieces of functionality, such as gbm.fit.R
, but a large number of functions have no coverage at all.
Finally, the system itself can now be tested using R's testthat
package thanks to a recent update. These unit tests have yet to be written.