Skip to contents

Random survival forests for competing risks. Calls randomForestSRC::rfsrc() from randomForestSRC.

Dictionary

This Learner can be instantiated via lrn():

lrn("cmprsk.rfsrc")

Meta Information

Parameters

IdTypeDefaultLevelsRange
ntreeinteger500\([1, \infty)\)
mtryinteger-\([1, \infty)\)
mtry.rationumeric-\([0, 1]\)
nodesizeinteger15\([1, \infty)\)
nodedepthinteger-\([1, \infty)\)
splitrulecharacterlogrankCRlogrankCR, logrank-
nsplitinteger10\([0, \infty)\)
importancecharacterFALSEFALSE, TRUE, none, anti, permute, random-
block.sizeinteger10\([1, \infty)\)
bootstrapcharacterby.rootby.root, by.node, none, by.user-
samptypecharactersworswor, swr-
sampuntyped--
membershiplogicalFALSETRUE, FALSE-
sampsizeuntyped--
sampsize.rationumeric-\([0, 1]\)
na.actioncharacterna.omitna.omit, na.impute-
nimputeinteger1\([1, \infty)\)
ntimeinteger150\([0, \infty)\)
causeuntyped--
proximitycharacterFALSEFALSE, TRUE, inbag, oob, all-
distancecharacterFALSEFALSE, TRUE, inbag, oob, all-
forest.wtcharacterFALSEFALSE, TRUE, inbag, oob, all-
xvar.wtuntyped--
split.wtuntyped--
forestlogicalTRUETRUE, FALSE-
var.usedcharacterFALSEFALSE, all.trees-
split.depthcharacterFALSEFALSE, all.trees, by.tree-
seedinteger-\((-\infty, -1]\)
do.tracelogicalFALSETRUE, FALSE-
statisticslogicalFALSETRUE, FALSE-
get.treeuntyped--
outcomecharactertraintrain, test-
ptn.countinteger0\([0, \infty)\)
coresinteger1\([1, \infty)\)
save.memorylogicalFALSETRUE, FALSE-
perf.typecharacter-none-
case.depthlogicalFALSETRUE, FALSE-
marginal.xvaruntypedNULL-

Initial parameter values

  • ntime: Number of time points to coerce the observed event times for use in the estimated cumulative incidence functions during prediction. We changed the default value of 150 to 0, meaning we now use all the unique event times from the train set across all competing causes.

Custom mlr3 parameters

  • mtry: This hyperparameter can alternatively be set via the added hyperparameter mtry.ratio as mtry = max(ceiling(mtry.ratio * n_features), 1). Note that mtry and mtry.ratio are mutually exclusive.

  • sampsize: This hyperparameter can alternatively be set via the added hyperparameter sampsize.ratio as sampsize = max(ceiling(sampsize.ratio * n_obs), 1). Note that sampsize and sampsize.ratio are mutually exclusive.

  • cores: This value is set as the option rf.cores during training and is set to 1 by default.

References

Ishwaran, H., Gerds, A. T, Kogalur, B. U, Moore, D. R, Gange, J. S, Lau, M. B (2014). “Random survival forests for competing risks.” Biostatistics, 15(4), 757–773. doi:10.1093/BIOSTATISTICS/KXU010 , https://doi.org/10.1093/BIOSTATISTICS/KXU010.

See also

Author

bblodfon

Super classes

mlr3::Learner -> mlr3proba::LearnerCompRisks -> LearnerCompRisksRandomForestSRC

Methods

Inherited methods


Method new()

Creates a new instance of this R6 class.


Method importance()

The importance scores are extracted from the model slot importance and are cause-specific.

Usage

LearnerCompRisksRandomForestSRC$importance(cause = 1)

Arguments

cause

Integer value indicating the event of interest

Returns

Named numeric().


Method selected_features()

Selected features are extracted from the model slot var.used.

