Skip to content

Commit

Permalink
Add fix for scalar select
Browse files Browse the repository at this point in the history
  • Loading branch information
giuseros committed Sep 19, 2024
1 parent 501d4ed commit 3beae6d
Showing 1 changed file with 10 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -753,12 +753,21 @@ LogicalResult PointerCanonicalizer::rewriteSelectOp(arith::SelectOp selectOp,
Value cond = selectOp.getCondition();
// If we didn't traverse both operands, simply materialize the pointer
if (!pointers.contains(trueVal) || !pointers.contains(falseVal))
return (materializeFatPointer(selectOp, curLoc, curOperand->get()));
return materializeFatPointer(selectOp, curLoc, curOperand->get());

// If both have been traversed, then we can rewrite select of pointers as a
// select of base and offset
FatPtr fatPtrT = pointers[trueVal];
FatPtr fatPtrF = pointers[falseVal];
nextPtr = selectOp.getResult();

// Simple case of a scalar select: update the base pointer
if (!isa<RankedTensorType>(selectOp.getType())) {
FatPtr fatPtr = pointers[trueVal];
pointers[nextPtr] = fatPtr.copyWithOffset(nextPtr);
nextPtr = selectOp.getResult();
return success();
}

// Rewrite `select` for base and offset
Value newBase = rewriter.create<arith::SelectOp>(
Expand All @@ -767,7 +776,6 @@ LogicalResult PointerCanonicalizer::rewriteSelectOp(arith::SelectOp selectOp,
curLoc, cond, fatPtrT.offset, fatPtrF.offset);
assert(fatPtrT.canNarrow == fatPtrF.canNarrow);

nextPtr = selectOp.getResult();
pointers[nextPtr] = fatPtrT.copy(newBase, newOffset);
opToDelete.insert(selectOp);
return success();
Expand Down

0 comments on commit 3beae6d

Please sign in to comment.