# NOT RUN {
### Example logistic regression with randomly-generated data
library(stochQN)
### Will sample data y ~ Bernoulli(sigm(Ax))
true_coefs <- c(1.12, 5.34, -6.123)
generate_data_batch <- function(true_coefs, n = 100) {
X <- matrix(rnorm(length(true_coefs) * n), nrow=n, ncol=length(true_coefs))
y <- 1 / (1 + exp(-as.numeric(X %*% true_coefs)))
y <- as.numeric(y >= runif(n))
return(list(X = X, y = y))
}
### Logistic regression likelihood/loss
eval_fun <- function(coefs, X, y, weights=NULL, lambda=1e-5) {
pred <- 1 / (1 + exp(-as.numeric(X %*% coefs)))
logloss <- mean(-(y * log(pred) + (1 - y) * log(1 - pred)))
reg <- lambda * as.numeric(coefs %*% coefs)
return(logloss + reg)
}
eval_grad <- function(coefs, X, y, weights=NULL, lambda=1e-5) {
pred <- 1 / (1 + exp(-(X %*% coefs)))
grad <- colMeans(X * as.numeric(pred - y))
grad <- grad + 2 * lambda * as.numeric(coefs^2)
return(as.numeric(grad))
}
eval_Hess_vec <- function(coefs, vec, X, y, weights=NULL, lambda=1e-5) {
pred <- 1 / (1 + exp(-as.numeric(X %*% coefs)))
diag <- pred * (1 - pred)
Hp <- (t(X) * diag) %*% (X %*% vec)
Hp <- Hp / NROW(X) + 2 * lambda * vec
return(as.numeric(Hp))
}
pred_fun <- function(X, coefs, ...) {
return(1 / (1 + exp(-as.numeric(X %*% coefs))))
}
### Initialize optimizer form arbitrary values
x0 <- c(1, 1, 1)
optimizer <- SQN(x0, grad_fun=eval_grad, pred_fun=pred_fun,
hess_vec_fun=eval_Hess_vec, initial_step=1e-0)
val_data <- generate_data_batch(true_coefs, n=100)
### Fit to 250 batches of data, 100 observations each
set.seed(1)
for (i in 1:250) {
new_batch <- generate_data_batch(true_coefs, n=100)
partial_fit(optimizer, new_batch$X, new_batch$y, lambda=1e-5)
x_curr <- get_curr_x(optimizer)
i_curr <- get_iteration_number(optimizer)
if ((i_curr %% 10) == 0) {
cat(sprintf("Iteration %3d - E[f(x)]: %f - values of x: [%f, %f, %f]\n",
i_curr, eval_fun(x_curr, val_data$X, val_data$y, lambda=1e-5),
x_curr[1], x_curr[2], x_curr[3]))
}
}
### Predict for new data
new_batch <- generate_data_batch(true_coefs, n=10)
yhat <- predict(optimizer, new_batch$X)
# }
Run the code above in your browser using DataLab