Skip to content

Commit

Permalink
[VectorCombine] foldShuffleOfCastops - extend shuffle(bitcast(x),bitc…
Browse files Browse the repository at this point in the history
…ast(y)) -> bitcast(shuffle(x,y)) support

Handle shuffle mask scaling handling for cases where the bitcast src/dst element counts are different
  • Loading branch information
RKSimon committed Apr 11, 2024
1 parent 402f15e commit 6fd2fdc
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 24 deletions.
3 changes: 1 addition & 2 deletions clang/test/CodeGen/X86/avx-shuffle-builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ __m256 test_mm256_permute2f128_ps(__m256 a, __m256 b) {

__m256i test_mm256_permute2f128_si256(__m256i a, __m256i b) {
// CHECK-LABEL: test_mm256_permute2f128_si256
// X64: shufflevector{{.*}}<i32 0, i32 1, i32 4, i32 5>
// X86: shufflevector{{.*}}<i32 0, i32 1, i32 2, i32 3, i32 8, i32 9, i32 10, i32 11>
// CHECK: shufflevector{{.*}}<i32 0, i32 1, i32 4, i32 5>
return _mm256_permute2f128_si256(a, b, 0x20);
}

Expand Down
40 changes: 30 additions & 10 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1448,9 +1448,9 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
/// into "castop (shuffle)".
bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
Value *V0, *V1;
ArrayRef<int> Mask;
ArrayRef<int> OldMask;
if (!match(&I, m_Shuffle(m_OneUse(m_Value(V0)), m_OneUse(m_Value(V1)),
m_Mask(Mask))))
m_Mask(OldMask))))
return false;

auto *C0 = dyn_cast<CastInst>(V0);
Expand All @@ -1473,12 +1473,32 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
auto *CastDstTy = dyn_cast<FixedVectorType>(C0->getDestTy());
auto *CastSrcTy = dyn_cast<FixedVectorType>(C0->getSrcTy());
if (!ShuffleDstTy || !CastDstTy || !CastSrcTy ||
CastDstTy->getElementCount() != CastSrcTy->getElementCount())
if (!ShuffleDstTy || !CastDstTy || !CastSrcTy)
return false;

unsigned NumSrcElts = CastSrcTy->getNumElements();
unsigned NumDstElts = CastDstTy->getNumElements();
assert((NumDstElts == NumSrcElts || Opcode == Instruction::BitCast) &&
"Only bitcasts expected to alter src/dst element counts");

SmallVector<int, 16> NewMask;
if (NumSrcElts >= NumDstElts) {
// The bitcast is from wide to narrow/equal elements. The shuffle mask can
// always be expanded to the equivalent form choosing narrower elements.
assert(NumSrcElts % NumDstElts == 0 && "Unexpected shuffle mask");
unsigned ScaleFactor = NumSrcElts / NumDstElts;
narrowShuffleMaskElts(ScaleFactor, OldMask, NewMask);
} else {
// The bitcast is from narrow elements to wide elements. The shuffle mask
// must choose consecutive elements to allow casting first.
assert(NumDstElts % NumSrcElts == 0 && "Unexpected shuffle mask");
unsigned ScaleFactor = NumDstElts / NumSrcElts;
if (!widenShuffleMaskElts(ScaleFactor, OldMask, NewMask))
return false;
}

auto *NewShuffleDstTy =
FixedVectorType::get(CastSrcTy->getScalarType(), Mask.size());
FixedVectorType::get(CastSrcTy->getScalarType(), NewMask.size());

// Try to replace a castop with a shuffle if the shuffle is not costly.
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
Expand All @@ -1489,11 +1509,11 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy,
TTI::CastContextHint::None, CostKind);
OldCost +=
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, CastDstTy, Mask,
CostKind, 0, nullptr, std::nullopt, &I);
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, CastDstTy,
OldMask, CostKind, 0, nullptr, std::nullopt, &I);

InstructionCost NewCost = TTI.getShuffleCost(
TargetTransformInfo::SK_PermuteTwoSrc, CastSrcTy, Mask, CostKind);
TargetTransformInfo::SK_PermuteTwoSrc, CastSrcTy, NewMask, CostKind);
NewCost += TTI.getCastInstrCost(Opcode, ShuffleDstTy, NewShuffleDstTy,
TTI::CastContextHint::None, CostKind);

Expand All @@ -1503,8 +1523,8 @@ bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
if (NewCost > OldCost)
return false;

Value *Shuf =
Builder.CreateShuffleVector(C0->getOperand(0), C1->getOperand(0), Mask);
Value *Shuf = Builder.CreateShuffleVector(C0->getOperand(0),
C1->getOperand(0), NewMask);
Value *Cast = Builder.CreateCast(Opcode, Shuf, ShuffleDstTy);

