Compile a Keras model for truncated data under dist
Usage
tf_compile_model(
inputs,
intermediate_output,
dist,
optimizer,
censoring = TRUE,
truncation = TRUE,
metrics = NULL,
weighted_metrics = NULL
)Arguments
- inputs
List of keras input layers
- intermediate_output
Intermediate model layer to be used as input to distribution parameters
- dist
A
Distributionto use for compiling the loss and parameter outputs- optimizer
String (name of optimizer) or optimizer instance. See
optimizer_*family.- censoring
A flag, whether the compiled model should support censored observations. Set to
FALSEfor higher efficiency.fit(...)will error if the resulting model is used to fit censored observations.- truncation
A flag, whether the compiled model should support truncated observations. Set to
FALSEfor higher efficiency.fit(...)will warn if the resuting model is used to fit truncated observations.- metrics
List of metrics to be evaluated by the model during training and testing. Each of these can be:
a string (name of a built-in function),
a function, optionally with a
"name"attribute ora
Metric()instance. See themetric_*family of functions.
Typically you will use
metrics = c('accuracy'). A function is any callable with the signatureresult = fn(y_true, y_pred). To specify different metrics for different outputs of a multi-output model, you could also pass a named list, such asmetrics = list(a = 'accuracy', b = c('accuracy', 'mse')). You can also pass a list to specify a metric or a list of metrics for each output, such asmetrics = list(c('accuracy'), c('accuracy', 'mse'))ormetrics = list('accuracy', c('accuracy', 'mse')). When you pass the strings'accuracy'or'acc', we convert this to one ofmetric_binary_accuracy(),metric_categorical_accuracy(),metric_sparse_categorical_accuracy()based on the shapes of the targets and of the model output. A similar conversion is done for the strings"crossentropy"and"ce"as well. The metrics passed here are evaluated without sample weighting; if you would like sample weighting to apply, you can specify your metrics via theweighted_metricsargument instead.If providing an anonymous R function, you can customize the printed name during training by assigning
attr(<fn>, "name") <- "my_custom_metric_name", or by callingcustom_metric("my_custom_metric_name", <fn>)- weighted_metrics
List of metrics to be evaluated and weighted by
sample_weightorclass_weightduring training and testing.
Value
A reservr_keras_model that can be used to train truncated
and censored observations from dist based on input data from inputs.
Examples
dist <- dist_exponential()
params <- list(rate = 1.0)
N <- 100L
rand_input <- runif(N)
x <- dist$sample(N, with_params = params)
if (interactive()) {
tf_in <- keras3::layer_input(1L)
mod <- tf_compile_model(
inputs = list(tf_in),
intermediate_output = tf_in,
dist = dist,
optimizer = keras3::optimizer_adam(),
censoring = FALSE,
truncation = FALSE
)
}