Skip to content

Commit 05fa881

Browse files
authored
JIT: Avoid mask<->vector optimization for masks used in unhandled ways (#110307)
When a local is used as a return buffer it is not address exposed, so the address-exposure check was not sufficient. Add checks for `LCL_ADDR`, `LCL_FLD` and `STORE_LCL_FLD` to make sure any use of a mask local that is not converted disqualifies it from participating in the optimization. Also avoid doing some work for locals that are not SIMD/mask typed (common case). Previously we would do some unnecessary hash table lookups and other things in these cases.
1 parent 12afded commit 05fa881

File tree

3 files changed

+130
-35
lines changed

3 files changed

+130
-35
lines changed

src/coreclr/jit/optimizemaskconversions.cpp

Lines changed: 48 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,9 @@ class MaskConversionsCheckVisitor final : public GenTreeVisitor<MaskConversionsC
118118
public:
119119
enum
120120
{
121-
DoPostOrder = true,
122-
UseExecutionOrder = true
121+
DoPreOrder = true,
122+
UseExecutionOrder = true,
123+
DoLclVarsOnly = true,
123124
};
124125

125126
MaskConversionsCheckVisitor(Compiler* compiler, weight_t bbWeight, MaskConversionsWeightTable* weightsTable)
@@ -129,16 +130,29 @@ class MaskConversionsCheckVisitor final : public GenTreeVisitor<MaskConversionsC
129130
{
130131
}
131132

132-
Compiler::fgWalkResult PostOrderVisit(GenTree** use, GenTree* user)
133+
Compiler::fgWalkResult PreOrderVisit(GenTree** use, GenTree* user)
133134
{
135+
GenTreeLclVarCommon* lclOp = (*use)->AsLclVarCommon();
136+
LclVarDsc* varDsc = m_compiler->lvaGetDesc(lclOp);
137+
138+
if (!varTypeIsSIMDOrMask(varDsc))
139+
{
140+
return fgWalkResult::WALK_CONTINUE;
141+
}
142+
143+
// Get the existing weighting (if any).
144+
MaskConversionsWeight* weight = weightsTable->LookupPointerOrAdd(lclOp->GetLclNum(), MaskConversionsWeight());
145+
146+
JITDUMP("%s V%02d at [%06u] ", GenTree::OpName(lclOp->gtOper), lclOp->GetLclNum(),
147+
m_compiler->dspTreeID(lclOp));
148+
134149
GenTreeHWIntrinsic* convertOp = nullptr;
135150

136151
bool isLocalStore = false;
137152
bool isLocalUse = false;
138-
bool isInvalid = false;
139153
bool hasConversion = false;
140154

141-
switch ((*use)->OperGet())
155+
switch (lclOp->OperGet())
142156
{
143157
case GT_STORE_LCL_VAR:
144158
{
@@ -147,9 +161,9 @@ class MaskConversionsCheckVisitor final : public GenTreeVisitor<MaskConversionsC
147161
// Look for:
148162
// use:STORE_LCL_VAR(ConvertMaskToVector(x))
149163

150-
if ((*use)->AsLclVar()->Data()->OperIsConvertMaskToVector())
164+
if (lclOp->Data()->OperIsConvertMaskToVector())
151165
{
152-
convertOp = (*use)->AsLclVar()->Data()->AsHWIntrinsic();
166+
convertOp = lclOp->Data()->AsHWIntrinsic();
153167
hasConversion = true;
154168
}
155169
break;
@@ -164,7 +178,7 @@ class MaskConversionsCheckVisitor final : public GenTreeVisitor<MaskConversionsC
164178
// -or-
165179
// user: ConditionalSelect(use:LCL_VAR(x), y, z)
166180

167-
if (user->OperIsHWIntrinsic())
181+
if ((user != nullptr) && user->OperIsHWIntrinsic())
168182
{
169183
GenTreeHWIntrinsic* hwintrin = user->AsHWIntrinsic();
170184
NamedIntrinsic ni = hwintrin->GetHWIntrinsicId();
@@ -186,7 +200,7 @@ class MaskConversionsCheckVisitor final : public GenTreeVisitor<MaskConversionsC
186200
// emit `vblendmps zmm1 {k1}, zmm2, zmm3` instead of containing the CndSel
187201
// as part of something like `vaddps zmm1 {k1}, zmm2, zmm3`
188202

189-
if (hwintrin->Op(1) == (*use))
203+
if (hwintrin->Op(1) == lclOp)
190204
{
191205
convertOp = user->AsHWIntrinsic();
192206
hasConversion = true;
@@ -197,25 +211,19 @@ class MaskConversionsCheckVisitor final : public GenTreeVisitor<MaskConversionsC
197211
}
198212

199213
default:
200-
break;
214+
// LCL_ADDR (can show up unexposed due to retbufs), or partial
215+
// use/store. We do not handle these.
216+
weight->InvalidateWeight();
217+
JITDUMP("is unhandled. ");
218+
return fgWalkResult::WALK_CONTINUE;
201219
}
202220

203221
if (isLocalStore || isLocalUse)
204222
{
205-
GenTreeLclVarCommon* lclOp = (*use)->AsLclVarCommon();
206-
LclVarDsc* varDsc = m_compiler->lvaGetDesc(lclOp->GetLclNum());
207-
208-
// Get the existing weighting (if any).
209-
MaskConversionsWeight defaultWeight;
210-
MaskConversionsWeight* weight = weightsTable->LookupPointerOrAdd(lclOp->GetLclNum(), defaultWeight);
211-
212-
JITDUMP("Local %s V%02d at [%06u] ", isLocalStore ? "store" : "use", lclOp->GetLclNum(),
213-
m_compiler->dspTreeID(lclOp));
214-
215223
// Cannot convert any locals with an exposed address.
216224
if (varDsc->IsAddressExposed())
217225
{
218-
JITDUMP("is address exposed elsewhere. ");
226+
JITDUMP("is address exposed. ");
219227
weight->InvalidateWeight();
220228
return fgWalkResult::WALK_CONTINUE;
221229
}
@@ -345,29 +353,34 @@ class MaskConversionsUpdateVisitor final : public GenTreeVisitor<MaskConversions
345353
assert(lclOp != nullptr);
346354

347355
// Get the existing weighting.
348-
MaskConversionsWeight weight;
349-
bool found = weightsTable->Lookup(lclOp->GetLclNum(), &weight);
350-
assert(found);
356+
MaskConversionsWeight* weight = weightsTable->LookupPointer(lclOp->GetLclNum());
357+
358+
if (weight == nullptr)
359+
{
360+
return fgWalkResult::WALK_CONTINUE;
361+
}
351362

352363
// Quit if the cost of changing is higher or is invalid.
353-
if (weight.currentCost <= weight.switchCost || weight.invalid)
364+
if (weight->currentCost <= weight->switchCost || weight->invalid)
354365
{
355366
JITDUMP("Local %s V%02d at [%06u] will not be converted. ", isLocalStore ? "store" : "use",
356367
lclOp->GetLclNum(), Compiler::dspTreeID(lclOp));
357-
weight.DumpTotalWeight();
368+
weight->DumpTotalWeight();
358369
return fgWalkResult::WALK_CONTINUE;
359370
}
360371

361372
JITDUMP("Local %s V%02d at [%06u] will be converted. ", isLocalStore ? "store" : "use", lclOp->GetLclNum(),
362373
Compiler::dspTreeID(lclOp));
363-
weight.DumpTotalWeight();
374+
weight->DumpTotalWeight();
364375

365376
// Fix up the type of the lcl and the lclvar.
366377
assert(lclOp->gtType != TYP_MASK);
367378
var_types lclOrigType = lclOp->gtType;
368379
lclOp->gtType = TYP_MASK;
369-
LclVarDsc* varDsc = m_compiler->lvaGetDesc(lclOp->GetLclNum());
370-
varDsc->lvType = TYP_MASK;
380+
381+
LclVarDsc* varDsc = m_compiler->lvaGetDesc(lclOp->GetLclNum());
382+
assert(varTypeIsSIMDOrMask(varDsc));
383+
varDsc->lvType = TYP_MASK;
371384

372385
// Add or remove a conversion
373386

@@ -390,9 +403,9 @@ class MaskConversionsUpdateVisitor final : public GenTreeVisitor<MaskConversions
390403

391404
// There is not enough information in the lcl to get simd types. Instead reuse the cached
392405
// simd types from the removed convert nodes.
393-
assert(weight.simdBaseJitType != CORINFO_TYPE_UNDEF);
394-
lclOp->Data() = m_compiler->gtNewSimdCvtVectorToMaskNode(TYP_MASK, lclOp->Data(), weight.simdBaseJitType,
395-
weight.simdSize);
406+
assert(weight->simdBaseJitType != CORINFO_TYPE_UNDEF);
407+
lclOp->Data() = m_compiler->gtNewSimdCvtVectorToMaskNode(TYP_MASK, lclOp->Data(), weight->simdBaseJitType,
408+
weight->simdSize);
396409
}
397410

398411
else if (isLocalUse && removeConversion)
@@ -414,9 +427,9 @@ class MaskConversionsUpdateVisitor final : public GenTreeVisitor<MaskConversions
414427

415428
// There is not enough information in the lcl to get simd types. Instead reuse the cached simd
416429
// types from the removed convert nodes.
417-
assert(weight.simdBaseJitType != CORINFO_TYPE_UNDEF);
430+
assert(weight->simdBaseJitType != CORINFO_TYPE_UNDEF);
418431
*use =
419-
m_compiler->gtNewSimdCvtMaskToVectorNode(lclOrigType, lclOp, weight.simdBaseJitType, weight.simdSize);
432+
m_compiler->gtNewSimdCvtMaskToVectorNode(lclOrigType, lclOp, weight->simdBaseJitType, weight->simdSize);
420433
}
421434

422435
JITDUMP("Updated %s V%02d at [%06u] to mask (%s conversion)\n", isLocalStore ? "store" : "use",
@@ -521,7 +534,7 @@ PhaseStatus Compiler::fgOptimizeMaskConversions()
521534
// Only check statements where there is a local of type TYP_SIMD/TYP_MASK.
522535
for (GenTreeLclVarCommon* lcl : stmt->LocalsTreeList())
523536
{
524-
if (varTypeIsSIMDOrMask(lcl))
537+
if (varTypeIsSIMDOrMask(lvaGetDesc(lcl)))
525538
{
526539
// Parse the entire statement.
527540
MaskConversionsCheckVisitor ev(this, block->getBBWeight(this), &weightsTable);
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
// Generated by Fuzzlyn v2.4 on 2024-12-01 16:32:26
5+
// Run on X64 Linux
6+
// Seed: 7861295224295601455-vectort,vector128,vector256,x86aes,x86avx,x86avx2,x86avx512bw,x86avx512bwvl,x86avx512cd,x86avx512cdvl,x86avx512dq,x86avx512dqvl,x86avx512f,x86avx512fvl,x86avx512fx64,x86bmi1,x86bmi1x64,x86bmi2,x86bmi2x64,x86fma,x86lzcnt,x86lzcntx64,x86pclmulqdq,x86popcnt,x86popcntx64,x86sse,x86ssex64,x86sse2,x86sse2x64,x86sse3,x86sse41,x86sse41x64,x86sse42,x86sse42x64,x86ssse3,x86x86base
7+
// Reduced from 115.8 KiB to 0.9 KiB in 00:02:27
8+
// Hits JIT assert in Release:
9+
// Assertion failed 'newLclValue.BothDefined()' in 'Program:Main(Fuzzlyn.ExecutionServer.IRuntime)' during 'Do value numbering' (IL size 61; hash 0xade6b36b; FullOpts)
10+
//
11+
// File: /__w/1/s/src/coreclr/jit/valuenum.cpp Line: 6138
12+
//
13+
using System;
14+
using System.Numerics;
15+
using System.Runtime.CompilerServices;
16+
using System.Runtime.Intrinsics;
17+
using System.Runtime.Intrinsics.X86;
18+
using Xunit;
19+
20+
public class C0
21+
{
22+
public uint F0;
23+
}
24+
25+
public struct S0
26+
{
27+
public C0 F2;
28+
}
29+
30+
public class C3
31+
{
32+
public byte F0;
33+
}
34+
35+
public class Runtime_110306
36+
{
37+
public static S0 s_3;
38+
39+
[Fact]
40+
public static void TestEntryPoint()
41+
{
42+
if (!Avx512F.VL.IsSupported)
43+
{
44+
return;
45+
}
46+
47+
try
48+
{
49+
TestMain();
50+
}
51+
catch
52+
{
53+
}
54+
}
55+
56+
private static void TestMain()
57+
{
58+
var vr5 = Vector256.Create(1, 0, 0, 0);
59+
Vector256<long> vr15 = Vector256.Create<long>(0);
60+
Vector256<long> vr8 = Avx512F.VL.CompareNotEqual(vr5, vr15);
61+
long vr9 = 0;
62+
var vr10 = new C3();
63+
vr8 = M3();
64+
long vr11 = vr9;
65+
var vr12 = s_3.F2.F0;
66+
vr8 = vr8;
67+
}
68+
69+
[MethodImpl(MethodImplOptions.NoInlining)]
70+
public static Vector256<long> M3()
71+
{
72+
return default;
73+
}
74+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
<Project Sdk="Microsoft.NET.Sdk">
2+
<PropertyGroup>
3+
<Optimize>True</Optimize>
4+
</PropertyGroup>
5+
<ItemGroup>
6+
<Compile Include="$(MSBuildProjectName).cs" />
7+
</ItemGroup>
8+
</Project>

0 commit comments

Comments
 (0)