Provides a keras callback similar to keras::callback_reduce_lr_on_plateau() but which also restores the weights to the best seen so far whenever a learning rate reduction occurs, and with slightly more restrictive improvement detection.


  monitor = "val_loss",
  factor = 0.1,
  patience = 10L,
  verbose = 0L,
  mode = c("auto", "min", "max"),
  delta_abs = 1e-04,
  delta_rel = 0,
  cooldown = 0L,
  min_lr = 0,
  restore_weights = TRUE



quantity to be monitored.


factor by which the learning rate will be reduced. new_lr = old_lr * factor.


number of epochs with no significant improvement after which the learning rate will be reduced.


integer. Set to 1 to receive update messages.


Optimisation mode. "auto" detects the mode from the name of monitor. "min" monitors for decreasing metrics. "max" monitors for increasing metrics.


Minimum absolute metric improvement per epoch. The learning rate will be reduced if the average improvement is less than delta_abs per epoch for patience epochs.


Minimum relative metric improvement per epoch. The learning rate will be reduced if the average improvement is less than |metric| * delta_rel per epoch for patience epochs.


number of epochs to wait before resuming normal operation after learning rate has been reduced. The minimum number of epochs between two learning rate reductions is patience + cooldown.


lower bound for the learning rate. If a learning rate reduction would lower the learning rate below min_lr, it will be clipped at min_lr instead and no further reductions will be performed.


Bool. If TRUE, the best weights will be restored at each learning rate reduction. This is very useful if the metric oscillates.


A KerasCallback suitable for passing to keras::fit().


Note that while callback_reduce_lr_on_plateau() automatically logs the learning rate as a metric 'lr', this is currently impossible from R. Thus, if you want to also log the learning rate, you should add callback_reduce_lr_on_plateau() with a high min_lr to effectively disable the callback but still monitor the learning rate.


dist <- dist_exponential()
group <- sample(c(0, 1), size = 100, replace = TRUE)
x <- dist$sample(100, with_params = list(rate = group + 1))
global_fit <- fit(dist, x)

if (interactive() && keras::is_keras_available()) {
  l_in <- layer_input(shape = 1L)
  mod <- tf_compile_model(
    inputs = list(l_in),
    intermediate_output = l_in,
    dist = dist,
    optimizer = optimizer_adam(),
    censoring = FALSE,
    truncation = FALSE
  tf_initialise_model(mod, global_fit$params)
  fit_history <- fit(
    x = k_constant(group),
    y = as_trunc_obs(x),
    epochs = 20L,
    callbacks = list(
      callback_adaptive_lr("loss", factor = 0.5, patience = 2L, verbose = 1L, min_lr = 1.0e-4),
      callback_reduce_lr_on_plateau("loss", min_lr = 1.0) # to track lr


  predicted_means <- predict(mod, data = k_constant(c(0, 1)))