OTProblem R6 class
devicethe torch::torch_device() of the data.
dtypethe torch::torch_dtype of the data.
selected_deltathe delta value selected after choose_hyperparameters
selected_lambdathe lambda value selected after choose_hyperparameters
lossprints the current value of the objective. Only availble after the solve method has been run
penaltyReturns a list of the lambda and delta penalities that will be iterated through. To set these values, use the setup_arguments function.
add()adds o2 to the OTProblem
OTProblem_$add(o2)o2A number or object of class OTProblem
subtract()subtracts o2 from OTProblem
OTProblem_$subtract(o2)o2A number or object of class OTProblem
multiply()multiplies OTProblem by o2
OTProblem_$multiply(o2)o2A number or object of class OTProblem
divide()divides OTProblem by o2
OTProblem_$divide(o2)o2A number or object of class OTProblem
...Not used
new()Constructor method
OTProblem_$new(measure_1, measure_2)measure_1An object of class Measure
measure_2An object of class Measure
...Not used at this time
An R6 object of class "OTProblem"
setup_arguments()OTProblem_$setup_arguments(
lambda,
delta,
grid.length = 7L,
cost.function = NULL,
p = 2,
cost.online = "auto",
debias = TRUE,
diameter = NULL,
ot_niter = 1000L,
ot_tol = 0.001
)lambdaThe penalty parameters to try for the OT problems. If not provided, function will select some
deltaThe constraint paramters to try for the balance function problems, if any
grid.lengthThe number of hyperparameters to try if not provided
cost.functionThe cost function for the data. Can be any function that takes arguments x1, x2, p. Defaults to the Euclidean distance
pThe power to raise the cost matrix by. Default is 2
cost.onlineShould online costs be used? Default is "auto" but "tensorized" stores the cost matrix in memory while "online" will calculate it on the fly.
debiasShould debiased OT problems be used? Defaults to TRUE
diameterDiameter of the cost function.
ot_niterNumber of iterations to run the OT problems
ot_tolThe tolerance for convergence of the OT problems
NULL
solve()Solve the OTProblem at each parameter value. Must run setup_arguments first.
OTProblem_$solve(
niter = 1000L,
tol = 1e-05,
optimizer = c("torch", "frank-wolfe"),
torch_optim = torch::optim_lbfgs,
torch_scheduler = torch::lr_reduce_on_plateau,
torch_args = NULL,
osqp_args = NULL,
quick.balance.function = TRUE
)niterThe nubmer of iterations to run solver at each combination of hyperparameter values
tolThe tolerance for convergence
optimizerThe optimizer to use. One of "torch" or "frank-wolfe"
torch_optimThe torch_optimizer to use. Default is torch::optim_lbfgs
torch_schedulerThe torch::lr_scheduler to use. Default is torch::lr_reduce_on_plateau
torch_argsArguments passed to the torch optimizer and scheduler
osqp_argsArguments passed to osqp::osqpSettings() if appropriate
quick.balance.functionShould osqp::osqp() be used to select balance function constraints (delta) or not. Default true.
choose_hyperparameters()Selects the hyperparameter values through a bootstrap algorithm
OTProblem_$choose_hyperparameters(
n_boot_lambda = 100L,
n_boot_delta = 1000L,
lambda_bootstrap = Inf
)n_boot_lambdaThe number of bootstrap iterations to run when selecting lambda
n_boot_deltaThe number of bootstrap iterations to run when selecting delta
lambda_bootstrapThe penalty parameter to use when selecting lambda. Higher numbers run faster.
info()Provides diagnostics after solve and choose_hyperparameter methods have been run.
OTProblem_$info()a list with slots
loss the final loss values
iterations The number of iterations run for each combination of parameters
balance.function.differences The final differences in the balance functions
hyperparam.metrics A list of the bootstrap evalustion for delta and lambda values
clone()The objects of this class are cloneable with this method.
OTProblem_$clone(deep = FALSE)deepWhether to make a deep clone.