This bijector implements a continuous dynamics transformation parameterized by a differential equation, where initial and terminal conditions correspond to domain (X) and image (Y) i.e.
tfb_ffjord(
state_time_derivative_fn,
ode_solve_fn = NULL,
trace_augmentation_fn = tfp$bijectors$ffjord$trace_jacobian_hutchinson,
initial_time = 0,
final_time = 1,
validate_args = FALSE,
dtype = tf$float32,
name = "ffjord"
)function taking arguments time
(a scalar representing time) and state (a Tensor representing the
state at given time) returning the time derivative of the state at
given time.
function taking arguments ode_fn (same as
state_time_derivative_fn above), initial_time (a scalar representing
the initial time of integration), initial_state (a Tensor of floating
dtype represents the initial state) and solution_times (1D Tensor of
floating dtype representing time at which to obtain the solution)
returning a Tensor of shape [time_axis, initial_state$shape]. Will take
[final_time] as the solution_times argument and
state_time_derivative_fn as ode_fn argument.
If NULL a DormandPrince solver from tfp$math$ode is used.
Default value: NULL
function taking arguments ode_fn (
function same as state_time_derivative_fn above),
state_shape (TensorShape of a the state), dtype (same as dtype of
the state) and returning a function taking arguments time
(a scalar representing the time at which the function is evaluted),
state (a Tensor representing the state at given time) that computes
a tuple (ode_fn(time, state), jacobian_trace_estimation).
jacobian_trace_estimation should represent trace of the jacobian of
ode_fn with respect to state. state_time_derivative_fn will be
passed as ode_fn argument.
Default value: tfp$bijectors$ffjord$trace_jacobian_hutchinson
Scalar float representing time to which the x value of the
bijector corresponds to. Passed as initial_time to ode_solve_fn.
For default solver can be float or floating scalar Tensor.
Default value: 0.
Scalar float representing time to which the y value of the
bijector corresponds to. Passed as solution_times to ode_solve_fn.
For default solver can be float or floating scalar Tensor.
Default value: 1.
Logical, default FALSE. Whether to validate input with asserts. If validate_args is FALSE, and the inputs are invalid, correct behavior is not guaranteed.
tf$DType to prefer when converting args to Tensors. Else, we
fall back to a common dtype inferred from the args, finally falling
back to float32.
name prefixed to Ops created by this class.
a bijector instance.
d/dt[state(t)] = state_time_derivative_fn(t, state(t)) state(initial_time) = X state(final_time) = Y
For this transformation the value of log_det_jacobian follows another
differential equation, reducing it to computation of the trace of the jacobian
along the trajectory
state_time_derivative = state_time_derivative_fn(t, state(t)) d/dt[log_det_jac(t)] = Tr(jacobian(state_time_derivative, state(t)))
FFJORD constructor takes two functions ode_solve_fn and
trace_augmentation_fn arguments that customize integration of the
differential equation and trace estimation.
Differential equation integration is performed by a call to ode_solve_fn.
Custom ode_solve_fn must accept the following arguments:
ode_fn(time, state): Differential equation to be solved.
initial_time: Scalar float or floating Tensor representing the initial time.
initial_state: Floating Tensor representing the initial state.
solution_times: 1D floating Tensor of solution times.
And return a Tensor of shape [solution_times$shape, initial_state$shape]
representing state values evaluated at solution_times. In addition
ode_solve_fn must support nested structures. For more details see the
interface of tfp$math$ode$Solver$solve().
Trace estimation is computed simultaneously with state_time_derivative
using augmented_state_time_derivative_fn that is generated by
trace_augmentation_fn. trace_augmentation_fn takes
state_time_derivative_fn, state.shape and state.dtype arguments and
returns a augmented_state_time_derivative_fn callable that computes both
state_time_derivative and unreduced trace_estimation.
Custom ode_solve_fn and trace_augmentation_fn examples:
# custom_solver_fn: `function(f, t_initial, t_solutions, y_initial, ...)`
# ... : Additional arguments to pass to custom_solver_fn.
ode_solve_fn <- function(ode_fn, initial_time, initial_state, solution_times) {
custom_solver_fn(ode_fn, initial_time, solution_times, initial_state, ...)
}
ffjord <- tfb_ffjord(state_time_derivative_fn, ode_solve_fn = ode_solve_fn)
# state_time_derivative_fn: `function(time, state)`
# trace_jac_fn: `function(time, state)` unreduced jacobian trace function
trace_augmentation_fn <- function(ode_fn, state_shape, state_dtype) {
augmented_ode_fn <- function(time, state) {
list(ode_fn(time, state), trace_jac_fn(time, state))
}
augmented_ode_fn
}
ffjord <- tfb_ffjord(state_time_derivative_fn, trace_augmentation_fn = trace_augmentation_fn)
For more details on FFJORD and continous normalizing flows see Chen et al. (2018), Grathwol et al. (2018).
Chen, T. Q., Rubanova, Y., Bettencourt, J., & Duvenaud, D. K. (2018). Neural ordinary differential equations. In Advances in neural information processing systems (pp. 6571-6583)
For usage examples see tfb_forward(), tfb_inverse(), tfb_inverse_log_det_jacobian().
Other bijectors:
tfb_absolute_value(),
tfb_affine_linear_operator(),
tfb_affine_scalar(),
tfb_affine(),
tfb_ascending(),
tfb_batch_normalization(),
tfb_blockwise(),
tfb_chain(),
tfb_cholesky_outer_product(),
tfb_cholesky_to_inv_cholesky(),
tfb_correlation_cholesky(),
tfb_cumsum(),
tfb_discrete_cosine_transform(),
tfb_expm1(),
tfb_exp(),
tfb_fill_scale_tri_l(),
tfb_fill_triangular(),
tfb_glow(),
tfb_gompertz_cdf(),
tfb_gumbel_cdf(),
tfb_gumbel(),
tfb_identity(),
tfb_inline(),
tfb_invert(),
tfb_iterated_sigmoid_centered(),
tfb_kumaraswamy_cdf(),
tfb_kumaraswamy(),
tfb_lambert_w_tail(),
tfb_masked_autoregressive_default_template(),
tfb_masked_autoregressive_flow(),
tfb_masked_dense(),
tfb_matrix_inverse_tri_l(),
tfb_matvec_lu(),
tfb_normal_cdf(),
tfb_ordered(),
tfb_pad(),
tfb_permute(),
tfb_power_transform(),
tfb_rational_quadratic_spline(),
tfb_rayleigh_cdf(),
tfb_real_nvp_default_template(),
tfb_real_nvp(),
tfb_reciprocal(),
tfb_reshape(),
tfb_scale_matvec_diag(),
tfb_scale_matvec_linear_operator(),
tfb_scale_matvec_lu(),
tfb_scale_matvec_tri_l(),
tfb_scale_tri_l(),
tfb_scale(),
tfb_shifted_gompertz_cdf(),
tfb_shift(),
tfb_sigmoid(),
tfb_sinh_arcsinh(),
tfb_sinh(),
tfb_softmax_centered(),
tfb_softplus(),
tfb_softsign(),
tfb_split(),
tfb_square(),
tfb_tanh(),
tfb_transform_diagonal(),
tfb_transpose(),
tfb_weibull_cdf(),
tfb_weibull()