Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SVE/SVE2 support for ARMv8/ARMv9 #5510

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion scripts/get_native_properties.sh
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,11 @@ case $uname_s in
'aarch64')
file_os='android'
true_arch='armv8'
if check_flags 'asimddp'; then
if check_flags 'sve2'; then
true_arch="armv9"
elif check_flags 'sve' 'asimddp'; then
true_arch="$true_arch-sve"
elif check_flags 'asimddp'; then
true_arch="$true_arch-dotprod"
fi
;;
Expand Down
10 changes: 10 additions & 0 deletions scripts/sve-vl.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/sh

### TODO: What is the Windows path for this? (msys2)
MinetaS marked this conversation as resolved.
Show resolved Hide resolved

if [ ! -f /proc/sys/abi/sve_default_vector_length ]; then
return 1
fi

vl=$(cat /proc/sys/abi/sve_default_vector_length)
echo "$vl * 8" | bc
61 changes: 53 additions & 8 deletions src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ VPATH = syzygy:nnue:nnue/features
# vnni512 = yes/no --- -mavx512vnni --- Use Intel Vector Neural Network Instructions 512
# neon = yes/no --- -DUSE_NEON --- Use ARM SIMD architecture
# dotprod = yes/no --- -DUSE_NEON_DOTPROD --- Use ARM advanced SIMD Int8 dot product instructions
# sve = yes/no --- -DUSE_SVE=1000 --- Use ARM scalable vector extensions
# sve2 = yes/no --- -DUSE_SVE=2000 --- Use ARM scalable vector extensions 2.0
#
# Note that Makefile is space sensitive, so when adding new architectures
# or modifying existing flags, you have to make sure there are no extra spaces
Expand All @@ -125,7 +127,8 @@ ifeq ($(ARCH), $(filter $(ARCH), \
x86-64-vnni512 x86-64-vnni256 x86-64-avx512 x86-64-avxvnni x86-64-bmi2 \
x86-64-avx2 x86-64-sse41-popcnt x86-64-modern x86-64-ssse3 x86-64-sse3-popcnt \
x86-64 x86-32-sse41-popcnt x86-32-sse2 x86-32 ppc-64 ppc-32 e2k \
armv7 armv7-neon armv8 armv8-dotprod apple-silicon general-64 general-32 riscv64 loongarch64))
armv7 armv7-neon armv8 armv8-dotprod armv8-sve armv9 apple-silicon \
general-64 general-32 riscv64 loongarch64))
SUPPORTED_ARCH=true
else
SUPPORTED_ARCH=false
Expand All @@ -150,6 +153,8 @@ vnni256 = no
vnni512 = no
neon = no
dotprod = no
sve = no
sve2 = no
arm_version = 0
STRIP = strip

Expand Down Expand Up @@ -335,6 +340,27 @@ ifeq ($(ARCH),armv8-dotprod)
arm_version = 8
endif

ifeq ($(ARCH),armv9)
arch = armv9
prefetch = yes
popcnt = yes
neon = yes
dotprod = yes
sve = yes
sve2 = yes
arm_version = 9
endif

ifeq ($(ARCH),armv8-sve)
arch = armv8
prefetch = yes
popcnt = yes
neon = yes
dotprod = yes
sve = yes
arm_version = 8
endif

ifeq ($(ARCH),apple-silicon)
arch = arm64
prefetch = yes
Expand Down Expand Up @@ -400,7 +426,7 @@ ifeq ($(COMP),gcc)
CXX=g++
CXXFLAGS += -pedantic -Wextra -Wshadow -Wmissing-declarations

ifeq ($(arch),$(filter $(arch),armv7 armv8 riscv64))
ifeq ($(arch),$(filter $(arch),armv7 armv8 armv9 riscv64))
ifeq ($(OS),Android)
CXXFLAGS += -m$(bits)
LDFLAGS += -m$(bits)
Expand Down Expand Up @@ -472,7 +498,7 @@ ifeq ($(COMP),clang)
endif
endif

ifeq ($(arch),$(filter $(arch),armv7 armv8 riscv64))
ifeq ($(arch),$(filter $(arch),armv7 armv8 armv9 riscv64))
ifeq ($(OS),Android)
CXXFLAGS += -m$(bits)
LDFLAGS += -m$(bits)
Expand Down Expand Up @@ -513,7 +539,7 @@ ifeq ($(COMP),ndk)
STRIP=llvm-strip
endif
endif
ifeq ($(arch),armv8)
ifeq ($(arch),$(filter $(arch),armv8 armv9))
CXX=aarch64-linux-android21-clang++
ifneq ($(shell which aarch64-linux-android-strip 2>/dev/null),)
STRIP=aarch64-linux-android-strip
Expand Down Expand Up @@ -634,7 +660,7 @@ else
endif

