The nest_cv
function applies cross-validation splits to nested data frames or data tables within a data table. It uses the rsample
package's vfold_cv
function to create cross-validation splits for predictive modeling and analysis on nested datasets.
nest_cv(
nest_dt,
v = 10,
repeats = 1,
strata = NULL,
breaks = 4,
pool = 0.1,
...
)
A data.table
containing the cross-validation splits for each nested dataset. It includes:
Original non-nested columns from nest_dt
.
splits
: The cross-validation split objects returned by rsample::vfold_cv
.
train
: The training data for each split.
validate
: The validation data for each split.
A data.frame
or data.table
containing at least one nested
data.frame
or data.table
column.
Supports multi-level nested structures
Requires at least one nested data column
The number of partitions of the data set.
The number of times to repeat the V-fold partitioning.
A variable in data
(single character or name) used to conduct
stratified sampling. When not NULL
, each resample is created within the
stratification variable. Numeric strata
are binned into quartiles.
A single number giving the number of bins desired to stratify a numeric stratification variable.
A proportion of data used to determine if a particular group is too small and should be pooled into another group. We do not recommend decreasing this argument below its default of 0.1 because of the dangers of stratifying groups that are too small.
These dots are for future extensions and must be empty.
The function performs the following steps:
Checks if the input nest_dt
is non-empty and contains at least one nested column of data.frame
s or data.table
s.
Identifies the nested columns and non-nested columns within nest_dt
.
Applies rsample::vfold_cv
to each nested data frame in the specified nested column(s), creating the cross-validation splits.
Expands the cross-validation splits and associates them with the non-nested columns.
Extracts the training and validation data for each split and adds them to the output data table.
If the strata
parameter is provided, stratified sampling is performed during the cross-validation. Additional arguments can be passed to rsample::vfold_cv
via ...
.
rsample::vfold_cv()
Underlying cross-validation function
rsample::training()
Extract training set
rsample::testing()
Extract test set
# Example: Cross-validation for nested data.table demonstrations
# Setup test data
dt_nest <- w2l_nest(
data = iris, # Input dataset
cols2l = 1:2 # Nest first 2 columns
)
# Example 1: Basic 2-fold cross-validation
nest_cv(
nest_dt = dt_nest, # Input nested data.table
v = 2 # Number of folds (2-fold CV)
)
# Example 2: Repeated 2-fold cross-validation
nest_cv(
nest_dt = dt_nest, # Input nested data.table
v = 2, # Number of folds (2-fold CV)
repeats = 2 # Number of repetitions
)
Run the code above in your browser using DataLab