Class RandomForest
java.lang.Object
smile.regression.RandomForest
- All Implemented Interfaces:
Serializable, ToDoubleFunction<Tuple>, SHAP<Tuple>, TreeSHAP, DataFrameRegression, Regression<Tuple>
Random forest for regression. Random forest is an ensemble method that
consists of many regression trees and outputs the average of individual
trees. The method combines bagging idea and the random selection of features.
Each tree is constructed using the following algorithm:
- If the number of cases in the training set is N, randomly sample N cases with replacement from the original data. This sample will be the training set for growing the tree.
- If there are M input variables, a number
m << Mis specified such that at each node, m variables are selected at random out of the M and the best split on these m is used to split the node. The value of m is held constant during the forest growing. - Each tree is grown to the largest extent possible. There is no pruning.
- For many data sets, it produces a highly accurate model.
- It runs efficiently on large data sets.
- It can handle thousands of input variables without variable deletion.
- It gives estimates of what variables are important in the classification.
- It generates an internal unbiased estimate of the generalization error as the forest building progresses.
- It has an effective method for estimating missing data and maintains accuracy when a large proportion of the data are missing.
- Random forests are prone to over-fitting for some datasets. This is even more pronounced in noisy classification/regression tasks.
- For data including categorical variables with different number of levels, random forests are biased in favor of those attributes with more levels. Therefore, the variable importance scores from random forest are not reliable for this type of data.
- See Also:
-
Nested Class Summary
Nested ClassesModifier and TypeClassDescriptionstatic final recordThe base model.static final recordRandom forest hyperparameters.static final recordTraining status per tree.Nested classes/interfaces inherited from interface DataFrameRegression
DataFrameRegression.Trainer<M> -
Constructor Summary
ConstructorsConstructorDescriptionRandomForest(Formula formula, RandomForest.Model[] models, RegressionMetrics metrics, double[] importance) Constructor. -
Method Summary
Modifier and TypeMethodDescriptionstatic RandomForestFits a random forest for regression.static RandomForestfit(Formula formula, DataFrame data, RandomForest.Options options) Fits a random forest for regression.formula()Returns the model formula.double[]Returns the variable importance.merge(RandomForest other) Merges two random forests.metrics()Returns the overall out-of-bag metric estimations.models()Returns the base models.doublePredicts the dependent variable of an instance.schema()Returns the schema of predictors.intsize()Returns the number of trees in the model.double[][]Test the model on a validation dataset.trees()Returns the decision trees.trim(int ntrees) Trims the tree model set to a smaller size in case of over-fitting.Methods inherited from class Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, waitMethods inherited from interface DataFrameRegression
predictMethods inherited from interface Regression
applyAsDouble, online, predict, predict, predict, update, update, update
-
Constructor Details
-
RandomForest
public RandomForest(Formula formula, RandomForest.Model[] models, RegressionMetrics metrics, double[] importance) Constructor.- Parameters:
formula- a symbolic description of the model to be fitted.models- the base models.metrics- the overall out-of-bag metric estimations.importance- the feature importance.
-
-
Method Details
-
fit
Fits a random forest for regression.- Parameters:
formula- a symbolic description of the model to be fitted.data- the data frame of the explanatory and response variables.- Returns:
- the model.
-
fit
Fits a random forest for regression.- Parameters:
formula- a symbolic description of the model to be fitted.data- the data frame of the explanatory and response variables.options- the hyperparameters.- Returns:
- the model.
-
formula
Description copied from interface:DataFrameRegressionReturns the model formula.- Specified by:
formulain interfaceDataFrameRegression- Specified by:
formulain interfaceTreeSHAP- Returns:
- the model formula.
-
schema
Description copied from interface:DataFrameRegressionReturns the schema of predictors.- Specified by:
schemain interfaceDataFrameRegression- Returns:
- the schema of predictors.
-
metrics
Returns the overall out-of-bag metric estimations. The OOB estimate is quite accurate given that enough trees have been grown. Otherwise, the OOB error estimate can bias upward.- Returns:
- the overall out-of-bag metric estimations.
-
importance
public double[] importance()Returns the variable importance. Every time a split of a node is made on variable the impurity criterion for the two descendent nodes is less than the parent node. Adding up the decreases for each individual variable over all trees in the forest gives a fast measure of variable importance that is often very consistent with the permutation importance measure.- Returns:
- the variable importance
-
size
public int size()Returns the number of trees in the model.- Returns:
- the number of trees in the model
-
models
-
trees
Description copied from interface:TreeSHAPReturns the decision trees. -
trim
Trims the tree model set to a smaller size in case of over-fitting. Or if extra decision trees in the model don't improve the performance, we may remove them to reduce the model size and also improve the speed of prediction.- Parameters:
ntrees- the new (smaller) size of tree model set.- Returns:
- the trimmed model.
-
merge
Merges two random forests.- Parameters:
other- the model to merge with.- Returns:
- the merged model.
-
predict
Description copied from interface:RegressionPredicts the dependent variable of an instance.- Specified by:
predictin interfaceRegression<Tuple>- Parameters:
x- an instance.- Returns:
- the predicted value of dependent variable.
-
test
Test the model on a validation dataset.- Parameters:
data- the test data set.- Returns:
- the predictions with first 1, 2, ..., regression trees.
-