ifeq ($(popcnt),yes)
ifeq ($(arch),$(filter $(arch),ppc64 armv7 armv8 arm64))
ifeq ($(arch),$(filter $(arch),ppc64 armv7 armv8 armv9 arm64))
CXXFLAGS += -DUSE_POPCNT
else
CXXFLAGS += -msse3 -mpopcnt -DUSE_POPCNT
Expand Down Expand Up @@ -708,15 +734,26 @@ ifeq ($(neon),yes)
CXXFLAGS += -DUSE_NEON=$(arm_version)
ifeq ($(KERNEL),Linux)
ifneq ($(COMP),ndk)
ifneq ($(arch),armv8)
ifeq ($(arch),armv7)
CXXFLAGS += -mfpu=neon
endif
endif
endif
endif

ifeq ($(dotprod),yes)
CXXFLAGS += -march=armv8.2-a+dotprod -DUSE_NEON_DOTPROD
ifeq ($(sve),yes)
SVE_VL := $(shell $(SHELL) ../scripts/sve-vl.sh)
CXXFLAGS += -msve-vector-bits=$(SVE_VL)

ifeq ($(sve2),yes)
CXXFLAGS += -march=armv9-a -DUSE_NEON_DOTPROD -DUSE_SVE=2000
else
CXXFLAGS += -march=armv8.2-a+dotprod+sve -DUSE_NEON_DOTPROD -DUSE_SVE=1000
endif
else
CXXFLAGS += -march=armv8.2-a+dotprod -DUSE_NEON_DOTPROD
endif
endif

### 3.7 pext
Expand Down Expand Up @@ -829,6 +866,8 @@ help:
@echo "armv7-neon > ARMv7 32-bit with popcnt and neon"
@echo "armv8 > ARMv8 64-bit with popcnt and neon"
@echo "armv8-dotprod > ARMv8 64-bit with popcnt, neon and dot product support"
@echo "armv8-sve > ARMv8 64-bit with popcnt, neon, dot product, and sve support"
@echo "armv9 > ARMv9 64-bit"
@echo "e2k > Elbrus 2000"
@echo "apple-silicon > Apple silicon ARM64"
@echo "general-64 > unspecified 64-bit"
Expand Down Expand Up @@ -1009,6 +1048,8 @@ config-sanity: net
@echo "vnni512: '$(vnni512)'"
@echo "neon: '$(neon)'"
@echo "dotprod: '$(dotprod)'"
@echo "sve: '$(sve)'"
@echo "sve2: '$(sve2)'"
@echo "arm_version: '$(arm_version)'"
@echo "target_windows: '$(target_windows)'"
@echo ""
Expand All @@ -1024,7 +1065,8 @@ config-sanity: net
@test "$(SUPPORTED_ARCH)" = "true"
@test "$(arch)" = "any" || test "$(arch)" = "x86_64" || test "$(arch)" = "i386" || \
test "$(arch)" = "ppc64" || test "$(arch)" = "ppc" || test "$(arch)" = "e2k" || \
test "$(arch)" = "armv7" || test "$(arch)" = "armv8" || test "$(arch)" = "arm64" || test "$(arch)" = "riscv64" || test "$(arch)" = "loongarch64"
test "$(arch)" = "armv7" || test "$(arch)" = "armv8" || test "$(arch)" = "armv9" || test "$(arch)" = "arm64" || \
test "$(arch)" = "riscv64" || test "$(arch)" = "loongarch64"
@test "$(bits)" = "32" || test "$(bits)" = "64"
@test "$(prefetch)" = "yes" || test "$(prefetch)" = "no"
@test "$(popcnt)" = "yes" || test "$(popcnt)" = "no"
Expand All @@ -1039,6 +1081,9 @@ config-sanity: net
@test "$(vnni256)" = "yes" || test "$(vnni256)" = "no"
@test "$(vnni512)" = "yes" || test "$(vnni512)" = "no"
@test "$(neon)" = "yes" || test "$(neon)" = "no"
@test "$(dotprod)" = "yes" || test "$(dotprod)" = "no"
@test "$(sve)" = "yes" && test ! -z "$(SVE_VL)" || test "$(sve)" = "no"
@test "$(sve2)" = "yes" || test "$(sve2)" = "no"
@test "$(comp)" = "gcc" || test "$(comp)" = "icx" || test "$(comp)" = "mingw" || test "$(comp)" = "clang" \
|| test "$(comp)" = "armv7a-linux-androideabi16-clang" || test "$(comp)" = "aarch64-linux-android21-clang"

