# NOT RUN {
library(cuml)
library(MASS)
library(magrittr)
library(purrr)
set.seed(0L)
centers <- list(c(3, 3), c(-3, -3), c(-3, 3))
gen_pts <- function(cluster_sz) {
pts <- centers %>%
map(~ mvrnorm(cluster_sz, mu = .x, Sigma = matrix(c(1, 0, 0, 1), nrow = 2)))
rlang::exec(rbind, !!!pts) %>% as.matrix()
}
gen_labels <- function(cluster_sz) {
seq_along(centers) %>%
sapply(function(x) rep(x, cluster_sz)) %>%
factor()
}
sample_cluster_sz <- 1000
sample_pts <- cbind(
gen_pts(sample_cluster_sz) %>% as.data.frame(),
label = gen_labels(sample_cluster_sz)
)
model <- cuml_knn(label ~ ., sample_pts, algo = "ivfflat", metric = "euclidean")
test_cluster_sz <- 10
test_pts <- gen_pts(test_cluster_sz) %>% as.data.frame()
predictions <- predict(model, test_pts)
print(predictions, n = 30)
# }
Run the code above in your browser using DataLab