Skip to contents

Random forest for classification. Calls randomForestSRC::rfsrc() from randomForestSRC.

Dictionary

This Learner can be instantiated via lrn():

lrn("classif.rfsrc")

Meta Information

  • Task type: “classif”

  • Predict Types: “response”, “prob”

  • Feature Types: “logical”, “integer”, “numeric”, “factor”

  • Required Packages: mlr3, mlr3extralearners, randomForestSRC

Parameters

IdTypeDefaultLevelsRange
ntreeinteger500\([1, \infty)\)
mtryinteger-\([1, \infty)\)
mtry.rationumeric-\([0, 1]\)
nodesizeinteger15\([1, \infty)\)
nodedepthinteger-\([1, \infty)\)
splitrulecharacterginigini, auc, entropy-
nsplitinteger10\([0, \infty)\)
importancecharacterFALSEFALSE, TRUE, none, permute, random, anti-
block.sizeinteger10\([1, \infty)\)
bootstrapcharacterby.rootby.root, by.node, none, by.user-
samptypecharactersworswor, swr-
sampuntyped--
membershiplogicalFALSETRUE, FALSE-
sampsizeuntyped--
sampsize.rationumeric-\([0, 1]\)
na.actioncharacterna.omitna.omit, na.impute-
nimputeinteger1\([1, \infty)\)
ntimeinteger-\([1, \infty)\)
causeinteger-\([1, \infty)\)
proximitycharacterFALSEFALSE, TRUE, inbag, oob, all-
distancecharacterFALSEFALSE, TRUE, inbag, oob, all-
forest.wtcharacterFALSEFALSE, TRUE, inbag, oob, all-
xvar.wtuntyped--
split.wtuntyped--
forestlogicalTRUETRUE, FALSE-
var.usedcharacterFALSEFALSE, all.trees, by.tree-
split.depthcharacterFALSEFALSE, all.trees, by.tree-
seedinteger-\((-\infty, -1]\)
do.tracelogicalFALSETRUE, FALSE-
statisticslogicalFALSETRUE, FALSE-
get.treeuntyped--
outcomecharactertraintrain, test-
ptn.countinteger0\([0, \infty)\)
coresinteger1\([1, \infty)\)
save.memorylogicalFALSETRUE, FALSE-
perf.typecharacter-gmean, misclass, brier, none-
case.depthlogicalFALSETRUE, FALSE-

Custom mlr3 parameters

  • mtry: This hyperparameter can alternatively be set via the added hyperparameter mtry.ratio as mtry = max(ceiling(mtry.ratio * n_features), 1). Note that mtry and mtry.ratio are mutually exclusive.

  • sampsize: This hyperparameter can alternatively be set via the added hyperparameter sampsize.ratio as sampsize = max(ceiling(sampsize.ratio * n_obs), 1). Note that sampsize and sampsize.ratio are mutually exclusive.

  • cores: This value is set as the option rf.cores during training and is set to 1 by default.

References

Breiman, Leo (2001). “Random Forests.” Machine Learning, 45(1), 5–32. ISSN 1573-0565, doi:10.1023/A:1010933404324 .

See also

Author

RaphaelS1

Super classes

mlr3::Learner -> mlr3::LearnerClassif -> LearnerClassifRandomForestSRC

Methods

Inherited methods


Method new()

Creates a new instance of this R6 class.


Method importance()

The importance scores are extracted from the model slot importance, returned for 'all'.

Usage

LearnerClassifRandomForestSRC$importance()

Returns

Named numeric().


Method selected_features()

Selected features are extracted from the model slot var.used.

Usage

LearnerClassifRandomForestSRC$selected_features()

Returns

character().


Method oob_error()

OOB error extracted from the model slot err.rate.

Usage

LearnerClassifRandomForestSRC$oob_error()

Returns

numeric().


Method clone()

The objects of this class are cloneable with this method.

Usage

