Learn R Programming

RBERT (version 0.1.11)

get_assignment_map_from_checkpoint: Compute the intersection of the current variables and checkpoint variables

Description

Returns the intersection (not the union, as python docs say -JDB) of the sets of variable names from the current graph and the checkpoint.

Usage

get_assignment_map_from_checkpoint(tvars, init_checkpoint)

Arguments

tvars

List of training variables in the current model.

init_checkpoint

Character; path to the checkpoint directory, plus checkpoint name stub (e.g. "bert_model.ckpt"). Path must be absolute and explicit, starting with "/".

Value

List with two elements: the assignment map and the initialized variable names. The assignment map is a list of the "base" variable names that are in both the current computational graph and the checkpoint. The initialized variable names list contains both the base names and the base names + ":0". (This seems redundant to me. I assume it will make sense later. -JDB)

Details

Note that a Tensorflow checkpoint is not the same as a saved model. A saved model contains a complete description of the computational graph and is sufficient to reconstruct the entire model, while a checkpoint contains just the parameter values (and variable names), and so requires a specification of the original model structure to reconstruct the computational graph. -JDB

Examples

Run this code
# NOT RUN {
# Just for illustration: create a "model" with a couple variables
# that overlap some variable names in the BERT checkpoint.
with(tensorflow::tf$variable_scope("bert",
  reuse = tensorflow::tf$AUTO_REUSE
), {
  test_ten1 <- tensorflow::tf$get_variable(
    "encoder/layer_9/output/dense/bias",
    shape = c(1L, 2L, 3L)
  )
  test_ten2 <- tensorflow::tf$get_variable(
    "encoder/layer_9/output/dense/kernel",
    shape = c(1L, 2L, 3L)
  )
})
tvars <- tensorflow::tf$get_collection(
  tensorflow::tf$GraphKeys$GLOBAL_VARIABLES
)
temp_dir <- tempdir()
init_checkpoint <- file.path(
  temp_dir,
  "BERT_checkpoints",
  "uncased_L-12_H-768_A-12",
  "bert_model.ckpt"
)

amap <- get_assignment_map_from_checkpoint(tvars, init_checkpoint)
# }

Run the code above in your browser using DataLab