Skip to content

Commit

Permalink
Update table_stacking_test and embedding_test to work for tpuv6.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 718924888
  • Loading branch information
cantonios authored and Google-ML-Automation committed Jan 23, 2025
1 parent d44d459 commit 8bc0241
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 52 deletions.
1 change: 1 addition & 0 deletions jax_tpu_embedding/sparsecore/lib/nn/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ pytype_strict_contrib_test(
"requires-tpu",
],
deps = [
":test_utils",
"//jax_tpu_embedding/sparsecore/lib/nn:embedding",
"//jax_tpu_embedding/sparsecore/lib/nn:embedding_spec",
"//jax_tpu_embedding/sparsecore/lib/proto:embedding_spec_py_pb2",
Expand Down
95 changes: 62 additions & 33 deletions jax_tpu_embedding/sparsecore/lib/nn/tests/embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from jax import sharding
from jax_tpu_embedding.sparsecore.lib.nn import embedding
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
from jax_tpu_embedding.sparsecore.lib.nn.tests import test_utils
from jax_tpu_embedding.sparsecore.lib.proto import embedding_spec_pb2
import numpy as np

Expand Down Expand Up @@ -72,9 +73,13 @@ def test_get_valid_table_specs(self):
feature_specs,
global_device_count=jax.device_count(),
)
total_sc_in_test = test_utils.NUM_SC_PER_DEVICE * jax.device_count()
padded_vocab_size = test_utils.round_up_to_multiple(
table_spec.vocabulary_size, 8 * total_sc_in_test
)
expected_stacked_table_spec = embedding_spec.StackedTableSpec(
stack_name=table_spec.name,
stack_vocab_size=128,
stack_vocab_size=padded_vocab_size,
stack_embedding_dim=16,
combiner=table_spec.combiner,
optimizer=table_spec.optimizer,
Expand All @@ -97,7 +102,7 @@ def test_get_valid_table_specs(self):
_setting_in_stack=embedding_spec.TableSettingInStack(
stack_name=table_spec.name,
padded_embedding_dim=table_spec.embedding_dim,
padded_vocab_size=128,
padded_vocab_size=padded_vocab_size,
row_offset_in_shard=0,
shard_rotation=0,
),
Expand Down Expand Up @@ -238,9 +243,13 @@ def test_prepare_features_for_training_with_feature_stacking(self):
(feature_spec_a, feature_spec_b),
global_device_count=jax.device_count(),
)
total_sc_in_test = test_utils.NUM_SC_PER_DEVICE * jax.device_count()
padded_vocab_size = test_utils.round_up_to_multiple(
table_spec.vocabulary_size, 8 * total_sc_in_test
)
expected_stacked_table_spec = embedding_spec.StackedTableSpec(
stack_name="table",
stack_vocab_size=128, # Vocab round up
stack_vocab_size=padded_vocab_size,
stack_embedding_dim=16, # Dim round up
combiner="sum",
optimizer=embedding_spec.SGDOptimizerSpec(),
Expand Down Expand Up @@ -481,19 +490,25 @@ def test_init_embedding_variables_for_pmap(self, optimizer_spec):
feature_specs = [feature_a_spec, feature_b_spec]
embedding.auto_stack_tables(
feature_specs,
num_sc_per_device=4,
num_sc_per_device=test_utils.NUM_SC_PER_DEVICE,
global_device_count=jax.device_count(),
)
# Assert on the preconditions.
self.assertLen(feature_specs, 2)
padded_vocab_size_a = vocab_size_a + 96 # extra padding rows = 96.
total_sc_in_test = test_utils.NUM_SC_PER_DEVICE * jax.device_count()
padded_vocab_size_a = test_utils.round_up_to_multiple(
vocab_size_a, 8 * total_sc_in_test
)
self.assertEqual(
feature_a_spec.table_spec.setting_in_stack.padded_vocab_size,
padded_vocab_size_a,
)
padded_vocab_size_b = test_utils.round_up_to_multiple(
vocab_size_b, 8 * total_sc_in_test
)
self.assertEqual(
feature_b_spec.table_spec.setting_in_stack.padded_vocab_size,
vocab_size_b,
padded_vocab_size_b,
)
updated_table_specs = [f.table_spec for f in feature_specs]

Expand All @@ -520,8 +535,8 @@ def test_init_embedding_variables_for_pmap(self, optimizer_spec):
self.assertEqual(
variable.shape,
(
4,
(padded_vocab_size_a + vocab_size_b) // len(devices),
jax.device_count(),
(padded_vocab_size_a + padded_vocab_size_b) // len(devices),
16,
),
)
Expand All @@ -534,7 +549,7 @@ def test_init_embedding_variables_for_pmap(self, optimizer_spec):
embedding_variables["table_a_table_b"].table.addressable_shards,
(
1,
(padded_vocab_size_a + vocab_size_b) // len(devices),
(padded_vocab_size_a + padded_vocab_size_b) // len(devices),
16,
),
)
Expand Down Expand Up @@ -595,19 +610,25 @@ def test_init_embedding_variables_stacking_for_jit(self, optimizer_spec):
feature_specs = [feature_a_spec, feature_b_spec]
embedding.auto_stack_tables(
feature_specs,
num_sc_per_device=4,
num_sc_per_device=test_utils.NUM_SC_PER_DEVICE,
global_device_count=jax.device_count(),
)
# Assert on the preconditions.
self.assertLen(feature_specs, 2)
padded_vocab_size_a = vocab_size_a + 96 # extra padding rows = 96.
total_sc_in_test = test_utils.NUM_SC_PER_DEVICE * jax.device_count()
padded_vocab_size_a = test_utils.round_up_to_multiple(
vocab_size_a, 8 * total_sc_in_test
)
self.assertEqual(
feature_a_spec.table_spec.setting_in_stack.padded_vocab_size,
padded_vocab_size_a,
)
padded_vocab_size_b = test_utils.round_up_to_multiple(
vocab_size_b, 8 * total_sc_in_test
)
self.assertEqual(
feature_b_spec.table_spec.setting_in_stack.padded_vocab_size,
vocab_size_b,
padded_vocab_size_b,
)
updated_table_specs = [f.table_spec for f in feature_specs]

