## ------------------------------------------------------------
## Two common plots:
## (A) one patient, multiple measurements -> interval band vs X1
## (B) multiple patients, one measurement -> intervals by patient (sorted by X1)
## ------------------------------------------------------------
dat_train <- generate_clustered_mar(
n = 200, m = 20, d = 1,
x_dist = "uniform", x_params = list(min = 0, max = 10),
hetero_gamma = 2.5,
target_missing = 0.30,
seed = 1
)
y_grid <- seq(-6, 10, length.out = 201)
## test data with latent truth
dat_test <- generate_clustered_mar(
n = 100, m = 20, d = 1,
x_dist = "uniform", x_params = list(min = 0, max = 10),
hetero_gamma = 2.5,
seed = 999
)
## ---------- Case A: P=1, M>1 (one patient, multiple measurements) ----------
pid <- dat_test$id[1]
idx <- which(dat_test$id == pid)
idx <- idx[order(dat_test$X1[idx])][1:10]
test_1M <- data.frame(pid = pid, X1 = dat_test$X1[idx], y_true = dat_test$Y_full[idx])
out_1M <- hcp_predict_targets(
dat = dat_train, test = test_1M,
x_cols = "X1", y_grid = y_grid,
alpha = 0.1,
S = 2, B = 2,
seed = 1
)
plot_hcp_intervals(
out_1M$pred, mode = "band", x_col = "X1",
y_true_col = "y_true", show_true = TRUE,
main = "Case A: one patient, multiple time points (band vs time)"
)
## ---------- Case B: P>1, M=1 (multiple patients, one measurement each) ----------
## take one measurement per patient: j==1 for the first 20 patients
pids <- unique(dat_test$id)[1:20]
test_P1 <- subset(dat_test, id %in% pids & j == 1,
select = c(id, X1, Y_full))
names(test_P1) <- c("pid", "X1", "y_true")
out_P1 <- hcp_predict_targets(
dat = dat_train, test = test_P1,
x_cols = "X1", y_grid = y_grid,
alpha = 0.1,
S = 2, B = 2,
seed = 1
)
plot_hcp_intervals(
out_P1$pred, mode = "pid", pid_col = "pid", x_sort_col = "X1",
y_true_col = "y_true", show_true = TRUE,
main = "Case B: multiple patients, one time point (by patient)"
)
Run the code above in your browser using DataLab