# \donttest{
make_model <- function() {
W <- ag_param(matrix(rnorm(4), 2, 2))
list(
forward = function(x) ag_matmul(W, x),
parameters = function() list(W = W)
)
}
data <- lapply(1:8, function(i) matrix(rnorm(2), 2, 1))
result <- dp_train(
make_model = make_model,
data = data,
loss_fn = function(out, tgt) ag_mse_loss(out, tgt),
target_fn = function(s) s,
n_gpu = 1L,
n_iter = 10L,
lr = 1e-3,
verbose = FALSE
)
# }
Run the code above in your browser using DataLab