Skip to contents

Many learners are already included in the mlr3 ecosystem, but there are still many more that have not been implemented and so in this vignette we will look at how to add a new Learner to mlr3. Here we will already assume a lot of knowledge, so it is best to read the mlr3book first! If you want to contribute a new Learner to mlr3extralearners it is recommended to first open an issue to discuss the learner. As another side note, the mlr3extralearner::create_learner() function can be used to generate files containing templates for the learner and test files required to create and test a Learner. We will not use it in this vignette for the sake of clarity, but recommend using it in practice to avoid writing a lot of boilerplate code.

We will first demonstrate what the final class looks like and then we will explain it line by line (and method by method). As a working example, we will implement a slightly stripped-down version of regr.rpart.

library(mlr3)
library(paradox)
library(R6)
library(mlr3misc)

LearnerRegrRpartSimple = R6Class("LearnerRegrRpartSimple",
  inherit = LearnerRegr,
  public = list(
    initialize = function() {
      param_set = ps(
        cp             = p_dbl(0, 1, default = 0.01, tags = "train"),
        maxcompete     = p_int(0L, default = 4L, tags = "train"),
        maxdepth       = p_int(1L, 30L, default = 30L, tags = "train"),
        maxsurrogate   = p_int(0L, default = 5L, tags = "train"),
        minbucket      = p_int(1L, tags = "train"),
        minsplit       = p_int(1L, default = 20L, tags = "train"),
        surrogatestyle = p_int(0L, 1L, default = 0L, tags = "train"),
        usesurrogate   = p_int(0L, 2L, default = 2L, tags = "train"),
        xval           = p_int(0L, default = 10L, tags = "train")
      )
      param_set$values$xval = 10
      super$initialize(
        id = "regr.rpart_simple",
        feature_types = c("logical", "integer", "numeric", "factor", "ordered"),
        predict_types = "response",
        packages = "rpart",
        param_set = param_set,
        properties = c("weights", "missings", "importance"),
        label = "Regression Tree",
        man = "mlr3::mlr_learners_regr.rpart"
      )
    },
    importance = function() {
      if (is.null(self$model)) {
        stopf("No model stored")
      }
      # importance is only present if there is at least on split
      if (is.null(self$model$variable.importance)) {
        importance = set_names(numeric())
      } else {
        importance = sort(self$model$variable.importance, decreasing = TRUE)
      }
      return(importance)
    }
  ),
  private = list(
    .train = function(task) {
      pv = self$param_set$get_values(tags = "train")
      if ("weights" %in% task$properties) {
        pv$weights = task$weights$weight
      }
      invoke(
        rpart::rpart,
        formula = task$formula(),
        data = task$data(),
        .args = pv
      )
    },
    .predict = function(task) {
      pv = self$param_set$get_values(tags = "predict")

      # ensure same column order in train and predict
      newdata = mlr3extralearners:::ordered_features(task, self)
      response = invoke(predict, self$model, newdata = newdata, .args = pv)
      list(response = unname(response))
    }
  )
)

We first specify the class name according to the mlr3 naming convention as "LearnerRegrRpartSimple". In the second line, we select the appropriate parent class from which we want to inherit. Because we are implementing a regression learner, we have to inherit from the LearnerRegr class. For classification we would inherit from LearnerClassif, for survival from mlr3proba::LearnerSurv, and for clustering from mlr3cluster::LearnerClust.

In the constructor, the constructor of the super class (in this case LearnerRegr) is called with meta-information about the learner which we are defining. We strongly recommend reading the relevant class documentation to understand its arguments (see e.g. the learner documentation). These include:

  • id: The ID of the new learner.
  • packages: The upstream package name(s) of the implemented learner.
  • param_set: A set of hyperparameters and their descriptions provided as a paradox::ParamSet.
  • predict_types: Set of predict types the learner supports.
  • feature_types: Set of feature types the learner is able to handle.
  • properties: Set of properties of the learner.
  • man: The roxygen identifier of the learner. This is used within the $help() method of the super class to open the help page of the learner.
  • label: The label of the learner. This should briefly describe the learner (similar to the description’s title) and is for example used for printing.

To fill out some of this meta-information, one has to go through the manual pages of the upstream packages. In order to know e.g. which predict types or properties are available, reading the manual pages of the corresponding superclass(es) is indispensable. Furthermore, it is a good idea to look at already existing learner implementations for inspiration.

