s2net (version 1.0.7)

predict.s2netR: S3 Methods for s2netR objects.

Description

Generic predict method. Wrapper for the C++ class method s2net$predict.

Usage

# S3 method for s2netR
predict(object, newX, type = "default", ...)

Value

A column matrix with predictions.

Arguments

object

A s2netR object

newX

A matrix with the data to make predictions. It should be in the same scale as the original data. See s2Data to see how to format the data.

type

Type of predictions. One of "default" (figure it out from the train data), "response", "probs", "class".

...

other parameters passed to predict

See Also

s2netR, s2net

Examples

Run this code
data("auto_mpg")
train = s2Data(xL = auto_mpg$P1$xL, yL = auto_mpg$P1$yL,  xU = auto_mpg$P1$xU)

model = s2netR(train, 
                s2Params(lambda1 = 0.1, 
                           lambda2 = 0,
                           gamma1 = 0.1,
                           gamma2 = 100,
                           gamma3 = 0.1),
                loss = "linear",
                frame = "ExtJT",
                proj = "auto",
                fista = s2Fista(5000, 1e-7, 1, 0.8))

valid = s2Data(auto_mpg$P1$xU, auto_mpg$P1$yU, preprocess = train)
ypred = predict(model, valid$xL)
if (FALSE) {
if(require(ggplot2)){
  ggplot() + 
    aes(x = ypred, y = valid$yL) + geom_point() + 
    geom_abline(intercept = 0, slope = 1, linetype = 2)
}
}

Run the code above in your browser using DataLab