Skip to content

Commit

Permalink
[AArch64][SVE] Add dot product codegen for partial reductions with no…
Browse files Browse the repository at this point in the history
… binary operation on input (llvm#120207)

Add codegen for when the input type has 4 times as many elements as the
output type and the input to the partial reduction does not have a
binary operation performed on it.
  • Loading branch information
JamesChesterman authored Jan 6, 2025
1 parent 1feeeb4 commit 3134045
Show file tree
Hide file tree
Showing 3 changed files with 483 additions and 17 deletions.
44 changes: 27 additions & 17 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21987,21 +21987,35 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
SDLoc DL(N);

SDValue Op2 = N->getOperand(2);
if (Op2->getOpcode() != ISD::MUL ||
!ISD::isExtOpcode(Op2->getOperand(0)->getOpcode()) ||
!ISD::isExtOpcode(Op2->getOperand(1)->getOpcode()))
return SDValue();
unsigned Op2Opcode = Op2->getOpcode();
SDValue MulOpLHS, MulOpRHS;
bool MulOpLHSIsSigned, MulOpRHSIsSigned;
if (ISD::isExtOpcode(Op2Opcode)) {
MulOpLHSIsSigned = MulOpRHSIsSigned = (Op2Opcode == ISD::SIGN_EXTEND);
MulOpLHS = Op2->getOperand(0);
MulOpRHS = DAG.getConstant(1, DL, MulOpLHS.getValueType());
} else if (Op2Opcode == ISD::MUL) {
SDValue ExtMulOpLHS = Op2->getOperand(0);
SDValue ExtMulOpRHS = Op2->getOperand(1);

unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
!ISD::isExtOpcode(ExtMulOpRHSOpcode))
return SDValue();

SDValue Acc = N->getOperand(1);
SDValue Mul = N->getOperand(2);
SDValue ExtMulOpLHS = Mul->getOperand(0);
SDValue ExtMulOpRHS = Mul->getOperand(1);
MulOpLHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
MulOpRHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;

SDValue MulOpLHS = ExtMulOpLHS->getOperand(0);
SDValue MulOpRHS = ExtMulOpRHS->getOperand(0);
if (MulOpLHS.getValueType() != MulOpRHS.getValueType())
MulOpLHS = ExtMulOpLHS->getOperand(0);
MulOpRHS = ExtMulOpRHS->getOperand(0);

if (MulOpLHS.getValueType() != MulOpRHS.getValueType())
return SDValue();
} else
return SDValue();

SDValue Acc = N->getOperand(1);
EVT ReducedVT = N->getValueType(0);
EVT MulSrcVT = MulOpLHS.getValueType();

Expand All @@ -22015,8 +22029,6 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
!(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
return SDValue();

bool MulOpLHSIsSigned = ExtMulOpLHS->getOpcode() == ISD::SIGN_EXTEND;
bool MulOpRHSIsSigned = ExtMulOpRHS->getOpcode() == ISD::SIGN_EXTEND;
// If the extensions are mixed, we should lower it to a usdot instead
unsigned Opcode = 0;
if (MulOpLHSIsSigned != MulOpRHSIsSigned) {
Expand All @@ -22032,10 +22044,8 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
// USDOT expects the signed operand to be last
if (!MulOpRHSIsSigned)
std::swap(MulOpLHS, MulOpRHS);
} else if (MulOpLHSIsSigned)
Opcode = AArch64ISD::SDOT;
else
Opcode = AArch64ISD::UDOT;
} else
Opcode = MulOpLHSIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;

// Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
// product followed by a zero / sign extension
Expand Down
248 changes: 248 additions & 0 deletions llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,166 @@ entry:
ret <4 x i64> %partial.reduce
}

define <4 x i32> @udot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){
; CHECK-DOT-LABEL: udot_no_bin_op:
; CHECK-DOT: // %bb.0:
; CHECK-DOT-NEXT: movi v2.16b, #1
; CHECK-DOT-NEXT: udot v0.4s, v1.16b, v2.16b
; CHECK-DOT-NEXT: ret
;
; CHECK-NODOT-LABEL: udot_no_bin_op:
; CHECK-NODOT: // %bb.0:
; CHECK-NODOT-NEXT: ushll v2.8h, v1.8b, #0
; CHECK-NODOT-NEXT: ushll2 v1.8h, v1.16b, #0
; CHECK-NODOT-NEXT: ushll v3.4s, v1.4h, #0
; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v2.4h
; CHECK-NODOT-NEXT: uaddw2 v2.4s, v3.4s, v2.8h
; CHECK-NODOT-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
; CHECK-NODOT-NEXT: ret
%a.wide = zext <16 x i8> %a to <16 x i32>
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %a.wide)
ret <4 x i32> %partial.reduce
}

