Survival Random Forest SRC Learner
Random survival forest.
Calls randomForestSRC::rfsrc()
from randomForestSRC.
Prediction types
This learner returns two prediction types:
: a survival matrix in two dimensions, where observations are represented in rows and (unique event) time points in columns. Calculated using the internalrandomForestSRC::predict.rfsrc()
: the expected mortality usingmlr3proba::.surv_return()
Meta Information
Task type: “surv”
Predict Types: “crank”, “distr”
Feature Types: “logical”, “integer”, “numeric”, “factor”
Required Packages: mlr3, mlr3proba, mlr3extralearners, randomForestSRC, pracma
Id | Type | Default | Levels | Range |
ntree | integer | 500 | \([1, \infty)\) | |
mtry | integer | - | \([1, \infty)\) | |
mtry.ratio | numeric | - | \([0, 1]\) | |
nodesize | integer | 15 | \([1, \infty)\) | |
nodedepth | integer | - | \([1, \infty)\) | |
splitrule | character | logrank | logrank, bs.gradient, logrankscore | - |
nsplit | integer | 10 | \([0, \infty)\) | |
importance | character | FALSE | FALSE, TRUE, none, permute, random, anti | - |
block.size | integer | 10 | \([1, \infty)\) | |
bootstrap | character | by.root | by.root, by.node, none, by.user | - |
samptype | character | swor | swor, swr | - |
samp | untyped | - | - | |
membership | logical | FALSE | TRUE, FALSE | - |
sampsize | untyped | - | - | |
sampsize.ratio | numeric | - | \([0, 1]\) | |
na.action | character | na.omit | na.omit, na.impute | - |
nimpute | integer | 1 | \([1, \infty)\) | |
ntime | integer | 150 | \([0, \infty)\) | |
cause | integer | - | \([1, \infty)\) | |
proximity | character | FALSE | FALSE, TRUE, inbag, oob, all | - |
distance | character | FALSE | FALSE, TRUE, inbag, oob, all | - |
forest.wt | character | FALSE | FALSE, TRUE, inbag, oob, all | - |
xvar.wt | untyped | - | - | |
split.wt | untyped | - | - | |
forest | logical | TRUE | TRUE, FALSE | - |
var.used | character | FALSE | FALSE, all.trees, by.tree | - |
split.depth | character | FALSE | FALSE, all.trees, by.tree | - |
seed | integer | - | \((-\infty, -1]\) | |
do.trace | logical | FALSE | TRUE, FALSE | - |
statistics | logical | FALSE | TRUE, FALSE | - |
get.tree | untyped | - | - | |
outcome | character | train | train, test | - |
ptn.count | integer | 0 | \([0, \infty)\) | |
estimator | character | nelson | nelson, kaplan | - |
cores | integer | 1 | \([1, \infty)\) | |
save.memory | logical | FALSE | TRUE, FALSE | - |
perf.type | character | - | none | - |
case.depth | logical | FALSE | TRUE, FALSE | - |
Custom mlr3 parameters
: Hidden parameter that controls the type of estimator used to derive the survival function during prediction. The default value is"chf"
which uses a bootstrapped Nelson-Aalen estimator for the cumulative hazard function \(H(t)\), (Ishwaran, 2008) from which we calculate \(S(t) = \exp(-H(t))\), whereas"surv"
uses a bootstrapped Kaplan-Meier estimator to directly estimate \(S(t)\).
: This hyperparameter can alternatively be set via the added hyperparametermtry.ratio
asmtry = max(ceiling(mtry.ratio * n_features), 1)
. Note thatmtry
are mutually exclusive.sampsize
: This hyperparameter can alternatively be set via the added hyperparametersampsize.ratio
assampsize = max(ceiling(sampsize.ratio * n_obs), 1)
. Note thatsampsize
are mutually exclusive.cores
: This value is set as the optionrf.cores
during training and is set to 1 by default.
Initial parameter values
: Number of time points to coerce the observed event times for use in the estimated survival function during prediction. We changed the default value of150
in order to be in line with other random survival forest learners and use all the unique event times from the train set.
# Define the Learner
learner = mlr3::lrn("surv.rfsrc", importance = "TRUE")
#> <LearnerSurvRandomForestSRC:surv.rfsrc>: Random Forest
#> * Model: -
#> * Parameters: importance=TRUE, ntime=0
#> * Packages: mlr3, mlr3proba, mlr3extralearners, randomForestSRC, pracma
#> * Predict Types: [crank], distr
#> * Feature Types: logical, integer, numeric, factor
#> * Properties: importance, missings, oob_error, weights
# Define a Task
task = mlr3::tsk("grace")
# Create train and test set
ids = mlr3::partition(task)
# Train the learner on the training ids
learner$train(task, row_ids = ids$train)
#> Sample size: 670
#> Number of deaths: 213
#> Number of trees: 500
#> Forest terminal node size: 15
#> Average no. of terminal nodes: 26.722
#> No. of variables tried at each split: 3
#> Total no. of variables: 6
#> Resampling used to grow trees: swor
#> Resample size used to grow trees: 423
#> Analysis: RSF
#> Family: surv
#> Splitting rule: logrank *random*
#> Number of random split points: 10
#> (OOB) CRPS: 16.64868647
#> (OOB) stand. CRPS: 0.09406038
#> (OOB) Requested performance error: 0.16563417
#> revascdays revasc age sysbp los stchange
#> 0.42597100 0.27735197 0.11553466 0.07151100 0.04637084 0.00350123
# Make predictions for the test rows
predictions = learner$predict(task, row_ids = ids$test)
# Score the predictions
#> surv.cindex
#> 0.8410948