gnn (version 0.0-2)

loss: Loss Function

Description

Implementation of various loss functions to measure statistical discrepancy between two datasets.

Usage

loss(x, y, type = c("MSE", "binary.cross", "MMD"), ...)

Arguments

x

2d tensor with shape (batch size, dimension of input dataset).

y

2d tensor with shape (batch size, dimension of input dataset).

type

character string indicating the type of loss used. Currently available are the mean squared error ("MSE"), binary cross entropy ("binary.cross") and (kernel) maximum mean discrepancy ("MMD").

additional arguments passed to the underlying loss function; at the moment, this is only affects type = "MMD" for which "bandwidth" can be provided.

Value

loss() returns a 0d tensor containing the loss.

References

Kingma, D. P. and Welling, M. (2014). Stochastic gradient VB and the variational auto-encoder. Second International Conference on Learning Representations (ICLR). See https://keras.rstudio.com/articles/examples/variational_autoencoder.html

See Also

GMMN_model() and VAE_model() where loss() is used.