Expand Down
15 changes: 15 additions & 0 deletions src/misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ std::string compiler_info() {

compiler += "\nCompilation settings : ";
compiler += (Is64Bit ? "64bit" : "32bit");

// x86/AMD64 family
#if defined(USE_VNNI)
compiler += " VNNI";
#endif
Expand All @@ -251,7 +253,20 @@ std::string compiler_info() {
#if defined(USE_SSE2)
compiler += " SSE2";
#endif

compiler += (HasPopCnt ? " POPCNT" : "");

// ARM/AArch64 family
#if defined(USE_SVE)
#if USE_SVE >= 2001
compiler += " SVE2.1";
#elif USE_SVE >= 2000
compiler += " SVE2";
#else
compiler += " SVE";
#endif
#endif

#if defined(USE_NEON_DOTPROD)
compiler += " NEON_DOTPROD";
#elif defined(USE_NEON)
Expand Down
20 changes: 10 additions & 10 deletions src/nnue/layers/affine_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,20 +199,20 @@ class AffineTransform {
using vec_t = __m512i;
#define vec_setzero _mm512_setzero_si512
#define vec_set_32 _mm512_set1_epi32
#define vec_add_dpbusd_32 Simd::m512_add_dpbusd_epi32
#define vec_hadd Simd::m512_hadd
#define vec_add_dpbusd_32 SIMD::m512_add_dpbusd_epi32
#define vec_hadd SIMD::m512_hadd
#elif defined(USE_AVX2)
using vec_t = __m256i;
#define vec_setzero _mm256_setzero_si256
#define vec_set_32 _mm256_set1_epi32
#define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32
#define vec_hadd Simd::m256_hadd
#define vec_add_dpbusd_32 SIMD::m256_add_dpbusd_epi32
#define vec_hadd SIMD::m256_hadd
#elif defined(USE_SSSE3)
using vec_t = __m128i;
#define vec_setzero _mm_setzero_si128
#define vec_set_32 _mm_set1_epi32
#define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32
#define vec_hadd Simd::m128_hadd
#define vec_add_dpbusd_32 SIMD::m128_add_dpbusd_epi32
#define vec_hadd SIMD::m128_hadd
#endif

static constexpr IndexType OutputSimdWidth = sizeof(vec_t) / sizeof(OutputType);
Expand Down Expand Up @@ -256,14 +256,14 @@ class AffineTransform {
using vec_t = __m256i;
#define vec_setzero _mm256_setzero_si256
#define vec_set_32 _mm256_set1_epi32
#define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32
#define vec_hadd Simd::m256_hadd
#define vec_add_dpbusd_32 SIMD::m256_add_dpbusd_epi32
#define vec_hadd SIMD::m256_hadd
#elif defined(USE_SSSE3)
using vec_t = __m128i;
#define vec_setzero _mm_setzero_si128
#define vec_set_32 _mm_set1_epi32
#define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32
#define vec_hadd Simd::m128_hadd
#define vec_add_dpbusd_32 SIMD::m128_add_dpbusd_epi32
#define vec_hadd SIMD::m128_hadd
#endif

const auto inputVector = reinterpret_cast<const vec_t*>(input);
Expand Down
15 changes: 10 additions & 5 deletions src/nnue/layers/affine_transform_sparse_input.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,27 +204,32 @@ class AffineTransformSparseInput {
using invec_t = __m512i;
using outvec_t = __m512i;
#define vec_set_32 _mm512_set1_epi32
#define vec_add_dpbusd_32 Simd::m512_add_dpbusd_epi32
#define vec_add_dpbusd_32 SIMD::m512_add_dpbusd_epi32
#elif defined(USE_AVX2)
using invec_t = __m256i;
using outvec_t = __m256i;
#define vec_set_32 _mm256_set1_epi32
#define vec_add_dpbusd_32 Simd::m256_add_dpbusd_epi32
#define vec_add_dpbusd_32 SIMD::m256_add_dpbusd_epi32
#elif defined(USE_SSSE3)
using invec_t = __m128i;
using outvec_t = __m128i;
#define vec_set_32 _mm_set1_epi32
#define vec_add_dpbusd_32 Simd::m128_add_dpbusd_epi32
#define vec_add_dpbusd_32 SIMD::m128_add_dpbusd_epi32
#elif defined(USE_SVE)
using invec_t = vec_s8_t;
using outvec_t = vec_s32_t;
#define vec_set_32(n) svreinterpret_s8(svdup_n_u32(n))
#define vec_add_dpbusd_32 SIMD::sve_add_dpbusd_s32
#elif defined(USE_NEON_DOTPROD)
using invec_t = int8x16_t;
using outvec_t = int32x4_t;
#define vec_set_32(a) vreinterpretq_s8_u32(vdupq_n_u32(a))
#define vec_add_dpbusd_32 Simd::dotprod_m128_add_dpbusd_epi32
#define vec_add_dpbusd_32 SIMD::dotprod_m128_add_dpbusd_epi32
#elif defined(USE_NEON)
using invec_t = int8x16_t;
using outvec_t = int32x4_t;
#define vec_set_32(a) vreinterpretq_s8_u32(vdupq_n_u32(a))
#define vec_add_dpbusd_32 Simd::neon_m128_add_dpbusd_epi32
#define vec_add_dpbusd_32 SIMD::neon_m128_add_dpbusd_epi32
#endif
static constexpr IndexType OutputSimdWidth = sizeof(outvec_t) / sizeof(OutputType);

Expand Down
53 changes: 42 additions & 11 deletions src/nnue/layers/clipped_relu.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,19 +135,50 @@ class ClippedReLU {
constexpr IndexType Start = NumChunks * SimdWidth;

#elif defined(USE_NEON)
constexpr IndexType NumChunks = InputDimensions / (SimdWidth / 2);
const int8x8_t Zero = {0};
const auto in = reinterpret_cast<const int32x4_t*>(input);
const auto out = reinterpret_cast<int8x8_t*>(output);
for (IndexType i = 0; i < NumChunks; ++i)
#if defined(USE_SVE) && USE_SVE >= 2000
// Check SVE vector size, and fall back to Neon if it's too big for
// this layer.
constexpr bool UseSVE = SVERegisterSize / 8 <= InputDimensions * sizeof(InputType);
constexpr size_t ChunkSize =
UseSVE ? SVERegisterSize / (8 * sizeof(InputType)) : SimdWidth / 2;
#else
constexpr size_t ChunkSize = SimdWidth / 2;
#endif

constexpr IndexType NumChunks = InputDimensions / ChunkSize;
static_assert(NumChunks > 0);

#if defined(USE_SVE) && USE_SVE >= 2000
if constexpr (UseSVE)
{
int16x8_t shifted;
const auto pack = reinterpret_cast<int16x4_t*>(&shifted);
pack[0] = vqshrn_n_s32(in[i * 2 + 0], WeightScaleBits);
pack[1] = vqshrn_n_s32(in[i * 2 + 1], WeightScaleBits);
out[i] = vmax_s8(vqmovn_s16(shifted), Zero);
const auto in = reinterpret_cast<const vec_s32_t*>(input);
const auto out = reinterpret_cast<uint8_t*>(output);

for (IndexType i = 0; i < NumChunks; ++i)
{
vec_s16_t tmp16 = svqshrnb_n_s32(in[i], WeightScaleBits);
vec_s8_t tmp8 = svmax_n_s8_z(svptrue_b8(), svqxtnb_s16(tmp16), 0);
svst1b_u32(svptrue_b32(), &out[i * SVERegisterSize / 32],
svreinterpret_u32_s8(tmp8));
}
}
else
#endif
{
const int8x8_t Zero = {0};
const auto in = reinterpret_cast<const int32x4_t*>(input);
const auto out = reinterpret_cast<int8x8_t*>(output);

for (IndexType i = 0; i < NumChunks; ++i)
{
int16x8_t shifted;
const auto pack = reinterpret_cast<int16x4_t*>(&shifted);
pack[0] = vqshrn_n_s32(in[i * 2 + 0], WeightScaleBits);
pack[1] = vqshrn_n_s32(in[i * 2 + 1], WeightScaleBits);
out[i] = vmax_s8(vqmovn_s16(shifted), Zero);
}
}
constexpr IndexType Start = NumChunks * (SimdWidth / 2);
constexpr IndexType Start = NumChunks * ChunkSize;
#else
constexpr IndexType Start = 0;
#endif
Expand Down
Loading