# \donttest{
if (torch::torch_is_installed() & rlang::is_installed(c("recipes", "yardstick", "modeldata"))) {
## -----------------------------------------------------------------------------
library(recipes)
library(yardstick)
data(ames, package = "modeldata")
ames$Sale_Price <- log10(ames$Sale_Price)
set.seed(122)
in_train <- sample(1:nrow(ames), 2000)
ames_train <- ames[ in_train,]
ames_test <- ames[-in_train,]
# Using matrices
set.seed(1)
brulee_linear_reg(x = as.matrix(ames_train[, c("Longitude", "Latitude")]),
y = ames_train$Sale_Price,
penalty = 0.10, epochs = 1, batch_size = 64)
# Using recipe
library(recipes)
ames_rec <-
recipe(Sale_Price ~ Bldg_Type + Neighborhood + Year_Built + Gr_Liv_Area +
Full_Bath + Year_Sold + Lot_Area + Central_Air + Longitude + Latitude,
data = ames_train) %>%
# Transform some highly skewed predictors
step_BoxCox(Lot_Area, Gr_Liv_Area) %>%
# Lump some rarely occurring categories into "other"
step_other(Neighborhood, threshold = 0.05) %>%
# Encode categorical predictors as binary.
step_dummy(all_nominal_predictors(), one_hot = TRUE) %>%
# Add an interaction effect:
step_interact(~ starts_with("Central_Air"):Year_Built) %>%
step_zv(all_predictors()) %>%
step_normalize(all_numeric_predictors())
set.seed(2)
fit <- brulee_linear_reg(ames_rec, data = ames_train,
epochs = 5, batch_size = 32)
fit
autoplot(fit)
library(ggplot2)
predict(fit, ames_test) %>%
bind_cols(ames_test) %>%
ggplot(aes(x = .pred, y = Sale_Price)) +
geom_abline(col = "green") +
geom_point(alpha = .3) +
lims(x = c(4, 6), y = c(4, 6)) +
coord_fixed(ratio = 1)
library(yardstick)
predict(fit, ames_test) %>%
bind_cols(ames_test) %>%
rmse(Sale_Price, .pred)
}
# }
Run the code above in your browser using DataLab