Learn R Programming

boostr (version 1.0.0)

kFoldCV: Generic k-fold Cross Validation wrapper

Description

A general abstraction of the k-fold cross validation procedure.

Usage

kFoldCV(proc, k, data, params, .rngSeed = 1234, .chunkSize = 1L, .doSEQ = FALSE)

Arguments

proc
the procedure to be k-fold cross validated. proc needs to accept data and newdata in its signature, and must return a numeric vector.
k
the number of folds.
data
a matrix or data.frame from which the folds will be created.
params
a list or data.frame. If params is a list, every combination of the entries in its cells will be used as parameters to be cross validated. If params is a data.frame, each row of arguments will be cross-validated.
.rngSeed
the seed set before randomly generating fold indices.
.chunkSize
the number of parameter combinations to be processed at once (see help for iter).
.doSEQ
logical flag indicating whether cross validation should be run sequentially or with %dopar%.

Value

a vector whose length is equal to nrow(params), if params is a data.frame, or the number of combinations of elements of params if it's a list. The i-th component corresponds to the k-fold cross-validated value of proc evaluated with parameters from the i-th combination of params.

Details

This function leverages foreach and iter to perform k-fold cross validation in a distributed fashion (provided a parallel backend is registered).

Because the heart of this function is a pair of nested foreach loops one should be careful of "over-parallelization". Meaning, if the routine inside proc is already natively parallel, then by invoking this routine around proc you'll be distributing a distributed computation. This may not yield the speed gains you would expect.

One work around to this -- assuming proc is parallelized using foreach is to call create a wrapper around proc that calls registerDoSEQ. For example,

proC <- function(...) {registerDoSEQ(); proc(...)}

Alternatively, you could run kFoldCV sequentially by setting .doSEQ to TRUE.

For a procedure proc <- function(data, newdata, arg1, ..., argN){...} , it may end up that cross-validating a single N-tuple of arguments c(arg1, ..., argN) may be very quick. Hence, the time it takes to send off proc, the data and the appropriate combinations of params may overwhelm the actual computation time. In this instance, one should consider changing .chunkSize from 1 to n (where n is any reasonable integer value that would justify the passing of data to a distant node).

Examples

Run this code
# simple example with k-NN where we can build our own wrapper
library(class)
data(iris)
.iris <- iris[, 5:1] # put response as first column

# make a wrapper for class::knn
f <- function(data, newdata, k) {
  preds <- knn(train=data[,-1],
               test=newdata[, -1],
               cl=data[, 1],
               k=k)
  mean(preds==newdata[, 1])
}

params <- list(k=c(1,3,5,7))

accuracy <- kFoldCV(f, 10, .iris, params, .rngSeed=407)

data.frame(expand.grid(params), accuracy=accuracy)

# look at a more complicated example:
# cross validate an svm with different kernels and different models
require(e1071)
g <- function(data, newdata, kernel, cost, gamma, formula) {
  kern <- switch(kernel, "linear", "radial", stop("invalid kernel"))
  form <- switch(formula,
                 as.formula(Species ~ .),
                 as.formula(Species ~ Petal.Length + Petal.Width),
                 as.formula(Petal.Length ~ .),
                 stop('invalid formula'))

   svmWrapper <- function(data, newdata, kernel, cost, gamma, form) {
                   svmObj <- svm(formula=form, data=data, kernel=kernel,
                                 cost=cost, gamma=gamma)
                   predict(svmObj, newdata)
                 }
  preds <- svmWrapper(data, newdata, kernel=kern, cost=cost,
                      gamma=gamma, form=form)

  if (formula != 3) {
    mean(preds == newdata[["Species"]])
  } else {
    mean((preds - newdata[["Petal.Length"]])^2)
  }
}

params <- list(kernel=1:2, cost=c(10,50), gamma=0.01, formula=1)
accuracy <- kFoldCV(g, 10, iris, params)
data.frame(expand.grid(params), metric=accuracy)

Run the code above in your browser using DataLab