if (FALSE) {
options(java.parameters = c("-Xmx20g", "--add-modules=jdk.incubator.vector", "-XX:+UseZGC"))
pacman::p_load(bartMachine, tidyverse)
seed = 1984
set.seed(seed)
n = 100
x = rnorm(n, 0, 1)
sigma = 0.1
y = x + rnorm(n, 0, sigma)
num_trees = 200
num_iterations_after_burn_in = 1000
bart_mod = bartMachine(data.frame(x = x), y,
flush_indices_to_save_RAM = FALSE,
num_trees = num_trees,
num_iterations_after_burn_in = num_iterations_after_burn_in,
seed = seed)
bart_mod
n_star = 100
x_star = rnorm(n_star)
y_star = as.numeric(x_star + rnorm(n_star, 0, sigma))
yhat_star_bart = predict(bart_mod, data.frame(x = x_star))
Hstar = get_projection_weights(bart_mod, data.frame(x = x_star))
rowSums(Hstar)
yhat_star_projection = as.numeric(Hstar %*% y)
ggplot(data.frame(
yhat_star = yhat_star_bart,
yhat_star_projection = yhat_star_projection,
y_star = y_star)) +
geom_point(aes(x = yhat_star_bart, y = yhat_star_projection), col = "green") +
geom_abline(slope = 1, intercept = 0)
Hstar = get_projection_weights(bart_mod, data.frame(x = x_star), regression_kludge = TRUE)
rowSums(Hstar)
yhat_star_projection = as.numeric(Hstar %*% y)
ggplot(data.frame(
yhat_star = yhat_star_bart,
yhat_star_projection = yhat_star_projection,
y_star = y_star)) +
geom_point(aes(x = yhat_star_bart, y = yhat_star_projection), col = "green") +
geom_abline(slope = 1, intercept = 0)
}
Run the code above in your browser using DataLab