Learn R Programming

causalOT (version 1.0.3)

OTProblem: OTProblem

Description

User-facing constructor for an R6 OTProblem object.

Usage

OTProblem(measure_1, measure_2, ...)

Value

An R6 object of class OTProblem.

Arguments

measure_1

An object of class Measure

measure_2

An object of class Measure

...

Not used at this time

Public fields

device

the torch::torch_device() of the data.

dtype

the torch::torch_dtype of the data.

selected_delta

the delta value selected after choose_hyperparameters

selected_lambda

the lambda value selected after choose_hyperparameters

Active bindings

loss

Prints the current value of the objective. Only available after the solve method has been run

penalty

Returns a list of the lambda and delta penalities that will be iterated through. To set these values, use the setup_arguments function.

Methods


Method add()

adds o2 to the OTProblem

Usage

OTProblem$add(o2)

Arguments

o2

A number or object of class OTProblem

Examples

# example code
if (torch::torch_is_installed()) {
 # setup measures
 x <- matrix(1, 100, 10)
 m1 <- Measure(x = x)
 
 y <- matrix(2, 100, 10)
 m2 <- Measure(x = y)
 
 z <- matrix(3,102, 10)
 
 m3 <- Measure(x = z)
 
# setup OT problems
 ot1 <- OTProblem(m1, m2)
 
 ot2 <- OTProblem(m3, m2)
 
 print(ot1)
 print(ot2)
 
 ot1$add(ot2)
 
 print(ot1)
 print(ot2)
 
}


Method subtract()

subtracts o2 from OTProblem

Usage

OTProblem$subtract(o2)

Arguments

o2

A number or object of class OTProblem

Examples

if (torch::torch_is_installed()) {
 # setup measures
 x <- matrix(1, 100, 10)
 m1 <- Measure(x = x)
 
 y <- matrix(2, 100, 10)
 m2 <- Measure(x = y)
 
 z <- matrix(3,102, 10)
 
 m3 <- Measure(x = z)
 
# setup OT problems
 ot1 <- OTProblem(m1, m2)
 
 ot2 <- OTProblem(m3, m2)
 
 print(ot1)
 print(ot2)
 
 ot1$subtract(ot2)
 
 print(ot1)
 print(ot2)
 
}


Method multiply()

multiplies OTProblem by o2

Usage

OTProblem$multiply(o2)

Arguments

o2

A number or object of class OTProblem

Examples

if (torch::torch_is_installed()) {
 # setup measures
 x <- matrix(1, 100, 10)
 m1 <- Measure(x = x)
 
 y <- matrix(2, 100, 10)
 m2 <- Measure(x = y)
 
 z <- matrix(3,102, 10)
 
 m3 <- Measure(x = z)
 
# setup OT problems
 ot1 <- OTProblem(m1, m2)
 
 ot2 <- OTProblem(m3, m2)
 
 print(ot1)
 print(ot2)
 
 ot1$multiply(ot2)
 
 print(ot1)
 print(ot2)
 
}


Method divide()

divides OTProblem by agument o2

Usage

OTProblem$divide(o2)

Arguments

o2

A number or object of class OTProblem

Examples

if (torch::torch_is_installed()) {
 # setup measures
 x <- matrix(1, 100, 10)
 m1 <- Measure(x = x)
 
 y <- matrix(2, 100, 10)
 m2 <- Measure(x = y)
 
 z <- matrix(3,102, 10)
 
 m3 <- Measure(x = z)
 
# setup OT problems
 ot1 <- OTProblem(m1, m2)
 
 ot2 <- OTProblem(m3, m2)
 
 print(ot1)
 print(ot2)
 
 ot1$divide(ot2)
 
 print(ot1)
 print(ot2)
 
}


Method print()

prints the OT problem object

Usage

OTProblem$print(...)

Arguments

...

Not used at this time


Method new()

Constructor method

Usage

OTProblem$new(measure_1, measure_2)

Arguments

measure_1

An object of class Measure

measure_2

An object of class Measure

...

Not used at this time

Returns

An R6 object of class OTProblem


Method setup_arguments()

Sets up the OT problems for the OTProblem object. This should be run before choose_hyperparameters and solve.

Usage

OTProblem$setup_arguments(
  lambda,
  delta,
  grid.length = 7L,
  cost.function = NULL,
  p = 2,
  cost.online = "auto",
  debias = TRUE,
  diameter = NULL,
  ot_niter = 1000L,
  ot_tol = 0.001
)

Arguments

lambda

The penalty parameters to try for the OTProblem. If not provided, the function will select some.

delta

The constraint paramters to try for the balance function problems, if any.

grid.length

The number of hyperparameters to try if not provided

cost.function

The cost function for the data. Can be any function that takes arguments x1, x2, p. Defaults to the Euclidean distance.

p

The power to raise the cost matrix by. Default is 2

cost.online

