Callback to monitor likelihood gradient components
Source:R/callback_debug_dist_gradients.R
callback_debug_dist_gradients.Rd
Provides a keras callback to monitor the individual components of the censored and truncated likelihood. Useful for debugging TensorFlow implementations of Distributions.
Usage
callback_debug_dist_gradients(
object,
data,
obs,
keep_grads = FALSE,
stop_on_na = TRUE,
verbose = TRUE
)
Arguments
- object
A
reservr_keras_model
created bytf_compile_model()
.- data
Input data for the model.
- obs
Observations associated to
data
.- keep_grads
Log actual gradients? (memory hungry!)
- stop_on_na
Stop if any likelihood component as NaN in its gradients?
- verbose
Print a message if training is halted? The Message will contain information about which likelihood components have NaN in their gradients.
Value
A KerasCallback
suitable for passing to keras3::fit()
.
Examples
dist <- dist_exponential()
group <- sample(c(0, 1), size = 100, replace = TRUE)
x <- dist$sample(100, with_params = list(rate = group + 1))
global_fit <- fit(dist, x)
if (interactive()) {
library(keras3)
l_in <- layer_input(shape = 1L)
mod <- tf_compile_model(
inputs = list(l_in),
intermediate_output = l_in,
dist = dist,
optimizer = optimizer_adam(),
censoring = FALSE,
truncation = FALSE
)
tf_initialise_model(mod, global_fit$params)
gradient_tracker <- callback_debug_dist_gradients(
mod,
as_tensor(group, config_floatx()),
x,
keep_grads = TRUE
)
fit_history <- fit(
mod,
x = as_tensor(group, config_floatx()),
y = x,
epochs = 20L,
callbacks = list(
callback_adaptive_lr("loss", factor = 0.5, patience = 2L, verbose = 1L, min_lr = 1.0e-4),
gradient_tracker,
callback_reduce_lr_on_plateau("loss", min_lr = 1.0) # to track lr
)
)
gradient_tracker$gradient_logs[[20]]$dens
plot(fit_history)
predicted_means <- predict(mod, data = as_tensor(c(0, 1), config_floatx()))
}