rsparse (version 0.3.3.2)

FactorizationMachine: Creates FactorizationMachine model.

Description

Creates second order Factorization Machines model

Usage

FactorizationMachine

Format

R6Class object.

Usage

For usage details see Methods, Arguments and Examples sections.

fm = FM$new(learning_rate_w = 0.2, rank = 4, lambda_w = 0, lambda_v = 0, family = c("binomial", "gaussian")
  intercept = TRUE, learning_rate_v = learning_rate_w)
fm$partial_fit(x, y, ...)
fm$predict(x, ...)

Methods

FM$new(learning_rate_w = 0.2, rank = 4, lambda_w = 0, lambda_v = 0, family = c("binomial", "gaussian"), intercept = TRUE, learning_rate_v = learning_rate_w)

Constructor for FactorizationMachines model. For description of arguments see Arguments section.

$partial_fit(x, y, ...)

fits/updates model given input matrix x and target vector y. x shape = (n_samples, n_features)

$predict(x, ...)

predicts output x

Arguments

fm

FM object

x

Input sparse matrix - native format is Matrix::RsparseMatrix. If x is in different format, model will try to convert it to RsparseMatrix with as(x, "RsparseMatrix") call

learning_rate_w

learning rate for linear weights in AdaGrad SGD

learning_rate_v

learning rate for interactions in AdaGrad SGD

rank

rank of the latent dimension in factorization

lambda_w

regularization parameter for linear terms

lambda_v

regularization parameter for interactions terms

intercept

logical flag which specify whether to allow model to have non-zero intercept/bias

family

a description of the error distribution and link function to be used in the model. Can be "gaussian" (for regression) or "binomial" (for classification)

Examples

Run this code
# NOT RUN {
# Factorization Machines can fit XOR function!
x = rbind(
  c(0, 0),
  c(0, 1),
  c(1, 0),
  c(1, 1)
)
y = c(0, 1, 1, 0)

x = as(x, "RsparseMatrix")
fm = FactorizationMachine$new(learning_rate_w = 10, rank = 2, lambda_w = 0,
  lambda_v = 0, family = 'binomial', intercept = TRUE)
res = fm$fit(x, y, n_iter = 100)
preds = fm$predict(x)
all(preds[c(1, 4)] < 0.01)
all(preds[c(2, 3)] > 0.99)
# }

Run the code above in your browser using DataLab