Package smile.classification
Class RandomForest
- All Implemented Interfaces:
Serializable
,ToDoubleFunction<Tuple>
,ToIntFunction<Tuple>
,Classifier<Tuple>
,DataFrameClassifier
,SHAP<Tuple>
,TreeSHAP
public class RandomForest
extends AbstractClassifier<Tuple>
implements DataFrameClassifier, TreeSHAP
Random forest for classification. Random forest is an ensemble classifier
that consists of many decision trees and outputs the majority vote of
individual trees. The method combines bagging idea and the random
selection of features.
Each tree is constructed using the following algorithm:
- If the number of cases in the training set is N, randomly sample N cases with replacement from the original data. This sample will be the training set for growing the tree.
- If there are M input variables, a number
m << M
is specified such that at each node, m variables are selected at random out of the M and the best split on these m is used to split the node. The value of m is held constant during the forest growing. - Each tree is grown to the largest extent possible. There is no pruning.
- For many data sets, it produces a highly accurate classifier.
- It runs efficiently on large data sets.
- It can handle thousands of input variables without variable deletion.
- It gives estimates of what variables are important in the classification.
- It generates an internal unbiased estimate of the generalization error as the forest building progresses.
- It has an effective method for estimating missing data and maintains accuracy when a large proportion of the data are missing.
- Random forests are prone to over-fitting for some datasets. This is even more pronounced on noisy data.
- For data including categorical variables with different number of levels, random forests are biased in favor of those attributes with more levels. Therefore, the variable importance scores from random forest are not reliable for this type of data.
- See Also:
-
Nested Class Summary
Nested classes/interfaces inherited from interface smile.classification.Classifier
Classifier.Trainer<T,
M extends Classifier<T>> Nested classes/interfaces inherited from interface smile.classification.DataFrameClassifier
DataFrameClassifier.Trainer<M extends DataFrameClassifier>
-
Field Summary
Fields inherited from class smile.classification.AbstractClassifier
classes
-
Constructor Summary
ConstructorDescriptionRandomForest
(Formula formula, int k, RandomForest.Model[] models, ClassificationMetrics metrics, double[] importance) Constructor.RandomForest
(Formula formula, int k, RandomForest.Model[] models, ClassificationMetrics metrics, double[] importance, IntSet labels) Constructor. -
Method Summary
Modifier and TypeMethodDescriptionstatic RandomForest
Fits a random forest for classification.static RandomForest
fit
(Formula formula, DataFrame data, int ntrees, int mtry, SplitRule rule, int maxDepth, int maxNodes, int nodeSize, double subsample) Fits a random forest for classification.static RandomForest
fit
(Formula formula, DataFrame data, int ntrees, int mtry, SplitRule rule, int maxDepth, int maxNodes, int nodeSize, double subsample, int[] classWeight) Fits a random forest for regression.static RandomForest
fit
(Formula formula, DataFrame data, int ntrees, int mtry, SplitRule rule, int maxDepth, int maxNodes, int nodeSize, double subsample, int[] classWeight, LongStream seeds) Fits a random forest for classification.static RandomForest
fit
(Formula formula, DataFrame data, Properties params) Fits a random forest for classification.formula()
Returns the formula associated with the model.double[]
Returns the variable importance.merge
(RandomForest other) Merges two random forests.metrics()
Returns the overall out-of-bag metric estimations.models()
Returns the base models.int
Predicts the class label of an instance.int
Predicts the class label of an instance and also calculate a posteriori probabilities.Returns a new random forest by reduced error pruning.schema()
Returns the predictor schema.int
size()
Returns the number of trees in the model.boolean
soft()
Returns true if this is a soft classifier that can estimate the posteriori probabilities of classification.int[][]
Test the model on a validation dataset.trees()
Returns the decision trees.trim
(int ntrees) Trims the tree model set to a smaller size in case of over-fitting.int
Predict and estimate the probability by voting.Methods inherited from class smile.classification.AbstractClassifier
classes, numClasses
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
Methods inherited from interface smile.classification.Classifier
applyAsDouble, applyAsInt, classes, numClasses, online, predict, predict, predict, predict, predict, predict, score, update, update, update
Methods inherited from interface smile.classification.DataFrameClassifier
predict, predict
-
Constructor Details
-
RandomForest
public RandomForest(Formula formula, int k, RandomForest.Model[] models, ClassificationMetrics metrics, double[] importance) Constructor.- Parameters:
formula
- a symbolic description of the model to be fitted.k
- the number of classes.models
- forest of decision trees.metrics
- the overall out-of-bag metric estimation.importance
- the feature importance.
-
RandomForest
public RandomForest(Formula formula, int k, RandomForest.Model[] models, ClassificationMetrics metrics, double[] importance, IntSet labels) Constructor.- Parameters:
formula
- a symbolic description of the model to be fitted.k
- the number of classes.models
- the base models.metrics
- the overall out-of-bag metric estimation.importance
- the feature importance.labels
- the class label encoder.
-
-
Method Details
-
fit
Fits a random forest for classification.- Parameters:
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.- Returns:
- the model.
-
fit
Fits a random forest for classification.- Parameters:
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.params
- the hyper-parameters.- Returns:
- the model.
-
fit
public static RandomForest fit(Formula formula, DataFrame data, int ntrees, int mtry, SplitRule rule, int maxDepth, int maxNodes, int nodeSize, double subsample) Fits a random forest for classification.- Parameters:
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.ntrees
- the number of trees.mtry
- the number of input variables to be used to determine the decision at a node of the tree. floor(sqrt(p)) generally gives good performance, where p is the number of variables.rule
- Decision tree split rule.maxDepth
- the maximum depth of the tree.maxNodes
- the maximum number of leaf nodes in the tree.nodeSize
- the number of instances in a node below which the tree will not split, nodeSize = 5 generally gives good results.subsample
- the sampling rate for training tree. 1.0 means sampling with replacement.< 1.0
means sampling without replacement.- Returns:
- the model.
-
fit
public static RandomForest fit(Formula formula, DataFrame data, int ntrees, int mtry, SplitRule rule, int maxDepth, int maxNodes, int nodeSize, double subsample, int[] classWeight) Fits a random forest for regression.- Parameters:
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.ntrees
- the number of trees.mtry
- the number of input variables to be used to determine the decision at a node of the tree. floor(sqrt(p)) generally gives good performance, where p is the number of variablesrule
- Decision tree split rule.maxDepth
- the maximum depth of the tree.maxNodes
- the maximum number of leaf nodes in the tree.nodeSize
- the number of instances in a node below which the tree will not split, nodeSize = 5 generally gives good results.subsample
- the sampling rate for training tree. 1.0 means sampling with replacement.< 1.0
means sampling without replacement.classWeight
- Priors of the classes. The weight of each class is roughly the ratio of samples in each class. For example, if there are 400 positive samples and 100 negative samples, the classWeight should be [1, 4] (assuming label 0 is of negative, label 1 is of positive).- Returns:
- the model.
-
fit
public static RandomForest fit(Formula formula, DataFrame data, int ntrees, int mtry, SplitRule rule, int maxDepth, int maxNodes, int nodeSize, double subsample, int[] classWeight, LongStream seeds) Fits a random forest for classification.- Parameters:
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.ntrees
- the number of trees.mtry
- the number of input variables to be used to determine the decision at a node of the tree. floor(sqrt(p)) generally gives good performance, where p is the number of variables.rule
- Decision tree split rule.maxDepth
- the maximum depth of the tree.maxNodes
- the maximum number of leaf nodes in the tree.nodeSize
- the number of instances in a node below which the tree will not split, nodeSize = 5 generally gives good results.subsample
- the sampling rate for training tree. 1.0 means sampling with replacement.< 1.0
means sampling without replacement.classWeight
- Priors of the classes. The weight of each class is roughly the ratio of samples in each class. For example, if there are 400 positive samples and 100 negative samples, the classWeight should be [1, 4] (assuming label 0 is of negative, label 1 is of positive).seeds
- optional RNG seeds for each regression tree.- Returns:
- the model.
-
formula
Description copied from interface:DataFrameClassifier
Returns the formula associated with the model.- Specified by:
formula
in interfaceDataFrameClassifier
- Specified by:
formula
in interfaceTreeSHAP
- Returns:
- the formula associated with the model.
-
schema
Description copied from interface:DataFrameClassifier
Returns the predictor schema.- Specified by:
schema
in interfaceDataFrameClassifier
- Returns:
- the predictor schema.
-
metrics
Returns the overall out-of-bag metric estimations. The OOB estimate is quite accurate given that enough trees have been grown. Otherwise, the OOB error estimate can bias upward.- Returns:
- the out-of-bag metrics estimations.
-
importance
public double[] importance()Returns the variable importance. Every time a split of a node is made on variable the (GINI, information gain, etc.) impurity criterion for the two descendent nodes is less than the parent node. Adding up the decreases for each individual variable over all trees in the forest gives a fast measure of variable importance that is often very consistent with the permutation importance measure.- Returns:
- the variable importance
-
size
public int size()Returns the number of trees in the model.- Returns:
- the number of trees in the model.
-
models
Returns the base models.- Returns:
- the base models.
-
trees
Description copied from interface:TreeSHAP
Returns the decision trees. -
trim
Trims the tree model set to a smaller size in case of over-fitting. Or if extra decision trees in the model don't improve the performance, we may remove them to reduce the model size and also improve the speed of prediction.- Parameters:
ntrees
- the new (smaller) size of tree model set.- Returns:
- a new trimmed forest.
-
merge
Merges two random forests.- Parameters:
other
- the other forest to merge with.- Returns:
- the merged forest.
-
predict
Description copied from interface:Classifier
Predicts the class label of an instance.- Specified by:
predict
in interfaceClassifier<Tuple>
- Parameters:
x
- the instance to be classified.- Returns:
- the predicted class label.
-
soft
public boolean soft()Description copied from interface:Classifier
Returns true if this is a soft classifier that can estimate the posteriori probabilities of classification.- Specified by:
soft
in interfaceClassifier<Tuple>
- Returns:
- true if soft classifier.
-
predict
Description copied from interface:Classifier
Predicts the class label of an instance and also calculate a posteriori probabilities. Classifiers may NOT support this method since not all classification algorithms are able to calculate such a posteriori probabilities.- Specified by:
predict
in interfaceClassifier<Tuple>
- Parameters:
x
- an instance to be classified.posteriori
- a posteriori probabilities on output.- Returns:
- the predicted class label
-
vote
Predict and estimate the probability by voting.- Parameters:
x
- the instances to be classified.posteriori
- a posteriori probabilities on output.- Returns:
- the predicted class labels.
-
test
Test the model on a validation dataset.- Parameters:
data
- the test data set.- Returns:
- the predictions with first 1, 2, ..., decision trees.
-
prune
Returns a new random forest by reduced error pruning.- Parameters:
test
- the test data set to evaluate the errors of nodes.- Returns:
- a new pruned random forest.
-