library(tensorflow)
library(keras)

library(tfautograph)
tf$version$VERSION
#> [1] "2.6.0"

autograph() works without needing to provide any hints. It can automatically translate expressions like if and while into tf.cond() and tf.while_loop() calls.

However, compared to the flexibility of R control flow, the tensorflow control flow functions have much narrower constraints on how the control flow nodes are constructed. The constraints of functions like tf.cond(), tf.while_loop() and dataset.reduct() require that autograph() do some sophisticated inference from the R code to anticipate what the outcome will be from evaluating a block of code, so that it can “fix-up” the outcome to satisfy the constraints tensorflow sets out.

If you want to bypass these inference routines, (for example, because the inferred outcome is producing a graph that has more nodes than strictly necessary, or for some minor performance benefits), you can provide hints to autograph to let it know what it would otherwise try to infer.

There are 2 main types of hints, corresponding to the two main control flow functions: if and while.

Since these hints are only relevant in graph mode we’ll disable eager execution for the rest of the vignette. In your code you’re more likely to enter graph mode only when tracing a tf.function().

tf$compat$v1$disable_eager_execution()
sess <- tf$compat$v1$Session()

Passing hints to if

Take the following expression:

x <- tf$compat$v1$placeholder('float32')
y <- tf$constant(0)
y
#> Tensor("Const:0", shape=(), dtype=float32)
autograph({
  if(x > 0)
    y <- y + 1
})
y
#> Tensor("cond/Identity:0", shape=(), dtype=float32)

Notice that y points to a different tensor after the autograph() call. The new tensor has 'cond' in the name, which is a big hint that autograph() took that expression and called tf.cond() with it.

If you were to inspect the graph (e.g, by reading the GraphDef or inspecting it in tensorboard, perhaps by calling tfautograph::view_function_graph()), you would see that in essence, the expression got translated into the equivalent of:

y <- tf$cond(x > 0, 
             true_fn = function() y + 1,
             false_fn = function() y)

tf.cond() takes the true and false branches as functions true_fn and false_fn, and requires that the two functions return the same structure of tensors.

How did autograph() know how to balance the two branches? It traced each branch as a tf.function and then reconciled the differences. Then within the tf.cond() call it recalled the traced function of the branch, along with some routines that “fix-up” the actual outcome from the branch into a balanced final outcome.

If you wanted to avoid tracing each branch as a separate ConcreteFunction and balancing the output structures, you could supply a hint to autograph about what outcomes you expect from the branches by calling ag_if_vars() right before the if.

autograph({
  ag_if_vars(y)
  if(x > 0)
    y <- y + 1
})

Here, y is the only variable modified, so it’s the only symbol that we want returned from the tf.cond().

ag_if_vars() can help you specify a few different types of outcomes that should be captured by the tf.cond() besides modified symbols, including modified values in nested structures, return values, number of control flow events (e.g., break and/or next statements), and local variables.

args(ag_if_vars)
#> function (..., modified = list(), return = FALSE, undefs = NULL, 
#>     control_flow = 0) 
#> NULL

see ag_if_vars() for details.

loops: while and for

Take the expression:

x <- tf$constant(0)
n <- tf$constant(10)
autograph({
  while(n > 0) {
    x <- x + 1
    n <- n - 1
  }
})

If you were to build the equivalent with tf.while_loop() call it would look like this:

c(x, n) %<-% tf$while_loop(
  cond = function(x, n)
    n > 0,
  body = function(x, n) {
    x <- x + 1
    n <- n - 1
    tuple(x, n)
  },
  loop_vars = tuple(x, n)
)

So, as you can see there are a few things that autograph needs to know to be able to translate an R while loop into a tf.while_loop() call. Namely, it needs to know what are the loop_vars that might be modified by the loop body, both so that it can pass them to the loop_vars argument and so that it can construct cond and body functions with the appropriate formals and return values.

How does autograph() know which symbols belong in loop_vars, which are meant to be local-only variables, and which are only read and never modified? There are a couple different ways to approach solving this, but the way autograph() does it by statically inferring them by examining the R loop body expression. That is, without actually evaluating the code, we inspect which symbols are the targets of common assignment operators like <-, ->, =, %<-%, %->% and %<>%. If you are autographing a for loop like for(var in seq) than the symbol var is included in the set of symbols that would be statically inferred as assignment targets.

After all the assigned symbols are inferred, they are classified as either belonging to loop_vars or undefs based on whether the symbol already exists.

This approach is quicker and cheaper than tracing, but it can sometimes lead to some symbols to loop_vars that otherwise could have been to be local-only variables.

You can provide hints to autograph about what are the loop_vars like so:

autograph({
  ag_loop_vars(x, n)
  while(n > 0) {
    x <- x + 1
    n <- n - 1
  }
})

Because the most common motivation for providing a loop hint will be to exclude a certain symbol from being captured as a loop var, there is special syntax for that. Prefix the variable with a minus - like so:

tmp <- tf$constant(99)
autograph({
  ag_loop_vars(-tmp)
  while(n > 0) {
    tmp <- x + 1
    x <- tmp
    n <- n - 1
  }
})

Note that by default ag_loop_vars exports excluded variables as undefs, symbols that throw an error (with a helpful error message) if you try to access them. To prevent autograph from creating undefs you can pass NULL like so:

ag_loop_vars(-tmp, undefs = NULL)

You can also use ag_loop_vars() to specify additional symbols to be included with the statically inferred ones by prefixing them with a +. See ?ag_loop_vars() for additional examples.