Provides a function to send the output of a DataTransformer to a torch tensor, so that it can be accessed during GAN training.
gan_update_step(
data,
batch_size,
noise_dim,
sample_noise,
device = "cpu",
g_net,
d_net,
g_optim,
d_optim,
value_function,
weight_clipper
)
Input a data set. Needs to be a matrix, array, torch::torch_tensor or torch::dataset.
The number of training samples selected into the mini batch for training. Defaults to 50.
The dimensions of the GAN noise vector z. Defaults to 2.
A function to sample noise to a torch::tensor
Input on which device (e.g. "cpu" or "cuda") training should be done. Defaults to "cpu".
The generator network. Expects a neural network provided as torch::nn_module. Default is NULL which will create a simple fully connected neural network.
The discriminator network. Expects a neural network provided as torch::nn_module. Default is NULL which will create a simple fully connected neural network.
The optimizer for the generator network. Expects a torch::optim_xxx function, e.g. torch::optim_adam(). Default is NULL which will setup torch::optim_adam(g_net$parameters, lr = base_lr)
.
The optimizer for the generator network. Expects a torch::optim_xxx function, e.g. torch::optim_adam(). Default is NULL which will setup torch::optim_adam(g_net$parameters, lr = base_lr * ttur_factor)
.
The value function for GAN training. Expects a function that takes discriminator scores of real and fake data as input and returns a list with the discriminator loss and generator loss. For reference see: . For convenience three loss functions "original", "wasserstein" and "f-wgan" are already implemented. Defaults to "original".
The wasserstein GAN puts some constraints on the weights of the discriminator, therefore weights are clipped during training.
A function