tf.map_fn()
Thin wrapper around tf.map_fn()
with the following
differences:
accepts purrr
style ~
lambda syntax to define function fn
.
The order of elems
and fn
is switched to make it more pipe %>%
friendly and consistent with R mappers lapply()
and purrr::map()
.
tf_map(
elems,
fn,
dtype = NULL,
parallel_iterations = NULL,
back_prop = TRUE,
swap_memory = FALSE,
infer_shape = TRUE,
name = NULL
)
A tensor or (possibly nested) sequence of tensors, each of which
will be unpacked along their first dimension. The nested sequence of the
resulting slices will be applied to fn
.
An R function, specified using purrr
style ~ syntax, a character
string, a python function (or more generally, any python object with a
__call__
method) or anything coercible via as.function()
. The function
will be be called with one argument, which will have the same (possibly
nested) structure as elems
. Its output must return the same structure as
dtype
if one is provided, otherwise it must return the same structure as
elems
.
(optional) The output type(s) of fn. If fn returns a structure of Tensors differing from the structure of elems, then dtype is not optional and must have the same structure as the output of fn.
(optional) The number of iterations allowed to run in parallel. When graph building, the default value is 10. While executing eagerly, the default value is set to 1.
(optional) True enables support for back propagation.
(optional) True enables GPU-CPU memory swapping.
(optional) False disables tests for consistent output shapes.
(optional) Name prefix for the returned tensors.
A tensor or (possibly nested) sequence of tensors. Each tensor packs the results of applying fn to tensors unpacked from elems along the first dimension, from first to last.