if (FALSE) { # rlang::is_installed("tabnet")
# Load libraries
suppressWarnings(suppressMessages(library(parsnip)))
suppressWarnings(suppressMessages(library(rsample)))
suppressWarnings(suppressMessages(library(tabnet)))
# Load data
split <- initial_split(mtcars, prop = 9/10)
car_train <- training(split)
if (interactive() & torch::torch_is_installed()) {
torch::torch_manual_seed(1)
# Create model and fit
mtcar_fit <- tabnet::tabnet() |>
set_mode("regression") |>
set_engine("torch") |>
fit(mpg ~ ., data = car_train)
out <- butcher(mtcar_fit, verbose = TRUE)
}
}
Run the code above in your browser using DataLab