s2netR
Trains a generalized extended linear joint trained model using semi-supervised data.
This function is a wrapper for the class s2net
. It creates the C++ object and fits the model using input data
.
Usage
s2netR(data, params, loss = "default", frame = "ExtJT", proj = "auto",
fista = NULL, S3 = TRUE)
Arguments
- data
A
s2Data
object with the (training) data.- params
A
s2Params
object with the model hyper-parameters.- loss
Loss function. One of
"default"
(figure it out from the data),"linear"
or"logit"
.- frame
The semi-supervised frame:
"ExtJT"
(the extended linear joint trained model),"JT"
(the linear joint trained model from Ryan and Culp. 2015)- proj
Should the unlabeled data be shifted to remove the model's effect? One of
"no", "yes", "auto"
(optionauto
shifts the unlabeled data if the angle betwen beta and the center of the data is important)- fista
Fista setup parameters. An object of class
s2Fista
.- S3
Boolean: should the method return an S3 object (default) or a C++ object?
Value
Returns an object of S3 class s2netR
or a C++ object of class s2net
References
Ryan, K. J., & Culp, M. V. (2015). On semi-supervised linear regression in covariate shift problems. The Journal of Machine Learning Research, 16(1), 3183-3217.
See Also
Examples
# NOT RUN {
data("auto_mpg")
train = s2Data(xL = auto_mpg$P1$xL, yL = auto_mpg$P1$yL, xU = auto_mpg$P1$xU)
model = s2netR(train,
s2Params(lambda1 = 0.1,
lambda2 = 0,
gamma1 = 0.1,
gamma2 = 100,
gamma3 = 0.1),
loss = "linear",
frame = "ExtJT",
proj = "auto",
fista = s2Fista(5000, 1e-7, 1, 0.8))
valid = s2Data(auto_mpg$P1$xU, auto_mpg$P1$yU, preprocess = train)
ypred = predict(model, valid$xL)
# }
# NOT RUN {
if(require(ggplot2)){
ggplot() +
aes(x = ypred, y = valid$yL) + geom_point() +
geom_abline(intercept = 0, slope = 1, linetype = 2)
}
# }
# NOT RUN {
# }