define <4 x i32> @sdot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){
; CHECK-DOT-LABEL: sdot_no_bin_op:
; CHECK-DOT: // %bb.0:
; CHECK-DOT-NEXT: movi v2.16b, #1
; CHECK-DOT-NEXT: sdot v0.4s, v1.16b, v2.16b
; CHECK-DOT-NEXT: ret
;
; CHECK-NODOT-LABEL: sdot_no_bin_op:
; CHECK-NODOT: // %bb.0:
; CHECK-NODOT-NEXT: sshll v2.8h, v1.8b, #0
; CHECK-NODOT-NEXT: sshll2 v1.8h, v1.16b, #0
; CHECK-NODOT-NEXT: sshll v3.4s, v1.4h, #0
; CHECK-NODOT-NEXT: saddw v0.4s, v0.4s, v2.4h
; CHECK-NODOT-NEXT: saddw2 v2.4s, v3.4s, v2.8h
; CHECK-NODOT-NEXT: saddw2 v0.4s, v0.4s, v1.8h
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
; CHECK-NODOT-NEXT: ret
%a.wide = sext <16 x i8> %a to <16 x i32>
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %a.wide)
ret <4 x i32> %partial.reduce
}

define <2 x i32> @udot_no_bin_op_narrow(<2 x i32> %acc, <8 x i8> %a){
; CHECK-DOT-LABEL: udot_no_bin_op_narrow:
; CHECK-DOT: // %bb.0:
; CHECK-DOT-NEXT: movi v2.8b, #1
; CHECK-DOT-NEXT: udot v0.2s, v1.8b, v2.8b
; CHECK-DOT-NEXT: ret
;
; CHECK-NODOT-LABEL: udot_no_bin_op_narrow:
; CHECK-NODOT: // %bb.0:
; CHECK-NODOT-NEXT: ushll v1.8h, v1.8b, #0
; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
; CHECK-NODOT-NEXT: ushll v2.4s, v1.4h, #0
; CHECK-NODOT-NEXT: ushll2 v3.4s, v1.8h, #0
; CHECK-NODOT-NEXT: ext v4.16b, v1.16b, v1.16b, #8
; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v1.4h
; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
; CHECK-NODOT-NEXT: uaddw v1.4s, v2.4s, v4.4h
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NODOT-NEXT: ret
%a.wide = zext <8 x i8> %a to <8 x i32>
%partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v2i32.v8i32(<2 x i32> %acc, <8 x i32> %a.wide)
ret <2 x i32> %partial.reduce
}

define <2 x i32> @sdot_no_bin_op_narrow(<2 x i32> %acc, <8 x i8> %a){
; CHECK-DOT-LABEL: sdot_no_bin_op_narrow:
; CHECK-DOT: // %bb.0:
; CHECK-DOT-NEXT: movi v2.8b, #1
; CHECK-DOT-NEXT: sdot v0.2s, v1.8b, v2.8b
; CHECK-DOT-NEXT: ret
;
; CHECK-NODOT-LABEL: sdot_no_bin_op_narrow:
; CHECK-NODOT: // %bb.0:
; CHECK-NODOT-NEXT: sshll v1.8h, v1.8b, #0
; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
; CHECK-NODOT-NEXT: sshll v2.4s, v1.4h, #0
; CHECK-NODOT-NEXT: sshll2 v3.4s, v1.8h, #0
; CHECK-NODOT-NEXT: ext v4.16b, v1.16b, v1.16b, #8
; CHECK-NODOT-NEXT: saddw v0.4s, v0.4s, v1.4h
; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
; CHECK-NODOT-NEXT: saddw v1.4s, v2.4s, v4.4h
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NODOT-NEXT: ret
%a.wide = sext <8 x i8> %a to <8 x i32>
%partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v2i32.v8i32(<2 x i32> %acc, <8 x i32> %a.wide)
ret <2 x i32> %partial.reduce
}

