library("parttree")
data.table::setDTthreads(2)
#
## rpart trees
library("rpart")
rp = rpart(Kyphosis ~ Start + Age, data = kyphosis)
# A parttree object is just a data frame with additional attributes
(rp_pt = parttree(rp))
attr(rp_pt, "parttree")
# simple plot
plot(rp_pt)
# removing the (recursive) partition borders helps to emphasise overall fit
plot(rp_pt, border = NA)
# customize further by passing extra options to (tiny)plot
plot(
rp_pt,
border = NA, # no partition borders
pch = 16, # filled points
alpha = 0.6, # point transparency
grid = TRUE, # background grid
palette = "classic", # new colour palette
xlab = "Topmost vertebra operated on", # custom x title
ylab = "Patient age (months)", # custom y title
main = "Tree predictions: Kyphosis recurrence" # custom title
)
#
## conditional inference trees from partyit
library("partykit")
ct = ctree(Species ~ Petal.Length + Petal.Width, data = iris)
ct_pt = parttree(ct)
plot(ct_pt, pch = 19, palette = "okabe", main = "ctree predictions: iris species")
## rpart via partykit
rp2 = as.party(rp)
parttree(rp2)
#
## various front-end frameworks are also supported, e.g.
# tidymodels
# install.packages("parsnip")
library(parsnip)
decision_tree() |>
set_engine("rpart") |>
set_mode("classification") |>
fit(Species ~ Petal.Length + Petal.Width, data=iris) |>
parttree() |>
plot(main = "This time brought to you via parsnip...")
# mlr3 (NB: use `keep_model = TRUE` for mlr3 learners)
# install.packages("mlr3")
library(mlr3)
task_iris = TaskClassif$new("iris", iris, target = "Species")
task_iris$formula(rhs = "Petal.Length + Petal.Width")
fit_iris = lrn("classif.rpart", keep_model = TRUE) # NB!
fit_iris$train(task_iris)
plot(parttree(fit_iris), main = "... and now mlr3")
Run the code above in your browser using DataLab