Learn R Programming

ruta (version 1.0.2)

train.ruta_autoencoder: Train a learner object with data

Description

This function compiles the neural network described by the learner object and trains it with the input data.

Usage

# S3 method for ruta_autoencoder
train(learner, data, validation_data = NULL,
  metrics = NULL, epochs = 20, optimizer = keras::optimizer_rmsprop(),
  ...)

train(learner, ...)

Arguments

learner

A "ruta_autoencoder" object

data

Training data: columns are attributes and rows are instances

validation_data

Additional numeric data matrix which will not be used for training but the loss measure and any metrics will be computed against it

metrics

Optional list of metrics which will evaluate the model but won't be optimized. See keras::compile

epochs

The number of times data will pass through the network

optimizer

The optimizer to be used in order to train the model, can be any optimizer object defined by Keras (e.g. keras::optimizer_adam())

...

Additional parameters for keras::fit. Some useful parameters:

  • batch_size The number of examples to be grouped for each gradient update. Use a smaller batch size for more frequent weight updates or a larger one for faster optimization.

  • shuffle Whether to shuffle the training data before each epoch, defaults to TRUE

Value

Same autoencoder passed as parameter, with trained internal models

See Also

autoencoder

Examples

Run this code
# NOT RUN {
# Minimal example ================================================
# }
# NOT RUN {
iris_model <- train(autoencoder(2), as.matrix(iris[, 1:4]))
# }
# NOT RUN {
# Simple example with MNIST ======================================
# }
# NOT RUN {
library(keras)

# Load and normalize MNIST
mnist = dataset_mnist()
x_train <- array_reshape(
  mnist$train$x, c(dim(mnist$train$x)[1], 784)
)
x_train <- x_train / 255.0
x_test <- array_reshape(
  mnist$test$x, c(dim(mnist$test$x)[1], 784)
)
x_test <- x_test / 255.0

# Autoencoder with layers: 784-256-36-256-784
learner <- autoencoder(c(256, 36), "binary_crossentropy")
train(
  learner,
  x_train,
  epochs = 1,
  optimizer = "rmsprop",
  batch_size = 64,
  validation_data = x_test,
  metrics = list("binary_accuracy")
)
# }

Run the code above in your browser using DataLab