Skip to contents

Random survival forest. Calls randomForestSRC::rfsrc() from randomForestSRC.

Prediction types

This learner returns two prediction types:

  1. distr: a survival matrix in two dimensions, where observations are represented in rows and (unique event) time points in columns. Calculated using the internal randomForestSRC::predict.rfsrc() function.

  2. crank: the expected mortality using mlr3proba::.surv_return().

Dictionary

This Learner can be instantiated via lrn():

lrn("surv.rfsrc")

Meta Information

Parameters

IdTypeDefaultLevelsRange
ntreeinteger500\([1, \infty)\)
mtryinteger-\([1, \infty)\)
mtry.rationumeric-\([0, 1]\)
nodesizeinteger15\([1, \infty)\)
nodedepthinteger-\([1, \infty)\)
splitrulecharacterlogranklogrank, bs.gradient, logrankscore-
nsplitinteger10\([0, \infty)\)
importancecharacterFALSEFALSE, TRUE, none, permute, random, anti-
block.sizeinteger10\([1, \infty)\)
bootstrapcharacterby.rootby.root, by.node, none, by.user-
samptypecharactersworswor, swr-
sampuntyped--
membershiplogicalFALSETRUE, FALSE-
sampsizeuntyped--
sampsize.rationumeric-\([0, 1]\)
na.actioncharacterna.omitna.omit, na.impute-
nimputeinteger1\([1, \infty)\)
ntimeinteger150\([0, \infty)\)
causeinteger-\([1, \infty)\)
proximitycharacterFALSEFALSE, TRUE, inbag, oob, all-
distancecharacterFALSEFALSE, TRUE, inbag, oob, all-
forest.wtcharacterFALSEFALSE, TRUE, inbag, oob, all-
xvar.wtuntyped--
split.wtuntyped--
forestlogicalTRUETRUE, FALSE-
var.usedcharacterFALSEFALSE, all.trees, by.tree-
split.depthcharacterFALSEFALSE, all.trees, by.tree-
seedinteger-\((-\infty, -1]\)
do.tracelogicalFALSETRUE, FALSE-
statisticslogicalFALSETRUE, FALSE-
get.treeuntyped--
outcomecharactertraintrain, test-
ptn.countinteger0\([0, \infty)\)
estimatorcharacternelsonnelson, kaplan-
coresinteger1\([1, \infty)\)
save.memorylogicalFALSETRUE, FALSE-
perf.typecharacter-none-
case.depthlogicalFALSETRUE, FALSE-

Custom mlr3 parameters

  • estimator: Hidden parameter that controls the type of estimator used to derive the survival function during prediction. The default value is "chf" which uses a bootstrapped Nelson-Aalen estimator for the cumulative hazard function \(H(t)\), (Ishwaran, 2008) from which we calculate \(S(t) = \exp(-H(t))\), whereas "surv" uses a bootstrapped Kaplan-Meier estimator to directly estimate \(S(t)\).

  • mtry: This hyperparameter can alternatively be set via the added hyperparameter mtry.ratio as mtry = max(ceiling(mtry.ratio * n_features), 1). Note that mtry and mtry.ratio are mutually exclusive.

  • sampsize: This hyperparameter can alternatively be set via the added hyperparameter sampsize.ratio as sampsize = max(ceiling(sampsize.ratio * n_obs), 1). Note that sampsize and sampsize.ratio are mutually exclusive.

  • cores: This value is set as the option rf.cores during training and is set to 1 by default.

Initial parameter values

  • ntime: Number of time points to coerce the observed event times for use in the estimated survival function during prediction. We changed the default value of 150 to 0 in order to be in line with other random survival forest learners and use all the unique event times from the train set.

References

Ishwaran H, Kogalur UB, Blackstone EH, Lauer MS (2008). “Random survival forests.” The Annals of Applied Statistics, 2(3). doi:10.1214/08-aoas169 , https://doi.org/10.1214/08-aoas169.

Breiman, Leo (2001). “Random Forests.” Machine Learning, 45(1), 5–32. ISSN 1573-0565, doi:10.1023/A:1010933404324 .

See also

Author

RaphaelS1

Super classes

mlr3::Learner -> mlr3proba::LearnerSurv -> LearnerSurvRandomForestSRC

Methods

Inherited methods


Method new()

Creates a new instance of this R6 class.


Method importance()

The importance scores are extracted from the model slot importance.

Usage

LearnerSurvRandomForestSRC$importance()

Returns

Named numeric().


Method selected_features()

Selected features are extracted from the model slot var.used.

Usage

LearnerSurvRandomForestSRC$selected_features()

Returns

character().


Method oob_error()

OOB error extracted from the model slot err.rate.

Usage

LearnerSurvRandomForestSRC$oob_error()

Returns

numeric().


Method clone()

The objects of this class are cloneable with this method.

Usage

LearnerSurvRandomForestSRC$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

# Define the Learner
learner = mlr3::lrn("surv.rfsrc", importance = "TRUE")
print(learner)
#> <LearnerSurvRandomForestSRC:surv.rfsrc>: Random Forest
#> * Model: -
#> * Parameters: importance=TRUE, ntime=0
#> * Packages: mlr3, mlr3proba, mlr3extralearners, randomForestSRC, pracma
#> * Predict Types:  [crank], distr
#> * Feature Types: logical, integer, numeric, factor
#> * Properties: importance, missings, oob_error, weights

# Define a Task
task = mlr3::tsk("grace")

# Create train and test set
ids = mlr3::partition(task)

# Train the learner on the training ids
learner$train(task, row_ids = ids$train)

print(learner$model)
#>                          Sample size: 670
#>                     Number of deaths: 206
#>                      Number of trees: 500
#>            Forest terminal node size: 15
#>        Average no. of terminal nodes: 27.05
#> No. of variables tried at each split: 3
#>               Total no. of variables: 6
#>        Resampling used to grow trees: swor
#>     Resample size used to grow trees: 423
#>                             Analysis: RSF
#>                               Family: surv
#>                       Splitting rule: logrank *random*
#>        Number of random split points: 10
#>                           (OOB) CRPS: 16.73000411
#>                    (OOB) stand. CRPS: 0.09560002
#>    (OOB) Requested performance error: 0.16039056
#> 
print(learner$importance())
#>  revascdays      revasc         age       sysbp         los    stchange 
#> 0.418851933 0.269610078 0.178833080 0.070602211 0.051400184 0.005023546 

# Make predictions for the test rows
predictions = learner$predict(task, row_ids = ids$test)

# Score the predictions
predictions$score()
#> surv.cindex 
#>   0.8326797