Skip to contents

Fits a neural network based on pseudo-conditional survival probabilities. Calls survivalmodels::dnnsurv() from package 'survivalmodels'.

Details

Custom nets can be used in this learner either using the survivalmodels::build_keras_net utility function or using keras. The number of output channels should be of length 1 and number of input channels is the number of features plus number of cuts.

Dictionary

This Learner can be instantiated via the dictionary mlr_learners or with the associated sugar function lrn():

mlr_learners$get("surv.dnnsurv")
lrn("surv.dnnsurv")

Meta Information

Parameters

IdTypeDefaultLevelsRange
cutsinteger5\([1, \infty)\)
cutpointsuntyped--
custom_modeluntyped--
optimizercharacteradamadadelta, adagrad, adamax, adam, nadam, rmsprop, sgd-
lrnumeric0.02\([0, \infty)\)
beta_1numeric0.9\([0, 1]\)
beta_2numeric0.999\([0, 1]\)
epsilonnumeric-\([0, \infty)\)
decaynumeric0\([0, \infty)\)
clipnormnumeric-\((-\infty, \infty)\)
clipvaluenumeric-\((-\infty, \infty)\)
momentumnumeric0\([0, \infty)\)
nesterovlogicalFALSETRUE, FALSE-
loss_weightsuntyped--
weighted_metricsuntyped--
early_stoppinglogicalFALSETRUE, FALSE-
min_deltanumeric0\([0, \infty)\)
patienceinteger0\([0, \infty)\)
verboseinteger0\([0, 2]\)
baselinenumeric-\((-\infty, \infty)\)
restore_best_weightslogicalFALSETRUE, FALSE-
batch_sizeinteger32\([1, \infty)\)
epochsinteger10\([1, \infty)\)
validation_splitnumeric0\([0, 1]\)
shufflelogicalTRUETRUE, FALSE-
sample_weightuntyped--
initial_epochinteger0\([0, \infty)\)
steps_per_epochinteger-\([1, \infty)\)
validation_stepsinteger-\([1, \infty)\)
stepsinteger-\([0, \infty)\)
callbacksuntyped--
rhonumeric0.95\((-\infty, \infty)\)
global_clipnormnumeric-\((-\infty, \infty)\)
use_emalogical-TRUE, FALSE-
ema_momentumnumeric0.99\((-\infty, \infty)\)
ema_overwrite_frequencynumeric-\((-\infty, \infty)\)
jit_compilelogicalTRUETRUE, FALSE-
initial_accumultator_valuenumeric0.1\((-\infty, \infty)\)
amsgradlogicalFALSETRUE, FALSE-
lr_powernumeric-0.5\((-\infty, \infty)\)
l1_regularization_strengthnumeric0\([0, \infty)\)
l2_regularization_strengthnumeric0\([0, \infty)\)
l2_shrinkage_regularization_strengthnumeric0\([0, \infty)\)
betanumeric0\((-\infty, \infty)\)
centeredlogicalFALSETRUE, FALSE-

Installation

Package 'survivalmodels' is not on CRAN and has to be install from GitHub via remotes::install_github("RaphaelS1/survivalmodels").

Initial parameter values

  • verbose is initialized to 0.

References

Zhao, Lili, Feng, Dai (2019). “Dnnsurv: Deep neural networks for survival analysis using pseudo values.” arXiv preprint arXiv:1908.02337.

See also

Author

RaphaelS1

Super classes

mlr3::Learner -> mlr3proba::LearnerSurv -> LearnerSurvDNNSurv

Methods

Inherited methods


Method new()

Creates a new instance of this R6 class.

Usage


Method clone()

The objects of this class are cloneable with this method.

Usage

LearnerSurvDNNSurv$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

learner = mlr3::lrn("surv.dnnsurv")
print(learner)
#> <LearnerSurvDNNSurv:surv.dnnsurv>: Neural Network
#> * Model: -
#> * Parameters: verbose=0
#> * Packages: mlr3, mlr3proba, mlr3extralearners, survivalmodels, keras,
#>   pseudo, tensorflow, distr6
#> * Predict Types:  [crank], distr
#> * Feature Types: integer, numeric
#> * Properties: -

# available parameters:
learner$param_set$ids()
#>  [1] "cuts"                                
#>  [2] "cutpoints"                           
#>  [3] "custom_model"                        
#>  [4] "optimizer"                           
#>  [5] "lr"                                  
#>  [6] "beta_1"                              
#>  [7] "beta_2"                              
#>  [8] "epsilon"                             
#>  [9] "decay"                               
#> [10] "clipnorm"                            
#> [11] "clipvalue"                           
#> [12] "momentum"                            
#> [13] "nesterov"                            
#> [14] "loss_weights"                        
#> [15] "weighted_metrics"                    
#> [16] "early_stopping"                      
#> [17] "min_delta"                           
#> [18] "patience"                            
#> [19] "verbose"                             
#> [20] "baseline"                            
#> [21] "restore_best_weights"                
#> [22] "batch_size"                          
#> [23] "epochs"                              
#> [24] "validation_split"                    
#> [25] "shuffle"                             
#> [26] "sample_weight"                       
#> [27] "initial_epoch"                       
#> [28] "steps_per_epoch"                     
#> [29] "validation_steps"                    
#> [30] "steps"                               
#> [31] "callbacks"                           
#> [32] "rho"                                 
#> [33] "global_clipnorm"                     
#> [34] "use_ema"                             
#> [35] "ema_momentum"                        
#> [36] "ema_overwrite_frequency"             
#> [37] "jit_compile"                         
#> [38] "initial_accumultator_value"          
#> [39] "amsgrad"                             
#> [40] "lr_power"                            
#> [41] "l1_regularization_strength"          
#> [42] "l2_regularization_strength"          
#> [43] "l2_shrinkage_regularization_strength"
#> [44] "beta"                                
#> [45] "centered"