OTProblem R6 class
device
the torch::torch_device()
of the data.
dtype
the torch::torch_dtype of the data.
selected_delta
the delta value selected after choose_hyperparameters
selected_lambda
the lambda value selected after choose_hyperparameters
loss
prints the current value of the objective. Only availble after the solve method has been run
penalty
Returns a list of the lambda and delta penalities that will be iterated through. To set these values, use the setup_arguments
function.
add()
adds o2
to the OTProblem
OTProblem_$add(o2)
o2
A number or object of class OTProblem
subtract()
subtracts o2
from OTProblem
OTProblem_$subtract(o2)
o2
A number or object of class OTProblem
multiply()
multiplies OTProblem by o2
OTProblem_$multiply(o2)
o2
A number or object of class OTProblem
divide()
divides OTProblem by o2
OTProblem_$divide(o2)
o2
A number or object of class OTProblem
...
Not used
new()
Constructor method
OTProblem_$new(measure_1, measure_2)
measure_1
An object of class Measure
measure_2
An object of class Measure
...
Not used at this time
An R6 object of class "OTProblem"
setup_arguments()
OTProblem_$setup_arguments(
lambda,
delta,
grid.length = 7L,
cost.function = NULL,
p = 2,
cost.online = "auto",
debias = TRUE,
diameter = NULL,
ot_niter = 1000L,
ot_tol = 0.001
)
lambda
The penalty parameters to try for the OT problems. If not provided, function will select some
delta
The constraint paramters to try for the balance function problems, if any
grid.length
The number of hyperparameters to try if not provided
cost.function
The cost function for the data. Can be any function that takes arguments x1
, x2
, p
. Defaults to the Euclidean distance
p
The power to raise the cost matrix by. Default is 2
cost.online
Should online costs be used? Default is "auto" but "tensorized" stores the cost matrix in memory while "online" will calculate it on the fly.
debias
Should debiased OT problems be used? Defaults to TRUE
diameter
Diameter of the cost function.
ot_niter
Number of iterations to run the OT problems
ot_tol
The tolerance for convergence of the OT problems
NULL
solve()
Solve the OTProblem at each parameter value. Must run setup_arguments first.
OTProblem_$solve(
niter = 1000L,
tol = 1e-05,
optimizer = c("torch", "frank-wolfe"),
torch_optim = torch::optim_lbfgs,
torch_scheduler = torch::lr_reduce_on_plateau,
torch_args = NULL,
osqp_args = NULL,
quick.balance.function = TRUE
)
niter
The nubmer of iterations to run solver at each combination of hyperparameter values
tol
The tolerance for convergence
optimizer
The optimizer to use. One of "torch" or "frank-wolfe"
torch_optim
The torch_optimizer
to use. Default is torch::optim_lbfgs
torch_scheduler
The torch::lr_scheduler to use. Default is torch::lr_reduce_on_plateau
torch_args
Arguments passed to the torch optimizer and scheduler
osqp_args
Arguments passed to osqp::osqpSettings()
if appropriate
quick.balance.function
Should osqp::osqp()
be used to select balance function constraints (delta) or not. Default true.
choose_hyperparameters()
Selects the hyperparameter values through a bootstrap algorithm
OTProblem_$choose_hyperparameters(
n_boot_lambda = 100L,
n_boot_delta = 1000L,
lambda_bootstrap = Inf
)
n_boot_lambda
The number of bootstrap iterations to run when selecting lambda
n_boot_delta
The number of bootstrap iterations to run when selecting delta
lambda_bootstrap
The penalty parameter to use when selecting lambda. Higher numbers run faster.
info()
Provides diagnostics after solve and choose_hyperparameter methods have been run.
OTProblem_$info()
a list with slots
loss
the final loss values
iterations
The number of iterations run for each combination of parameters
balance.function.differences
The final differences in the balance functions
hyperparam.metrics
A list of the bootstrap evalustion for delta and lambda values
clone()
The objects of this class are cloneable with this method.
OTProblem_$clone(deep = FALSE)
deep
Whether to make a deep clone.