LearnerClassifRandomForestSRC$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

# Define the Learner
learner = mlr3::lrn("classif.rfsrc", importance = "TRUE")
print(learner)
#> <LearnerClassifRandomForestSRC:classif.rfsrc>: Random Forest
#> * Model: -
#> * Parameters: importance=TRUE
#> * Packages: mlr3, mlr3extralearners, randomForestSRC
#> * Predict Types:  [response], prob
#> * Feature Types: logical, integer, numeric, factor
#> * Properties: importance, missings, multiclass, oob_error, twoclass,
#>   weights

# Define a Task
task = mlr3::tsk("sonar")
# Create train and test set
ids = mlr3::partition(task)

# Train the learner on the training ids
learner$train(task, row_ids = ids$train)

print(learner$model)
#>                          Sample size: 139
#>            Frequency of class labels: 74, 65
#>                      Number of trees: 500
#>            Forest terminal node size: 1
#>        Average no. of terminal nodes: 16.76
#> No. of variables tried at each split: 8
#>               Total no. of variables: 60
#>        Resampling used to grow trees: swor
#>     Resample size used to grow trees: 88
#>                             Analysis: RF-C
#>                               Family: class
#>                       Splitting rule: gini *random*
#>        Number of random split points: 10
#>                     Imbalanced ratio: 1.1385
#>                    (OOB) Brier score: 0.13325482
#>         (OOB) Normalized Brier score: 0.53301927
#>                            (OOB) AUC: 0.91683992
#>                       (OOB) Log-loss: 0.42236712
#>                         (OOB) PR-AUC: 0.90714404
#>                         (OOB) G-mean: 0.79473318
#>    (OOB) Requested performance error: 0.20143885, 0.16216216, 0.24615385
#> 
#> Confusion matrix:
#> 
#>           predicted
#>   observed  M  R class.error
#>          M 62 12      0.1622
#>          R 16 49      0.2462
#> 
#>       (OOB) Misclassification rate: 0.2014388
print(learner$importance())
#>          V11           V9          V48          V47          V45          V10 
#> 0.1010851367 0.0614457394 0.0373809371 0.0326200423 0.0318032166 0.0315247926 
#>          V17          V12          V16          V46          V49           V5 
#> 0.0264013494 0.0233198607 0.0232601185 0.0229356985 0.0208056778 0.0199518511 
#>          V51          V28          V15          V52          V19          V36 
#> 0.0160038551 0.0158525164 0.0154674949 0.0147749884 0.0143662539 0.0133352673 
#>          V32          V18          V37           V4          V20          V39 
#> 0.0120487966 0.0118954535 0.0116076976 0.0107392220 0.0100327281 0.0099681902 
#>          V27          V24          V33          V43          V21           V1 
#> 0.0097075991 0.0092811252 0.0088621710 0.0088317361 0.0085802992 0.0082939182 
#>          V58          V55          V22           V6          V44          V29 
#> 0.0079993295 0.0072656342 0.0071271807 0.0070869716 0.0069781749 0.0069420285 
#>          V31          V14          V54          V34          V26           V8 
#> 0.0068300555 0.0068191532 0.0066666975 0.0065374383 0.0065256729 0.0063944993 
#>          V53          V23          V30          V42           V2          V40 
#> 0.0063655898 0.0058333821 0.0056532733 0.0055121152 0.0049107098 0.0043715856 
#>           V3          V56          V38          V59           V7          V60 
#> 0.0042073469 0.0039078275 0.0033410306 0.0033344189 0.0031952486 0.0029058894 
#>          V35          V25          V57          V50          V13          V41 
#> 0.0023171589 0.0023168141 0.0019106823 0.0017436146 0.0005965298 0.0004497040 

# Make predictions for the test rows
predictions = learner$predict(task, row_ids = ids$test)

# Score the predictions
predictions$score()
#> classif.ce 
#>  0.2173913