Skip to contents

Provides a keras callback similar to keras3::callback_reduce_lr_on_plateau() but which also restores the weights to the best seen so far whenever a learning rate reduction occurs, and with slightly more restrictive improvement detection.

Usage

callback_adaptive_lr(
  monitor = "val_loss",
  factor = 0.1,
  patience = 10L,
  verbose = 0L,
  mode = c("auto", "min", "max"),
  delta_abs = 1e-04,
  delta_rel = 0,
  cooldown = 0L,
  min_lr = 0,
  restore_weights = TRUE
)

Arguments

monitor

quantity to be monitored.

factor

factor by which the learning rate will be reduced. new_lr = old_lr * factor.

patience

number of epochs with no significant improvement after which the learning rate will be reduced.

verbose

integer. Set to 1 to receive update messages.

mode

Optimisation mode. "auto" detects the mode from the name of monitor. "min" monitors for decreasing metrics. "max" monitors for increasing metrics.

delta_abs

Minimum absolute metric improvement per epoch. The learning rate will be reduced if the average improvement is less than delta_abs per epoch for patience epochs.

delta_rel

Minimum relative metric improvement per epoch. The learning rate will be reduced if the average improvement is less than |metric| * delta_rel per epoch for patience epochs.

cooldown

number of epochs to wait before resuming normal operation after learning rate has been reduced. The minimum number of epochs between two learning rate reductions is patience + cooldown.

min_lr

lower bound for the learning rate. If a learning rate reduction would lower the learning rate below min_lr, it will be clipped at min_lr instead and no further reductions will be performed.

restore_weights

Bool. If TRUE, the best weights will be restored at each learning rate reduction. This is very useful if the metric oscillates.

Value

A KerasCallback suitable for passing to keras3::fit().

Details

Note that while keras3::callback_reduce_lr_on_plateau() automatically logs the learning rate as a metric 'lr', this is currently impossible from R. Thus, if you want to also log the learning rate, you should add keras3::callback_reduce_lr_on_plateau() with a high min_lr to effectively disable the callback but still monitor the learning rate.

Examples

dist <- dist_exponential()
group <- sample(c(0, 1), size = 100, replace = TRUE)
x <- dist$sample(100, with_params = list(rate = group + 1))
global_fit <- fit(dist, x)

if (interactive()) {
  library(keras3)
  l_in <- layer_input(shape = 1L)
  mod <- tf_compile_model(
    inputs = list(l_in),
    intermediate_output = l_in,
    dist = dist,
    optimizer = optimizer_adam(),
    censoring = FALSE,
    truncation = FALSE
  )
  tf_initialise_model(mod, global_fit$params)
  fit_history <- fit(
    mod,
    x = as_tensor(group, config_floatx()),
    y = as_trunc_obs(x),
    epochs = 20L,
    callbacks = list(
      callback_adaptive_lr("loss", factor = 0.5, patience = 2L, verbose = 1L, min_lr = 1.0e-4),
      callback_reduce_lr_on_plateau("loss", min_lr = 1.0) # to track lr
    )
  )

  plot(fit_history)

  predicted_means <- predict(mod, data = as_tensor(c(0, 1), config_floatx()))
}