Learn R Programming

SMMAL (version 0.0.5)

cross_validation: Assign Cross-Validation Folds for Labelled and Unlabelled Data

Description

Creates fold assignments for both labelled and unlabelled data using stratified random sampling, ensuring an approximately equal number of samples per fold within each group.

Usage

cross_validation(N, nfold, A, Y)

Value

A list containing:

R

Binary vector of length N, where 1 indicates labelled observations (non-missing A and Y), and 0 indicates unlabelled observations.

foldid

Integer vector of length N. Fold assignments (from 1 to nfold) for use in cross-validation.

Arguments

N

Integer. Total number of observations in the dataset.

nfold

Integer. Number of folds to assign for cross-validation.

A

Numeric vector. Treatment assignment indicator (may contain NA for unlabelled samples).

Y

Numeric vector. Outcome variable (may contain NA for unlabelled samples).

Details

The function first separates observations into labelled and unlabelled groups based on the availability of both treatment (A) and outcome (Y). Within each group, fold assignments are randomly assigned to ensure approximately balanced sample sizes across folds. This setup supports semi-supervised learning workflows by maintaining structure between labelled and unlabelled data during cross-fitting.

Examples

Run this code
set.seed(123)
N <- 100
A <- sample(c(0, 1, NA), size = N, replace = TRUE, prob = c(0.45, 0.45, 0.10))
Y <- sample(c(0, 1, NA), size = N, replace = TRUE, prob = c(0.45, 0.45, 0.10))

# Assign 5 folds for cross-fitting
result <- cross_validation(N = N, nfold = 5, A = A, Y = Y)

table(result$R)  # Check number of labelled vs unlabelled
table(result$foldid)  # Check how folds are distributed


Run the code above in your browser using DataLab