Creating a new Learner
extending.Rmd
Many learners are already included in the mlr3
ecosystem, but there are still many more that have not been implemented
and so in this vignette we will look at how to add a new
Learner
to mlr3
. Here we will already assume a
lot of knowledge, so it is best to read the mlr3book first! If you want to
contribute a new Learner
to mlr3extralearners
it is recommended to first open an issue to discuss the learner. As
another side note, the mlr3extralearner::create_learner()
function can be used to generate files containing templates for the
learner and test files required to create and test a
Learner
. We will not use it in this vignette for the sake
of clarity, but recommend using it in practice to avoid writing a lot of
boilerplate code.
We will first demonstrate what the final class looks like and then we
will explain it line by line (and method by method). As a working
example, we will implement a slightly stripped-down version of
regr.rpart
.
library(mlr3)
library(paradox)
library(R6)
library(mlr3misc)
LearnerRegrRpartSimple = R6Class("LearnerRegrRpartSimple",
inherit = LearnerRegr,
public = list(
initialize = function() {
param_set = ps(
cp = p_dbl(0, 1, default = 0.01, tags = "train"),
maxcompete = p_int(0L, default = 4L, tags = "train"),
maxdepth = p_int(1L, 30L, default = 30L, tags = "train"),
maxsurrogate = p_int(0L, default = 5L, tags = "train"),
minbucket = p_int(1L, tags = "train"),
minsplit = p_int(1L, default = 20L, tags = "train"),
surrogatestyle = p_int(0L, 1L, default = 0L, tags = "train"),
usesurrogate = p_int(0L, 2L, default = 2L, tags = "train"),
xval = p_int(0L, default = 10L, tags = "train")
)
param_set$values$xval = 10
super$initialize(
id = "regr.rpart_simple",
feature_types = c("logical", "integer", "numeric", "factor", "ordered"),
predict_types = "response",
packages = "rpart",
param_set = param_set,
properties = c("weights", "missings", "importance"),
label = "Regression Tree",
man = "mlr3::mlr_learners_regr.rpart"
)
},
importance = function() {
if (is.null(self$model)) {
stopf("No model stored")
}
# importance is only present if there is at least on split
if (is.null(self$model$variable.importance)) {
importance = set_names(numeric())
} else {
importance = sort(self$model$variable.importance, decreasing = TRUE)
}
return(importance)
}
),
private = list(
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
if ("weights" %in% task$properties) {
pv$weights = task$weights$weight
}
invoke(
rpart::rpart,
formula = task$formula(),
data = task$data(),
.args = pv
)
},
.predict = function(task) {
pv = self$param_set$get_values(tags = "predict")
# ensure same column order in train and predict
newdata = mlr3extralearners:::ordered_features(task, self)
response = invoke(predict, self$model, newdata = newdata, .args = pv)
list(response = unname(response))
}
)
)
We first specify the class name according to the mlr3
naming convention as "LearnerRegrRpartSimple"
. In the
second line, we select the appropriate parent class from which we want
to inherit. Because we are implementing a regression learner, we have to
inherit from the LearnerRegr
class. For classification we
would inherit from LearnerClassif
, for survival from
mlr3proba::LearnerSurv
, and for clustering from
mlr3cluster::LearnerClust
.
In the constructor, the constructor of the super class (in this case
LearnerRegr
) is called with meta-information about the
learner which we are defining. We strongly recommend reading the
relevant class documentation to understand its arguments (see e.g. the
learner
documentation). These include:
-
id
: The ID of the new learner. -
packages
: The upstream package name(s) of the implemented learner. -
param_set
: A set of hyperparameters and their descriptions provided as aparadox::ParamSet
. -
predict_types
: Set of predict types the learner supports. -
feature_types
: Set of feature types the learner is able to handle. -
properties
: Set of properties of the learner. -
man
: The roxygen identifier of the learner. This is used within the$help()
method of the super class to open the help page of the learner. -
label
: The label of the learner. This should briefly describe the learner (similar to the description’s title) and is for example used for printing.
To fill out some of this meta-information, one has to go through the manual pages of the upstream packages. In order to know e.g. which predict types or properties are available, reading the manual pages of the corresponding superclass(es) is indispensable. Furthermore, it is a good idea to look at already existing learner implementations for inspiration.
Three properties in mlr3
exist that deserve special
consideration, which are:
-
"validation"
- the learner can make use of an internal validation to e.g. track the performance -
"internal_tuning"
- the learner can internally optimize hyperparameters, e.g. via early stopping -
"marshal"
- the learner cannot be serialized without loss of information (e.g. learners frommlr3torch
have this property)
While these properties are rarely needed, it is still important to
know how to implement them if they do apply. For the first two
properties, see the documentation of
Learner
or for a concrete implementation the XGBoost
learners. The documentation on marshaling can be found by running
?marshaling
, and for an example on how to implement this
see mlr3torch::LearnerTorch
.
In any case, if you suspect your learner to have any of these properties
it is best to first discuss the implementation with the
mlr-org
team.
As a final note, we want to mention that a Learner
should be constructable even when the packages that it uses are not
loaded, so the $initialize()
method should not use any
objects from the package.
Defining the Parameter Set of a Learner
The parameter set of a learner is the set of hyperparameters used in
model training and predicting, this is given as a
paradox::ParamSet
. The set consists of a list of
hyperparameters, where each has a specific class for the hyperparameter
type. For an explanation of parameters and the ParamSet
class we refer to the mlr3book, where we show how to
create parameter sets to define search spaces for hyperparameter tuning.
However, there are some aspects of parameters that are mostly relevant
when creating new Learner
s, so we cover them here. These
are:
- tags that organize the parameters, e.g. whether they are used during training or prediction.
-
default values which give information about the
value that is used by the upstream function when no value is set. Note
that this is not necessarily the default of the function signature but
what is used internally: e.g. the default of parameter
a
for a function isNULL
, but internally the value fora
is set to1
it is left asNULL
by the caller. -
initial values are the initial
$values
of the parameter set that are set during construction. -
custom checks which allow to specify custom
constraints for untyped parameters (
ParamUty
).
Tags
Each parameter has one or more tags that determine in which method
they are used. I.e. parameters that are used during training must be
tagged with "train"
and those that are used during
prediction with "predict"
. In our case there are only
training parameters. Parameters with certain tags, can be retrieved by
calling the $get_values()
method of a ParamSet
with the tag
argument.
Furthermore, there are other tags that serve specific purposes:
- The tag
"threads"
should be used (if applicable) to tag the parameter that determines the number of threads used for the learner’s internal parallelization. Once a learner is constructed, this parameter can be set usingset_threads()
. - The tag
"required"
should be used to tag parameters that must be provided for the algorithm to be executable.
Default Values
The default values of each parameter indicate which value the
upstream function (in this case rpart::rpart()
) uses if no
parameter is specified. These have to be retrieved from the package’s
documentation. If parameters have no default, they have to be tagged
with the "required"
tag. For such parameters, we
recommended setting the parameter to a value in the initialize method,
so that the learner can be executed after creation. Note that setting a
parameter value in the $initialize()
method does not change
the parameter’s default value (as the default denotes what happens when
no value is set).
Parameters expect the default values to be of their respective type,
e.g. the default of ParamInt
must be an integer. Some
packages have R expressions as default values which cannot be properly
expressed in paradox
. In such cases, a practical compromise
is to not specify any default value when calling
p_int()
.
Initial Values
In some cases you want to change the parameter values during
initialization. You can do this by changing the
param_set$values
during the learner’s construction. You can
see we have done this for regr.rpart_simple
where the
default for xval
is changed to 0
. Note however
that we still annotate the default of the xval
parameter as
10. It is recommended to initialize parameter values for good reasons
and document this properly when doing so. When contributing a learner to
mlr3extralearners
this information should go into the
Initial parameter values section.
Custom Checks
In paradox
only integers, numerics, factors, and
logicals have dedicated parameter types. For all other parameters, the
ParamUty
class has to be used, which is constructable via
the p_uty()
function. It has the argument
custom_check
, which allows specifying custom constraints.
This function return TRUE
if the input is valid and a
character(1)
with the error message otherwise.
Custom Parameters
Sometimes you might want to add paramters that are not present in the
upstream function. One example is the mtry.ratio
parameter
of lrn("surv.ranger")
. Such parameters should usually not
have default
values (because we are here not calling an
upstream function with a default behaviour) but instead have their value
initialized during construction and have the "required"
tag. This avoids having to deal with the case where this parameter has
no value set. When contributing a learner to
mlr3extralearners
this information should go into the
Custom parameters section.
Train function
The train function takes a Task
as input and must return
a model, i.e. some R
object that allows the learner to make
predictions. We want to translate the following call of
rpart::rpart()
into code that can be used inside the
.train()
method.
First, we write something down that works completely without
mlr3
:
data = mtcars
model = rpart::rpart(mpg ~ ., data = mtcars, xval = 0)
We need to pass the formula notation mpg ~ .
, the data
and the hyperparameters. To get the hyperparameters, we call
self$param_set$get_values(tag = "train")
and thereby query
all parameters that are using during "train"
. Then, the
dataset is extracted from the Task
. Because the learner has
the property "weights"
, we insert the weights of the task
if there are any. Then we obtain the formula from the task using
task$formula()
and access the training data using
task$data()
. Last, we call the upstream function
rpart::rpart()
with the data and pass all hyperparameters
via argument .args
using the
mlr3misc::invoke()
function. The latter is simply an
optimized version of do.call()
that we use within the
mlr3
ecosystem. The return value of this method will be
available as the $model
slot of the trained learner.
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
if ("weights" %in% task$properties) {
pv$weights = task$weights$weight
}
formula = task$formula()
data = task$data()
invoke(
rpart::rpart,
formula = formula,
data = data,
.args = pv
)
}
At this point, an explanation for the difference between the public
method $train()
and the private method
$.train()
is in order. The former essentially calls the
latter and additionally performs some checks and tasks that are the same
for every learner. The same holds for the predict method.
Predict function
The internal method $.predict()
also operates on a
Task
as well as on the fitted model that has been created
by the $train()
call previously and has been stored in
self$model
. The return value must contain the required
information to produce a mlr3::Prediction
object that is
returned when $predict()
is called on a learner. We proceed
analogously to what we did in the previous section. We start with a
version without any mlr3
objects and continue to replace
objects until we have reached the desired interface:
# inputs:
task = tsk("mtcars")
self = list(model = rpart::rpart(task$formula(), data = task$data()))
data = mtcars
# response prediction
response = predict(self$model, newdata = data)
Next, we transition from data
to a Task
again and construct a list with the return type requested by the user.
This is stored in the $predict_type
slot of a learner
class. Because the regr.rpart_simple
learner we are
implementing has only one predict type, we do not have to take the value
of self$predict_type
into account. For, e.g., regression
learners with predict types "response"
and
"se"
or classification learners with predict types
"response"
and "prob"
, the predict method must
return the prediction requested by the user.
The final $.predict()
method is below, we could omit the
pv
line as there are no parameters with the
"predict"
tag but we keep it here to be consistent:
.predict = function(task) {
pv = self$param_set$get_values(tags = "predict")
# ensure same column order in train and predict
newdata = mlr3extralearners:::ordered_features(task, self)
response = invoke(predict, self$model, newdata = newdata, .args = pv)
list(response = unname(response))
}
Beware that you cannot rely on the column order of the data returned
by task$data()
as the order of columns may be different
from the order of the columns during $.train()
. The
newdata
line ensures the ordering is the same by calling
the same order as in train!
Optional Extractors
Specific learner implementations are free to implement additional
getters to ease the access of certain parts of the model in the
inherited subclasses. Some of these methods are standardized in
mlr3
, c.f. the Learner
documentation. If it is
one of the standardized extractors, some custom code might have to be
written for the return value to comply with the mlr3
standards (see the Learner
documentation). Because we
specified earlier that the learner has the property
"importance"
, we implement the public
$importance()
method. They access the $model
slot of a learner after training and return one of its fields. Because
the $importance()
method is standardized, it must return
the importance scores as a sorted numeric vector, the names being the
features.
importance = function() {
if (is.null(self$model)) {
stopf("No model stored")
}
# importance is only present if there is at least on split
if (is.null(self$model$variable.importance)) {
importance = set_names(numeric())
} else {
importance = sort(self$model$variable.importance, decreasing = TRUE)
}
return(importance)
}
Testing the learner
Once your learner is created, you should write tests to verify its
correctness. While you should also do this when implementing a learner
only for your own use, this is absolutely required when you want to
contribute a new learner to mlr3extralearners
. Besides
manual tests, the mlr3
package also provides automatic
tests to check that a learner satisfies some basic sanity checks and
that it implements all the available parameters of the upstream
function. The manual page for these functions can be accessed via
?mlr_test_helpers
.
For a bare-bone check you can just try to run a simple
$train()
call.
task = tsk("mtcars")
learner = LearnerRegrRpartSimple$new()
learner$train(task)
p = learner$predict(task)
p$score(msr("regr.mse"))
## regr.mse
## 9.127693
learner$importance()
## cyl disp hp wt qsec vs carb gear
## 724.18935 721.08062 702.29023 573.72817 442.05737 395.01237 31.36333 15.68167
If it runs without erroring, that’s a very good start!
Autotest
To ensure that your learner is able to handle all kinds of different
properties and feature types, we have written an “autotest” that checks
the learner for different combinations of such. In addition to the
autotest we also check some other learner properties using
expect_learner()
. See section run_autotest and
expect_learner of ?mlr_test_helpers
for an
explanation. Because these functions are not in the mlr3
namespace, we have to source them from the ./testthat
subfolder of the installed mlr3
package Note that when
implementing survival learners, you must also source these help files
from the mlr3proba
package.
lapply(list.files(system.file("testthat", package = "mlr3"), pattern = "^helper.*\\.[rR]", full.names = TRUE), source)
# mlr3proba
# lapply(list.files(system.file("testthat", package = "mlr3proba"), pattern = "^helper.*\\.[rR]", full.names = TRUE), source)
We now perform the autotest. By setting
expect_learner(..., check_man = FALSE)
, we disable a test
that verifies that the correct manual page exists, which is only
relevant when adding the learner to mlr3extralearners
.
library(testthat)
test_that("autotest", {
learner = LearnerRegrRpartSimple$new()
# basic learner properties
expect_learner(learner)
# you can skip tests using the `exclude` argument
result = run_autotest(learner)
expect_true(result, info = result$error)
})
## Test passed 🎉
Checking Parameters
Some learners have a high number of parameters and it is easy to miss
out on some during the creation of a new learner or misspell a
parameter. Therefore we have written a “Parameter Check” that can be
used to check that the parameters are correctly implemented. See section
run_paramtest of ?mlr_test_helpers
for an
explanation. This “Parameter Check” compares the parameters of the
mlr3
ParamSet against all arguments available in the
upstream function that is called during $train()
and
$predict()
. When the $.train()
method either
calls multiple functions, or, e.g., the ...
arguments are
passed further to a control function (as in the case here), a list of
functions can be passed to the parameter test.
The test comes with an exclude
argument that should be
used to exclude and explain why certain arguments of the
upstream function are not within the ParamSet
of the
learner. This will likely be required for all learners as common
arguments like x
, target
or data
are handled by the mlr3
interface and are therefore not
included within the ParamSet.
However, there might be more parameters that need to be excluded, for example:
- Type dependent parameters, i.e. parameters that only apply for classification or regression learners.
- Parameters that are actually deprecated by the upstream package and
which were therefore not included in the
mlr3
ParamSet.
All excluded parameters should have a comment justifying their exclusion.
In our example, the final paramtest script looks like:
test_that("paramtest", {
learner = LearnerRegrRpartSimple$new()
exclude = c("formula", "data", "weights", "subset", "na.action", "method", "model",
"x", "y", "parms", "control", "cost", "keep_model")
result = run_paramtest(learner, list(rpart::rpart, rpart::rpart.control), exclude, tag = "train")
expect_true(result, info = result$error)
exclude = c(
"object", # handled internally
"newdata", # handled internally
"type", # handled internally
"na.action", # handled internally
"model" # not implemented
)
result = run_paramtest(learner, rpart:::predict.rpart, exclude, tag = "predict")
expect_true(result, info = result$error)
})
## Test passed 🎉
Contributing to mlr3extralearners
When adding a Learner
to mlr3extralearners
there are some additional requirements that have to be satisfied:
- Adding the line
.extralrns_dict$add("<learner-id>", function() <LearnerName>$new())
to the bottom of the learner file to add the learner to the learnersDictionary
. - Add the learner dependencies to Suggests (and not e.g. to Imports) in the DESCRIPTION file.
- The learner must be properly documented. When using
mlr3extralearners::create_learner()
this is automatically generated. Otherwise you can use and adapt the documentation of an already implemented learner. - Conforming to the
mlr3
style guide. - The tests must pass in the CI (also check the CodeFactor results).
- Update the
NEWS.md
file and describe that you added the learner. - Add yourself as a contributor