Skip to contents

Compile a Keras model for truncated data under dist

Usage

tf_compile_model(
  inputs,
  intermediate_output,
  dist,
  optimizer,
  censoring = TRUE,
  truncation = TRUE,
  metrics = NULL,
  weighted_metrics = NULL
)

Arguments

inputs

List of keras input layers

intermediate_output

Intermediate model layer to be used as input to distribution parameters

dist

A Distribution to use for compiling the loss and parameter outputs

optimizer

String (name of optimizer) or optimizer instance. See optimizer_* family.

censoring

A flag, whether the compiled model should support censored observations. Set to FALSE for higher efficiency. fit(...) will error if the resulting model is used to fit censored observations.

truncation

A flag, whether the compiled model should support truncated observations. Set to FALSE for higher efficiency. fit(...) will warn if the resuting model is used to fit truncated observations.

metrics

List of metrics to be evaluated by the model during training and testing. Each of these can be:

  • a string (name of a built-in function),

  • a function, optionally with a "name" attribute or

  • a Metric() instance. See the metric_* family of functions.

Typically you will use metrics = c('accuracy'). A function is any callable with the signature result = fn(y_true, y_pred). To specify different metrics for different outputs of a multi-output model, you could also pass a named list, such as metrics = list(a = 'accuracy', b = c('accuracy', 'mse')). You can also pass a list to specify a metric or a list of metrics for each output, such as metrics = list(c('accuracy'), c('accuracy', 'mse')) or metrics = list('accuracy', c('accuracy', 'mse')). When you pass the strings 'accuracy' or 'acc', we convert this to one of metric_binary_accuracy(), metric_categorical_accuracy(), metric_sparse_categorical_accuracy() based on the shapes of the targets and of the model output. A similar conversion is done for the strings "crossentropy" and "ce" as well. The metrics passed here are evaluated without sample weighting; if you would like sample weighting to apply, you can specify your metrics via the weighted_metrics argument instead.

If providing an anonymous R function, you can customize the printed name during training by assigning attr(<fn>, "name") <- "my_custom_metric_name", or by calling custom_metric("my_custom_metric_name", <fn>)

weighted_metrics

List of metrics to be evaluated and weighted by sample_weight or class_weight during training and testing.

Value

A reservr_keras_model that can be used to train truncated and censored observations from dist based on input data from inputs.

Examples

dist <- dist_exponential()
params <- list(rate = 1.0)
N <- 100L
rand_input <- runif(N)
x <- dist$sample(N, with_params = params)

if (interactive()) {
  tf_in <- keras3::layer_input(1L)
  mod <- tf_compile_model(
    inputs = list(tf_in),
    intermediate_output = tf_in,
    dist = dist,
    optimizer = keras3::optimizer_adam(),
    censoring = FALSE,
    truncation = FALSE
  )
}