From 31340457399d218c27a7a74770eb9fa03e6ae92b Mon Sep 17 00:00:00 2001 From: James Chesterman Date: Mon, 6 Jan 2025 10:51:47 +0000 Subject: [PATCH] [AArch64][SVE] Add dot product codegen for partial reductions with no binary operation on input (#120207) Add codegen for when the input type has 4 times as many elements as the output type and the input to the partial reduction does not have a binary operation performed on it. --- .../Target/AArch64/AArch64ISelLowering.cpp | 44 ++-- .../neon-partial-reduce-dot-product.ll | 248 ++++++++++++++++++ .../AArch64/sve-partial-reduce-dot-product.ll | 208 +++++++++++++++ 3 files changed, 483 insertions(+), 17 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index c965659b0fef111..ef00b092fe5e060 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -21987,21 +21987,35 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, SDLoc DL(N); SDValue Op2 = N->getOperand(2); - if (Op2->getOpcode() != ISD::MUL || - !ISD::isExtOpcode(Op2->getOperand(0)->getOpcode()) || - !ISD::isExtOpcode(Op2->getOperand(1)->getOpcode())) - return SDValue(); + unsigned Op2Opcode = Op2->getOpcode(); + SDValue MulOpLHS, MulOpRHS; + bool MulOpLHSIsSigned, MulOpRHSIsSigned; + if (ISD::isExtOpcode(Op2Opcode)) { + MulOpLHSIsSigned = MulOpRHSIsSigned = (Op2Opcode == ISD::SIGN_EXTEND); + MulOpLHS = Op2->getOperand(0); + MulOpRHS = DAG.getConstant(1, DL, MulOpLHS.getValueType()); + } else if (Op2Opcode == ISD::MUL) { + SDValue ExtMulOpLHS = Op2->getOperand(0); + SDValue ExtMulOpRHS = Op2->getOperand(1); + + unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode(); + unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode(); + if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) || + !ISD::isExtOpcode(ExtMulOpRHSOpcode)) + return SDValue(); - SDValue Acc = N->getOperand(1); - SDValue Mul = N->getOperand(2); - SDValue ExtMulOpLHS = Mul->getOperand(0); - SDValue ExtMulOpRHS = Mul->getOperand(1); + MulOpLHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND; + MulOpRHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND; - SDValue MulOpLHS = ExtMulOpLHS->getOperand(0); - SDValue MulOpRHS = ExtMulOpRHS->getOperand(0); - if (MulOpLHS.getValueType() != MulOpRHS.getValueType()) + MulOpLHS = ExtMulOpLHS->getOperand(0); + MulOpRHS = ExtMulOpRHS->getOperand(0); + + if (MulOpLHS.getValueType() != MulOpRHS.getValueType()) + return SDValue(); + } else return SDValue(); + SDValue Acc = N->getOperand(1); EVT ReducedVT = N->getValueType(0); EVT MulSrcVT = MulOpLHS.getValueType(); @@ -22015,8 +22029,6 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, !(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8)) return SDValue(); - bool MulOpLHSIsSigned = ExtMulOpLHS->getOpcode() == ISD::SIGN_EXTEND; - bool MulOpRHSIsSigned = ExtMulOpRHS->getOpcode() == ISD::SIGN_EXTEND; // If the extensions are mixed, we should lower it to a usdot instead unsigned Opcode = 0; if (MulOpLHSIsSigned != MulOpRHSIsSigned) { @@ -22032,10 +22044,8 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, // USDOT expects the signed operand to be last if (!MulOpRHSIsSigned) std::swap(MulOpLHS, MulOpRHS); - } else if (MulOpLHSIsSigned) - Opcode = AArch64ISD::SDOT; - else - Opcode = AArch64ISD::UDOT; + } else + Opcode = MulOpLHSIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT; // Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot // product followed by a zero / sign extension diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll index c1b9a4c9dbb7978..9ece9edb843439d 100644 --- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll +++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll @@ -367,6 +367,166 @@ entry: ret <4 x i64> %partial.reduce } +define <4 x i32> @udot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){ +; CHECK-DOT-LABEL: udot_no_bin_op: +; CHECK-DOT: // %bb.0: +; CHECK-DOT-NEXT: movi v2.16b, #1 +; CHECK-DOT-NEXT: udot v0.4s, v1.16b, v2.16b +; CHECK-DOT-NEXT: ret +; +; CHECK-NODOT-LABEL: udot_no_bin_op: +; CHECK-NODOT: // %bb.0: +; CHECK-NODOT-NEXT: ushll v2.8h, v1.8b, #0 +; CHECK-NODOT-NEXT: ushll2 v1.8h, v1.16b, #0 +; CHECK-NODOT-NEXT: ushll v3.4s, v1.4h, #0 +; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v2.4h +; CHECK-NODOT-NEXT: uaddw2 v2.4s, v3.4s, v2.8h +; CHECK-NODOT-NEXT: uaddw2 v0.4s, v0.4s, v1.8h +; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s +; CHECK-NODOT-NEXT: ret + %a.wide = zext <16 x i8> %a to <16 x i32> + %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %a.wide) + ret <4 x i32> %partial.reduce +} + +define <4 x i32> @sdot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){ +; CHECK-DOT-LABEL: sdot_no_bin_op: +; CHECK-DOT: // %bb.0: +; CHECK-DOT-NEXT: movi v2.16b, #1 +; CHECK-DOT-NEXT: sdot v0.4s, v1.16b, v2.16b +; CHECK-DOT-NEXT: ret +; +; CHECK-NODOT-LABEL: sdot_no_bin_op: +; CHECK-NODOT: // %bb.0: +; CHECK-NODOT-NEXT: sshll v2.8h, v1.8b, #0 +; CHECK-NODOT-NEXT: sshll2 v1.8h, v1.16b, #0 +; CHECK-NODOT-NEXT: sshll v3.4s, v1.4h, #0 +; CHECK-NODOT-NEXT: saddw v0.4s, v0.4s, v2.4h +; CHECK-NODOT-NEXT: saddw2 v2.4s, v3.4s, v2.8h +; CHECK-NODOT-NEXT: saddw2 v0.4s, v0.4s, v1.8h +; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s +; CHECK-NODOT-NEXT: ret + %a.wide = sext <16 x i8> %a to <16 x i32> + %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %a.wide) + ret <4 x i32> %partial.reduce +} + +define <2 x i32> @udot_no_bin_op_narrow(<2 x i32> %acc, <8 x i8> %a){ +; CHECK-DOT-LABEL: udot_no_bin_op_narrow: +; CHECK-DOT: // %bb.0: +; CHECK-DOT-NEXT: movi v2.8b, #1 +; CHECK-DOT-NEXT: udot v0.2s, v1.8b, v2.8b +; CHECK-DOT-NEXT: ret +; +; CHECK-NODOT-LABEL: udot_no_bin_op_narrow: +; CHECK-NODOT: // %bb.0: +; CHECK-NODOT-NEXT: ushll v1.8h, v1.8b, #0 +; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0 +; CHECK-NODOT-NEXT: ushll v2.4s, v1.4h, #0 +; CHECK-NODOT-NEXT: ushll2 v3.4s, v1.8h, #0 +; CHECK-NODOT-NEXT: ext v4.16b, v1.16b, v1.16b, #8 +; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v1.4h +; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8 +; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8 +; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s +; CHECK-NODOT-NEXT: uaddw v1.4s, v2.4s, v4.4h +; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s +; CHECK-NODOT-NEXT: ret + %a.wide = zext <8 x i8> %a to <8 x i32> + %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v2i32.v8i32(<2 x i32> %acc, <8 x i32> %a.wide) + ret <2 x i32> %partial.reduce +} + +define <2 x i32> @sdot_no_bin_op_narrow(<2 x i32> %acc, <8 x i8> %a){ +; CHECK-DOT-LABEL: sdot_no_bin_op_narrow: +; CHECK-DOT: // %bb.0: +; CHECK-DOT-NEXT: movi v2.8b, #1 +; CHECK-DOT-NEXT: sdot v0.2s, v1.8b, v2.8b +; CHECK-DOT-NEXT: ret +; +; CHECK-NODOT-LABEL: sdot_no_bin_op_narrow: +; CHECK-NODOT: // %bb.0: +; CHECK-NODOT-NEXT: sshll v1.8h, v1.8b, #0 +; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0 +; CHECK-NODOT-NEXT: sshll v2.4s, v1.4h, #0 +; CHECK-NODOT-NEXT: sshll2 v3.4s, v1.8h, #0 +; CHECK-NODOT-NEXT: ext v4.16b, v1.16b, v1.16b, #8 +; CHECK-NODOT-NEXT: saddw v0.4s, v0.4s, v1.4h +; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8 +; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8 +; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s +; CHECK-NODOT-NEXT: saddw v1.4s, v2.4s, v4.4h +; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s +; CHECK-NODOT-NEXT: ret + %a.wide = sext <8 x i8> %a to <8 x i32> + %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v2i32.v8i32(<2 x i32> %acc, <8 x i32> %a.wide) + ret <2 x i32> %partial.reduce +} + +define <4 x i64> @udot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){ +; CHECK-DOT-LABEL: udot_no_bin_op_8to64: +; CHECK-DOT: // %bb.0: +; CHECK-DOT-NEXT: movi v3.16b, #1 +; CHECK-DOT-NEXT: movi v4.2d, #0000000000000000 +; CHECK-DOT-NEXT: udot v4.4s, v2.16b, v3.16b +; CHECK-DOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s +; CHECK-DOT-NEXT: saddw v0.2d, v0.2d, v4.2s +; CHECK-DOT-NEXT: ret +; +; CHECK-NODOT-LABEL: udot_no_bin_op_8to64: +; CHECK-NODOT: // %bb.0: +; CHECK-NODOT-NEXT: ushll v3.8h, v2.8b, #0 +; CHECK-NODOT-NEXT: ushll2 v2.8h, v2.16b, #0 +; CHECK-NODOT-NEXT: ushll v4.4s, v3.4h, #0 +; CHECK-NODOT-NEXT: ushll v5.4s, v2.4h, #0 +; CHECK-NODOT-NEXT: ushll2 v3.4s, v3.8h, #0 +; CHECK-NODOT-NEXT: ushll2 v2.4s, v2.8h, #0 +; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v4.4s +; CHECK-NODOT-NEXT: uaddw v0.2d, v0.2d, v4.2s +; CHECK-NODOT-NEXT: uaddl2 v4.2d, v3.4s, v5.4s +; CHECK-NODOT-NEXT: uaddl v3.2d, v3.2s, v5.2s +; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v2.4s +; CHECK-NODOT-NEXT: uaddw v0.2d, v0.2d, v2.2s +; CHECK-NODOT-NEXT: add v1.2d, v4.2d, v1.2d +; CHECK-NODOT-NEXT: add v0.2d, v3.2d, v0.2d +; CHECK-NODOT-NEXT: ret + %a.wide = zext <16 x i8> %a to <16 x i64> + %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(<4 x i64> %acc, <16 x i64> %a.wide) + ret <4 x i64> %partial.reduce +} + +define <4 x i64> @sdot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){ +; CHECK-DOT-LABEL: sdot_no_bin_op_8to64: +; CHECK-DOT: // %bb.0: +; CHECK-DOT-NEXT: movi v3.16b, #1 +; CHECK-DOT-NEXT: movi v4.2d, #0000000000000000 +; CHECK-DOT-NEXT: sdot v4.4s, v2.16b, v3.16b +; CHECK-DOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s +; CHECK-DOT-NEXT: saddw v0.2d, v0.2d, v4.2s +; CHECK-DOT-NEXT: ret +; +; CHECK-NODOT-LABEL: sdot_no_bin_op_8to64: +; CHECK-NODOT: // %bb.0: +; CHECK-NODOT-NEXT: sshll v3.8h, v2.8b, #0 +; CHECK-NODOT-NEXT: sshll2 v2.8h, v2.16b, #0 +; CHECK-NODOT-NEXT: sshll v4.4s, v3.4h, #0 +; CHECK-NODOT-NEXT: sshll v5.4s, v2.4h, #0 +; CHECK-NODOT-NEXT: sshll2 v3.4s, v3.8h, #0 +; CHECK-NODOT-NEXT: sshll2 v2.4s, v2.8h, #0 +; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s +; CHECK-NODOT-NEXT: saddw v0.2d, v0.2d, v4.2s +; CHECK-NODOT-NEXT: saddl2 v4.2d, v3.4s, v5.4s +; CHECK-NODOT-NEXT: saddl v3.2d, v3.2s, v5.2s +; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v2.4s +; CHECK-NODOT-NEXT: saddw v0.2d, v0.2d, v2.2s +; CHECK-NODOT-NEXT: add v1.2d, v4.2d, v1.2d +; CHECK-NODOT-NEXT: add v0.2d, v3.2d, v0.2d +; CHECK-NODOT-NEXT: ret + %a.wide = sext <16 x i8> %a to <16 x i64> + %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(<4 x i64> %acc, <16 x i64> %a.wide) + ret <4 x i64> %partial.reduce +} + define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{ ; CHECK-LABEL: not_udot: ; CHECK: // %bb.0: @@ -398,3 +558,91 @@ define <2 x i32> @not_udot_narrow(<2 x i32> %acc, <4 x i8> %u, <4 x i8> %s) { %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <4 x i32> %mult) ret <2 x i32> %partial.reduce } + +define <2 x i64> @udot_different_types(<2 x i64> %acc, <8 x i16> %a, <8 x i8> %b){ +; CHECK-LABEL: udot_different_types: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ushll v2.8h, v2.8b, #0 +; CHECK-NEXT: ushll v3.4s, v1.4h, #0 +; CHECK-NEXT: ushll2 v1.4s, v1.8h, #0 +; CHECK-NEXT: ushll v4.4s, v2.4h, #0 +; CHECK-NEXT: ushll2 v2.4s, v2.8h, #0 +; CHECK-NEXT: umull v5.2d, v1.2s, v2.2s +; CHECK-NEXT: umlal v0.2d, v3.2s, v4.2s +; CHECK-NEXT: umlal2 v0.2d, v1.4s, v2.4s +; CHECK-NEXT: umlal2 v5.2d, v3.4s, v4.4s +; CHECK-NEXT: add v0.2d, v5.2d, v0.2d +; CHECK-NEXT: ret +entry: + %a.wide = zext <8 x i16> %a to <8 x i64> + %b.wide = zext <8 x i8> %b to <8 x i64> + %mult = mul nuw nsw <8 x i64> %a.wide, %b.wide + %partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add.v2i64.v8i64(<2 x i64> %acc, <8 x i64> %mult) + ret <2 x i64> %partial.reduce +} + +define <2 x i64> @sdot_different_types(<2 x i64> %acc, <8 x i16> %a, <8 x i8> %b){ +; CHECK-LABEL: sdot_different_types: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: sshll v2.8h, v2.8b, #0 +; CHECK-NEXT: sshll v3.4s, v1.4h, #0 +; CHECK-NEXT: sshll2 v1.4s, v1.8h, #0 +; CHECK-NEXT: sshll v4.4s, v2.4h, #0 +; CHECK-NEXT: sshll2 v2.4s, v2.8h, #0 +; CHECK-NEXT: smull v5.2d, v1.2s, v2.2s +; CHECK-NEXT: smlal v0.2d, v3.2s, v4.2s +; CHECK-NEXT: smlal2 v0.2d, v1.4s, v2.4s +; CHECK-NEXT: smlal2 v5.2d, v3.4s, v4.4s +; CHECK-NEXT: add v0.2d, v5.2d, v0.2d +; CHECK-NEXT: ret +entry: + %a.wide = sext <8 x i16> %a to <8 x i64> + %b.wide = sext <8 x i8> %b to <8 x i64> + %mult = mul nuw nsw <8 x i64> %a.wide, %b.wide + %partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add.v2i64.v8i64(<2 x i64> %acc, <8 x i64> %mult) + ret <2 x i64> %partial.reduce +} + +define <2 x i64> @usdot_different_types(<2 x i64> %acc, <8 x i16> %a, <8 x i8> %b){ +; CHECK-LABEL: usdot_different_types: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: sshll v2.8h, v2.8b, #0 +; CHECK-NEXT: ushll v3.4s, v1.4h, #0 +; CHECK-NEXT: ushll2 v1.4s, v1.8h, #0 +; CHECK-NEXT: sshll v4.4s, v2.4h, #0 +; CHECK-NEXT: sshll2 v2.4s, v2.8h, #0 +; CHECK-NEXT: smull v5.2d, v1.2s, v2.2s +; CHECK-NEXT: smlal v0.2d, v3.2s, v4.2s +; CHECK-NEXT: smlal2 v0.2d, v1.4s, v2.4s +; CHECK-NEXT: smlal2 v5.2d, v3.4s, v4.4s +; CHECK-NEXT: add v0.2d, v5.2d, v0.2d +; CHECK-NEXT: ret +entry: + %a.wide = zext <8 x i16> %a to <8 x i64> + %b.wide = sext <8 x i8> %b to <8 x i64> + %mult = mul nuw nsw <8 x i64> %a.wide, %b.wide + %partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add.v2i64.v8i64(<2 x i64> %acc, <8 x i64> %mult) + ret <2 x i64> %partial.reduce +} + +define <2 x i64> @sudot_different_types(<2 x i64> %acc, <8 x i16> %a, <8 x i8> %b){ +; CHECK-LABEL: sudot_different_types: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ushll v2.8h, v2.8b, #0 +; CHECK-NEXT: sshll v3.4s, v1.4h, #0 +; CHECK-NEXT: sshll2 v1.4s, v1.8h, #0 +; CHECK-NEXT: ushll v4.4s, v2.4h, #0 +; CHECK-NEXT: ushll2 v2.4s, v2.8h, #0 +; CHECK-NEXT: smull v5.2d, v1.2s, v2.2s +; CHECK-NEXT: smlal v0.2d, v3.2s, v4.2s +; CHECK-NEXT: smlal2 v0.2d, v1.4s, v2.4s +; CHECK-NEXT: smlal2 v5.2d, v3.4s, v4.4s +; CHECK-NEXT: add v0.2d, v5.2d, v0.2d +; CHECK-NEXT: ret +entry: + %a.wide = sext <8 x i16> %a to <8 x i64> + %b.wide = zext <8 x i8> %b to <8 x i64> + %mult = mul nuw nsw <8 x i64> %a.wide, %b.wide + %partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add.v2i64.v8i64(<2 x i64> %acc, <8 x i64> %mult) + ret <2 x i64> %partial.reduce +} diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll index 66d6e0388bbf94e..66f83c658ff4f25 100644 --- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll +++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll @@ -316,6 +316,84 @@ entry: ret %partial.reduce } +define @udot_no_bin_op( %acc, %a){ +; CHECK-LABEL: udot_no_bin_op: +; CHECK: // %bb.0: +; CHECK-NEXT: mov z2.b, #1 // =0x1 +; CHECK-NEXT: udot z0.s, z1.b, z2.b +; CHECK-NEXT: ret + %a.ext = zext %a to + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( %acc, %a.ext) + ret %partial.reduce +} + +define @sdot_no_bin_op( %acc, %a){ +; CHECK-LABEL: sdot_no_bin_op: +; CHECK: // %bb.0: +; CHECK-NEXT: mov z2.b, #1 // =0x1 +; CHECK-NEXT: sdot z0.s, z1.b, z2.b +; CHECK-NEXT: ret + %a.ext = sext %a to + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32( %acc, %a.ext) + ret %partial.reduce +} + +define @udot_no_bin_op_wide( %acc, %a, %b){ +; CHECK-LABEL: udot_no_bin_op_wide: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: mov z2.h, #1 // =0x1 +; CHECK-NEXT: udot z0.d, z1.h, z2.h +; CHECK-NEXT: ret +entry: + %a.wide = zext %a to + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64( %acc, %a.wide) + ret %partial.reduce +} + +define @sdot_no_bin_op_wide( %acc, %a, %b){ +; CHECK-LABEL: sdot_no_bin_op_wide: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: mov z2.h, #1 // =0x1 +; CHECK-NEXT: sdot z0.d, z1.h, z2.h +; CHECK-NEXT: ret +entry: + %a.wide = sext %a to + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64( %acc, %a.wide) + ret %partial.reduce +} + +define @udot_no_bin_op_8to64( %acc, %a){ +; CHECK-LABEL: udot_no_bin_op_8to64: +; CHECK: // %bb.0: +; CHECK-NEXT: mov z3.b, #1 // =0x1 +; CHECK-NEXT: mov z4.s, #0 // =0x0 +; CHECK-NEXT: udot z4.s, z2.b, z3.b +; CHECK-NEXT: sunpklo z2.d, z4.s +; CHECK-NEXT: sunpkhi z3.d, z4.s +; CHECK-NEXT: add z0.d, z0.d, z2.d +; CHECK-NEXT: add z1.d, z1.d, z3.d +; CHECK-NEXT: ret + %a.ext = zext %a to + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64( %acc, %a.ext) + ret %partial.reduce +} + +define @sdot_no_bin_op_8to64( %acc, %a){ +; CHECK-LABEL: sdot_no_bin_op_8to64: +; CHECK: // %bb.0: +; CHECK-NEXT: mov z3.b, #1 // =0x1 +; CHECK-NEXT: mov z4.s, #0 // =0x0 +; CHECK-NEXT: sdot z4.s, z2.b, z3.b +; CHECK-NEXT: sunpklo z2.d, z4.s +; CHECK-NEXT: sunpkhi z3.d, z4.s +; CHECK-NEXT: add z0.d, z0.d, z2.d +; CHECK-NEXT: add z1.d, z1.d, z3.d +; CHECK-NEXT: ret + %a.ext = sext %a to + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64( %acc, %a.ext) + ret %partial.reduce +} + define @not_udot( %acc, %a, %b) { ; CHECK-LABEL: not_udot: ; CHECK: // %bb.0: // %entry @@ -419,3 +497,133 @@ entry: %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64( %acc, %mult) ret %partial.reduce } + +define @udot_different_types( %acc, %a, %b){ +; CHECK-LABEL: udot_different_types: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: and z2.h, z2.h, #0xff +; CHECK-NEXT: uunpklo z3.s, z1.h +; CHECK-NEXT: uunpkhi z1.s, z1.h +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: uunpklo z4.s, z2.h +; CHECK-NEXT: uunpkhi z2.s, z2.h +; CHECK-NEXT: uunpklo z5.d, z3.s +; CHECK-NEXT: uunpkhi z3.d, z3.s +; CHECK-NEXT: uunpklo z7.d, z1.s +; CHECK-NEXT: uunpkhi z1.d, z1.s +; CHECK-NEXT: uunpklo z6.d, z4.s +; CHECK-NEXT: uunpkhi z4.d, z4.s +; CHECK-NEXT: uunpklo z24.d, z2.s +; CHECK-NEXT: uunpkhi z2.d, z2.s +; CHECK-NEXT: mul z3.d, z3.d, z4.d +; CHECK-NEXT: mla z0.d, p0/m, z5.d, z6.d +; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: movprfx z1, z3 +; CHECK-NEXT: mla z1.d, p0/m, z7.d, z24.d +; CHECK-NEXT: add z0.d, z1.d, z0.d +; CHECK-NEXT: ret +entry: + %a.wide = zext %a to + %b.wide = zext %b to + %mult = mul nuw nsw %a.wide, %b.wide + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64( %acc, %mult) + ret %partial.reduce +} + +define @sdot_different_types( %acc, %a, %b){ +; CHECK-LABEL: sdot_different_types: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: sunpklo z3.s, z1.h +; CHECK-NEXT: sunpkhi z1.s, z1.h +; CHECK-NEXT: sxtb z2.h, p0/m, z2.h +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: sunpklo z5.d, z3.s +; CHECK-NEXT: sunpkhi z3.d, z3.s +; CHECK-NEXT: sunpklo z7.d, z1.s +; CHECK-NEXT: sunpklo z4.s, z2.h +; CHECK-NEXT: sunpkhi z2.s, z2.h +; CHECK-NEXT: sunpkhi z1.d, z1.s +; CHECK-NEXT: sunpklo z6.d, z4.s +; CHECK-NEXT: sunpkhi z4.d, z4.s +; CHECK-NEXT: sunpklo z24.d, z2.s +; CHECK-NEXT: sunpkhi z2.d, z2.s +; CHECK-NEXT: mul z3.d, z3.d, z4.d +; CHECK-NEXT: mla z0.d, p0/m, z5.d, z6.d +; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: movprfx z1, z3 +; CHECK-NEXT: mla z1.d, p0/m, z7.d, z24.d +; CHECK-NEXT: add z0.d, z1.d, z0.d +; CHECK-NEXT: ret +entry: + %a.wide = sext %a to + %b.wide = sext %b to + %mult = mul nuw nsw %a.wide, %b.wide + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64( %acc, %mult) + ret %partial.reduce +} + +define @usdot_different_types( %acc, %a, %b){ +; CHECK-LABEL: usdot_different_types: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ptrue p0.h +; CHECK-NEXT: uunpklo z3.s, z1.h +; CHECK-NEXT: uunpkhi z1.s, z1.h +; CHECK-NEXT: sxtb z2.h, p0/m, z2.h +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: uunpklo z5.d, z3.s +; CHECK-NEXT: uunpkhi z3.d, z3.s +; CHECK-NEXT: uunpklo z7.d, z1.s +; CHECK-NEXT: sunpklo z4.s, z2.h +; CHECK-NEXT: sunpkhi z2.s, z2.h +; CHECK-NEXT: uunpkhi z1.d, z1.s +; CHECK-NEXT: sunpklo z6.d, z4.s +; CHECK-NEXT: sunpkhi z4.d, z4.s +; CHECK-NEXT: sunpklo z24.d, z2.s +; CHECK-NEXT: sunpkhi z2.d, z2.s +; CHECK-NEXT: mul z3.d, z3.d, z4.d +; CHECK-NEXT: mla z0.d, p0/m, z5.d, z6.d +; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: movprfx z1, z3 +; CHECK-NEXT: mla z1.d, p0/m, z7.d, z24.d +; CHECK-NEXT: add z0.d, z1.d, z0.d +; CHECK-NEXT: ret +entry: + %a.wide = zext %a to + %b.wide = sext %b to + %mult = mul nuw nsw %a.wide, %b.wide + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64( %acc, %mult) + ret %partial.reduce +} + +define @sudot_different_types( %acc, %a, %b){ +; CHECK-LABEL: sudot_different_types: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: and z2.h, z2.h, #0xff +; CHECK-NEXT: sunpklo z3.s, z1.h +; CHECK-NEXT: sunpkhi z1.s, z1.h +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: uunpklo z4.s, z2.h +; CHECK-NEXT: uunpkhi z2.s, z2.h +; CHECK-NEXT: sunpklo z5.d, z3.s +; CHECK-NEXT: sunpkhi z3.d, z3.s +; CHECK-NEXT: sunpklo z7.d, z1.s +; CHECK-NEXT: sunpkhi z1.d, z1.s +; CHECK-NEXT: uunpklo z6.d, z4.s +; CHECK-NEXT: uunpkhi z4.d, z4.s +; CHECK-NEXT: uunpklo z24.d, z2.s +; CHECK-NEXT: uunpkhi z2.d, z2.s +; CHECK-NEXT: mul z3.d, z3.d, z4.d +; CHECK-NEXT: mla z0.d, p0/m, z5.d, z6.d +; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d +; CHECK-NEXT: movprfx z1, z3 +; CHECK-NEXT: mla z1.d, p0/m, z7.d, z24.d +; CHECK-NEXT: add z0.d, z1.d, z0.d +; CHECK-NEXT: ret +entry: + %a.wide = sext %a to + %b.wide = zext %b to + %mult = mul nuw nsw %a.wide, %b.wide + %partial.reduce = tail call @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64( %acc, %mult) + ret %partial.reduce +}