Skip to content

Commit 22cb90f

Browse files
ghehglanza
authored andcommitted
[CIR][CIRGen][Builtin][Neon] Lower neon_vshll_n (#1010)
1 parent cf18966 commit 22cb90f

File tree

2 files changed

+113
-53
lines changed

2 files changed

+113
-53
lines changed

clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp

+25-3
Original file line numberDiff line numberDiff line change
@@ -2239,6 +2239,19 @@ static mlir::Value buildNeonShiftVector(CIRGenBuilderTy &builder,
22392239
return builder.create<mlir::cir::ConstantOp>(loc, vecTy, constVecAttr);
22402240
}
22412241

2242+
/// Build ShiftOp of vector type whose shift amount is a vector built
2243+
/// from a constant integer using `buildNeonShiftVector` function
2244+
static mlir::Value buildCommonNeonShift(CIRGenBuilderTy &builder,
2245+
mlir::Location loc,
2246+
mlir::cir::VectorType resTy,
2247+
mlir::Value shifTgt,
2248+
mlir::Value shiftAmt, bool shiftLeft,
2249+
bool negAmt = false) {
2250+
shiftAmt = buildNeonShiftVector(builder, shiftAmt, resTy, loc, negAmt);
2251+
return builder.create<mlir::cir::ShiftOp>(
2252+
loc, resTy, builder.createBitcast(shifTgt, resTy), shiftAmt, shiftLeft);
2253+
}
2254+
22422255
mlir::Value CIRGenFunction::buildCommonNeonBuiltinExpr(
22432256
unsigned builtinID, unsigned llvmIntrinsic, unsigned altLLVMIntrinsic,
22442257
const char *nameHint, unsigned modifier, const CallExpr *e,
@@ -2326,9 +2339,18 @@ mlir::Value CIRGenFunction::buildCommonNeonBuiltinExpr(
23262339
case NEON::BI__builtin_neon_vshl_n_v:
23272340
case NEON::BI__builtin_neon_vshlq_n_v: {
23282341
mlir::Location loc = getLoc(e->getExprLoc());
2329-
ops[1] = buildNeonShiftVector(builder, ops[1], vTy, loc, false);
2330-
return builder.create<mlir::cir::ShiftOp>(
2331-
loc, vTy, builder.createBitcast(ops[0], vTy), ops[1], true);
2342+
return buildCommonNeonShift(builder, loc, vTy, ops[0], ops[1], true);
2343+
}
2344+
case NEON::BI__builtin_neon_vshll_n_v: {
2345+
mlir::Location loc = getLoc(e->getExprLoc());
2346+
mlir::cir::VectorType srcTy =
2347+
builder.getExtendedOrTruncatedElementVectorType(
2348+
vTy, false /* truncate */,
2349+
mlir::cast<mlir::cir::IntType>(vTy.getEltType()).isSigned());
2350+
ops[0] = builder.createBitcast(ops[0], srcTy);
2351+
// The following cast will be lowered to SExt or ZExt in LLVM.
2352+
ops[0] = builder.createIntCast(ops[0], vTy);
2353+
return buildCommonNeonShift(builder, loc, vTy, ops[0], ops[1], true);
23322354
}
23332355
}
23342356

clang/test/CIR/CodeGen/AArch64/neon.c

+88-50
Original file line numberDiff line numberDiff line change
@@ -6586,61 +6586,99 @@ uint32x2_t test_vqrshrun_n_s64(int64x2_t a) {
65866586
// return vqrshrn_high_n_u64(a, b, 19);
65876587
// }
65886588

6589-
// NYI-LABEL: @test_vshll_n_s8(
6590-
// NYI: [[TMP0:%.*]] = sext <8 x i8> %a to <8 x i16>
6591-
// NYI: [[VSHLL_N:%.*]] = shl <8 x i16> [[TMP0]], <i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3>
6592-
// NYI: ret <8 x i16> [[VSHLL_N]]
6593-
// int16x8_t test_vshll_n_s8(int8x8_t a) {
6594-
// return vshll_n_s8(a, 3);
6595-
// }
6589+
int16x8_t test_vshll_n_s8(int8x8_t a) {
6590+
return vshll_n_s8(a, 3);
6591+
6592+
// CIR-LABEL: vshll_n_s8
6593+
// CIR: [[SHIFT_TGT:%.*]] = cir.cast(integral, {{%.*}} : !cir.vector<!s8i x 8>), !cir.vector<!s16i x 8>
6594+
// CIR: [[SHIFT_AMT:%.*]] = cir.const #cir.const_vector<[#cir.int<3> : !s16i, #cir.int<3> : !s16i, #cir.int<3> : !s16i, #cir.int<3> : !s16i,
6595+
// CIR-SAME: #cir.int<3> : !s16i, #cir.int<3> : !s16i, #cir.int<3> : !s16i, #cir.int<3> : !s16i]> : !cir.vector<!s16i x 8>
6596+
// CIR: {{%.*}} = cir.shift(left, [[SHIFT_TGT]] : !cir.vector<!s16i x 8>, [[SHIFT_AMT]] : !cir.vector<!s16i x 8>) -> !cir.vector<!s16i x 8>
6597+
6598+
// LLVM: {{.*}}@test_vshll_n_s8(<8 x i8>{{.*}}[[A:%.*]])
6599+
// LLVM: [[TMP0:%.*]] = sext <8 x i8> [[A]] to <8 x i16>
6600+
// LLVM: [[VSHLL_N:%.*]] = shl <8 x i16> [[TMP0]], splat (i16 3)
6601+
// LLVM: ret <8 x i16> [[VSHLL_N]]
6602+
}
65966603

6597-
// NYI-LABEL: @test_vshll_n_s16(
6598-
// NYI: [[TMP0:%.*]] = bitcast <4 x i16> %a to <8 x i8>
6599-
// NYI: [[TMP1:%.*]] = bitcast <8 x i8> [[TMP0]] to <4 x i16>
6600-
// NYI: [[TMP2:%.*]] = sext <4 x i16> [[TMP1]] to <4 x i32>
6601-
// NYI: [[VSHLL_N:%.*]] = shl <4 x i32> [[TMP2]], <i32 9, i32 9, i32 9, i32 9>
6602-
// NYI: ret <4 x i32> [[VSHLL_N]]
6603-
// int32x4_t test_vshll_n_s16(int16x4_t a) {
6604-
// return vshll_n_s16(a, 9);
6605-
// }
6604+
int32x4_t test_vshll_n_s16(int16x4_t a) {
6605+
return vshll_n_s16(a, 9);
66066606

6607-
// NYI-LABEL: @test_vshll_n_s32(
6608-
// NYI: [[TMP0:%.*]] = bitcast <2 x i32> %a to <8 x i8>
6609-
// NYI: [[TMP1:%.*]] = bitcast <8 x i8> [[TMP0]] to <2 x i32>
6610-
// NYI: [[TMP2:%.*]] = sext <2 x i32> [[TMP1]] to <2 x i64>
6611-
// NYI: [[VSHLL_N:%.*]] = shl <2 x i64> [[TMP2]], <i64 19, i64 19>
6612-
// NYI: ret <2 x i64> [[VSHLL_N]]
6613-
// int64x2_t test_vshll_n_s32(int32x2_t a) {
6614-
// return vshll_n_s32(a, 19);
6615-
// }
6607+
// CIR-LABEL: vshll_n_s16
6608+
// CIR: [[SHIFT_TGT:%.*]] = cir.cast(integral, {{%.*}} : !cir.vector<!s16i x 4>), !cir.vector<!s32i x 4>
6609+
// CIR: [[SHIFT_AMT:%.*]] = cir.const #cir.const_vector<[#cir.int<9> : !s32i, #cir.int<9> : !s32i, #cir.int<9> :
6610+
// CIR-SAME: !s32i, #cir.int<9> : !s32i]> : !cir.vector<!s32i x 4>
6611+
// CIR: {{%.*}} = cir.shift(left, [[SHIFT_TGT]] : !cir.vector<!s32i x 4>, [[SHIFT_AMT]] : !cir.vector<!s32i x 4>) -> !cir.vector<!s32i x 4>
66166612

6617-
// NYI-LABEL: @test_vshll_n_u8(
6618-
// NYI: [[TMP0:%.*]] = zext <8 x i8> %a to <8 x i16>
6619-
// NYI: [[VSHLL_N:%.*]] = shl <8 x i16> [[TMP0]], <i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3>
6620-
// NYI: ret <8 x i16> [[VSHLL_N]]
6621-
// uint16x8_t test_vshll_n_u8(uint8x8_t a) {
6622-
// return vshll_n_u8(a, 3);
6623-
// }
6613+
// LLVM: {{.*}}@test_vshll_n_s16(<4 x i16>{{.*}}[[A:%.*]])
6614+
// LLVM: [[TMP0:%.*]] = bitcast <4 x i16> [[A]] to <8 x i8>
6615+
// LLVM: [[TMP1:%.*]] = bitcast <8 x i8> [[TMP0]] to <4 x i16>
6616+
// LLVM: [[TMP2:%.*]] = sext <4 x i16> [[TMP1]] to <4 x i32>
6617+
// LLVM: [[VSHLL_N:%.*]] = shl <4 x i32> [[TMP2]], splat (i32 9)
6618+
// LLVM: ret <4 x i32> [[VSHLL_N]]
6619+
}
66246620

6625-
// NYI-LABEL: @test_vshll_n_u16(
6626-
// NYI: [[TMP0:%.*]] = bitcast <4 x i16> %a to <8 x i8>
6627-
// NYI: [[TMP1:%.*]] = bitcast <8 x i8> [[TMP0]] to <4 x i16>
6628-
// NYI: [[TMP2:%.*]] = zext <4 x i16> [[TMP1]] to <4 x i32>
6629-
// NYI: [[VSHLL_N:%.*]] = shl <4 x i32> [[TMP2]], <i32 9, i32 9, i32 9, i32 9>
6630-
// NYI: ret <4 x i32> [[VSHLL_N]]
6631-
// uint32x4_t test_vshll_n_u16(uint16x4_t a) {
6632-
// return vshll_n_u16(a, 9);
6633-
// }
6621+
int64x2_t test_vshll_n_s32(int32x2_t a) {
6622+
return vshll_n_s32(a, 19);
66346623

6635-
// NYI-LABEL: @test_vshll_n_u32(
6636-
// NYI: [[TMP0:%.*]] = bitcast <2 x i32> %a to <8 x i8>
6637-
// NYI: [[TMP1:%.*]] = bitcast <8 x i8> [[TMP0]] to <2 x i32>
6638-
// NYI: [[TMP2:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i64>
6639-
// NYI: [[VSHLL_N:%.*]] = shl <2 x i64> [[TMP2]], <i64 19, i64 19>
6640-
// NYI: ret <2 x i64> [[VSHLL_N]]
6641-
// uint64x2_t test_vshll_n_u32(uint32x2_t a) {
6642-
// return vshll_n_u32(a, 19);
6643-
// }
6624+
// CIR-LABEL: vshll_n_s32
6625+
// CIR: [[SHIFT_TGT:%.*]] = cir.cast(integral, {{%.*}} : !cir.vector<!s32i x 2>), !cir.vector<!s64i x 2>
6626+
// CIR: [[SHIFT_AMT:%.*]] = cir.const #cir.const_vector<[#cir.int<19> : !s64i, #cir.int<19> : !s64i]> : !cir.vector<!s64i x 2>
6627+
// CIR: {{%.*}} = cir.shift(left, [[SHIFT_TGT]] : !cir.vector<!s64i x 2>, [[SHIFT_AMT]] : !cir.vector<!s64i x 2>)
6628+
6629+
// LLVM: {{.*}}@test_vshll_n_s32(<2 x i32>{{.*}}[[A:%.*]])
6630+
// LLVM: [[TMP0:%.*]] = bitcast <2 x i32> [[A]] to <8 x i8>
6631+
// LLVM: [[TMP1:%.*]] = bitcast <8 x i8> [[TMP0]] to <2 x i32>
6632+
// LLVM: [[TMP2:%.*]] = sext <2 x i32> [[TMP1]] to <2 x i64>
6633+
// LLVM: [[VSHLL_N:%.*]] = shl <2 x i64> [[TMP2]], splat (i64 19)
6634+
// LLVM: ret <2 x i64> [[VSHLL_N]]
6635+
}
6636+
6637+
uint16x8_t test_vshll_n_u8(uint8x8_t a) {
6638+
return vshll_n_u8(a, 3);
6639+
6640+
// CIR-LABEL: vshll_n_u8
6641+
// CIR: [[SHIFT_TGT:%.*]] = cir.cast(integral, {{%.*}} : !cir.vector<!u8i x 8>), !cir.vector<!u16i x 8>
6642+
// CIR: [[SHIFT_AMT:%.*]] = cir.const #cir.const_vector<[#cir.int<3> : !u16i, #cir.int<3> : !u16i, #cir.int<3> : !u16i, #cir.int<3> : !u16i,
6643+
// CIR-SAME: #cir.int<3> : !u16i, #cir.int<3> : !u16i, #cir.int<3> : !u16i, #cir.int<3> : !u16i]> : !cir.vector<!u16i x 8>
6644+
// CIR: {{%.*}} = cir.shift(left, [[SHIFT_TGT]] : !cir.vector<!u16i x 8>, [[SHIFT_AMT]] : !cir.vector<!u16i x 8>)
6645+
6646+
// LLVM: {{.*}}@test_vshll_n_u8(<8 x i8>{{.*}}[[A:%.*]])
6647+
// LLVM: [[TMP0:%.*]] = zext <8 x i8> [[A]] to <8 x i16>
6648+
// LLVM: [[VSHLL_N:%.*]] = shl <8 x i16> [[TMP0]], splat (i16 3)
6649+
}
6650+
6651+
uint32x4_t test_vshll_n_u16(uint16x4_t a) {
6652+
return vshll_n_u16(a, 9);
6653+
6654+
// CIR-LABEL: vshll_n_u16
6655+
// CIR: [[SHIFT_TGT:%.*]] = cir.cast(integral, {{%.*}} : !cir.vector<!u16i x 4>), !cir.vector<!u32i x 4>
6656+
// CIR: [[SHIFT_AMT:%.*]] = cir.const #cir.const_vector<[#cir.int<9> : !u32i, #cir.int<9> : !u32i,
6657+
// CIR-SAME: #cir.int<9> : !u32i, #cir.int<9> : !u32i]> : !cir.vector<!u32i x 4>
6658+
6659+
// LLVM: {{.*}}@test_vshll_n_u16(<4 x i16>{{.*}}[[A:%.*]])
6660+
// LLVM: [[TMP0:%.*]] = bitcast <4 x i16> [[A]] to <8 x i8>
6661+
// LLVM: [[TMP1:%.*]] = bitcast <8 x i8> [[TMP0]] to <4 x i16>
6662+
// LLVM: [[TMP2:%.*]] = zext <4 x i16> [[TMP1]] to <4 x i32>
6663+
// LLVM: [[VSHLL_N:%.*]] = shl <4 x i32> [[TMP2]], splat (i32 9)
6664+
// LLVM: ret <4 x i32> [[VSHLL_N]]
6665+
}
6666+
6667+
uint64x2_t test_vshll_n_u32(uint32x2_t a) {
6668+
return vshll_n_u32(a, 19);
6669+
6670+
// CIR-LABEL: vshll_n_u32
6671+
// CIR: [[SHIFT_TGT:%.*]] = cir.cast(integral, {{%.*}} : !cir.vector<!u32i x 2>), !cir.vector<!u64i x 2>
6672+
// CIR: [[SHIFT_AMT:%.*]] = cir.const #cir.const_vector<[#cir.int<19> : !u64i, #cir.int<19> : !u64i]> : !cir.vector<!u64i x 2>
6673+
// CIR: {{%.*}} = cir.shift(left, [[SHIFT_TGT]] : !cir.vector<!u64i x 2>, [[SHIFT_AMT]] : !cir.vector<!u64i x 2>)
6674+
6675+
// LLVM: {{.*}}@test_vshll_n_u32(<2 x i32>{{.*}}[[A:%.*]])
6676+
// LLVM: [[TMP0:%.*]] = bitcast <2 x i32> [[A]] to <8 x i8>
6677+
// LLVM: [[TMP1:%.*]] = bitcast <8 x i8> [[TMP0]] to <2 x i32>
6678+
// LLVM: [[TMP2:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i64>
6679+
// LLVM: [[VSHLL_N:%.*]] = shl <2 x i64> [[TMP2]], splat (i64 19)
6680+
// LLVM: ret <2 x i64> [[VSHLL_N]]
6681+
}
66446682

66456683
// NYI-LABEL: @test_vshll_high_n_s8(
66466684
// NYI: [[SHUFFLE_I:%.*]] = shufflevector <16 x i8> %a, <16 x i8> %a, <8 x i32> <i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>

0 commit comments

Comments
 (0)