Deep neural network survival models from package survdnn, aimed at
tabular (low to moderate-dimensional) covariate settings using torch-based
multilayer perceptrons. The learner wraps survdnn::survdnn().
Prediction types
This learner supports the following prediction types:
lpA numeric vector of linear predictors, one per observation. For
loss"cox"/"cox_l2"this is a log-risk score (higher implies worse prognosis). For"aft",predict.survdnn()returns the predicted log-time location \(\mu(x)\) (higher implies better prognosis), therefore the learner internally negates it such that higher values imply higher risk (consistent with mlr3 conventions). For"coxtime", this is \(g(t_0, x)\) evaluated at a reference time.cranksame as
lp.distrA survival matrix (rows = observations, columns = time points) based on
predict(type = "survival"). By default, predictions are evaluated on the unique event times of the training data.
Meta Information
Task type: “surv”
Predict Types: “crank”, “distr”, “lp”
Feature Types: “integer”, “numeric”, “factor”, “ordered”
Required Packages: mlr3, mlr3proba, mlr3extralearners, survdnn, torch
Parameters
| Id | Type | Default | Levels | Range |
| hidden | untyped | c(32L, 16L) | - | |
| activation | character | relu | relu, leaky_relu, tanh, sigmoid, gelu, elu, softplus | - |
| lr | numeric | 1e-04 | \([1e-06, 1]\) | |
| epochs | integer | 300 | \([1, \infty)\) | |
| loss | character | cox | cox, cox_l2, aft, coxtime | - |
| optimizer | character | adam | adam, adamw, sgd, rmsprop, adagrad | - |
| optim_args | untyped | list() | - | |
| verbose | logical | FALSE | TRUE, FALSE | - |
| dropout | numeric | 0.3 | \([0, 1]\) | |
| batch_norm | logical | TRUE | TRUE, FALSE | - |
| callbacks | untyped | NULL | - | |
| .seed | integer | NULL | \((-\infty, \infty)\) | |
| .device | character | auto | auto, cpu, cuda | - |
| na_action | character | omit | omit, fail | - |
See also
as.data.table(mlr_learners)for a table of available Learners in the running session (depending on the loaded packages).Chapter in the mlr3book: https://mlr3book.mlr-org.com/chapters/chapter2/data_and_basic_modeling.html#sec-learners
mlr3learners for a selection of recommended learners.
mlr3cluster for unsupervised clustering learners.
mlr3pipelines to combine learners with pre- and postprocessing steps.
mlr3tuning for tuning of hyperparameters, mlr3tuningspaces for established default tuning spaces.
Super classes
mlr3::Learner -> mlr3proba::LearnerSurv -> LearnerSurvDNN
Methods
Method marshal()
Marshal the learner's model.
Arguments
...(any)
Additional arguments passed tomlr3::marshal_model().
Method unmarshal()
Unmarshal the learner's model.
Arguments
...(any)
Additional arguments passed tomlr3::unmarshal_model().
Examples
if (FALSE) { # learner_is_runnable("surv.survdnn") && torch::torch_is_installed()
# Define the Learner
learner = lrn("surv.survdnn", epochs = 42L, loss = "cox")
print(learner)
# Define the task, split to train/test set
task = tsk("lung")
set.seed(42)
part = partition(task)
# Train the learner on the training ids
learner$train(task, row_ids = part$train)
print(learner$model)
# Make predictions for the test rows
predictions = learner$predict(task, row_ids = part$test)
print(predictions)
# Score the predictions
predictions$score()
}