s2net (version 1.0)

s2netR: Trains a generalized extended linear joint trained model using semi-supervised data.

Description

This function is a wrapper for the class s2net. It creates the C++ object and fits the model using input data.

Usage

s2netR(data, params, loss = "default", frame = "ExtJT", proj = "auto", 
        fista = NULL, S3 = TRUE)

Arguments

data

A s2Data object with the (training) data.

params

A s2Params object with the model hyper-parameters.

loss

Loss function. One of "default" (figure it out from the data), "linear" or "logit".

frame

The semi-supervised frame: "ExtJT" (the extended linear joint trained model), "JT" (the linear joint trained model from Ryan and Culp. 2015)

proj

Should the unlabeled data be shifted to remove the model's effect? One of "no", "yes", "auto" (option auto shifts the unlabeled data if the angle betwen beta and the center of the data is important)

fista

Fista setup parameters. An object of class s2Fista.

S3

Boolean: should the method return an S3 object (default) or a C++ object?

Value

Returns an object of S3 class s2netR or a C++ object of class s2net

References

Ryan, K. J., & Culp, M. V. (2015). On semi-supervised linear regression in covariate shift problems. The Journal of Machine Learning Research, 16(1), 3183-3217.

See Also

s2net

Examples

Run this code
# NOT RUN {
data("auto_mpg")
train = s2Data(xL = auto_mpg$P1$xL, yL = auto_mpg$P1$yL,  xU = auto_mpg$P1$xU)

model = s2netR(train, 
                s2Params(lambda1 = 0.1, 
                           lambda2 = 0,
                           gamma1 = 0.1,
                           gamma2 = 100,
                           gamma3 = 0.1),
                loss = "linear",
                frame = "ExtJT",
                proj = "auto",
                fista = s2Fista(5000, 1e-7, 1, 0.8))

valid = s2Data(auto_mpg$P1$xU, auto_mpg$P1$yU, preprocess = train)
ypred = predict(model, valid$xL)

# }
# NOT RUN {
if(require(ggplot2)){
  ggplot() + 
    aes(x = ypred, y = valid$yL) + geom_point() + 
    geom_abline(intercept = 0, slope = 1, linetype = 2)
}
# }
# NOT RUN {
# }

Run the code above in your browser using DataLab