# \donttest{
if (torch::torch_is_installed() & rlang::is_installed(c("recipes", "yardstick", "modeldata"))) {
library(recipes)
library(yardstick)
## -----------------------------------------------------------------------------
# increase # epochs to get better results
data(cells, package = "modeldata")
cells$case <- NULL
set.seed(122)
in_train <- sample(1:nrow(cells), 1000)
cells_train <- cells[ in_train,]
cells_test <- cells[-in_train,]
# Using matrices
set.seed(1)
brulee_logistic_reg(x = as.matrix(cells_train[, c("fiber_width_ch_1", "width_ch_1")]),
y = cells_train$class,
penalty = 0.10, epochs = 3)
# Using recipe
library(recipes)
cells_rec <-
recipe(class ~ ., data = cells_train) %>%
# Transform some highly skewed predictors
step_YeoJohnson(all_numeric_predictors()) %>%
step_normalize(all_numeric_predictors()) %>%
step_pca(all_numeric_predictors(), num_comp = 10)
set.seed(2)
fit <- brulee_logistic_reg(cells_rec, data = cells_train,
penalty = .01, epochs = 5)
fit
autoplot(fit)
library(yardstick)
predict(fit, cells_test, type = "prob") %>%
bind_cols(cells_test) %>%
roc_auc(class, .pred_PS)
}
# }
Run the code above in your browser using DataLab