Learn R Programming

fastrerandomize (version 0.3)

fast_distance: JAX-accelerated distance calculations

Description

Compute pairwise distances between the rows of one matrix A or two matrices A and B, using JAX-backed, JIT-compiled kernels. Supports common metrics: Euclidean, squared Euclidean, Manhattan, Chebyshev, Cosine, Minkowski (with optional feature weights), and Mahalanobis (with full or diagonal inverse covariance).

The function automatically batches computations to avoid excessive device memory use.

Usage

fast_distance(
  A,
  B = NULL,
  metric = c("euclidean", "sqeuclidean", "manhattan", "chebyshev", "cosine", "minkowski",
    "mahalanobis"),
  p = 2,
  weights = NULL,
  cov_inv = NULL,
  approximate_inv = TRUE,
  squared = FALSE,
  row_batch_size = NULL,
  as_dist = FALSE,
  return_type = "R",
  verbose = FALSE,
  conda_env = "fastrerandomize_env",
  conda_env_required = TRUE
)

Value

An \(n \times m\) distance matrix in the format specified by return_type. If as_dist = TRUE and B = NULL (symmetric case), returns a

dist object.

Arguments

A

A numeric matrix with rows as observations and columns as features.

B

Optional numeric matrix with the same number of columns as A. If NULL, distances are computed within A (i.e., \(n \times n\)).

metric

Character; one of "euclidean", "sqeuclidean", "manhattan", "chebyshev", "cosine", "minkowski", "mahalanobis". Default is "euclidean".

p

Numeric order for Minkowski distance (must be \(>0\)). Default is 2.

weights

Optional numeric vector of length ncol(A) with nonnegative feature weights. Used for "minkowski" and "manhattan" (the latter is equivalent to Minkowski with p = 1).

cov_inv

Optional inverse covariance matrix (p x p) for Mahalanobis (ignored if approximate_inv = TRUE). If not supplied and approximate_inv = FALSE, it is estimated from rbind(A, B) and inverted in JAX.

approximate_inv

Logical; if TRUE and metric = "mahalanobis", uses a diagonal inverse (reciprocal variances) for speed and robustness. Default TRUE.

squared

Logical; if TRUE, return squared distances when supported ("euclidean" and "mahalanobis"). Ignored for other metrics. Default FALSE.

row_batch_size

Optional integer; number of rows of A to process per batch. If NULL, a safe size is chosen automatically.

as_dist

Logical; if TRUE and B is NULL, return a base dist object (for symmetric metrics). Default FALSE.

return_type

Either "R" (convert to base R matrix/dist) or "jax" (return a JAX array). Default "R".

verbose

Logical; print batching progress. Default FALSE.

conda_env

Character; conda environment name used by reticulate. Default "fastrerandomize_env".

conda_env_required

Logical; whether the specified conda environment must be used. Default TRUE.

Details

- **Mahalanobis**: with approximate_inv = TRUE, the diagonal of the pooled covariance is used (variance stabilizer); otherwise a full inverse covariance is used. - **Weighted distances**: supply weights (length p) for "minkowski" and "manhattan" (the latter uses p = 1). - Computations run in float32 and are JIT-compiled with JAX; where applicable, GPU/Metal/CPU device selection follows your existing backend.

See Also

Examples

Run this code
if (FALSE) {
# Simple Euclidean within-matrix distances (returns an n x n matrix)
X <- matrix(rnorm(50 * 8), 50, 8)
D <- fast_distance(X, metric = "euclidean")

# Cosine distance between two sets
A <- matrix(rnorm(100 * 16), 100, 16)
B <- matrix(rnorm(120 * 16), 120, 16)
Dcos <- fast_distance(A, B, metric = "cosine")

# Minkowski with p = 3 and feature weights
w <- runif(ncol(A))
Dm3 <- fast_distance(A, B, metric = "minkowski", p = 3, weights = w)

# Mahalanobis (diagonal approx, fast & robust)
Dmah_diag <- fast_distance(X, metric = "mahalanobis", approximate_inv = TRUE)

# Mahalanobis with full inverse (computed internally)
Dmah_full <- fast_distance(X, metric = "mahalanobis", approximate_inv = FALSE)

# Return a base R 'dist' object
D_dist <- fast_distance(X, metric = "euclidean", as_dist = TRUE)
}

Run the code above in your browser using DataLab