Demo – MNIST Training Loop using autograph()

Here is a full MNIST training loop implemented in R using tfautograph. (originally adapted from here)

library(magrittr)
library(purrr, warn.conflicts = FALSE)

library(tensorflow)
library(tfdatasets)
library(keras)

library(tfautograph)
tf$version$VERSION
#> [1] "2.2.0"

First, some helpers so we can capture tf.print() output in the Rmarkdown vignette.

TEMPFILE <- tempfile("tf-print-out", fileext = ".txt")

print_tempfile <- function(clear_after_read = TRUE) {
  if (clear_after_read) on.exit(unlink(TEMPFILE))
  writeLines(readLines(TEMPFILE, warn = FALSE))
}

tf_print <- function(...)
  tf$print(..., output_stream = sprintf("file://%s", TEMPFILE))

Prepare the dataset

c(c(x_train, y_train), .) %<-%  tf$keras$datasets$mnist$load_data()

train_dataset <- list(x_train, y_train) %>%
  tensor_slices_dataset() %>%
  dataset_map(function(x, y) {
    x <- tf$cast(x, tf$float32) / 255
    y <- tf$cast(y, tf$int64)
    list(x, y)
  }) %>%
  dataset_take(20000) %>%
  dataset_shuffle(20000) %>%
  dataset_batch(100)

Define the model

new_model_and_optimizer <- function() {
  model <- keras_model_sequential() %>%
    layer_reshape(target_shape = c(28 * 28),
                  input_shape = shape(28, 28)) %>%
    layer_dense(100, activation = 'relu') %>%
    layer_dense(100, activation = 'relu') %>%
    layer_dense(10)
  model$build()
  optimizer <- tf$keras$optimizers$Adam()
  list(model, optimizer)
}
c(model, optimizer) %<-% new_model_and_optimizer()

Define the training loop

compute_loss <- tf$keras$losses$SparseCategoricalCrossentropy(from_logits = TRUE)
compute_accuracy <- tf$keras$metrics$SparseCategoricalAccuracy()

train_one_step <- function(model, optimizer, x, y) {
  with(tf$GradientTape() %as% tape, {
    logits <- model(x)
    loss <- compute_loss(y, logits)
  })

  grads <- tape$gradient(loss, model$trainable_variables)
  optimizer$apply_gradients(
    transpose(list(grads, model$trainable_variables)))

  compute_accuracy(y, logits)
  loss
}

train <- autograph(function(model, optimizer) {
  step <- 0L
  loss <- 0
  for (batch in train_dataset) {
    c(x, y) %<-% batch
    step %<>% add(1L)
    loss <- train_one_step(model, optimizer, x, y)
    if (compute_accuracy$result() > 0.8) {
      tf_print("Accuracy over 0.8; breaking early")
      break
    } else if (step %% 10L == 0L)
      tf_print('Step', step, ': loss', loss, '; accuracy', compute_accuracy$result())
  }
  tf_print('Final step', step, ": loss", loss, "; accuracy", compute_accuracy$result())
  list(step, loss)
})

Train in Eager mode

# autograph also works in eager mode

c(model, optimizer) %<-% new_model_and_optimizer()
c(step, loss) %<-% train(model, optimizer)

print_tempfile()
#> Step 10 : loss 1.75744426 ; accuracy 0.377
#> Step 20 : loss 1.05870402 ; accuracy 0.5325
#> Step 30 : loss 0.873054206 ; accuracy 0.615
#> Step 40 : loss 0.504798 ; accuracy 0.6625
#> Step 50 : loss 0.602258563 ; accuracy 0.6972
#> Step 60 : loss 0.527141631 ; accuracy 0.723833323
#> Step 70 : loss 0.375314593 ; accuracy 0.749571443
#> Step 80 : loss 0.44427222 ; accuracy 0.76875
#> Step 90 : loss 0.218258798 ; accuracy 0.782666683
#> Step 100 : loss 0.517769277 ; accuracy 0.7928
#> Accuracy over 0.8; breaking early
#> Final step 107 : loss 0.260477364 ; accuracy 0.800467312

Train in Graph mode

c(model, optimizer) %<-% new_model_and_optimizer()

train_on_graph <- tf_function(train)
c(step, loss) %<-% train_on_graph(model, optimizer)

print_tempfile()
#> Step 10 : loss 1.80349636 ; accuracy 0.766410232
#> Step 20 : loss 0.99037528 ; accuracy 0.760236204
#> Step 30 : loss 0.781103075 ; accuracy 0.760437965
#> Step 40 : loss 0.421598345 ; accuracy 0.764625847
#> Step 50 : loss 0.497002035 ; accuracy 0.769681513
#> Step 60 : loss 0.457768053 ; accuracy 0.775748491
#> Step 70 : loss 0.298231125 ; accuracy 0.781864405
#> Step 80 : loss 0.400309563 ; accuracy 0.786898375
#> Step 90 : loss 0.219272882 ; accuracy 0.79228425
#> Step 100 : loss 0.313125879 ; accuracy 0.797729492
#> Accuracy over 0.8; breaking early
#> Final step 104 : loss 0.283688635 ; accuracy 0.800473928