# \donttest{
# Self-attention
mha <- ag_multihead_attention(64L, 8L)
x <- ag_tensor(matrix(rnorm(64 * 10), 64, 10)) # [d_model=64, seq_len=10]
out <- mha$forward(x) # [64, 10]
# Cross-attention
q <- ag_tensor(matrix(rnorm(64 * 10), 64, 10))
kv <- ag_tensor(matrix(rnorm(64 * 15), 64, 15))
out <- mha$forward(q, kv, kv)
# Causal (GPT-style)
out <- mha$forward(x, causal_mask = TRUE)
# }
Run the code above in your browser using DataLab