ggparty (version 1.0.0)

geom_node_plot: Draw plots at nodes

Description

Additional component for a ggparty() that allows to create in each node a ggplot with its data. #'

Usage

geom_node_plot(plot_call = "ggplot", gglist = NULL, width = 1,
  height = 1, size = 1, ids = "terminal", scales = "fixed",
  nudge_x = 0, nudge_y = 0, shared_axis_labels = FALSE,
  shared_legend = TRUE, predict = NULL, predict_gpar = NULL,
  legend_separator = FALSE)

Arguments

plot_call

Any function that generates a ggplot2 object.

gglist

List of additional gg components. Columns of data of nodes can be mapped. Additionally fitted_values and residuals can be mapped if present in party of ggparty()

width

Expansion factor for viewport's width.

height

Expansion factor for viewport's height.

size

Expansion factor for viewport's size.

ids

Id's to plot. Numeric, "terminal", "inner" or "all". Defaults to "terminal".

scales
nudge_x, nudge_y

Nudges node plot.

shared_axis_labels

If TRUE only one pair of axes labels is plotted in the terminal space. Only recommended if ids "terminal" or "all".

shared_legend

If TRUE one shared legend is plotted at the bottom of the tree.

predict

Character string specifying variable for which predictions should be plotted.

predict_gpar

Named list containing arguments to be passed to the geom_line() call of predicted values.

legend_separator

If TRUE line between legend and tree is drawn.

See Also

ggparty()

Examples

Run this code
# NOT RUN {
library(ggparty)

airq <- subset(airquality, !is.na(Ozone))
airct <- ctree(Ozone ~ ., data = airq)

ggparty(airct, horizontal = TRUE, terminal_space = 0.6) +
  geom_edge() +
  geom_edge_label() +
  geom_node_splitvar() +
  geom_node_plot(gglist = list(
    geom_density(aes(x = Ozone))),
    shared_axis_labels = TRUE)

#############################################################

## Plot with ggparty


## Demand for economics journals data
data("Journals", package = "AER")
Journals <- transform(Journals,
                      age = 2000 - foundingyear,
                      chars = charpp * pages)

## linear regression tree (OLS)
j_tree <- lmtree(log(subs) ~ log(price/citations) | price + citations +
                   age + chars + society, data = Journals, minsize = 10, verbose = TRUE)

pred_df <- get_predictions(j_tree, ids = "terminal", newdata =  function(x) {
  data.frame(
    citations = 1,
    price = exp(seq(from = min(x$`log(price/citations)`),
                    to = max(x$`log(price/citations)`),
                    length.out = 100)))
})

ggparty(j_tree, terminal_space = 0.8) +
  geom_edge() +
  geom_edge_label() +
  geom_node_splitvar() +
  geom_node_plot(gglist =
                   list(aes(x = `log(price/citations)`, y = `log(subs)`),
                        geom_point(),
                        geom_line(data = pred_df,
                                  aes(x = log(price/citations),
                                      y = prediction),
                                  col = "red")))
# }

Run the code above in your browser using DataLab