if (torch::torch_is_installed()) {
# regression using formula specification
data("ames", package = "modeldata")
fit <- tabnet_fit(Sale_Price ~ ., data = ames, epochs = 1)
# classification using data-frame specification
data("attrition", package = "modeldata")
attrition_x <- attrition[,-which(names(attrition) == "Attrition")]
fit <- tabnet_fit(attrition_x, attrition$Attrition, epochs = 1)
}
Run the code above in your browser using DataLab