Should online costs be used? Default is "auto" but "tensorized" stores the cost matrix in memory while "online" will calculate it on the fly.

debias

Should debiased a debiased OTProblem be used? Defaults to TRUE

diameter

Diameter of the cost function.

ot_niter

Number of iterations to run the solver

ot_tol

The tolerance for convergence of the objective function

Returns

returns the object invisibly

Examples

if (torch::torch_is_installed()) {
 # setup measures
 x <- matrix(1, 100, 10)
 m1 <- Measure(x = x)
 y <- matrix(2, 100, 10)
 m2 <- Measure(x = y, adapt = "weights")
 
 ot <- OTProblem(m1, m2)
 ot$setup_arguments(lambda = 1000)
}


Method solve()

Solve the OTProblem at each parameter value. Must run setup_arguments first.

Usage

OTProblem$solve(
  niter = 1000L,
  tol = 1e-05,
  optimizer = c("torch", "frank-wolfe"),
  torch_optim = torch::optim_lbfgs,
  torch_scheduler = torch::lr_reduce_on_plateau,
  torch_args = NULL,
  osqp_args = NULL,
  quick.balance.function = TRUE
)

Arguments

niter

The nubmer of iterations to run solver at each combination of hyperparameter values

tol

The tolerance for convergence

optimizer

The optimizer to use. One of "torch" or "frank-wolfe"

torch_optim

The torch_optimizer to use. Default is torch::optim_lbfgs

torch_scheduler

The torch::lr_scheduler to use. Default is torch::lr_reduce_on_plateau

torch_args

Arguments passed to the torch optimizer and scheduler

osqp_args

Arguments passed to osqp::osqpSettings() if appropriate

quick.balance.function

Should osqp::osqp() be used to select balance function constraints (delta) or not. Default true.

Returns

returns the object invisibly

Examples

if (torch::torch_is_installed()) {
 # setup measures
 x <- matrix(1, 100, 10)
 m1 <- Measure(x = x)
 y <- matrix(2, 100, 10)
 m2 <- Measure(x = y, adapt = "weights")
 
 ot <- OTProblem(m1, m2)
 ot$setup_arguments(lambda = 1000)
 ot$solve(niter = 1, torch_optim = torch::optim_rmsprop)
}


Method choose_hyperparameters()

Selects the hyperparameter values through a bootstrap algorithm

Usage

OTProblem$choose_hyperparameters(
  n_boot_lambda = 100L,
  n_boot_delta = 1000L,
  lambda_bootstrap = Inf
)

Arguments

n_boot_lambda

The number of bootstrap iterations to run when selecting lambda

n_boot_delta

The number of bootstrap iterations to run when selecting delta

lambda_bootstrap

The penalty parameter to use when selecting lambda. Higher numbers run faster.

Returns

returns the object invisibly

Examples

if (torch::torch_is_installed()) {
 # setup measures
 x <- matrix(1, 100, 10)
 m1 <- Measure(x = x)
 y <- matrix(2, 100, 10)
 m2 <- Measure(x = y, adapt = "weights")
 
 ot <- OTProblem(m1, m2)
 ot$setup_arguments(lambda = c(1,1000))
 ot$solve(niter = 1, torch_optim = torch::optim_rmsprop)
 ot$choose_hyperparameters(n_boot_lambda = 2, n_boot_delta = 10, lambda_bootstrap = 100)
}


Method info()

Provides diagnostics after solve and choose_hyperparameter methods have been run.

Usage

OTProblem$info()

Returns

a list with slots

  • loss the final loss values

  • iterations The number of iterations run for each combination of parameters

  • balance.function.differences The final differences in the balance functions

  • hyperparam.metrics A list of the bootstrap evaluation for delta and lambda values

Examples

if (torch::torch_is_installed()) {
  ot$info()
}


Method clone()

The objects of this class are cloneable with this method.

Usage

OTProblem$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Details

An R6 class for creating optimal transport problems with two Measure objects.

Use OTProblem() to construct an object of class OTProblem. The component objects must be of class Measure.

The process of solving an OT problem involves three steps: (1) setting up the problem by creating Measure objects and combining them into an OTProblem object, (2) choosing the hyperparameters for the problem, and (3) solving the problem by minimizing the objective function. The first step is done by creating Measure objects and then combining them into an OTProblem object using the $add(), $subtract(), $multiply(), and $divide() methods. The second step is done by calling the $setup_arguments() method on the OTProblem object. The third step is done by calling the $solve() method on the OTProblem object.

Examples

Run this code
if (torch::torch_is_installed()) {
  # setup measures
  x <- matrix(1, 100, 10)
  m1 <- Measure(x = x)
  
  y <- matrix(2, 100, 10)
  m2 <- Measure(x = y, adapt = "weights")
  
  z <- matrix(3,102, 10)
  m3 <- Measure(x = z)
  
  # setup OT problems
  ot1 <- OTProblem(m1, m2)
  ot2 <- OTProblem(m3, m2)
  
  # you can add or subtract OTProblem objects into 
  # a new OTProblem
  ot <- 0.5 * ot1 + 0.5 * ot2
  print(ot)
  
  # Then you choose the hyperparameters
  ot$setup_arguments(lambda = 1000)
  
  # then you can solve the objective function
  ot$solve(niter = 1, torch_optim = torch::optim_rmsprop)
  }

