TabPFN Classification Learner
mlr_learners_classif.tabpfn.Rd
Foundation model for tabular data.
Uses reticulate to interface with the tabpfn
Python package.
Installation
While the Python dependencies are handled via reticulate::py_require()
, you can
manually specify a virtual environment by calling reticulate::use_virtualenv()
prior to calling the $train()
function.
In this virtual environment, the tabpfn
package and its dependencies must be installed.
Saving a Learner
In order to save a LearnerClassifTabPFN
for later usage,
it is necessary to call the $marshal()
method on the Learner
before writing it to disk, as the object will otherwise not be saved correctly.
After loading a marshaled LearnerClassifTabPFN
into R again,
you then need to call $unmarshal()
to transform it into a useable state.
Initial parameter values
n_jobs
is initialized to 1 to avoid threading conflicts with future.
Custom mlr3 parameters
categorical_feature_indices
uses R indexing instead of zero-based Python indexing.device
must be a string. If set to"auto"
, the behavior is the same as original. Otherwise, the string is passed as argument totorch.device()
to create a device.inference_precision
must be"auto"
or"autocast"
. Passingtorch.dtype
is currently not supported.inference_config
is currently not supported.
Meta Information
Task type: “classif”
Predict Types: “response”, “prob”
Feature Types: “logical”, “integer”, “numeric”
Required Packages: mlr3
Parameters
Id | Type | Default | Levels | Range |
n_estimators | integer | 4 | \([1, \infty)\) | |
categorical_features_indices | untyped | - | - | |
softmax_temperature | numeric | 0.9 | \([0, \infty)\) | |
balance_probabilities | logical | FALSE | TRUE, FALSE | - |
average_before_softmax | logical | FALSE | TRUE, FALSE | - |
model_path | untyped | "auto" | - | |
device | untyped | "auto" | - | |
ignore_pretraining_limits | logical | FALSE | TRUE, FALSE | - |
inference_precision | character | auto | auto, autocast | - |
fit_mode | character | fit_preprocessors | low_memory, fit_preprocessors, fit_with_cache | - |
memory_saving_mode | untyped | "auto" | - | |
random_state | integer | 0 | \((-\infty, \infty)\) | |
n_jobs | integer | - | \([1, \infty)\) |
References
Hollmann, Noah, Müller, Samuel, Purucker, Lennart, Krishnakumar, Arjun, Körfer, Max, Hoo, Bin S, Schirrmeister, Tibor R, Hutter, Frank (2025). “Accurate predictions on small data with a tabular foundation model.” Nature. doi:10.1038/s41586-024-08328-6 , https://www.nature.com/articles/s41586-024-08328-6.
Hollmann, Noah, Müller, Samuel, Eggensperger, Katharina, Hutter, Frank (2023). “TabPFN: A transformer that solves small tabular classification problems in a second.” In International Conference on Learning Representations 2023.
See also
as.data.table(mlr_learners)
for a table of available Learners in the running session (depending on the loaded packages).Chapter in the mlr3book: https://mlr3book.mlr-org.com/basics.html#learners
mlr3learners for a selection of recommended learners.
mlr3cluster for unsupervised clustering learners.
mlr3pipelines to combine learners with pre- and postprocessing steps.
mlr3tuning for tuning of hyperparameters, mlr3tuningspaces for established default tuning spaces.
Super classes
mlr3::Learner
-> mlr3::LearnerClassif
-> LearnerClassifTabPFN
Methods
Method marshal()
Marshal the learner's model.
Arguments
...
(any)
Additional arguments passed tomarshal_model()
.
Method unmarshal()
Unmarshal the learner's model.
Arguments
...
(any)
Additional arguments passed tounmarshal_model()
.
Examples
# Define the Learner
learner = mlr3::lrn("classif.tabpfn")
print(learner)
#> <LearnerClassifTabPFN:classif.tabpfn>: TabPFN Classifier
#> * Model: -
#> * Parameters: n_jobs=1
#> * Packages: mlr3
#> * Predict Types: [response], prob
#> * Feature Types: logical, integer, numeric
#> * Properties: marshal, missings, multiclass, twoclass
# Define a Task
task = mlr3::tsk("sonar")
# Create train and test set
ids = mlr3::partition(task)
# Train the learner on the training ids
learner$train(task, row_ids = ids$train)
print(learner$model)
#> $fitted
#> TabPFNClassifier(n_jobs=1)
#>
#> attr(,"class")
#> [1] "tabpfn_model"
# Make predictions for the test rows
predictions = learner$predict(task, row_ids = ids$test)
# Score the predictions
predictions$score()
#> classif.ce
#> 0.1884058