## Gu and Wahba 4 univariate term example ##
# Generate a data sample for the response variable
# y and the covariates x0, x1 and x2; include a noise predictor x3
set.seed(123)
N <- 400
f_x0x1x2 <- function(x0,x1,x2) {
f0 <- function(x0) 2 * sin(pi * x0)
f1 <- function(x1) exp(2 * x1)
f2 <- function(x2) 0.2 * x2^11 * (10 * (1 - x2))^6 + 10 * (10 * x2)^3 * (1 - x2)^10
f <- f0(x0) + f1(x1) + f2(x2)
return(f)
}
x0 <- runif(N, 0, 1)
x1 <- runif(N, 0, 1)
x2 <- runif(N, 0, 1)
x3 <- runif(N, 0, 1)
# Specify a model for the mean of y
f <- f_x0x1x2(x0 = x0, x1 = x1, x2 = x2)
# Add (Normal) noise to the mean of y
y <- rnorm(N, mean = f, sd = 0.2)
data <- data.frame(y = y, x0 = x0, x1 = x1, x2 = x2, x3 = x3)
# Fit a GeDSgam model
Gmodgam <- NGeDSgam(y ~ f(x0) + f(x1) + f(x2) + f(x3), data = data)
# Check that the sum of the individual base-learner predictions equals the final
# model prediction
pred0 <- predict(Gmodgam, n = 2, newdata = data, base_learner = "f(x0)")
pred1 <- predict(Gmodgam, n = 2, newdata = data, base_learner = "f(x2)")
pred2 <- predict(Gmodgam, n = 2, newdata = data, base_learner = "f(x1)")
pred3 <- predict(Gmodgam, n = 2, newdata = data, base_learner = "f(x3)")
round(predict(Gmodgam, n = 2, newdata = data) -
(mean(predict(Gmodgam, n = 2, newdata = data)) + pred0 + pred1 + pred2 + pred3), 12)
pred0 <- predict(Gmodgam, n = 3, newdata = data, base_learner = "f(x0)")
pred1 <- predict(Gmodgam, n = 3, newdata = data, base_learner = "f(x2)")
pred2 <- predict(Gmodgam, n = 3, newdata = data, base_learner = "f(x1)")
pred3 <- predict(Gmodgam, n = 3, newdata = data, base_learner = "f(x3)")
round(predict(Gmodgam, n = 3, newdata = data) - (pred0 + pred1 + pred2 + pred3), 12)
pred0 <- predict(Gmodgam, n = 4, newdata = data, base_learner = "f(x0)")
pred1 <- predict(Gmodgam, n = 4, newdata = data, base_learner = "f(x2)")
pred2 <- predict(Gmodgam, n = 4, newdata = data, base_learner = "f(x1)")
pred3 <- predict(Gmodgam, n = 4, newdata = data, base_learner = "f(x3)")
round(predict(Gmodgam, n = 4, newdata = data) - (pred0 + pred1 + pred2 + pred3), 12)
Run the code above in your browser using DataLab