## ------------------------------------------------
## Method `OTProblem$add`
## ------------------------------------------------

# example code
if (torch::torch_is_installed()) {
 # setup measures
 x <- matrix(1, 100, 10)
 m1 <- Measure(x = x)
 
 y <- matrix(2, 100, 10)
 m2 <- Measure(x = y)
 
 z <- matrix(3,102, 10)
 
 m3 <- Measure(x = z)
 
# setup OT problems
 ot1 <- OTProblem(m1, m2)
 
 ot2 <- OTProblem(m3, m2)
 
 print(ot1)
 print(ot2)
 
 ot1$add(ot2)
 
 print(ot1)
 print(ot2)
 
}

## ------------------------------------------------
## Method `OTProblem$subtract`
## ------------------------------------------------

if (torch::torch_is_installed()) {
 # setup measures
 x <- matrix(1, 100, 10)
 m1 <- Measure(x = x)
 
 y <- matrix(2, 100, 10)
 m2 <- Measure(x = y)
 
 z <- matrix(3,102, 10)
 
 m3 <- Measure(x = z)
 
# setup OT problems
 ot1 <- OTProblem(m1, m2)
 
 ot2 <- OTProblem(m3, m2)
 
 print(ot1)
 print(ot2)
 
 ot1$subtract(ot2)
 
 print(ot1)
 print(ot2)
 
}

## ------------------------------------------------
## Method `OTProblem$multiply`
## ------------------------------------------------

if (torch::torch_is_installed()) {
 # setup measures
 x <- matrix(1, 100, 10)
 m1 <- Measure(x = x)
 
 y <- matrix(2, 100, 10)
 m2 <- Measure(x = y)
 
 z <- matrix(3,102, 10)
 
 m3 <- Measure(x = z)
 
# setup OT problems
 ot1 <- OTProblem(m1, m2)
 
 ot2 <- OTProblem(m3, m2)
 
 print(ot1)
 print(ot2)
 
 ot1$multiply(ot2)
 
 print(ot1)
 print(ot2)
 
}

## ------------------------------------------------
## Method `OTProblem$divide`
## ------------------------------------------------

if (torch::torch_is_installed()) {
 # setup measures
 x <- matrix(1, 100, 10)
 m1 <- Measure(x = x)
 
 y <- matrix(2, 100, 10)
 m2 <- Measure(x = y)
 
 z <- matrix(3,102, 10)
 
 m3 <- Measure(x = z)
 
# setup OT problems
 ot1 <- OTProblem(m1, m2)
 
 ot2 <- OTProblem(m3, m2)
 
 print(ot1)
 print(ot2)
 
 ot1$divide(ot2)
 
 print(ot1)
 print(ot2)
 
}

## ------------------------------------------------
## Method `OTProblem$setup_arguments`
## ------------------------------------------------

if (torch::torch_is_installed()) {
 # setup measures
 x <- matrix(1, 100, 10)
 m1 <- Measure(x = x)
 y <- matrix(2, 100, 10)
 m2 <- Measure(x = y, adapt = "weights")
 
 ot <- OTProblem(m1, m2)
 ot$setup_arguments(lambda = 1000)
}

## ------------------------------------------------
## Method `OTProblem$solve`
## ------------------------------------------------

if (torch::torch_is_installed()) {
 # setup measures
 x <- matrix(1, 100, 10)
 m1 <- Measure(x = x)
 y <- matrix(2, 100, 10)
 m2 <- Measure(x = y, adapt = "weights")
 
 ot <- OTProblem(m1, m2)
 ot$setup_arguments(lambda = 1000)
 ot$solve(niter = 1, torch_optim = torch::optim_rmsprop)
}

## ------------------------------------------------
## Method `OTProblem$choose_hyperparameters`
## ------------------------------------------------

if (torch::torch_is_installed()) {
 # setup measures
 x <- matrix(1, 100, 10)
 m1 <- Measure(x = x)
 y <- matrix(2, 100, 10)
 m2 <- Measure(x = y, adapt = "weights")
 
 ot <- OTProblem(m1, m2)
 ot$setup_arguments(lambda = c(1,1000))
 ot$solve(niter = 1, torch_optim = torch::optim_rmsprop)
 ot$choose_hyperparameters(n_boot_lambda = 2, n_boot_delta = 10, lambda_bootstrap = 100)
}

## ------------------------------------------------
## Method `OTProblem$info`
## ------------------------------------------------

if (torch::torch_is_installed()) {
  ot$info()
}

Run the code above in your browser using DataLab