From f728c94665b36c89b29ff92fc0d865e91ec920f0 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Sat, 19 Oct 2024 03:41:12 -0700 Subject: [PATCH] Partially back out "[fbgemm_gpu] Add support for int64_t indices and offsets in TBE inference [7C/N]" (#3257) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3257 X-link: https://github.com/facebookresearch/FBGEMM/pull/358 Original commit changeset: 270834722e8b Original Phabricator Diff: D63778645 The original diff D63778645 contained two features: (a) extending the remap index kernels from supporting int32-only to supporting both int32 and int64 (b) update the frontend code to construct int64 remap index arrays Some downstream code has picked up D63778645 and have generated models with int64 remapping indices, which fail 3 downstream unit tests. Since (b) is the problematic feature that is breaking downstream, only (b) has been reverted in this diff, as (a) is now needed for those unit tests to pass. Reviewed By: jianyuh Differential Revision: D64618221 fbshipit-source-id: 6c4838dbfaf301f1204d469d7f4bf7cfe4926b2e --- .../embedding_forward_quantized_cpu_template.cpp | 2 +- .../split_table_batched_embeddings_ops_inference.py | 11 +++++------ .../test/tbe/utils/split_embeddings_utils_test.py | 6 +++--- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp index 4429e580d0..ce40df0789 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp @@ -78,7 +78,7 @@ void pruned_hashmap_insert_{{ wdesc }}_cpu( const auto* dense_indices_acc = dense_indices.data_ptr(); const auto* offsets_acc = offsets.data_ptr(); - auto hash_table_acc = hash_table.accessor(); + auto hash_table_acc = hash_table.accessor(); const auto hash_table_offsets_acc = hash_table_offsets.accessor(); for (const auto t : c10::irange(T)) { diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index 541297b90c..e3b5d51af3 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -397,14 +397,13 @@ def max_ty_D(ty: SparseType) -> int: self.assign_embedding_weights(weight_lists) # Handle index remapping for embedding pruning. - # All buffers are int64 in order to support both int32 and int64 indices. self.register_buffer( "index_remappings_array_offsets", torch.empty(0, device=self.current_device, dtype=torch.int64), ) self.register_buffer( "index_remappings_array", - torch.empty(0, device=self.current_device, dtype=torch.int64), + torch.empty(0, device=self.current_device, dtype=torch.int32), ) self.register_buffer( "index_remapping_hash_table_offsets", @@ -412,7 +411,7 @@ def max_ty_D(ty: SparseType) -> int: ) self.register_buffer( "index_remapping_hash_table", - torch.empty(0, device=self.current_device, dtype=torch.int64), + torch.empty(0, device=self.current_device, dtype=torch.int32), ) self.register_buffer( "original_rows_per_table", @@ -1529,11 +1528,11 @@ def set_index_remappings_array( index_remappings_filter_nones.append(mapping) if len(index_remappings_filter_nones) == 0: self.index_remappings_array = torch.empty( - 0, dtype=torch.int64, device=self.current_device + 0, dtype=torch.int32, device=self.current_device ) else: self.index_remappings_array = torch.cat(index_remappings_filter_nones).to( - dtype=torch.int64, device=self.current_device + self.current_device ) def set_index_remappings( @@ -1556,7 +1555,7 @@ def set_index_remappings( ] hash_table = torch.empty( (sum(capacities), 2), - dtype=torch.int64, + dtype=torch.int32, ) hash_table[:, :] = -1 hash_table_offsets = torch.tensor([0] + list(accumulate(capacities))).long() diff --git a/fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py b/fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py index 5d9b3eabe6..1d475909c1 100644 --- a/fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py +++ b/fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py @@ -469,7 +469,7 @@ def test_pruning( # Initialize and insert Hashmap index remapping based data structure hash_table = torch.empty( (sum(capacities), 2), - dtype=torch.int64, + dtype=torch.int32, ) hash_table[:, :] = -1 hash_table_offsets = torch.tensor([0] + np.cumsum(capacities).tolist()).long() @@ -486,7 +486,7 @@ def test_pruning( # Initialize and insert Array index remapping based data structure index_remappings_array = torch.tensor( [-1] * original_E * T, - dtype=torch.int64, + dtype=torch.int32, device=current_device, ) index_remappings_array_offsets = torch.empty( @@ -498,7 +498,7 @@ def test_pruning( for t in range(T): indice_t = (indices.view(T, B, L))[t].long().view(-1).to(current_device) dense_indice_t = ( - (dense_indices.view(T, B, L))[t].long().view(-1).to(current_device) + (dense_indices.view(T, B, L))[t].view(-1).to(current_device) ) selected_indices = torch.add(indice_t, t * original_E)[:E] index_remappings_array[selected_indices] = dense_indice_t