Skip to content

Commit

Permalink
Add sparse_core_embedding_config to TPUEmbeddingLayer and add device …
Browse files Browse the repository at this point in the history
…assignment to VF in TPU distribute strategy.

PiperOrigin-RevId: 633971690
  • Loading branch information
ZhaoyueCheng authored and TensorFlow Recommenders Authors committed Jun 18, 2024
1 parent 3aa525a commit f4c6da5
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion tensorflow_recommenders/layers/embedding/tpu_embedding_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@
TPUEmbeddingType = (
TPUEmbeddingType | tf.tpu.experimental.embedding.TPUEmbeddingV2
)
if hasattr(tf.tpu.experimental.embedding, "SparseCoreEmbeddingConfig"):
SparseCoreEmbeddingConfig = (
tf.tpu.experimental.embedding.SparseCoreEmbeddingConfig
)
else:
SparseCoreEmbeddingConfig = None # pylint: disable=invalid-name


def _normalize_and_prepare_optimizer(optimizer):
Expand Down Expand Up @@ -597,7 +603,8 @@ def __init__(
tf.tpu.experimental.embedding.FTRL]],
pipeline_execution_with_tensor_core: bool = False,
batch_size: Optional[int] = None,
embedding_feature: Optional[EmbeddingFeature] = None):
embedding_feature: Optional[EmbeddingFeature] = None,
sparse_core_embedding_config: Optional[SparseCoreEmbeddingConfig] = None):
"""A Keras layer for accelerated embedding lookups on TPU.
Args:
Expand All @@ -617,6 +624,8 @@ def __init__(
compatibility.
embedding_feature: EmbeddingFeature enum, inidicating which version of TPU
hardware the layer should run on.
sparse_core_embedding_config: SparseCoreEmbeddingConfig, inidicating
configuration for sparse core embedding when using TPUEmbedding V2
"""
super().__init__()
self._feature_config, self._table_config_map = (
Expand Down Expand Up @@ -654,6 +663,7 @@ def __init__(
self._using_tpu,
self._embedding_feature,
pipeline_execution_with_tensor_core,
sparse_core_embedding_config
)
self.batch_size = batch_size
self._tpu_call_id = 0
Expand All @@ -663,6 +673,7 @@ def _create_tpu_embedding_mid_level_api(
using_tpu: bool,
embedding_feature: Optional[EmbeddingFeature],
pipeline_execution_with_tensor_core: bool,
sparse_core_embedding_config: Optional[SparseCoreEmbeddingConfig],
) -> TPUEmbeddingType:
"""Creates TPU Embedding mid level API instance based on settings.
Expand All @@ -674,6 +685,8 @@ def _create_tpu_embedding_mid_level_api(
computations will overlap with the TensorCore computations (and hence
will be one step old with potential correctness drawbacks). Only used
when the embedding feature is set to be v1.
sparse_core_embedding_config: SparseCoreEmbeddingConfig used by TPU
` Embedding V2
Returns:
Instance of the TPUEmbedding mid level API.
Expand All @@ -699,6 +712,7 @@ def _create_tpu_embedding_mid_level_api(
self._feature_config,
self._optimizer,
pipeline_execution_with_tensor_core,
sparse_core_embedding_config,
)
else:
raise ValueError("TPUEmbeddingV2 is not supported in TF.")
Expand Down

0 comments on commit f4c6da5

Please sign in to comment.