if (torch::torch_is_installed()) {
# setup measures
x <- matrix(1, 100, 10)
m1 <- Measure(x = x)
y <- matrix(2, 100, 10)
m2 <- Measure(x = y, adapt = "weights")
z <- matrix(3,102, 10)
m3 <- Measure(x = z)
# setup OT problems
ot1 <- OTProblem(m1, m2)
ot2 <- OTProblem(m3, m2)
# you can add or subtract OTProblem objects into
# a new OTProblem
ot <- 0.5 * ot1 + 0.5 * ot2
print(ot)
# Then you choose the hyperparameters
ot$setup_arguments(lambda = 1000)
# then you can solve the objective function
ot$solve(niter = 1, torch_optim = torch::optim_rmsprop)
}
## ------------------------------------------------
## Method `OTProblem$add`
## ------------------------------------------------
# example code
if (torch::torch_is_installed()) {
# setup measures
x <- matrix(1, 100, 10)
m1 <- Measure(x = x)
y <- matrix(2, 100, 10)
m2 <- Measure(x = y)
z <- matrix(3,102, 10)
m3 <- Measure(x = z)
# setup OT problems
ot1 <- OTProblem(m1, m2)
ot2 <- OTProblem(m3, m2)
print(ot1)
print(ot2)
ot1$add(ot2)
print(ot1)
print(ot2)
}
## ------------------------------------------------
## Method `OTProblem$subtract`
## ------------------------------------------------
if (torch::torch_is_installed()) {
# setup measures
x <- matrix(1, 100, 10)
m1 <- Measure(x = x)
y <- matrix(2, 100, 10)
m2 <- Measure(x = y)
z <- matrix(3,102, 10)
m3 <- Measure(x = z)
# setup OT problems
ot1 <- OTProblem(m1, m2)
ot2 <- OTProblem(m3, m2)
print(ot1)
print(ot2)
ot1$subtract(ot2)
print(ot1)
print(ot2)
}
## ------------------------------------------------
## Method `OTProblem$multiply`
## ------------------------------------------------
if (torch::torch_is_installed()) {
# setup measures
x <- matrix(1, 100, 10)
m1 <- Measure(x = x)
y <- matrix(2, 100, 10)
m2 <- Measure(x = y)
z <- matrix(3,102, 10)
m3 <- Measure(x = z)
# setup OT problems
ot1 <- OTProblem(m1, m2)
ot2 <- OTProblem(m3, m2)
print(ot1)
print(ot2)
ot1$multiply(ot2)
print(ot1)
print(ot2)
}
## ------------------------------------------------
## Method `OTProblem$divide`
## ------------------------------------------------
if (torch::torch_is_installed()) {
# setup measures
x <- matrix(1, 100, 10)
m1 <- Measure(x = x)
y <- matrix(2, 100, 10)
m2 <- Measure(x = y)
z <- matrix(3,102, 10)
m3 <- Measure(x = z)
# setup OT problems
ot1 <- OTProblem(m1, m2)
ot2 <- OTProblem(m3, m2)
print(ot1)
print(ot2)
ot1$divide(ot2)
print(ot1)
print(ot2)
}
## ------------------------------------------------
## Method `OTProblem$setup_arguments`
## ------------------------------------------------
if (torch::torch_is_installed()) {
# setup measures
x <- matrix(1, 100, 10)
m1 <- Measure(x = x)
y <- matrix(2, 100, 10)
m2 <- Measure(x = y, adapt = "weights")
ot <- OTProblem(m1, m2)
ot$setup_arguments(lambda = 1000)
}
## ------------------------------------------------
## Method `OTProblem$solve`
## ------------------------------------------------
if (torch::torch_is_installed()) {
# setup measures
x <- matrix(1, 100, 10)
m1 <- Measure(x = x)
y <- matrix(2, 100, 10)
m2 <- Measure(x = y, adapt = "weights")
ot <- OTProblem(m1, m2)
ot$setup_arguments(lambda = 1000)
ot$solve(niter = 1, torch_optim = torch::optim_rmsprop)
}
## ------------------------------------------------
## Method `OTProblem$choose_hyperparameters`
## ------------------------------------------------
if (torch::torch_is_installed()) {
# setup measures
x <- matrix(1, 100, 10)
m1 <- Measure(x = x)
y <- matrix(2, 100, 10)
m2 <- Measure(x = y, adapt = "weights")
ot <- OTProblem(m1, m2)
ot$setup_arguments(lambda = c(1,1000))
ot$solve(niter = 1, torch_optim = torch::optim_rmsprop)
ot$choose_hyperparameters(n_boot_lambda = 2, n_boot_delta = 10, lambda_bootstrap = 100)
}
## ------------------------------------------------
## Method `OTProblem$info`
## ------------------------------------------------
if (torch::torch_is_installed()) {
ot$info()
}
Run the code above in your browser using DataLab