Learn R Programming

treeheatr (version 0.2.3)

draw_tree: Draws the conditional decision tree.

Description

Draws the conditional decision tree output from partykit::ctree(), utilizing ggparty geoms: geom_edge, geom_edge_label, geom_node_label.

Usage

draw_tree(
  dat,
  fit,
  term_dat,
  layout,
  target_cols = NULL,
  title = NULL,
  tree_space_top = 0.05,
  tree_space_bottom = 0.05,
  print_eval = FALSE,
  metrics = NULL,
  x_eval = 0,
  y_eval = 0.9,
  task = c("classification", "regression"),
  par_node_vars = list(label.padding = unit(0.15, "lines"), line_list = list(aes(label =
    splitvar)), line_gpar = list(list(size = 9)), ids = "inner"),
  terminal_vars = list(label.padding = unit(0.25, "lines"), size = 3, col = "white"),
  edge_vars = list(color = "grey70", linewidth = 0.5),
  edge_text_vars = list(color = "grey30", size = 3, mapping = aes(label =
    paste(breaks_label, "*NA")))
)

Value

A ggplot2 grob object of the decision tree.

Arguments

dat

Dataframe with samples from original dataset ordered according to the clustering within each leaf node.

fit

party object, e.g., as output from partykit::ctree()

term_dat

Dataframe for terminal nodes, must include these columns: id, x, y and y_hat.

layout

Dataframe of layout of all nodes, must include these columns: id, x, y and y_hat.

target_cols

Character vectors representing the hex values of different level colors for targets, defaults to viridis option B.

title

Character string for plot title.

tree_space_top

Numeric value to pass to expand for top margin of tree.

tree_space_bottom

Numeric value to pass to expand for bottom margin of tree.

print_eval

Logical. If TRUE, print evaluation of the tree performance.

metrics

A set of metric functions to evaluate decision tree, defaults to common metrics for classification/regression problems. Can be defined with `yardstick::metric_set`.

x_eval

Numeric value indicating x position to print performance statistics.

y_eval

Numeric value indicating y position to print performance statistics.

task

Character string indicating the type of problem, either 'classification' (categorical outcome) or 'regression' (continuous outcome).

par_node_vars

Named list containing arguments to be passed to the `geom_node_label()` call for non-terminal nodes.

terminal_vars

Named list containing arguments to be passed to the `geom_node_label()` call for terminal nodes.

edge_vars

Named list containing arguments to be passed to the `geom_edge()` call for tree edges.

edge_text_vars

Named list containing arguments to be passed to the `geom_edge_label()` call for tree edge annotations.

Examples

Run this code
x <- compute_tree(penguins, target_lab = "species")
draw_tree(x$dat, x$fit, x$term_dat, x$layout)

Run the code above in your browser using DataLab