Skip to content

Commit

Permalink
Remove special casing by sparse core as it did not work as intended
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 702977131
  • Loading branch information
TensorFlow Recommenders Authors committed Dec 5, 2024
1 parent 151a970 commit c60e42e
Showing 1 changed file with 3 additions and 18 deletions.
21 changes: 3 additions & 18 deletions tensorflow_recommenders/layers/embedding/tpu_embedding_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,8 @@
"""Keras interface for TPU Embeddings in TF2."""

from typing import Any, Dict, Iterable, Optional, Union
import tensorflow.compat.v2 as tf


# From tensorflow/python/layers/sparse_core_util.py to avoid circular dependency
# and avoid creating another separate file.
def has_sparse_core() -> bool:
"""Check to see if SparseCore is available."""
strategy = tf.distribute.get_strategy()
if not isinstance(
strategy,
(tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy),
):
return False
return (
strategy.extended.tpu_hardware_feature.embedding_feature
== tf.tpu.experimental.HardwareFeature.EmbeddingFeature.V2
)
import tensorflow.compat.v2 as tf

_SLOT_NAME_MAPPING = {
# Slot names in Keras optimizer v2 are different compared to the slot names
Expand Down Expand Up @@ -636,7 +621,7 @@ def __init__(
will be one step old with potential correctness drawbacks). Set to True
for improved performance.
batch_size: Batch size of the input feature. Deprecated, support backward
compatibility. Set None for sparse core for proper shape inference.
compatibility.
embedding_feature: EmbeddingFeature enum, inidicating which version of TPU
hardware the layer should run on.
sparse_core_embedding_config: SparseCoreEmbeddingConfig, inidicating
Expand Down Expand Up @@ -680,7 +665,7 @@ def __init__(
pipeline_execution_with_tensor_core,
sparse_core_embedding_config
)
self.batch_size = None if has_sparse_core() else batch_size
self.batch_size = batch_size
self._tpu_call_id = 0

def _create_tpu_embedding_mid_level_api(
Expand Down

0 comments on commit c60e42e

Please sign in to comment.