From 7134466ab5e72321135c720562e99f4bd8610de0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 21 Feb 2024 22:50:38 -0800 Subject: [PATCH] Adds Checkpoint adapter for TPU Embeddings. PiperOrigin-RevId: 609249665 --- tf_keras/optimizers/legacy/optimizer_v2.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tf_keras/optimizers/legacy/optimizer_v2.py b/tf_keras/optimizers/legacy/optimizer_v2.py index 8ebc2e593..78905c92e 100644 --- a/tf_keras/optimizers/legacy/optimizer_v2.py +++ b/tf_keras/optimizers/legacy/optimizer_v2.py @@ -1540,7 +1540,8 @@ def _restore_slot_variable(self, slot_name, variable, slot_variable): checkpoint_position.restore(slot_variable) def _create_or_restore_slot_variable( - self, slot_variable_position, slot_name, variable + self, slot_variable_position, slot_name, variable, + slot_variable_shape=None ): """Returns the slot variable that should have a value restored into it. @@ -1563,6 +1564,8 @@ def _create_or_restore_slot_variable( indicating the slot variable `Trackable` object to be restored. slot_name: The name of this `Optimizer`'s slot to restore into. variable: The variable object this slot is being created for. + slot_variable_shape: (Optional) Shape of the required slot variable. + When not provided, the shape is same as the value in checkpoint. Returns: A slot variable that should have a value restored into it, or None if @@ -1602,7 +1605,7 @@ def _create_or_restore_slot_variable( var=variable, initializer=initializer, slot_name=slot_name, - shape=slot_variable_position.value_shape(), + shape=slot_variable_position.value_shape() if slot_variable_shape is None else slot_variable_shape, ) # Slot variables are not owned by any one object (because we don't # want to save the slot variable if the optimizer is saved without