if (FALSE) {
#-------------------------------------------------------------------
# Description
#
# * class 1 : samples from N((0,0), diag(c(4,1/4)))
# * class 2 : samples from N((10,0), diag(c(1/4,4)))
# * class 3 : samples from N((10,0), diag(c(1/4,4))) randomly rotated
#
# We draw 10 empirical measures from each and compare
# the regular Wasserstein and PW distance.
#-------------------------------------------------------------------
## GENERATE DATA
set.seed(10)
# prepare empty lists
inputs = vector("list", length=30)
# generate
random_rot = qr.Q(qr(matrix(runif(4), ncol=2)))
for (i in 1:10){
inputs[[i]] = matrix(rnorm(50*2), ncol=2)
}
for (j in 11:20){
base_draw = matrix(rnorm(50*2), ncol=2)
base_draw[,1] = base_draw[,1] + 10
inputs[[j]] = base_draw
inputs[[j+10]] = base_draw%*%random_rot
}
## COMPUTE
# empty arrays
dist_RW = array(0, c(30, 30))
dist_PW = array(0, c(30, 30))
# compute pairwise distances
for (i in 1:29){
for (j in (i+1):30){
dist_RW[i,j] <- dist_RW[j,i] <- wasserstein(inputs[[i]], inputs[[j]])$distance
dist_PW[i,j] <- dist_PW[j,i] <- pwdist(inputs[[i]], inputs[[j]])$distance
}
}
## VISUALIZE
opar <- par(no.readonly=TRUE)
par(mfrow=c(1,2), pty="s")
image(dist_RW, xaxt="n", yaxt="n", main="Regular Wasserstein distance")
image(dist_PW, xaxt="n", yaxt="n", main="PW distance")
par(opar)
}
Run the code above in your browser using DataLab