Learn R Programming

ggmlR (version 0.6.1)

dp_train: Data-parallel training across multiple GPUs

Description

Runs synchronous data-parallel training:

  1. make_model() is called n_gpu times to create one independent model replica per GPU (each with its own parameters).

  2. Each iteration: the current data item is forwarded through every replica in parallel; gradients are computed via backward().

  3. Gradients are averaged across all replicas (element-wise mean).

  4. One optimizer step is taken on replica 0; updated weights are then broadcast to replicas 1 … N-1 so all replicas stay in sync.

Usage

dp_train(
  make_model,
  data,
  loss_fn = NULL,
  forward_fn = NULL,
  target_fn = NULL,
  n_gpu = NULL,
  n_iter = 10L,
  lr = 0.001,
  max_norm = Inf,
  verbose = 10L
)

Value

A list with:

params

Named list of final parameters (from replica 0).

loss_history

Numeric vector of per-iteration mean loss.

model

Replica 0 model object.

Arguments

make_model

A zero-argument function that returns a model object with at least $forward(x) and $parameters() methods. Called n_gpu times; each call must produce independent parameters.

data

A list of training samples. Each element is passed directly to forward_fn (or to model$forward() if forward_fn is NULL).

loss_fn

A function (logits, target) -> scalar ag_tensor. If NULL, forward_fn must return the loss directly.

forward_fn

Optional function (model, sample) -> logits. If NULL, the sample is passed directly as model$forward(sample).

target_fn

Optional function (sample) -> target. Used when loss_fn is not NULL to extract the target from a sample. If NULL, sample itself is used as the target.

n_gpu

Number of GPU replicas (default: all available Vulkan devices, minimum 1).

n_iter

Number of training iterations (passes over data).

lr

Learning rate for Adam optimizer (default 1e-3).

max_norm

Gradient clipping threshold (default Inf = no clip).

verbose

Print loss every verbose iterations, or FALSE to suppress output.

Details

Because all replicas live in the same R process and ag_param uses environment (reference) semantics, no IPC or NCCL is required — weight synchronisation is a simple in-place copy.

Examples

Run this code
# \donttest{
make_model <- function() {
  W <- ag_param(matrix(rnorm(4), 2, 2))
  list(
    forward    = function(x) ag_matmul(W, x),
    parameters = function() list(W = W)
  )
}
data <- lapply(1:8, function(i) matrix(rnorm(2), 2, 1))
result <- dp_train(
  make_model = make_model,
  data       = data,
  loss_fn    = function(out, tgt) ag_mse_loss(out, tgt),
  target_fn  = function(s) s,
  n_gpu      = 1L,
  n_iter     = 10L,
  lr         = 1e-3,
  verbose    = FALSE
)
# }

Run the code above in your browser using DataLab