Three properties in mlr3 exist that deserve special consideration, which are:

  • "validation" - the learner can make use of an internal validation to e.g. track the performance
  • "internal_tuning" - the learner can internally optimize hyperparameters, e.g. via early stopping
  • "marshal" - the learner cannot be serialized without loss of information (e.g. learners from mlr3torch have this property)

While these properties are rarely needed, it is still important to know how to implement them if they do apply. For the first two properties, see the documentation of Learner or for a concrete implementation the XGBoost learners. The documentation on marshaling can be found by running ?marshaling, and for an example on how to implement this see mlr3torch::LearnerTorch. In any case, if you suspect your learner to have any of these properties it is best to first discuss the implementation with the mlr-org team.

As a final note, we want to mention that a Learner should be constructable even when the packages that it uses are not loaded, so the $initialize() method should not use any objects from the package.

Defining the Parameter Set of a Learner

The parameter set of a learner is the set of hyperparameters used in model training and predicting, this is given as a paradox::ParamSet. The set consists of a list of hyperparameters, where each has a specific class for the hyperparameter type. For an explanation of parameters and the ParamSet class we refer to the mlr3book, where we show how to create parameter sets to define search spaces for hyperparameter tuning. However, there are some aspects of parameters that are mostly relevant when creating new Learners, so we cover them here. These are:

  • tags that organize the parameters, e.g. whether they are used during training or prediction.
  • default values which give information about the value that is used by the upstream function when no value is set. Note that this is not necessarily the default of the function signature but what is used internally: e.g. the default of parameter a for a function is NULL, but internally the value for a is set to 1 it is left as NULL by the caller.
  • initial values are the initial $values of the parameter set that are set during construction.
  • custom checks which allow to specify custom constraints for untyped parameters (ParamUty).

Tags

Each parameter has one or more tags that determine in which method they are used. I.e. parameters that are used during training must be tagged with "train" and those that are used during prediction with "predict". In our case there are only training parameters. Parameters with certain tags, can be retrieved by calling the $get_values() method of a ParamSet with the tag argument.

Furthermore, there are other tags that serve specific purposes:

  • The tag "threads" should be used (if applicable) to tag the parameter that determines the number of threads used for the learner’s internal parallelization. Once a learner is constructed, this parameter can be set using set_threads().
  • The tag "required" should be used to tag parameters that must be provided for the algorithm to be executable.

Default Values

The default values of each parameter indicate which value the upstream function (in this case rpart::rpart()) uses if no parameter is specified. These have to be retrieved from the package’s documentation. If parameters have no default, they have to be tagged with the "required" tag. For such parameters, we recommended setting the parameter to a value in the initialize method, so that the learner can be executed after creation. Note that setting a parameter value in the $initialize() method does not change the parameter’s default value (as the default denotes what happens when no value is set).

Parameters expect the default values to be of their respective type, e.g. the default of ParamInt must be an integer. Some packages have R expressions as default values which cannot be properly expressed in paradox. In such cases, a practical compromise is to not specify any default value when calling p_int().

Initial Values

In some cases you want to change the parameter values during initialization. You can do this by changing the param_set$values during the learner’s construction. You can see we have done this for regr.rpart_simple where the default for xval is changed to 0. Note however that we still annotate the default of the xval parameter as 10. It is recommended to initialize parameter values for good reasons and document this properly when doing so. When contributing a learner to mlr3extralearners this information should go into the Initial parameter values section.

Custom Checks

In paradox only integers, numerics, factors, and logicals have dedicated parameter types. For all other parameters, the ParamUty class has to be used, which is constructable via the p_uty() function. It has the argument custom_check, which allows specifying custom constraints. This function return TRUE if the input is valid and a character(1) with the error message otherwise.

Custom Parameters

Sometimes you might want to add paramters that are not present in the upstream function. One example is the mtry.ratio parameter of lrn("surv.ranger"). Such parameters should usually not have default values (because we are here not calling an upstream function with a default behaviour) but instead have their value initialized during construction and have the "required" tag. This avoids having to deal with the case where this parameter has no value set. When contributing a learner to mlr3extralearners this information should go into the Custom parameters section.

