Skip to content

Commit ad994e5

Browse files
authored
[mono][jit] Adding support for Vector128::ExtractMostSignificantBits intrinsic on ARM64 with miniJIT (#84345)
Contributes to #76025
1 parent a0e23f4 commit ad994e5

File tree

5 files changed

+152
-4
lines changed

5 files changed

+152
-4
lines changed

src/mono/mono/arch/arm64/arm64-codegen.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1804,6 +1804,7 @@ arm_encode_arith_imm (int imm, guint32 *shift)
18041804
#define arm_neon_cmhi(p, width, type, rd, rn, rm) arm_neon_3svec_opcode ((p), (width), 0b1, (type), 0b00110, (rd), (rn), (rm))
18051805
#define arm_neon_cmhs(p, width, type, rd, rn, rm) arm_neon_3svec_opcode ((p), (width), 0b1, (type), 0b00111, (rd), (rn), (rm))
18061806
#define arm_neon_addp(p, width, type, rd, rn, rm) arm_neon_3svec_opcode ((p), (width), 0b0, (type), 0b10111, (rd), (rn), (rm))
1807+
#define arm_neon_ushl(p, width, type, rd, rn, rm) arm_neon_3svec_opcode ((p), (width), 0b1, (type), 0b01000, (rd), (rn), (rm))
18071808

18081809
// Generalized macros for float ops:
18091810
// width - determines if full register or its lower half is used one of {VREG_LOW, VREG_FULL}

src/mono/mono/mini/cpu-arm64.mdesc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,8 @@ create_scalar_unsafe_int: dest:x src1:i len:4
533533
create_scalar_unsafe_float: dest:x src1:f len:4
534534
arm64_bic: dest:x src1:x src2:x len:4
535535
bitwise_select: dest:x src1:x src2:x src3:x len:12
536+
arm64_ushl: dest:x src1:x src2:x len:4
537+
arm64_ext_imm: dest:x src1:x src2:x len:4
536538

537539
generic_class_init: src1:a len:44 clob:c
538540
gc_safe_point: src1:i len:12 clob:c

src/mono/mono/mini/mini-arm64.c

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3920,8 +3920,18 @@ mono_arch_output_basic_block (MonoCompile *cfg, MonoBasicBlock *bb)
39203920
// case OP_XCONCAT:
39213921
// arm_neon_ext_16b(code, dreg, sreg1, sreg2, 8);
39223922
// break;
3923-
3924-
/* BRANCH */
3923+
case OP_ARM64_USHL: {
3924+
arm_neon_ushl (code, get_vector_size_macro (ins), get_type_size_macro (ins->inst_c1), dreg, sreg1, sreg2);
3925+
break;
3926+
}
3927+
case OP_ARM64_EXT_IMM: {
3928+
if (get_vector_size_macro (ins) == VREG_LOW)
3929+
arm_neon_ext_8b (code, dreg, sreg1, sreg2, ins->inst_c0);
3930+
else
3931+
arm_neon_ext_16b (code, dreg, sreg1, sreg2, ins->inst_c0);
3932+
break;
3933+
}
3934+
/* BRANCH */
39253935
case OP_BR:
39263936
mono_add_patch_info_rel (cfg, offset, MONO_PATCH_INFO_BB, ins->inst_target_bb, MONO_R_ARM64_B);
39273937
arm_b (code, code);

src/mono/mono/mini/mini-ops.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1759,6 +1759,7 @@ MINI_OP(OP_ARM64_ABSCOMPARE, "arm64_abscompare", XREG, XREG, XREG)
17591759
MINI_OP(OP_ARM64_XNARROW_SCALAR, "arm64_xnarrow_scalar", XREG, XREG, NONE)
17601760

17611761
MINI_OP3(OP_ARM64_EXT, "arm64_ext", XREG, XREG, XREG, IREG)
1762+
MINI_OP(OP_ARM64_EXT_IMM, "arm64_ext_imm", XREG, XREG, XREG)
17621763

17631764
MINI_OP3(OP_ARM64_SQRDMLAH, "arm64_sqrdmlah", XREG, XREG, XREG, XREG)
17641765
MINI_OP3(OP_ARM64_SQRDMLAH_BYSCALAR, "arm64_sqrdmlah_byscalar", XREG, XREG, XREG, XREG)
@@ -1775,6 +1776,8 @@ MINI_OP3(OP_ARM64_SQRDMLSH_SCALAR, "arm64_sqrdmlsh_scalar", XREG, XREG, XREG, XR
17751776
MINI_OP(OP_ARM64_TBL_INDIRECT, "arm64_tbl_indirect", XREG, IREG, XREG)
17761777
MINI_OP3(OP_ARM64_TBX_INDIRECT, "arm64_tbx_indirect", XREG, IREG, XREG, XREG)
17771778

1779+
MINI_OP(OP_ARM64_USHL, "arm64_ushl", XREG, XREG, XREG)
1780+
17781781
#endif // TARGET_ARM64
17791782

17801783
MINI_OP(OP_SIMD_FCVTL, "simd_convert_to_higher_precision", XREG, XREG, NONE)

src/mono/mono/mini/simd-intrinsics.c

Lines changed: 134 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,6 +1203,97 @@ is_element_type_primitive (MonoType *vector_type)
12031203
}
12041204
}
12051205

