Compile a Keras model for truncated data under dist
Usage
tf_compile_model(
inputs,
intermediate_output,
dist,
optimizer,
censoring = TRUE,
truncation = TRUE,
metrics = NULL,
weighted_metrics = NULL
)
Arguments
- inputs
List of keras input layers
- intermediate_output
Intermediate model layer to be used as input to distribution parameters
- dist
A
Distribution
to use for compiling the loss and parameter outputs- optimizer
String (name of optimizer) or optimizer instance. See
optimizer_*
family.- censoring
A flag, whether the compiled model should support censored observations. Set to
FALSE
for higher efficiency.fit(...)
will error if the resulting model is used to fit censored observations.- truncation
A flag, whether the compiled model should support truncated observations. Set to
FALSE
for higher efficiency.fit(...)
will warn if the resuting model is used to fit truncated observations.- metrics
List of metrics to be evaluated by the model during training and testing. Each of these can be:
a string (name of a built-in function),
a function, optionally with a
"name"
attribute ora
Metric()
instance. See themetric_*
family of functions.
Typically you will use
metrics = c('accuracy')
. A function is any callable with the signatureresult = fn(y_true, y_pred)
. To specify different metrics for different outputs of a multi-output model, you could also pass a named list, such asmetrics = list(a = 'accuracy', b = c('accuracy', 'mse'))
. You can also pass a list to specify a metric or a list of metrics for each output, such asmetrics = list(c('accuracy'), c('accuracy', 'mse'))
ormetrics = list('accuracy', c('accuracy', 'mse'))
. When you pass the strings'accuracy'
or'acc'
, we convert this to one ofmetric_binary_accuracy()
,metric_categorical_accuracy()
,metric_sparse_categorical_accuracy()
based on the shapes of the targets and of the model output. A similar conversion is done for the strings"crossentropy"
and"ce"
as well. The metrics passed here are evaluated without sample weighting; if you would like sample weighting to apply, you can specify your metrics via theweighted_metrics
argument instead.If providing an anonymous R function, you can customize the printed name during training by assigning
attr(<fn>, "name") <- "my_custom_metric_name"
, or by callingcustom_metric("my_custom_metric_name", <fn>)
- weighted_metrics
List of metrics to be evaluated and weighted by
sample_weight
orclass_weight
during training and testing.
Value
A reservr_keras_model
that can be used to train truncated
and censored observations from dist
based on input data from inputs
.
Examples
dist <- dist_exponential()
params <- list(rate = 1.0)
N <- 100L
rand_input <- runif(N)
x <- dist$sample(N, with_params = params)
if (interactive()) {
tf_in <- keras3::layer_input(1L)
mod <- tf_compile_model(
inputs = list(tf_in),
intermediate_output = tf_in,
dist = dist,
optimizer = keras3::optimizer_adam(),
censoring = FALSE,
truncation = FALSE
)
}