Skip to contents

Deep neural network survival models from package survdnn, aimed at tabular (low to moderate-dimensional) covariate settings using torch-based multilayer perceptrons. The learner wraps survdnn::survdnn().

Prediction types

This learner supports the following prediction types:

lp

A numeric vector of linear predictors, one per observation. For loss "cox" / "cox_l2" this is a log-risk score (higher implies worse prognosis). For "aft", predict.survdnn() returns the predicted log-time location \(\mu(x)\) (higher implies better prognosis), therefore the learner internally negates it such that higher values imply higher risk (consistent with mlr3 conventions). For "coxtime", this is \(g(t_0, x)\) evaluated at a reference time.

crank

same as lp.

distr

A survival matrix (rows = observations, columns = time points) based on predict(type = "survival"). By default, predictions are evaluated on the unique event times of the training data.

Dictionary

This Learner can be instantiated via lrn():

lrn("surv.survdnn")

Meta Information

  • Task type: “surv”

  • Predict Types: “crank”, “distr”, “lp”

  • Feature Types: “integer”, “numeric”, “factor”, “ordered”

  • Required Packages: mlr3, mlr3proba, mlr3extralearners, survdnn, torch

Parameters

IdTypeDefaultLevelsRange
hiddenuntypedc(32L, 16L)-
activationcharacterrelurelu, leaky_relu, tanh, sigmoid, gelu, elu, softplus-
lrnumeric1e-04\([1e-06, 1]\)
epochsinteger300\([1, \infty)\)
losscharactercoxcox, cox_l2, aft, coxtime-
optimizercharacteradamadam, adamw, sgd, rmsprop, adagrad-
optim_argsuntypedlist()-
verboselogicalFALSETRUE, FALSE-
dropoutnumeric0.3\([0, 1]\)
batch_normlogicalTRUETRUE, FALSE-
callbacksuntypedNULL-
.seedintegerNULL\((-\infty, \infty)\)
.devicecharacterautoauto, cpu, cuda-
na_actioncharacteromitomit, fail-

See also

Author

ielbadisy

Super classes

mlr3::Learner -> mlr3proba::LearnerSurv -> LearnerSurvDNN

Active bindings

marshaled

(logical(1)) Whether the learner has been marshaled.

Methods

Inherited methods


Method new()

Creates a new instance of this R6 class.

Usage


Method marshal()

Marshal the learner's model.

Usage

LearnerSurvDNN$marshal(...)

Arguments

...

(any)
Additional arguments passed to mlr3::marshal_model().


Method unmarshal()

Unmarshal the learner's model.

Usage

LearnerSurvDNN$unmarshal(...)

Arguments

...

(any)
Additional arguments passed to mlr3::unmarshal_model().


Method clone()

The objects of this class are cloneable with this method.

Usage

LearnerSurvDNN$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples


# Define the Learner
learner = lrn("surv.survdnn", epochs = 42L, loss = "cox")
print(learner)
#> 
#> ── <LearnerSurvDNN> (surv.survdnn): SurvDNN (torch-based deep survival models) 
#> • Model: -
#> • Parameters: epochs=42, loss=cox
#> • Packages: mlr3, mlr3proba, mlr3extralearners, survdnn, and torch
#> • Predict Types: [crank], distr, and lp
#> • Feature Types: integer, numeric, factor, and ordered
#> • Encapsulation: none (fallback: -)
#> • Properties: marshal
#> • Other settings: use_weights = 'error'

# Define the task, split to train/test set
task = tsk("lung")
set.seed(42)
part = partition(task)

# Train the learner on the training ids
learner$train(task, row_ids = part$train)
#> Error: The Torch backend is not installed.
#> Please run: torch::install_torch().
print(learner$model)
#> NULL

# Make predictions for the test rows
predictions = learner$predict(task, row_ids = part$test)
#> Error: 
#>  Cannot predict, Learner 'surv.survdnn' has not been trained yet
#> → Class: Mlr3ErrorInput
print(predictions)
#> Error: object 'predictions' not found

# Score the predictions
predictions$score()
#> Error: object 'predictions' not found