Train function

The train function takes a Task as input and must return a model, i.e. some R object that allows the learner to make predictions. We want to translate the following call of rpart::rpart() into code that can be used inside the .train() method.

First, we write something down that works completely without mlr3:

data = mtcars
model = rpart::rpart(mpg ~ ., data = mtcars, xval = 0)

We need to pass the formula notation mpg ~ ., the data and the hyperparameters. To get the hyperparameters, we call self$param_set$get_values(tag = "train") and thereby query all parameters that are using during "train". Then, the dataset is extracted from the Task. Because the learner has the property "weights", we insert the weights of the task if there are any. Then we obtain the formula from the task using task$formula() and access the training data using task$data(). Last, we call the upstream function rpart::rpart() with the data and pass all hyperparameters via argument .args using the mlr3misc::invoke() function. The latter is simply an optimized version of do.call() that we use within the mlr3 ecosystem. The return value of this method will be available as the $model slot of the trained learner.

.train = function(task) {
  pv = self$param_set$get_values(tags = "train")
  if ("weights" %in% task$properties) {
    pv$weights = task$weights$weight
  }
  formula = task$formula()
  data = task$data()
  invoke(
    rpart::rpart,
    formula = formula,
    data = data,
    .args = pv
  )
}

At this point, an explanation for the difference between the public method $train() and the private method $.train() is in order. The former essentially calls the latter and additionally performs some checks and tasks that are the same for every learner. The same holds for the predict method.

Predict function

The internal method $.predict() also operates on a Task as well as on the fitted model that has been created by the $train() call previously and has been stored in self$model. The return value must contain the required information to produce a mlr3::Prediction object that is returned when $predict() is called on a learner. We proceed analogously to what we did in the previous section. We start with a version without any mlr3 objects and continue to replace objects until we have reached the desired interface:

# inputs:
task = tsk("mtcars")
self = list(model = rpart::rpart(task$formula(), data = task$data()))
data = mtcars

# response prediction
response = predict(self$model, newdata = data)

Next, we transition from data to a Task again and construct a list with the return type requested by the user. This is stored in the $predict_type slot of a learner class. Because the regr.rpart_simple learner we are implementing has only one predict type, we do not have to take the value of self$predict_type into account. For, e.g., regression learners with predict types "response" and "se" or classification learners with predict types "response" and "prob", the predict method must return the prediction requested by the user.

The final $.predict() method is below, we could omit the pv line as there are no parameters with the "predict" tag but we keep it here to be consistent:

.predict = function(task) {
  pv = self$param_set$get_values(tags = "predict")
  # ensure same column order in train and predict
  newdata = mlr3extralearners:::ordered_features(task, self)
  response = invoke(predict, self$model, newdata = newdata, .args = pv)
  list(response = unname(response))
}

Beware that you cannot rely on the column order of the data returned by task$data() as the order of columns may be different from the order of the columns during $.train(). The newdata line ensures the ordering is the same by calling the same order as in train!

Optional Extractors

Specific learner implementations are free to implement additional getters to ease the access of certain parts of the model in the inherited subclasses. Some of these methods are standardized in mlr3, c.f. the Learner documentation. If it is one of the standardized extractors, some custom code might have to be written for the return value to comply with the mlr3 standards (see the Learner documentation). Because we specified earlier that the learner has the property "importance", we implement the public $importance() method. They access the $model slot of a learner after training and return one of its fields. Because the $importance() method is standardized, it must return the importance scores as a sorted numeric vector, the names being the features.

importance = function() {
  if (is.null(self$model)) {
    stopf("No model stored")
  }
  # importance is only present if there is at least on split
  if (is.null(self$model$variable.importance)) {
    importance = set_names(numeric())
  } else {
    importance = sort(self$model$variable.importance, decreasing = TRUE)
  }
  return(importance)
}

Testing the learner

Once your learner is created, you should write tests to verify its correctness. While you should also do this when implementing a learner only for your own use, this is absolutely required when you want to contribute a new learner to mlr3extralearners. Besides manual tests, the mlr3 package also provides automatic tests to check that a learner satisfies some basic sanity checks and that it implements all the available parameters of the upstream function. The manual page for these functions can be accessed via ?mlr_test_helpers.

For a bare-bone check you can just try to run a simple $train() call.