// Intersect flags from the old casts.
Expand Down
5 changes: 2 additions & 3 deletions llvm/test/Transforms/PhaseOrdering/X86/pr67803.ll
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,15 @@ define <4 x i64> @PR67803(<4 x i64> %x, <4 x i64> %y, <4 x i64> %a, <4 x i64> %b
; CHECK-NEXT: [[TMP9:%.*]] = bitcast <8 x i32> [[TMP3]] to <32 x i8>
; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <32 x i8> [[TMP9]], <32 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
; CHECK-NEXT: [[TMP11:%.*]] = tail call <16 x i8> @llvm.x86.sse41.pblendvb(<16 x i8> [[TMP6]], <16 x i8> [[TMP8]], <16 x i8> [[TMP10]])
; CHECK-NEXT: [[TMP12:%.*]] = bitcast <16 x i8> [[TMP11]] to <2 x i64>
; CHECK-NEXT: [[TMP13:%.*]] = bitcast <4 x i64> [[A]] to <32 x i8>
; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <32 x i8> [[TMP13]], <32 x i8> poison, <16 x i32> <i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30, i32 31>
; CHECK-NEXT: [[TMP15:%.*]] = bitcast <4 x i64> [[B]] to <32 x i8>
; CHECK-NEXT: [[TMP16:%.*]] = shufflevector <32 x i8> [[TMP15]], <32 x i8> poison, <16 x i32> <i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30, i32 31>
; CHECK-NEXT: [[TMP17:%.*]] = bitcast <8 x i32> [[TMP3]] to <32 x i8>
; CHECK-NEXT: [[TMP18:%.*]] = shufflevector <32 x i8> [[TMP17]], <32 x i8> poison, <16 x i32> <i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30, i32 31>
; CHECK-NEXT: [[TMP19:%.*]] = tail call <16 x i8> @llvm.x86.sse41.pblendvb(<16 x i8> [[TMP14]], <16 x i8> [[TMP16]], <16 x i8> [[TMP18]])
; CHECK-NEXT: [[TMP20:%.*]] = bitcast <16 x i8> [[TMP19]] to <2 x i64>
; CHECK-NEXT: [[SHUFFLE_I23:%.*]] = shufflevector <2 x i64> [[TMP12]], <2 x i64> [[TMP20]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[TMP20:%.*]] = shufflevector <16 x i8> [[TMP11]], <16 x i8> [[TMP19]], <32 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15, i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30, i32 31>
; CHECK-NEXT: [[SHUFFLE_I23:%.*]] = bitcast <32 x i8> [[TMP20]] to <4 x i64>
; CHECK-NEXT: ret <4 x i64> [[SHUFFLE_I23]]
;
entry:
Expand Down
14 changes: 6 additions & 8 deletions llvm/test/Transforms/VectorCombine/X86/shuffle-of-casts.ll
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,12 @@ define <8 x float> @concat_bitcast_v4i32_v8f32(<4 x i32> %a0, <4 x i32> %a1) {
ret <8 x float> %r
}

; TODO - bitcasts (lower element count)
; bitcasts (lower element count)

define <4 x double> @concat_bitcast_v8i16_v4f64(<8 x i16> %a0, <8 x i16> %a1) {
; CHECK-LABEL: @concat_bitcast_v8i16_v4f64(
; CHECK-NEXT: [[X0:%.*]] = bitcast <8 x i16> [[A0:%.*]] to <2 x double>
; CHECK-NEXT: [[X1:%.*]] = bitcast <8 x i16> [[A1:%.*]] to <2 x double>
; CHECK-NEXT: [[R:%.*]] = shufflevector <2 x double> [[X0]], <2 x double> [[X1]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x i16> [[A0:%.*]], <8 x i16> [[A1:%.*]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
; CHECK-NEXT: [[R:%.*]] = bitcast <16 x i16> [[TMP1]] to <4 x double>
; CHECK-NEXT: ret <4 x double> [[R]]
;
%x0 = bitcast <8 x i16> %a0 to <2 x double>
Expand All @@ -194,13 +193,12 @@ define <4 x double> @concat_bitcast_v8i16_v4f64(<8 x i16> %a0, <8 x i16> %a1) {
ret <4 x double> %r
}

; TODO - bitcasts (higher element count)
; bitcasts (higher element count)

define <16 x i16> @concat_bitcast_v4i32_v16i16(<4 x i32> %a0, <4 x i32> %a1) {
; CHECK-LABEL: @concat_bitcast_v4i32_v16i16(
; CHECK-NEXT: [[X0:%.*]] = bitcast <4 x i32> [[A0:%.*]] to <8 x i16>
; CHECK-NEXT: [[X1:%.*]] = bitcast <4 x i32> [[A1:%.*]] to <8 x i16>
; CHECK-NEXT: [[R:%.*]] = shufflevector <8 x i16> [[X0]], <8 x i16> [[X1]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x i32> [[A0:%.*]], <4 x i32> [[A1:%.*]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
; CHECK-NEXT: [[R:%.*]] = bitcast <8 x i32> [[TMP1]] to <16 x i16>
; CHECK-NEXT: ret <16 x i16> [[R]]
;
%x0 = bitcast <4 x i32> %a0 to <8 x i16>
Expand Down
5 changes: 4 additions & 1 deletion llvm/test/Transforms/VectorCombine/X86/shuffle.ll
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,14 @@ define <16 x i8> @bitcast_shuf_uses(<4 x i32> %v) {
}

; shuffle of 2 operands removes bitcasts
; TODO - can we remove the empty bitcast(bitcast()) ?

define <4 x i64> @bitcast_shuf_remove_bitcasts(<2 x i64> %a0, <2 x i64> %a1) {
; CHECK-LABEL: @bitcast_shuf_remove_bitcasts(
; CHECK-NEXT: [[R:%.*]] = shufflevector <2 x i64> [[A0:%.*]], <2 x i64> [[A1:%.*]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: ret <4 x i64> [[R]]
; CHECK-NEXT: [[SHUF:%.*]] = bitcast <4 x i64> [[R]] to <8 x i32>
; CHECK-NEXT: [[R1:%.*]] = bitcast <8 x i32> [[SHUF]] to <4 x i64>
; CHECK-NEXT: ret <4 x i64> [[R1]]
;
%bc0 = bitcast <2 x i64> %a0 to <4 x i32>
%bc1 = bitcast <2 x i64> %a1 to <4 x i32>
Expand Down

0 comments on commit 6fd2fdc

Please sign in to comment.