Skip to contents

Provides a keras callback to monitor the individual components of the censored and truncated likelihood. Useful for debugging TensorFlow implementations of Distributions.

Usage

callback_debug_dist_gradients(
  object,
  data,
  obs,
  keep_grads = FALSE,
  stop_on_na = TRUE,
  verbose = TRUE
)

Arguments

object

A reservr_keras_model created by tf_compile_model().

data

Input data for the model.

obs

Observations associated to data.

keep_grads

Log actual gradients? (memory hungry!)

stop_on_na

Stop if any likelihood component as NaN in its gradients?

verbose

Print a message if training is halted? The Message will contain information about which likelihood components have NaN in their gradients.

Value

A KerasCallback suitable for passing to keras3::fit().

Examples

dist <- dist_exponential()
group <- sample(c(0, 1), size = 100, replace = TRUE)
x <- dist$sample(100, with_params = list(rate = group + 1))
global_fit <- fit(dist, x)

if (interactive()) {
  library(keras3)
  l_in <- layer_input(shape = 1L)
  mod <- tf_compile_model(
    inputs = list(l_in),
    intermediate_output = l_in,
    dist = dist,
    optimizer = optimizer_adam(),
    censoring = FALSE,
    truncation = FALSE
  )
  tf_initialise_model(mod, global_fit$params)
  gradient_tracker <- callback_debug_dist_gradients(
    mod,
    as_tensor(group, config_floatx()),
    x,
    keep_grads = TRUE
  )
  fit_history <- fit(
    mod,
    x = as_tensor(group, config_floatx()),
    y = x,
    epochs = 20L,
    callbacks = list(
      callback_adaptive_lr("loss", factor = 0.5, patience = 2L, verbose = 1L, min_lr = 1.0e-4),
      gradient_tracker,
      callback_reduce_lr_on_plateau("loss", min_lr = 1.0) # to track lr
    )
  )
  gradient_tracker$gradient_logs[[20]]$dens

  plot(fit_history)

  predicted_means <- predict(mod, data = as_tensor(c(0, 1), config_floatx()))
}