Deep neural network survival models from package survdnn, aimed at
tabular (low to moderate-dimensional) covariate settings using torch-based
multilayer perceptrons. The learner wraps survdnn::survdnn().
Prediction types
This learner supports the following prediction types:
lpA numeric vector of linear predictors, one per observation. For
loss"cox"/"cox_l2"this is a log-risk score (higher implies worse prognosis). For"aft",predict.survdnn()returns the predicted log-time location \(\mu(x)\) (higher implies better prognosis), therefore the learner internally negates it such that higher values imply higher risk (consistent with mlr3 conventions). For"coxtime", this is \(g(t_0, x)\) evaluated at a reference time.cranksame as
lp.distrA survival matrix (rows = observations, columns = time points) based on
predict(type = "survival"). By default, predictions are evaluated on the unique event times of the training data.
Meta Information
Task type: “surv”
Predict Types: “crank”, “distr”, “lp”
Feature Types: “integer”, “numeric”, “factor”, “ordered”
Required Packages: mlr3, mlr3proba, mlr3extralearners, survdnn, torch
Parameters
| Id | Type | Default | Levels | Range |
| hidden | untyped | c(32L, 16L) | - | |
| activation | character | relu | relu, leaky_relu, tanh, sigmoid, gelu, elu, softplus | - |
| lr | numeric | 1e-04 | \([1e-06, 1]\) | |
| epochs | integer | 300 | \([1, \infty)\) | |
| loss | character | cox | cox, cox_l2, aft, coxtime | - |
| optimizer | character | adam | adam, adamw, sgd, rmsprop, adagrad | - |
| optim_args | untyped | list() | - | |
| verbose | logical | FALSE | TRUE, FALSE | - |
| dropout | numeric | 0.3 | \([0, 1]\) | |
| batch_norm | logical | TRUE | TRUE, FALSE | - |
| callbacks | untyped | NULL | - | |
| .seed | integer | NULL | \((-\infty, \infty)\) | |
| .device | character | auto | auto, cpu, cuda | - |
| na_action | character | omit | omit, fail | - |
See also
as.data.table(mlr_learners)for a table of available Learners in the running session (depending on the loaded packages).Chapter in the mlr3book: https://mlr3book.mlr-org.com/basics.html#learners
mlr3learners for a selection of recommended learners.
mlr3cluster for unsupervised clustering learners.
mlr3pipelines to combine learners with pre- and postprocessing steps.
mlr3tuning for tuning of hyperparameters, mlr3tuningspaces for established default tuning spaces.
Super classes
mlr3::Learner -> mlr3proba::LearnerSurv -> LearnerSurvDNN
Methods
Method marshal()
Marshal the learner's model.
Arguments
...(any)
Additional arguments passed tomlr3::marshal_model().
Method unmarshal()
Unmarshal the learner's model.
Arguments
...(any)
Additional arguments passed tomlr3::unmarshal_model().
Examples
# Define the Learner
learner = lrn("surv.survdnn", epochs = 42L, loss = "cox")
print(learner)
#>
#> ── <LearnerSurvDNN> (surv.survdnn): SurvDNN (torch-based deep survival models) ─
#> • Model: -
#> • Parameters: epochs=42, loss=cox
#> • Packages: mlr3, mlr3proba, mlr3extralearners, survdnn, and torch
#> • Predict Types: [crank], distr, and lp
#> • Feature Types: integer, numeric, factor, and ordered
#> • Encapsulation: none (fallback: -)
#> • Properties: marshal
#> • Other settings: use_weights = 'error'
# Define the task, split to train/test set
task = tsk("lung")
set.seed(42)
part = partition(task)
# Train the learner on the training ids
learner$train(task, row_ids = part$train)
#> Error: The Torch backend is not installed.
#> Please run: torch::install_torch().
print(learner$model)
#> NULL
# Make predictions for the test rows
predictions = learner$predict(task, row_ids = part$test)
#> Error:
#> ✖ Cannot predict, Learner 'surv.survdnn' has not been trained yet
#> → Class: Mlr3ErrorInput
print(predictions)
#> Error: object 'predictions' not found
# Score the predictions
predictions$score()
#> Error: object 'predictions' not found