Skip to content

Commit 52109df

Browse files
authored
Prevent potential oob read in matxOpTDKernel (#586)
During the last iteration when expanding the linear index into the N-D tensor index, there may be an out-of-bounds read access into the sizes array. The result of the read is unused and the iteration count is known at compile time, so the read may be removed by the compiler. Extract the last iteration into an epilogue to guarantee no out-of-bounds access. Signed-off-by: Thomas Benson <[email protected]>
1 parent 1a1ed4f commit 52109df

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

include/matx/executors/kernel.h

+13-7
Original file line numberDiff line numberDiff line change
@@ -203,20 +203,26 @@ __global__ void matxOpT4StrideKernel(Op op, index_t size0, index_t size1, index_
203203
template <class Op>
204204
__global__ void matxOpTDKernel(Op op, const std::array<index_t, Op::Rank()> sizes, index_t mult) {
205205
std::array<index_t, Op::Rank()> indices;
206+
207+
// This kernel is currently only used for ranks > 4. We assume the rank is
208+
// at least one in the following accesses into sizes and indices
209+
static_assert(Op::Rank() >= 1, "rank must exceed zero");
206210

207211
// Compute the index into the operator for this thread. N-D tensors require more computations
208212
// since we're limited to 3 dimensions in both grid and block, so we need to iterate to compute
209213
// our index.
210214
index_t x_abs = static_cast<index_t>(blockIdx.x) * blockDim.x + threadIdx.x;
211-
bool valid = x_abs < mult*sizes[0];
212-
#pragma unroll
213-
for (int r = 0; r < Op::Rank(); r++) {
214-
indices[r] = x_abs / mult;
215-
x_abs -= indices[r] * mult;
216-
mult /= sizes[r+1];
217-
}
215+
const bool valid = x_abs < mult*sizes[0];
218216

219217
if (valid) {
218+
#pragma unroll
219+
for (int r = 0; r < Op::Rank()-1; r++) {
220+
indices[r] = x_abs / mult;
221+
x_abs -= indices[r] * mult;
222+
mult /= sizes[r+1];
223+
}
224+
indices[Op::Rank()-1] = x_abs / mult;
225+
220226
if constexpr (std::is_pointer_v<Op>) {
221227
(*op)(indices);
222228
}

0 commit comments

Comments
 (0)