diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_cpu.cpp b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_cpu.cpp index fe5c6f4e1a..720d0a2612 100644 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_cpu.cpp +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_cpu.cpp @@ -380,18 +380,18 @@ namespace internal { namespace { -template +template void csr2csc_template_( HyperCompressedSparseColumn& csc, int B, - const at::TensorAccessor& csr_offsets, - const at::TensorAccessor& csr_indices, + const at::TensorAccessor& csr_offsets, + const at::TensorAccessor& csr_indices, const at::TensorAccessor& csr_weights, int64_t pooling_mode, const int* table_to_feature_offset, int64_t num_embeddings) { csc.num_non_zero_columns = 0; - int64_t nnz = csr_offsets[table_to_feature_offset[1] * B] - + const auto nnz = csr_offsets[table_to_feature_offset[1] * B] - csr_offsets[table_to_feature_offset[0] * B]; if (nnz == 0) { return; @@ -407,7 +407,7 @@ void csr2csc_template_( [[maybe_unused]] int column_ptr_curr = 0; bool is_shared_table = table_to_feature_offset[1] > table_to_feature_offset[0] + 1; - auto NS = csr_offsets[table_to_feature_offset[1] * B] - + const auto NS = csr_offsets[(size_t)table_to_feature_offset[1] * B] - csr_offsets[table_to_feature_offset[0] * B]; using pair_t = std::pair; @@ -432,9 +432,9 @@ void csr2csc_template_( #pragma omp parallel for for (int b = 0; b < B; ++b) { const auto FBb = feature * B + b; - int64_t pool_begin = csr_offsets[FBb]; - int64_t pool_end = csr_offsets[FBb + 1]; - int64_t L = pool_end - pool_begin; + const auto pool_begin = csr_offsets[FBb]; + const auto pool_end = csr_offsets[FBb + 1]; + const auto L = pool_end - pool_begin; // MEAN pooling will not work with indice_weights! double scale_factor = (static_cast(pooling_mode) == PoolingMode::MEAN && @@ -581,39 +581,40 @@ void csr2csc_template_( assert(column_ptr_curr == nnz); } -#define INSTANTIATE_BATCHED_CSR2CSC(SCALAR_T) \ - template void csr2csc_template_( \ - HyperCompressedSparseColumn & csc, \ - int B, \ - const at::TensorAccessor& csr_offsets, \ - const at::TensorAccessor& csr_indices, \ - const at::TensorAccessor& csr_weights, \ - int64_t pooling_mode, \ - const int* table_to_feature_offset, \ - int64_t num_embeddings); \ - \ - template void csr2csc_template_( \ - HyperCompressedSparseColumn & csc, \ - int B, \ - const at::TensorAccessor& csr_offsets, \ - const at::TensorAccessor& csr_indices, \ - const at::TensorAccessor& csr_weights, \ - int64_t pooling_mode, \ - const int* table_to_feature_offset, \ +#define INSTANTIATE_CSR2CSC_TEMPLATE_0(index_t, scalar_t, is_value_pair) \ + template void csr2csc_template_( \ + HyperCompressedSparseColumn & csc, \ + int B, \ + const at::TensorAccessor& csr_offsets, \ + const at::TensorAccessor& csr_indices, \ + const at::TensorAccessor& csr_weights, \ + int64_t pooling_mode, \ + const int* table_to_feature_offset, \ int64_t num_embeddings); -INSTANTIATE_BATCHED_CSR2CSC(float) -INSTANTIATE_BATCHED_CSR2CSC(double) -#undef INSTANTIATE_BATCHED_CSR2CSC +#define INSTANTIATE_CSR2CSC_TEMPLATE_1(index_t, scalar_t) \ + INSTANTIATE_CSR2CSC_TEMPLATE_0(index_t, scalar_t, true); \ + INSTANTIATE_CSR2CSC_TEMPLATE_0(index_t, scalar_t, false); + +#define INSTANTIATE_CSR2CSC_TEMPLATE_2(index_t) \ + INSTANTIATE_CSR2CSC_TEMPLATE_1(index_t, float); \ + INSTANTIATE_CSR2CSC_TEMPLATE_1(index_t, double); + +INSTANTIATE_CSR2CSC_TEMPLATE_2(int32_t); +INSTANTIATE_CSR2CSC_TEMPLATE_2(int64_t); + +#undef INSTANTIATE_CSR2CSC_TEMPLATE_2 +#undef INSTANTIATE_CSR2CSC_TEMPLATE_1 +#undef INSTANTIATE_CSR2CSC_TEMPLATE_0 } // namespace -template +template void csr2csc( HyperCompressedSparseColumn& csc, int B, - const at::TensorAccessor& csr_offsets, - const at::TensorAccessor& csr_indices, + const at::TensorAccessor& csr_offsets, + const at::TensorAccessor& csr_indices, const at::TensorAccessor& csr_weights, int64_t pooling_mode, const int* table_to_feature_offset, @@ -621,7 +622,7 @@ void csr2csc( bool has_weights = csr_weights.data() != nullptr; if (has_weights || static_cast(pooling_mode) == PoolingMode::MEAN) { - csr2csc_template_( + csr2csc_template_( csc, B, csr_offsets, @@ -631,7 +632,7 @@ void csr2csc( table_to_feature_offset, num_embeddings); } else { - csr2csc_template_( + csr2csc_template_( csc, B, csr_offsets, @@ -643,25 +644,26 @@ void csr2csc( } } -template void csr2csc( - HyperCompressedSparseColumn& csc, - int B, - const at::TensorAccessor& csr_offsets, - const at::TensorAccessor& csr_indices, - const at::TensorAccessor& csr_weights, - int64_t pooling_mode, - const int* table_to_feature_offset, - int64_t num_embeddings); +#define INSTANTIATE_CSR2CSC_0(index_t, scalar_t) \ + template void csr2csc( \ + HyperCompressedSparseColumn & csc, \ + int B, \ + const at::TensorAccessor& csr_offsets, \ + const at::TensorAccessor& csr_indices, \ + const at::TensorAccessor& csr_weights, \ + int64_t pooling_mode, \ + const int* table_to_feature_offset, \ + int64_t num_embeddings); -template void csr2csc( - HyperCompressedSparseColumn& csc, - int B, - const at::TensorAccessor& csr_offsets, - const at::TensorAccessor& csr_indices, - const at::TensorAccessor& csr_weights, - int64_t pooling_mode, - const int* table_to_feature_offset, - int64_t num_embeddings); +#define INSTANTIATE_CSR2CSC_1(index_t) \ + INSTANTIATE_CSR2CSC_0(index_t, float); \ + INSTANTIATE_CSR2CSC_0(index_t, double); + +INSTANTIATE_CSR2CSC_1(int32_t); +INSTANTIATE_CSR2CSC_1(int64_t); + +#undef INSTANTIATE_CSR2CSC_1 +#undef INSTANTIATE_CSR2CSC_0 } // namespace internal diff --git a/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_split_cpu.h b/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_split_cpu.h index 908fe9fb11..2025f9d7fb 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_split_cpu.h +++ b/fbgemm_gpu/include/fbgemm_gpu/embedding_forward_split_cpu.h @@ -116,12 +116,12 @@ struct HyperCompressedSparseColumn { } }; -template +template void csr2csc( HyperCompressedSparseColumn& csc, int B, - const at::TensorAccessor& csr_offsets, - const at::TensorAccessor& csr_indices, + const at::TensorAccessor& csr_offsets, + const at::TensorAccessor& csr_indices, const at::TensorAccessor& csr_weights, int64_t pooling_mode, const int* table_to_feature_offset, diff --git a/fbgemm_gpu/test/tbe/utils/cpu_kernel_test.cpp b/fbgemm_gpu/test/tbe/utils/cpu_kernel_test.cpp index fbc5dc8b2c..8a94d1b370 100644 --- a/fbgemm_gpu/test/tbe/utils/cpu_kernel_test.cpp +++ b/fbgemm_gpu/test/tbe/utils/cpu_kernel_test.cpp @@ -15,11 +15,14 @@ #include "fbgemm_gpu/embedding_forward_split_cpu.h" #include "torch/types.h" // @manual=//caffe2:torch-cpp-cpu -TEST(CpuKernelTest, csr2csc_test) { +template +void test_csr2csc() { internal::HyperCompressedSparseColumn csc; int B = 2; - at::Tensor offsets = torch::tensor({0, 4, 8}); - at::Tensor indices = torch::tensor({1, 2, 4, 5, 4, 3, 2, 9}); + at::Tensor offsets = + torch::tensor({0, 4, 8}, torch::TensorOptions().dtype(DType)); + at::Tensor indices = torch::tensor( + {1, 2, 4, 5, 4, 3, 2, 9}, torch::TensorOptions().dtype(DType)); int64_t pooling_mode = (int64_t)fbgemm_gpu::PoolingMode::SUM; int table_to_feature_offset[2] = {0, 1}; int num_embeddings = 10; @@ -27,8 +30,8 @@ TEST(CpuKernelTest, csr2csc_test) { ::internal::csr2csc( csc, B, - offsets.accessor(), - indices.accessor(), + offsets.accessor(), + indices.accessor(), at::TensorAccessor, 1>( nullptr, nullptr, nullptr), // no weights pooling_mode, @@ -61,8 +64,8 @@ TEST(CpuKernelTest, csr2csc_test) { ::internal::csr2csc( csc_weighted, B, - offsets.accessor(), - indices.accessor(), + offsets.accessor(), + indices.accessor(), indice_weights.accessor, 1>(), pooling_mode, table_to_feature_offset, @@ -88,3 +91,11 @@ TEST(CpuKernelTest, csr2csc_test) { EXPECT_EQ(expect_weights[i], csc_weighted.weights[i]); } } + +TEST(CpuKernelTest, csr2csc_test_int32) { + test_csr2csc(); +} + +TEST(CpuKernelTest, csr2csc_test_int64) { + test_csr2csc(); +}