# \donttest{
if (torch::torch_is_installed()) {
# Standard architecture object for train_nn()
std_arch = nn_arch(nn_name = "mlp_model")
# GRU architecture spec
gru_arch = nn_arch(
nn_name = "GRU",
nn_layer = "torch::nn_gru",
layer_arg_fn = ~ if (.is_output) {
list(.in, .out)
} else {
list(input_size = .in, hidden_size = .out, batch_first = TRUE)
},
out_nn_layer = "torch::nn_linear",
forward_extract = ~ .[[1]],
before_output_transform = ~ .[, .$size(2), ],
input_transform = ~ .$unsqueeze(2)
)
# Custom layer architecture (resolved from calling environment)
custom_linear = torch::nn_module(
"CustomLinear",
initialize = function(in_features, out_features, bias = TRUE) {
self$layer = torch::nn_linear(in_features, out_features, bias = bias)
},
forward = function(x) self$layer(x)
)
custom_arch = nn_arch(
nn_name = "CustomMLP",
nn_layer = ~ custom_linear
)
model = train_nn(
Sepal.Length ~ .,
data = iris[, 1:4],
hidden_neurons = c(64, 32),
activations = "relu",
epochs = 50,
architecture = gru_arch
)
}
# }
Run the code above in your browser using DataLab