if (FALSE) {
#-------------------------------------------------------------------
# Free-Support PW Barycenter of Multiple Gaussians
#
# * class 1 : samples from N((0,0), diag(c(4,1/4)))
# * class 2 : samples from N((10,0), diag(c(1/4,4)))
# * class 3 : samples from N((10,0), Id) randomly rotated
#
# We draw 10 empirical measures from each and compare
# their barycenters under the regular and PW geometries.
#-------------------------------------------------------------------
## GENERATE DATA
set.seed(10)
# prepare empty lists
input_1 = vector("list", length=10L)
input_2 = vector("list", length=10L)
input_3 = vector("list", length=10L)
# generate
random_rot = qr.Q(qr(matrix(runif(4), ncol=2)))
for (i in 1:10){
input_1[[i]] = cbind(rnorm(50, sd=2), rnorm(50, sd=0.5))
}
for (j in 1:10){
base_draw = cbind(rnorm(50, sd=0.5), rnorm(50, sd=2))
base_draw[,1] = base_draw[,1] + 10
input_2[[j]] = base_draw
input_3[[j]] = base_draw%*%random_rot
}
## COMPUTE
# regular Wasserstein barycenters
regular_1 = rbaryGD(input_1, num_support=50)
regular_2 = rbaryGD(input_2, num_support=50)
regular_3 = rbaryGD(input_3, num_support=50)
# Procrustes-Wasserstein barycenters
pw_1 = pwbary(input_1, num_support=50)
pw_2 = pwbary(input_2, num_support=50)
pw_3 = pwbary(input_3, num_support=50)
## VISUALIZE
opar <- par(no.readonly=TRUE)
par(mfrow=c(3,1))
# set the x- and y-limits for display
lim_x = c(-12, 12)
lim_y = c(-10, 5)
# plot prototypical measures per class
plot(input_1[[1]], pch=19, cex=0.5, col="gray80",
main="3 types of measures", xlab="", ylab="",
xlim=lim_x, ylim=lim_y)
points(input_2[[1]], pch=19, cex=0.5, col="gray50")
points(input_3[[1]], pch=19, cex=0.5, col="gray10")
# plot regular barycenters
plot(regular_1$support, pch=19, cex=0.5, col="blue",
main="Regular Wasserstein barycenters",
xlab="", ylab="", xlim=lim_x, ylim=lim_y)
points(regular_2$support, pch=19, cex=0.5, col="cyan")
points(regular_3$support, pch=19, cex=0.5, col="red")
# plot PW barycenters
plot(pw_1$support, pch=19, cex=0.5, col="blue",
main="Procrustes-Wasserstein barycenters",
xlab="", ylab="", xlim=lim_x, ylim=lim_y)
points(pw_2$support, pch=19, cex=0.5, col="cyan")
points(pw_3$support, pch=19, cex=0.5, col="red")
par(opar)
}
Run the code above in your browser using DataLab