Skip to content

Commit

Permalink
[CIR][CIRGen][Builtin][Neon] Lower neon_vshll_n (#1010)
Browse files Browse the repository at this point in the history
  • Loading branch information
ghehg authored Oct 28, 2024
1 parent d7de21f commit c76c138
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 53 deletions.
28 changes: 25 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2241,6 +2241,19 @@ static mlir::Value buildNeonShiftVector(CIRGenBuilderTy &builder,
return builder.create<mlir::cir::ConstantOp>(loc, vecTy, constVecAttr);
}

/// Build ShiftOp of vector type whose shift amount is a vector built
/// from a constant integer using `buildNeonShiftVector` function
static mlir::Value buildCommonNeonShift(CIRGenBuilderTy &builder,
mlir::Location loc,
mlir::cir::VectorType resTy,
mlir::Value shifTgt,
mlir::Value shiftAmt, bool shiftLeft,
bool negAmt = false) {
shiftAmt = buildNeonShiftVector(builder, shiftAmt, resTy, loc, negAmt);
return builder.create<mlir::cir::ShiftOp>(
loc, resTy, builder.createBitcast(shifTgt, resTy), shiftAmt, shiftLeft);
}

mlir::Value CIRGenFunction::buildCommonNeonBuiltinExpr(
unsigned builtinID, unsigned llvmIntrinsic, unsigned altLLVMIntrinsic,
const char *nameHint, unsigned modifier, const CallExpr *e,
Expand Down Expand Up @@ -2328,9 +2341,18 @@ mlir::Value CIRGenFunction::buildCommonNeonBuiltinExpr(
case NEON::BI__builtin_neon_vshl_n_v:
case NEON::BI__builtin_neon_vshlq_n_v: {
mlir::Location loc = getLoc(e->getExprLoc());
ops[1] = buildNeonShiftVector(builder, ops[1], vTy, loc, false);
return builder.create<mlir::cir::ShiftOp>(
loc, vTy, builder.createBitcast(ops[0], vTy), ops[1], true);
return buildCommonNeonShift(builder, loc, vTy, ops[0], ops[1], true);
}
case NEON::BI__builtin_neon_vshll_n_v: {
mlir::Location loc = getLoc(e->getExprLoc());
mlir::cir::VectorType srcTy =
builder.getExtendedOrTruncatedElementVectorType(
vTy, false /* truncate */,
mlir::cast<mlir::cir::IntType>(vTy.getEltType()).isSigned());
ops[0] = builder.createBitcast(ops[0], srcTy);
// The following cast will be lowered to SExt or ZExt in LLVM.
ops[0] = builder.createIntCast(ops[0], vTy);
return buildCommonNeonShift(builder, loc, vTy, ops[0], ops[1], true);
}
}

Expand Down
138 changes: 88 additions & 50 deletions clang/test/CIR/CodeGen/AArch64/neon.c
Original file line number Diff line number Diff line change
Expand Up @@ -6587,61 +6587,99 @@ uint32x2_t test_vqrshrun_n_s64(int64x2_t a) {
// return vqrshrn_high_n_u64(a, b, 19);
// }

// NYI-LABEL: @test_vshll_n_s8(
// NYI: [[TMP0:%.*]] = sext <8 x i8> %a to <8 x i16>
// NYI: [[VSHLL_N:%.*]] = shl <8 x i16> [[TMP0]], <i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3>
// NYI: ret <8 x i16> [[VSHLL_N]]
// int16x8_t test_vshll_n_s8(int8x8_t a) {
// return vshll_n_s8(a, 3);
// }
int16x8_t test_vshll_n_s8(int8x8_t a) {
return vshll_n_s8(a, 3);

// CIR-LABEL: vshll_n_s8
// CIR: [[SHIFT_TGT:%.*]] = cir.cast(integral, {{%.*}} : !cir.vector<!s8i x 8>), !cir.vector<!s16i x 8>
// CIR: [[SHIFT_AMT:%.*]] = cir.const #cir.const_vector<[#cir.int<3> : !s16i, #cir.int<3> : !s16i, #cir.int<3> : !s16i, #cir.int<3> : !s16i,
// CIR-SAME: #cir.int<3> : !s16i, #cir.int<3> : !s16i, #cir.int<3> : !s16i, #cir.int<3> : !s16i]> : !cir.vector<!s16i x 8>
// CIR: {{%.*}} = cir.shift(left, [[SHIFT_TGT]] : !cir.vector<!s16i x 8>, [[SHIFT_AMT]] : !cir.vector<!s16i x 8>) -> !cir.vector<!s16i x 8>

// LLVM: {{.*}}@test_vshll_n_s8(<8 x i8>{{.*}}[[A:%.*]])
// LLVM: [[TMP0:%.*]] = sext <8 x i8> [[A]] to <8 x i16>
// LLVM: [[VSHLL_N:%.*]] = shl <8 x i16> [[TMP0]], <i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3>
// LLVM: ret <8 x i16> [[VSHLL_N]]
}

// NYI-LABEL: @test_vshll_n_s16(
// NYI: [[TMP0:%.*]] = bitcast <4 x i16> %a to <8 x i8>
// NYI: [[TMP1:%.*]] = bitcast <8 x i8> [[TMP0]] to <4 x i16>
// NYI: [[TMP2:%.*]] = sext <4 x i16> [[TMP1]] to <4 x i32>
// NYI: [[VSHLL_N:%.*]] = shl <4 x i32> [[TMP2]], <i32 9, i32 9, i32 9, i32 9>
// NYI: ret <4 x i32> [[VSHLL_N]]
// int32x4_t test_vshll_n_s16(int16x4_t a) {
// return vshll_n_s16(a, 9);
// }
int32x4_t test_vshll_n_s16(int16x4_t a) {
return vshll_n_s16(a, 9);

// NYI-LABEL: @test_vshll_n_s32(
// NYI: [[TMP0:%.*]] = bitcast <2 x i32> %a to <8 x i8>
// NYI: [[TMP1:%.*]] = bitcast <8 x i8> [[TMP0]] to <2 x i32>
// NYI: [[TMP2:%.*]] = sext <2 x i32> [[TMP1]] to <2 x i64>
// NYI: [[VSHLL_N:%.*]] = shl <2 x i64> [[TMP2]], <i64 19, i64 19>
// NYI: ret <2 x i64> [[VSHLL_N]]
// int64x2_t test_vshll_n_s32(int32x2_t a) {
// return vshll_n_s32(a, 19);
// }
// CIR-LABEL: vshll_n_s16
// CIR: [[SHIFT_TGT:%.*]] = cir.cast(integral, {{%.*}} : !cir.vector<!s16i x 4>), !cir.vector<!s32i x 4>
// CIR: [[SHIFT_AMT:%.*]] = cir.const #cir.const_vector<[#cir.int<9> : !s32i, #cir.int<9> : !s32i, #cir.int<9> :
// CIR-SAME: !s32i, #cir.int<9> : !s32i]> : !cir.vector<!s32i x 4>
// CIR: {{%.*}} = cir.shift(left, [[SHIFT_TGT]] : !cir.vector<!s32i x 4>, [[SHIFT_AMT]] : !cir.vector<!s32i x 4>) -> !cir.vector<!s32i x 4>

// NYI-LABEL: @test_vshll_n_u8(
// NYI: [[TMP0:%.*]] = zext <8 x i8> %a to <8 x i16>
// NYI: [[VSHLL_N:%.*]] = shl <8 x i16> [[TMP0]], <i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3>
// NYI: ret <8 x i16> [[VSHLL_N]]
// uint16x8_t test_vshll_n_u8(uint8x8_t a) {
// return vshll_n_u8(a, 3);
// }
// LLVM: {{.*}}@test_vshll_n_s16(<4 x i16>{{.*}}[[A:%.*]])
// LLVM: [[TMP0:%.*]] = bitcast <4 x i16> [[A]] to <8 x i8>
// LLVM: [[TMP1:%.*]] = bitcast <8 x i8> [[TMP0]] to <4 x i16>
// LLVM: [[TMP2:%.*]] = sext <4 x i16> [[TMP1]] to <4 x i32>
// LLVM: [[VSHLL_N:%.*]] = shl <4 x i32> [[TMP2]], <i32 9, i32 9, i32 9, i32 9>
// LLVM: ret <4 x i32> [[VSHLL_N]]
}

// NYI-LABEL: @test_vshll_n_u16(
// NYI: [[TMP0:%.*]] = bitcast <4 x i16> %a to <8 x i8>
// NYI: [[TMP1:%.*]] = bitcast <8 x i8> [[TMP0]] to <4 x i16>
// NYI: [[TMP2:%.*]] = zext <4 x i16> [[TMP1]] to <4 x i32>
// NYI: [[VSHLL_N:%.*]] = shl <4 x i32> [[TMP2]], <i32 9, i32 9, i32 9, i32 9>
// NYI: ret <4 x i32> [[VSHLL_N]]
// uint32x4_t test_vshll_n_u16(uint16x4_t a) {
// return vshll_n_u16(a, 9);
// }
int64x2_t test_vshll_n_s32(int32x2_t a) {
return vshll_n_s32(a, 19);

// NYI-LABEL: @test_vshll_n_u32(
// NYI: [[TMP0:%.*]] = bitcast <2 x i32> %a to <8 x i8>
// NYI: [[TMP1:%.*]] = bitcast <8 x i8> [[TMP0]] to <2 x i32>
// NYI: [[TMP2:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i64>
// NYI: [[VSHLL_N:%.*]] = shl <2 x i64> [[TMP2]], <i64 19, i64 19>
// NYI: ret <2 x i64> [[VSHLL_N]]
// uint64x2_t test_vshll_n_u32(uint32x2_t a) {
// return vshll_n_u32(a, 19);
// }
// CIR-LABEL: vshll_n_s32
// CIR: [[SHIFT_TGT:%.*]] = cir.cast(integral, {{%.*}} : !cir.vector<!s32i x 2>), !cir.vector<!s64i x 2>
// CIR: [[SHIFT_AMT:%.*]] = cir.const #cir.const_vector<[#cir.int<19> : !s64i, #cir.int<19> : !s64i]> : !cir.vector<!s64i x 2>
// CIR: {{%.*}} = cir.shift(left, [[SHIFT_TGT]] : !cir.vector<!s64i x 2>, [[SHIFT_AMT]] : !cir.vector<!s64i x 2>)

// LLVM: {{.*}}@test_vshll_n_s32(<2 x i32>{{.*}}[[A:%.*]])
// LLVM: [[TMP0:%.*]] = bitcast <2 x i32> [[A]] to <8 x i8>
// LLVM: [[TMP1:%.*]] = bitcast <8 x i8> [[TMP0]] to <2 x i32>
// LLVM: [[TMP2:%.*]] = sext <2 x i32> [[TMP1]] to <2 x i64>
// LLVM: [[VSHLL_N:%.*]] = shl <2 x i64> [[TMP2]], <i64 19, i64 19>
// LLVM: ret <2 x i64> [[VSHLL_N]]
}

uint16x8_t test_vshll_n_u8(uint8x8_t a) {
return vshll_n_u8(a, 3);

// CIR-LABEL: vshll_n_u8
// CIR: [[SHIFT_TGT:%.*]] = cir.cast(integral, {{%.*}} : !cir.vector<!u8i x 8>), !cir.vector<!u16i x 8>
// CIR: [[SHIFT_AMT:%.*]] = cir.const #cir.const_vector<[#cir.int<3> : !u16i, #cir.int<3> : !u16i, #cir.int<3> : !u16i, #cir.int<3> : !u16i,
// CIR-SAME: #cir.int<3> : !u16i, #cir.int<3> : !u16i, #cir.int<3> : !u16i, #cir.int<3> : !u16i]> : !cir.vector<!u16i x 8>
// CIR: {{%.*}} = cir.shift(left, [[SHIFT_TGT]] : !cir.vector<!u16i x 8>, [[SHIFT_AMT]] : !cir.vector<!u16i x 8>)

// LLVM: {{.*}}@test_vshll_n_u8(<8 x i8>{{.*}}[[A:%.*]])
// LLVM: [[TMP0:%.*]] = zext <8 x i8> [[A]] to <8 x i16>
// LLVM: [[VSHLL_N:%.*]] = shl <8 x i16> [[TMP0]], <i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3>
}

uint32x4_t test_vshll_n_u16(uint16x4_t a) {
return vshll_n_u16(a, 9);

// CIR-LABEL: vshll_n_u16
// CIR: [[SHIFT_TGT:%.*]] = cir.cast(integral, {{%.*}} : !cir.vector<!u16i x 4>), !cir.vector<!u32i x 4>
// CIR: [[SHIFT_AMT:%.*]] = cir.const #cir.const_vector<[#cir.int<9> : !u32i, #cir.int<9> : !u32i,
// CIR-SAME: #cir.int<9> : !u32i, #cir.int<9> : !u32i]> : !cir.vector<!u32i x 4>

// LLVM: {{.*}}@test_vshll_n_u16(<4 x i16>{{.*}}[[A:%.*]])
// LLVM: [[TMP0:%.*]] = bitcast <4 x i16> [[A]] to <8 x i8>
// LLVM: [[TMP1:%.*]] = bitcast <8 x i8> [[TMP0]] to <4 x i16>
// LLVM: [[TMP2:%.*]] = zext <4 x i16> [[TMP1]] to <4 x i32>
// LLVM: [[VSHLL_N:%.*]] = shl <4 x i32> [[TMP2]], <i32 9, i32 9, i32 9, i32 9>
// LLVM: ret <4 x i32> [[VSHLL_N]]
}

uint64x2_t test_vshll_n_u32(uint32x2_t a) {
return vshll_n_u32(a, 19);

// CIR-LABEL: vshll_n_u32
// CIR: [[SHIFT_TGT:%.*]] = cir.cast(integral, {{%.*}} : !cir.vector<!u32i x 2>), !cir.vector<!u64i x 2>
// CIR: [[SHIFT_AMT:%.*]] = cir.const #cir.const_vector<[#cir.int<19> : !u64i, #cir.int<19> : !u64i]> : !cir.vector<!u64i x 2>
// CIR: {{%.*}} = cir.shift(left, [[SHIFT_TGT]] : !cir.vector<!u64i x 2>, [[SHIFT_AMT]] : !cir.vector<!u64i x 2>)

// LLVM: {{.*}}@test_vshll_n_u32(<2 x i32>{{.*}}[[A:%.*]])
// LLVM: [[TMP0:%.*]] = bitcast <2 x i32> [[A]] to <8 x i8>
// LLVM: [[TMP1:%.*]] = bitcast <8 x i8> [[TMP0]] to <2 x i32>
// LLVM: [[TMP2:%.*]] = zext <2 x i32> [[TMP1]] to <2 x i64>
// LLVM: [[VSHLL_N:%.*]] = shl <2 x i64> [[TMP2]], <i64 19, i64 19>
// LLVM: ret <2 x i64> [[VSHLL_N]]
}

// NYI-LABEL: @test_vshll_high_n_s8(
// NYI: [[SHUFFLE_I:%.*]] = shufflevector <16 x i8> %a, <16 x i8> %a, <8 x i32> <i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
Expand Down

0 comments on commit c76c138

Please sign in to comment.