Note: Due to a known issue in randomForestSRC, enabling var.used = "all.trees" causes prediction to fail. Therefore, this setting should be used exclusively for feature selection purposes and not when prediction is required.

Usage

LearnerCompRisksRandomForestSRC$selected_features()

Returns

character().


Method oob_error()

Extracts the out-of-bag (OOB) cumulative incidence function (CIF) error from the model's err.rate slot.

If cause = "mean" (default), the function returns a weighted average of the cause-specific OOB errors, where the weights correspond to the observed proportion of events for each cause in the training data.

Usage

LearnerCompRisksRandomForestSRC$oob_error(cause = "mean")

Arguments

cause

Integer (event type) or "mean" (default). Use a specific event type to retrieve its OOB error, or "mean" to compute the weighted average across causes.

Returns

numeric().


Method clone()

The objects of this class are cloneable with this method.

Usage

LearnerCompRisksRandomForestSRC$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

# Define the Learner
learner = lrn("cmprsk.rfsrc", importance = "TRUE")
print(learner)
#> 
#> ── <LearnerCompRisksRandomForestSRC> (cmprsk.rfsrc): Competing Risk Survival For
#> • Model: -
#> • Parameters: importance=TRUE, ntime=0, cores=1
#> • Packages: mlr3, mlr3proba, mlr3extralearners, and randomForestSRC
#> • Predict Types: [cif]
#> • Feature Types: logical, integer, numeric, and factor
#> • Encapsulation: none (fallback: -)
#> • Properties: importance, missings, oob_error, selected_features, and weights
#> • Other settings: use_weights = 'use'

# Define a Task
task = tsk("pbc")

# Stratification based on event
task$set_col_roles(cols = "status", add_to = "stratum")

# Create train and test set
ids = partition(task)

# Train the learner on the training ids
learner$train(task, row_ids = ids$train)

print(learner$model)
#>                          Sample size: 184
#>                     Number of events: 12, 74
#>                      Number of trees: 500
#>            Forest terminal node size: 15
#>        Average no. of terminal nodes: 8.648
#> No. of variables tried at each split: 5
#>               Total no. of variables: 17
#>        Resampling used to grow trees: swor
#>     Resample size used to grow trees: 116
#>                             Analysis: RSF
#>                               Family: surv-CR
#>                       Splitting rule: logrankCR *random*
#>        Number of random split points: 10
#>    (OOB) Requested performance error: 0.16840459, 0.1471629
#> 
print(learner$importance(cause = 1)) # VIMP for cause = 1
#>          bili         edema       protime        copper           age 
#>  0.4460523444  0.0751811141  0.0700544612  0.0495429781  0.0425033563 
#>           sex       ascites       albumin        hepato         stage 
#>  0.0398590525  0.0392858497  0.0240643199  0.0183598461  0.0081833155 
#>          chol           ast      platelet      alk.phos           trt 
#>  0.0073745503  0.0047940418  0.0029641559  0.0017100448 -0.0002916088 
#>       spiders          trig 
#> -0.0007002908 -0.0034118768 
print(learner$importance(cause = 2)) # VIMP for cause = 2
#>          bili         edema           age       ascites       protime 
#>  0.2533068989  0.1536836539  0.0627128381  0.0622469828  0.0543399949 
#>       albumin        copper          trig      alk.phos      platelet 
#>  0.0483440053  0.0468109751  0.0158447604  0.0104082300  0.0095549953 
#>          chol           ast       spiders         stage           sex 
#>  0.0073672755  0.0044953987  0.0038724819  0.0032225486  0.0023355943 
#>        hepato           trt 
#>  0.0006450588 -0.0001067855 
print(learner$oob_error()) # weighted-mean across causes
#> [1] 0.1501269

# Make predictions for the test rows
predictions = learner$predict(task, row_ids = ids$test)

# Score the predictions
predictions$score()
#> cmprsk.auc 
#>   0.888109