Keras Callback for adaptive learning rate with weight restoration
Source:R/callback_adaptive_lr.R
callback_adaptive_lr.Rd
Provides a keras callback similar to keras3::callback_reduce_lr_on_plateau()
but which also restores the weights
to the best seen so far whenever a learning rate reduction occurs, and with slightly more restrictive improvement
detection.
Usage
callback_adaptive_lr(
monitor = "val_loss",
factor = 0.1,
patience = 10L,
verbose = 0L,
mode = c("auto", "min", "max"),
delta_abs = 1e-04,
delta_rel = 0,
cooldown = 0L,
min_lr = 0,
restore_weights = TRUE
)
Arguments
- monitor
quantity to be monitored.
- factor
factor by which the learning rate will be reduced.
new_lr = old_lr * factor
.- patience
number of epochs with no significant improvement after which the learning rate will be reduced.
- verbose
integer. Set to 1 to receive update messages.
- mode
Optimisation mode. "auto" detects the mode from the name of
monitor
. "min" monitors for decreasing metrics. "max" monitors for increasing metrics.- delta_abs
Minimum absolute metric improvement per epoch. The learning rate will be reduced if the average improvement is less than
delta_abs
per epoch forpatience
epochs.- delta_rel
Minimum relative metric improvement per epoch. The learning rate will be reduced if the average improvement is less than
|metric| * delta_rel
per epoch forpatience
epochs.- cooldown
number of epochs to wait before resuming normal operation after learning rate has been reduced. The minimum number of epochs between two learning rate reductions is
patience + cooldown
.- min_lr
lower bound for the learning rate. If a learning rate reduction would lower the learning rate below
min_lr
, it will be clipped atmin_lr
instead and no further reductions will be performed.- restore_weights
Bool. If TRUE, the best weights will be restored at each learning rate reduction. This is very useful if the metric oscillates.
Value
A KerasCallback
suitable for passing to keras3::fit()
.
Details
Note that while keras3::callback_reduce_lr_on_plateau()
automatically logs the learning rate as a metric 'lr',
this is currently impossible from R.
Thus, if you want to also log the learning rate, you should add keras3::callback_reduce_lr_on_plateau()
with a high
min_lr
to effectively disable the callback but still monitor the learning rate.
Examples
dist <- dist_exponential()
group <- sample(c(0, 1), size = 100, replace = TRUE)
x <- dist$sample(100, with_params = list(rate = group + 1))
global_fit <- fit(dist, x)
if (interactive()) {
library(keras3)
l_in <- layer_input(shape = 1L)
mod <- tf_compile_model(
inputs = list(l_in),
intermediate_output = l_in,
dist = dist,
optimizer = optimizer_adam(),
censoring = FALSE,
truncation = FALSE
)
tf_initialise_model(mod, global_fit$params)
fit_history <- fit(
mod,
x = as_tensor(group, config_floatx()),
y = as_trunc_obs(x),
epochs = 20L,
callbacks = list(
callback_adaptive_lr("loss", factor = 0.5, patience = 2L, verbose = 1L, min_lr = 1.0e-4),
callback_reduce_lr_on_plateau("loss", min_lr = 1.0) # to track lr
)
)
plot(fit_history)
predicted_means <- predict(mod, data = as_tensor(c(0, 1), config_floatx()))
}