Classification BART (Bayesian Additive Regression Trees) Learner
mlr_learners_classif.bart.Rd
Bayesian Additive Regression Trees are similar to gradient boosting algorithms.
The classification problem is solved by 0-1 encoding of the two-class targets and setting the
decision threshold to p = 0.5 during the prediction phase.
Calls dbarts::bart()
from dbarts.
Meta Information
Task type: “classif”
Predict Types: “response”, “prob”
Feature Types: “integer”, “numeric”, “factor”, “ordered”
Required Packages: mlr3, mlr3extralearners, dbarts
Parameters
Id | Type | Default | Levels | Range |
ntree | integer | 200 | \([1, \infty)\) | |
k | numeric | 2 | \([0, \infty)\) | |
power | numeric | 2 | \([0, \infty)\) | |
base | numeric | 0.95 | \([0, 1]\) | |
binaryOffset | numeric | 0 | \((-\infty, \infty)\) | |
ndpost | integer | 1000 | \([1, \infty)\) | |
nskip | integer | 100 | \([0, \infty)\) | |
printevery | integer | 100 | \([0, \infty)\) | |
keepevery | integer | 1 | \([1, \infty)\) | |
keeptrainfits | logical | TRUE | TRUE, FALSE | - |
usequants | logical | FALSE | TRUE, FALSE | - |
numcut | integer | 100 | \([1, \infty)\) | |
printcutoffs | integer | 0 | \((-\infty, \infty)\) | |
verbose | logical | FALSE | TRUE, FALSE | - |
nthread | integer | 1 | \((-\infty, \infty)\) | |
keepcall | logical | TRUE | TRUE, FALSE | - |
sampleronly | logical | FALSE | TRUE, FALSE | - |
seed | integer | NA | \((-\infty, \infty)\) | |
proposalprobs | untyped | NULL | - | |
splitprobs | untyped | NULL | - | |
keepsampler | logical | - | TRUE, FALSE | - |
Parameter Changes
Parameter: keeptrees
Original: FALSE
New: TRUE
Reason: Required for prediction
Parameter: offset
The parameter is removed, because only
dbarts::bart2
allows an offset during training, and therefore the offset parameter indbarts:::predict.bart
is irrelevant fordbarts::dbart
.Parameter: nchain, combineChains, combinechains
The parameters are removed as parallelization of multiple models is handled by future.
Parameter: sigest, sigdf, sigquant, keeptres
Regression only.
References
Sparapani, Rodney, Spanbauer, Charles, McCulloch, Robert (2021). “Nonparametric machine learning and efficient computation with bayesian additive regression trees: the BART R package.” Journal of Statistical Software, 97, 1–66.
Chipman, A H, George, I E, McCulloch, E R (2010). “BART: Bayesian additive regression trees.” The Annals of Applied Statistics, 4(1), 266–298.
See also
as.data.table(mlr_learners)
for a table of available Learners in the running session (depending on the loaded packages).Chapter in the mlr3book: https://mlr3book.mlr-org.com/basics.html#learners
mlr3learners for a selection of recommended learners.
mlr3cluster for unsupervised clustering learners.
mlr3pipelines to combine learners with pre- and postprocessing steps.
mlr3tuning for tuning of hyperparameters, mlr3tuningspaces for established default tuning spaces.
Super classes
mlr3::Learner
-> mlr3::LearnerClassif
-> LearnerClassifBart
Examples
# Define the Learner
learner = mlr3::lrn("classif.bart")
print(learner)
#> <LearnerClassifBart:classif.bart>: Bayesian Additive Regression Trees
#> * Model: -
#> * Parameters: list()
#> * Packages: mlr3, mlr3extralearners, dbarts
#> * Predict Types: [response], prob
#> * Feature Types: integer, numeric, factor, ordered
#> * Properties: twoclass, weights
# Define a Task
task = mlr3::tsk("sonar")
# Create train and test set
ids = mlr3::partition(task)
# Train the learner on the training ids
learner$train(task, row_ids = ids$train)
#>
#> Running BART with binary y
#>
#> number of trees: 200
#> number of chains: 1, number of threads 1
#> tree thinning rate: 1
#> Prior:
#> k prior fixed to 2.000000
#> power and base for tree prior: 2.000000 0.950000
#> use quantiles for rule cut points: false
#> proposal probabilities: birth/death 0.50, swap 0.10, change 0.40; birth 0.50
#> data:
#> number of training observations: 139
#> number of test observations: 0
#> number of explanatory variables: 60
#>
#> Cutoff rules c in x<=c vs x>c
#> Number of cutoffs: (var: number of possible c):
#> (1: 100) (2: 100) (3: 100) (4: 100) (5: 100)
#> (6: 100) (7: 100) (8: 100) (9: 100) (10: 100)
#> (11: 100) (12: 100) (13: 100) (14: 100) (15: 100)
#> (16: 100) (17: 100) (18: 100) (19: 100) (20: 100)
#> (21: 100) (22: 100) (23: 100) (24: 100) (25: 100)
#> (26: 100) (27: 100) (28: 100) (29: 100) (30: 100)
#> (31: 100) (32: 100) (33: 100) (34: 100) (35: 100)
#> (36: 100) (37: 100) (38: 100) (39: 100) (40: 100)
#> (41: 100) (42: 100) (43: 100) (44: 100) (45: 100)
#> (46: 100) (47: 100) (48: 100) (49: 100) (50: 100)
#> (51: 100) (52: 100) (53: 100) (54: 100) (55: 100)
#> (56: 100) (57: 100) (58: 100) (59: 100) (60: 100)
#>
#> offsets:
#> reg : 0.00 0.00 0.00 0.00 0.00
#> Running mcmc loop:
#> iteration: 100 (of 1000)
#> iteration: 200 (of 1000)
#> iteration: 300 (of 1000)
#> iteration: 400 (of 1000)
#> iteration: 500 (of 1000)
#> iteration: 600 (of 1000)
#> iteration: 700 (of 1000)
#> iteration: 800 (of 1000)
#> iteration: 900 (of 1000)
#> iteration: 1000 (of 1000)
#> total seconds in loop: 0.521236
#>
#> Tree sizes, last iteration:
#> [1] 2 2 2 4 2 2 3 2 3 3 3 2 2 3 2 2 4 3
#> 2 2 2 2 2 3 2 2 2 2 2 2 2 2 1 2 4 2 3 2
#> 5 1 2 2 3 2 3 3 2 2 1 2 4 3 2 4 2 5 2 2
#> 2 2 1 1 2 4 3 2 2 2 2 2 2 3 2 2 2 2 2 1
#> 4 2 2 2 2 3 2 2 2 2 2 2 3 2 2 2 2 2 2 2
#> 1 2 2 3 3 2 2 3 5 2 2 4 2 2 3 1 5 2 3 2
#> 3 2 2 2 3 1 2 2 2 2 2 5 2 2 4 3 2 2 2 1
#> 3 2 2 5 3 3 3 1 2 2 2 3 3 3 3 3 2 4 2 2
#> 2 3 2 2 3 2 4 2 3 2 2 2 2 2 1 3 3 2 1 2
#> 2 2 2 2 2 3 3 2 2 2 3 3 4 2 1 3 1 4 2 2
#> 2 3
#>
#> Variable Usage, last iteration (var:count):
#> (1: 3) (2: 5) (3: 7) (4: 8) (5: 4)
#> (6: 8) (7: 3) (8: 10) (9: 5) (10: 5)
#> (11: 3) (12: 3) (13: 5) (14: 12) (15: 5)
#> (16: 4) (17: 5) (18: 3) (19: 1) (20: 5)
#> (21: 5) (22: 5) (23: 6) (24: 3) (25: 2)
#> (26: 5) (27: 8) (28: 3) (29: 4) (30: 7)
#> (31: 5) (32: 1) (33: 3) (34: 3) (35: 2)
#> (36: 4) (37: 4) (38: 2) (39: 4) (40: 4)
#> (41: 6) (42: 7) (43: 3) (44: 3) (45: 2)
#> (46: 6) (47: 4) (48: 2) (49: 4) (50: 7)
#> (51: 6) (52: 2) (53: 4) (54: 5) (55: 4)
#> (56: 4) (57: 1) (58: 5) (59: 8) (60: 6)
#>
#> DONE BART
#>
print(learner$model)
#>
#> Call:
#> dbarts::bart(x.train = x_train, y.train = y_train, keeptrees = TRUE)
#>
# Make predictions for the test rows
predictions = learner$predict(task, row_ids = ids$test)
# Score the predictions
predictions$score()
#> classif.ce
#> 0.173913