Skip to content

Commit d474ca3

Browse files
authored
Merge pull request #18064 from MatthiasJReisinger/mjr/polly/boundchecks
Simplify bounds checks for multi-dimensional array accesses
2 parents 2826047 + 9f531d9 commit d474ca3

File tree

1 file changed

+26
-6
lines changed

1 file changed

+26
-6
lines changed

src/cgutils.cpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,6 +1262,14 @@ static void assign_arrayvar(jl_arrayvar_t &av, const jl_cgval_t &ainfo, jl_codec
12621262
builder.CreateStore(emit_arraysize(ainfo, i+1, ctx), av.sizes[i]);
12631263
}
12641264

1265+
// Returns the size of the array represented by `tinfo` for the given dimension `dim` if
1266+
// `dim` is a valid dimension, otherwise returns constant one.
1267+
static Value *emit_arraysize_for_unsafe_dim(const jl_cgval_t &tinfo, jl_value_t *ex, size_t dim,
1268+
size_t nd, jl_codectx_t *ctx)
1269+
{
1270+
return dim > nd ? ConstantInt::get(T_size, 1) : emit_arraysize(tinfo, ex, dim, ctx);
1271+
}
1272+
12651273
static Value *emit_array_nd_index(const jl_cgval_t &ainfo, jl_value_t *ex, size_t nd, jl_value_t **args,
12661274
size_t nidxs, jl_codectx_t *ctx)
12671275
{
@@ -1282,12 +1290,12 @@ static Value *emit_array_nd_index(const jl_cgval_t &ainfo, jl_value_t *ex, size_
12821290
for(size_t k=0; k < nidxs; k++) {
12831291
idxs[k] = emit_unbox(T_size, emit_expr(args[k], ctx), NULL);
12841292
}
1293+
Value *ii;
12851294
for(size_t k=0; k < nidxs; k++) {
1286-
Value *ii = builder.CreateSub(idxs[k], ConstantInt::get(T_size, 1));
1295+
ii = builder.CreateSub(idxs[k], ConstantInt::get(T_size, 1));
12871296
i = builder.CreateAdd(i, builder.CreateMul(ii, stride));
12881297
if (k < nidxs-1) {
1289-
Value *d =
1290-
k >= nd ? ConstantInt::get(T_size, 1) : emit_arraysize(ainfo, ex, k+1, ctx);
1298+
Value *d = emit_arraysize_for_unsafe_dim(ainfo, ex, k+1, nd, ctx);
12911299
#if CHECK_BOUNDS==1
12921300
if (bc) {
12931301
BasicBlock *okBB = BasicBlock::Create(jl_LLVMContext, "ib");
@@ -1302,9 +1310,21 @@ static Value *emit_array_nd_index(const jl_cgval_t &ainfo, jl_value_t *ex, size_
13021310
}
13031311
#if CHECK_BOUNDS==1
13041312
if (bc) {
1305-
Value *alen = emit_arraylen(ainfo, ex, ctx);
1306-
// if !(i < alen) goto error
1307-
builder.CreateCondBr(builder.CreateICmpULT(i, alen), endBB, failBB);
1313+
// We have already emitted a bounds check for each index except for
1314+
// the last one which we therefore have to do here.
1315+
bool linear_indexing = nidxs < nd;
1316+
if (linear_indexing) {
1317+
// Compare the linearized index `i` against the linearized size of
1318+
// the accessed array, i.e. `if !(i < alen) goto error`.
1319+
Value *alen = emit_arraylen(ainfo, ex, ctx);
1320+
builder.CreateCondBr(builder.CreateICmpULT(i, alen), endBB, failBB);
1321+
} else {
1322+
// Compare the last index of the access against the last dimension of
1323+
// the accessed array, i.e. `if !(last_index < last_dimension) goto error`.
1324+
Value *last_index = ii;
1325+
Value *last_dimension = emit_arraysize_for_unsafe_dim(ainfo, ex, nidxs, nd, ctx);
1326+
builder.CreateCondBr(builder.CreateICmpULT(last_index, last_dimension), endBB, failBB);
1327+
}
13081328

13091329
ctx->f->getBasicBlockList().push_back(failBB);
13101330
builder.SetInsertPoint(failBB);

0 commit comments

Comments
 (0)