mil_data <- generate_mild_df(nbag = 15, nsample = 20, positive_prob = 0.15,
sd_of_mean = rep(0.1, 3))
mdl1 <- mismm(mil_data, control = list(sigma = 1/5))
# bag level predictions
library(dplyr)
mil_data %>%
bind_cols(predict(mdl1, mil_data, type = "class")) %>%
bind_cols(predict(mdl1, mil_data, type = "raw")) %>%
distinct(bag_name, bag_label, .pred_class, .pred)
# instance level prediction
mil_data %>%
bind_cols(predict(mdl1, mil_data, type = "class", layer = "instance")) %>%
bind_cols(predict(mdl1, mil_data, type = "raw", layer = "instance")) %>%
distinct(bag_name, instance_name, bag_label, .pred_class, .pred)
Run the code above in your browser using DataLab