Learn R Programming

causalOT (version 1.0.2)

ot_distance: Optimal Transport Distance

Description

Optimal Transport Distance

Usage

ot_distance(
  x1,
  x2 = NULL,
  a = NULL,
  b = NULL,
  penalty,
  p = 2,
  cost = NULL,
  debias = TRUE,
  online.cost = "auto",
  diameter = NULL,
  niter = 1000,
  tol = 1e-07
)

# S3 method for causalWeights ot_distance( x1, x2 = NULL, a = NULL, b = NULL, penalty, p = 2, cost = NULL, debias = TRUE, online.cost = "auto", diameter = NULL, niter = 1000, tol = 1e-07 )

# S3 method for matrix ot_distance( x1, x2, a = NULL, b = NULL, penalty, p = 2, cost = NULL, debias = TRUE, online.cost = "auto", diameter = NULL, niter = 1000, tol = 1e-07 )

# S3 method for array ot_distance( x1, x2, a = NULL, b = NULL, penalty, p = 2, cost = NULL, debias = TRUE, online.cost = "auto", diameter = NULL, niter = 1000, tol = 1e-07 )

# S3 method for torch_tensor ot_distance( x1, x2, a = NULL, b = NULL, penalty, p = 2, cost = NULL, debias = TRUE, online.cost = "auto", diameter = NULL, niter = 1000, tol = 1e-07 )

Value

For objects of class matrix, numeric value giving the optimal transport distance. For objects of class causalWeights, results are returned as a list for before ('pre') and after adjustment ('post').

Arguments

x1

Either an object of class causalWeights or a matrix of the covariates in the first sample

x2

NULL or a matrix of the covariates in the second sample.

a

Empirical measure of the first sample. If NULL, assumes each observation gets equal mass. Ignored for objects of class causalWeights.

b

Empirical measure of the second sample. If NULL, assumes each observation gets equal mass. Ignored for objects of class causalWeights.

penalty

The penalty of the optimal transport distance to use. If missing or NULL, the function will try to guess a suitable value depending if debias is TRUE or FALSE.

p

\(L_p\) distance metric power

cost

Supply your own cost function. Should take arguments x1, x2, and p.

debias

TRUE or FALSE. Should the debiased optimal transport distances be used.

online.cost

How to calculate the distance matrix. One of "auto", "tensorized", or "online".

diameter

The diameter of the metric space, if known. Default is NULL.

niter

The maximum number of iterations for the Sinkhorn updates

tol

The tolerance for convergence

Methods (by class)

  • ot_distance(causalWeights): method for causalWeights class

  • ot_distance(matrix): method for matrices

  • ot_distance(array): method for arrays

  • ot_distance(torch_tensor): method for torch_tensors

Examples

Run this code
if ( torch::torch_is_installed()) {
x <- matrix(stats::rnorm(10*5), 10, 5)
z <- stats::rbinom(10, 1, 0.5)
weights <- calc_weight(x = x, z = z, method = "Logistic", estimand = "ATT")
ot1 <- ot_distance(x1 = weights, penalty = 100, 
p = 2, debias = TRUE, online.cost = "auto", 
diameter = NULL)
ot2<- ot_distance(x1 = x[z==0, ], x2 = x[z == 1,], 
a= weights@w0/sum(weights@w0), b = weights@w1,
 penalty = 100, p = 2, debias = TRUE, online.cost = "auto", diameter = NULL)

 all.equal(ot1$post, ot2) 
}

Run the code above in your browser using DataLab