dist <- dist_exponential()
group <- sample(c(0, 1), size = 100, replace = TRUE)
x <- dist$sample(100, with_params = list(rate = group + 1))
global_fit <- fit(dist, x)
if (interactive() && keras::is_keras_available()) {
library(keras)
l_in <- layer_input(shape = 1L)
mod <- tf_compile_model(
inputs = list(l_in),
intermediate_output = l_in,
dist = dist,
optimizer = optimizer_adam(),
censoring = FALSE,
truncation = FALSE
)
tf_initialise_model(mod, global_fit$params)
fit_history <- fit(
mod,
x = k_constant(group),
y = as_trunc_obs(x),
epochs = 20L,
callbacks = list(
callback_adaptive_lr("loss", factor = 0.5, patience = 2L, verbose = 1L, min_lr = 1.0e-4),
callback_reduce_lr_on_plateau("loss", min_lr = 1.0) # to track lr
)
)
plot(fit_history)
predicted_means <- predict(mod, data = k_constant(c(0, 1)))
}
Run the code above in your browser using DataLab