Foundation model for tabular data.
Uses reticulate to interface with the tabpfn
Python package.
Installation
This learner relies on reticulate to handle Python dependencies.
It is not necessary to install any Python package manually in advance or specify a Python environment
via reticulate::use_python()
, reticulate::use_virtualenv()
, reticulate::use_condaenv()
,
or reticulate::use_miniconda()
.
By calling $train()
or $predict()
, the required Python packages (tapfn
, torch
, etc.) will be installed
automatically, if not already.
Reticulate will then configure and initialize an ephemeral environment satisfying those requirements,
unless an existing environment (e.g., "r-reticulate"
) in reticulate's
Order of Discovery
contains all the necessary packages.
You may also manually install tabpfn
into a Python environment following the
official installation guide
and specify the environment via reticulate::use_*()
before calling $train()
or $predict()
.
Note that the GPU version of PyTorch cannot be loaded by reticulate::use_condaenv()
from a conda environment.
To use a conda environment, please install the CPU version of PyTorch.
Saving a Learner
In order to save a LearnerRegrTabPFN
for later usage,
it is necessary to call the $marshal()
method on the Learner
before writing it to disk, as the object will otherwise not be saved correctly.
After loading a marshaled LearnerRegrTabPFN
into R again,
you then need to call $unmarshal()
to transform it into a useable state.
Initial parameter values
n_jobs
is initialized to 1 to avoid threading conflicts with future.
Custom mlr3 parameters
output_type
corresponds to the same argument of the.predict()
method of theTabPFNRegressor
class, but only supports the options"mean"
,"median"
and"mode"
. The point predictions are stored as$response
of the prediction object.categorical_feature_indices
uses R indexing instead of zero-based Python indexing.device
must be a string. If set to"auto"
, the behavior is the same as original. Otherwise, the string is passed as argument totorch.device()
to create a device.inference_precision
must be"auto"
,"autocast"
, or atorch.dtype
string, e.g.,"torch.float32"
,"torch.float64"
, etc. Non-float dtypes are not supported.inference_config
is currently not supported.random_state
accepts either an integer or the special value"None"
which corresponds toNone
in Python. Following the original Python implementation, the defaultrandom_state
is0
.
Meta Information
Task type: “regr”
Predict Types: “response”, “quantiles”
Feature Types: “logical”, “integer”, “numeric”
Required Packages: mlr3, reticulate
Parameters
Id | Type | Default | Levels | Range |
output_type | character | mean | mean, median, mode | - |
n_estimators | integer | 4 | \([1, \infty)\) | |
categorical_features_indices | untyped | - | - | |
softmax_temperature | numeric | 0.9 | \([0, \infty)\) | |
average_before_softmax | logical | FALSE | TRUE, FALSE | - |
model_path | untyped | "auto" | - | |
device | untyped | "auto" | - | |
ignore_pretraining_limits | logical | FALSE | TRUE, FALSE | - |
inference_precision | character | auto | auto, autocast, torch.float32, torch.float, torch.float64, torch.double, torch.float16, torch.half, torch.bfloat16 | - |
fit_mode | character | fit_preprocessors | low_memory, fit_preprocessors, fit_with_cache | - |
memory_saving_mode | untyped | "auto" | - | |
random_state | integer | 0 | \((-\infty, \infty)\) | |
n_jobs | integer | - | \([1, \infty)\) |
References
Hollmann, Noah, Müller, Samuel, Purucker, Lennart, Krishnakumar, Arjun, Körfer, Max, Hoo, Bin S, Schirrmeister, Tibor R, Hutter, Frank (2025). “Accurate predictions on small data with a tabular foundation model.” Nature. doi:10.1038/s41586-024-08328-6 , https://www.nature.com/articles/s41586-024-08328-6.
Hollmann, Noah, Müller, Samuel, Eggensperger, Katharina, Hutter, Frank (2023). “TabPFN: A transformer that solves small tabular classification problems in a second.” In International Conference on Learning Representations 2023.
See also
as.data.table(mlr_learners)
for a table of available Learners in the running session (depending on the loaded packages).Chapter in the mlr3book: https://mlr3book.mlr-org.com/basics.html#learners
mlr3learners for a selection of recommended learners.
mlr3cluster for unsupervised clustering learners.
mlr3pipelines to combine learners with pre- and postprocessing steps.
mlr3tuning for tuning of hyperparameters, mlr3tuningspaces for established default tuning spaces.
Super classes
mlr3::Learner
-> mlr3::LearnerRegr
-> LearnerRegrTabPFN
Methods
Inherited methods
mlr3::Learner$base_learner()
mlr3::Learner$configure()
mlr3::Learner$encapsulate()
mlr3::Learner$format()
mlr3::Learner$help()
mlr3::Learner$predict()
mlr3::Learner$predict_newdata()
mlr3::Learner$print()
mlr3::Learner$reset()
mlr3::Learner$selected_features()
mlr3::Learner$train()
mlr3::LearnerRegr$predict_newdata_fast()
Method marshal()
Marshal the learner's model.
Arguments
...
(any)
Additional arguments passed tomlr3::marshal_model()
.
Method unmarshal()
Unmarshal the learner's model.
Arguments
...
(any)
Additional arguments passed tomlr3::unmarshal_model()
.