Skip to contents

Foundation model for tabular data. Uses reticulate to interface with the tabpfn Python package.

Installation

This learner relies on reticulate to handle Python dependencies. It is not necessary to install any Python package manually in advance or specify a Python environment via reticulate::use_python(), reticulate::use_virtualenv(), reticulate::use_condaenv(), or reticulate::use_miniconda(). By calling $train() or $predict(), the required Python packages (tapfn, torch, etc.) will be installed automatically, if not already. Reticulate will then configure and initialize an ephemeral environment satisfying those requirements, unless an existing environment (e.g., "r-reticulate") in reticulate's Order of Discovery contains all the necessary packages.

You may also manually install tabpfn into a Python environment following the official installation guide and specify the environment via reticulate::use_*() before calling $train() or $predict(). Note that the GPU version of PyTorch cannot be loaded by reticulate::use_condaenv() from a conda environment. To use a conda environment, please install the CPU version of PyTorch.

Saving a Learner

In order to save a LearnerClassifTabPFN for later usage, it is necessary to call the $marshal() method on the Learner before writing it to disk, as the object will otherwise not be saved correctly. After loading a marshaled LearnerClassifTabPFN into R again, you then need to call $unmarshal() to transform it into a useable state.

Initial parameter values

  • n_jobs is initialized to 1 to avoid threading conflicts with future.

Custom mlr3 parameters

  • categorical_feature_indices uses R indexing instead of zero-based Python indexing.

  • device must be a string. If set to "auto", the behavior is the same as original. Otherwise, the string is passed as argument to torch.device() to create a device.

  • inference_precision must be "auto", "autocast", or a torch.dtype string, e.g., "torch.float32", "torch.float64", etc. Non-float dtypes are not supported.

  • inference_config is currently not supported.

  • random_state accepts either an integer or the special value "None" which corresponds to None in Python. Following the original Python implementation, the default random_state is 0.

Dictionary

This Learner can be instantiated via lrn():

lrn("classif.tabpfn")

Meta Information

  • Task type: “classif”

  • Predict Types: “response”, “prob”

  • Feature Types: “logical”, “integer”, “numeric”

  • Required Packages: mlr3, reticulate

Parameters

IdTypeDefaultLevelsRange
n_estimatorsinteger4\([1, \infty)\)
categorical_features_indicesuntyped--
softmax_temperaturenumeric0.9\([0, \infty)\)
balance_probabilitieslogicalFALSETRUE, FALSE-
average_before_softmaxlogicalFALSETRUE, FALSE-
model_pathuntyped"auto"-
deviceuntyped"auto"-
ignore_pretraining_limitslogicalFALSETRUE, FALSE-
inference_precisioncharacterautoauto, autocast, torch.float32, torch.float, torch.float64, torch.double, torch.float16, torch.half, torch.bfloat16-
fit_modecharacterfit_preprocessorslow_memory, fit_preprocessors, fit_with_cache-
memory_saving_modeuntyped"auto"-
random_stateinteger0\((-\infty, \infty)\)
n_jobsinteger-\([1, \infty)\)

References

Hollmann, Noah, Müller, Samuel, Purucker, Lennart, Krishnakumar, Arjun, Körfer, Max, Hoo, Bin S, Schirrmeister, Tibor R, Hutter, Frank (2025). “Accurate predictions on small data with a tabular foundation model.” Nature. doi:10.1038/s41586-024-08328-6 , https://www.nature.com/articles/s41586-024-08328-6.

Hollmann, Noah, Müller, Samuel, Eggensperger, Katharina, Hutter, Frank (2023). “TabPFN: A transformer that solves small tabular classification problems in a second.” In International Conference on Learning Representations 2023.

See also

Author

b-zhou

Super classes

mlr3::Learner -> mlr3::LearnerClassif -> LearnerClassifTabPFN

Active bindings

marshaled

(logical(1))
Whether the learner has been marshaled.

Methods

Inherited methods


Method new()

Creates a new instance of this R6 class.

Usage


Method marshal()

Marshal the learner's model.

Usage

LearnerClassifTabPFN$marshal(...)

Arguments

...

(any)
Additional arguments passed to mlr3::marshal_model().


Method unmarshal()

Unmarshal the learner's model.

Usage

LearnerClassifTabPFN$unmarshal(...)

Arguments

...

(any)
Additional arguments passed to mlr3::unmarshal_model().


Method clone()

The objects of this class are cloneable with this method.

Usage

LearnerClassifTabPFN$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.