Learn R Programming

repsim (version 0.1.0)

cka: Centered Kernel Alignment

Description

Compute pairwise CKA similarities between multiple representations using a chosen kernel and estimator.

Usage

cka(mats, kernel_type = NULL, estimator = NULL)

Value

An \(M \times M\) matrix of CKA values.

Arguments

mats

A list of length \(M\) containing data matrices of size \((n_\mathrm{samples},\, p_k)\). All matrices must share the same number of rows for matching samples.

kernel_type

Character scalar indicating the kernel. Defaults to "rbf" (if NULL). See repsim_kernels for a list of available kernels.

estimator

Character scalar indicating the HSIC estimator. Defaults to "gretton" (if NULL). See repsim_hsic for a list of available estimators.

References

cristianini_2001_KernelTargetAlignmentrepsim

cortes_2012_AlgorithmsLearningKernelsrepsim

Examples

Run this code
# \donttest{
# --------------------------------------------------
# Use "iris" and "USArrests" datasets
#   1. apply scaling to reduce the effect of scales
#   2. add white noise to create multiple representations
#   3. generate 10 perturbations per each dataset
# --------------------------------------------------
# prepare the prototype
set.seed(1)
X = as.matrix(scale(as.matrix(iris[sample(1:150, 50, replace=FALSE),1:4])))
Y = as.matrix(scale(as.matrix(USArrests)))
n = nrow(X)
p_X = ncol(X)
p_Y = ncol(Y)

# generate 10 of each by perturbation
mats = vector("list", length=20L)
for (i in 1:10){
  mats[[i]] = X + matrix(rnorm(n*p_X, sd=1), nrow=n)
}
for (j in 11:20){
  mats[[j]] = Y + matrix(rnorm(n*p_Y, sd=1), nrow=n)
}

# compute similarity with rbf kernel and different estimators
cka1 = cka(mats, estimator="gretton")
cka2 = cka(mats, estimator="song")
cka3 = cka(mats, estimator="lange")

# visualize
opar <- par(no.readonly=TRUE)
labs <- paste0("rep ",1:20)
par(mfrow=c(1,3), pty="s")

image(cka1[,20:1], axes=FALSE, main="CKA (Gretton)")
axis(1, seq(0, 1, length.out=20), labels = labs, las = 2)
axis(2, at = seq(0, 1, length.out=20), labels = labs[20:1], las = 2)

image(cka2[,20:1], axes=FALSE, main="CKA (Song)")
axis(1, seq(0, 1, length.out=20), labels = labs, las = 2)
axis(2, at = seq(0, 1, length.out=20), labels = labs[20:1], las = 2)

image(cka3[,20:1], axes=FALSE, main="CKA (Lange)")
axis(1, seq(0, 1, length.out=20), labels = labs, las = 2)
axis(2, at = seq(0, 1, length.out=20), labels = labs[20:1], las = 2)
par(opar)
# }

Run the code above in your browser using DataLab