1206+
static MonoInst*
1207+
emit_msb_vector_mask (MonoCompile *cfg, MonoClass *arg_class, MonoTypeEnum arg_type)
1208+
{
1209+
guint64 msb_mask_value[2];
1210+
1211+
switch (arg_type) {
1212+
case MONO_TYPE_I1:
1213+
case MONO_TYPE_U1:
1214+
msb_mask_value[0] = 0x8080808080808080;
1215+
msb_mask_value[1] = 0x8080808080808080;
1216+
break;
1217+
case MONO_TYPE_I2:
1218+
case MONO_TYPE_U2:
1219+
msb_mask_value[0] = 0x8000800080008000;
1220+
msb_mask_value[1] = 0x8000800080008000;
1221+
break;
1222+
#if TARGET_SIZEOF_VOID_P == 4
1223+
case MONO_TYPE_I:
1224+
case MONO_TYPE_U:
1225+
#endif
1226+
case MONO_TYPE_I4:
1227+
case MONO_TYPE_U4:
1228+
case MONO_TYPE_R4:
1229+
msb_mask_value[0] = 0x8000000080000000;
1230+
msb_mask_value[1] = 0x8000000080000000;
1231+
break;
1232+
#if TARGET_SIZEOF_VOID_P == 8
1233+
case MONO_TYPE_I:
1234+
case MONO_TYPE_U:
1235+
#endif
1236+
case MONO_TYPE_I8:
1237+
case MONO_TYPE_U8:
1238+
case MONO_TYPE_R8:
1239+
msb_mask_value[0] = 0x8000000000000000;
1240+
msb_mask_value[1] = 0x8000000000000000;
1241+
break;
1242+
default:
1243+
g_assert_not_reached ();
1244+
}
1245+
1246+
MonoInst* msb_mask_vec = emit_xconst_v128 (cfg, arg_class, (guint8*)msb_mask_value);
1247+
msb_mask_vec->klass = arg_class;
1248+
return msb_mask_vec;
1249+
}
1250+
1251+
static MonoInst*
1252+
emit_msb_shift_vector_constant (MonoCompile *cfg, MonoClass *arg_class, MonoTypeEnum arg_type)
1253+
{
1254+
guint64 msb_shift_value[2];
1255+
1256+
// NOTE: On ARM64 ushl shifts a vector left or right depending on the sign of the shift constant
1257+
switch (arg_type) {
1258+
case MONO_TYPE_I1:
1259+
case MONO_TYPE_U1:
1260+
msb_shift_value[0] = 0x00FFFEFDFCFBFAF9;
1261+
msb_shift_value[1] = 0x00FFFEFDFCFBFAF9;
1262+
break;
1263+
case MONO_TYPE_I2:
1264+
case MONO_TYPE_U2:
1265+
msb_shift_value[0] = 0xFFF4FFF3FFF2FFF1;
1266+
msb_shift_value[1] = 0xFFF8FFF7FFF6FFF5;
1267+
break;
1268+
#if TARGET_SIZEOF_VOID_P == 4
1269+
case MONO_TYPE_I:
1270+
case MONO_TYPE_U:
1271+
#endif
1272+
case MONO_TYPE_I4:
1273+
case MONO_TYPE_U4:
1274+
case MONO_TYPE_R4:
1275+
msb_shift_value[0] = 0xFFFFFFE2FFFFFFE1;
1276+
msb_shift_value[1] = 0xFFFFFFE4FFFFFFE3;
1277+
break;
1278+
#if TARGET_SIZEOF_VOID_P == 8
1279+
case MONO_TYPE_I:
1280+
case MONO_TYPE_U:
1281+
#endif
1282+
case MONO_TYPE_I8:
1283+
case MONO_TYPE_U8:
1284+
case MONO_TYPE_R8:
1285+
msb_shift_value[0] = 0xFFFFFFFFFFFFFFC1;
1286+
msb_shift_value[1] = 0xFFFFFFFFFFFFFFC2;
1287+
break;
1288+
default:
1289+
g_assert_not_reached ();
1290+
}
1291+
1292+
MonoInst* msb_shift_vec = emit_xconst_v128 (cfg, arg_class, (guint8*)msb_shift_value);
1293+
msb_shift_vec->klass = arg_class;
1294+
return msb_shift_vec;
1295+
}
1296+
12061297
static MonoInst*
12071298
emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsig, MonoInst **args)
12081299
{
@@ -1234,7 +1325,6 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
12341325
case SN_ConvertToUInt64:
12351326
case SN_Create:
12361327
case SN_Dot:
1237-
case SN_ExtractMostSignificantBits:
12381328
case SN_GetElement:
12391329
case SN_GetLower:
12401330
case SN_GetUpper:
@@ -1542,7 +1632,49 @@ emit_sri_vector (MonoCompile *cfg, MonoMethod *cmethod, MonoMethodSignature *fsi
15421632
return NULL;
15431633
#ifdef TARGET_WASM
15441634
return emit_simd_ins_for_sig (cfg, klass, OP_WASM_SIMD_BITMASK, -1, -1, fsig, args);
1545-
#else
1635+
#elif defined(TARGET_ARM64)
1636+
if (COMPILE_LLVM (cfg))
1637+
return NULL;
1638+
1639+
MonoInst* result_ins = NULL;
1640+
MonoClass* arg_class = mono_class_from_mono_type_internal (fsig->params [0]);
1641+
int size = mono_class_value_size (arg_class, NULL);
1642+
if (size != 16)
1643+
return NULL;
1644+
1645+
MonoInst* msb_mask_vec = emit_msb_vector_mask (cfg, arg_class, arg0_type);
1646+
MonoInst* and_res_vec = emit_simd_ins_for_binary_op (cfg, arg_class, fsig, args, arg0_type, SN_BitwiseAnd);
1647+
and_res_vec->sreg2 = msb_mask_vec->dreg;
1648+
1649+
MonoInst* msb_shift_vec = emit_msb_shift_vector_constant (cfg, arg_class, arg0_type);
1650+
MonoInst* shift_res_vec = emit_simd_ins (cfg, arg_class, OP_ARM64_USHL, and_res_vec->dreg, msb_shift_vec->dreg);
1651+
shift_res_vec->inst_c1 = arg0_type;
1652+
1653+
if (arg0_type == MONO_TYPE_I1 || arg0_type == MONO_TYPE_U1) {
1654+
// Always perform usigned operations as vector sum and extract operations could sign-extend the result into the GP register
1655+
// making the final result invalid. This is not needed for wider type as the maximum sum of extracted MSB cannot be larger than 8bits
1656+
arg0_type = MONO_TYPE_U1;
1657+
1658+
// In order to sum high and low 64bits of the shifted vector separatly, we use a zeroed vector and the extract operation
1659+
MonoInst* zero_vec = emit_xzero(cfg, arg_class);
1660+
1661+
MonoInst* ext_low_vec = emit_simd_ins (cfg, arg_class, OP_ARM64_EXT_IMM, zero_vec->dreg, shift_res_vec->dreg);
1662+
ext_low_vec->inst_c0 = 8;
1663+
ext_low_vec->inst_c1 = arg0_type;
1664+
MonoInst* sum_low_vec = emit_sum_vector (cfg, fsig->params [0], arg0_type, ext_low_vec);
1665+
1666+
MonoInst* ext_high_vec = emit_simd_ins (cfg, arg_class, OP_ARM64_EXT_IMM, shift_res_vec->dreg, zero_vec->dreg);
1667+
ext_high_vec->inst_c0 = 8;
1668+
ext_high_vec->inst_c1 = arg0_type;
1669+
MonoInst* sum_high_vec = emit_sum_vector (cfg, fsig->params [0], arg0_type, ext_high_vec);
1670+
1671+
MONO_EMIT_NEW_BIALU_IMM (cfg, OP_SHL_IMM, sum_high_vec->dreg, sum_high_vec->dreg, 8);
1672+
EMIT_NEW_BIALU (cfg, result_ins, OP_IOR, sum_high_vec->dreg, sum_high_vec->dreg, sum_low_vec->dreg);
1673+
} else {
1674+
result_ins = emit_sum_vector (cfg, fsig->params [0], arg0_type, shift_res_vec);
1675+
}
1676+
return result_ins;
1677+
#elif defined(TARGET_AMD64)
15461678
return NULL;
15471679
#endif
15481680
}

0 commit comments

Comments
 (0)