NYU
WIRELESS
Lecture 13
Decision Trees and Random
Forest
EL-GY 6143/CS-GY 6923: INTRODUCTION TO MACHINE LEARNING
PROF. PEI LIU
1
NYU
WIRELESS
Outline
qDecision tree as constrained space partition
qRegression tree design
qDecision tree pruning
qClassification tree design
qBagging
qRandom Forest
qFeature ranking from random forest
2
NYU
WIRELESS
Decision tree as constrained
space partitioning
3
Fig. 9.2 in ESL
• Each region is regressed/classified to
the same value
• The partition can be specified
sequentially by splitting the range of
one feature at a time.
• The splitting rule can be described by
a tree.
• Each leaf node = One region
• Size of tree = number of leaf
nodes
• The partition is constrained: only
rectangles in the 2D case.
• The top left partition cannot be
realized by a decision tree.
NYU
WIRELESS
How to build a decision tree?
(Regression Case)
qGoal: minimize RSS =#!"#$ #%!∈'" ( − &! )
qGreedy algorithm:
qStart with a single region (entire space) and iterate:
qFor each region !, select a feature *, and a
splitting threshold , such that splitting + with the
criterion * < produces the largest decrease in
RSS in !
◦ Exhaustive search: for each !, try all possible s in the
current range of ! in "
qStop splitting a region if it contains <= !,(
samples
4
# $
All possible splits of #:
$
NYU
WIRELESS
Overfitting
qDecision tree is very prone to overfitting
qCan exactly represent any function defined by the training set by having as many regions (or
leaf nodes) as needed (Fully grown tree)
qHow to control overfitting?
◦ Find optimal subtree (with a certain constraint on the minimum number of samples in the leaf nodes or
maximum depth) by cross validation: too many possibilities
◦ Stop growing once RSS stop decreasing by a threshold with any new cut:
◦ Not good because we use greedy search. It is possible to find a good cut after a bad one.
◦ Better idea: grow a full tree first, then prune the tree.
5
NYU
WIRELESS
Weakest link pruning
q Starting with with the initial full tree T0 , merge
two adjacent leave nodes (daughter nodes) to a
single leaf node (mother). Select which nodes to
merge by minimizing error increase. This produces a
tree with one less region (or node)
q Repeat to merge another two nodes, until the
minimum size tree is reached (e.g. a stump with 2
nodes)
q Generate a sequence of trees
T0, T1, T2, T3, …
qWhich one to choose?
6
NYU
WIRELESS
Cost complexity pruning
qMinimize a complexity regularized loss, over all possible trees T0, T1, T2, …(, α) =(!"#$ (%%∈'& ( − +! ) + α
◦ =0: Full tree, = ∞:minimum sized tree
qHow to choose α? Cross validation!
◦ For each α
◦ For each validation fold:
◦ build a sequence of trees using the training set, and finding the RSS on the testing set for each candidate tree.
Find a tree that minimizes , α
◦ Find average , α over all validation folds
qWhen dataset is very large, can just pick one tree that has minimal RSS for the testset.
7
Number of
leaf nodes
(regions)
NYU
WIRELESS
Example: Predicting baseball player salaries
8
From http://web.stanford.edu/class/stats202/notes/Tree/Regression-trees.html
NYU
WIRELESS
Feature importance
q For each feature, find all splits where this feature was used as the split variable and add up the
loss reduction at all such splits
q The sum reflects the importance of this feature!
9
NYU
WIRELESS
Demo: weather prediction using decision tree
10
NYU
WIRELESS
What about classification?
q The predicted class for each region = the majority class of training samples in the region
q How to design the tree?
◦ Can use the same greedy algorithm
◦ Split each region (by picking a feature and a threshold) to minimize a loss
◦ What loss functions to use?
11
NYU
WIRELESS
Classification loss
qMisclassification rate = ∑!"#$ ∑%%∈'& 1 ( ≠ +! , +!= majority class of !
qGini index = ∑!"#$ !∑*"#+ ̂!,* 1 − ̂!,* , ̂!,*= ratio of samples in ! that is class k
◦ Expected error rate if we randomly pick an index, with probability ̂",( and error rate 1 − ̂",(
qCross entropy = −(!"#$ !(*"#+ ̂!,* log ̂!,* ,
◦ Smaller entropy means less uniform distribution (the region is more pure!)
qGini and cross entropy loss lead to more ”Pure” regions, with 1 dominate class in each region.
12
NYU
WIRELESS
Performance metric and pruning
q After a tree is designed, the performance is still measured by the misclassification rate or
accuracy
q It is typical to use the Gini index or cross entropy when growing a tree, but use the
misclassification rate for pruning a tree
13
NYU
WIRELESS
Example: Classifying heart disease
14
From http://web.stanford.edu/class/stats202/notes/Tree/Classification-trees.html
NYU
WIRELESS
Advantage of decision tree
q Easy to interpret: Doctors like them
q Closer to human decision making
q Feature importance can be derived during training
q Can easily handle mixed type of features (numerical and categorical) and missing features in
some samples
◦ Did not discuss here
q Problem:
◦ To reduce bias, needs to grow the tree deeper
◦ Deeper trees tend to overfit the training data (Large variance among different training data)
◦ How to overcome ?
15
NYU
WIRELESS
Bagging (Bootstrap Aggregating)
qIdea: Generate multiple trees from different training sets, and apply all models to each test
sample and take average (or majority) of the results from all the trees
qHow to generate different training sets giving a dataset?
qCross validation: using a subset of data each time for training and the remaining for testing
qBootstrap sampling: Sampling by replacement, each sampling contains the same number of
samples as the original dataset, but some samples are replicated, others were not included
qBagging: Generate B models from B bootstrap samplings
◦ Regression: Average the prediction results from B models
◦ Classification: Take the majority class index
qApply to other regressors/classifiers as well.
16
NYU
WIRELESS
Out of bag (OOB) error
qEach time we draw a bootstrap sampling, we only use ~63% of
the samples
◦ Probability that a sample is chosen among N samples in each
bootstrap sampling1− 1− 1 )~1− *# = 0.632
qWe can use the remaining samples for testing
qOOB Error
◦
For each sample #, find the models generated by samplings which do not
contain #. There are about 0.37B of models. Average predictions by these
models for #.
◦ Compute the regression/classification error for #
◦ Average the error over all samples
qWe can use OOB error as an estimate of the test error.
qDoes not require design multiple models for multiple folds as in
cross validation. OOB can be estimated from one pass of
designing multiple trees.
17
From ESL Fig. 15.4
NYU
WIRELESS
Why bagging?
qWhen a regressor or classifier has tendency to overfit (i.e. sensitive to the training set), bagging
reduces the variance of the prediction
◦ Reduce the test error
◦ Particularly useful for decision trees
qWhen the sample number N in a given dataset is large
◦ The empirical distribution is similar to the true distribution
◦ Each bootstrap sampling is similar to an independent realization of the true distribution
◦ Bagging amounts to averaging the fits from many identically distributed datasets
18
NYU
WIRELESS
Problems with bagging?
q Trees generated by different samplings can be very similar
q Test error reduces slowly as B increases
◦ +() : prediction by tree b for test sample
◦ Assume +() for all have the same mean and variance $
◦ Assume these predictions have pair-wise correlation
◦ The variance of the average prediction = #,∑+ + : (Shown on board),$ = ρ $ + #, (1 − ρ)$
19
NYU
WIRELESS
Random Forest
q As with Bagging: fit a different tree for each bootstrap sampling
q Recall that when growing a tree, at each current node (region), we split the region by choosing
a particular feature and a threshold. The feature and the threshold are chosen among all P
features to minimize a certain loss.
qWith random forest, randomly choose among a subset of features ( P’
node
q The resulting trees are more different
q Rule of thumb: . = (but should be tuned using test error or OOB error)
20
NYU
WIRELESS
Bagging vs. RF
q Bagging: /) = ρ ) + #/ (1 − ρ))
q Random forest (assuming ρ =0):/) = #/)
q Recall:
Test error = bias^2+ Variance +Noise Variance
21
From ESL, Fig. 15.1
NYU
WIRELESS
Feature importance
q For each feature, add up the loss reduction
at splits where this feature was used over all
trees.
22
NYU
WIRELESS
Demo: Random forest
23
NYU
WIRELESS
Problem with bagging and random forest
q Resulting model has many trees!
q Lose interpretability!
q Related methods (not covered):
◦ Boosting
◦ Gradient boosting
24
NYU
WIRELESS
What you should know from this lecture
q How to use/interpret decision tree ?
q How to train a decision tree ?
◦ Loss function for regression
◦ Loss function for classification
q How to reduce overfitting ?
qWhat does bagging mean ?
q How to train and use a random forest ?
q How to determine feature importance?
25
NYU
WIRELESS
References
q [ESL] Hastie, T., & Tibshirani, R. & Friedman, J.(2008). The Elements of Statistical Learning; Data
Mining, Inference and Prediction. Sec. 9. (Decision tree), Sec. 15.2 (Random forest)