s2net (version 1.0)

Rcpp_s2net-class: Class s2net

Description

This is the main class of this library, implemented in C++ and exposed to R using Rcpp modules. It can be used in R directly, although some generic S4 methods have been implemented to make it easier to interact in R.

Arguments

Methods

predict

signature(object = "Rcpp_s2net"): See predict_Rcpp_s2net

Fields

beta:

Object of class matrix. The fitted model coefficients.

intercept:

The model intercept.

Class-Based Methods

initialize(data, loss):

data

s2Data object

loss

Loss function: 0 = linear, 1 = logit

setupFista(s2Fista):

Configures the FISTA internal algorithm.

predict(newX, type):

newX

New data matrix to make predictions.

type

0 = default, 1 = response, 2 = probs, 3 = class

fit(params, frame, proj):

params

s2Params object

frame

0 = "JT", 1 = "ExtJT"

proj

0 = no, 1 = yes, 2 = auto

Examples

Run this code
# NOT RUN {
data("auto_mpg")
train = s2Data(xL = auto_mpg$P1$xL, yL = auto_mpg$P1$yL,  xU = auto_mpg$P1$xU)

# We create the C++ object calling the new method (constructor)
obj = new(s2net, train, 0) # 0 = regression 
obj

# We call directly the $fit method of obj, 
obj$fit(s2Params(lambda1 = 0.01, 
                   lambda2 = 0.01, 
                   gamma1 = 0.05, 
                   gamma2 = 100, 
                   gamma3 = 0.05), 1, 2)
# fitted model
obj$beta

# We can test the results using the unlabeled data
test = s2Data(xL = auto_mpg$P1$xU, yL = auto_mpg$P1$yU,  preprocess = train)
ypred = obj$predict(test$xL, 0)

# }
# NOT RUN {
if(require(ggplot2)){
  ggplot() + 
    aes(x = ypred, y = test$yL) + geom_point() + 
    geom_abline(intercept = 0, slope = 1, linetype = 2)
}
# }
# NOT RUN {
# }

Run the code above in your browser using DataCamp Workspace