Skip to contents

Random survival forest. Calls ranger::ranger() from package ranger.

Prediction types

This learner returns two prediction types:

  1. distr: a survival matrix in two dimensions, where observations are represented in rows and (unique event) time points in columns. Calculated using the internal ranger::predict.ranger() function.

  2. crank: the expected mortality using mlr3proba::.surv_return().

Custom mlr3 parameters

  • mtry: This hyperparameter can alternatively be set via our hyperparameter mtry.ratio as mtry = max(ceiling(mtry.ratio * n_features), 1). Note that mtry and mtry.ratio are mutually exclusive.

Initial parameter values

  • num.threads is initialized to 1 to avoid conflicts with parallelization via future.

Dictionary

This Learner can be instantiated via lrn():

lrn("surv.ranger")

Meta Information

  • Task type: “surv”

  • Predict Types: “crank”, “distr”

  • Feature Types: “logical”, “integer”, “numeric”, “character”, “factor”, “ordered”

  • Required Packages: mlr3, mlr3proba, mlr3extralearners, ranger

Parameters

IdTypeDefaultLevelsRange
alphanumeric0.5\((-\infty, \infty)\)
always.split.variablesuntyped--
holdoutlogicalFALSETRUE, FALSE-
importancecharacter-none, impurity, impurity_corrected, permutation-
keep.inbaglogicalFALSETRUE, FALSE-
max.depthintegerNULL\([0, \infty)\)
min.node.sizeinteger5\([1, \infty)\)
minpropnumeric0.1\((-\infty, \infty)\)
mtryinteger-\([1, \infty)\)
mtry.rationumeric-\([0, 1]\)
num.random.splitsinteger1\([1, \infty)\)
num.threadsinteger1\([1, \infty)\)
num.treesinteger500\([1, \infty)\)
oob.errorlogicalTRUETRUE, FALSE-
regularization.factoruntyped1-
regularization.usedepthlogicalFALSETRUE, FALSE-
replacelogicalTRUETRUE, FALSE-
respect.unordered.factorscharacterignoreignore, order, partition-
sample.fractionnumeric-\([0, 1]\)
save.memorylogicalFALSETRUE, FALSE-
scale.permutation.importancelogicalFALSETRUE, FALSE-
seedintegerNULL\((-\infty, \infty)\)
split.select.weightsnumeric-\([0, 1]\)
splitrulecharacterlogranklogrank, extratrees, C, maxstat-
verboselogicalTRUETRUE, FALSE-
write.forestlogicalTRUETRUE, FALSE-
min.bucketinteger3\((-\infty, \infty)\)
time.interestintegerNULL\([1, \infty)\)
node.statslogicalFALSETRUE, FALSE-

References

Wright, N. M, Ziegler, Andreas (2017). “ranger: A Fast Implementation of Random Forests for High Dimensional Data in C++ and R.” Journal of Statistical Software, 77(1), 1–17. doi:10.18637/jss.v077.i01 .

Breiman, Leo (2001). “Random Forests.” Machine Learning, 45(1), 5–32. ISSN 1573-0565, doi:10.1023/A:1010933404324 .

See also

Author

be-marc

Super classes

mlr3::Learner -> mlr3proba::LearnerSurv -> LearnerSurvRanger

Methods

Inherited methods


Method new()

Creates a new instance of this R6 class.

Usage


Method importance()

The importance scores are extracted from the model slot variable.importance.

Usage

LearnerSurvRanger$importance()

Returns

Named numeric().


Method oob_error()

The out-of-bag error is extracted from the model slot prediction.error.

Usage

LearnerSurvRanger$oob_error()

Returns

numeric(1).


Method clone()

The objects of this class are cloneable with this method.

Usage

LearnerSurvRanger$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

# Define the Learner
learner = mlr3::lrn("surv.ranger", importance = "permutation")
print(learner)
#> <LearnerSurvRanger:surv.ranger>: Random Forest
#> * Model: -
#> * Parameters: importance=permutation, num.threads=1
#> * Packages: mlr3, mlr3proba, mlr3extralearners, ranger
#> * Predict Types:  [crank], distr
#> * Feature Types: logical, integer, numeric, character, factor, ordered
#> * Properties: importance, oob_error, weights

# Define a Task
task = mlr3::tsk("grace")

# Create train and test set
ids = mlr3::partition(task)

# Train the learner on the training ids
learner$train(task, row_ids = ids$train)

print(learner$model)
#> Ranger result
#> 
#> Call:
#>  ranger::ranger(formula = NULL, dependent.variable.name = targets[1L],      status.variable.name = targets[2L], data = task$data(), case.weights = task$weights$weight,      importance = "permutation", num.threads = 1L) 
#> 
#> Type:                             Survival 
#> Number of trees:                  500 
#> Sample size:                      670 
#> Number of independent variables:  6 
#> Mtry:                             3 
#> Target node size:                 3 
#> Variable importance mode:         permutation 
#> Splitrule:                        logrank 
#> Number of unique death times:     89 
#> OOB prediction error (1-C):       0.1635784 
print(learner$importance())
#> revascdays     revasc        age        los      sysbp   stchange 
#> 0.15240296 0.07015780 0.03495694 0.02346363 0.01186921 0.00554037 

# Make predictions for the test rows
predictions = learner$predict(task, row_ids = ids$test)

# Score the predictions
predictions$score()
#> surv.cindex 
#>   0.8356969