Skip to contents

Foundation model for tabular data. Uses reticulate to interface with the tabpfn Python package.

Installation

While the Python dependencies are handled via reticulate::py_require(), you can manually specify a virtual environment by calling reticulate::use_virtualenv() prior to calling the $train() function. In this virtual environment, the tabpfn package and its dependencies must be installed.

Saving a Learner

In order to save a LearnerClassifTabPFN for later usage, it is necessary to call the $marshal() method on the Learner before writing it to disk, as the object will otherwise not be saved correctly. After loading a marshaled LearnerClassifTabPFN into R again, you then need to call $unmarshal() to transform it into a useable state.

Initial parameter values

  • n_jobs is initialized to 1 to avoid threading conflicts with future.

Custom mlr3 parameters

  • categorical_feature_indices uses R indexing instead of zero-based Python indexing.

  • device must be a string. If set to "auto", the behavior is the same as original. Otherwise, the string is passed as argument to torch.device() to create a device.

  • inference_precision must be "auto" or "autocast". Passing torch.dtype is currently not supported.

  • inference_config is currently not supported.

Dictionary

This Learner can be instantiated via lrn():

lrn("classif.tabpfn")

Meta Information

  • Task type: “classif”

  • Predict Types: “response”, “prob”

  • Feature Types: “logical”, “integer”, “numeric”

  • Required Packages: mlr3

Parameters

IdTypeDefaultLevelsRange
n_estimatorsinteger4\([1, \infty)\)
categorical_features_indicesuntyped--
softmax_temperaturenumeric0.9\([0, \infty)\)
balance_probabilitieslogicalFALSETRUE, FALSE-
average_before_softmaxlogicalFALSETRUE, FALSE-
model_pathuntyped"auto"-
deviceuntyped"auto"-
ignore_pretraining_limitslogicalFALSETRUE, FALSE-
inference_precisioncharacterautoauto, autocast-
fit_modecharacterfit_preprocessorslow_memory, fit_preprocessors, fit_with_cache-
memory_saving_modeuntyped"auto"-
random_stateinteger0\((-\infty, \infty)\)
n_jobsinteger-\([1, \infty)\)

References

Hollmann, Noah, Müller, Samuel, Purucker, Lennart, Krishnakumar, Arjun, Körfer, Max, Hoo, Bin S, Schirrmeister, Tibor R, Hutter, Frank (2025). “Accurate predictions on small data with a tabular foundation model.” Nature. doi:10.1038/s41586-024-08328-6 , https://www.nature.com/articles/s41586-024-08328-6.

Hollmann, Noah, Müller, Samuel, Eggensperger, Katharina, Hutter, Frank (2023). “TabPFN: A transformer that solves small tabular classification problems in a second.” In International Conference on Learning Representations 2023.

See also

Author

b-zhou

Super classes

mlr3::Learner -> mlr3::LearnerClassif -> LearnerClassifTabPFN

Active bindings

marshaled

(logical(1))
Whether the learner has been marshaled.

Methods

Inherited methods


Method new()

Creates a new instance of this R6 class.

Usage


Method marshal()

Marshal the learner's model.

Usage

LearnerClassifTabPFN$marshal(...)

Arguments

...

(any)
Additional arguments passed to marshal_model().


Method unmarshal()

Unmarshal the learner's model.

Usage

LearnerClassifTabPFN$unmarshal(...)

Arguments

...

(any)
Additional arguments passed to unmarshal_model().


Method clone()

The objects of this class are cloneable with this method.

Usage

LearnerClassifTabPFN$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

# Define the Learner
learner = mlr3::lrn("classif.tabpfn")
print(learner)
#> <LearnerClassifTabPFN:classif.tabpfn>: TabPFN Classifier
#> * Model: -
#> * Parameters: n_jobs=1
#> * Packages: mlr3
#> * Predict Types:  [response], prob
#> * Feature Types: logical, integer, numeric
#> * Properties: marshal, missings, multiclass, twoclass

# Define a Task
task = mlr3::tsk("sonar")

# Create train and test set
ids = mlr3::partition(task)

# Train the learner on the training ids
learner$train(task, row_ids = ids$train)

print(learner$model)
#> $fitted
#> TabPFNClassifier(n_jobs=1)
#> 
#> attr(,"class")
#> [1] "tabpfn_model"


# Make predictions for the test rows
predictions = learner$predict(task, row_ids = ids$test)

# Score the predictions
predictions$score()
#> classif.ce 
#>  0.1884058