This bijector implements a continuous dynamics transformation parameterized by a differential equation, where initial and terminal conditions correspond to domain (X) and image (Y) i.e.
tfb_ffjord(
state_time_derivative_fn,
ode_solve_fn = NULL,
trace_augmentation_fn = tfp$bijectors$ffjord$trace_jacobian_hutchinson,
initial_time = 0,
final_time = 1,
validate_args = FALSE,
dtype = tf$float32,
name = "ffjord"
)
a bijector instance.
function
taking arguments time
(a scalar representing time) and state
(a Tensor representing the
state at given time
) returning the time derivative of the state
at
given time
.
function
taking arguments ode_fn
(same as
state_time_derivative_fn
above), initial_time
(a scalar representing
the initial time of integration), initial_state
(a Tensor of floating
dtype represents the initial state) and solution_times
(1D Tensor of
floating dtype representing time at which to obtain the solution)
returning a Tensor of shape [time_axis, initial_state$shape]
. Will take
[final_time]
as the solution_times
argument and
state_time_derivative_fn
as ode_fn
argument.
If NULL
a DormandPrince solver from tfp$math$ode
is used.
Default value: NULL
function
taking arguments ode_fn
(
function
same as state_time_derivative_fn
above),
state_shape
(TensorShape of a the state), dtype
(same as dtype of
the state) and returning a function
taking arguments time
(a scalar representing the time at which the function is evaluted),
state
(a Tensor representing the state at given time
) that computes
a tuple (ode_fn(time, state)
, jacobian_trace_estimation
).
jacobian_trace_estimation
should represent trace of the jacobian of
ode_fn
with respect to state
. state_time_derivative_fn
will be
passed as ode_fn
argument.
Default value: tfp$bijectors$ffjord$trace_jacobian_hutchinson
Scalar float representing time to which the x
value of the
bijector corresponds to. Passed as initial_time
to ode_solve_fn
.
For default solver can be float
or floating scalar Tensor
.
Default value: 0.
Scalar float representing time to which the y
value of the
bijector corresponds to. Passed as solution_times
to ode_solve_fn
.
For default solver can be float
or floating scalar Tensor
.
Default value: 1.
Logical, default FALSE. Whether to validate input with asserts. If validate_args is FALSE, and the inputs are invalid, correct behavior is not guaranteed.
tf$DType
to prefer when converting args to Tensor
s. Else, we
fall back to a common dtype inferred from the args, finally falling
back to float32.
name prefixed to Ops created by this class.
d/dt[state(t)] = state_time_derivative_fn(t, state(t))
state(initial_time) = X
state(final_time) = Y
For this transformation the value of log_det_jacobian
follows another
differential equation, reducing it to computation of the trace of the jacobian
along the trajectory
state_time_derivative = state_time_derivative_fn(t, state(t))
d/dt[log_det_jac(t)] = Tr(jacobian(state_time_derivative, state(t)))
FFJORD constructor takes two functions ode_solve_fn
and
trace_augmentation_fn
arguments that customize integration of the
differential equation and trace estimation.
Differential equation integration is performed by a call to ode_solve_fn
.
Custom ode_solve_fn
must accept the following arguments:
ode_fn(time, state): Differential equation to be solved.
initial_time: Scalar float or floating Tensor representing the initial time.
initial_state: Floating Tensor representing the initial state.
solution_times: 1D floating Tensor of solution times.
And return a Tensor of shape [solution_times$shape, initial_state$shape]
representing state values evaluated at solution_times
. In addition
ode_solve_fn
must support nested structures. For more details see the
interface of tfp$math$ode$Solver$solve()
.
Trace estimation is computed simultaneously with state_time_derivative
using augmented_state_time_derivative_fn
that is generated by
trace_augmentation_fn
. trace_augmentation_fn
takes
state_time_derivative_fn
, state.shape
and state.dtype
arguments and
returns a augmented_state_time_derivative_fn
callable that computes both
state_time_derivative
and unreduced trace_estimation
.
Custom ode_solve_fn
and trace_augmentation_fn
examples:
# custom_solver_fn: `function(f, t_initial, t_solutions, y_initial, ...)`
# ... : Additional arguments to pass to custom_solver_fn.
ode_solve_fn <- function(ode_fn, initial_time, initial_state, solution_times) {
custom_solver_fn(ode_fn, initial_time, solution_times, initial_state, ...)
}
ffjord <- tfb_ffjord(state_time_derivative_fn, ode_solve_fn = ode_solve_fn)
# state_time_derivative_fn: `function(time, state)`
# trace_jac_fn: `function(time, state)` unreduced jacobian trace function
trace_augmentation_fn <- function(ode_fn, state_shape, state_dtype) {
augmented_ode_fn <- function(time, state) {
list(ode_fn(time, state), trace_jac_fn(time, state))
}
augmented_ode_fn
}
ffjord <- tfb_ffjord(state_time_derivative_fn, trace_augmentation_fn = trace_augmentation_fn)
For more details on FFJORD and continous normalizing flows see Chen et al. (2018), Grathwol et al. (2018).
Chen, T. Q., Rubanova, Y., Bettencourt, J., & Duvenaud, D. K. (2018). Neural ordinary differential equations. In Advances in neural information processing systems (pp. 6571-6583)
For usage examples see tfb_forward()
, tfb_inverse()
, tfb_inverse_log_det_jacobian()
.
Other bijectors:
tfb_absolute_value()
,
tfb_affine_linear_operator()
,
tfb_affine_scalar()
,
tfb_affine()
,
tfb_ascending()
,
tfb_batch_normalization()
,
tfb_blockwise()
,
tfb_chain()
,
tfb_cholesky_outer_product()
,
tfb_cholesky_to_inv_cholesky()
,
tfb_correlation_cholesky()
,
tfb_cumsum()
,
tfb_discrete_cosine_transform()
,
tfb_expm1()
,
tfb_exp()
,
tfb_fill_scale_tri_l()
,
tfb_fill_triangular()
,
tfb_glow()
,
tfb_gompertz_cdf()
,
tfb_gumbel_cdf()
,
tfb_gumbel()
,
tfb_identity()
,
tfb_inline()
,
tfb_invert()
,
tfb_iterated_sigmoid_centered()
,
tfb_kumaraswamy_cdf()
,
tfb_kumaraswamy()
,
tfb_lambert_w_tail()
,
tfb_masked_autoregressive_default_template()
,
tfb_masked_autoregressive_flow()
,
tfb_masked_dense()
,
tfb_matrix_inverse_tri_l()
,
tfb_matvec_lu()
,
tfb_normal_cdf()
,
tfb_ordered()
,
tfb_pad()
,
tfb_permute()
,
tfb_power_transform()
,
tfb_rational_quadratic_spline()
,
tfb_rayleigh_cdf()
,
tfb_real_nvp_default_template()
,
tfb_real_nvp()
,
tfb_reciprocal()
,
tfb_reshape()
,
tfb_scale_matvec_diag()
,
tfb_scale_matvec_linear_operator()
,
tfb_scale_matvec_lu()
,
tfb_scale_matvec_tri_l()
,
tfb_scale_tri_l()
,
tfb_scale()
,
tfb_shifted_gompertz_cdf()
,
tfb_shift()
,
tfb_sigmoid()
,
tfb_sinh_arcsinh()
,
tfb_sinh()
,
tfb_softmax_centered()
,
tfb_softplus()
,
tfb_softsign()
,
tfb_split()
,
tfb_square()
,
tfb_tanh()
,
tfb_transform_diagonal()
,
tfb_transpose()
,
tfb_weibull_cdf()
,
tfb_weibull()