# kfold-helpers

##### Helper functions for K-fold cross-validation

These functions can be used to generate indexes for use with K-fold cross-validation.

##### Usage

`kfold_split_random(K = 10, N = NULL)`kfold_split_balanced(K = 10, x = NULL)

kfold_split_stratified(K = 10, x = NULL)

##### Arguments

- K
The number of folds to use.

- N
The number of observations in the data.

- x
A discrete variable of length

`N`

. Will be coerced to`factor`

. For`kfold_split_balanced`

`x`

should be a binary variable. For`kfold_split_stratified`

`x`

should be a grouping variable with at least`K`

levels.

##### Details

`kfold_split_random`

splits the data into `K`

groups
of equal size (or roughly equal size).

For a binary variable `x`

that has many more `0`

s than `1`

s
(or vice-versa) `kfold_split_balanced`

first splits the data by value of
`x`

, does `kfold_split_random`

within each of the two groups, and
then recombines the indexes returned from the two calls to
`kfold_split_random`

. This helps ensure that the observations in the
less common category of `x`

are more evenly represented across the
folds.

For a grouping variable `x`

, `kfold_split_stratified`

places all
observations in `x`

from the same group/level together the same fold.
The selection of which groups/levels go into which fold (relevant when when
there are more folds than groups) is randomized.

##### Value

An integer vector of length `N`

where each element is an index
in `1:K`

.

##### Examples

```
# NOT RUN {
kfold_split_random(K = 5, N = 20)
x <- sample(c(0, 1), size = 200, replace = TRUE, prob = c(0.05, 0.95))
table(x)
ids <- kfold_split_balanced(K = 5, x = x)
table(ids[x == 0])
table(ids[x == 1])
grp <- gl(n = 50, k = 15, labels = state.name)
length(grp)
head(table(grp))
ids_10 <- kfold_split_stratified(K = 10, x = grp)
(tab_10 <- table(grp, ids_10))
print(colSums(tab_10))
all.equal(sum(colSums(tab_10)), length(grp))
ids_9 <- kfold_split_stratified(K = 9, x = grp)
tab_9 <- table(grp, ids_9)
print(colSums(tab_9))
all.equal(sum(colSums(tab_10)), length(grp))
# }
```

*Documentation reproduced from package loo, version 2.0.0, License: GPL (>= 3)*