if (FALSE) {
masks <- torch_tensor(matrix(c(0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1),
nrow = 3, ncol = 4, byrow = TRUE
))
masks_probs <- c(3, 1, 6)
mask_gen <- specified_masks_mask_generator(masks = masks, masks_probs = masks_probs)
empirical_prob <-
table(as.array(mask_gen(torch::torch_randn(c(10000, ncol(masks))))$sum(-1)))
empirical_prob / sum(empirical_prob)
masks_probs / sum(masks_probs)
}
Run the code above in your browser using DataLab