define <4 x i64> @udot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
; CHECK-DOT-LABEL: udot_no_bin_op_8to64:
; CHECK-DOT: // %bb.0:
; CHECK-DOT-NEXT: movi v3.16b, #1
; CHECK-DOT-NEXT: movi v4.2d, #0000000000000000
; CHECK-DOT-NEXT: udot v4.4s, v2.16b, v3.16b
; CHECK-DOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s
; CHECK-DOT-NEXT: saddw v0.2d, v0.2d, v4.2s
; CHECK-DOT-NEXT: ret
;
; CHECK-NODOT-LABEL: udot_no_bin_op_8to64:
; CHECK-NODOT: // %bb.0:
; CHECK-NODOT-NEXT: ushll v3.8h, v2.8b, #0
; CHECK-NODOT-NEXT: ushll2 v2.8h, v2.16b, #0
; CHECK-NODOT-NEXT: ushll v4.4s, v3.4h, #0
; CHECK-NODOT-NEXT: ushll v5.4s, v2.4h, #0
; CHECK-NODOT-NEXT: ushll2 v3.4s, v3.8h, #0
; CHECK-NODOT-NEXT: ushll2 v2.4s, v2.8h, #0
; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v4.4s
; CHECK-NODOT-NEXT: uaddw v0.2d, v0.2d, v4.2s
; CHECK-NODOT-NEXT: uaddl2 v4.2d, v3.4s, v5.4s
; CHECK-NODOT-NEXT: uaddl v3.2d, v3.2s, v5.2s
; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v2.4s
; CHECK-NODOT-NEXT: uaddw v0.2d, v0.2d, v2.2s
; CHECK-NODOT-NEXT: add v1.2d, v4.2d, v1.2d
; CHECK-NODOT-NEXT: add v0.2d, v3.2d, v0.2d
; CHECK-NODOT-NEXT: ret
%a.wide = zext <16 x i8> %a to <16 x i64>
%partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(<4 x i64> %acc, <16 x i64> %a.wide)
ret <4 x i64> %partial.reduce
}

define <4 x i64> @sdot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
; CHECK-DOT-LABEL: sdot_no_bin_op_8to64:
; CHECK-DOT: // %bb.0:
; CHECK-DOT-NEXT: movi v3.16b, #1
; CHECK-DOT-NEXT: movi v4.2d, #0000000000000000
; CHECK-DOT-NEXT: sdot v4.4s, v2.16b, v3.16b
; CHECK-DOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s
; CHECK-DOT-NEXT: saddw v0.2d, v0.2d, v4.2s
; CHECK-DOT-NEXT: ret
;
; CHECK-NODOT-LABEL: sdot_no_bin_op_8to64:
; CHECK-NODOT: // %bb.0:
; CHECK-NODOT-NEXT: sshll v3.8h, v2.8b, #0
; CHECK-NODOT-NEXT: sshll2 v2.8h, v2.16b, #0
; CHECK-NODOT-NEXT: sshll v4.4s, v3.4h, #0
; CHECK-NODOT-NEXT: sshll v5.4s, v2.4h, #0
; CHECK-NODOT-NEXT: sshll2 v3.4s, v3.8h, #0
; CHECK-NODOT-NEXT: sshll2 v2.4s, v2.8h, #0
; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s
; CHECK-NODOT-NEXT: saddw v0.2d, v0.2d, v4.2s
; CHECK-NODOT-NEXT: saddl2 v4.2d, v3.4s, v5.4s
; CHECK-NODOT-NEXT: saddl v3.2d, v3.2s, v5.2s
; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v2.4s
; CHECK-NODOT-NEXT: saddw v0.2d, v0.2d, v2.2s
; CHECK-NODOT-NEXT: add v1.2d, v4.2d, v1.2d
; CHECK-NODOT-NEXT: add v0.2d, v3.2d, v0.2d
; CHECK-NODOT-NEXT: ret
%a.wide = sext <16 x i8> %a to <16 x i64>
%partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(<4 x i64> %acc, <16 x i64> %a.wide)
ret <4 x i64> %partial.reduce
}

define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
; CHECK-LABEL: not_udot:
; CHECK: // %bb.0:
Expand Down Expand Up @@ -398,3 +558,91 @@ define <2 x i32> @not_udot_narrow(<2 x i32> %acc, <4 x i8> %u, <4 x i8> %s) {
%partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <4 x i32> %mult)
ret <2 x i32> %partial.reduce
}

define <2 x i64> @udot_different_types(<2 x i64> %acc, <8 x i16> %a, <8 x i8> %b){
; CHECK-LABEL: udot_different_types:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: ushll v2.8h, v2.8b, #0
; CHECK-NEXT: ushll v3.4s, v1.4h, #0
; CHECK-NEXT: ushll2 v1.4s, v1.8h, #0
; CHECK-NEXT: ushll v4.4s, v2.4h, #0
; CHECK-NEXT: ushll2 v2.4s, v2.8h, #0
; CHECK-NEXT: umull v5.2d, v1.2s, v2.2s
; CHECK-NEXT: umlal v0.2d, v3.2s, v4.2s
; CHECK-NEXT: umlal2 v0.2d, v1.4s, v2.4s
; CHECK-NEXT: umlal2 v5.2d, v3.4s, v4.4s
; CHECK-NEXT: add v0.2d, v5.2d, v0.2d
; CHECK-NEXT: ret
entry:
%a.wide = zext <8 x i16> %a to <8 x i64>
%b.wide = zext <8 x i8> %b to <8 x i64>
%mult = mul nuw nsw <8 x i64> %a.wide, %b.wide
%partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add.v2i64.v8i64(<2 x i64> %acc, <8 x i64> %mult)
ret <2 x i64> %partial.reduce
}

