Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LV] Fix runtime-VF logic when generating RT-checks #130118

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1924,21 +1924,17 @@ class GeneratedRTChecks {
"vector.memcheck");

auto DiffChecks = RtPtrChecking.getDiffChecks();
if (DiffChecks) {
Value *RuntimeVF = nullptr;
MemRuntimeCheckCond = addDiffRuntimeChecks(
MemCheckBlock->getTerminator(), *DiffChecks, MemCheckExp,
[VF, &RuntimeVF](IRBuilderBase &B, unsigned Bits) {
if (!RuntimeVF)
RuntimeVF = getRuntimeVF(B, B.getIntNTy(Bits), VF);
return RuntimeVF;
},
IC);
} else {
MemRuntimeCheckCond = addRuntimeChecks(
MemCheckBlock->getTerminator(), L, RtPtrChecking.getChecks(),
MemCheckExp, VectorizerParams::HoistRuntimeChecks);
}
MemRuntimeCheckCond =
DiffChecks
? addDiffRuntimeChecks(
MemCheckBlock->getTerminator(), *DiffChecks, MemCheckExp,
[VF](IRBuilderBase &B, unsigned Bits) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure how this fixes the issue described in the commit message. The lambda should always be called with the same VF and arguments, unless I am missing something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is two-fold: first, the runtime check is generated as as an llvm.vscale expression: that is llvm.vscale multiplied by a constant. Next, addDiffRuntimeChecks creates a mul of this result, and an SCEV expansion of SinkStart - SrcStart, and looks up this pair in a map, eliminating redundant compares. Now, if my understanding is correct, it is not safe to cache a llvm.vscale call.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fhahn I thought that at first too, but when I looked into this it does look like a bug to me, but perhaps I'm missing something. The issue here is in addDiffRuntimeChecks this code:

  for (const auto &[SrcStart, SinkStart, AccessSize, NeedsFreeze] : Checks) {
    Type *Ty = SinkStart->getType();
    // Compute VF * IC * AccessSize.
    auto *VFTimesUFTimesSize =
        ChkBuilder.CreateMul(GetVF(ChkBuilder, Ty->getScalarSizeInBits()),
                             ConstantInt::get(Ty, IC * AccessSize));

we are potentially passing in different values for the Bits argument, i.e. for each value of Bits we should be creating a different runtime VF. See the bit in the lambda function getRuntimeVF(B, B.getIntNTy(Bits), VF). So if on the first instance Bits=32, we create a 32-bit runtime VF and cache that. Then on the second iteration we might pass in Bits=64 and still return the 32-bit runtime VF. I don't even know how the ChkBuilder.CreateMul worked given the mismatched types.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then it gets worse with this code:

    // Check if the same compare has already been created earlier. In that case,
    // there is no need to check it again.
    Value *IsConflict = SeenCompares.lookup({Diff, VFTimesUFTimesSize});

because even though SrcStart, SinkStart, AccessSize, NeedsFreeze may all be different we still perform the lookup based on the first runtime VF, potentially getting the value IsConflict wrong.

Copy link
Contributor

@david-arm david-arm Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So any time you have different types for SinkStart in your list you potentially end up with incorrect runtime diff checks.

Copy link
Contributor

@david-arm david-arm Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps now I think about this you explicitly do want to cache the runtime VF because the code in addDiffRuntimeChecks expects that, however there are no comments to that affect. And also, we should be caching a single value for all values of Bits in that case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about changing Value *RuntimeVF to a map such as DenseMap<unsigned, Value*> RuntimeVFs?

return getRuntimeVF(B, B.getIntNTy(Bits), VF);
},
IC)
: addRuntimeChecks(MemCheckBlock->getTerminator(), L,
RtPtrChecking.getChecks(), MemCheckExp,
VectorizerParams::HoistRuntimeChecks);
assert(MemRuntimeCheckCond &&
"no RT checks generated although RtPtrChecking "
"claimed checks are required");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ target triple = "aarch64-unknown-linux-gnu"

; Test case where the minimum profitable trip count due to runtime checks
; exceeds VF.getKnownMinValue() * UF.
; FIXME: The code currently incorrectly is missing a umax(VF * UF, 28).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this FIXME is wrong. We already generate the umax considering UF * VF. This is for the minimum iteration check, and UF*VF is needed for correctness.

