Survival Random Forest SRC Learner
mlr_learners_surv.rfsrc.Rd
Random survival forest.
Calls randomForestSRC::rfsrc()
from randomForestSRC.
Prediction types
This learner returns two prediction types:
distr
: a survival matrix in two dimensions, where observations are represented in rows and (unique event) time points in columns. Calculated using the internalrandomForestSRC::predict.rfsrc()
function.crank
: the expected mortality usingmlr3proba::.surv_return()
.
Meta Information
Task type: “surv”
Predict Types: “crank”, “distr”
Feature Types: “logical”, “integer”, “numeric”, “factor”
Required Packages: mlr3, mlr3proba, mlr3extralearners, randomForestSRC, pracma
Parameters
Id | Type | Default | Levels | Range |
ntree | integer | 500 | \([1, \infty)\) | |
mtry | integer | - | \([1, \infty)\) | |
mtry.ratio | numeric | - | \([0, 1]\) | |
nodesize | integer | 15 | \([1, \infty)\) | |
nodedepth | integer | - | \([1, \infty)\) | |
splitrule | character | logrank | logrank, bs.gradient, logrankscore | - |
nsplit | integer | 10 | \([0, \infty)\) | |
importance | character | FALSE | FALSE, TRUE, none, permute, random, anti | - |
block.size | integer | 10 | \([1, \infty)\) | |
bootstrap | character | by.root | by.root, by.node, none, by.user | - |
samptype | character | swor | swor, swr | - |
samp | untyped | - | - | |
membership | logical | FALSE | TRUE, FALSE | - |
sampsize | untyped | - | - | |
sampsize.ratio | numeric | - | \([0, 1]\) | |
na.action | character | na.omit | na.omit, na.impute | - |
nimpute | integer | 1 | \([1, \infty)\) | |
ntime | integer | 150 | \([0, \infty)\) | |
cause | integer | - | \([1, \infty)\) | |
proximity | character | FALSE | FALSE, TRUE, inbag, oob, all | - |
distance | character | FALSE | FALSE, TRUE, inbag, oob, all | - |
forest.wt | character | FALSE | FALSE, TRUE, inbag, oob, all | - |
xvar.wt | untyped | - | - | |
split.wt | untyped | - | - | |
forest | logical | TRUE | TRUE, FALSE | - |
var.used | character | FALSE | FALSE, all.trees, by.tree | - |
split.depth | character | FALSE | FALSE, all.trees, by.tree | - |
seed | integer | - | \((-\infty, -1]\) | |
do.trace | logical | FALSE | TRUE, FALSE | - |
statistics | logical | FALSE | TRUE, FALSE | - |
get.tree | untyped | - | - | |
outcome | character | train | train, test | - |
ptn.count | integer | 0 | \([0, \infty)\) | |
estimator | character | nelson | nelson, kaplan | - |
cores | integer | 1 | \([1, \infty)\) | |
save.memory | logical | FALSE | TRUE, FALSE | - |
perf.type | character | - | none | - |
case.depth | logical | FALSE | TRUE, FALSE | - |
Custom mlr3 parameters
estimator
: Hidden parameter that controls the type of estimator used to derive the survival function during prediction. The default value is"chf"
which uses a bootstrapped Nelson-Aalen estimator for the cumulative hazard function \(H(t)\), (Ishwaran, 2008) from which we calculate \(S(t) = \exp(-H(t))\), whereas"surv"
uses a bootstrapped Kaplan-Meier estimator to directly estimate \(S(t)\).
mtry
: This hyperparameter can alternatively be set via the added hyperparametermtry.ratio
asmtry = max(ceiling(mtry.ratio * n_features), 1)
. Note thatmtry
andmtry.ratio
are mutually exclusive.sampsize
: This hyperparameter can alternatively be set via the added hyperparametersampsize.ratio
assampsize = max(ceiling(sampsize.ratio * n_obs), 1)
. Note thatsampsize
andsampsize.ratio
are mutually exclusive.cores
: This value is set as the optionrf.cores
during training and is set to 1 by default.
Initial parameter values
ntime
: Number of time points to coerce the observed event times for use in the estimated survival function during prediction. We changed the default value of150
to0
in order to be in line with other random survival forest learners and use all the unique event times from the train set.
References
Ishwaran H, Kogalur UB, Blackstone EH, Lauer MS (2008). “Random survival forests.” The Annals of Applied Statistics, 2(3). doi:10.1214/08-aoas169 , https://doi.org/10.1214/08-aoas169.
Breiman, Leo (2001). “Random Forests.” Machine Learning, 45(1), 5–32. ISSN 1573-0565, doi:10.1023/A:1010933404324 .
See also
as.data.table(mlr_learners)
for a table of available Learners in the running session (depending on the loaded packages).Chapter in the mlr3book: https://mlr3book.mlr-org.com/basics.html#learners
mlr3learners for a selection of recommended learners.
mlr3cluster for unsupervised clustering learners.
mlr3pipelines to combine learners with pre- and postprocessing steps.
mlr3tuning for tuning of hyperparameters, mlr3tuningspaces for established default tuning spaces.
Super classes
mlr3::Learner
-> mlr3proba::LearnerSurv
-> LearnerSurvRandomForestSRC
Methods
Method importance()
The importance scores are extracted from the model slot importance
.
Returns
Named numeric()
.
Examples
# Define the Learner
learner = mlr3::lrn("surv.rfsrc", importance = "TRUE")
print(learner)
#> <LearnerSurvRandomForestSRC:surv.rfsrc>: Random Forest
#> * Model: -
#> * Parameters: importance=TRUE, ntime=0
#> * Packages: mlr3, mlr3proba, mlr3extralearners, randomForestSRC, pracma
#> * Predict Types: [crank], distr
#> * Feature Types: logical, integer, numeric, factor
#> * Properties: importance, missings, oob_error, weights
# Define a Task
task = mlr3::tsk("grace")
# Create train and test set
ids = mlr3::partition(task)
# Train the learner on the training ids
learner$train(task, row_ids = ids$train)
print(learner$model)
#> Sample size: 670
#> Number of deaths: 218
#> Number of trees: 500
#> Forest terminal node size: 15
#> Average no. of terminal nodes: 26.59
#> No. of variables tried at each split: 3
#> Total no. of variables: 6
#> Resampling used to grow trees: swor
#> Resample size used to grow trees: 423
#> Analysis: RSF
#> Family: surv
#> Splitting rule: logrank *random*
#> Number of random split points: 10
#> (OOB) CRPS: 15.09158359
#> (OOB) stand. CRPS: 0.08623762
#> (OOB) Requested performance error: 0.15421667
#>
print(learner$importance())
#> revascdays revasc age sysbp los stchange
#> 0.441541973 0.301779352 0.109764012 0.057528102 0.045678476 0.001599396
# Make predictions for the test rows
predictions = learner$predict(task, row_ids = ids$test)
# Score the predictions
predictions$score()
#> surv.cindex
#> 0.809909