s2net (version 1.0)

predict.s2netR: S3 Methods for s2netR objects.

Description

Generic predict method. Wrapper for the C++ class method s2net$predict.

Usage

# S3 method for s2netR
predict(object, newX, type = "default", ...)

Arguments

object

A s2netR object

newX

A matrix with the data to make predictions. It should be in the same scale as the original data. See s2Data to see how to format the data.

type

Type of predictions. One of "default" (figure it out from the train data), "response", "probs", "class".

...

Value

A column matrix with predictions.

See Also

s2netR, s2net

Examples

Run this code
# NOT RUN {
data("auto_mpg")
train = s2Data(xL = auto_mpg$P1$xL, yL = auto_mpg$P1$yL,  xU = auto_mpg$P1$xU)

model = s2netR(train, 
                s2Params(lambda1 = 0.1, 
                           lambda2 = 0,
                           gamma1 = 0.1,
                           gamma2 = 100,
                           gamma3 = 0.1),
                loss = "linear",
                frame = "ExtJT",
                proj = "auto",
                fista = s2Fista(5000, 1e-7, 1, 0.8))

valid = s2Data(auto_mpg$P1$xU, auto_mpg$P1$yU, preprocess = train)
ypred = predict(model, valid$xL)
# }
# NOT RUN {
if(require(ggplot2)){
  ggplot() + 
    aes(x = ypred, y = valid$yL) + geom_point() + 
    geom_abline(intercept = 0, slope = 1, linetype = 2)
}
# }

Run the code above in your browser using DataCamp Workspace