Skip to content

Commit 9f531d9

Browse files
Simplify bounds checks for multi-dimensional array accesses
This simplifies the array bounds check code that is emitted for multi-dimensional array accesses that use "regular" indexing, i.e. accesses of the form `A[i1,i2,...,iN]` to some `N`-dimensional array `A`. For example, with this change, the access `A[i,j,k]` to an array with the three dimensions `m`, `n` and `o` now leads to bounds checks that correspond to the following pseudo code: ``` if (i >= m) out_of_bounds_error(); else if (j >= n) out_of_bounds_error(); else if (k >= o) out_of_bounds_error(); ``` So far, the following more complicated bounds checks would have been emitted: ``` if (i >= m) out_of_bounds_error(); else if (j >= n) out_of_bounds_error(); else if (((k * n + j) * m + i) >= m * n * o) out_of_bounds_error(); ``` Julia also allows one-dimensional and "partial" linear indexing (see #14770), i.e. the number of indices used to access an array does not have to match the actual number of dimensions of the accessed array. For this case we still have use this old scheme. One motivation for this change was the following: expressions like `((k * n + j) * m + i)` are non-affine and Polly would not be able to analyze them. This change therefore also facilitates Polly's bounds check elimination logic, which would hoist such checks out of loops or may remove them entirely where possible.
1 parent e4b5233 commit 9f531d9

File tree

1 file changed

+26
-6
lines changed

1 file changed

+26
-6
lines changed

src/cgutils.cpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,6 +1247,14 @@ static void assign_arrayvar(jl_arrayvar_t &av, const jl_cgval_t &ainfo, jl_codec
12471247
builder.CreateStore(emit_arraysize(ainfo, i+1, ctx), av.sizes[i]);
12481248
}
12491249

1250+
// Returns the size of the array represented by `tinfo` for the given dimension `dim` if
1251+
// `dim` is a valid dimension, otherwise returns constant one.
1252+
static Value *emit_arraysize_for_unsafe_dim(const jl_cgval_t &tinfo, jl_value_t *ex, size_t dim,
1253+
size_t nd, jl_codectx_t *ctx)
1254+
{
1255+
return dim > nd ? ConstantInt::get(T_size, 1) : emit_arraysize(tinfo, ex, dim, ctx);
1256+
}
1257+
12501258
static Value *emit_array_nd_index(const jl_cgval_t &ainfo, jl_value_t *ex, size_t nd, jl_value_t **args,
12511259
size_t nidxs, jl_codectx_t *ctx)
12521260
{
@@ -1267,12 +1275,12 @@ static Value *emit_array_nd_index(const jl_cgval_t &ainfo, jl_value_t *ex, size_
12671275
for(size_t k=0; k < nidxs; k++) {
12681276
idxs[k] = emit_unbox(T_size, emit_expr(args[k], ctx), NULL);
12691277
}
1278+
Value *ii;
12701279
for(size_t k=0; k < nidxs; k++) {
1271-
Value *ii = builder.CreateSub(idxs[k], ConstantInt::get(T_size, 1));
1280+
ii = builder.CreateSub(idxs[k], ConstantInt::get(T_size, 1));
12721281
i = builder.CreateAdd(i, builder.CreateMul(ii, stride));
12731282
if (k < nidxs-1) {
1274-
Value *d =
1275-
k >= nd ? ConstantInt::get(T_size, 1) : emit_arraysize(ainfo, ex, k+1, ctx);
1283+
Value *d = emit_arraysize_for_unsafe_dim(ainfo, ex, k+1, nd, ctx);
12761284
#if CHECK_BOUNDS==1
12771285
if (bc) {
12781286
BasicBlock *okBB = BasicBlock::Create(jl_LLVMContext, "ib");
@@ -1287,9 +1295,21 @@ static Value *emit_array_nd_index(const jl_cgval_t &ainfo, jl_value_t *ex, size_
12871295
}
12881296
#if CHECK_BOUNDS==1
12891297
if (bc) {
1290-
Value *alen = emit_arraylen(ainfo, ex, ctx);
1291-
// if !(i < alen) goto error
1292-
builder.CreateCondBr(builder.CreateICmpULT(i, alen), endBB, failBB);
1298+
// We have already emitted a bounds check for each index except for
1299+
// the last one which we therefore have to do here.
1300+
bool linear_indexing = nidxs < nd;
1301+
if (linear_indexing) {
1302+
// Compare the linearized index `i` against the linearized size of
1303+
// the accessed array, i.e. `if !(i < alen) goto error`.
1304+
Value *alen = emit_arraylen(ainfo, ex, ctx);
1305+
builder.CreateCondBr(builder.CreateICmpULT(i, alen), endBB, failBB);
1306+
} else {
1307+
// Compare the last index of the access against the last dimension of
1308+
// the accessed array, i.e. `if !(last_index < last_dimension) goto error`.
1309+
Value *last_index = ii;
1310+
Value *last_dimension = emit_arraysize_for_unsafe_dim(ainfo, ex, nidxs, nd, ctx);
1311+
builder.CreateCondBr(builder.CreateICmpULT(last_index, last_dimension), endBB, failBB);
1312+
}
12931313

12941314
ctx->f->getBasicBlockList().push_back(failBB);
12951315
builder.SetInsertPoint(failBB);

0 commit comments

Comments
 (0)