## ------------------------------------------------
## Method `OTProblem(measure_1, measure_2)`
## ------------------------------------------------
if (torch::torch_is_installed()) {
# setup measures
x <- matrix(1, 100, 10)
m1 <- Measure(x = x)
y <- matrix(2, 100, 10)
m2 <- Measure(x = y, adapt = "weights")
z <- matrix(3,102, 10)
m3 <- Measure(x = z)
# setup OT problems
ot1 <- OTProblem(m1, m2)
ot2 <- OTProblem(m3, m2)
ot <- 0.5 * ot1 + 0.5 * ot2
print(ot)
## ------------------------------------------------
## Method `OTProblem$setup_arguments`
## ------------------------------------------------
ot$setup_arguments(lambda = 1000)
## ------------------------------------------------
## Method `OTProblem$solve`
## ------------------------------------------------
ot$solve(niter = 1, torch_optim = torch::optim_rmsprop)
## ------------------------------------------------
## Method `OTProblem$choose_hyperparameters`
## ------------------------------------------------
ot$choose_hyperparameters(n_boot_lambda = 1,
n_boot_delta = 1,
lambda_bootstrap = Inf)
## ------------------------------------------------
## Method `OTProblem$info`
## ------------------------------------------------
ot$info()
}
Run the code above in your browser using DataLab