the constant is the minimum iteration count needed to be profitable, but not for correctness.

AFAICT the only reason the constant gets changed is due to the runtime checks being more expensive now, due to duplicated RT VF computations?

the actual memory checks seem to use the same runtime VF as originally?

Copy link
Contributor Author

@artagnon artagnon Mar 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAICT the only reason the constant gets changed is due to the runtime checks being more expensive now, due to duplicated RT VF computations?

If my understanding is correct, this constant is generated from the SCEVExpander call in addDiffRuntimeChecks, and has nothing to do with cost.

The constant is MinProfTC in the creation of the binary intrinsic in emitIterationCountCheck?

This is for the minimum iteration check, and UF*VF is needed for correctness.

Yes, it's the profitability check that seems to be broken.

define void @min_trip_count_due_to_runtime_checks_1(ptr %dst.1, ptr %dst.2, ptr %src.1, ptr %src.2, i64 %n) {
; CHECK-LABEL: @min_trip_count_due_to_runtime_checks_1(
; CHECK-NEXT: entry:
Expand All @@ -16,7 +15,7 @@ define void @min_trip_count_due_to_runtime_checks_1(ptr %dst.1, ptr %dst.2, ptr
; CHECK-NEXT: [[UMAX:%.*]] = call i64 @llvm.umax.i64(i64 [[N:%.*]], i64 1)
; CHECK-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
; CHECK-NEXT: [[TMP1:%.*]] = mul i64 [[TMP0]], 4
; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.umax.i64(i64 20, i64 [[TMP1]])
; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.umax.i64(i64 28, i64 [[TMP1]])
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[UMAX]], [[TMP2]]
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_MEMCHECK:%.*]]
; CHECK: vector.memcheck:
Expand All @@ -25,21 +24,29 @@ define void @min_trip_count_due_to_runtime_checks_1(ptr %dst.1, ptr %dst.2, ptr
; CHECK-NEXT: [[TMP5:%.*]] = mul i64 [[TMP4]], 16
; CHECK-NEXT: [[TMP6:%.*]] = sub i64 [[DST_21]], [[DST_12]]
; CHECK-NEXT: [[DIFF_CHECK:%.*]] = icmp ult i64 [[TMP6]], [[TMP5]]
; CHECK-NEXT: [[TMP7:%.*]] = mul i64 [[TMP4]], 16
; CHECK-NEXT: [[TMP7:%.*]] = call i64 @llvm.vscale.i64()
; CHECK-NEXT: [[TMP18:%.*]] = mul i64 [[TMP7]], 2
; CHECK-NEXT: [[TMP9:%.*]] = mul i64 [[TMP18]], 16
; CHECK-NEXT: [[TMP8:%.*]] = sub i64 [[DST_12]], [[SRC_13]]
; CHECK-NEXT: [[DIFF_CHECK4:%.*]] = icmp ult i64 [[TMP8]], [[TMP7]]
; CHECK-NEXT: [[DIFF_CHECK4:%.*]] = icmp ult i64 [[TMP8]], [[TMP9]]
; CHECK-NEXT: [[CONFLICT_RDX:%.*]] = or i1 [[DIFF_CHECK]], [[DIFF_CHECK4]]
; CHECK-NEXT: [[TMP9:%.*]] = mul i64 [[TMP4]], 16
; CHECK-NEXT: [[TMP11:%.*]] = call i64 @llvm.vscale.i64()
; CHECK-NEXT: [[TMP22:%.*]] = mul i64 [[TMP11]], 2
; CHECK-NEXT: [[TMP13:%.*]] = mul i64 [[TMP22]], 16
; CHECK-NEXT: [[TMP10:%.*]] = sub i64 [[DST_12]], [[SRC_25]]
; CHECK-NEXT: [[DIFF_CHECK6:%.*]] = icmp ult i64 [[TMP10]], [[TMP9]]
; CHECK-NEXT: [[DIFF_CHECK6:%.*]] = icmp ult i64 [[TMP10]], [[TMP13]]
; CHECK-NEXT: [[CONFLICT_RDX7:%.*]] = or i1 [[CONFLICT_RDX]], [[DIFF_CHECK6]]
; CHECK-NEXT: [[TMP11:%.*]] = mul i64 [[TMP4]], 16
; CHECK-NEXT: [[TMP24:%.*]] = call i64 @llvm.vscale.i64()
; CHECK-NEXT: [[TMP26:%.*]] = mul i64 [[TMP24]], 2
; CHECK-NEXT: [[TMP38:%.*]] = mul i64 [[TMP26]], 16
; CHECK-NEXT: [[TMP12:%.*]] = sub i64 [[DST_21]], [[SRC_13]]
; CHECK-NEXT: [[DIFF_CHECK8:%.*]] = icmp ult i64 [[TMP12]], [[TMP11]]
; CHECK-NEXT: [[DIFF_CHECK8:%.*]] = icmp ult i64 [[TMP12]], [[TMP38]]
; CHECK-NEXT: [[CONFLICT_RDX9:%.*]] = or i1 [[CONFLICT_RDX7]], [[DIFF_CHECK8]]
; CHECK-NEXT: [[TMP13:%.*]] = mul i64 [[TMP4]], 16
; CHECK-NEXT: [[TMP19:%.*]] = call i64 @llvm.vscale.i64()
; CHECK-NEXT: [[TMP20:%.*]] = mul i64 [[TMP19]], 2
; CHECK-NEXT: [[TMP21:%.*]] = mul i64 [[TMP20]], 16
; CHECK-NEXT: [[TMP14:%.*]] = sub i64 [[DST_21]], [[SRC_25]]
; CHECK-NEXT: [[DIFF_CHECK10:%.*]] = icmp ult i64 [[TMP14]], [[TMP13]]
; CHECK-NEXT: [[DIFF_CHECK10:%.*]] = icmp ult i64 [[TMP14]], [[TMP21]]
; CHECK-NEXT: [[CONFLICT_RDX11:%.*]] = or i1 [[CONFLICT_RDX9]], [[DIFF_CHECK10]]
; CHECK-NEXT: br i1 [[CONFLICT_RDX11]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
; CHECK: vector.ph:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ define void @vp_smax(ptr %a, ptr %b, ptr %c, i64 %N) {
; IF-EVL-NEXT: [[TMP0:%.*]] = sub i64 -1, [[N]]
; IF-EVL-NEXT: [[TMP1:%.*]] = call i64 @llvm.vscale.i64()
; IF-EVL-NEXT: [[TMP2:%.*]] = mul i64 [[TMP1]], 4
; IF-EVL-NEXT: [[TMP3:%.*]] = call i64 @llvm.umax.i64(i64 13, i64 [[TMP2]])
; IF-EVL-NEXT: [[TMP3:%.*]] = call i64 @llvm.umax.i64(i64 16, i64 [[TMP2]])
; IF-EVL-NEXT: [[TMP22:%.*]] = icmp ult i64 [[TMP0]], [[TMP3]]
; IF-EVL-NEXT: br i1 [[TMP22]], label %[[SCALAR_PH:.*]], label %[[VECTOR_MEMCHECK:.*]]
; IF-EVL: [[VECTOR_MEMCHECK]]:
Expand All @@ -28,9 +28,11 @@ define void @vp_smax(ptr %a, ptr %b, ptr %c, i64 %N) {
; IF-EVL-NEXT: [[TMP23:%.*]] = mul i64 [[TMP5]], 4
; IF-EVL-NEXT: [[TMP24:%.*]] = sub i64 [[A1]], [[B2]]
; IF-EVL-NEXT: [[DIFF_CHECK:%.*]] = icmp ult i64 [[TMP24]], [[TMP23]]
; IF-EVL-NEXT: [[TMP25:%.*]] = mul i64 [[TMP5]], 4
; IF-EVL-NEXT: [[TMP15:%.*]] = call i64 @llvm.vscale.i64()
; IF-EVL-NEXT: [[TMP25:%.*]] = mul i64 [[TMP15]], 4
; IF-EVL-NEXT: [[TMP30:%.*]] = mul i64 [[TMP25]], 4
; IF-EVL-NEXT: [[TMP26:%.*]] = sub i64 [[A1]], [[C3]]
; IF-EVL-NEXT: [[DIFF_CHECK4:%.*]] = icmp ult i64 [[TMP26]], [[TMP25]]
; IF-EVL-NEXT: [[DIFF_CHECK4:%.*]] = icmp ult i64 [[TMP26]], [[TMP30]]
; IF-EVL-NEXT: [[CONFLICT_RDX:%.*]] = or i1 [[DIFF_CHECK]], [[DIFF_CHECK4]]
; IF-EVL-NEXT: br i1 [[CONFLICT_RDX]], label %[[SCALAR_PH]], label %[[VECTOR_PH:.*]]
; IF-EVL: [[VECTOR_PH]]:
Expand Down Expand Up @@ -134,7 +136,7 @@ define void @vp_smin(ptr %a, ptr %b, ptr %c, i64 %N) {
; IF-EVL-NEXT: [[TMP0:%.*]] = sub i64 -1, [[N]]
; IF-EVL-NEXT: [[TMP1:%.*]] = call i64 @llvm.vscale.i64()
; IF-EVL-NEXT: [[TMP2:%.*]] = mul i64 [[TMP1]], 4
; IF-EVL-NEXT: [[TMP3:%.*]] = call i64 @llvm.umax.i64(i64 13, i64 [[TMP2]])
; IF-EVL-NEXT: [[TMP3:%.*]] = call i64 @llvm.umax.i64(i64 16, i64 [[TMP2]])
; IF-EVL-NEXT: [[TMP22:%.*]] = icmp ult i64 [[TMP0]], [[TMP3]]
; IF-EVL-NEXT: br i1 [[TMP22]], label %[[SCALAR_PH:.*]], label %[[VECTOR_MEMCHECK:.*]]
; IF-EVL: [[VECTOR_MEMCHECK]]:
Expand All @@ -143,9 +145,11 @@ define void @vp_smin(ptr %a, ptr %b, ptr %c, i64 %N) {
; IF-EVL-NEXT: [[TMP23:%.*]] = mul i64 [[TMP5]], 4
; IF-EVL-NEXT: [[TMP24:%.*]] = sub i64 [[A1]], [[B2]]
; IF-EVL-NEXT: [[DIFF_CHECK:%.*]] = icmp ult i64 [[TMP24]], [[TMP23]]
; IF-EVL-NEXT: [[TMP25:%.*]] = mul i64 [[TMP5]], 4
; IF-EVL-NEXT: [[TMP15:%.*]] = call i64 @llvm.vscale.i64()
; IF-EVL-NEXT: [[TMP25:%.*]] = mul i64 [[TMP15]], 4
; IF-EVL-NEXT: [[TMP30:%.*]] = mul i64 [[TMP25]], 4
; IF-EVL-NEXT: [[TMP26:%.*]] = sub i64 [[A1]], [[C3]]
; IF-EVL-NEXT: [[DIFF_CHECK4:%.*]] = icmp ult i64 [[TMP26]], [[TMP25]]
; IF-EVL-NEXT: [[DIFF_CHECK4:%.*]] = icmp ult i64 [[TMP26]], [[TMP30]]
; IF-EVL-NEXT: [[CONFLICT_RDX:%.*]] = or i1 [[DIFF_CHECK]], [[DIFF_CHECK4]]
; IF-EVL-NEXT: br i1 [[CONFLICT_RDX]], label %[[SCALAR_PH]], label %[[VECTOR_PH:.*]]
; IF-EVL: [[VECTOR_PH]]:
Expand Down Expand Up @@ -249,7 +253,7 @@ define void @vp_umax(ptr %a, ptr %b, ptr %c, i64 %N) {
; IF-EVL-NEXT: [[TMP0:%.*]] = sub i64 -1, [[N]]
; IF-EVL-NEXT: [[TMP1:%.*]] = call i64 @llvm.vscale.i64()
; IF-EVL-NEXT: [[TMP2:%.*]] = mul i64 [[TMP1]], 4
; IF-EVL-NEXT: [[TMP3:%.*]] = call i64 @llvm.umax.i64(i64 13, i64 [[TMP2]])
; IF-EVL-NEXT: [[TMP3:%.*]] = call i64 @llvm.umax.i64(i64 16, i64 [[TMP2]])
; IF-EVL-NEXT: [[TMP22:%.*]] = icmp ult i64 [[TMP0]], [[TMP3]]
; IF-EVL-NEXT: br i1 [[TMP22]], label %[[SCALAR_PH:.*]], label %[[VECTOR_MEMCHECK:.*]]
; IF-EVL: [[VECTOR_MEMCHECK]]:
Expand All @@ -258,9 +262,11 @@ define void @vp_umax(ptr %a, ptr %b, ptr %c, i64 %N) {
; IF-EVL-NEXT: [[TMP23:%.*]] = mul i64 [[TMP5]], 4
; IF-EVL-NEXT: [[TMP24:%.*]] = sub i64 [[A1]], [[B2]]
; IF-EVL-NEXT: [[DIFF_CHECK:%.*]] = icmp ult i64 [[TMP24]], [[TMP23]]
; IF-EVL-NEXT: [[TMP25:%.*]] = mul i64 [[TMP5]], 4
; IF-EVL-NEXT: [[TMP15:%.*]] = call i64 @llvm.vscale.i64()
; IF-EVL-NEXT: [[TMP25:%.*]] = mul i64 [[TMP15]], 4
; IF-EVL-NEXT: [[TMP30:%.*]] = mul i64 [[TMP25]], 4
; IF-EVL-NEXT: [[TMP26:%.*]] = sub i64 [[A1]], [[C3]]
; IF-EVL-NEXT: [[DIFF_CHECK4:%.*]] = icmp ult i64 [[TMP26]], [[TMP25]]
; IF-EVL-NEXT: [[DIFF_CHECK4:%.*]] = icmp ult i64 [[TMP26]], [[TMP30]]
; IF-EVL-NEXT: [[CONFLICT_RDX:%.*]] = or i1 [[DIFF_CHECK]], [[DIFF_CHECK4]]
; IF-EVL-NEXT: br i1 [[CONFLICT_RDX]], label %[[SCALAR_PH]], label %[[VECTOR_PH:.*]]
; IF-EVL: [[VECTOR_PH]]:
Expand Down Expand Up @@ -364,7 +370,7 @@ define void @vp_umin(ptr %a, ptr %b, ptr %c, i64 %N) {
; IF-EVL-NEXT: [[TMP0:%.*]] = sub i64 -1, [[N]]
; IF-EVL-NEXT: [[TMP1:%.*]] = call i64 @llvm.vscale.i64()
; IF-EVL-NEXT: [[TMP2:%.*]] = mul i64 [[TMP1]], 4
; IF-EVL-NEXT: [[TMP3:%.*]] = call i64 @llvm.umax.i64(i64 13, i64 [[TMP2]])
; IF-EVL-NEXT: [[TMP3:%.*]] = call i64 @llvm.umax.i64(i64 16, i64 [[TMP2]])
; IF-EVL-NEXT: [[TMP22:%.*]] = icmp ult i64 [[TMP0]], [[TMP3]]
; IF-EVL-NEXT: br i1 [[TMP22]], label %[[SCALAR_PH:.*]], label %[[VECTOR_MEMCHECK:.*]]
; IF-EVL: [[VECTOR_MEMCHECK]]:
Expand All @@ -373,9 +379,11 @@ define void @vp_umin(ptr %a, ptr %b, ptr %c, i64 %N) {
; IF-EVL-NEXT: [[TMP23:%.*]] = mul i64 [[TMP5]], 4
; IF-EVL-NEXT: [[TMP24:%.*]] = sub i64 [[A1]], [[B2]]
; IF-EVL-NEXT: [[DIFF_CHECK:%.*]] = icmp ult i64 [[TMP24]], [[TMP23]]
; IF-EVL-NEXT: [[TMP25:%.*]] = mul i64 [[TMP5]], 4
; IF-EVL-NEXT: [[TMP15:%.*]] = call i64 @llvm.vscale.i64()
; IF-EVL-NEXT: [[TMP25:%.*]] = mul i64 [[TMP15]], 4
; IF-EVL-NEXT: [[TMP30:%.*]] = mul i64 [[TMP25]], 4
; IF-EVL-NEXT: [[TMP26:%.*]] = sub i64 [[A1]], [[C3]]
; IF-EVL-NEXT: [[DIFF_CHECK4:%.*]] = icmp ult i64 [[TMP26]], [[TMP25]]
; IF-EVL-NEXT: [[DIFF_CHECK4:%.*]] = icmp ult i64 [[TMP26]], [[TMP30]]
; IF-EVL-NEXT: [[CONFLICT_RDX:%.*]] = or i1 [[DIFF_CHECK]], [[DIFF_CHECK4]]
; IF-EVL-NEXT: br i1 [[CONFLICT_RDX]], label %[[SCALAR_PH]], label %[[VECTOR_PH:.*]]
; IF-EVL: [[VECTOR_PH]]:
Expand Down