task = tsk("mtcars")
learner = LearnerRegrRpartSimple$new()
learner$train(task)
p = learner$predict(task)

p$score(msr("regr.mse"))
## regr.mse 
## 9.127693
learner$importance()
##       cyl      disp        hp        wt      qsec        vs      carb      gear 
## 724.18935 721.08062 702.29023 573.72817 442.05737 395.01237  31.36333  15.68167

If it runs without erroring, that’s a very good start!

Autotest

To ensure that your learner is able to handle all kinds of different properties and feature types, we have written an “autotest” that checks the learner for different combinations of such. In addition to the autotest we also check some other learner properties using expect_learner(). See section run_autotest and expect_learner of ?mlr_test_helpers for an explanation. Because these functions are not in the mlr3 namespace, we have to source them from the ./testthat subfolder of the installed mlr3 package Note that when implementing survival learners, you must also source these help files from the mlr3proba package.

lapply(list.files(system.file("testthat", package = "mlr3"), pattern = "^helper.*\\.[rR]", full.names = TRUE), source)
# mlr3proba
# lapply(list.files(system.file("testthat", package = "mlr3proba"), pattern = "^helper.*\\.[rR]", full.names = TRUE), source)

We now perform the autotest. By setting expect_learner(..., check_man = FALSE), we disable a test that verifies that the correct manual page exists, which is only relevant when adding the learner to mlr3extralearners.

library(testthat)
test_that("autotest", {
  learner = LearnerRegrRpartSimple$new()

  # basic learner properties
  expect_learner(learner)

  # you can skip tests using the `exclude` argument
  result = run_autotest(learner)
  expect_true(result, info = result$error)
})
## Test passed 🎉

Checking Parameters

Some learners have a high number of parameters and it is easy to miss out on some during the creation of a new learner or misspell a parameter. Therefore we have written a “Parameter Check” that can be used to check that the parameters are correctly implemented. See section run_paramtest of ?mlr_test_helpers for an explanation. This “Parameter Check” compares the parameters of the mlr3 ParamSet against all arguments available in the upstream function that is called during $train() and $predict(). When the $.train() method either calls multiple functions, or, e.g., the ... arguments are passed further to a control function (as in the case here), a list of functions can be passed to the parameter test.

The test comes with an exclude argument that should be used to exclude and explain why certain arguments of the upstream function are not within the ParamSet of the learner. This will likely be required for all learners as common arguments like x, target or data are handled by the mlr3 interface and are therefore not included within the ParamSet.

However, there might be more parameters that need to be excluded, for example:

  • Type dependent parameters, i.e. parameters that only apply for classification or regression learners.
  • Parameters that are actually deprecated by the upstream package and which were therefore not included in the mlr3 ParamSet.

All excluded parameters should have a comment justifying their exclusion.

In our example, the final paramtest script looks like:

test_that("paramtest", {
  learner = LearnerRegrRpartSimple$new()

  exclude = c("formula", "data", "weights", "subset", "na.action", "method", "model",
    "x", "y", "parms", "control", "cost", "keep_model")
  result = run_paramtest(learner, list(rpart::rpart, rpart::rpart.control), exclude, tag = "train")
  expect_true(result, info = result$error)

  exclude = c(
    "object", # handled internally
    "newdata", # handled internally
    "type", # handled internally
    "na.action", # handled internally
    "model" # not implemented
  )
  result = run_paramtest(learner, rpart:::predict.rpart, exclude, tag = "predict")
  expect_true(result, info = result$error)
})
## Test passed 🎉

Contributing to mlr3extralearners

When adding a Learner to mlr3extralearners there are some additional requirements that have to be satisfied:

  • Adding the line .extralrns_dict$add("<learner-id>", function() <LearnerName>$new()) to the bottom of the learner file to add the learner to the learners Dictionary.
  • Add the learner dependencies to Suggests (and not e.g. to Imports) in the DESCRIPTION file.
  • The learner must be properly documented. When using mlr3extralearners::create_learner() this is automatically generated. Otherwise you can use and adapt the documentation of an already implemented learner.
  • Conforming to the mlr3 style guide.
  • The tests must pass in the CI (also check the CodeFactor results).
  • Update the NEWS.md file and describe that you added the learner.
  • Add yourself as a contributor