Random Forest Competing Risks Learner
Source:R/learner_randomForestSRC_cmprsk_rfsrc.R
mlr_learners_cmprsk.rfsrc.Rd
Random survival forests for competing risks.
Calls randomForestSRC::rfsrc()
from randomForestSRC.
Meta Information
Task type: “cmprsk”
Predict Types: “cif”
Feature Types: “logical”, “integer”, “numeric”, “factor”
Required Packages: mlr3, mlr3cmprsk, mlr3extralearners, randomForestSRC
Parameters
Id | Type | Default | Levels | Range |
ntree | integer | 500 | \([1, \infty)\) | |
mtry | integer | - | \([1, \infty)\) | |
mtry.ratio | numeric | - | \([0, 1]\) | |
nodesize | integer | 15 | \([1, \infty)\) | |
nodedepth | integer | - | \([1, \infty)\) | |
splitrule | character | logrankCR | logrankCR, logrank | - |
nsplit | integer | 10 | \([0, \infty)\) | |
importance | character | FALSE | FALSE, TRUE, none, anti, permute, random | - |
block.size | integer | 10 | \([1, \infty)\) | |
bootstrap | character | by.root | by.root, by.node, none, by.user | - |
samptype | character | swor | swor, swr | - |
samp | untyped | - | - | |
membership | logical | FALSE | TRUE, FALSE | - |
sampsize | untyped | - | - | |
sampsize.ratio | numeric | - | \([0, 1]\) | |
na.action | character | na.omit | na.omit, na.impute | - |
nimpute | integer | 1 | \([1, \infty)\) | |
ntime | integer | 150 | \([0, \infty)\) | |
cause | untyped | - | - | |
proximity | character | FALSE | FALSE, TRUE, inbag, oob, all | - |
distance | character | FALSE | FALSE, TRUE, inbag, oob, all | - |
forest.wt | character | FALSE | FALSE, TRUE, inbag, oob, all | - |
xvar.wt | untyped | - | - | |
split.wt | untyped | - | - | |
forest | logical | TRUE | TRUE, FALSE | - |
var.used | character | FALSE | FALSE, all.trees | - |
split.depth | character | FALSE | FALSE, all.trees, by.tree | - |
seed | integer | - | \((-\infty, -1]\) | |
do.trace | logical | FALSE | TRUE, FALSE | - |
get.tree | untyped | - | - | |
outcome | character | train | train, test | - |
ptn.count | integer | 0 | \([0, \infty)\) | |
cores | integer | 1 | \([1, \infty)\) | |
save.memory | logical | FALSE | TRUE, FALSE | - |
perf.type | character | - | none | - |
case.depth | logical | FALSE | TRUE, FALSE | - |
marginal.xvar | untyped | NULL | - |
Initial parameter values
ntime
: Number of time points to coerce the observed event times for use in the estimated cumulative incidence functions during prediction. We changed the default value of150
to0
, meaning we now use all the unique event times from the train set across all competing causes.
Custom mlr3 parameters
mtry
: This hyperparameter can alternatively be set via the added hyperparametermtry.ratio
asmtry = max(ceiling(mtry.ratio * n_features), 1)
. Note thatmtry
andmtry.ratio
are mutually exclusive.sampsize
: This hyperparameter can alternatively be set via the added hyperparametersampsize.ratio
assampsize = max(ceiling(sampsize.ratio * n_obs), 1)
. Note thatsampsize
andsampsize.ratio
are mutually exclusive.cores
: This value is set as the optionrf.cores
during training and is set to 1 by default.
References
Ishwaran, H., Gerds, A. T, Kogalur, B. U, Moore, D. R, Gange, J. S, Lau, M. B (2014). “Random survival forests for competing risks.” Biostatistics, 15(4), 757–773. doi:10.1093/BIOSTATISTICS/KXU010 , https://doi.org/10.1093/BIOSTATISTICS/KXU010.
See also
as.data.table(mlr_learners)
for a table of available Learners in the running session (depending on the loaded packages).Chapter in the mlr3book: https://mlr3book.mlr-org.com/basics.html#learners
mlr3learners for a selection of recommended learners.
mlr3cluster for unsupervised clustering learners.
mlr3pipelines to combine learners with pre- and postprocessing steps.
mlr3tuning for tuning of hyperparameters, mlr3tuningspaces for established default tuning spaces.
Super classes
mlr3::Learner
-> mlr3cmprsk::LearnerCompRisks
-> LearnerCompRisksRandomForestSRC
Methods
Method importance()
The importance scores are extracted from the model slot importance
and
are cause-specific.
Returns
Named numeric()
.
Method selected_features()
Selected features are extracted from the model slot var.used
.
Note: Due to a known issue in randomForestSRC
, enabling var.used = "all.trees"
causes prediction to fail. Therefore, this setting should be used exclusively
for feature selection purposes and not when prediction is required.
Method oob_error()
Extracts the out-of-bag (OOB) cumulative incidence function (CIF) error
from the model's err.rate
slot.
If cause = "mean"
(default), the function returns a weighted average
of the cause-specific OOB errors, where the weights correspond to the
observed proportion of events for each cause in the training data.
Arguments
cause
Integer (event type) or
"mean"
(default). Use a specific event type to retrieve its OOB error, or"mean"
to compute the weighted average across causes.
Examples
# Define the Learner
learner = lrn("cmprsk.rfsrc", importance = "TRUE")
print(learner)
#>
#> ── <LearnerCompRisksRandomForestSRC> (cmprsk.rfsrc): Competing Risk Survival For
#> • Model: -
#> • Parameters: importance=TRUE, ntime=0, cores=1
#> • Packages: mlr3, mlr3cmprsk, mlr3extralearners, and randomForestSRC
#> • Predict Types: [cif]
#> • Feature Types: logical, integer, numeric, and factor
#> • Encapsulation: none (fallback: -)
#> • Properties: importance, missings, oob_error, selected_features, and weights
#> • Other settings: use_weights = 'use'
# Define a Task
task = tsk("pbc")
# Stratification based on event
task$set_col_roles(cols = "status", add_to = "stratum")
# Create train and test set
ids = partition(task)
# Train the learner on the training ids
learner$train(task, row_ids = ids$train)
print(learner$model)
#> Sample size: 184
#> Number of events: 1=12, 2=74
#> Number of trees: 500
#> Forest terminal node size: 15
#> Average no. of terminal nodes: 9.47
#> No. of variables tried at each split: 5
#> Total no. of variables: 17
#> Resampling used to grow trees: swor
#> Resample size used to grow trees: 116
#> Analysis: RSF
#> Family: surv-CR
#> Splitting rule: logrankCR *random*
#> Number of random split points: 10
#> (OOB) Requested performance error: 0.21936148, 0.18921458
#>
print(learner$importance(cause = 1)) # VIMP for cause = 1
#> bili copper hepato ascites ast sex
#> 0.222561092 0.138832935 0.080284630 0.059164257 0.056707277 0.043153928
#> protime chol age stage edema alk.phos
#> 0.043111252 0.039077209 0.038332856 0.033717122 0.027112591 0.005332359
#> albumin platelet trt trig spiders
#> 0.004437528 0.001371270 -0.001248001 -0.003946608 -0.007677854
print(learner$importance(cause = 2)) # VIMP for cause = 2
#> copper bili edema ascites albumin
#> 0.1869919350 0.1319000627 0.1058266846 0.1011116776 0.0870990434
#> chol protime age ast sex
#> 0.0302579649 0.0298101376 0.0213136694 0.0099815401 0.0077290451
#> platelet stage trig alk.phos hepato
#> 0.0056903812 0.0052047887 0.0049728348 0.0041976721 0.0027304176
#> trt spiders
#> 0.0002651914 -0.0002692321
print(learner$oob_error()) # weighted-mean across causes
#> [1] 0.1934211
# Make predictions for the test rows
predictions = learner$predict(task, row_ids = ids$test)
# Score the predictions
predictions$score()
#> cmprsk.auc
#> 0.8881963