# reticulate::install_miniconda()
# tensorflow::install_tensorflow("conda", envname = "tf-v1", version = "1-cpu")
reticulate::use_condaenv("tf-v1", required = TRUE)
library(tensorflow)
library(tfautograph)
tf$version$VERSION
#> [1] "1.15.0"
tf <- tf$compat$v1
autograph()
has complete compatibility with both Tensorflow version >= 2.0 and >= 1.15. For the most part, there are no difference between how autograph()
work in both version, either in eager mode or graph mode.
One thing to be aware of however is how control dependencies are registered. In Tensorflow version 2, all ops created inside a traced tf.function()
are executed sequentially. That means that simply inlining a call to tf$print()
or tf$Assert()
is sufficient to ensure that the op is executed, and is executed in the intended order in a sequence of operations.
However, when tensorflow is not executing eagerly and outside a tf.function()
context (that is, the default behavior of Tensorflow version 1, but also possible in version 2 after a call to tf.compat.v1.disable_eager_execution()
), then created ops are only typically only evaluated if they are manually registered as control dependencies of other tensors.
For example, this block of code never throws an error:
sess <- tf$Session()
x <- tf$constant(-1)
stop_if_negative <- function(x) {
tf_assert(x > 0, "x is positive")
x
}
sess$run(stop_if_negative(x))
#> [1] -1
There are two straight forward solutions to this. One is to wrap the block of code in an tf.function()
stop_if_negative2 <- tf_function(stop_if_negative)
sess$run(stop_if_negative2(x))
# [<#> Error in py_call_impl(callable, dots$args, dots$keywords): InvalidArgumentError: assertion failed:
# [<Assert condition R expression>: `x > 0`] [x is positive] [`x` value:] [-1]
# [<R call stack>:]
# [<R call 1>: tf_assert(x > 0, \"x is positive\")]
# [<R call 2>: (function (x) ]
# [<R call 3>: (function (what, args, quote = FALSE, envir = parent.frame()) ]
# [<R call 4>: evalq((function (what, args, quote = FALSE, envir = parent.frame()) ]
# [<R call 5>: evalq((function (what, args, quote = FALSE, envir = parent.frame()) ]
# [<R call 6>: doTryCatch(return(expr), name, parentenv, handler)]
# [<R call 7>: tryCatchOne(expr, names, parentenv, handlers[[1L]])]
# [<R call 8>: tryCatchList(expr, names[-nh], parentenv, handlers[-nh])]
# [<R call 9>: doTryCatch(return(expr), name, parentenv, handler)]
# [<R call 10>: tryCatchOne(tryCatchList(expr, names[-nh], parentenv, handlers[-nh]), names[nh], parentenv, handlers[[nh]])]
# [<R call 11>: tryCatchList(expr, classes, parentenv, handlers)]
# [<R call 12>: tryCatch(evalq((function (what, args, quote = FALSE, envir = parent.frame()) ]
# [<R call 13>: py_call_impl(callable, dots$args, dots$keywords)]
# [<R call 14>: stop_if_negative2(x)]
# [<R call 15>: py_resolve_dots(list(...))]
# [<R call 16>: sess$run(stop_if_negative2(x))]
# [<R call 17>: eval(expr, envir, enclos)]
# [<R call 18>: eval(expr, envir, enclos)]
# [<R call 19>: withVisible(eval(expr, envir, enclos))]
# [<R call 20>: withCallingHandlers(withVisible(eval(expr, envir, enclos)), warning = wHandler, error = eHandler, message = mHandler)]
# [<R call 21>: doTryCatch(return(expr), name, parentenv, handler)]
# [<R call 22>: tryCatchOne(expr, names, parentenv, handlers[[1L]])]
# [<R call 23>: tryCatchList(expr, classes, parentenv, handlers)]
# [<R call 24>: tryCatch(expr, error = function(e) {]
# [<R call 25>: try(f, silent = TRUE)]
# [<R call 26>: handle(ev <- withCallingHandlers(withVisible(eval(expr, envir, enclos)), warning = wHandler, error = eHandler, message = mHandler))]
# [<R call 27>: timing_fn(handle(ev <- withCallingHandlers(withVisible(eval(expr, envir, enclos)), warning = wHandler, error = eHandler, message = mHandler)))]
# [<R call 28>: evaluate_call(expr, parsed$src[[i]], envir = envir, enclos = enclos, debug = debug, last = i == length(out), use_try = stop_on_error != 2L, keep_warning = keep_warning, keep_message = keep_message, output_handler = output_handler, include_timing = include_timing)]
# [<R call 29>: evaluate::evaluate(...)]
# [<R call 30>: evaluate(code, envir = env, new_device = FALSE, keep_warning = !isFALSE(options$warning), keep_message = !isFALSE(options$message), stop_on_error = if (is.numeric(options$error)) options$error else {]
# [<R call 31>: in_dir(input_dir(), evaluate(code, envir = env, new_device = FALSE, keep_warning = !isFALSE(options$warning), keep_message = !isFALSE(options$message), stop_on_error = if (is.numeric(options$error)) options$error else {]
# [<R call 32>: eng_r(options)]
# [<R call 33>: block_exec(params)]
# [<R call 34>: call_block(x)]
# [<R call 35>: process_group.block(group)]
# [<R call 36>: process_group(group)]
# [<R call 37>: withCallingHandlers(if (tangle) process_tangle(group) else process_group(group), error = function(e) {]
# [<R call 38>: process_file(text, output)]
# [<R call 39>: knitr::knit(knit_input, knit_output, envir = envir, quiet = quiet)]
# [<R call 40>: rmarkdown::render(...)]
# [<R call 41>: (function (...) ]
# [<R call 42>: (function (what, args, quote = FALSE, envir = parent.frame()) ]
# [<R call 43>: do.call(do.call, c(readRDS(\"/tmp/RtmpEYn9yM/callr-fun-55c9194f81bb\"), list(envir = .GlobalEnv, quote = TRUE)), envir = .GlobalEnv, quote = TRUE)]
# [<R call 44>: saveRDS(do.call(do.call, c(readRDS(\"/tmp/RtmpEYn9yM/callr-fun-55c9194f81bb\"), list(envir = .GlobalEnv, quote = TRUE)), envir = .GlobalEnv, quote = TRUE), file = \"/tmp/RtmpEYn9yM/callr-res-55c930331500\")]
# [<R call 45>: withCallingHandlers({]
# [<R call 46>: doTryCatch(return(expr), name, parentenv, handler)]
# [<R call 47>: tryCatchOne(expr, names, parentenv, handlers[[1L]])]
# [<R call 48>: tryCatchList(expr, names[-nh], parentenv, handlers[-nh])]
# [<R call 49>: doTryCatch(return(expr), name, parentenv, handler)]
# [<R call 50>: tryCatchOne(tryCatchList(expr, names[-nh], parentenv, handlers[-nh]), names[nh], parentenv, handlers[[nh]])]
# [<R call 51>: tryCatchList(expr, classes, parentenv, handlers)]
# [<R call 52>: tryCatch(withCallingHandlers({]
#> [[{{node Assert}}]]
#>
#> Detailed traceback:
#> File "/home/tomasz/.local/share/r-miniconda/envs/tf-v1/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 956, in run
#> run_metadata_ptr)
#> File "/home/tomasz/.local/share/r-miniconda/envs/tf-v1/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1180, in _run
#> feed_dict_tensor, options, run_metadata)
#> File "/home/tomasz/.local/share/r-miniconda/envs/tf-v1/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1359, in _do_run
#> run_metadata)
#> File "/home/tomasz/.local/share/r-miniconda/envs/tf-v1/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1384, in _do_call
#> raise type(e)(node_def, op, message)
The other is to capture the assert operation as a control dependency of the evaluated tensor.
stop_if_negative3 <- function(x) {
assert_op <- tf_assert(x > 0, "x is positive")
with(tf$control_dependencies(list(assert_op)), {
y <- tf$identity(x)
})
y
}
sess$run(stop_if_negative3(x))
# [<#> Error in py_call_impl(callable, dots$args, dots$keywords): InvalidArgumentError: assertion failed:
# [<Assert condition R expression>: `x > 0`] [x is positive] [`x` value:] [-1]
# [<R call stack>:]
# [<R call 1>: tf_assert(x > 0, \"x is positive\")]
# [<R call 2>: stop_if_negative3(x)]
# [<R call 3>: py_resolve_dots(list(...))]
# [<R call 4>: sess$run(stop_if_negative3(x))]
# [<R call 5>: eval(expr, envir, enclos)]
# [<R call 6>: eval(expr, envir, enclos)]
# [<R call 7>: withVisible(eval(expr, envir, enclos))]
# [<R call 8>: withCallingHandlers(withVisible(eval(expr, envir, enclos)), warning = wHandler, error = eHandler, message = mHandler)]
# [<R call 9>: doTryCatch(return(expr), name, parentenv, handler)]
# [<R call 10>: tryCatchOne(expr, names, parentenv, handlers[[1L]])]
# [<R call 11>: tryCatchList(expr, classes, parentenv, handlers)]
# [<R call 12>: tryCatch(expr, error = function(e) {]
# [<R call 13>: try(f, silent = TRUE)]
# [<R call 14>: handle(ev <- withCallingHandlers(withVisible(eval(expr, envir, enclos)), warning = wHandler, error = eHandler, message = mHandler))]
# [<R call 15>: timing_fn(handle(ev <- withCallingHandlers(withVisible(eval(expr, envir, enclos)), warning = wHandler, error = eHandler, message = mHandler)))]
# [<R call 16>: evaluate_call(expr, parsed$src[[i]], envir = envir, enclos = enclos, debug = debug, last = i == length(out), use_try = stop_on_error != 2L, keep_warning = keep_warning, keep_message = keep_message, output_handler = output_handler, include_timing = include_timing)]
# [<R call 17>: evaluate::evaluate(...)]
# [<R call 18>: evaluate(code, envir = env, new_device = FALSE, keep_warning = !isFALSE(options$warning), keep_message = !isFALSE(options$message), stop_on_error = if (is.numeric(options$error)) options$error else {]
# [<R call 19>: in_dir(input_dir(), evaluate(code, envir = env, new_device = FALSE, keep_warning = !isFALSE(options$warning), keep_message = !isFALSE(options$message), stop_on_error = if (is.numeric(options$error)) options$error else {]
# [<R call 20>: eng_r(options)]
# [<R call 21>: block_exec(params)]
# [<R call 22>: call_block(x)]
# [<R call 23>: process_group.block(group)]
# [<R call 24>: process_group(group)]
# [<R call 25>: withCallingHandlers(if (tangle) process_tangle(group) else process_group(group), error = function(e) {]
# [<R call 26>: process_file(text, output)]
# [<R call 27>: knitr::knit(knit_input, knit_output, envir = envir, quiet = quiet)]
# [<R call 28>: rmarkdown::render(...)]
# [<R call 29>: (function (...) ]
# [<R call 30>: (function (what, args, quote = FALSE, envir = parent.frame()) ]
# [<R call 31>: do.call(do.call, c(readRDS(\"/tmp/RtmpEYn9yM/callr-fun-55c9194f81bb\"), list(envir = .GlobalEnv, quote = TRUE)), envir = .GlobalEnv, quote = TRUE)]
# [<R call 32>: saveRDS(do.call(do.call, c(readRDS(\"/tmp/RtmpEYn9yM/callr-fun-55c9194f81bb\"), list(envir = .GlobalEnv, quote = TRUE)), envir = .GlobalEnv, quote = TRUE), file = \"/tmp/RtmpEYn9yM/callr-res-55c930331500\")]
# [<R call 33>: withCallingHandlers({]
# [<R call 34>: doTryCatch(return(expr), name, parentenv, handler)]
# [<R call 35>: tryCatchOne(expr, names, parentenv, handlers[[1L]])]
# [<R call 36>: tryCatchList(expr, names[-nh], parentenv, handlers[-nh])]
# [<R call 37>: doTryCatch(return(expr), name, parentenv, handler)]
# [<R call 38>: tryCatchOne(tryCatchList(expr, names[-nh], parentenv, handlers[-nh]), names[nh], parentenv, handlers[[nh]])]
# [<R call 39>: tryCatchList(expr, classes, parentenv, handlers)]
# [<R call 40>: tryCatch(withCallingHandlers({]
#> [[node Assert_1/AssertGuard/Assert (defined at /framework/ops.py:1748) ]]
#>
#> Detailed traceback:
#> File "/home/tomasz/.local/share/r-miniconda/envs/tf-v1/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 956, in run
#> run_metadata_ptr)
#> File "/home/tomasz/.local/share/r-miniconda/envs/tf-v1/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1180, in _run
#> feed_dict_tensor, options, run_metadata)
#> File "/home/tomasz/.local/share/r-miniconda/envs/tf-v1/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1359, in _do_run
#> run_metadata)
#> File "/home/tomasz/.local/share/r-miniconda/envs/tf-v1/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1384, in _do_call
#> raise type(e)(node_def, op, message)
Because the latter approach in stop_if_negative3
is so common and so cumbersome, autograph
takes care of it for you when autographing stopifnot()
calls.
stop_if_negative4 <- function(x) {
autograph(stopifnot(x > 0))
x
}
sess$run(stop_if_negative4(x))
# [<#> Error in py_call_impl(callable, dots$args, dots$keywords): InvalidArgumentError: assertion failed:
# [<Assert condition R expression>: `x > 0`] [`x` value:] [-1]
# [<R call stack>:]
# [<R call 1>: stopifnot(x > 0)]
# [<R call 2>: fn(...)]
# [<R call 3>: withVisible(fn(...))]
# [<R call 4>: fn()]
# [<R call 5>: autograph(stopifnot(x > 0))]
# [<R call 6>: stop_if_negative4(x)]
# [<R call 7>: py_resolve_dots(list(...))]
# [<R call 8>: sess$run(stop_if_negative4(x))]
# [<R call 9>: eval(expr, envir, enclos)]
# [<R call 10>: eval(expr, envir, enclos)]
# [<R call 11>: withVisible(eval(expr, envir, enclos))]
# [<R call 12>: withCallingHandlers(withVisible(eval(expr, envir, enclos)), warning = wHandler, error = eHandler, message = mHandler)]
# [<R call 13>: doTryCatch(return(expr), name, parentenv, handler)]
# [<R call 14>: tryCatchOne(expr, names, parentenv, handlers[[1L]])]
# [<R call 15>: tryCatchList(expr, classes, parentenv, handlers)]
# [<R call 16>: tryCatch(expr, error = function(e) {]
# [<R call 17>: try(f, silent = TRUE)]
# [<R call 18>: handle(ev <- withCallingHandlers(withVisible(eval(expr, envir, enclos)), warning = wHandler, error = eHandler, message = mHandler))]
# [<R call 19>: timing_fn(handle(ev <- withCallingHandlers(withVisible(eval(expr, envir, enclos)), warning = wHandler, error = eHandler, message = mHandler)))]
# [<R call 20>: evaluate_call(expr, parsed$src[[i]], envir = envir, enclos = enclos, debug = debug, last = i == length(out), use_try = stop_on_error != 2L, keep_warning = keep_warning, keep_message = keep_message, output_handler = output_handler, include_timing = include_timing)]
# [<R call 21>: evaluate::evaluate(...)]
# [<R call 22>: evaluate(code, envir = env, new_device = FALSE, keep_warning = !isFALSE(options$warning), keep_message = !isFALSE(options$message), stop_on_error = if (is.numeric(options$error)) options$error else {]
# [<R call 23>: in_dir(input_dir(), evaluate(code, envir = env, new_device = FALSE, keep_warning = !isFALSE(options$warning), keep_message = !isFALSE(options$message), stop_on_error = if (is.numeric(options$error)) options$error else {]
# [<R call 24>: eng_r(options)]
# [<R call 25>: block_exec(params)]
# [<R call 26>: call_block(x)]
# [<R call 27>: process_group.block(group)]
# [<R call 28>: process_group(group)]
# [<R call 29>: withCallingHandlers(if (tangle) process_tangle(group) else process_group(group), error = function(e) {]
# [<R call 30>: process_file(text, output)]
# [<R call 31>: knitr::knit(knit_input, knit_output, envir = envir, quiet = quiet)]
# [<R call 32>: rmarkdown::render(...)]
# [<R call 33>: (function (...) ]
# [<R call 34>: (function (what, args, quote = FALSE, envir = parent.frame()) ]
# [<R call 35>: do.call(do.call, c(readRDS(\"/tmp/RtmpEYn9yM/callr-fun-55c9194f81bb\"), list(envir = .GlobalEnv, quote = TRUE)), envir = .GlobalEnv, quote = TRUE)]
# [<R call 36>: saveRDS(do.call(do.call, c(readRDS(\"/tmp/RtmpEYn9yM/callr-fun-55c9194f81bb\"), list(envir = .GlobalEnv, quote = TRUE)), envir = .GlobalEnv, quote = TRUE), file = \"/tmp/RtmpEYn9yM/callr-res-55c930331500\")]
# [<R call 37>: withCallingHandlers({]
# [<R call 38>: doTryCatch(return(expr), name, parentenv, handler)]
# [<R call 39>: tryCatchOne(expr, names, parentenv, handlers[[1L]])]
# [<R call 40>: tryCatchList(expr, names[-nh], parentenv, handlers[-nh])]
# [<R call 41>: doTryCatch(return(expr), name, parentenv, handler)]
# [<R call 42>: tryCatchOne(tryCatchList(expr, names[-nh], parentenv, handlers[-nh]), names[nh], parentenv, handlers[[nh]])]
# [<R call 43>: tryCatchList(expr, classes, parentenv, handlers)]
# [<R call 44>: tryCatch(withCallingHandlers({]
#> [[node Assert_2/AssertGuard/Assert (defined at /framework/ops.py:1748) ]]
#>
#> Detailed traceback:
#> File "/home/tomasz/.local/share/r-miniconda/envs/tf-v1/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 956, in run
#> run_metadata_ptr)
#> File "/home/tomasz/.local/share/r-miniconda/envs/tf-v1/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1180, in _run
#> feed_dict_tensor, options, run_metadata)
#> File "/home/tomasz/.local/share/r-miniconda/envs/tf-v1/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1359, in _do_run
#> run_metadata)
#> File "/home/tomasz/.local/share/r-miniconda/envs/tf-v1/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1384, in _do_call
#> raise type(e)(node_def, op, message)
Note that the error was thrown even though we are not using tf.function()
. How does autograph(stopifnot(...))
guarantee that the created tf$Assert()
Op is evaluated? With a multi-pronged, belt-and-suspenders approach:
tf$identity()
on every symbol involved in the assert expression and then replaces those symbols in the current scope.tf$control_dependency()
context open for the rest of the current execution scope.on.exit()
that calls tf$identity()
on the return value and then closes the control dependency context.autograph(stopifnot(...))
only does these additional things outside of a tf$function()
and when not executing eagerly. In side a tf.function()
or when executing eagerly, stopifnot()
merely calls tf$Assert()
, and doesn’t attempt to manually open and capture the ops as control dependencies.