set.seed(8)
mil_data <- generate_mild_df(nbag = 20,
positive_prob = 0.15,
dist = rep("mvnormal", 3),
mean = list(rep(1, 10), rep(2, 10)),
sd_of_mean = rep(0.1, 3))
df <- build_instance_feature(mil_data, seq(0.05, 0.95, length.out = 10))
cost_seq <- 2^seq(-5, 7, length.out = 3)
# Heuristic method
mdl1 <- cv_misvm(x = df[, 4:123], y = df$bag_label,
bags = df$bag_name, cost_seq = cost_seq,
n_fold = 3, method = "heuristic")
mdl2 <- cv_misvm(mi(bag_label, bag_name) ~ X1_mean + X2_mean + X3_mean, data = df,
cost_seq = cost_seq, n_fold = 3)
if (require(gurobi)) {
# solve using the MIP method
mdl3 <- cv_misvm(x = df[, 4:123], y = df$bag_label,
bags = df$bag_name, cost_seq = cost_seq,
n_fold = 3, method = "mip")
}
predict(mdl1, new_data = df, type = "raw", layer = "bag")
# summarize predictions at the bag layer
suppressWarnings(library(dplyr))
df %>%
bind_cols(predict(mdl2, df, type = "class")) %>%
bind_cols(predict(mdl2, df, type = "raw")) %>%
distinct(bag_name, bag_label, .pred_class, .pred)
Run the code above in your browser using DataLab