Random Forest Competing Risks Learner
Source:R/learner_randomForestSRC_cmprsk_rfsrc.R
mlr_learners_cmprsk.rfsrc.Rd
Random survival forests for competing risks.
Calls randomForestSRC::rfsrc()
from randomForestSRC.
Meta Information
Task type: “cmprsk”
Predict Types: “cif”
Feature Types: “logical”, “integer”, “numeric”, “factor”
Required Packages: mlr3, mlr3proba, mlr3extralearners, randomForestSRC
Parameters
Id | Type | Default | Levels | Range |
ntree | integer | 500 | \([1, \infty)\) | |
mtry | integer | - | \([1, \infty)\) | |
mtry.ratio | numeric | - | \([0, 1]\) | |
nodesize | integer | 15 | \([1, \infty)\) | |
nodedepth | integer | - | \([1, \infty)\) | |
splitrule | character | logrankCR | logrankCR, logrank | - |
nsplit | integer | 10 | \([0, \infty)\) | |
importance | character | FALSE | FALSE, TRUE, none, anti, permute, random | - |
block.size | integer | 10 | \([1, \infty)\) | |
bootstrap | character | by.root | by.root, by.node, none, by.user | - |
samptype | character | swor | swor, swr | - |
samp | untyped | - | - | |
membership | logical | FALSE | TRUE, FALSE | - |
sampsize | untyped | - | - | |
sampsize.ratio | numeric | - | \([0, 1]\) | |
na.action | character | na.omit | na.omit, na.impute | - |
nimpute | integer | 1 | \([1, \infty)\) | |
ntime | integer | 150 | \([0, \infty)\) | |
cause | untyped | - | - | |
proximity | character | FALSE | FALSE, TRUE, inbag, oob, all | - |
distance | character | FALSE | FALSE, TRUE, inbag, oob, all | - |
forest.wt | character | FALSE | FALSE, TRUE, inbag, oob, all | - |
xvar.wt | untyped | - | - | |
split.wt | untyped | - | - | |
forest | logical | TRUE | TRUE, FALSE | - |
var.used | character | FALSE | FALSE, all.trees | - |
split.depth | character | FALSE | FALSE, all.trees, by.tree | - |
seed | integer | - | \((-\infty, -1]\) | |
do.trace | logical | FALSE | TRUE, FALSE | - |
statistics | logical | FALSE | TRUE, FALSE | - |
get.tree | untyped | - | - | |
outcome | character | train | train, test | - |
ptn.count | integer | 0 | \([0, \infty)\) | |
cores | integer | 1 | \([1, \infty)\) | |
save.memory | logical | FALSE | TRUE, FALSE | - |
perf.type | character | - | none | - |
case.depth | logical | FALSE | TRUE, FALSE | - |
marginal.xvar | untyped | NULL | - |
Initial parameter values
ntime
: Number of time points to coerce the observed event times for use in the estimated cumulative incidence functions during prediction. We changed the default value of150
to0
, meaning we now use all the unique event times from the train set across all competing causes.
Custom mlr3 parameters
mtry
: This hyperparameter can alternatively be set via the added hyperparametermtry.ratio
asmtry = max(ceiling(mtry.ratio * n_features), 1)
. Note thatmtry
andmtry.ratio
are mutually exclusive.sampsize
: This hyperparameter can alternatively be set via the added hyperparametersampsize.ratio
assampsize = max(ceiling(sampsize.ratio * n_obs), 1)
. Note thatsampsize
andsampsize.ratio
are mutually exclusive.cores
: This value is set as the optionrf.cores
during training and is set to 1 by default.
References
Ishwaran, H., Gerds, A. T, Kogalur, B. U, Moore, D. R, Gange, J. S, Lau, M. B (2014). “Random survival forests for competing risks.” Biostatistics, 15(4), 757–773. doi:10.1093/BIOSTATISTICS/KXU010 , https://doi.org/10.1093/BIOSTATISTICS/KXU010.
See also
as.data.table(mlr_learners)
for a table of available Learners in the running session (depending on the loaded packages).Chapter in the mlr3book: https://mlr3book.mlr-org.com/basics.html#learners
mlr3learners for a selection of recommended learners.
mlr3cluster for unsupervised clustering learners.
mlr3pipelines to combine learners with pre- and postprocessing steps.
mlr3tuning for tuning of hyperparameters, mlr3tuningspaces for established default tuning spaces.
Super classes
mlr3::Learner
-> mlr3proba::LearnerCompRisks
-> LearnerCompRisksRandomForestSRC
Methods
Method importance()
The importance scores are extracted from the model slot importance
and
are cause-specific.
Returns
Named numeric()
.
Method selected_features()
Selected features are extracted from the model slot var.used
.
Note: Due to a known issue in randomForestSRC
, enabling var.used = "all.trees"
causes prediction to fail. Therefore, this setting should be used exclusively
for feature selection purposes and not when prediction is required.
Method oob_error()
Extracts the out-of-bag (OOB) cumulative incidence function (CIF) error
from the model's err.rate
slot.
If cause = "mean"
(default), the function returns a weighted average
of the cause-specific OOB errors, where the weights correspond to the
observed proportion of events for each cause in the training data.
Arguments
cause
Integer (event type) or
"mean"
(default). Use a specific event type to retrieve its OOB error, or"mean"
to compute the weighted average across causes.
Examples
# Define the Learner
learner = lrn("cmprsk.rfsrc", importance = "TRUE")
print(learner)
#>
#> ── <LearnerCompRisksRandomForestSRC> (cmprsk.rfsrc): Competing Risk Survival For
#> • Model: -
#> • Parameters: importance=TRUE, ntime=0, cores=1
#> • Packages: mlr3, mlr3proba, mlr3extralearners, and randomForestSRC
#> • Predict Types: [cif]
#> • Feature Types: logical, integer, numeric, and factor
#> • Encapsulation: none (fallback: -)
#> • Properties: importance, missings, oob_error, selected_features, and weights
#> • Other settings: use_weights = 'use'
# Define a Task
task = tsk("pbc")
# Stratification based on event
task$set_col_roles(cols = "status", add_to = "stratum")
# Create train and test set
ids = partition(task)
# Train the learner on the training ids
learner$train(task, row_ids = ids$train)
print(learner$model)
#> Sample size: 184
#> Number of events: 12, 74
#> Number of trees: 500
#> Forest terminal node size: 15
#> Average no. of terminal nodes: 8.648
#> No. of variables tried at each split: 5
#> Total no. of variables: 17
#> Resampling used to grow trees: swor
#> Resample size used to grow trees: 116
#> Analysis: RSF
#> Family: surv-CR
#> Splitting rule: logrankCR *random*
#> Number of random split points: 10
#> (OOB) Requested performance error: 0.16840459, 0.1471629
#>
print(learner$importance(cause = 1)) # VIMP for cause = 1
#> bili edema protime copper age
#> 0.4460523444 0.0751811141 0.0700544612 0.0495429781 0.0425033563
#> sex ascites albumin hepato stage
#> 0.0398590525 0.0392858497 0.0240643199 0.0183598461 0.0081833155
#> chol ast platelet alk.phos trt
#> 0.0073745503 0.0047940418 0.0029641559 0.0017100448 -0.0002916088
#> spiders trig
#> -0.0007002908 -0.0034118768
print(learner$importance(cause = 2)) # VIMP for cause = 2
#> bili edema age ascites protime
#> 0.2533068989 0.1536836539 0.0627128381 0.0622469828 0.0543399949
#> albumin copper trig alk.phos platelet
#> 0.0483440053 0.0468109751 0.0158447604 0.0104082300 0.0095549953
#> chol ast spiders stage sex
#> 0.0073672755 0.0044953987 0.0038724819 0.0032225486 0.0023355943
#> hepato trt
#> 0.0006450588 -0.0001067855
print(learner$oob_error()) # weighted-mean across causes
#> [1] 0.1501269
# Make predictions for the test rows
predictions = learner$predict(task, row_ids = ids$test)
# Score the predictions
predictions$score()
#> cmprsk.auc
#> 0.888109