Learn R Programming

loo (version 2.0.0)

kfold-helpers: Helper functions for K-fold cross-validation

Description

These functions can be used to generate indexes for use with K-fold cross-validation.

Usage

kfold_split_random(K = 10, N = NULL)

kfold_split_balanced(K = 10, x = NULL)

kfold_split_stratified(K = 10, x = NULL)

Arguments

K

The number of folds to use.

N

The number of observations in the data.

x

A discrete variable of length N. Will be coerced to factor. For kfold_split_balanced x should be a binary variable. For kfold_split_stratified x should be a grouping variable with at least K levels.

Value

An integer vector of length N where each element is an index in 1:K.

Details

kfold_split_random splits the data into K groups of equal size (or roughly equal size).

For a binary variable x that has many more 0s than 1s (or vice-versa) kfold_split_balanced first splits the data by value of x, does kfold_split_random within each of the two groups, and then recombines the indexes returned from the two calls to kfold_split_random. This helps ensure that the observations in the less common category of x are more evenly represented across the folds.

For a grouping variable x, kfold_split_stratified places all observations in x from the same group/level together the same fold. The selection of which groups/levels go into which fold (relevant when when there are more folds than groups) is randomized.

Examples

Run this code
# NOT RUN {
kfold_split_random(K = 5, N = 20)

x <- sample(c(0, 1), size = 200, replace = TRUE, prob = c(0.05, 0.95))
table(x)
ids <- kfold_split_balanced(K = 5, x = x)
table(ids[x == 0])
table(ids[x == 1])

grp <- gl(n = 50, k = 15, labels = state.name)
length(grp)
head(table(grp))

ids_10 <- kfold_split_stratified(K = 10, x = grp)
(tab_10 <- table(grp, ids_10))
print(colSums(tab_10))
all.equal(sum(colSums(tab_10)), length(grp))

ids_9 <- kfold_split_stratified(K = 9, x = grp)
tab_9 <- table(grp, ids_9)
print(colSums(tab_9))
all.equal(sum(colSums(tab_10)), length(grp))

# }

Run the code above in your browser using DataLab