define <2 x i64> @sdot_different_types(<2 x i64> %acc, <8 x i16> %a, <8 x i8> %b){
; CHECK-LABEL: sdot_different_types:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sshll v2.8h, v2.8b, #0
; CHECK-NEXT: sshll v3.4s, v1.4h, #0
; CHECK-NEXT: sshll2 v1.4s, v1.8h, #0
; CHECK-NEXT: sshll v4.4s, v2.4h, #0
; CHECK-NEXT: sshll2 v2.4s, v2.8h, #0
; CHECK-NEXT: smull v5.2d, v1.2s, v2.2s
; CHECK-NEXT: smlal v0.2d, v3.2s, v4.2s
; CHECK-NEXT: smlal2 v0.2d, v1.4s, v2.4s
; CHECK-NEXT: smlal2 v5.2d, v3.4s, v4.4s
; CHECK-NEXT: add v0.2d, v5.2d, v0.2d
; CHECK-NEXT: ret
entry:
%a.wide = sext <8 x i16> %a to <8 x i64>
%b.wide = sext <8 x i8> %b to <8 x i64>
%mult = mul nuw nsw <8 x i64> %a.wide, %b.wide
%partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add.v2i64.v8i64(<2 x i64> %acc, <8 x i64> %mult)
ret <2 x i64> %partial.reduce
}

define <2 x i64> @usdot_different_types(<2 x i64> %acc, <8 x i16> %a, <8 x i8> %b){
; CHECK-LABEL: usdot_different_types:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sshll v2.8h, v2.8b, #0
; CHECK-NEXT: ushll v3.4s, v1.4h, #0
; CHECK-NEXT: ushll2 v1.4s, v1.8h, #0
; CHECK-NEXT: sshll v4.4s, v2.4h, #0
; CHECK-NEXT: sshll2 v2.4s, v2.8h, #0
; CHECK-NEXT: smull v5.2d, v1.2s, v2.2s
; CHECK-NEXT: smlal v0.2d, v3.2s, v4.2s
; CHECK-NEXT: smlal2 v0.2d, v1.4s, v2.4s
; CHECK-NEXT: smlal2 v5.2d, v3.4s, v4.4s
; CHECK-NEXT: add v0.2d, v5.2d, v0.2d
; CHECK-NEXT: ret
entry:
%a.wide = zext <8 x i16> %a to <8 x i64>
%b.wide = sext <8 x i8> %b to <8 x i64>
%mult = mul nuw nsw <8 x i64> %a.wide, %b.wide
%partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add.v2i64.v8i64(<2 x i64> %acc, <8 x i64> %mult)
ret <2 x i64> %partial.reduce
}

define <2 x i64> @sudot_different_types(<2 x i64> %acc, <8 x i16> %a, <8 x i8> %b){
; CHECK-LABEL: sudot_different_types:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: ushll v2.8h, v2.8b, #0
; CHECK-NEXT: sshll v3.4s, v1.4h, #0
; CHECK-NEXT: sshll2 v1.4s, v1.8h, #0
; CHECK-NEXT: ushll v4.4s, v2.4h, #0
; CHECK-NEXT: ushll2 v2.4s, v2.8h, #0
; CHECK-NEXT: smull v5.2d, v1.2s, v2.2s
; CHECK-NEXT: smlal v0.2d, v3.2s, v4.2s
; CHECK-NEXT: smlal2 v0.2d, v1.4s, v2.4s
; CHECK-NEXT: smlal2 v5.2d, v3.4s, v4.4s
; CHECK-NEXT: add v0.2d, v5.2d, v0.2d
; CHECK-NEXT: ret
entry:
%a.wide = sext <8 x i16> %a to <8 x i64>
%b.wide = zext <8 x i8> %b to <8 x i64>
%mult = mul nuw nsw <8 x i64> %a.wide, %b.wide
%partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add.v2i64.v8i64(<2 x i64> %acc, <8 x i64> %mult)
ret <2 x i64> %partial.reduce
}
Loading

0 comments on commit 3134045

Please sign in to comment.