if (FALSE) {
#-------------------------------------------------------------------
# Description
#
# * class 1 : iris dataset (columns 1-4) with perturbations
# * class 2 : class 1 rotated randomly in R^4
# * class 3 : samples from N((0,0), I)
#
# We draw 10 empirical measures from each and compare
# the regular Wasserstein and GW distance. It is expected that
# the GW distance between class 1 and class 2 is negligible,
# while the regular Wasserstein distance is large. For simplicity,
# limit the cardinalities to 20.
#-------------------------------------------------------------------
## GENERATE DATA
set.seed(10)
# prepare empty lists
inputs = vector("list", length=30)
# generate class 1 and 2
iris_mat = as.matrix(iris[sample(1:150,20),1:4])
for (i in 1:10){
inputs[[i]] = iris_mat + matrix(rnorm(20*4), ncol=4)
inputs[[i+10]] = inputs[[i]]%*%qr.Q(qr(matrix(runif(16), ncol=4)))
}
# generate class 3
for (j in 21:30){
inputs[[j]] = matrix(rnorm(20*4), ncol=4)
}
## COMPUTE
# empty arrays
dist_RW = array(0, c(30, 30))
dist_GW = array(0, c(30, 30))
# compute pairwise distances
for (i in 1:29){
X <- inputs[[i]]
Dx <- stats::dist(X)
for (j in (i+1):30){
Y <- inputs[[j]]
Dy <- stats::dist(Y)
dist_RW[i,j] <- dist_RW[j,i] <- wasserstein(X, Y)$distance
dist_GW[i,j] <- dist_GW[j,i] <- gwdist(Dx, Dy)$distance
}
}
## VISUALIZE
opar <- par(no.readonly=TRUE)
par(mfrow=c(1,2), pty="s")
image(dist_RW, xaxt="n", yaxt="n", main="Regular Wasserstein distance")
image(dist_GW, xaxt="n", yaxt="n", main="Gromov-Wasserstein distance")
par(opar)
}
Run the code above in your browser using DataLab