Expand All @@ -632,7 +653,7 @@ def test_init_embedding_variables_stacking_for_jit(self, optimizer_spec):
for variable in jax.tree.leaves(embedding_variables):
self.assertEqual(
variable.shape,
(padded_vocab_size_a + vocab_size_b, embedding_dim),
(padded_vocab_size_a + padded_vocab_size_b, embedding_dim),
)

self.assertEqual(
Expand All @@ -643,7 +664,7 @@ def test_init_embedding_variables_stacking_for_jit(self, optimizer_spec):
self.assert_all_shards_shape(
variable.addressable_shards,
(
(padded_vocab_size_a + vocab_size_b) // len(devices),
(padded_vocab_size_a + padded_vocab_size_b) // len(devices),
16,
),
)
Expand Down Expand Up @@ -692,7 +713,7 @@ def test_malformated_partition_for_init_embedding_variables(self, pspec):
feature_specs = [feature_a_spec, feature_a_spec]
embedding.auto_stack_tables(
feature_specs,
num_sc_per_device=4,
num_sc_per_device=test_utils.NUM_SC_PER_DEVICE,
global_device_count=jax.device_count(),
)

Expand Down Expand Up @@ -744,7 +765,7 @@ def test_non_default_device_mesh_for_init_embedding_variables(self):
feature_specs = [feature_a_spec, feature_a_spec]
embedding.auto_stack_tables(
feature_specs,
num_sc_per_device=4,
num_sc_per_device=test_utils.NUM_SC_PER_DEVICE,
global_device_count=jax.device_count(),
)

Expand Down Expand Up @@ -820,7 +841,7 @@ def test_muti_dimensional_mesh_for_init_embedding_variables(self):
feature_specs = [feature_a_spec, feature_b_spec]
embedding.auto_stack_tables(
feature_specs,
num_sc_per_device=4,
num_sc_per_device=test_utils.NUM_SC_PER_DEVICE,
global_device_count=jax.device_count(),
)

Expand Down Expand Up @@ -892,28 +913,36 @@ def test_create_proto_from_feature_specs(self):
feature_specs = [feature_a_spec, feature_b_spec]
embedding.auto_stack_tables(
feature_specs,
num_sc_per_device=4,
num_sc_per_device=test_utils.NUM_SC_PER_DEVICE,
global_device_count=jax.device_count(),
)
expected_proto = embedding_spec_pb2.EmbeddingSpecProto()
num_sparsecores = test_utils.NUM_SC_PER_DEVICE * jax.device_count()
padded_vocab_size_a = test_utils.round_up_to_multiple(
feature_a_spec.table_spec.vocabulary_size, 8 * num_sparsecores
)
padded_vocab_size_b = test_utils.round_up_to_multiple(
feature_b_spec.table_spec.vocabulary_size, 8 * num_sparsecores
)
stack_vocab_size = padded_vocab_size_a + padded_vocab_size_b
text_format.Parse(
"""stacked_table_specs {
f"""stacked_table_specs {{
stack_name: "table_a_table_b"
stack_vocab_size: 256
stack_vocab_size: {stack_vocab_size}
stack_embedding_dim: 16
total_sample_count: 32
max_ids_per_partition: 256
max_unique_ids_per_partition: 256
num_sparsecores: 16
table_specs {
num_sparsecores: {num_sparsecores}
table_specs {{
table_name: "table_a"
vocab_size: 32
embedding_dim: 14
padded_vocab_size: 128
padded_vocab_size: {padded_vocab_size_a}
padded_embedding_dim: 16
row_offset_in_shard: 0
shard_rotation: 0
feature_specs {
feature_specs {{
feature_name: "feature_a"
input_shape: 16
input_shape: 1
Expand All @@ -922,28 +951,28 @@ def test_create_proto_from_feature_specs(self):
row_offset: 0
col_offset: 0
col_shift: 0
}
}
table_specs {
}}
}}
table_specs {{
table_name: "table_b"
vocab_size: 128
embedding_dim: 14
padded_vocab_size: 128
padded_vocab_size: {padded_vocab_size_b}
padded_embedding_dim: 16
row_offset_in_shard: 8
shard_rotation: 4
feature_specs {
feature_specs {{
feature_name: "feature_b"
input_shape: 16
input_shape: 1
output_shape: 16
output_shape: 14
row_offset: 16
col_offset: 128
col_offset: {padded_vocab_size_a}
col_shift: 4
}
}
}""",
}}
}}
}}""",
expected_proto,
)
actual = embedding.create_proto_from_feature_specs(
Expand Down
Loading

0 comments on commit 8bc0241

Please sign in to comment.