TabPFN Regression Learner
mlr_learners_regr.tabpfn.Rd
Foundation model for tabular data.
Uses reticulate to interface with the tabpfn
Python package.
Installation
While the Python dependencies are handled via reticulate::py_require()
, you can
manually specify a virtual environment by calling reticulate::use_virtualenv()
prior to calling the $train()
function.
In this virtual environment, the tabpfn
package and its dependencies must be installed.
Saving a Learner
In order to save a LearnerRegrTabPFN
for later usage,
it is necessary to call the $marshal()
method on the Learner
before writing it to disk, as the object will otherwise not be saved correctly.
After loading a marshaled LearnerRegrTabPFN
into R again,
you then need to call $unmarshal()
to transform it into a useable state.
Initial parameter values
n_jobs
is initialized to 1 to avoid threading conflicts with future.
Custom mlr3 parameters
output_type
corresponds to the same argument of the.predict()
method of theTabPFNRegressor
class, but only supports the options"mean"
,"median"
and"mode"
. The point predictions are stored as$response
of the prediction object.categorical_feature_indices
uses R indexing instead of zero-based Python indexing.device
must be a string. If set to"auto"
, the behavior is the same as original. Otherwise, the string is passed as argument totorch.device()
to create a device.inference_precision
must be"auto"
or"autocast"
. Passingtorch.dtype
is currently not supported.inference_config
is currently not supported.
Meta Information
Task type: “regr”
Predict Types: “response”, “quantiles”
Feature Types: “logical”, “integer”, “numeric”
Required Packages: mlr3
Parameters
Id | Type | Default | Levels | Range |
output_type | character | mean | mean, median, mode | - |
n_estimators | integer | 4 | \([1, \infty)\) | |
categorical_features_indices | untyped | - | - | |
softmax_temperature | numeric | 0.9 | \([0, \infty)\) | |
balance_probabilities | logical | FALSE | TRUE, FALSE | - |
average_before_softmax | logical | FALSE | TRUE, FALSE | - |
model_path | untyped | "auto" | - | |
device | untyped | "auto" | - | |
ignore_pretraining_limits | logical | FALSE | TRUE, FALSE | - |
inference_precision | character | auto | auto, autocast | - |
fit_mode | character | fit_preprocessors | low_memory, fit_preprocessors, fit_with_cache | - |
memory_saving_mode | untyped | "auto" | - | |
random_state | integer | 0 | \((-\infty, \infty)\) | |
n_jobs | integer | - | \([1, \infty)\) |
References
Hollmann, Noah, Müller, Samuel, Purucker, Lennart, Krishnakumar, Arjun, Körfer, Max, Hoo, Bin S, Schirrmeister, Tibor R, Hutter, Frank (2025). “Accurate predictions on small data with a tabular foundation model.” Nature. doi:10.1038/s41586-024-08328-6 , https://www.nature.com/articles/s41586-024-08328-6.
Hollmann, Noah, Müller, Samuel, Eggensperger, Katharina, Hutter, Frank (2023). “TabPFN: A transformer that solves small tabular classification problems in a second.” In International Conference on Learning Representations 2023.
See also
as.data.table(mlr_learners)
for a table of available Learners in the running session (depending on the loaded packages).Chapter in the mlr3book: https://mlr3book.mlr-org.com/basics.html#learners
mlr3learners for a selection of recommended learners.
mlr3cluster for unsupervised clustering learners.
mlr3pipelines to combine learners with pre- and postprocessing steps.
mlr3tuning for tuning of hyperparameters, mlr3tuningspaces for established default tuning spaces.
Super classes
mlr3::Learner
-> mlr3::LearnerRegr
-> LearnerRegrTabPFN
Methods
Method marshal()
Marshal the learner's model.
Arguments
...
(any)
Additional arguments passed tomarshal_model()
.
Method unmarshal()
Unmarshal the learner's model.
Arguments
...
(any)
Additional arguments passed tounmarshal_model()
.
Examples
# Define the Learner
learner = mlr3::lrn("regr.tabpfn")
print(learner)
#> <LearnerRegrTabPFN:regr.tabpfn>: TabPFN Regressor
#> * Model: -
#> * Parameters: n_jobs=1
#> * Packages: mlr3
#> * Predict Types: [response], quantiles
#> * Feature Types: logical, integer, numeric
#> * Properties: marshal, missings
# Define a Task
task = mlr3::tsk("mtcars")
# Create train and test set
ids = mlr3::partition(task)
# Train the learner on the training ids
learner$train(task, row_ids = ids$train)
print(learner$model)
#> $fitted
#> TabPFNRegressor(n_jobs=1)
#>
#> attr(,"class")
#> [1] "tabpfn_model"
# Make predictions for the test rows
predictions = learner$predict(task, row_ids = ids$test)
# Score the predictions
predictions$score()
#> regr.mse
#> 9.461085