Learn R Programming

T4transport (version 0.1.5)

pwdist: Procrustes-Wasserstein Distance

Description

Given two empirical measures $$\mu = \sum_{m=1}^M \mu_m \delta_{X_m}\quad\textrm{and}\quad \nu = \sum_{n=1}^N \nu_n \delta_{Y_n}$$ in \(\mathbb{R}^P\), the Procrustes-Wasserstein (PW) distance is defined as follows: $$ PW_2^2(\mu, \nu) = \min_{Q\in \mathcal{O}(P)} W_2^2(\mu, Q_\# \nu), $$ where \(\mathcal{O}(P)\) is the orthogonal group and \(Q_\#\nu\) is the pushforward via \(Q\).

Usage

pwdist(X, Y, wx = NULL, wy = NULL, ...)

Value

a named list containing

distance

the computed PW distance value.

plan

an \((M\times N)\) nonnegative matrix for the optimal transport plan.

alignment

an optimal alignment matrix of size \((P\times P)\) in \(\mathcal{O}(P)\).

Arguments

X

an \((M\times P)\) matrix of row observations.

Y

an \((N\times P)\) matrix of row observations.

wx

a length-\(M\) marginal density that sums to \(1\). If NULL (default), uniform weight is set.

wy

a length-\(N\) marginal density that sums to \(1\). If NULL (default), uniform weight is set.

...

extra parameters including

maxiter

maximum number of iterations (default: 496).

abstol

stopping criterion for iterations (default: 1e-10).

References

adamo_2025_DepthLookProcrustesWassersteinT4transport

Examples

Run this code
if (FALSE) {
#-------------------------------------------------------------------
#                           Description
#
# * class 1 : samples from N((0,0),  diag(c(4,1/4)))
# * class 2 : samples from N((10,0), diag(c(1/4,4)))
# * class 3 : samples from N((10,0), diag(c(1/4,4))) randomly rotated
#
#  We draw 10 empirical measures from each and compare 
#  the regular Wasserstein and PW distance.
#-------------------------------------------------------------------
## GENERATE DATA
set.seed(10)

#  prepare empty lists
inputs = vector("list", length=30)

#  generate
random_rot = qr.Q(qr(matrix(runif(4), ncol=2)))
for (i in 1:10){
  inputs[[i]] = matrix(rnorm(50*2), ncol=2)
}
for (j in 11:20){
  base_draw = matrix(rnorm(50*2), ncol=2)
  base_draw[,1] = base_draw[,1] + 10
  
  inputs[[j]] = base_draw
  inputs[[j+10]] = base_draw%*%random_rot
}

## COMPUTE
#  empty arrays
dist_RW = array(0, c(30, 30))
dist_PW = array(0, c(30, 30))

#  compute pairwise distances
for (i in 1:29){
  for (j in (i+1):30){
  dist_RW[i,j] <- dist_RW[j,i] <- wasserstein(inputs[[i]], inputs[[j]])$distance
  dist_PW[i,j] <- dist_PW[j,i] <- pwdist(inputs[[i]], inputs[[j]])$distance
  }
}

## VISUALIZE
opar <- par(no.readonly=TRUE)
par(mfrow=c(1,2), pty="s")
image(dist_RW, xaxt="n", yaxt="n", main="Regular Wasserstein distance")
image(dist_PW, xaxt="n", yaxt="n", main="PW distance")
par(opar)
}

Run the code above in your browser using DataLab