if ( torch::torch_is_installed()) {
x <- matrix(stats::rnorm(10*5), 10, 5)
z <- stats::rbinom(10, 1, 0.5)
weights <- calc_weight(x = x, z = z, method = "Logistic", estimand = "ATT")
ot1 <- ot_distance(x1 = weights, penalty = 100,
p = 2, debias = TRUE, online.cost = "auto",
diameter = NULL)
ot2<- ot_distance(x1 = x[z==0, ], x2 = x[z == 1,],
a= weights@w0/sum(weights@w0), b = weights@w1,
penalty = 100, p = 2, debias = TRUE, online.cost = "auto", diameter = NULL)
all.equal(ot1$post, ot2)
}
Run the code above in your browser using DataLab