Skip to content

Commit

Permalink
Add check to ensure that there is enough room in permuted_indices (py…
Browse files Browse the repository at this point in the history
…torch#3403)

Summary:

In scope of the MTIA IG Sprint we encountered a crash (T208229934) that turned out to be an incorrectly provided argument for `permuted_lengths_sum` argument.


This argument suppose to speed the operation, but actually, the true value it suppose to substitute is computed every time anyway. 

One radical solution would be to drop the argument entirely, but it requires more thoughtful analysis. This diff just prevent clearly faulty case of allocated `permuted_indices` being less than required by the function logic.

Exactly this case was spotted in case of T208229934.

Differential Revision: D66274522
  • Loading branch information
Sergey Zimin authored and facebook-github-bot committed Nov 22, 2024
1 parent b94be33 commit 27fea8d
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,11 @@ std::tuple<Tensor, Tensor, std::optional<Tensor>> permute_2D_sparse_data_cpu(
int64_t permuted_indices_size = 0;
if (permuted_lengths_sum.has_value()) {
permuted_indices_size = permuted_lengths_sum.value();

// Ensure there is enough space.
TORCH_CHECK(
permuted_indices_size >=
output_offsets_per_thread_cumsum[num_threads * FALSE_SHARING_PAD]);
} else {
permuted_indices_size =
output_offsets_per_thread_cumsum[num_threads * FALSE_SHARING_PAD];
Expand Down Expand Up @@ -842,6 +847,11 @@ std::tuple<Tensor, Tensor, std::optional<Tensor>> permute_1D_sparse_data_cpu(
int64_t permuted_indices_size = 0;
if (permuted_lengths_sum.has_value()) {
permuted_indices_size = permuted_lengths_sum.value();

// Ensure there is enough space.
TORCH_CHECK(
permuted_indices_size >=
output_offsets[permuted_lengths_size].item<int64_t>());
} else {
permuted_indices_size =
output_offsets[permuted_lengths_size].item<int64_t>();
Expand Down

0 comments on commit 27fea8d

Please sign in to comment.