Learn R Programming

FDOTT (version 0.1.0)

ot_cost_sgn: Compute optimal transport costs for signed measures

Description

Compute the optimal transport (OT) cost between signed measures that have the same total mass.

Usage

ot_cost_sgn(mu, nu, costm, mode = c("all", "diag"))

Value

The OT cost between the vectors in mu and nu.

For mode = "all" the whole matrix of size \(K_1 \times K_2\) is returned. If mu or nu is a vector, then this matrix is also returned as a vector. nu = NULL means that nu = mu and only the lower triangular part is actually computed and then reflected.

If mode = "diag", then only the diagonal is returned (requiring \(K_1 = K_2\)).

Arguments

mu

matrix (row-wise) or list containing \(K_1\) vectors of length \(N\).

nu

matrix (row-wise) or list containing \(K_2\) vectors of length \(N\) or NULL.

costm

cost matrix \(c \in \mathbb{R}^{N \times N}\).

mode

controls which of the pairwise OT costs are computed.

Details

The extended OT functional for vectors \(\mu,\,\nu \in \mathbb{R}^N\) with \(\sum_{i=1}^N \mu_i = \sum_{i=1}^N \nu_i\) is defined as $$ \mathrm{OT}^{\pm}_c(\mu, \nu) := \mathrm{OT}_c(\mu^+ + \nu^-, \, \nu^+ + \mu^-)\,, $$ where \(\mu^+ = \max(0, \mu)\) and \(\mu^- = -\min(0, \mu)\) denote the positive and negative part of \(\mu\), and \(\mathrm{OT}_c\) is the standard OT functional. To compute the standard OT, the function transport::transport is used. The values may be computed in parallel via future::plan.

See Also

transport::transport

Examples

Run this code

# enable parallel computation
if (requireNamespace("future")) {
    future::plan(future::multisession)
}

# generate random signed measures with total mass 0 (row-wise)
rsum0 <- \(K, N) {
    x <- runif(K * N) |> matrix(K, N)
    x <- sweep(x, 1, rowSums(x) / N, "-")
    x[, 1] <- x[, 1] - rowSums(x)
    x
}

K1 <- 3
K2 <- 2
N <- 4
costm <- cost_matrix_lp(1:N)

set.seed(123)
mu <- rsum0(K1, N)
nu <- rsum0(K2, N)

print(ot_cost_sgn(mu[2, ], nu[2, ], costm))

# mode = "diag" requires K1 = K2
print(ot_cost_sgn(mu[1:2, ], nu, costm, mode = "diag"))

print(ot_cost_sgn(mu, nu, costm))

# only works properly if costm is semi-metric
print(ot_cost_sgn(mu, NULL, costm))
# but it requires less computations than
print(ot_cost_sgn(mu, mu, costm))
# \dontshow{
## R CMD check: make sure any open connections are closed afterward
if (requireNamespace("future") && !inherits(future::plan(), "sequential")) future::plan(future::sequential)
# }

Run the code above in your browser using DataLab