From aabef6c0584f06f4c0f4b61fb787d80374240619 Mon Sep 17 00:00:00 2001 From: Jenkins Date: Thu, 18 Aug 2022 12:49:09 +0000 Subject: [PATCH] Compute Library v22.08 --- Android.bp | 77 +- CONTRIBUTING.md | 11 + README.md | 44 +- SConscript | 13 +- SConstruct | 134 +- arm_compute/AclOpenClExt.h | 4 +- arm_compute/core/CL/CLDevice.h | 28 +- arm_compute/core/CL/CLHelpers.h | 17 +- arm_compute/core/CL/CLTypes.h | 7 +- arm_compute/core/CL/OpenCL.h | 9 +- arm_compute/core/CPP/CPPTypes.h | 10 + arm_compute/core/CPP/ICPPKernel.h | 4 +- arm_compute/core/GPUTarget.h | 49 +- arm_compute/core/KernelDescriptors.h | 9 + arm_compute/core/QuantizationInfo.h | 23 +- arm_compute/core/Types.h | 296 +- arm_compute/core/Window.h | 1 + arm_compute/core/experimental/OperatorGraph.h | 33 +- arm_compute/core/utils/misc/MMappedFile.h | 4 +- arm_compute/core/utils/misc/ShapeCalculator.h | 48 +- arm_compute/runtime/CL/CLTypes.h | 6 +- arm_compute/runtime/FunctionDescriptors.h | 8 +- .../NEON/functions/NEFullyConnectedLayer.h | 42 +- arm_compute/runtime/NEON/functions/NEGEMM.h | 11 +- .../NEON/functions/NEGEMMConvolutionLayer.h | 61 +- arm_compute/runtime/NEON/functions/NEGather.h | 13 +- .../functions/NEWinogradConvolutionLayer.h | 3 +- docs/03_scripts.dox | 178 - docs/ComputeLibrary.dir | 8 - docs/Doxyfile | 2 +- .../contribution_guidelines.dox | 4 +- docs/user_guide/errata.dox | 2 +- .../how_to_build_and_run_examples.dox | 263 +- docs/user_guide/introduction.dox | 13 +- .../release_version_and_change_log.dox | 35 +- examples/SConscript | 16 +- .../cl_fused_conv2d_elementwise_add.cpp | 19 +- .../cl_ref_conv2d_elementwise_add.cpp | 37 +- filedefs.json | 5 +- filelist.json | 113 +- include/libnpy/npy.hpp | 4 +- scripts/arm_compute_library_nn_driver.go | 25 + scripts/caffe_data_extractor.py | 45 - scripts/tensorflow_data_extractor.py | 51 - src/core/CL/CLCompileContext.cpp | 14 +- src/core/CL/CLHelpers.cpp | 21 +- src/core/CL/OpenCL.cpp | 40 +- .../common/gemm_reshaped_only_rhs_mmul.cl | 528 ++ src/core/CL/cl_kernels/common/gemmlowp.cl | 22 +- .../common/gemmlowp_reshaped_only_rhs_mmul.cl | 309 + src/core/CL/cl_kernels/helpers.h | 1 + .../CL/cl_kernels/nhwc/direct_convolution.cl | 16 +- .../cl_kernels/nhwc/direct_convolution3d.cl | 4 +- .../CL/cl_kernels/nhwc/dwc_native_fp_nhwc.cl | 4 +- .../nhwc/dwc_native_quantized_nhwc.cl | 4 +- .../nhwc/winograd_output_transform.cl | 70 +- src/core/CL/cl_kernels/tile_helpers.h | 95 +- ...LDepthwiseConvolutionLayerNativeKernel.cpp | 25 +- src/core/CPP/CPPTypes.cpp | 10 + src/core/GPUTarget.cpp | 45 +- src/core/NEON/kernels/NEGatherKernel.cpp | 116 +- src/core/NEON/kernels/NEGatherKernel.h | 26 +- .../depthwise/depthwise_implementation.hpp | 9 +- .../arm_conv/pooling/pooling_depthfirst.hpp | 3 +- .../pooling/pooling_depthfirst_generic.hpp | 2 + src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp | 56 +- src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp | 47 +- src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp | 89 +- .../kernels/arm_gemm/gemm_hybrid_indirect.hpp | 142 +- .../kernels/arm_gemm/gemm_implementation.hpp | 104 +- src/core/NEON/kernels/arm_gemm/gemm_int16.cpp | 2 +- src/core/NEON/kernels/arm_gemm/gemm_int8.cpp | 2 +- .../kernels/arm_gemm/gemm_interleaved.hpp | 147 +- src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp | 2 +- .../NEON/kernels/arm_gemm/gemm_quint8.cpp | 2 +- .../NEON/kernels/arm_gemm/gemm_uint16.cpp | 2 +- src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp | 2 +- .../kernels/arm_gemm/kernel_weight_format.hpp | 60 + .../a64_ffhybrid_bf16fp32_mmla_6x16.hpp | 109 + .../generic.cpp | 3807 ++++++++++++ .../kernels/a64_ffhybrid_fp16_mla_6x32.hpp | 108 + .../a64_ffhybrid_fp16_mla_6x32/generic.cpp | 5429 +++++++++++++++++ .../kernels/a64_ffhybrid_fp32_mla_6x16.hpp | 108 + .../a64_ffhybrid_fp32_mla_6x16/generic.cpp | 3461 +++++++++++ .../a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp | 109 + .../generic.cpp | 2561 ++++++++ .../a64_ffinterleaved_bf16fp32_dot_8x12.hpp | 101 + .../generic.cpp | 269 + .../a64_ffinterleaved_bf16fp32_mmla_8x12.hpp | 109 + .../generic.cpp | 314 + .../a64_ffinterleaved_fp16_mla_8x24.hpp | 100 + .../generic.cpp | 264 + .../a64_ffinterleaved_fp32_mla_8x12.hpp | 100 + .../generic.cpp | 332 + .../a64_interleaved_bf16fp32_mmla_8x12.hpp | 6 +- .../sve_ffhybrid_bf16fp32_mmla_6x4VL.hpp | 109 + .../generic.cpp | 2227 +++++++ .../kernels/sve_ffhybrid_fp16_mla_6x4VL.hpp | 116 + .../sve_ffhybrid_fp16_mla_6x4VL/a64fx.cpp | 1530 +++++ .../sve_ffhybrid_fp16_mla_6x4VL/generic.cpp | 3318 ++++++++++ .../kernels/sve_ffhybrid_fp32_mla_6x4VL.hpp | 116 + .../sve_ffhybrid_fp32_mla_6x4VL/a64fx.cpp | 1530 +++++ .../sve_ffhybrid_fp32_mla_6x4VL/generic.cpp | 2310 +++++++ .../sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp | 109 + .../generic.cpp | 1464 +++++ .../sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp | 109 + .../generic.cpp | 319 + .../sve_ffinterleaved_fp16_mla_8x3VL.hpp | 108 + .../a64fx.cpp | 297 + .../generic.cpp | 269 + .../sve_ffinterleaved_fp32_mla_8x3VL.hpp | 108 + .../a64fx.cpp | 297 + .../generic.cpp | 273 + .../sve_interleaved_bf16fp32_mmla_8x3VL.hpp | 6 +- src/core/NEON/kernels/arm_gemm/misc.cpp | 44 +- .../NEON/kernels/arm_gemm/transform-sve.cpp | 3 + src/core/NEON/kernels/arm_gemm/utils.hpp | 2 +- src/core/NEON/kernels/assembly/winograd.hpp | 234 + .../kernels/convolution/common/padding.cpp | 2 +- .../kernels/convolution/common/padding.hpp | 28 +- .../convolution/winograd/input_transform.hpp | 384 ++ .../a64_fp16_6x6.cpp} | 33 +- .../a64_fp32_6x6.cpp} | 244 +- .../arm_fp32_1x8.cpp} | 33 +- .../arm_fp32_4x4.cpp} | 36 +- .../input_transforms/arm_fp32_6x6.cpp | 202 + .../input_transforms/sve_fp32_6x6.cpp | 361 ++ .../winograd/input_transforms_fp16.cpp | 56 + .../winograd/input_transforms_fp32.cpp | 71 + .../convolution/winograd/output_transform.hpp | 302 + .../a64_fp16_4x4_3x3.cpp} | 31 +- .../arm_fp32_1x2_1x7.cpp} | 69 +- .../arm_fp32_1x4_1x5.cpp} | 69 +- .../arm_fp32_1x6_1x3.cpp} | 72 +- .../arm_fp32_2x2_3x3.cpp} | 99 +- .../arm_fp32_2x2_5x5.cpp} | 99 +- .../arm_fp32_4x4_3x3.cpp} | 98 +- .../winograd/output_transforms_fp16.cpp | 55 + .../winograd/output_transforms_fp32.cpp | 68 + .../kernels/convolution/winograd/padding.cpp | 29 +- .../convolution/winograd/weight_transform.hpp | 145 + .../a64_fp16_4x4_3x3.cpp} | 161 +- .../weight_transforms/arm_fp32_2x2_3x3.cpp | 200 + .../weight_transforms/arm_fp32_2x2_5x5.cpp | 381 ++ .../weight_transforms/arm_fp32_4x4_3x3.cpp | 236 + .../weight_transforms/cpp_fp32_1x2_1x7.cpp | 71 + .../weight_transforms/cpp_fp32_1x4_1x5.cpp | 77 + .../weight_transforms/cpp_fp32_1x6_1x3.cpp | 71 + .../winograd/weight_transforms_fp16.cpp | 54 + .../winograd/weight_transforms_fp32.cpp | 74 + .../kernels/convolution/winograd/winograd.cpp | 182 - .../kernels/convolution/winograd/winograd.hpp | 621 -- .../convolution/winograd/winograd_fp16.cpp | 45 + .../convolution/winograd/winograd_fp32.cpp | 41 + .../winograd/winograd_implementations.hpp | 341 ++ .../convolution/winograd/winograd_layer.hpp | 207 - .../winograd/winograd_transforms/input.hpp | 268 - .../input_4x4_fp16_fp16_integers.cpp | 257 - .../winograd/winograd_transforms/kernel.hpp | 78 - .../winograd/winograd_transforms/output.hpp | 252 - .../weights_2_7_fp32_fp32_integers.cpp | 90 - .../weights_2x2_3x3_fp32_fp32_integers.cpp | 220 - .../weights_2x2_5x5_fp32_fp32_integers.cpp | 401 -- .../weights_4_5_fp32_fp32_integers.cpp | 90 - .../weights_4x4_3x3_fp32_fp32_integers.cpp | 257 - .../weights_6_3_fp32_fp32_integers.cpp | 90 - src/core/NEON/wrapper/intrinsics/cvt.h | 6 +- src/core/NEON/wrapper/intrinsics/svdup_n.h | 6 +- src/core/NEON/wrapper/svtraits.h | 5 +- src/core/common/Registrars.h | 6 +- .../dynamic_fusion/ClKernelBuildingAPI.cpp | 20 +- .../dynamic_fusion/ClKernelBuildingAPI.h | 10 +- .../ClKernelBuildingImpl/Common.h | 8 +- .../ClDirectConvolutionKernelComponent.cpp | 2 +- ...t.cpp => ClElementwiseKernelComponent.cpp} | 107 +- .../components/ClElementwiseKernelComponent.h | 90 + .../components/ClFloorKernelComponent.cpp | 153 + ...elComponent.h => ClFloorKernelComponent.h} | 32 +- .../components/ClKernelComponents.h | 3 +- .../components/ClStoreKernelComponents.cpp | 4 + .../dynamic_fusion/OperatorGraph.cpp | 33 +- .../WorkloadImpl/ClKernelDescriptors.h | 18 +- .../WorkloadImpl/ClKernelGraph.cpp | 63 +- .../WorkloadImpl/ClKernelGraph.h | 30 +- .../WorkloadImpl/OperatorGraphImpl.cpp | 55 +- .../WorkloadImpl/OperatorGraphImpl.h | 42 +- src/core/utils/AssemblyUtils.cpp | 242 +- src/core/utils/AssemblyUtils.h | 18 +- src/cpu/kernels/CpuActivationKernel.cpp | 40 +- src/cpu/kernels/CpuActivationKernel.h | 6 +- src/cpu/kernels/CpuAddKernel.cpp | 123 +- src/cpu/kernels/CpuAddKernel.h | 12 +- src/cpu/kernels/CpuIm2ColKernel.cpp | 10 +- src/cpu/kernels/CpuKernelSelectionTypes.h | 19 +- src/cpu/kernels/CpuWinogradConv2dKernel.cpp | 568 +- src/cpu/kernels/CpuWinogradConv2dKernel.h | 533 +- .../activation/generic/neon/qasymm8.cpp | 443 +- .../activation/generic/sve2/qasymm8.cpp | 18 +- src/cpu/kernels/activation/list.h | 3 +- src/cpu/kernels/add/generic/neon/fp16.cpp | 7 +- src/cpu/kernels/add/generic/neon/fp32.cpp | 7 +- src/cpu/kernels/add/generic/neon/impl.cpp | 40 +- src/cpu/kernels/add/generic/neon/impl.h | 5 +- src/cpu/kernels/add/generic/neon/integer.cpp | 17 +- src/cpu/kernels/add/list.h | 7 +- src/cpu/kernels/assembly/arm_gemm.hpp | 58 +- .../kernels/cast/generic/neon/bfloat16.cpp | 4 +- .../elementwise_binary/generic/sve/fp16.cpp | 4 +- .../elementwise_binary/generic/sve/fp32.cpp | 4 +- .../elementwise_binary/generic/sve/impl.cpp | 297 +- .../elementwise_binary/generic/sve/impl.h | 8 +- .../generic/sve/integer.cpp | 10 +- .../elementwise_binary/generic/sve2/impl.h | 303 +- .../generic/sve2/qasymm8.cpp | 4 +- .../generic/sve2/qasymm8_signed.cpp | 4 +- src/cpu/kernels/softmax/generic/sve/impl.cpp | 21 +- src/cpu/kernels/softmax/generic/sve2/impl.cpp | 20 +- src/cpu/operators/CpuAdd.cpp | 16 +- src/cpu/operators/CpuAdd.h | 5 +- src/cpu/operators/CpuFullyConnected.cpp | 43 +- src/cpu/operators/CpuFullyConnected.h | 55 +- src/cpu/operators/CpuGemm.cpp | 27 +- src/cpu/operators/CpuGemm.h | 20 +- src/cpu/operators/CpuGemmConv2d.cpp | 122 +- src/cpu/operators/CpuGemmConv2d.h | 28 +- src/cpu/operators/CpuGemmDirectConv2d.cpp | 29 +- src/cpu/operators/CpuWinogradConv2d.cpp | 913 +-- src/cpu/operators/CpuWinogradConv2d.h | 54 +- .../internal/CpuGemmAssemblyDispatch.cpp | 128 +- .../internal/CpuGemmAssemblyDispatch.h | 58 +- src/gpu/cl/ClKernelLibrary.cpp | 11 + src/gpu/cl/kernels/ClCastKernel.cpp | 2 +- src/gpu/cl/kernels/ClCastKernel.h | 3 +- src/gpu/cl/kernels/ClDirectConv2dKernel.cpp | 138 +- src/gpu/cl/kernels/ClDirectConv2dKernel.h | 11 +- src/gpu/cl/kernels/ClDirectConv3dKernel.cpp | 4 +- ...atrixMultiplyReshapedOnlyRhsMMULKernel.cpp | 480 ++ ...pMatrixMultiplyReshapedOnlyRhsMMULKernel.h | 93 + ...atrixMultiplyReshapedOnlyRhsMMULKernel.cpp | 365 ++ ...mMatrixMultiplyReshapedOnlyRhsMMULKernel.h | 89 + .../ClWinogradOutputTransformKernel.cpp | 37 +- .../ClDirectConvDefaultConfigBifrost.cpp | 192 + .../ClDirectConvDefaultConfigBifrost.h | 55 + .../ClDirectConvDefaultConfigValhall.cpp | 358 ++ .../ClDirectConvDefaultConfigValhall.h | 55 + .../direct_conv/ClDirectConvKernelConfig.h | 64 + .../direct_conv/IClDirectConvKernelConfig.h | 115 + src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp | 19 +- src/gpu/cl/kernels/gemm/ClGemmHelpers.h | 17 +- ...emmDefaultConfigReshapedRhsOnlyValhall.cpp | 42 +- ...lGemmDefaultConfigReshapedRhsOnlyValhall.h | 4 +- src/gpu/cl/operators/ClConv2d.cpp | 28 +- src/gpu/cl/operators/ClDirectConv2d.cpp | 32 +- src/gpu/cl/operators/ClGemm.cpp | 134 + src/gpu/cl/operators/ClGemm.h | 29 +- .../ClGemmLowpMatrixMultiplyCore.cpp | 129 +- .../operators/ClGemmLowpMatrixMultiplyCore.h | 38 +- src/gpu/cl/operators/ClWinogradConv2d.cpp | 3 +- src/runtime/CL/CLScheduler.cpp | 5 +- .../functions/CLDepthwiseConvolutionLayer.cpp | 3 +- src/runtime/CL/functions/CLGEMM.cpp | 3 +- .../CL/gemm/CLGEMMDefaultTypeValhall.cpp | 61 +- .../CL/gemm/CLGEMMDefaultTypeValhall.h | 4 +- src/runtime/IScheduler.cpp | 4 +- .../NEON/functions/NEFullyConnectedLayer.cpp | 13 +- src/runtime/NEON/functions/NEGEMM.cpp | 9 +- .../NEON/functions/NEGEMMConvolutionLayer.cpp | 10 +- .../functions/NEWinogradConvolutionLayer.cpp | 3 +- src/runtime/OMP/OMPScheduler.cpp | 13 +- support/Bfloat16.h | 8 +- support/ToolchainSupport.h | 6 +- tests/AssetsLibrary.h | 24 +- tests/SConscript | 19 +- tests/datasets/GatherDataset.h | 15 +- tests/framework/Asserts.h | 9 +- tests/framework/SConscript | 7 +- tests/main.cpp | 3 +- tests/validate_examples/cl_gemm.cpp | 38 +- ...MLowpMatrixMultiplyReshapedOnlyRhsMMUL.cpp | 206 + .../GEMMMatrixMultiplyReshapedOnlyRhsMMUL.cpp | 231 + .../ArbitraryElementwiseFusion.cpp | 394 ++ .../UNIT/dynamic_fusion/ClCompositeKernel.cpp | 9 +- .../CL/UNIT/dynamic_fusion/Floor.cpp | 135 + ...tegration_OperatorFuseMovenetSubGraph1.cpp | 30 +- tests/validation/NEON/ActivationLayer.cpp | 2 +- tests/validation/NEON/ArithmeticAddition.cpp | 31 +- tests/validation/NEON/ConvolutionLayer.cpp | 366 +- tests/validation/NEON/DepthConvertLayer.cpp | 6 +- tests/validation/NEON/FillBorder.cpp | 12 +- tests/validation/NEON/Gather.cpp | 12 +- .../validation/NEON/UNIT/TensorAllocator.cpp | 4 +- tests/validation/UNIT/GPUTarget.cpp | 16 +- tests/validation/fixtures/ArgMinMaxFixture.h | 6 +- .../ConvertFullyConnectedWeightsFixture.h | 4 +- .../fixtures/ConvolutionLayerFixture.h | 304 +- .../fixtures/DeconvolutionLayerFixture.h | 8 +- .../fixtures/DepthConvertLayerFixture.h | 4 +- .../DepthwiseConvolutionLayerFixture.h | 6 +- .../fixtures/DirectConvolutionLayerFixture.h | 6 +- .../fixtures/FullyConnectedLayerFixture.h | 10 +- tests/validation/fixtures/GEMMFixture.h | 197 +- tests/validation/fixtures/GEMMLowpFixture.h | 379 +- .../fixtures/ReductionOperationFixture.h | 6 +- tests/validation/fixtures/ScaleFixture.h | 6 +- tests/validation/reference/Gather.cpp | 57 +- utils/TypePrinter.h | 133 +- utils/command_line/CommandLineParser.h | 3 +- 307 files changed, 46312 insertions(+), 7463 deletions(-) create mode 100644 CONTRIBUTING.md delete mode 100644 docs/03_scripts.dox delete mode 100755 scripts/caffe_data_extractor.py delete mode 100755 scripts/tensorflow_data_extractor.py create mode 100644 src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl create mode 100644 src/core/CL/cl_kernels/common/gemmlowp_reshaped_only_rhs_mmul.cl create mode 100644 src/core/NEON/kernels/arm_gemm/kernel_weight_format.hpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_bf16fp32_mmla_6x16.hpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_bf16fp32_mmla_6x16/generic.cpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp16_mla_6x32.hpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp16_mla_6x32/generic.cpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32_mla_6x16.hpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32_mla_6x16/generic.cpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24/generic.cpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_dot_8x12.hpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_dot_8x12/generic.cpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12/generic.cpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp16_mla_8x24.hpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp16_mla_8x24/generic.cpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp32_mla_8x12.hpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp32_mla_8x12/generic.cpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_bf16fp32_mmla_6x4VL.hpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_bf16fp32_mmla_6x4VL/generic.cpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp16_mla_6x4VL.hpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp16_mla_6x4VL/a64fx.cpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp16_mla_6x4VL/generic.cpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32_mla_6x4VL.hpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32_mla_6x4VL/a64fx.cpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32_mla_6x4VL/generic.cpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL/generic.cpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL/generic.cpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp16_mla_8x3VL.hpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp16_mla_8x3VL/a64fx.cpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp16_mla_8x3VL/generic.cpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp32_mla_8x3VL.hpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp32_mla_8x3VL/a64fx.cpp create mode 100644 src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp32_mla_8x3VL/generic.cpp create mode 100644 src/core/NEON/kernels/assembly/winograd.hpp create mode 100644 src/core/NEON/kernels/convolution/winograd/input_transform.hpp rename src/core/NEON/kernels/convolution/winograd/{winograd_transforms/input_6x6_fp16_fp16_integers.cpp => input_transforms/a64_fp16_6x6.cpp} (95%) rename src/core/NEON/kernels/convolution/winograd/{winograd_transforms/input_6x6_fp32_fp32_integers.cpp => input_transforms/a64_fp32_6x6.cpp} (84%) rename src/core/NEON/kernels/convolution/winograd/{winograd_transforms/input_1x8_fp32_fp32_integers.cpp => input_transforms/arm_fp32_1x8.cpp} (91%) rename src/core/NEON/kernels/convolution/winograd/{winograd_transforms/input_4x4_fp32_fp32_integers.cpp => input_transforms/arm_fp32_4x4.cpp} (92%) create mode 100644 src/core/NEON/kernels/convolution/winograd/input_transforms/arm_fp32_6x6.cpp create mode 100644 src/core/NEON/kernels/convolution/winograd/input_transforms/sve_fp32_6x6.cpp create mode 100644 src/core/NEON/kernels/convolution/winograd/input_transforms_fp16.cpp create mode 100644 src/core/NEON/kernels/convolution/winograd/input_transforms_fp32.cpp create mode 100644 src/core/NEON/kernels/convolution/winograd/output_transform.hpp rename src/core/NEON/kernels/convolution/winograd/{winograd_transforms/output_4x4_3x3_fp16_fp16_integers.cpp => output_transforms/a64_fp16_4x4_3x3.cpp} (94%) rename src/core/NEON/kernels/convolution/winograd/{winograd_transforms/output_2_7_fp32_fp32_integers.cpp => output_transforms/arm_fp32_1x2_1x7.cpp} (70%) rename src/core/NEON/kernels/convolution/winograd/{winograd_transforms/output_4_5_fp32_fp32_integers.cpp => output_transforms/arm_fp32_1x4_1x5.cpp} (75%) rename src/core/NEON/kernels/convolution/winograd/{winograd_transforms/output_6_3_fp32_fp32_integers.cpp => output_transforms/arm_fp32_1x6_1x3.cpp} (77%) rename src/core/NEON/kernels/convolution/winograd/{winograd_transforms/output_2x2_3x3_fp32_fp32_integers.cpp => output_transforms/arm_fp32_2x2_3x3.cpp} (70%) rename src/core/NEON/kernels/convolution/winograd/{winograd_transforms/output_2x2_5x5_fp32_fp32_integers.cpp => output_transforms/arm_fp32_2x2_5x5.cpp} (73%) rename src/core/NEON/kernels/convolution/winograd/{winograd_transforms/output_4x4_3x3_fp32_fp32_integers.cpp => output_transforms/arm_fp32_4x4_3x3.cpp} (78%) create mode 100644 src/core/NEON/kernels/convolution/winograd/output_transforms_fp16.cpp create mode 100644 src/core/NEON/kernels/convolution/winograd/output_transforms_fp32.cpp create mode 100644 src/core/NEON/kernels/convolution/winograd/weight_transform.hpp rename src/core/NEON/kernels/convolution/winograd/{winograd_transforms/weights_4x4_3x3_fp16_fp16_integers.cpp => weight_transforms/a64_fp16_4x4_3x3.cpp} (67%) create mode 100644 src/core/NEON/kernels/convolution/winograd/weight_transforms/arm_fp32_2x2_3x3.cpp create mode 100644 src/core/NEON/kernels/convolution/winograd/weight_transforms/arm_fp32_2x2_5x5.cpp create mode 100644 src/core/NEON/kernels/convolution/winograd/weight_transforms/arm_fp32_4x4_3x3.cpp create mode 100644 src/core/NEON/kernels/convolution/winograd/weight_transforms/cpp_fp32_1x2_1x7.cpp create mode 100644 src/core/NEON/kernels/convolution/winograd/weight_transforms/cpp_fp32_1x4_1x5.cpp create mode 100644 src/core/NEON/kernels/convolution/winograd/weight_transforms/cpp_fp32_1x6_1x3.cpp create mode 100644 src/core/NEON/kernels/convolution/winograd/weight_transforms_fp16.cpp create mode 100644 src/core/NEON/kernels/convolution/winograd/weight_transforms_fp32.cpp delete mode 100644 src/core/NEON/kernels/convolution/winograd/winograd.cpp delete mode 100644 src/core/NEON/kernels/convolution/winograd/winograd.hpp create mode 100644 src/core/NEON/kernels/convolution/winograd/winograd_fp16.cpp create mode 100644 src/core/NEON/kernels/convolution/winograd/winograd_fp32.cpp create mode 100644 src/core/NEON/kernels/convolution/winograd/winograd_implementations.hpp delete mode 100644 src/core/NEON/kernels/convolution/winograd/winograd_layer.hpp delete mode 100644 src/core/NEON/kernels/convolution/winograd/winograd_transforms/input.hpp delete mode 100644 src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_4x4_fp16_fp16_integers.cpp delete mode 100644 src/core/NEON/kernels/convolution/winograd/winograd_transforms/kernel.hpp delete mode 100644 src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp delete mode 100644 src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_2_7_fp32_fp32_integers.cpp delete mode 100644 src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_2x2_3x3_fp32_fp32_integers.cpp delete mode 100644 src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_2x2_5x5_fp32_fp32_integers.cpp delete mode 100644 src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_4_5_fp32_fp32_integers.cpp delete mode 100644 src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_4x4_3x3_fp32_fp32_integers.cpp delete mode 100644 src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_6_3_fp32_fp32_integers.cpp rename src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/{ClElementwiseAddKernelComponent.cpp => ClElementwiseKernelComponent.cpp} (60%) create mode 100644 src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseKernelComponent.h create mode 100644 src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClFloorKernelComponent.cpp rename src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/{ClElementwiseAddKernelComponent.h => ClFloorKernelComponent.h} (74%) create mode 100644 src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp create mode 100644 src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel.h create mode 100644 src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp create mode 100644 src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.h create mode 100644 src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigBifrost.cpp create mode 100644 src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigBifrost.h create mode 100644 src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigValhall.cpp create mode 100644 src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigValhall.h create mode 100644 src/gpu/cl/kernels/direct_conv/ClDirectConvKernelConfig.h create mode 100644 src/gpu/cl/kernels/direct_conv/IClDirectConvKernelConfig.h create mode 100644 tests/validation/CL/GEMMLowpMatrixMultiplyReshapedOnlyRhsMMUL.cpp create mode 100644 tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRhsMMUL.cpp create mode 100644 tests/validation/CL/UNIT/dynamic_fusion/ArbitraryElementwiseFusion.cpp create mode 100644 tests/validation/CL/UNIT/dynamic_fusion/Floor.cpp diff --git a/Android.bp b/Android.bp index d1efc0a632..4a6ba4f3ab 100644 --- a/Android.bp +++ b/Android.bp @@ -40,8 +40,10 @@ opencl_srcs = [ "src/core/CL/cl_kernels/common/floor.cl", "src/core/CL/cl_kernels/common/gather.cl", "src/core/CL/cl_kernels/common/gemm.cl", + "src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl", "src/core/CL/cl_kernels/common/gemm_utils.cl", "src/core/CL/cl_kernels/common/gemmlowp.cl", + "src/core/CL/cl_kernels/common/gemmlowp_reshaped_only_rhs_mmul.cl", "src/core/CL/cl_kernels/common/gemv.cl", "src/core/CL/cl_kernels/common/generate_proposals.cl", "src/core/CL/cl_kernels/common/generate_proposals_quantized.cl", @@ -158,7 +160,6 @@ arm_compute_library_defaults { "-DARM_COMPUTE_ENABLE_NEON", "-Wno-unused-parameter", "-DNO_DOT_IN_TOOLCHAIN", - "-no-integrated-as", "-Wno-implicit-fallthrough" ], rtti: true, @@ -339,27 +340,30 @@ cc_library_static { "src/core/NEON/kernels/convolution/common/qasymm8.cpp", "src/core/NEON/kernels/convolution/common/qsymm8.cpp", "src/core/NEON/kernels/convolution/common/utils.cpp", + "src/core/NEON/kernels/convolution/winograd/input_transforms/arm_fp32_1x8.cpp", + "src/core/NEON/kernels/convolution/winograd/input_transforms/arm_fp32_4x4.cpp", + "src/core/NEON/kernels/convolution/winograd/input_transforms/arm_fp32_6x6.cpp", + "src/core/NEON/kernels/convolution/winograd/input_transforms_fp16.cpp", + "src/core/NEON/kernels/convolution/winograd/input_transforms_fp32.cpp", + "src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_1x2_1x7.cpp", + "src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_1x4_1x5.cpp", + "src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_1x6_1x3.cpp", + "src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_2x2_3x3.cpp", + "src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_2x2_5x5.cpp", + "src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_4x4_3x3.cpp", + "src/core/NEON/kernels/convolution/winograd/output_transforms_fp16.cpp", + "src/core/NEON/kernels/convolution/winograd/output_transforms_fp32.cpp", "src/core/NEON/kernels/convolution/winograd/padding.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_1x8_fp32_fp32_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_4x4_fp16_fp16_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_4x4_fp32_fp32_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_6x6_fp16_fp16_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_6x6_fp32_fp32_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2_7_fp32_fp32_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2x2_3x3_fp32_fp32_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2x2_5x5_fp32_fp32_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4_5_fp32_fp32_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4x4_3x3_fp16_fp16_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4x4_3x3_fp32_fp32_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_6_3_fp32_fp32_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_2_7_fp32_fp32_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_2x2_3x3_fp32_fp32_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_2x2_5x5_fp32_fp32_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_4_5_fp32_fp32_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_4x4_3x3_fp16_fp16_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_4x4_3x3_fp32_fp32_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_6_3_fp32_fp32_integers.cpp", + "src/core/NEON/kernels/convolution/winograd/weight_transforms/arm_fp32_2x2_3x3.cpp", + "src/core/NEON/kernels/convolution/winograd/weight_transforms/arm_fp32_2x2_5x5.cpp", + "src/core/NEON/kernels/convolution/winograd/weight_transforms/arm_fp32_4x4_3x3.cpp", + "src/core/NEON/kernels/convolution/winograd/weight_transforms/cpp_fp32_1x2_1x7.cpp", + "src/core/NEON/kernels/convolution/winograd/weight_transforms/cpp_fp32_1x4_1x5.cpp", + "src/core/NEON/kernels/convolution/winograd/weight_transforms/cpp_fp32_1x6_1x3.cpp", + "src/core/NEON/kernels/convolution/winograd/weight_transforms_fp16.cpp", + "src/core/NEON/kernels/convolution/winograd/weight_transforms_fp32.cpp", + "src/core/NEON/kernels/convolution/winograd/winograd_fp16.cpp", + "src/core/NEON/kernels/convolution/winograd/winograd_fp32.cpp", "src/core/Rounding.cpp", "src/core/Size2D.cpp", "src/core/Size3D.cpp", @@ -370,7 +374,8 @@ cc_library_static { "src/core/Version.cpp", "src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.cpp", "src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClDirectConvolutionKernelComponent.cpp", - "src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseAddKernelComponent.cpp", + "src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseKernelComponent.cpp", + "src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClFloorKernelComponent.cpp", "src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClStoreKernelComponents.cpp", "src/core/experimental/dynamic_fusion/OperatorGraph.cpp", "src/core/experimental/dynamic_fusion/WorkloadImpl/ClFusedKernelGraph.cpp", @@ -607,6 +612,7 @@ cc_library_static { "src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyNativeKernel.cpp", "src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedKernel.cpp", "src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedOnlyRhsKernel.cpp", + "src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp", "src/gpu/cl/kernels/ClGemmLowpOffsetContributionKernel.cpp", "src/gpu/cl/kernels/ClGemmLowpOffsetContributionOutputStageKernel.cpp", "src/gpu/cl/kernels/ClGemmLowpQuantizeDownInt32ScaleByFixedPointKernel.cpp", @@ -616,6 +622,7 @@ cc_library_static { "src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp", "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp", "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.cpp", + "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp", "src/gpu/cl/kernels/ClGemmReshapeLhsMatrixKernel.cpp", "src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.cpp", "src/gpu/cl/kernels/ClHeightConcatenateKernel.cpp", @@ -636,6 +643,8 @@ cc_library_static { "src/gpu/cl/kernels/ClWinogradFilterTransformKernel.cpp", "src/gpu/cl/kernels/ClWinogradInputTransformKernel.cpp", "src/gpu/cl/kernels/ClWinogradOutputTransformKernel.cpp", + "src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigBifrost.cpp", + "src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigValhall.cpp", "src/gpu/cl/kernels/experimental/dynamic_fusion/ClCompositeKernel.cpp", "src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp", "src/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeBifrost.cpp", @@ -1057,6 +1066,14 @@ cc_library_static { "src/core/NEON/kernels/arm_conv/pooling/kernels/sve_u8_nhwc_max_generic_depthfirst/generic.cpp", "src/core/NEON/kernels/arm_conv/pooling/kernels/sve_u8q_nhwc_avg_generic_depthfirst/generic.cpp", "src/core/NEON/kernels/arm_conv/pooling/kernels/sve_u8q_nhwc_max_generic_depthfirst/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_bf16fp32_mmla_6x16/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp16_mla_6x32/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32_mla_6x16/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_dot_8x12/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp16_mla_8x24/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp32_mla_8x12/generic.cpp", "src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s16_8x12/generic.cpp", "src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_4x4/generic.cpp", "src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_8x12/a55r1.cpp", @@ -1122,6 +1139,17 @@ cc_library_static { "src/core/NEON/kernels/arm_gemm/kernels/a64_smallK_hybrid_u8u32_dot_6x4/generic.cpp", "src/core/NEON/kernels/arm_gemm/kernels/a64_smallK_hybrid_u8u32_dot_8x4/a55.cpp", "src/core/NEON/kernels/arm_gemm/kernels/a64_smallK_hybrid_u8u32_dot_8x4/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_bf16fp32_mmla_6x4VL/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp16_mla_6x4VL/a64fx.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp16_mla_6x4VL/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32_mla_6x4VL/a64fx.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32_mla_6x4VL/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp16_mla_8x3VL/a64fx.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp16_mla_8x3VL/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp32_mla_8x3VL/a64fx.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp32_mla_8x3VL/generic.cpp", "src/core/NEON/kernels/arm_gemm/kernels/sve_hybrid_bf16fp32_dot_6x4VL/generic.cpp", "src/core/NEON/kernels/arm_gemm/kernels/sve_hybrid_bf16fp32_mmla_6x4VL/generic.cpp", "src/core/NEON/kernels/arm_gemm/kernels/sve_hybrid_fp16_mla_6x4VL/a64fx.cpp", @@ -1160,6 +1188,11 @@ cc_library_static { "src/core/NEON/kernels/arm_gemm/kernels/sve_smallK_hybrid_fp32_mla_8x1VL/generic.cpp", "src/core/NEON/kernels/arm_gemm/kernels/sve_smallK_hybrid_s8s32_dot_8x1VL/generic.cpp", "src/core/NEON/kernels/arm_gemm/kernels/sve_smallK_hybrid_u8u32_dot_8x1VL/generic.cpp", + "src/core/NEON/kernels/convolution/winograd/input_transforms/a64_fp16_6x6.cpp", + "src/core/NEON/kernels/convolution/winograd/input_transforms/a64_fp32_6x6.cpp", + "src/core/NEON/kernels/convolution/winograd/input_transforms/sve_fp32_6x6.cpp", + "src/core/NEON/kernels/convolution/winograd/output_transforms/a64_fp16_4x4_3x3.cpp", + "src/core/NEON/kernels/convolution/winograd/weight_transforms/a64_fp16_4x4_3x3.cpp", ], }, diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000..171d101bd1 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,11 @@ +Please read https://arm-software.github.io/ComputeLibrary/latest/contribution_guidelines.xhtml + +Here on github we only publish a snapshot of the main development branch for each release, that's the reason why we don't accept pull requests. + +Please submit your patch for review to review.mlplatform.org. + +The development is structured in the following way: + + Release repository: https://github.com/arm-software/ComputeLibrary + Development repository: https://review.mlplatform.org/#/admin/projects/ml/ComputeLibrary + Please report issues here: https://github.com/ARM-software/ComputeLibrary/issues diff --git a/README.md b/README.md index 22602a5737..e925c2feb4 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,19 @@ > **⚠ Important** -> From this release (22.05): 'master' branch has been replaced with 'main' following our inclusive language update, more information [here](https://arm-software.github.io/ComputeLibrary/latest/contribution_guidelines.xhtml#S5_0_inc_lang). +> From release 22.05: 'master' branch has been replaced with 'main' following our inclusive language update, more information [here](https://arm-software.github.io/ComputeLibrary/latest/contribution_guidelines.xhtml#S5_0_inc_lang). + +> **⚠ Important** +> From release 22.08: armv7a with Android build will no longer be tested or maintained.


-# Compute Library ![](https://img.shields.io/badge/latest_release-22.05-green) +# Compute Library ![](https://img.shields.io/badge/latest_release-22.08-green) -The Compute Library is a collection of low-level machine learning functions optimized for Arm® Cortex®-A and Arm® Mali™ GPUs architectures.
+The Compute Library is a collection of low-level machine learning functions optimized for Arm® Cortex®-A, Arm® Neoverse® and Arm® Mali™ GPUs architectures.
The library provides superior performance to other open source alternatives and immediate support for new Arm® technologies e.g. SVE2. @@ -35,7 +38,7 @@ Key Features:
## Documentation -[![Documentation](https://img.shields.io/badge/documentation-22.05-green)](https://arm-software.github.io/ComputeLibrary/latest) +[![Documentation](https://img.shields.io/badge/documentation-22.08-green)](https://arm-software.github.io/ComputeLibrary/latest) > Note: The documentation includes the reference API, changelogs, build guide, contribution guide, errata, etc. @@ -48,23 +51,23 @@ All the binaries can be downloaded from [here](https://github.com/ARM-software/C | Platform | Operating System | Release archive (Download) | | ----------- | ----------- | ----------- | -| Raspberry Pi 4 | Linux 32bit | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.05/arm_compute-v22.05-bin-linux-armv7a-neon.tar.gz) | -| Raspberry Pi 4 | Linux 64bit | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.05/arm_compute-v22.05-bin-linux-arm64-v8a-neon.tar.gz) | -| Odroid N2 | Linux 64bit | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.05/arm_compute-v22.05-bin-linux-arm64-v8a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.05/arm_compute-v22.05-bin-linux-arm64-v8a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.05/arm_compute-v22.05-bin-linux-arm64-v8a-neon-cl.tar.gz) | -| HiKey960 | Linux 64bit | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.05/arm_compute-v22.05-bin-linux-arm64-v8a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.05/arm_compute-v22.05-bin-linux-arm64-v8a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.05/arm_compute-v22.05-bin-linux-arm64-v8a-neon-cl.tar.gz) | +| Raspberry Pi 4 | Linux 32bit | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.08/arm_compute-v22.08-bin-linux-armv7a-neon.tar.gz) | +| Raspberry Pi 4 | Linux 64bit | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.08/arm_compute-v22.08-bin-linux-arm64-v8a-neon.tar.gz) | +| Odroid N2 | Linux 64bit | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.08/arm_compute-v22.08-bin-linux-arm64-v8a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.08/arm_compute-v22.08-bin-linux-arm64-v8a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.08/arm_compute-v22.08-bin-linux-arm64-v8a-neon-cl.tar.gz) | +| HiKey960 | Linux 64bit | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.08/arm_compute-v22.08-bin-linux-arm64-v8a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.08/arm_compute-v22.08-bin-linux-arm64-v8a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.08/arm_compute-v22.08-bin-linux-arm64-v8a-neon-cl.tar.gz) |
| Architecture | Operating System | Release archive (Download) | | ----------- | ----------- | ----------- | -| armv7 | Android | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.05/arm_compute-v22.05-bin-android-armv7a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.05/arm_compute-v22.05-bin-android-armv7a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.05/arm_compute-v22.05-bin-android-armv7a-neon-cl.tar.gz) | -| armv7 | Linux | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.05/arm_compute-v22.05-bin-linux-armv7a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.05/arm_compute-v22.05-bin-linux-armv7a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.05/arm_compute-v22.05-bin-linux-armv7a-neon-cl.tar.gz) | -| arm64-v8a | Android | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.05/arm_compute-v22.05-bin-android-arm64-v8a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.05/arm_compute-v22.05-bin-android-arm64-v8a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.05/arm_compute-v22.05-bin-android-arm64-v8a-neon-cl.tar.gz) | -| arm64-v8a | Linux | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.05/arm_compute-v22.05-bin-linux-arm64-v8a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.05/arm_compute-v22.05-bin-linux-arm64-v8a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.05/arm_compute-v22.05-bin-linux-arm64-v8a-neon-cl.tar.gz) | -| arm64-v8.2-a | Android | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.05/arm_compute-v22.05-bin-android-arm64-v8.2-a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.05/arm_compute-v22.05-bin-android-arm64-v8.2-a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.05/arm_compute-v22.05-bin-android-arm64-v8.2-a-neon-cl.tar.gz) | -| arm64-v8.2-a | Linux | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.05/arm_compute-v22.05-bin-linux-arm64-v8.2-a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.05/arm_compute-v22.05-bin-linux-arm64-v8.2-a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.05/arm_compute-v22.05-bin-linux-arm64-v8.2-a-neon-cl.tar.gz) | +| armv7 | Linux | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.08/arm_compute-v22.08-bin-linux-armv7a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.08/arm_compute-v22.08-bin-linux-armv7a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.08/arm_compute-v22.08-bin-linux-armv7a-neon-cl.tar.gz) | +| arm64-v8a | Android | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.08/arm_compute-v22.08-bin-android-arm64-v8a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.08/arm_compute-v22.08-bin-android-arm64-v8a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.08/arm_compute-v22.08-bin-android-arm64-v8a-neon-cl.tar.gz) | +| arm64-v8a | Linux | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.08/arm_compute-v22.08-bin-linux-arm64-v8a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.08/arm_compute-v22.08-bin-linux-arm64-v8a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.08/arm_compute-v22.08-bin-linux-arm64-v8a-neon-cl.tar.gz) | +| arm64-v8.2-a | Android | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.08/arm_compute-v22.08-bin-android-arm64-v8.2-a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.08/arm_compute-v22.08-bin-android-arm64-v8.2-a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.08/arm_compute-v22.08-bin-android-arm64-v8.2-a-neon-cl.tar.gz) | +| arm64-v8.2-a | Linux | [![](https://img.shields.io/badge/build-neon-orange)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.08/arm_compute-v22.08-bin-linux-arm64-v8.2-a-neon.tar.gz) [![](https://img.shields.io/badge/build-opencl-blue)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.08/arm_compute-v22.08-bin-linux-arm64-v8.2-a-cl.tar.gz) [![](https://img.shields.io/badge/build-neon+cl-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/download/v22.08/arm_compute-v22.08-bin-linux-arm64-v8.2-a-neon-cl.tar.gz) |
+Please refer to the following link for more pre-built binaries: [![](https://img.shields.io/badge/v22.08-bins-yellowgreen)](https://github.com/ARM-software/ComputeLibrary/releases/tag/v22.08) Pre-build binaries are generated with the following security / good coding practices related flags: > -Wall, -Wextra, -Wformat=2, -Winit-self, -Wstrict-overflow=2, -Wswitch-default, -Woverloaded-virtual, -Wformat-security, -Wctor-dtor-privacy, -Wsign-promo, -Weffc++, -pedantic, -fstack-protector-strong @@ -73,6 +76,7 @@ Pre-build binaries are generated with the following security / good coding pract - Arm® CPUs: - Arm® Cortex®-A processor family using Arm® Neon™ technology + - Arm® Neoverse® processor family - Arm® Cortex®-R processor family with Armv8-R AArch64 architecture using Arm® Neon™ technology - Arm® Cortex®-X1 processor using Arm® Neon™ technology @@ -89,6 +93,7 @@ Pre-build binaries are generated with the following security / good coding pract - Android™ - Bare Metal - Linux® +- OpenBSD® - macOS® - Tizen™ @@ -126,6 +131,17 @@ https://lists.linaro.org/mailman3/lists/acl-dev.lists.linaro.org/ The software is provided under MIT license. Contributions to this project are accepted under the same license. +### Other Projects +This project contains code from other projects as listed below. The original license text is included in those source files. + +* The OpenCL header library is licensed under Apache License, Version 2.0, which is a permissive license compatible with MIT license. + +* The half library is licensed under MIT license. + +* The libnpy library is licensed under MIT license. + +* The stb image library is either licensed under MIT license or is in Public Domain. It is used by this project under the terms of MIT license. +
## Trademarks and Copyrights diff --git a/SConscript b/SConscript index e1b1b2c5df..818a2bf9cb 100644 --- a/SConscript +++ b/SConscript @@ -31,8 +31,8 @@ import zlib import json import codecs -VERSION = "v22.05" -LIBRARY_VERSION_MAJOR = 27 +VERSION = "v22.08" +LIBRARY_VERSION_MAJOR = 28 LIBRARY_VERSION_MINOR = 0 LIBRARY_VERSION_PATCH = 0 SONAME_VERSION = str(LIBRARY_VERSION_MAJOR) + "." + str(LIBRARY_VERSION_MINOR) + "." + str(LIBRARY_VERSION_PATCH) @@ -369,12 +369,14 @@ if env['opencl'] and env['embed_kernels']: 'src/core/CL/cl_kernels/common/floor.cl', 'src/core/CL/cl_kernels/common/gather.cl', 'src/core/CL/cl_kernels/common/gemm.cl', + 'src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl', 'src/core/CL/cl_kernels/common/gemm_utils.cl', 'src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_native.cl', 'src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped.cl', 'src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped_only_rhs.cl', 'src/core/CL/cl_kernels/common/gemv.cl', 'src/core/CL/cl_kernels/common/gemmlowp.cl', + 'src/core/CL/cl_kernels/common/gemmlowp_reshaped_only_rhs_mmul.cl', 'src/core/CL/cl_kernels/common/generate_proposals.cl', 'src/core/CL/cl_kernels/common/generate_proposals_quantized.cl', 'src/core/CL/cl_kernels/common/instance_normalization.cl', @@ -500,6 +502,10 @@ if env['experimental_dynamic_fusion']: lib_files += filelist['experimental']['dynamic_fusion'] arm_compute_env.Append(CPPDEFINES = ['ENABLE_EXPERIMENTAL_DYNAMIC_FUSION']) +# Fixed format GEMM kernels. +if env['experimental_fixed_format_kernels']: + arm_compute_env.Append(CPPDEFINES = ['ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS']) + # Logging files if env["logging"]: @@ -576,6 +582,9 @@ if env['neon']: else: attrs = get_attrs_list(env, env['data_type_support'], env['data_layout_support']) + if env['experimental_fixed_format_kernels']: + attrs.append("experimental_fixed_format_kernels") + # Setup data-type and data-layout files to include cpu_operators = custom_operators if use_custom_ops else filelist['cpu']['operators'].keys() cpu_ops_to_build = resolve_operator_dependencies(filelist, cpu_operators, 'cpu') diff --git a/SConstruct b/SConstruct index 13da02628a..9e046161ac 100644 --- a/SConstruct +++ b/SConstruct @@ -89,12 +89,12 @@ vars.AddVariables( BoolVariable("debug", "Debug", False), BoolVariable("asserts", "Enable asserts (this flag is forced to 1 for debug=1)", False), BoolVariable("logging", "Enable Logging", False), - EnumVariable("arch", "Target Architecture", "armv7a", + EnumVariable("arch", "Target Architecture. The x86_32 and x86_64 targets can only be used with neon=0 and opencl=1.", "armv7a", allowed_values=("armv7a", "armv7a-hf", "arm64-v8a", "arm64-v8.2-a", "arm64-v8.2-a-sve", "arm64-v8.2-a-sve2", "x86_32", "x86_64", "armv8a", "armv8.2-a", "armv8.2-a-sve", "armv8.6-a", "armv8.6-a-sve", "armv8.6-a-sve2", "armv8r64", "x86")), EnumVariable("estate", "Execution State", "auto", allowed_values=("auto", "32", "64")), - EnumVariable("os", "Target OS", "linux", allowed_values=("linux", "android", "tizen", "macos", "bare_metal", "openbsd","windows")), - EnumVariable("build", "Build type", "cross_compile", allowed_values=("native", "cross_compile", "embed_only")), + EnumVariable("os", "Target OS. With bare metal selected, only Arm® Neon™ (not OpenCL) can be used, static libraries get built and Neon™'s multi-threading support is disabled.", "linux", allowed_values=("linux", "android", "tizen", "macos", "bare_metal", "openbsd","windows")), + EnumVariable("build", "Either build directly on your device (native) or cross compile from your desktop machine (cross-compile). In both cases make sure the compiler is available in your path.", "cross_compile", allowed_values=("native", "cross_compile", "embed_only")), BoolVariable("examples", "Build example programs", True), BoolVariable("gemm_tuner", "Build gemm_tuner programs", True), BoolVariable("Werror", "Enable/disable the -Werror compilation flag", True), @@ -102,23 +102,36 @@ vars.AddVariables( BoolVariable("standalone", "Builds the tests as standalone executables, links statically with libgcc, libstdc++ and libarm_compute", False), BoolVariable("opencl", "Enable OpenCL support", True), BoolVariable("neon", "Enable Arm® Neon™ support", False), - BoolVariable("embed_kernels", "Embed OpenCL kernels and OpenGL ES compute shaders in library binary", True), - BoolVariable("compress_kernels", "Compress embedded OpenCL kernels in library binary. Note embed_kernels should be enabled", False), - BoolVariable("set_soname", "Set the library's soname and shlibversion (requires SCons 2.4 or above)", False), - BoolVariable("openmp", "Enable OpenMP backend", False), + BoolVariable("embed_kernels", "Enable if you want the OpenCL kernels to be built in the library's binaries instead of being read from separate '.cl' / '.cs' files. If embed_kernels is set to 0 then the application can set the path to the folder containing the OpenCL kernel files by calling CLKernelLibrary::init(). By default the path is set to './cl_kernels'.", True), + BoolVariable("compress_kernels", "Compress embedded OpenCL kernels in library binary using zlib. Useful for reducing the binary size. embed_kernels should be enabled", False), + BoolVariable("set_soname", "If enabled the library will contain a SONAME and SHLIBVERSION and some symlinks will automatically be created between the objects. (requires SCons 2.4 or above)", False), + BoolVariable("openmp", "Enable OpenMP backend. Only works when building with g++ and not clang++", False), BoolVariable("cppthreads", "Enable C++11 threads backend", True), PathVariable("build_dir", "Specify sub-folder for the build", ".", PathVariable.PathAccept), PathVariable("install_dir", "Specify sub-folder for the install", "", PathVariable.PathAccept), BoolVariable("exceptions", "Enable/disable C++ exception support", True), BoolVariable("high_priority", "Generate a library containing only the high priority operators", False), PathVariable("linker_script", "Use an external linker script", "", PathVariable.PathAccept), - PathVariable("external_tests_dir", "Add examples, benchmarks and tests to the tests suite", "", PathVariable.PathAccept), + PathVariable("external_tests_dir", """Add examples, benchmarks and tests to the tests suite from an external path. In order to use this option, the external tests directory must have the following structure: + EXTERNAL_TESTS_DIR: + └── tests + ├── benchmark + │   ├── CL + │   ├── datasets + │   ├── fixtures + │   └── Neon + └── validation +     ├── CL +     ├── datasets +     ├── fixtures +     └── Neon\n""", "", PathVariable.PathAccept), BoolVariable("experimental_dynamic_fusion", "Build the experimental dynamic fusion files", False), + BoolVariable("experimental_fixed_format_kernels", "Enable fixed format kernels for GEMM", False), ListVariable("custom_options", "Custom options that can be used to turn on/off features", "none", ["disable_mmla_fp"]), ListVariable("data_type_support", "Enable a list of data types to support", "all", ["qasymm8", "qasymm8_signed", "qsymm16", "fp16", "fp32", "integer"]), ListVariable("data_layout_support", "Enable a list of data layout to support", "all", ["nhwc", "nchw"]), - ("toolchain_prefix", "Override the toolchain prefix; used by all toolchain components: compilers, linker, assembler etc.", ""), - ("compiler_prefix", "Override the compiler prefix; used by just compilers (CC,CXX); further overrides toolchain_prefix for compilers; if left empty, SCons only uses toolchain_prefix; this is for when the compiler prefixes are different from that of the linkers, archivers etc.", ""), + ("toolchain_prefix", "Override the toolchain prefix; used by all toolchain components: compilers, linker, assembler etc. If unspecified, use default(auto) prefixes; if passed an empty string '' prefixes would be disabled", "auto"), + ("compiler_prefix", "Override the compiler prefix; used by just compilers (CC,CXX); further overrides toolchain_prefix for compilers; this is for when the compiler prefixes are different from that of the linkers, archivers etc. If unspecified, this is the same as toolchain_prefix; if passed an empty string '' prefixes would be disabled", "auto"), ("extra_cxx_flags", "Extra CXX flags to be appended to the build command", ""), ("extra_link_flags", "Extra LD flags to be appended to the build command", ""), ("compiler_cache", "Command to prefix to the C and C++ compiler (e.g ccache)", ""), @@ -167,6 +180,9 @@ Export('install_bin') Help(vars.GenerateHelpText(env)) +if 'armv7a' in env['arch'] and env['os'] == 'android': + print("WARNING: armv7a on Android is no longer maintained") + if env['linker_script'] and env['os'] != 'bare_metal': print("Linker script is only supported for bare_metal builds") Exit(1) @@ -201,14 +217,14 @@ if not env['exceptions']: env.Append(CPPDEFINES = ['ARM_COMPUTE_EXCEPTIONS_DISABLED']) env.Append(CXXFLAGS = ['-fno-exceptions']) -env.Append(CXXFLAGS = ['-Wall','-DARCH_ARM', +env.Append(CXXFLAGS = ['-DARCH_ARM', '-Wextra','-Wdisabled-optimization','-Wformat=2', '-Winit-self','-Wstrict-overflow=2','-Wswitch-default', '-Woverloaded-virtual', '-Wformat-security', '-Wctor-dtor-privacy','-Wsign-promo','-Weffc++','-Wno-overlength-strings']) if not 'windows' in env['os']: - env.Append(CXXFLAGS = ['-std=c++14', '-pedantic' ]) + env.Append(CXXFLAGS = ['-Wall','-std=c++14', '-pedantic' ]) env.Append(CPPDEFINES = ['_GLIBCXX_USE_NANOSLEEP']) @@ -290,6 +306,17 @@ else: # NONE "multi_isa" builds env.Append(CXXFLAGS = ['-mfloat-abi=softfp']) else: env.Append(CXXFLAGS = ['-mfloat-abi=hard']) + elif 'v8.6-a' in env['arch']: + if 'armv8.6-a-sve2' == env['arch']: + env.Append(CXXFLAGS = ['-march=armv8.6-a+sve2']) + elif 'armv8.6-a-sve' == env['arch']: + env.Append(CXXFLAGS = ['-march=armv8.6-a+sve']) + elif 'armv8.6-a' == env['arch']: + env.Append(CXXFLAGS = ['-march=armv8.6-a']) + + env.Append(CPPDEFINES = ['ARM_COMPUTE_ENABLE_I8MM', 'ARM_COMPUTE_ENABLE_BF16','ARM_COMPUTE_ENABLE_FP16']) + if "disable_mmla_fp" not in env['custom_options']: + env.Append(CPPDEFINES = ['ARM_COMPUTE_ENABLE_SVEF32MM']) elif 'v8' in env['arch']: # Preserve the V8 archs for non-multi-ISA variants if 'sve2' in env['arch']: @@ -303,10 +330,6 @@ else: # NONE "multi_isa" builds else: env.Append(CXXFLAGS = ['-march=armv8-a']) - if 'v8.6-a' in env['arch']: - env.Append(CPPDEFINES = ['ARM_COMPUTE_ENABLE_I8MM', 'ARM_COMPUTE_ENABLE_BF16']) - if "disable_mmla_fp" not in env['custom_options']: - env.Append(CPPDEFINES = ['ARM_COMPUTE_ENABLE_SVEF32MM']) if 'v8.' in env['arch']: env.Append(CPPDEFINES = ['ARM_COMPUTE_ENABLE_FP16']) @@ -322,47 +345,57 @@ else: # NONE "multi_isa" builds # Define toolchain # The reason why we distinguish toolchain_prefix from compiler_prefix is for cases where the linkers/archivers use a # different prefix than the compilers. An example is the NDK r20 toolchain -toolchain_prefix = "" +auto_toolchain_prefix = "" if 'x86' not in env['arch']: if env['estate'] == '32': if env['os'] == 'linux': - toolchain_prefix = "arm-linux-gnueabihf-" if 'v7' in env['arch'] else "armv8l-linux-gnueabihf-" + auto_toolchain_prefix = "arm-linux-gnueabihf-" if 'v7' in env['arch'] else "armv8l-linux-gnueabihf-" elif env['os'] == 'bare_metal': - toolchain_prefix = "arm-eabi-" + auto_toolchain_prefix = "arm-eabi-" elif env['os'] == 'android': - toolchain_prefix = "arm-linux-androideabi-" + auto_toolchain_prefix = "arm-linux-androideabi-" elif env['os'] == 'tizen': - toolchain_prefix = "armv7l-tizen-linux-gnueabi-" + auto_toolchain_prefix = "armv7l-tizen-linux-gnueabi-" elif env['estate'] == '64' and 'v8' in env['arch']: if env['os'] == 'linux': - toolchain_prefix = "aarch64-linux-gnu-" + auto_toolchain_prefix = "aarch64-linux-gnu-" elif env['os'] == 'bare_metal': - toolchain_prefix = "aarch64-elf-" + auto_toolchain_prefix = "aarch64-elf-" elif env['os'] == 'android': - toolchain_prefix = "aarch64-linux-android-" + auto_toolchain_prefix = "aarch64-linux-android-" elif env['os'] == 'tizen': - toolchain_prefix = "aarch64-tizen-linux-gnu-" + auto_toolchain_prefix = "aarch64-tizen-linux-gnu-" -if env['build'] == 'native': +if env['build'] == 'native' or env["toolchain_prefix"] == "": toolchain_prefix = "" - -if env["toolchain_prefix"] != "": +elif env["toolchain_prefix"] == "auto": + toolchain_prefix = auto_toolchain_prefix +else: toolchain_prefix = env["toolchain_prefix"] -compiler_prefix = toolchain_prefix -if env["compiler_prefix"] != "": +if env['build'] == 'native' or env["compiler_prefix"] == "": + compiler_prefix = "" +elif env["compiler_prefix"] == "auto": + compiler_prefix = toolchain_prefix +else: compiler_prefix = env["compiler_prefix"] env['CC'] = env['compiler_cache']+ " " + compiler_prefix + c_compiler env['CXX'] = env['compiler_cache']+ " " + compiler_prefix + cpp_compiler env['LD'] = toolchain_prefix + "ld" env['AS'] = toolchain_prefix + "as" + if env['os'] == 'windows': - env['AR'] = "LIB" + env['AR'] = "llvm-lib" + env['RANLIB'] = "llvm-ranlib" else: env['AR'] = toolchain_prefix + "ar" + env['RANLIB'] = toolchain_prefix + "ranlib" +print("Using compilers:") +print("CC", env['CC']) +print("CXX", env['CXX']) if not GetOption("help"): try: @@ -393,12 +426,21 @@ if not GetOption("help"): if not version_at_least(compiler_ver, '7.0.0') and env['os'] == 'bare_metal': env.Append(LINKFLAGS = ['-fstack-protector-strong']) - # For NDK >= r21, clang 9 or above is used - if env['os'] == 'android' and version_at_least(compiler_ver, '9.0.0'): - env['ndk_above_r21'] = True + # Add Android NDK toolchain specific flags + if 'clang++' in cpp_compiler and env['os'] == 'android': + # For NDK >= r21, clang 9 or above is used + if version_at_least(compiler_ver, '9.0.0'): + env['ndk_above_r21'] = True + + if env['openmp']: + env.Append(LINKFLAGS = ['-static-openmp']) - if env['openmp']: - env.Append(LINKFLAGS = ['-static-openmp']) + # For NDK >= r23, clang 12 or above is used. This condition detects NDK < r23 + if not version_at_least(compiler_ver, '12.0.0'): + # System assembler is deprecated and integrated assembler is preferred since r23. + # However integrated assembler has always been suppressed for NDK < r23. + # Thus for backward compatibility, we include this flag only for NDK < r23 + env.Append(CXXFLAGS = ['-no-integrated-as']) if env['high_priority'] and env['build_config']: print("The high priority library cannot be built in conjunction with a user-specified build configuration") @@ -423,10 +465,10 @@ else: env = update_data_type_layout_flags(env, data_types, data_layouts) if env['standalone']: - if not 'windows' in env['os']: + if not 'windows' in env['os']: env.Append(CXXFLAGS = ['-fPIC']) - env.Append(LINKFLAGS = ['-static-libgcc','-static-libstdc++']) - + env.Append(LINKFLAGS = ['-static-libgcc','-static-libstdc++']) + if env['Werror']: env.Append(CXXFLAGS = ['-Werror']) @@ -464,7 +506,7 @@ if env['opencl']: print("Cannot link OpenCL statically, which is required for bare metal / standalone builds") Exit(1) -if env["os"] not in ["android", "bare_metal"] and (env['opencl'] or env['cppthreads']): +if env["os"] not in ["windows","android", "bare_metal"] and (env['opencl'] or env['cppthreads']): env.Append(LIBS = ['pthread']) if env['os'] == 'openbsd': @@ -480,7 +522,12 @@ if env['opencl']: if env['debug']: env['asserts'] = True - env.Append(CXXFLAGS = ['-O0','-g','-gdwarf-2']) + if not 'windows' in env['os']: + env.Append(CXXFLAGS = ['-O0','-g','-gdwarf-2']) + else: + env.Append(CXXFLAGS = ['-Z7','-MTd','-fms-compatibility','-fdelayed-template-parsing']) + env.Append(LINKFLAGS = ['-DEBUG']) + env.Append(CPPDEFINES = ['ARM_COMPUTE_DEBUG_ENABLED']) else: if not 'windows' in env['os']: @@ -488,10 +535,11 @@ else: else: # on windows we use clang-cl which does not support the option -O3 env.Append(CXXFLAGS = ['-O2']) - + if env['asserts']: env.Append(CPPDEFINES = ['ARM_COMPUTE_ASSERTS_ENABLED']) - env.Append(CXXFLAGS = ['-fstack-protector-strong']) + if not 'windows' in env['os']: + env.Append(CXXFLAGS = ['-fstack-protector-strong']) if env['logging']: env.Append(CPPDEFINES = ['ARM_COMPUTE_LOGGING_ENABLED']) diff --git a/arm_compute/AclOpenClExt.h b/arm_compute/AclOpenClExt.h index c349f76d86..ef80fd2443 100644 --- a/arm_compute/AclOpenClExt.h +++ b/arm_compute/AclOpenClExt.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -27,7 +27,7 @@ #include "arm_compute/AclTypes.h" #ifndef CL_TARGET_OPENCL_VERSION -#define CL_TARGET_OPENCL_VERSION 200 +#define CL_TARGET_OPENCL_VERSION 300 #define CL_USE_DEPRECATED_OPENCL_1_1_APIS #define CL_USE_DEPRECATED_OPENCL_1_2_APIS #endif /* CL_TARGET_OPENCL_VERSION */ diff --git a/arm_compute/core/CL/CLDevice.h b/arm_compute/core/CL/CLDevice.h index 06aaac88f4..5e0f86e6d9 100644 --- a/arm_compute/core/CL/CLDevice.h +++ b/arm_compute/core/CL/CLDevice.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 Arm Limited. + * Copyright (c) 2020-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -143,6 +143,32 @@ class CLDevice : public IDevice return _options.extensions.count(extension) != 0; } + /** Returns whether non-uniform workgroup is supported and the build options. + * + * If the feature is supported, the appropriate build options will be + * appended to the specified string. + * + * @return A tuple (supported, build_options) indicating whether the feature + * is supported and the corresponding build options to enable it. + */ + std::tuple is_non_uniform_workgroup_supported() const + { + if(version() == CLVersion::CL30 && get_cl_non_uniform_work_group_supported(_device)) + { + return {true, " -cl-std=CL3.0 "}; + } + else if(version() == CLVersion::CL20) + { + return {true, " -cl-std=CL2.0 "}; + } + else if(supported("cl_arm_non_uniform_work_group_size")) + { + return {true, " -cl-arm-non-uniform-work-group-size "}; + } + + return {false, ""}; + } + private: cl::Device _device; /**< OpenCL device. */ struct CLDeviceOptions _options; /**< OpenCL device options */ diff --git a/arm_compute/core/CL/CLHelpers.h b/arm_compute/core/CL/CLHelpers.h index 729eb13398..edbc705c6f 100644 --- a/arm_compute/core/CL/CLHelpers.h +++ b/arm_compute/core/CL/CLHelpers.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2021 Arm Limited. + * Copyright (c) 2016-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -124,6 +124,14 @@ CLVersion get_cl_version(const cl::Device &device); */ size_t get_cl_image_pitch_alignment(const cl::Device &device); +/** Helper function to check whether non-uniform work group is supported + * + * @param[in] device A CL device + * + * @return True if the feature is supported + */ +bool get_cl_non_uniform_work_group_supported(const cl::Device &device); + /** Helper function to check whether a given extension is supported * * @param[in] device A CL device @@ -252,5 +260,12 @@ bool export_weights_to_cl_image(const ITensorInfo *tensor); */ void set_unroll_with_pragma(CLBuildOptions &built_opts, std::initializer_list values); +/** Helper function to check whether the cl_arm_matrix_multiply extension is supported + * + * @param[in] device A CL device + * + * @return True if the extension is supported + */ +bool arm_matrix_multiply_supported(const cl::Device &device); } // namespace arm_compute #endif /* ARM_COMPUTE_CLHELPERS_H */ diff --git a/arm_compute/core/CL/CLTypes.h b/arm_compute/core/CL/CLTypes.h index ede8d0a9e4..00b7cda2e1 100644 --- a/arm_compute/core/CL/CLTypes.h +++ b/arm_compute/core/CL/CLTypes.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -41,7 +41,8 @@ enum class CLVersion CL10, /* the OpenCL 1.0 */ CL11, /* the OpenCL 1.1 */ CL12, /* the OpenCL 1.2 */ - CL20, /* the OpenCL 2.0 and above */ + CL20, /* the OpenCL 2.x */ + CL30, /* the OpenCL 3.x */ UNKNOWN /* unkown version */ }; @@ -81,7 +82,7 @@ enum CLKernelType UNKNOWN, /**< Unknown CL kernel type */ DEPTHWISE, /**< Depthwise CL kernel type */ DIRECT, /**< Direct Convolution CL kernel type */ - ELEMENTWISE, /**< Elementeise CL kernel type */ + ELEMENTWISE, /**< Elementwise CL kernel type */ GEMM, /**< GEMM CL kernel type */ POOL, /**< Pool CL kernel type */ WINOGRAD /**< Winograd CL kernel type */ diff --git a/arm_compute/core/CL/OpenCL.h b/arm_compute/core/CL/OpenCL.h index 4ff42c6b8a..058214bd07 100644 --- a/arm_compute/core/CL/OpenCL.h +++ b/arm_compute/core/CL/OpenCL.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2021 Arm Limited. + * Copyright (c) 2016-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -31,7 +31,7 @@ #ifndef ARM_COMPUTE_NO_EXCEPTIONS #define CL_HPP_ENABLE_EXCEPTIONS #endif // ARM_COMPUTE_NO_EXCEPTIONS -#define CL_TARGET_OPENCL_VERSION 200 +#define CL_TARGET_OPENCL_VERSION 300 #define CL_HPP_TARGET_OPENCL_VERSION 110 #define CL_HPP_MINIMUM_OPENCL_VERSION 110 #pragma GCC diagnostic push @@ -75,11 +75,12 @@ class CLSymbols final static CLSymbols &get(); /** Load symbols from the given OpenCL library path. * - * @param[in] library Path to the OpenCL library. + * @param[in] library Path to the OpenCL library. + * @param[in] use_loader Use symbol loader function loadOpenCLPointer. * * @return True if loading the library is successful. */ - bool load(const std::string &library); + bool load(const std::string &library, bool use_loader = false); /** Load symbols from any of the default OpenCL library names. * * @return True if loading any library is successful. diff --git a/arm_compute/core/CPP/CPPTypes.h b/arm_compute/core/CPP/CPPTypes.h index a021bdf5e4..afefb1aeb0 100644 --- a/arm_compute/core/CPP/CPPTypes.h +++ b/arm_compute/core/CPP/CPPTypes.h @@ -127,6 +127,16 @@ class CPUInfo final * @return true of the cpu supports sve2, false otherwise */ bool has_sve2() const; + /** Checks if the cpu model supports sme. + * + * @return true of the cpu supports sme, false otherwise + */ + bool has_sme() const; + /** Checks if the cpu model supports sme2. + * + * @return true of the cpu supports sme2, false otherwise + */ + bool has_sme2() const; /** Gets the cpu model for a given cpuid. * * @param[in] cpuid the id of the cpu core to be retrieved, diff --git a/arm_compute/core/CPP/ICPPKernel.h b/arm_compute/core/CPP/ICPPKernel.h index 4697316379..00a10555e3 100644 --- a/arm_compute/core/CPP/ICPPKernel.h +++ b/arm_compute/core/CPP/ICPPKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2021 Arm Limited. + * Copyright (c) 2016-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -95,7 +95,7 @@ class ICPPKernel : public IKernel * @param[in] platform The CPU platform used to create the context. * @param[in] thread_count Number of threads in the execution. * - * @return[out] mws Minimum workload size for requsted configuration. + * @return Minimum workload size for requested configuration. */ virtual size_t get_mws(const CPUInfo &platform, size_t thread_count) const { diff --git a/arm_compute/core/GPUTarget.h b/arm_compute/core/GPUTarget.h index c4f5b8ca42..affa79a89e 100644 --- a/arm_compute/core/GPUTarget.h +++ b/arm_compute/core/GPUTarget.h @@ -33,26 +33,35 @@ namespace arm_compute /** Available GPU Targets */ enum class GPUTarget { - UNKNOWN = 0x101, - GPU_ARCH_MASK = 0xF00, - MIDGARD = 0x100, - BIFROST = 0x200, - VALHALL = 0x300, - T600 = 0x110, - T700 = 0x120, - T800 = 0x130, - G71 = 0x210, - G72 = 0x220, - G51 = 0x230, - G51BIG = 0x231, - G51LIT = 0x232, - G52 = 0x240, - G52LIT = 0x241, - G31 = 0x242, - G76 = 0x250, - G77 = 0x310, - G78 = 0x320, - G710 = 0x330, + UNKNOWN = 0x101, + GPU_ARCH_MASK = 0xF00, + GPU_GENERATION_MASK = 0x0F0, + MIDGARD = 0x100, + BIFROST = 0x200, + VALHALL = 0x300, + T600 = 0x110, + T700 = 0x120, + T800 = 0x130, + G71 = 0x210, + G72 = 0x220, + G51 = 0x221, + G51BIG = 0x222, + G51LIT = 0x223, + G31 = 0x224, + G76 = 0x230, + G52 = 0x231, + G52LIT = 0x232, + G77 = 0x310, + G57 = 0x311, + G78 = 0x320, + G68 = 0x321, + G78AE = 0x330, + G710 = 0x340, + G610 = 0x341, + G510 = 0x342, + G310 = 0x343, + G715 = 0x350, + G615 = 0x351, }; /** Enable bitwise operations on GPUTarget enumerations */ diff --git a/arm_compute/core/KernelDescriptors.h b/arm_compute/core/KernelDescriptors.h index b1086494e4..c45be9c06f 100644 --- a/arm_compute/core/KernelDescriptors.h +++ b/arm_compute/core/KernelDescriptors.h @@ -109,6 +109,15 @@ struct DWCComputeKernelInfo bool export_weights_to_cl_image{ false }; /**< Export the weights to cl_image */ }; +/** Compute descriptor used by the direct convolution kernel */ +struct DirectConvComputeKernelInfo +{ + int32_t m0{ 1 }; /**< Number of rows to be processed by the kernel */ + int32_t n0{ 1 }; /**< Number of columns to be processed by the kernel */ + int32_t k0{ 1 }; /**< Number of partial accumulations to be processed in a single iteration by the kernel */ + bool export_weights_to_cl_image{ false }; /**< Flag to export the weights to cl_image */ +}; + /** Descriptor used by the softmax kernels */ struct SoftmaxKernelInfo { diff --git a/arm_compute/core/QuantizationInfo.h b/arm_compute/core/QuantizationInfo.h index b331f7d923..21d962d08b 100644 --- a/arm_compute/core/QuantizationInfo.h +++ b/arm_compute/core/QuantizationInfo.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021 Arm Limited. + * Copyright (c) 2019-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -399,6 +399,27 @@ inline float dequantize_qsymm8(int8_t value, const UniformQuantizationInfo &qinf return value * qinfo.scale; } +inline qasymm8_t qasymm8_hard_swish(qasymm8_t in, + const UniformQuantizationInfo &qi_in, + const UniformQuantizationInfo &qi_out) +{ + float tmp_f = dequantize_qasymm8(in, qi_in); + tmp_f = tmp_f * ((std::min(std::max((tmp_f + 3), 0.0f), 6.0f)) * 0.166666667f); + const qasymm8_t tmp = quantize_qasymm8(tmp_f, qi_out); + return tmp; +} + +inline qasymm8_t qasymm8_leaky_relu(qasymm8_t in, + const UniformQuantizationInfo &qi_in, + const UniformQuantizationInfo &qi_out, + float alpha) +{ + float tmp_f = dequantize_qasymm8(in, qi_in); + tmp_f = tmp_f > 0 ? tmp_f : tmp_f * alpha; + const qasymm8_t tmp = quantize_qasymm8(tmp_f, qi_out); + return tmp; +} + /** Dequantize a value given a 8-bit symmetric quantization scheme * * @param[in] value Value to dequantize diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index 7ae6a7e67e..952c174194 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -774,10 +774,10 @@ class PadStrideInfo private: std::pair _stride; - unsigned int _pad_left; - unsigned int _pad_top; - unsigned int _pad_right; - unsigned int _pad_bottom; + unsigned int _pad_left; + unsigned int _pad_top; + unsigned int _pad_right; + unsigned int _pad_bottom; DimensionRoundingType _round_type; }; @@ -919,14 +919,14 @@ class PriorBoxLayerInfo final } private: - std::vector _min_sizes; - std::vector _variances; - float _offset; - bool _flip; - bool _clip; - std::vector _max_sizes; - std::vector _aspect_ratios; - Coordinates2D _img_size; + std::vector _min_sizes; + std::vector _variances; + float _offset; + bool _flip; + bool _clip; + std::vector _max_sizes; + std::vector _aspect_ratios; + Coordinates2D _img_size; std::array _steps; }; @@ -1171,15 +1171,15 @@ class DetectionPostProcessLayerInfo final } private: - unsigned int _max_detections; - unsigned int _max_classes_per_detection; - float _nms_score_threshold; - float _iou_threshold; - unsigned int _num_classes; + unsigned int _max_detections; + unsigned int _max_classes_per_detection; + float _nms_score_threshold; + float _iou_threshold; + unsigned int _num_classes; std::array _scales_values; - bool _use_regular_nms; - unsigned int _detection_per_class; - bool _dequantize_scores; + bool _use_regular_nms; + unsigned int _detection_per_class; + bool _dequantize_scores; }; /** Pooling Layer Information struct*/ @@ -1612,13 +1612,13 @@ class BoundingBoxTransformInfo final } private: - float _img_width; - float _img_height; - float _scale; - bool _apply_scale; - bool _correct_transform_coords; + float _img_width; + float _img_height; + float _scale; + bool _apply_scale; + bool _correct_transform_coords; std::array _weights; - float _bbox_xform_clip; + float _bbox_xform_clip; }; /** Activation Layer Information class */ @@ -1644,6 +1644,9 @@ class ActivationLayerInfo HARD_SWISH /**< Hard-swish ( \f$ f(x) = (x * relu6(x+3))/6 \f$ ) */ }; + /** Lookup table */ + using LookupTable256 = std::array; + ActivationLayerInfo() = default; /** Default Constructor * @@ -1677,11 +1680,62 @@ class ActivationLayerInfo return _enabled; } +#ifdef __aarch64__ + const LookupTable256 &lut() const + { + return _lut; + } + + void init_lut(const UniformQuantizationInfo &qi_in, const UniformQuantizationInfo &qi_out) + { + if(_act == ActivationFunction::HARD_SWISH) + { + qasymm8_hard_swish_populate_table(_lut, qi_in, qi_out); + } + else if(_act == ActivationFunction::LEAKY_RELU) + { + qasymm8_leaky_relu_populate_table(_lut, qi_in, qi_out, _a); + } + } +#endif // __aarch64__ + + static inline bool is_lut_supported(ActivationFunction act_func, DataType data_type) + { +#ifdef __aarch64__ + auto supported = (data_type == DataType::QASYMM8 && (act_func == ActivationFunction::HARD_SWISH || act_func == ActivationFunction::LEAKY_RELU)); + return supported; +#else // __aarch64__ + ARM_COMPUTE_UNUSED(act_func); + ARM_COMPUTE_UNUSED(data_type); + return false; +#endif // __aarch64__ + } + private: ActivationFunction _act = { ActivationLayerInfo::ActivationFunction::IDENTITY }; float _a = {}; float _b = {}; bool _enabled = { false }; + +#ifdef __aarch64__ + LookupTable256 _lut = {}; + + static inline void qasymm8_hard_swish_populate_table(LookupTable256 &lut, const UniformQuantizationInfo &qi_in, const UniformQuantizationInfo &qi_out) + { + for(size_t i = 0; i < lut.size(); ++i) + { + lut[i] = qasymm8_hard_swish(i, qi_in, qi_out); + } + } + + static inline void qasymm8_leaky_relu_populate_table(LookupTable256 &lut, const UniformQuantizationInfo &qi_in, const UniformQuantizationInfo &qi_out, float alpha) + { + for(size_t i = 0; i < lut.size(); ++i) + { + lut[i] = qasymm8_leaky_relu(i, qi_in, qi_out, alpha); + } + } +#endif // __aarch64__ }; /** Fully connected layer info */ @@ -1840,13 +1894,121 @@ class StridedSliceLayerInfo int32_t _shrink_axis_mask; }; +/** Memory layouts for the weights tensor. + * + * * UNSPECIFIED is used to select kernels that do not run in + * variable weights mode. + * + * * ANY is used to query the kernel database to retrieve any of the + * kernels that runs in variable weights mode. Once a kernel is + * found, the specific format expected by the kernel can be + * retrieved by the user for reordering the weights tensor + * accordingly. + * + * The other values OHWIo{interleave_by}i{block_by} describe the + * memory layout of a 4D tensor with layout OHWI that has been + * transformed into a 4D tensor with dimensions O'HWI' where: + * + * O' = first multiple of {interleave_by} s.t. O<=O' + * I' = first multiple of {block_by} s.t. I<=I' + * + * The total size of the dst tensor is O' x H x W x I' + * + * The access function of the tensor with layout + * OHWIo{interleave_by}i{block_by} and size O'HWI' is a 6-parameter + * access function, where the 6 parameters are computed as follows: + * + * x5 = floor(o/{interleave_by}) RANGE [0, O'/{interleave_by} -1] SIZE: O'/{interleave_by} + * + * x4 = h RANGE [0, H-1] SIZE: H + * x3 = w RANGE [0, W-1] SIZE: W + * x2 = floor(i/{block_by}) RANGE [0, I'/{block_by} -1] SIZE: I'/{block_by} + * x1 = o%{interleave_by} RANGE [0, {interleave_by} -1] SIZE: {interleave_by} + * x0 = i%{block_by} RANGE [0, {block_by} -1] SIZE: {block_by} + * TOTAL SIZE: O' * H * W * I' + * + * 4D 6D + * ----------------- ----------------------------------- + * value(o, h, w, i) = x5 * H * W * I' * {interleave_by} + * + x4 * W * I' * {interleave_by} + * + x3 * I' * {interleave_by} + * + x2 * {interleave_by} * {block_by} + * + x1 * {block_by} + * + x0 + * + * Notice that in arm_gemm the 4D tensor of dimension O'HWI' created + * for the OHWIo{interleave_by}i{block_by} format is in reality seen + * as a 2D tensor, where the number of rows is O'/{interleave_by} + * and the number of columns is {interleave_by} * H * W * I'. + * + * The postfix *_bf16 is for the memory layout needed for the + * fast-mode kernels, in which the weights are passed in bfloat16 + * format. + */ +enum class WeightFormat +{ + UNSPECIFIED = 0x1, + ANY = 0x2, + OHWI = 0x100100, + OHWIo2 = 0x100200, + OHWIo4 = 0x100400, + OHWIo8 = 0x100800, + OHWIo16 = 0x101000, + OHWIo32 = 0x102000, + OHWIo64 = 0x104000, + OHWIo128 = 0x108000, + OHWIo4i2 = 0x200400, + OHWIo4i2_bf16 = 0x200410, + OHWIo8i2 = 0x200800, + OHWIo8i2_bf16 = 0x200810, + OHWIo16i2 = 0x201000, + OHWIo16i2_bf16 = 0x201010, + OHWIo32i2 = 0x202000, + OHWIo32i2_bf16 = 0x202010, + OHWIo64i2 = 0x204000, + OHWIo64i2_bf16 = 0x204010, + OHWIo4i4 = 0x400400, + OHWIo4i4_bf16 = 0x400410, + OHWIo8i4 = 0x400800, + OHWIo8i4_bf16 = 0x400810, + OHWIo16i4 = 0x401000, + OHWIo16i4_bf16 = 0x401010, + OHWIo32i4 = 0x402000, + OHWIo32i4_bf16 = 0x402010, + OHWIo64i4 = 0x404000, + OHWIo64i4_bf16 = 0x404010, + OHWIo2i8 = 0x800200, + OHWIo4i8 = 0x800400, + OHWIo8i8 = 0x800800, + OHWIo16i8 = 0x801000, + OHWIo32i8 = 0x802000, + OHWIo64i8 = 0x804000 +}; +// OHWIoi +inline int interleave_by(const WeightFormat wf) +{ + return (static_cast(wf) >> 8) & 0xFFF; +} +inline int block_by(const WeightFormat wf) +{ + return (static_cast(wf) >> 20) & 0xF; +} +inline bool is_fixed_format(const WeightFormat &wf) +{ + return wf != WeightFormat::UNSPECIFIED && wf != WeightFormat::ANY; +} +inline bool is_fixed_format_fast_math(const WeightFormat &wf) +{ + return (static_cast(wf) >> 4) & 0x1; +} + /** Convolution Layer Weights Information class. This class stores the necessary information to compute convolution layer when the weights are already reshaped */ class WeightsInfo { public: /** Default constructor */ WeightsInfo() - : _are_reshaped(false), _kernel_width(0), _kernel_height(0), _num_kernels(0), _retain_internal_weights(false) + : _are_reshaped(false), _kernel_width(0), _kernel_height(0), _num_kernels(0), _retain_internal_weights(false), _weight_format(arm_compute::WeightFormat::UNSPECIFIED) { } /** Constructor @@ -1856,9 +2018,11 @@ class WeightsInfo * @param[in] kernel_height Kernel height. * @param[in] num_kernels Number of convolution kernels. * @param[in] retain_internal_weights (Optional) True if internal reshaped weights must be retained. Used for reconfiguration purposes. Default is false. + * @param[in] weight_format (Optional) arm_gemm:WeightFormat enumeration requested by the user. Default is arm_compute::WeightFormat::UNSPECIFIED. */ - WeightsInfo(bool are_reshaped, unsigned int kernel_width, unsigned int kernel_height, unsigned int num_kernels, bool retain_internal_weights = false) - : _are_reshaped(are_reshaped), _kernel_width(kernel_width), _kernel_height(kernel_height), _num_kernels(num_kernels), _retain_internal_weights(retain_internal_weights) + WeightsInfo(bool are_reshaped, unsigned int kernel_width, unsigned int kernel_height, unsigned int num_kernels, bool retain_internal_weights = false, + arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED) + : _are_reshaped(are_reshaped), _kernel_width(kernel_width), _kernel_height(kernel_height), _num_kernels(num_kernels), _retain_internal_weights(retain_internal_weights), _weight_format(weight_format) { } /** Flag which specifies if the weights tensor has been reshaped. @@ -1889,13 +2053,31 @@ class WeightsInfo { return _retain_internal_weights; } + arm_compute::WeightFormat weight_format() const + { + return _weight_format; + } + void set_weight_format(arm_compute::WeightFormat weight_format) + { + _weight_format = weight_format; + } + + unsigned int kernel_width() const + { + return _kernel_width; + } + unsigned int kernel_height() const + { + return _kernel_height; + } private: - bool _are_reshaped; - unsigned int _kernel_width; - unsigned int _kernel_height; - unsigned int _num_kernels; - bool _retain_internal_weights; + bool _are_reshaped; + unsigned int _kernel_width; + unsigned int _kernel_height; + unsigned int _num_kernels; + bool _retain_internal_weights; + arm_compute::WeightFormat _weight_format; }; /** GEMM reshape information class. This class stores the necessary information about matrix A and matrix B reshape. @@ -2105,7 +2287,9 @@ class GEMMInfo _pretranspose_A(false), _pretranspose_B(false), _activation_info(), - _post_ops() + _post_ops(), + _fixed_format(false), + _weight_format(arm_compute::WeightFormat::UNSPECIFIED) { } /** Constructor @@ -2124,10 +2308,13 @@ class GEMMInfo * @param[in] broadcast_bias (Optional) Broadcast the shape of the bias tensor from a vector to a matrix. * @param[in] activation_info (Optional) Activation to apply after the matrix multiplication * @param[in] post_ops (Optional) A sequence of post operations that are performed after the main operation. + * @param[in] fixed_format (Optional) Specify the selection of fixed format kernels for variable weights support in GEMM. These kernels expect the weights tensor to be in amemory format that is fixed by the kernel itself. For more information, see arm_compute::WeightFormat. + * @param[in] weight_format (Optional) arm_gemm:WeightFormat enumeration requested by the user. Default is arm_compute::WeightFormat::UNSPECIFIED. */ GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 0, bool reinterpret_input_as_3d = false, bool retain_internal_weights = false, GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false, bool fast_math = false, bool broadcast_bias = false, - const ActivationLayerInfo &activation_info = ActivationLayerInfo(), const experimental::PostOpList &post_ops = experimental::PostOpList()) noexcept + const ActivationLayerInfo &activation_info = ActivationLayerInfo(), const experimental::PostOpList &post_ops = experimental::PostOpList(), + bool fixed_format = false, arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED) noexcept : _is_a_reshaped(is_a_reshaped), _is_b_reshaped(is_b_reshaped), _reshape_b_only_on_first_run(reshape_b_only_on_first_run), @@ -2141,7 +2328,9 @@ class GEMMInfo _pretranspose_A(false), _pretranspose_B(false), _activation_info(activation_info), - _post_ops(post_ops) + _post_ops(post_ops), + _fixed_format(fixed_format), + _weight_format(weight_format) { } /** Flag which specifies if the matrix A has been reshaped @@ -2306,6 +2495,37 @@ class GEMMInfo { _post_ops = post_ops; } + /** Flag which specifies if the GEMM operation is running fixed-format kernels. + * + * @return True if the GEMM operation is running fixed-format kernel else false. + */ + bool fixed_format() const + { + return _fixed_format; + } + + /** Set fixed-format flag + * + * @param[in] fixed_format sets whether or not to use fixed-format kernels + */ + void set_fixed_format(bool fixed_format) + { + _fixed_format = fixed_format; + } + + arm_compute::WeightFormat weight_format() const + { + return _weight_format; + } + + /** Set weight format to be used + * + * @param[in] weight_format arm_compute::WeightFormat enumeration + */ + void set_weight_format(arm_compute::WeightFormat weight_format) + { + _weight_format = weight_format; + } private: bool _is_a_reshaped; @@ -2322,6 +2542,8 @@ class GEMMInfo bool _pretranspose_B; ActivationLayerInfo _activation_info; experimental::PostOpList _post_ops; + bool _fixed_format; + arm_compute::WeightFormat _weight_format; }; /** Winograd information */ diff --git a/arm_compute/core/Window.h b/arm_compute/core/Window.h index c566cffa88..440b942dcf 100644 --- a/arm_compute/core/Window.h +++ b/arm_compute/core/Window.h @@ -90,6 +90,7 @@ class Window : _start(start), _end(end), _step(step) { } + Dimension(const Dimension &d) = default; /** Default assignment operator to allow dimensions to be copied */ Dimension &operator=(const Dimension &d) = default; /** Return the start of the dimension */ diff --git a/arm_compute/core/experimental/OperatorGraph.h b/arm_compute/core/experimental/OperatorGraph.h index fd8fcd5c47..cab83c7f8b 100644 --- a/arm_compute/core/experimental/OperatorGraph.h +++ b/arm_compute/core/experimental/OperatorGraph.h @@ -176,17 +176,18 @@ Operator add_op_conv2d(OperatorGraph &graph, const Conv2dDescriptor &desc, OpTen */ void force_conv2d_method(OperatorGraph &graph, Operator conv2d, ConvolutionMethod method); -/** Descriptor for Addition operation +/** Descriptor for Elementwise binary operation * */ -struct AddDescriptor +struct ElementwiseDescriptor { /* TOSA compliant attribute parameters start */ /* TOSA compliant attribute parameters end */ /* Non-TOSA compliant attribute parameters start */ + ArithmeticOperation op; /* Non-TOSA compliant attribute parameters end */ }; -/** Add op Add to @p graph, and optionally describes fusion through passing of intermediate @ref OpTensor s +/** Add op Elementwise to @p graph, and optionally describes fusion through passing of intermediate @ref OpTensor s * * @param[in,out] graph OperatorGraph where the operator is added to * @param[in] desc Operator descriptor @@ -196,12 +197,34 @@ struct AddDescriptor * * @return Operator */ -Operator add_op_elementwise_add(OperatorGraph &graph, const AddDescriptor &desc, OpTensor lhs, OpTensor rhs, OpTensor dst); +Operator add_op_elementwise_op(OperatorGraph &graph, const ElementwiseDescriptor &desc, OpTensor lhs, OpTensor rhs, OpTensor dst); + +/** Descriptor for Floor operation + * + */ +struct FloorDescriptor +{ + /* TOSA compliant attribute parameters start */ + /* TOSA compliant attribute parameters end */ + /* Non-TOSA compliant attribute parameters start */ + /* Non-TOSA compliant attribute parameters end */ +}; +/** Add op Floor to @p graph, and optionally describes fusion through passing of intermediate @ref OpTensor s + * + * @param[in,out] graph OperatorGraph where the operator is added to + * @param[in] desc Operator descriptor + * @param[in] src Source OpTensor + * @param[in] dst Destination OpTensor + * + * @return Operator + */ +Operator add_op_floor(OperatorGraph &graph, const FloorDescriptor &desc, OpTensor src, OpTensor dst); bool operator==(const OpTensor &t0, const OpTensor &t1); bool operator==(const Padding2D &pad0, const Padding2D &pad1); bool operator==(const Conv2dDescriptor &conv2d0, const Conv2dDescriptor &conv2d1); -bool operator==(const AddDescriptor &, const AddDescriptor &); +bool operator==(const ElementwiseDescriptor &, const ElementwiseDescriptor &); +bool operator==(const FloorDescriptor &, const FloorDescriptor &); } // namespace dynamic_fusion } // namespace experimental diff --git a/arm_compute/core/utils/misc/MMappedFile.h b/arm_compute/core/utils/misc/MMappedFile.h index b3e0994b5b..3efdbc5bda 100644 --- a/arm_compute/core/utils/misc/MMappedFile.h +++ b/arm_compute/core/utils/misc/MMappedFile.h @@ -24,7 +24,7 @@ #ifndef ARM_COMPUTE_MISC_MMAPPED_FILE_H #define ARM_COMPUTE_MISC_MMAPPED_FILE_H -#if !defined(BARE_METAL) +#if !defined(_WIN64) && !defined(BARE_METAL) #include #include @@ -105,6 +105,6 @@ class MMappedFile } // namespace mmap_io } // namespace utils } // namespace arm_compute -#endif // !defined(BARE_METAL) +#endif // !defined(_WIN64) &&!defined(BARE_METAL) #endif /* ARM_COMPUTE_MISC_MMAPPED_FILE_H */ diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index df907c106e..9f9f53ed8b 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -1494,15 +1494,53 @@ inline TensorShape compute_pool3d_shape(const TensorShape &src, Pooling3dLayerIn return output_shape; } +/** Calculate the gather output shape of a tensor + * + * @param[in] input_shape Input tensor shape + * @param[in] indices_shape Indices tensor shape. Only supports for 2d and 3d indices + * @param[in] actual_axis Axis to be used in the computation + * + * @note Let input_shape be (X,Y,Z) and indices shape (W,O,P) and axis 1 + * the new shape is computed by replacing the axis in the input shape with + * the indice shape so the output shape will be (X,W,O,P,Z) + * + * @return the calculated shape + */ inline TensorShape compute_gather_shape(const TensorShape &input_shape, const TensorShape &indices_shape, uint32_t actual_axis) { - ARM_COMPUTE_ERROR_ON(indices_shape.num_dimensions() > 1); ARM_COMPUTE_ERROR_ON(input_shape.num_dimensions() > 4); ARM_COMPUTE_ERROR_ON(actual_axis >= input_shape.num_dimensions()); - - TensorShape output_shape = input_shape; - output_shape[actual_axis] = indices_shape[0]; - + ARM_COMPUTE_ERROR_ON(indices_shape.num_dimensions() > 3); + TensorShape output_shape = input_shape; + if(indices_shape.num_dimensions() == 1u) + { + output_shape[actual_axis] = indices_shape[0]; + } + else + { + const auto ind_num_dims + { + indices_shape.num_dimensions() + }; + output_shape.shift_right(ind_num_dims - 1); + switch(actual_axis) + { + case 1: + { + output_shape[0] = input_shape[0]; + for(size_t idx = 0; idx < ind_num_dims; ++idx) + { + output_shape.set(actual_axis + idx, indices_shape[idx], false); + } + break; + } + default: + { + // 2d and 3d indices are only supported for axis == 1 + ARM_COMPUTE_ERROR_ON(actual_axis != 1 && indices_shape.num_dimensions() > 1); + } + } + } return output_shape; } } // namespace shape_calculator diff --git a/arm_compute/runtime/CL/CLTypes.h b/arm_compute/runtime/CL/CLTypes.h index bba25c6d64..d298ecd614 100644 --- a/arm_compute/runtime/CL/CLTypes.h +++ b/arm_compute/runtime/CL/CLTypes.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 Arm Limited. + * Copyright (c) 2020-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -35,7 +35,9 @@ enum class CLGEMMKernelType /** Reshaped GEMM kernel where both lhs and rhs matrices are reshaped. Configurable reshape and block size */ RESHAPED, /** Reshaped GEMM kernel where only the rhs matrix is reshaped. Configurable reshape and block size */ - RESHAPED_ONLY_RHS + RESHAPED_ONLY_RHS, + /** Reshaped GEMM kernel where only the rhs matrix is reshaped. Using MMUL with configurable block size. */ + RESHAPED_ONLY_RHS_MMUL }; /** OpenCL GEMM kernel selection parameters. These information are retrieved to select the GEMM kernel on OpenCL */ diff --git a/arm_compute/runtime/FunctionDescriptors.h b/arm_compute/runtime/FunctionDescriptors.h index face8a6fb4..af79820bc3 100644 --- a/arm_compute/runtime/FunctionDescriptors.h +++ b/arm_compute/runtime/FunctionDescriptors.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021 Arm Limited. + * Copyright (c) 2019-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -62,8 +62,9 @@ struct Conv2dInfo const ActivationLayerInfo &act_info, bool enable_fast_math, unsigned int num_groups, - const experimental::PostOpList &post_ops = experimental::PostOpList {}) - : conv_info(conv_info), dilation(dilation), act_info(act_info), enable_fast_math(enable_fast_math), num_groups(num_groups), post_ops(post_ops) + const experimental::PostOpList &post_ops = experimental::PostOpList {}, + const WeightsInfo &weights_info = WeightsInfo()) + : conv_info(conv_info), dilation(dilation), act_info(act_info), enable_fast_math(enable_fast_math), num_groups(num_groups), post_ops(post_ops), weights_info(weights_info) { } @@ -73,6 +74,7 @@ struct Conv2dInfo bool enable_fast_math{ false }; unsigned int num_groups{ 1 }; experimental::PostOpList post_ops{}; + WeightsInfo weights_info{}; }; /** Descriptor used by the 3d Convolution function */ diff --git a/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h b/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h index aa96716d38..2b4f848b22 100644 --- a/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h +++ b/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -112,20 +112,21 @@ class NEFullyConnectedLayer : public IFunction * |QASYMM8 |QASYMM8 |S32 |QASYMM8 | * |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |QASYMM8_SIGNED | * - * @param[in] input Source tensor. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32. - * @param[in] weights Weights tensor. The weights must be 2 dimensional. - * If this function is called after a Convolution Layer, the (transposed) weights will have as many rows as the product of the first 3 input's dimensions. - * If it is called after another FullyConnected Layer, the (transposed) weights will have as many rows as the input's first dimension. - * Data type supported: Same as @p input. - * @param[in] biases Bias tensor. Can be nullptr. Data type supported: Same as @p weights, S32 if @p weights is QASYMM8/QASYMM8_SIGNED. - * @param[out] output Destination tensor. Its shape should be equal to the output of a matrix multiplication between: - * - The output of im2col on the input and the (transposed) 2D weights, if the function is called after a Convolution Layer - * - The input tensor and the (transposed) 2D weights, if the function is called after another FullyConnected Layer. - * Data type supported: Same as @p input. - * @param[in] fc_info (Optional) Fully connected layer additional info + * @param[in] input Source tensor. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32. + * @param[in] weights Weights tensor. The weights must be 2 dimensional. + * If this function is called after a Convolution Layer, the (transposed) weights will have as many rows as the product of the first 3 input's dimensions. + * If it is called after another FullyConnected Layer, the (transposed) weights will have as many rows as the input's first dimension. + * Data type supported: Same as @p input. + * @param[in] biases Bias tensor. Can be nullptr. Data type supported: Same as @p weights, S32 if @p weights is QASYMM8/QASYMM8_SIGNED. + * @param[out] output Destination tensor. Its shape should be equal to the output of a matrix multiplication between: + * - The output of im2col on the input and the (transposed) 2D weights, if the function is called after a Convolution Layer + * - The input tensor and the (transposed) 2D weights, if the function is called after another FullyConnected Layer. + * Data type supported: Same as @p input. + * @param[in] fc_info (Optional) Fully connected layer additional info + * @param[in] weights_info (Optional) Stores neccessary compute information when weights are already reshaped */ void configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, - FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo()); + FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(), const WeightsInfo &weights_info = WeightsInfo()); /** Static function to check if given info will lead to a valid configuration of @ref NEFullyConnectedLayer * * Similar to @ref NEFullyConnectedLayer @@ -135,6 +136,21 @@ class NEFullyConnectedLayer : public IFunction static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo()); + /** Static function that queries whether fixed-format kernel exists for a given problem description + * + * @param[out] expected_weight_format Format in which weights should be for found fixed format kernel + * @param[in] input Source tensor + * @param[in] weights Weights tensor. + * @param[in] biases Bias tensor. Can be nullptr. Data type supported: Same as @p weights, S32 if @p weights is QASYMM8/QASYMM8_SIGNED. + * @param[in] output Destination tensor + * @param[in] fc_info Fully connected layer additional info + * @param[in] weights_info Describes weights shape + * + * @return a status + */ + static Status has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *input, const ITensorInfo *weights, + const ITensorInfo *biases, const ITensorInfo *output, const FullyConnectedLayerInfo &fc_info, const WeightsInfo &weights_info); + //Inherited methods override void run() override; void prepare() override; diff --git a/arm_compute/runtime/NEON/functions/NEGEMM.h b/arm_compute/runtime/NEON/functions/NEGEMM.h index ce68a61923..7ce2521148 100644 --- a/arm_compute/runtime/NEON/functions/NEGEMM.h +++ b/arm_compute/runtime/NEON/functions/NEGEMM.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -84,6 +84,15 @@ class NEGEMM : public IFunction */ static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info = GEMMInfo()); + /** Static function that queries whether there exists fixed-format kernel and if it exists it will return in the first argument in what format + * weights are expected to be reshaped as defined by WeightFormat class. Apart from the first argument the rest of the arguments are the same + * as in @ref NEGEMM::validate() except that all arguments are required. + * + * @return a status + */ + static Status has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, + float alpha, float beta, const GEMMInfo &gemm_info = GEMMInfo()); + // Inherited methods overridden: void run() override; void prepare() override; diff --git a/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h b/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h index cf5fb82398..a28266265d 100644 --- a/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h +++ b/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -122,6 +122,65 @@ class NEGEMMConvolutionLayer : public IFunction const WeightsInfo &weights_info = WeightsInfo(), const Size2D &dilation = Size2D(1U, 1U), const ActivationLayerInfo &act_info = ActivationLayerInfo(), bool enable_fast_math = false, unsigned int num_groups = 1); + /** Static function to check if there is an optimized version of + * GEMM available for the input parameters. + * + * The method is intended to be used to find out the optimal + * memory layout to be used for the weights tensor when running + * variable weights execution. + * + * The user can query the database of optimised kernels in + * arm_gemm by specifying one of the enumerations of + * arm_compute::WeightFormat in the weight_format field of the input + * parameter weights_info. In case of success, the method + * writes the expected format in the output parameter + * expected_weight_format. The expected_weight_format can than be + * used in the configure method of the class for retrieving the + * best optimal kernel. + * + * Use case one - query for a specific format: + * + * WeightInfo weights_info(..., arm_compute::WeightFormat::OHWIo4, ...); // Set the value of the input query. + * if (NEGEMMConvolutionlayer::has_opt_impl(WeightFormat(), ...., weights_info, ...)) + * { + * auto conv = std::unique_ptr(); + * conv->configure(..., weights_info, ...); // uses the same WeightFormat the user wanted originally, OHWYo4. + * conv->run(...); + * } + * + * Use case two - query for any format that would be optimal for the GEMM to execute: + * + * WeightInfo weights_info(..., arm_compute::WeightFormat::ANY, ...); // Set the value of the input query. + * arm_compute::WeightFormat expected_wf; + * if (NEGEMMConvolutionlayer::has_opt_impl(expected_wf, ...., weights_info, ...)) + * { + * auto conv = std::unique_ptr(); + * // ... code to convert the layout of the weights tensor to the layout returned by has_opt_impl + * WeightInfo new_weights_info(..., expected_wf, ...); // Set the value of the WeightFormat returned by has_opt_impl. + * conv->configure(..., new_weights_info, ...); + * conv->run(...); + * } + * + * Notice that a GEMM configured with a WeightFormat other than + * UNSPECIFIED will run GEMM with variable weights mode. + * + * @param[out] expected_weight_format The arm_compute::WeightFormat expected by the kernel. + * @param[in] src Source tensor info. + * @param[in] weights Weights tensor info. + * @param[in] biases Biases tensor info. Shared biases supported. + * @param[in] dst Destination tensor info. + * @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo. + * @param[in] weights_info (optional) Specifies additional configuration parameters for the weights of the GEMM computation. + * @param[in] dilation (Optional) Dilation, in elements, across x and y. Defaults to (1, 1). + * @param[in] act_info (Optional) Activation layer information in case of a fused activation. Only RELU, BOUNDED_RELU and LU_BOUNDED_RELU supported. And no activation (i.e. Linear) which is the default value. + * @param[in] enable_fast_math (Optional) Enable fast math computation. In case this flag were set, the function could dispatch the fastest implementation + * + * @return a Status + */ + static Status has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, + const PadStrideInfo &conv_info, + const WeightsInfo &weights_info = WeightsInfo(), const Size2D &dilation = Size2D(1U, 1U), const ActivationLayerInfo &act_info = ActivationLayerInfo(), + bool enable_fast_math = false); // Inherited methods overridden: void run() override; void prepare() override; diff --git a/arm_compute/runtime/NEON/functions/NEGather.h b/arm_compute/runtime/NEON/functions/NEGather.h index 393a38ee4d..8253e986df 100644 --- a/arm_compute/runtime/NEON/functions/NEGather.h +++ b/arm_compute/runtime/NEON/functions/NEGather.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021 Arm Limited. + * Copyright (c) 2019-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -49,18 +49,17 @@ class NEGather : public INESimpleFunctionNoBorder * |All |All | * * @param[in] input Source tensor. Supported tensor rank: up to 4. Data type supported: All - * @param[in] indices Indices tensor. Supported tensor rank: up to 1. Must be one of the following type: U32/S32. Each value Must be in range [0, input.shape[@p axis]) + * @param[in] indices Indices tensor. Supported tensor rank: up to 3. Must be one of the following type: U32/S32. Each value Must be in range [0, input.shape[@p axis]) + * @note The "axis" must be in the range [0, input.rank -1] when indices is a vector, and must be 1 when indices is a 2D or 3D tensor. * @param[out] output Destination tensor. Data type supported: Same as @p input * @param[in] axis (Optional) The axis in @p input to gather @p indices from. Defaults to 0 + * */ void configure(const ITensor *input, const ITensor *indices, ITensor *output, int axis = 0); - /** Static function to check if given info will lead to a valid configuration of @ref NEGatherKernel + /** Static function to check if given info will lead to a valid configuration * - * @param[in] input Source tensor info. Supported tensor rank: up to 4. Data type supported: All - * @param[in] indices Indices tensor info. Supported tensor rank: up to 1. Must be one of the following types: U32/S32. Each value Must be in range [0, input.shape[@p axis]) - * @param[in] output Destination tensor info. Data type supported: Same as @p input - * @param[in] axis (Optional) The axis in @p input to gather @p indices from. Defaults to 0 + * Similar to @ref NEGather::configure() * * @return a status */ diff --git a/arm_compute/runtime/NEON/functions/NEWinogradConvolutionLayer.h b/arm_compute/runtime/NEON/functions/NEWinogradConvolutionLayer.h index 2a49f2be59..85b4d047ef 100644 --- a/arm_compute/runtime/NEON/functions/NEWinogradConvolutionLayer.h +++ b/arm_compute/runtime/NEON/functions/NEWinogradConvolutionLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -38,7 +38,6 @@ class ITensor; /** Basic function to simulate a convolution layer. This function calls the following kernels: * - * -# @ref cpu::CpuWinogradConv2dTransformWeightsKernel (executed only once in the first call to the run() method ) * -# @ref cpu::CpuWinogradConv2dTransformInputKernel * -# @ref cpu::CpuWinogradConv2dTransformOutputKernel * -# @ref cpu::CpuGemmAssemblyDispatch diff --git a/docs/03_scripts.dox b/docs/03_scripts.dox deleted file mode 100644 index e66bb402fe..0000000000 --- a/docs/03_scripts.dox +++ /dev/null @@ -1,178 +0,0 @@ -/// -/// Copyright (c) 2017-2020 Arm Limited. -/// -/// SPDX-License-Identifier: MIT -/// -/// Permission is hereby granted, free of charge, to any person obtaining a copy -/// of this software and associated documentation files (the "Software"), to -/// deal in the Software without restriction, including without limitation the -/// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or -/// sell copies of the Software, and to permit persons to whom the Software is -/// furnished to do so, subject to the following conditions: -/// -/// The above copyright notice and this permission notice shall be included in all -/// copies or substantial portions of the Software. -/// -/// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -/// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -/// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -/// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -/// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -/// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -/// SOFTWARE. -/// -namespace arm_compute -{ -/** -@page data_import Importing data from existing models - -@tableofcontents - -@section caffe_data_extractor Extract data from pre-trained caffe model - -One can find caffe pre-trained models on -caffe's official github repository. - -The caffe_data_extractor.py provided in the scripts folder is an example script that shows how to -extract the parameter values from a trained model. - -@note complex networks might require altering the script to properly work. - -@subsection caffe_how_to How to use the script - -Install caffe following caffe's document. -Make sure the pycaffe has been added into the PYTHONPATH. - -Download the pre-trained caffe model. - -Run the caffe_data_extractor.py script by - - python caffe_data_extractor.py -m -n - -For example, to extract the data from pre-trained caffe Alex model to binary file: - - python caffe_data_extractor.py -m /path/to/bvlc_alexnet.caffemodel -n /path/to/caffe/models/bvlc_alexnet/deploy.prototxt - -The script has been tested under Python2.7. - -@subsection caffe_result What is the expected output from the script - -If the script runs successfully, it prints the names and shapes of each layer onto the standard -output and generates *.npy files containing the weights and biases of each layer. - -The arm_compute::utils::load_trained_data shows how one could load -the weights and biases into tensor from the .npy file by the help of Accessor. - -@section tensorflow_data_extractor Extract data from pre-trained tensorflow model - -The script tensorflow_data_extractor.py extracts trainable parameters (e.g. values of weights and biases) from a -trained tensorflow model. A tensorflow model consists of the following two files: - -{model_name}.data-{step}-{global_step}: A binary file containing values of each variable. - -{model_name}.meta: A binary file containing a MetaGraph struct which defines the graph structure of the neural -network. - -@note Since Tensorflow version 0.11 the binary checkpoint file which contains the values for each parameter has the format of: - {model_name}.data-{step}-of-{max_step} -instead of: - {model_name}.ckpt -When dealing with binary files with version >= 0.11, only pass {model_name} to -m option; -when dealing with binary files with version < 0.11, pass the whole file name {model_name}.ckpt to -m option. - -@note This script relies on the parameters to be extracted being in the -'trainable_variables' tensor collection. By default all variables are automatically added to this collection unless -specified otherwise by the user. Thus should a user alter this default behavior and/or want to extract parameters from other -collections, tf.GraphKeys.TRAINABLE_VARIABLES should be replaced accordingly. - -@subsection tensorflow_how_to How to use the script - -Install tensorflow and numpy. - -Download the pre-trained tensorflow model. - -Run tensorflow_data_extractor.py with - - python tensorflow_data_extractor -m -n - -For example, to extract the data from pre-trained tensorflow Alex model to binary files: - - python tensorflow_data_extractor -m /path/to/bvlc_alexnet -n /path/to/bvlc_alexnet.meta - -Or for binary checkpoint files before Tensorflow 0.11: - - python tensorflow_data_extractor -m /path/to/bvlc_alexnet.ckpt -n /path/to/bvlc_alexnet.meta - -@note with versions >= Tensorflow 0.11 only model name is passed to the -m option - -The script has been tested with Tensorflow 1.2, 1.3 on Python 2.7.6 and Python 3.4.3. - -@subsection tensorflow_result What is the expected output from the script - -If the script runs successfully, it prints the names and shapes of each parameter onto the standard output and generates - *.npy files containing the weights and biases of each layer. - -The arm_compute::utils::load_trained_data shows how one could load -the weights and biases into tensor from the .npy file by the help of Accessor. - -@section tf_frozen_model_extractor Extract data from pre-trained frozen tensorflow model - -The script tf_frozen_model_extractor.py extracts trainable parameters (e.g. values of weights and biases) from a -frozen trained Tensorflow model. - -@subsection tensorflow_frozen_how_to How to use the script - -Install Tensorflow and NumPy. - -Download the pre-trained Tensorflow model and freeze the model using the architecture and the checkpoint file. - -Run tf_frozen_model_extractor.py with - - python tf_frozen_model_extractor -m -d - -For example, to extract the data from pre-trained Tensorflow model to binary files: - - python tf_frozen_model_extractor -m /path/to/inceptionv3.pb -d ./data - -@subsection tensorflow_frozen_result What is the expected output from the script - -If the script runs successfully, it prints the names and shapes of each parameter onto the standard output and generates - *.npy files containing the weights and biases of each layer. - -The arm_compute::utils::load_trained_data shows how one could load -the weights and biases into tensor from the .npy file by the help of Accessor. - -@section validate_examples Validating examples - -Compute Library provides a list of graph examples that are used in the context of integration and performance testing. -The provenance of each model is part of its documentation and no structural or data alterations have been applied to any -of them unless explicitly specified otherwise in the documentation. - -Using one of the provided scripts will generate files containing the trainable parameters. - -You can validate a given graph example on a list of inputs by running: - - LD_LIBRARY_PATH=lib ./ --validation-range='' --validation-file='' --validation-path='/path/to/test/images/' --data='/path/to/weights/' - -e.g: - -LD_LIBRARY_PATH=lib ./bin/graph_alexnet --target=CL --layout=NHWC --type=F32 --threads=4 --validation-range='16666,24998' --validation-file='val.txt' --validation-path='images/' --data='data/' - -where: - validation file is a plain document containing a list of images along with their expected label value. - e.g: - - val_00000001.JPEG 65 - val_00000002.JPEG 970 - val_00000003.JPEG 230 - val_00000004.JPEG 809 - val_00000005.JPEG 516 - - --validation-range is the index range of the images within the validation file you want to check: - e.g: - - --validation-range='100,200' will validate 100 images starting from 100th one in the validation file. - - This can be useful when parallelizing the validation process is needed. -*/ -} diff --git a/docs/ComputeLibrary.dir b/docs/ComputeLibrary.dir index e92cd72c37..ab9dfc1b93 100644 --- a/docs/ComputeLibrary.dir +++ b/docs/ComputeLibrary.dir @@ -198,14 +198,6 @@ * @brief Utility scripts. */ -/** @file scripts/caffe_data_extractor.py - * @brief Basic script to export weights from Caffe to npy files. - */ - -/** @file scripts/tensorflow_data_extractor.py - * @brief Basic script to export weights from TensorFlow to npy files. - */ - /** @dir src * @brief Source code implementing all the arm_compute headers. */ diff --git a/docs/Doxyfile b/docs/Doxyfile index 73a3e1efa0..bb4129a2e1 100644 --- a/docs/Doxyfile +++ b/docs/Doxyfile @@ -38,7 +38,7 @@ PROJECT_NAME = "Compute Library" # could be handy for archiving the generated documentation or if some version # control system is used. -PROJECT_NUMBER = 22.05 +PROJECT_NUMBER = 22.08 # Using the PROJECT_BRIEF tag one can provide an optional one line description # for a project that appears at the top of each page and should give viewer a diff --git a/docs/contributor_guide/contribution_guidelines.dox b/docs/contributor_guide/contribution_guidelines.dox index 02d8495fc5..4a3ae4db2e 100644 --- a/docs/contributor_guide/contribution_guidelines.dox +++ b/docs/contributor_guide/contribution_guidelines.dox @@ -42,7 +42,7 @@ As part of the initiative to use inclusive language, there are certain phrases a Please also follow this guideline when committing changes to Compute Library. It is worth mentioning that the term "master" is still used in some comments but only in reference to external code links that Arm has no governance on. -Futhermore, starting from this release (22.05), 'master' branch will no longer be used, it has been replaced by 'main'. Please update your clone jobs accordingly. +Futhermore, starting from release (22.05), 'master' branch is no longer being used, it has been replaced by 'main'. Please update your clone jobs accordingly. @section S5_1_coding_standards Coding standards and guidelines Best practices (as suggested by clang-tidy): @@ -272,7 +272,7 @@ auto d = vdup_n_u8(0); // NO: It's not obvious what type this function returns. - No '*' in front of argument names - [in], [out] or [in,out] *in front* of arguments - - Skip a line between the description and params and between params and @return (If there is a return) + - Skip a line between the description and params and between params and \@return (If there is a return) - Align params names and params descriptions (Using spaces), and with a single space between the widest column and the next one. - Use an upper case at the beginning of the description diff --git a/docs/user_guide/errata.dox b/docs/user_guide/errata.dox index 1d8900877c..4a2d008fec 100644 --- a/docs/user_guide/errata.dox +++ b/docs/user_guide/errata.dox @@ -31,7 +31,7 @@ namespace arm_compute @section S7_1_errata Errata - (COMPMID-5324) Issue identified with direct and depthwise convolutions for certain Arm® Mali™ DDK versions. - - Versions Affected: All. + - Versions Affected: < v22.08 - Conditions: - Arm® Mali™ DDK Versions : >= r23p0 && <= r38p0 - Mali™ GPUs: Bifrost GPU family with the exception of G71 diff --git a/docs/user_guide/how_to_build_and_run_examples.dox b/docs/user_guide/how_to_build_and_run_examples.dox index f2f88c9b32..077baf9d47 100644 --- a/docs/user_guide/how_to_build_and_run_examples.dox +++ b/docs/user_guide/how_to_build_and_run_examples.dox @@ -30,213 +30,7 @@ namespace arm_compute @section S1_1_build_options Build options scons 2.3 or above is required to build the library. -To see the build options available simply run ```scons -h```: - - debug: Debug (yes|no) - default: False - - asserts: Enable asserts (this flag is forced to 1 for debug=1) (yes|no) - default: False - - logging: Logging (this flag is forced to 1 for debug=1) (yes|no) - default: False - - arch: Target Architecture (armv7a|x86_32|x86_64|armv8a|armv8.2-a|armv8.2-a-sve|armv8.2-a-sve2|armv8.6-a|armv8.6-a-sve|armv8.6-a-sve2|armv8r64|x86) - default: armv7a - - estate: Execution State (auto|32|64) - default: auto - - os: Target OS (linux|android|macos|tizen|bare_metal) - default: linux - - build: Build type (native|cross_compile|embed_only) - default: cross_compile - - examples: Build example programs (yes|no) - default: True - - gemm_tuner: Build gemm_tuner programs (yes|no) - default: True - - Werror: Enable/disable the -Werror compilation flag (yes|no) - default: True - - standalone: Builds the tests as standalone executables, links statically with libgcc, libstdc++ and libarm_compute (yes|no) - default: False - - opencl: Enable OpenCL support (yes|no) - default: True - - neon: Enable Arm® Neon™ support (yes|no) - default: False - - embed_kernels: Embed OpenCL kernels in library binary (yes|no) - default: True - - compress_kernels: Compress embedded OpenCL kernels in library binary. Note embed_kernels should be enabled as well (yes|no) - default: False - - set_soname: Set the library's soname and shlibversion (requires SCons 2.4 or above) (yes|no) - default: False - - openmp: Enable OpenMP backend (yes|no) - default: False - - cppthreads: Enable C++11 threads backend (yes|no) - default: True - - build_dir: Specify sub-folder for the build ( /path/to/build_dir ) - default: . - - install_dir: Specify sub-folder for the install ( /path/to/install_dir ) - default: - - exceptions: Enable/disable C++ exception support (yes|no) - default: True - - linker_script: Use an external linker script ( /path/to/linker_script ) - default: - - custom_options: Custom options that can be used to turn on/off features - (all|none|comma-separated list of names) - allowed names: disable_mmla_fp - default: none - - data_type_support: Enable a list of data types to support - (all|none|comma-separated list of names) - allowed names: qasymm8 qasymm8_signed qsymm16 fp16 fp32 - default: all - - toolchain_prefix: Override the toolchain prefix - default: - - compiler_prefix: Override the compiler prefix - default: - - extra_cxx_flags: Extra CXX flags to be appended to the build command - default: - - extra_link_flags: Extra LD flags to be appended to the build command - default: - - compiler_cache: Command to prefix to the C and C++ compiler (e.g ccache) - default: - - specs_file: Specs file to use - default: rdimon.specs - - benchmark_examples: Build benchmark examples programs (yes|no) - default: False - - validate_examples: Build validate examples programs (yes|no) - default: False - - reference_openmp: Build reference validation with openmp (yes|no) - default: True - - validation_tests: Build validation test programs (yes|no) - default: False - - benchmark_tests: Build benchmark test programs (yes|no) - default: False - - test_filter: Pattern to specify the tests' filenames to be compiled - default: *.cpp - - pmu: Enable PMU counters (yes|no) - default: False - - mali: Enable Arm® Mali™ hardware counters (yes|no) - default: False - - external_tests_dir: Add examples, benchmarks and tests to the tests suite from an external path ( /path/to/external_tests_dir ) - default: - - high_priority: Generate a library using only the high priority operators - default: False - - data_layout_support: Enable a list of data layout to support - default: False - -@b debug / @b asserts: - - With debug=1 asserts are enabled, and the library is built with symbols and no optimisations enabled. - - With debug=0 and asserts=1: Optimisations are enabled and symbols are removed, however all the asserts are still present (This is about 20% slower than the release build) - - With debug=0 and asserts=0: All optimisations are enable and no validation is performed, if the application misuses the library it is likely to result in a crash. (Only use this mode once you are sure your application is working as expected). - -@b arch: The x86_32 and x86_64 targets can only be used with neon=0 and opencl=1. - -@b os: Choose the operating system you are targeting: Linux, Android or bare metal. -@note bare metal can only be used for Arm® Neon™ (not OpenCL), only static libraries get built and Neon™'s multi-threading support is disabled. - -@b build: you can either build directly on your device (native) or cross compile from your desktop machine (cross-compile). In both cases make sure the compiler is available in your path. - -@note If you want to natively compile for 32bit on a 64bit Arm device running a 64bit OS then you will have to use cross-compile too. - -There is also an 'embed_only' option which will generate all the .embed files for the OpenCL kernels. This might be useful if using a different build system to compile the library. - -In addition the option 'compress_kernels' will compress the embedded OpenCL kernel files using zlib and inject them in the library. This is useful for reducing the binary size. Note, this option is only available for Android when 'embed_kernels' is enabled. - -@b Werror: If you are compiling using the same toolchains as the ones used in this guide then there shouldn't be any warning and therefore you should be able to keep Werror=1. If with a different compiler version the library fails to build because of warnings interpreted as errors then, if you are sure the warnings are not important, you might want to try to build with Werror=0 (But please do report the issue on Github). - -@b opencl / @b neon: Choose which SIMD technology you want to target. (Neon™ for Arm® Cortex®-A CPUs or OpenCL for Arm® Mali™ GPUs) - -@b embed_kernels: For OpenCL only: set embed_kernels=1 if you want the OpenCL kernels to be built in the library's binaries instead of being read from separate ".cl" / ".cs" files. If embed_kernels is set to 0 then the application can set the path to the folder containing the OpenCL kernel files by calling CLKernelLibrary::init(). By default the path is set to "./cl_kernels". - -@b set_soname: Do you want to build the versioned version of the library ? - -If enabled the library will contain a SONAME and SHLIBVERSION and some symlinks will automatically be created between the objects. -Example: - libarm_compute_core.so -> libarm_compute_core.so.1.0.0 - libarm_compute_core.so.1 -> libarm_compute_core.so.1.0.0 - libarm_compute_core.so.1.0.0 - -@note This options is disabled by default as it requires SCons version 2.4 or above. - -@b extra_cxx_flags: Custom CXX flags which will be appended to the end of the build command. - -@b build_dir: Build the library in a subfolder of the "build" folder. (Allows to build several configurations in parallel). - -@b examples: Build or not the examples - -@b validation_tests: Enable the build of the validation suite. - -@b benchmark_tests: Enable the build of the benchmark tests - -@b pmu: Enable the PMU cycle counter to measure execution time in benchmark tests. (Your device needs to support it) - -@b mali: Enable the collection of Arm® Mali™ hardware counters to measure execution time in benchmark tests. (Your device needs to have a Arm® Mali™ driver that supports it) - -@b openmp: Build in the OpenMP scheduler for Neon™. - -@note Only works when building with g++ not clang++ - -@b cppthreads: Build in the C++11 scheduler for Neon™. - -@sa Scheduler::set - -@b external_tests_dir: Add examples, benchmarks and tests to the tests suite from an external path ( /path/to/external_tests_dir ) - -In order to use this option, the external tests directory must have the following structure: - - EXTERNAL_TESTS_DIR: - └── tests - ├── benchmark - │   ├── CL - │   ├── datasets - │   ├── fixtures - │   └── Neon - └── validation -    ├── CL -     ├── datasets -     ├── fixtures -     └── Neon - -Then, build the library with `external_tests_dir=`. - -@b high_priority: Generate a library using only the high priority operators - -@b data_layout_support: Enable a list of data layout to support +To see the build options available simply run ```scons -h``` @section S1_2_linux Building for Linux @@ -377,7 +171,6 @@ An example build command with SVE is: @section S1_3_android Building for Android For Android, the library was successfully built and tested using Google's standalone toolchains: - - clang++ from NDK r18b for armv7a - clang++ from NDK r20b for armv8a - clang++ from NDK r20b for armv8.2-a with FP16 support @@ -387,12 +180,26 @@ For NDK r18 or older, here is a guide to Download the NDK package for your development platform, without the need to launch the make_standalone_toolchain.py script. You can find all the prebuilt binaries inside $NDK/toolchains/llvm/prebuilt/$OS_ARCH/bin/. -@attention the building script will look for a binary named "aarch64-linux-android-clang++", while the prebuilt binaries will have their API version as a suffix to their filename (e.g. "aarch64-linux-android21-clang++"). You should copy/rename the binary removing this suffix, or - alternatively - create an alias for it. +@parblock +@attention The building script will look for a binary named "aarch64-linux-android-clang++", while the prebuilt binaries will have their API version as a suffix to their filename (e.g. "aarch64-linux-android21-clang++"). You can instruct scons to use the correct version by using a combination of the toolchain_prefix and the "CC" "CXX" environment variables. +@attention For this particular example, you can specify: + + CC=clang CXX=clang++ scons toolchain_prefix=aarch64-linux-android21- + +@attention or: + + CC=aarch64-linux-android21-clang CXX=aarch64-linux-android21-clang++ scons toolchain_prefix="" + +@endparblock + +@parblock @attention We used to use gnustl but as of NDK r17 it is deprecated so we switched to libc++ +@endparblock @note Make sure to add the toolchains to your PATH: @@ -503,14 +310,14 @@ To cross-compile the library with Arm® Neon™ support for baremetal armv8a: Examples are disabled when building for bare metal. If you want to build the examples you need to provide a custom bootcode depending on the target architecture and link against the compute library. More information about bare metal bootcode can be found here. -@section S1_6_windows_host Building on a Windows host system +@section S1_6_windows_host Building on a Windows host system (cross-compile) Using `scons` directly from the Windows command line is known to cause problems. The reason seems to be that if `scons` is setup for cross-compilation it gets confused about Windows style paths (using backslashes). Thus it is recommended to follow one of the options outlined below. -@subsection S1_6_1_ubuntu_on_windows Bash on Ubuntu on Windows +@subsection S1_6_1_ubuntu_on_windows Bash on Ubuntu on Windows (cross-compile) The best and easiest option is to use Ubuntu on Windows. @@ -518,7 +325,7 @@ This feature is still marked as *beta* and thus might not be available. However, if it is building the library is as simple as opening a *Bash on Ubuntu on Windows* shell and following the general guidelines given above. -@subsection S1_6_2_cygwin Cygwin +@subsection S1_6_2_cygwin Cygwin (cross-compile) If the Windows subsystem for Linux is not available Cygwin can be used to install and run `scons`, the minimum Cygwin version must be 3.0.7 or later. In addition @@ -531,6 +338,38 @@ compiler is included in the Android standalone toolchain. After everything has been set up in the Cygwin terminal the general guide on building the library can be followed. +@subsection S1_6_3_WoA Windows on ARM (native build) + + Native builds on Windows are experimental and some features from the library interacting with the OS are missing. + +It's possible to build Compute Library natively on a windows system running on ARM. + +Windows on ARM(WoA) systems provide compatibility emulating x86 binaries on aarch64. Unfortunately Visual Studio 2022 does not work on aarch64 systems because it's an x86_64bit application and these binaries cannot be exectuted on WoA yet. + +Because we cannot use Visual Studio to build Compute Library we have to set up a native standalone toolchain to compile C++ code for arm64 on Windows. + +Native arm64 toolchain installation for WoA: +- LLVM+Clang-12 which can be downloaded from: https://github.com/llvm/llvm-project/releases/download/llvmorg-12.0.0/LLVM-12.0.0-woa64.exe +- Arm64 VC Runtime which can be downloaded from https://aka.ms/vs/17/release/vc_redist.arm64.exe + +- While full VS22 cannot be installed on WoA, we can install some components + -# Desktop development with C++ and all Arm64 components for Visual Studio, refer to: https://developer.arm.com/documentation/102528/0100/Install-Visual-Studio + -# VS22 build tools: https://visualstudio.microsoft.com/downloads/#build-tools-for-visual-studio-2022 + +There are some additional tools we need to install to build Compute Library: + +- git https://git-scm.com/download/win +- python 3 https://www.python.org/downloads/windows/ +- scons can be installed with pip install scons + +In order to use clang to build windows binaries natively we have to initialize the environment variables from VS22 correctly so that the compiler could find the arm64 C++ libraries. This can be done by pressing the key windows + r and running the command: + + cmd /k "C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\VC\Auxiliary\Build\vcvarsx86_arm64.bat" + +To build Compute Library type: + + scons opencl=0 neon=1 os=windows examples=0 validation_tests=1 benchmark_examples=0 build=native arch=armv8a Werror=0 exceptions=1 standalone=1 + @section S1_7_cl_requirements OpenCL DDK Requirements @subsection S1_7_1_cl_hard_requirements Hard Requirements diff --git a/docs/user_guide/introduction.dox b/docs/user_guide/introduction.dox index fb483fc134..7086f86c7c 100644 --- a/docs/user_guide/introduction.dox +++ b/docs/user_guide/introduction.dox @@ -39,6 +39,8 @@ Several builds of the library are available using various configurations: - Technology: Arm® Neon™ / OpenCL / Arm® Neon™ and OpenCL. - Debug / Asserts / Release: Use a build with asserts enabled to debug your application and enable extra validation. Once you are sure your application works as expected you can switch to a release build of the library for maximum performance. +@warning From 22.08 release, armv7a with Android build will no longer be tested or maintained. + @b Minimum toolchains requirements are shown below: @@ -59,11 +61,9 @@ Several builds of the library are available using various configurations: - + @@ -78,8 +78,8 @@ Please create an issue on allocator()->allocate(); } // [Initialize and Allocate Auxiliary CLTensor objects] + TOCK(tensor_allocation, measurements); + TICK(dummy_run); /// @page example_dynamic_fusion_cl_conv2d_elementwise_add /// Run the ClCompositeOperator prepare job. This performs any jobs that are required for the first run, like /// reshaping tensors for a more performant format. @@ -327,6 +332,8 @@ class ClFusedConv2dEltwiseAddExample : public Example // [Run ClCompositeOperator] op.run(run_pack_map); // [Run ClCompositeOperator] + CLScheduler::get().sync(); + TOCK(dummy_run, measurements); TOCK(startup_time, measurements); return true; } diff --git a/examples/dynamic_fusion/cl_ref_conv2d_elementwise_add.cpp b/examples/dynamic_fusion/cl_ref_conv2d_elementwise_add.cpp index 4f68372b49..3aedcc0f41 100644 --- a/examples/dynamic_fusion/cl_ref_conv2d_elementwise_add.cpp +++ b/examples/dynamic_fusion/cl_ref_conv2d_elementwise_add.cpp @@ -52,6 +52,9 @@ using namespace utils; using std::chrono::duration_cast; using std::chrono::microseconds; +/** A reference for comparing against the fusion of a direct convolution with an elementwise addition: + * examples/dynamic_fusion/cl_fused_conv2d_elementwise_add.cpp + */ class ClRefConv2dEltwiseAddExample : public Example { public: @@ -69,7 +72,7 @@ class ClRefConv2dEltwiseAddExample : public Example if(argc < 10) { // Print help - std::cout << "Usage: ./cl_conv2d_elementwise_add ih iw ifm wh ww ofm tuner_choice(0=Disable, 1=Rapid, 2=Normal, 3=Exhaustive)\n"; + std::cout << "Usage: ./cl_ref_conv2d_elementwise_add ih iw ifm wh ww ofm tuner_choice(0=Disable, 1=Rapid, 2=Normal, 3=Exhaustive) pad_x pad_y\n"; std::cout << "Too few or no input_matrices provided. Using shape config = SRGAN_0, tuner_choice=2\n\n"; ih = 512; iw = 512; @@ -126,6 +129,7 @@ class ClRefConv2dEltwiseAddExample : public Example CLScheduler::get().default_init(tuner_to_use); TICK(startup_time); + TICK(configure); /* Computation: * out = add_desc(addend, conv2d1x1(direct_conv)(input, weights, bias)) @@ -133,54 +137,64 @@ class ClRefConv2dEltwiseAddExample : public Example const auto data_type = DataType::F32; const auto data_layout = DataLayout::NHWC; const PadStrideInfo conv_info{ 1, 1, pad_x, pad_y }; - // const auto t_input_shape = TensorShape(384, 12, 12); - // const auto t_weight_shape = TensorShape(384, 1, 1, 64); - // const auto t_dst_shape = TensorShape(64, 12, 12); - const auto t_input_shape = TensorShape(ifm, iw, ih); - const auto t_weight_shape = TensorShape(ifm, ww, wh, ofm); - const auto t_dst_shape = misc::shape_calculator::compute_deep_convolution_shape(t_input_shape, data_layout, t_weight_shape, conv_info); + const auto t_input_shape = TensorShape(ifm, iw, ih); + const auto t_weight_shape = TensorShape(ifm, ww, wh, ofm); + const auto t_bias_shape = TensorShape(ofm); + const auto t_l1_addend_shape = TensorShape(ofm, iw); + const auto t_dst_shape = misc::shape_calculator::compute_deep_convolution_shape(t_input_shape, data_layout, t_weight_shape, conv_info); std::cout << "input_shape: " << t_input_shape << std::endl; std::cout << "weight_shape: " << t_weight_shape << std::endl; + std::cout << "bias_shape: " << t_bias_shape << std::endl; + std::cout << "addend_shape: " << t_l1_addend_shape << std::endl; std::cout << "dst_shape: " << t_dst_shape << std::endl; auto t_input_info = TensorInfo(t_input_shape, 1, data_type, data_layout); auto t_weight_info = TensorInfo(t_weight_shape, 1, data_type, data_layout); + auto t_bias_info = TensorInfo(t_bias_shape, 1, data_type, data_layout); auto t_l0_dst_info = TensorInfo(t_dst_shape, 1, data_type, data_layout); // Intermediate tensor for cond3 - auto t_l1_addend_info = TensorInfo(t_dst_shape, 1, data_type, data_layout); + auto t_l1_addend_info = TensorInfo(t_l1_addend_shape, 1, data_type, data_layout); auto t_dst_info = TensorInfo(t_dst_shape, 1, data_type, data_layout); // Init tensors { t_input.allocator()->init(t_input_info); t_weight.allocator()->init(t_weight_info); + t_bias.allocator()->init(t_bias_info); t_l1_addend.allocator()->init(t_dst_info); t_l0_dst.allocator()->init(t_l0_dst_info); t_dst.allocator()->init(t_dst_info); } - op0.configure(&t_input, &t_weight, nullptr, &t_l0_dst, conv_info); + op0.configure(&t_input, &t_weight, &t_bias, &t_l0_dst, conv_info); op1.configure(&t_l0_dst, &t_l1_addend, &t_dst, ConvertPolicy{}); + TOCK(configure, measurements); + TICK(tensor_allocation); // Construct tensors // Allocate and fill tensors { t_input.allocator()->allocate(); t_weight.allocator()->allocate(); + t_bias.allocator()->allocate(); t_l1_addend.allocator()->allocate(); t_l0_dst.allocator()->allocate(); t_dst.allocator()->allocate(); fill_random_tensor(t_input, -1.f, 1.f); fill_random_tensor(t_weight, -1.f, 1.f); + fill_random_tensor(t_bias, -1.f, 1.f); fill_random_tensor(t_l1_addend, -1.f, 1.f); } + TOCK(tensor_allocation, measurements); // Dummy run for CLTuner + TICK(dummy_run); op0.run(); - op1.run(); + CLScheduler::get().sync(); + TOCK(dummy_run, measurements); TOCK(startup_time, measurements); return true; } void do_run() override { - // Run the fused op + // Run the ops op0.run(); op1.run(); @@ -199,6 +213,7 @@ class ClRefConv2dEltwiseAddExample : public Example private: CLTensor t_input{}; CLTensor t_weight{}; + CLTensor t_bias{}; CLTensor t_l1_addend{}; CLTensor t_l0_dst{}; CLTensor t_dst{}; diff --git a/filedefs.json b/filedefs.json index 76dccfffee..3422eeb252 100644 --- a/filedefs.json +++ b/filedefs.json @@ -23,7 +23,8 @@ }, "armv8.6-a": { "cxxflags": ["-march=armv8.6-a+fp16"], - "cppdefines": ["ARM_COMPUTE_ENABLE_FP16"] + "cppdefines": ["ARM_COMPUTE_ENABLE_FP16", "ARM_COMPUTE_ENABLE_BF16", + "ARM_COMPUTE_ENABLE_I8MM"] }, "armv8.6-a-sve": { "cxxflags": ["-march=armv8.6-a+sve+fp16+dotprod"], @@ -37,4 +38,4 @@ } } } -} \ No newline at end of file +} diff --git a/filelist.json b/filelist.json index dc4be05f58..c5de028928 100644 --- a/filelist.json +++ b/filelist.json @@ -460,6 +460,8 @@ "deps": [ "Cast" ], "files": { "common": [ + "src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigBifrost.cpp", + "src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigValhall.cpp", "src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp", "src/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeBifrost.cpp", "src/gpu/cl/kernels/gemm/native/ClGemmDefaultConfigNativeMidgard.cpp", @@ -471,12 +473,14 @@ "src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyNativeKernel.cpp", "src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedKernel.cpp", "src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedOnlyRhsKernel.cpp", + "src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp", "src/gpu/cl/kernels/ClGemmLowpOffsetContributionKernel.cpp", "src/gpu/cl/kernels/ClGemmLowpOffsetContributionOutputStageKernel.cpp", "src/gpu/cl/kernels/ClGemmLowpQuantizeDownInt32ScaleByFixedPointKernel.cpp", "src/gpu/cl/kernels/ClGemmLowpQuantizeDownInt32ScaleByFloatKernel.cpp", "src/gpu/cl/kernels/ClGemmLowpQuantizeDownInt32ScaleKernel.cpp", "src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.cpp", + "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp", "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.cpp", "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.cpp", "src/gpu/cl/kernels/ClGemmReshapeLhsMatrixKernel.cpp", @@ -665,7 +669,7 @@ "Reduction": { "deps": [ "Reshape" ], "files": { - "common": [ + "common": [ "src/core/CL/kernels/CLReductionOperationKernel.cpp", "src/runtime/CL/functions/CLReductionOperation.cpp" ] @@ -1059,33 +1063,42 @@ "src/core/NEON/kernels/convolution/common/qasymm8.cpp", "src/core/NEON/kernels/convolution/common/qsymm8.cpp", "src/core/NEON/kernels/convolution/common/utils.cpp", - "src/core/NEON/kernels/convolution/winograd/padding.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_1x8_fp32_fp32_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_4x4_fp16_fp16_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_4x4_fp32_fp32_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_6x6_fp16_fp16_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_6x6_fp32_fp32_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2_7_fp32_fp32_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2x2_3x3_fp32_fp32_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2x2_5x5_fp32_fp32_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4_5_fp32_fp32_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4x4_3x3_fp16_fp16_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4x4_3x3_fp32_fp32_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_6_3_fp32_fp32_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_2_7_fp32_fp32_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_2x2_3x3_fp32_fp32_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_2x2_5x5_fp32_fp32_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_4_5_fp32_fp32_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_4x4_3x3_fp16_fp16_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_4x4_3x3_fp32_fp32_integers.cpp", - "src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_6_3_fp32_fp32_integers.cpp", + "src/core/NEON/kernels/convolution/winograd/input_transforms_fp16.cpp", + "src/core/NEON/kernels/convolution/winograd/input_transforms_fp32.cpp", + "src/core/NEON/kernels/convolution/winograd/output_transforms_fp16.cpp", + "src/core/NEON/kernels/convolution/winograd/output_transforms_fp32.cpp", + "src/core/NEON/kernels/convolution/winograd/weight_transforms_fp16.cpp", + "src/core/NEON/kernels/convolution/winograd/weight_transforms_fp32.cpp", + "src/core/NEON/kernels/convolution/winograd/winograd_fp16.cpp", + "src/core/NEON/kernels/convolution/winograd/winograd_fp32.cpp", + "src/core/NEON/kernels/convolution/winograd/input_transforms/a64_fp16_6x6.cpp", + "src/core/NEON/kernels/convolution/winograd/input_transforms/a64_fp32_6x6.cpp", + "src/core/NEON/kernels/convolution/winograd/input_transforms/arm_fp32_1x8.cpp", + "src/core/NEON/kernels/convolution/winograd/input_transforms/arm_fp32_4x4.cpp", + "src/core/NEON/kernels/convolution/winograd/input_transforms/arm_fp32_6x6.cpp", + "src/core/NEON/kernels/convolution/winograd/output_transforms/a64_fp16_4x4_3x3.cpp", + "src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_1x2_1x7.cpp", + "src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_1x4_1x5.cpp", + "src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_1x6_1x3.cpp", + "src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_2x2_3x3.cpp", + "src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_2x2_5x5.cpp", + "src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_4x4_3x3.cpp", + "src/core/NEON/kernels/convolution/winograd/weight_transforms/a64_fp16_4x4_3x3.cpp", + "src/core/NEON/kernels/convolution/winograd/weight_transforms/arm_fp32_2x2_3x3.cpp", + "src/core/NEON/kernels/convolution/winograd/weight_transforms/arm_fp32_2x2_5x5.cpp", + "src/core/NEON/kernels/convolution/winograd/weight_transforms/arm_fp32_4x4_3x3.cpp", + "src/core/NEON/kernels/convolution/winograd/weight_transforms/cpp_fp32_1x2_1x7.cpp", + "src/core/NEON/kernels/convolution/winograd/weight_transforms/cpp_fp32_1x4_1x5.cpp", + "src/core/NEON/kernels/convolution/winograd/weight_transforms/cpp_fp32_1x6_1x3.cpp", "src/cpu/kernels/directconv2d/nhwc/neon/impl.cpp", "src/cpu/kernels/directconv2d/nchw/all.cpp" ], "fp32": [ "src/cpu/kernels/directconv2d/nhwc/neon/fp32.cpp" ] + }, + "sve": { + "common": ["src/core/NEON/kernels/convolution/winograd/input_transforms/sve_fp32_6x6.cpp"] } } }, @@ -1217,10 +1230,10 @@ "src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic.cpp", "src/core/NEON/kernels/arm_conv/depthwise/interleaves/generic_quantized_dot_product.cpp", "src/cpu/kernels/depthwiseconv2d/generic/neon/impl.cpp" - ], + ], "fp16":["src/cpu/kernels/depthwiseconv2d/generic/neon/fp16.cpp"], - "fp32":["src/cpu/kernels/depthwiseconv2d/generic/neon/fp32.cpp"], - "qasymm8":["src/cpu/kernels/depthwiseconv2d/generic/neon/qasymm8.cpp"], + "fp32":["src/cpu/kernels/depthwiseconv2d/generic/neon/fp32.cpp"], + "qasymm8":["src/cpu/kernels/depthwiseconv2d/generic/neon/qasymm8.cpp"], "qasymm8_signed":["src/cpu/kernels/depthwiseconv2d/generic/neon/qasymm8_signed.cpp"] }, "sve": { @@ -1321,7 +1334,7 @@ "fp32": ["src/cpu/kernels/elementwise_binary/generic/sve/fp32.cpp"], "fp16": ["src/cpu/kernels/elementwise_binary/generic/sve/fp16.cpp"] - }, + }, "sve2":{ "qasymm8": ["src/cpu/kernels/elementwise_binary/generic/sve2/qasymm8.cpp"], "qasymm8_signed": ["src/cpu/kernels/elementwise_binary/generic/sve2/qasymm8_signed.cpp"] @@ -1421,7 +1434,7 @@ }, "Gemm": { "deps": [ "Quantize", "Add"], - "files": { + "files": { "common": [ "src/cpu/kernels/CpuConvertQuantizedSignednessKernel.cpp", "src/cpu/kernels/CpuGemmMatrixAdditionKernel.cpp", @@ -1525,7 +1538,7 @@ "src/core/NEON/kernels/arm_gemm/kernels/a64_smallK_hybrid_u8u32_dot_6x4/a55.cpp", "src/core/NEON/kernels/arm_gemm/kernels/a64_smallK_hybrid_u8u32_dot_6x4/generic.cpp", "src/core/NEON/kernels/arm_gemm/kernels/a64_smallK_hybrid_u8u32_dot_8x4/a55.cpp", - "src/core/NEON/kernels/arm_gemm/kernels/a64_smallK_hybrid_u8u32_dot_8x4/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/a64_smallK_hybrid_u8u32_dot_8x4/generic.cpp", "src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.cpp", "src/cpu/kernels/gemm_matrix_add/generic/neon/impl.cpp" ], @@ -1540,6 +1553,16 @@ ], "estate64": [ "src/core/NEON/kernels/arm_gemm/kernels/a64_sgemv_pretransposed/generic.cpp" + ], + "experimental_fixed_format_kernels": [ + "src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_bf16fp32_mmla_6x16/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp16_mla_6x32/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32_mla_6x16/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_dot_8x12/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp16_mla_8x24/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp32_mla_8x12/generic.cpp" ] }, "sve": { @@ -1584,7 +1607,20 @@ "src/core/NEON/kernels/arm_gemm/kernels/sve_smallK_hybrid_u8u32_dot_8x1VL/generic.cpp", "src/core/NEON/kernels/arm_gemm/mergeresults-sve.cpp", "src/core/NEON/kernels/arm_gemm/transform-sve.cpp" - ] + ], + "experimental_fixed_format_kernels": [ + "src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_bf16fp32_mmla_6x4VL/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp16_mla_6x4VL/a64fx.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp16_mla_6x4VL/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32_mla_6x4VL/a64fx.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32_mla_6x4VL/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp16_mla_8x3VL/a64fx.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp16_mla_8x3VL/generic.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp32_mla_8x3VL/a64fx.cpp", + "src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp32_mla_8x3VL/generic.cpp" + ] } } }, @@ -1609,7 +1645,7 @@ "common": [ "src/core/NEON/kernels/NEInstanceNormalizationLayerKernel.cpp", "src/runtime/NEON/functions/NEInstanceNormalizationLayer.cpp" - ], + ], "neon":{ "common":["src/cpu/kernels/instancenorm/generic/neon/impl.cpp"], "fp16":["src/cpu/kernels/instancenorm/generic/neon/fp16.cpp"], @@ -1669,7 +1705,7 @@ "files": { "common": [ "src/cpu/kernels/CpuMaxUnpoolingLayerKernel.cpp", - "src/runtime/NEON/functions/NEMaxUnpoolingLayer.cpp", + "src/runtime/NEON/functions/NEMaxUnpoolingLayer.cpp", "src/cpu/operators/CpuMaxUnpooling.cpp" ], "neon":{ @@ -1770,12 +1806,12 @@ "src/core/NEON/kernels/arm_conv/pooling/kernels/a64_u8_nhwc_max_2x2_s1_output2x2_depthfirst/generic.cpp", "src/core/NEON/kernels/arm_conv/pooling/kernels/a64_u8_nhwc_max_generic_depthfirst/generic.cpp", "src/core/NEON/kernels/arm_conv/pooling/kernels/a64_u8q_nhwc_avg_generic_depthfirst/generic.cpp", - "src/core/NEON/kernels/arm_conv/pooling/kernels/a64_u8q_nhwc_max_generic_depthfirst/generic.cpp" + "src/core/NEON/kernels/arm_conv/pooling/kernels/a64_u8q_nhwc_max_generic_depthfirst/generic.cpp" ], "nchw": [ "src/cpu/kernels/pool2d/neon/nchw/all.cpp" ], "fp16": [ "src/cpu/kernels/pool2d/neon/fp16.cpp" ], - "fp32": [ "src/cpu/kernels/pool2d/neon/fp32.cpp" ], - "qasymm8":[ "src/cpu/kernels/pool2d/neon/qasymm8.cpp" ], + "fp32": [ "src/cpu/kernels/pool2d/neon/fp32.cpp" ], + "qasymm8":[ "src/cpu/kernels/pool2d/neon/qasymm8.cpp" ], "qasymm8_signed":["src/cpu/kernels/pool2d/neon/qasymm8_signed.cpp"] }, "sve": { @@ -1975,8 +2011,8 @@ "neon":{ "common":["src/cpu/kernels/softmax/generic/neon/impl.cpp"], "fp32": ["src/cpu/kernels/softmax/generic/neon/fp32.cpp"], - "fp16": ["src/cpu/kernels/softmax/generic/neon/fp16.cpp"], - "qasymm8":[ "src/cpu/kernels/softmax/generic/neon/qasymm8.cpp"], + "fp16": ["src/cpu/kernels/softmax/generic/neon/fp16.cpp"], + "qasymm8":[ "src/cpu/kernels/softmax/generic/neon/qasymm8.cpp"], "qasymm8_signed":["src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp"] }, "sve": { @@ -1988,7 +2024,7 @@ }, "sve2":{ "common" :["src/cpu/kernels/softmax/generic/sve2/impl.cpp"], - "qasymm8":[ "src/cpu/kernels/softmax/generic/sve2/qasymm8.cpp"], + "qasymm8":[ "src/cpu/kernels/softmax/generic/sve2/qasymm8.cpp"], "qasymm8_signed":["src/cpu/kernels/softmax/generic/sve2/qasymm8_signed.cpp"] } } @@ -2074,7 +2110,8 @@ "dynamic_fusion": [ "src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.cpp", "src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClDirectConvolutionKernelComponent.cpp", - "src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseAddKernelComponent.cpp", + "src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseKernelComponent.cpp", + "src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClFloorKernelComponent.cpp", "src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClStoreKernelComponents.cpp", "src/gpu/cl/kernels/experimental/dynamic_fusion/ClCompositeKernel.cpp", diff --git a/include/libnpy/npy.hpp b/include/libnpy/npy.hpp index 24244ca272..e4f2215931 100644 --- a/include/libnpy/npy.hpp +++ b/include/libnpy/npy.hpp @@ -35,7 +35,9 @@ #include #include #include - +#include +#include +#include namespace npy { diff --git a/scripts/arm_compute_library_nn_driver.go b/scripts/arm_compute_library_nn_driver.go index cbbe7a71e0..73a5ce43aa 100644 --- a/scripts/arm_compute_library_nn_driver.go +++ b/scripts/arm_compute_library_nn_driver.go @@ -11,6 +11,25 @@ import ( "strings" ) +func isVersionAtLeast(version_name string, target_version int) bool { + name_map := map[string]int { + "L": 5, "5": 5, + "M": 6, "6": 6, + "N": 7, "7": 7, + "O": 8, "8": 8, + "P": 9, "9": 9, + "Q": 10, "10": 10, + "R": 11, "11": 11, + "S": 12, "12": 12, + "T": 13, "13": 13, + } + if _, ok := name_map[version_name]; ok { + return name_map[version_name] >= target_version + } else { + return false + } +} + func globalFlags(ctx android.BaseContext) []string { var cppflags []string @@ -30,6 +49,12 @@ func globalFlags(ctx android.BaseContext) []string { } } + // Since Android T, the underlying NDK stops supporting system assembler like GAS, in favor of integrated assembler + // However for Android < Android T we still want to suppress integrated assembler for backward compatibility + if ! isVersionAtLeast(ctx.AConfig().PlatformVersionName(), 13) { + cppflags = append(cppflags, "-no-integrated-as") + } + data_types := strings.Split(ctx.AConfig().GetenvWithDefault("COMPUTE_LIB_DATA_TYPE", "ALL"), ",") for _, x := range data_types { diff --git a/scripts/caffe_data_extractor.py b/scripts/caffe_data_extractor.py deleted file mode 100755 index 47d24b265f..0000000000 --- a/scripts/caffe_data_extractor.py +++ /dev/null @@ -1,45 +0,0 @@ -#!/usr/bin/env python -"""Extracts trainable parameters from Caffe models and stores them in numpy arrays. -Usage - python caffe_data_extractor -m path_to_caffe_model_file -n path_to_caffe_netlist - -Saves each variable to a {variable_name}.npy binary file. - -Tested with Caffe 1.0 on Python 2.7 -""" -import argparse -import caffe -import os -import numpy as np - - -if __name__ == "__main__": - # Parse arguments - parser = argparse.ArgumentParser('Extract Caffe net parameters') - parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Path to Caffe model file') - parser.add_argument('-n', dest='netFile', type=str, required=True, help='Path to Caffe netlist') - args = parser.parse_args() - - # Create Caffe Net - net = caffe.Net(args.netFile, 1, weights=args.modelFile) - - # Read and dump blobs - for name, blobs in net.params.iteritems(): - print('Name: {0}, Blobs: {1}'.format(name, len(blobs))) - for i in range(len(blobs)): - # Weights - if i == 0: - outname = name + "_w" - # Bias - elif i == 1: - outname = name + "_b" - else: - continue - - varname = outname - if os.path.sep in varname: - varname = varname.replace(os.path.sep, '_') - print("Renaming variable {0} to {1}".format(outname, varname)) - print("Saving variable {0} with shape {1} ...".format(varname, blobs[i].data.shape)) - # Dump as binary - np.save(varname, blobs[i].data) diff --git a/scripts/tensorflow_data_extractor.py b/scripts/tensorflow_data_extractor.py deleted file mode 100755 index 1dbf0e127e..0000000000 --- a/scripts/tensorflow_data_extractor.py +++ /dev/null @@ -1,51 +0,0 @@ -#!/usr/bin/env python -"""Extracts trainable parameters from Tensorflow models and stores them in numpy arrays. -Usage - python tensorflow_data_extractor -m path_to_binary_checkpoint_file -n path_to_metagraph_file - -Saves each variable to a {variable_name}.npy binary file. - -Note that since Tensorflow version 0.11 the binary checkpoint file which contains the values for each parameter has the format of: - {model_name}.data-{step}-of-{max_step} -instead of: - {model_name}.ckpt -When dealing with binary files with version >= 0.11, only pass {model_name} to -m option; -when dealing with binary files with version < 0.11, pass the whole file name {model_name}.ckpt to -m option. - -Also note that this script relies on the parameters to be extracted being in the -'trainable_variables' tensor collection. By default all variables are automatically added to this collection unless -specified otherwise by the user. Thus should a user alter this default behavior and/or want to extract parameters from other -collections, tf.GraphKeys.TRAINABLE_VARIABLES should be replaced accordingly. - -Tested with Tensorflow 1.2, 1.3 on Python 2.7.6 and Python 3.4.3. -""" -import argparse -import numpy as np -import os -import tensorflow as tf - - -if __name__ == "__main__": - # Parse arguments - parser = argparse.ArgumentParser('Extract Tensorflow net parameters') - parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Path to Tensorflow checkpoint binary\ - file. For Tensorflow version >= 0.11, only include model name; for Tensorflow version < 0.11, include\ - model name with ".ckpt" extension') - parser.add_argument('-n', dest='netFile', type=str, required=True, help='Path to Tensorflow MetaGraph file') - args = parser.parse_args() - - # Load Tensorflow Net - saver = tf.train.import_meta_graph(args.netFile) - with tf.Session() as sess: - # Restore session - saver.restore(sess, args.modelFile) - print('Model restored.') - # Save trainable variables to numpy arrays - for t in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES): - varname = t.name - if os.path.sep in t.name: - varname = varname.replace(os.path.sep, '_') - print("Renaming variable {0} to {1}".format(t.name, varname)) - print("Saving variable {0} with shape {1} ...".format(varname, t.shape)) - # Dump as binary - np.save(varname, sess.run(t)) diff --git a/src/core/CL/CLCompileContext.cpp b/src/core/CL/CLCompileContext.cpp index b9b2b5651a..81eb748ab8 100644 --- a/src/core/CL/CLCompileContext.cpp +++ b/src/core/CL/CLCompileContext.cpp @@ -232,6 +232,8 @@ void CLCompileContext::set_context(cl::Context context) std::string CLCompileContext::generate_build_options(const StringSet &build_options_set, const std::string &kernel_path) const { std::string concat_str; + bool ext_supported = false; + std::string ext_buildopts; #if defined(ARM_COMPUTE_DEBUG_ENABLED) // Enable debug properties in CL kernels @@ -247,7 +249,7 @@ std::string CLCompileContext::generate_build_options(const StringSet &build_opti concat_str += " -DARM_COMPUTE_OPENCL_FP16_ENABLED=1 "; } - if(_device.supported("cl_arm_integer_dot_product_int8")) + if(_device.supported("cl_arm_integer_dot_product_int8") || _device.supported("cl_khr_integer_dot_product")) { concat_str += " -DARM_COMPUTE_OPENCL_DOT8_ENABLED=1 "; } @@ -257,13 +259,11 @@ std::string CLCompileContext::generate_build_options(const StringSet &build_opti concat_str += " -DARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED=1 "; } - if(_device.version() == CLVersion::CL20) - { - concat_str += " -cl-std=CL2.0 "; - } - else if(_device.supported("cl_arm_non_uniform_work_group_size")) + std::tie(ext_supported, ext_buildopts) = _device.is_non_uniform_workgroup_supported(); + + if(ext_supported) { - concat_str += " -cl-arm-non-uniform-work-group-size "; + concat_str += ext_buildopts; } else { diff --git a/src/core/CL/CLHelpers.cpp b/src/core/CL/CLHelpers.cpp index 8685180b7f..94675d60cc 100644 --- a/src/core/CL/CLHelpers.cpp +++ b/src/core/CL/CLHelpers.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2021 Arm Limited. + * Copyright (c) 2016-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -256,7 +256,11 @@ bool dot8_acc_supported(const cl::Device &device) CLVersion get_cl_version(const cl::Device &device) { std::string version_str = device.getInfo(); - if(version_str.find("OpenCL 2") != std::string::npos) + if(version_str.find("OpenCL 3") != std::string::npos) + { + return CLVersion::CL30; + } + else if(version_str.find("OpenCL 2") != std::string::npos) { return CLVersion::CL20; } @@ -388,6 +392,15 @@ size_t get_cl_image_pitch_alignment(const cl::Device &device) } } +bool get_cl_non_uniform_work_group_supported(const cl::Device &device) +{ + cl_bool supported = CL_FALSE; + + cl_int err = clGetDeviceInfo(device(), CL_DEVICE_NON_UNIFORM_WORK_GROUP_SUPPORT, sizeof(cl_bool), &supported, nullptr); + + return (err == CL_SUCCESS && supported == CL_TRUE); +} + cl::Kernel create_kernel(const CLCompileContext &ctx, const std::string &kernel_name, const std::set &build_opts) { opencl::ClKernelLibrary &klib = opencl::ClKernelLibrary::get(); @@ -478,4 +491,8 @@ void set_unroll_with_pragma(CLBuildOptions &built_opts, std::initializer_listclBuildProgram_ptr == nullptr, "Failed to load OpenCL symbols from shared library"); return true; } } +#ifdef __ANDROID__ + // When running in NDK environment, the above libraries are not accessible. + static const std::vector android_libraries{ "libOpenCL-pixel.so", "libOpenCL-car.so" }; + + for(const auto &lib : android_libraries) + { + if(load(lib, /* use_loader */true)) + { + ARM_COMPUTE_ERROR_ON_MSG(this->clBuildProgram_ptr == nullptr, "Failed to load OpenCL symbols from android shared library"); + return true; + } + } +#endif /* __ANDROID__ */ + std::cerr << "Couldn't find any OpenCL library.\n"; return false; } -bool CLSymbols::load(const std::string &library) +bool CLSymbols::load(const std::string &library, bool use_loader) { void *handle = dlopen(library.c_str(), RTLD_LAZY | RTLD_LOCAL); @@ -85,8 +99,28 @@ bool CLSymbols::load(const std::string &library) return false; } +#ifdef __ANDROID__ + typedef void* (*loadOpenCLPointer_t)(const char* name); + loadOpenCLPointer_t loadOpenCLPointer; + if (use_loader) { + typedef void (*enableOpenCL_t)(); + enableOpenCL_t enableOpenCL = + reinterpret_cast(dlsym(handle, "enableOpenCL")); + enableOpenCL(); + + loadOpenCLPointer = reinterpret_cast( + dlsym(handle, "loadOpenCLPointer")); + } else { + loadOpenCLPointer = nullptr; + } +#define LOAD_FUNCTION_PTR(func_name, _handle) \ + func_name##_ptr = reinterpret_cast( use_loader ? \ + loadOpenCLPointer(#func_name) : dlsym(handle, #func_name)); +#else /* __ANDROID__ */ + (void)use_loader; // Avoid unused warning #define LOAD_FUNCTION_PTR(func_name, handle) \ func_name##_ptr = reinterpret_cast(dlsym(handle, #func_name)); +#endif /* __ANDROID__ */ LOAD_FUNCTION_PTR(clCreateContext, handle); LOAD_FUNCTION_PTR(clCreateContextFromType, handle); diff --git a/src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl b/src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl new file mode 100644 index 0000000000..8919023d4c --- /dev/null +++ b/src/core/CL/cl_kernels/common/gemm_reshaped_only_rhs_mmul.cl @@ -0,0 +1,528 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "activation_float_helpers.h" +#include "helpers.h" +#include "tile_helpers.h" + +#if defined(GEMM_MM_RESHAPED_ONLY_RHS_NT_MMUL) +/** This OpenCL kernel computes the matrix multiplication between 2 matrices using the MMUL extension: + * + * The LHS matrix is NOT reshaped + * The RHS is reshaped with @ref ClGemmMatrixMultiplyReshapedOnlyRhsKernel and the block K0xN0 is NOT transposed + * + * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (e.g. -DN0=8, -DK0=4). + * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2) + * @note The number of output columns processed by the the cooperative mmul extension must be passed at compile time using -DMMUL_N0 (e.g., -DMMUL_N0=2) + * @note The number of output rows processed by the the cooperative mmul extension must be passed at compile time using -DMMUL_M0 (e.g., -DMMUL_M0=2) + * @note The number of lhs columns (or rhs rows) processed by the the cooperative mmul extension must be passed at compile time using -DMMUL_K0 (e.g., -DMMUL_K0=2) + * @note Only the following configurations of M0, N0 and K0 are currently supported: + * - M0 > 0 + * - N0 = 1, 2, 3, 4, 8, 16 + * - K0 = 1 + * + * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. + * The activation function is performed after the bias addition + * + * @param[in] lhs_ptr Pointer to the LHS tensor. Supported data types: F16/F32 + * @param[in] lhs_stride_y Stride of the LHS tensor in Y dimension (in bytes) + * @param[in] lhs_stride_z Stride of the LHS tensor in Z dimension (in bytes) + * @param[in] lhs_w The size of the width dimension of the LHS tensor + * @param[in] lhs_h The size of the height dimension of the LHS tensor + * @param[in] lhs_n The size of the depth dimension of the LHS tensor + * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS tensor + * @param[in] rhs_ptr Pointer to the RHS reshaped tensor. Supported data type: same as @p lhs_ptr + * @param[in] rhs_stride_y Stride of the RHS tensor in Y dimension (in bytes) + * @param[in] rhs_stride_z Stride of the RHS tensor in Z dimension (in bytes) + * @param[in] rhs_w The size of the width dimension of the RHS tensor + * @param[in] rhs_h The size of the height dimension of the RHS tensor + * @param[in] rhs_n The size of the depth dimension of the RHS tensor + * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS tensor + * @param[in] bia_ptr (Optional) Pointer to the bias tensor. Supported data type: same as @p lhs_ptr + * @param[in] bia_stride_y (Optional) Stride of the bias tensor in Y dimension (in bytes) + * @param[in] bia_stride_z (Optional) Stride of the bias tensor in Z dimension (in bytes) + * @param[in] bia_w (Optional) The size of the width dimension of the bias tensor + * @param[in] bia_h (Optional) The size of the height dimension of the bias tensor + * @param[in] bia_n (Optional) The size of the depth dimension of the bias tensor + * @param[in] bia_offset_first_element_in_bytes (Optional) The offset of the first element in the bias tensor + * @param[out] dst_ptr Pointer to the destination tensor. Supported data type: same as @p lhs_ptr + * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) + * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) + * @param[in] dst_w The size of the width dimension of the destination tensor + * @param[in] dst_h The size of the height dimension of the destination tensor + * @param[in] dst_n The size of the depth dimension of the destination tensor + * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor + * @param[in] M Number of rows in LHS matrix not reshaped + * @param[in] N Number of columns in RHS matrix not reshaped + * @param[in] K Number of columns in LHS matrix and rows in RHS matrix not reshaped + */ +__kernel void gemm_mm_reshaped_only_rhs_nt_mmul( + TENSOR3D_T(lhs, BUFFER), + TENSOR3D_T(rhs, BUFFER), +#if defined(BETA) + TENSOR3D_T(bia, BUFFER), +#endif // defined(BETA) + TENSOR3D_T(dst, BUFFER), + const int M, + const int N, + const int K) +{ +#define MMUL_BLOCK_SIZE (MMUL_N0 * MMUL_K0) + + uint x0 = get_global_id(0); // (N / N0) * MMUL_K0 + uint y0 = get_global_id(1); // (M / M0) / MMUL_M0 + uint z = get_global_id(2); // Batch + + // Get block ID and thread ID within the block + uint block_id = (x0 / MMUL_BLOCK_SIZE); + uint thread_id = (x0 % MMUL_BLOCK_SIZE); + + // Coordinate within a block + uint block_x = thread_id % MMUL_N0; + uint block_y = (thread_id / MMUL_M0); + + // Starting destination coordinates + uint dst_x = min(block_x * N0 + block_id * MMUL_N0 * N0, (uint)(N - 1)); + uint dst_y = min(block_y * M0 + y0 * M0 * MMUL_M0, (uint)(M - M0)); + + // Note: We need to clamp dst_x and dst_y because we always need to execute a complete MMUL block! Only after the matrix multiplication + // part can we exit the kernel if it is out-of-bound. Remember, we have a cooperative matrix multiplication. Therefore, we need a full block to get the correct results + + // Starting LHS coordinates + uint lhs_x = block_x; + uint lhs_y = dst_y; + + // Starting RHS coordinates + uint rhs_x = block_y * N0 * MMUL_N0 + block_x * N0; + uint rhs_y = block_id; + + // Compute LHS/RHS/DST matrix address + lhs_offset_first_element_in_bytes += lhs_x * sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z; + rhs_offset_first_element_in_bytes += rhs_x * sizeof(DATA_TYPE) + rhs_y * rhs_stride_y + z * rhs_stride_z; + dst_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z; + + // Note: If RHS derives from the weights of convolution 2d layer, RHS will always be 2D and rhs_stride_z will always be equal to 0 for + // not sliding the tensor + + // Initialize the accumulators + // MMUL extension accumulate the result in F32 for both F32 and F16 + TILE(float, M0, N0, c_f32); + +#if !defined(HALF_PRECISION) +#define c c_f32 +#endif // !defined(HALF_PRECISION) + + LOOP_UNROLLING(int, i, 0, 1, M0, + { + c_f32[i].v = 0; + }) + + for(int k = 0; k <= K - MMUL_K0; k += MMUL_K0) + { + TILE(DATA_TYPE, M0, 1, a); + TILE(DATA_TYPE, 1, N0, b); + + // Load tile from the lhs/rhs tensors + T_LOAD(DATA_TYPE, M0, 1, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a); + T_LOAD(DATA_TYPE, 1, N0, BUFFER, rhs, 0, 0, 1, 0, b); + + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + LOOP_UNROLLING(int, n0, 0, 1, N0, + { + c_f32[m0].s[n0] = arm_matrix_multiply(a[m0].s[0], b[0].s[n0], c_f32[m0].s[n0]); + }) + }) + + lhs_offset_first_element_in_bytes += MMUL_K0 * sizeof(DATA_TYPE); + rhs_offset_first_element_in_bytes += MMUL_K0 * MMUL_N0 * N0 * sizeof(DATA_TYPE); + } + + if(block_x * N0 + block_id * MMUL_N0 * N0 >= N) + { + return; + } + + if(block_y * M0 + y0 * M0 * MMUL_M0 >= M) + { + return; + } + +#if defined(HALF_PRECISION) + TILE(DATA_TYPE, M0, N0, c); + + // Conversion required for the half precision + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + LOOP_UNROLLING(int, n0, 0, 1, N0, + { + c[m0].s[n0] = c_f32[m0].s[n0]; + }) + }) +#endif // defined(HALF_PRECISION) + + // Multiply by the weight of matrix-matrix product and store the result +#if defined(ALPHA) + T_SCALE_CONSTANT(DATA_TYPE, M0, N0, c, (DATA_TYPE)ALPHA, c); +#endif // defined(ALPHA) + + // Add beta*bias +#if defined(BETA) +#if defined(BROADCAST_BIAS) + bia_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE); + + TILE(DATA_TYPE, 1, N0, bias0); + + if(dst_x + N0 <= N || N0_LEFTOVER == 0) + { + bias0[0].v = VLOAD(N0)(0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes)); + } + else + { + VLOAD_PARTIAL(N0, N0_LEFTOVER) + (bias0[0].v, 0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes)); + } + +#ifndef UNIT_BETA + T_SCALE_CONSTANT(DATA_TYPE, 1, N0, bias0, (DATA_TYPE)BETA, bias0); +#endif // UNIT_BIAS + + // c = c + bias[broadcasted] + T_ELTWISE_BROADCAST_X(V_ADD, DATA_TYPE, M0, N0, c, bias0, c); +#else // defined(BROADCAST_BIAS) + TILE(DATA_TYPE, M0, N0, bias0); + + bia_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + dst_y * bia_stride_y + z * bia_stride_z; + + if(dst_x + N0 <= N || N0_LEFTOVER == 0) + { + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + if(dst_y + m0 < M || M0_LEFTOVER == 0) + { + bias0[m0].v = VLOAD(N0)(0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes + m0 * bia_stride_y)); + } + }) + } + else + { + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + if(dst_y + m0 < M || M0_LEFTOVER == 0) + { + VLOAD_PARTIAL(N0, N0_LEFTOVER) + (bias0[m0].v, 0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes + m0 * bia_stride_y)); + } + }) + } + +#ifndef UNIT_BETA + T_SCALE_CONSTANT(DATA_TYPE, M0, N0, bias0, (DATA_TYPE)BETA, bias0); +#endif // UNIT_BIAS + + // c = c + bias + T_ADD(DATA_TYPE, M0, N0, c, bias0, c); + // c = c + bias +#endif // defined(BROADCAST_BIAS) +#endif // defined(BETA) + + T_ACTIVATION(DATA_TYPE, M0, N0, ACTIVATION_TYPE, A_VAL, B_VAL, c, c); + + // Store + if(dst_x + N0 <= N || N0_LEFTOVER == 0) + { + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + if(dst_y + m0 < M || M0_LEFTOVER == 0) + { + VSTORE(N0) + (c[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y)); + } + }) + } + else + { + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + if(dst_y + m0 < M || M0_LEFTOVER == 0) + { + VSTORE_PARTIAL(N0, N0_LEFTOVER) + (c[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y)); + } + }) + } + +#undef RHS_BLOCK_SIZE +#undef RHS_OFFSET_X +#undef RHS_STEP_X +} +#endif // defined(GEMM_MM_RESHAPED_ONLY_RHS_MMUL) + +#if defined(GEMM_MM_RESHAPED_ONLY_RHS_NT_MMUL_TEXTURE) +/** This OpenCL kernel computes the matrix multiplication between 2 matrices using the MMUL extension and the OpenCL image for RHS: + * + * The LHS matrix is NOT reshaped + * The RHS is reshaped with @ref ClGemmMatrixMultiplyReshapedOnlyRhsKernel and the block K0xN0 is NOT transposed + * + * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (e.g. -DN0=8, -DK0=4). + * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2) + * @note The number of output columns processed by the the cooperative mmul extension must be passed at compile time using -DMMUL_N0 (e.g., -DMMUL_N0=2) + * @note The number of output rows processed by the the cooperative mmul extension must be passed at compile time using -DMMUL_M0 (e.g., -DMMUL_M0=2) + * @note The number of lhs columns (or rhs rows) processed by the the cooperative mmul extension must be passed at compile time using -DMMUL_K0 (e.g., -DMMUL_K0=2) + * @note Only the following configurations of M0, N0 and K0 are currently supported: + * - M0 > 0 + * - N0 = 1, 2, 3, 4, 8, 16 + * - K0 = 1 + * + * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. + * The activation function is performed after the bias addition + * + * @param[in] lhs_ptr Pointer to the LHS tensor. Supported data types: F16/F32 + * @param[in] lhs_stride_y Stride of the LHS tensor in Y dimension (in bytes) + * @param[in] lhs_stride_z Stride of the LHS tensor in Z dimension (in bytes) + * @param[in] lhs_w The size of the width dimension of the LHS tensor + * @param[in] lhs_h The size of the height dimension of the LHS tensor + * @param[in] lhs_n The size of the depth dimension of the LHS tensor + * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS tensor + * @param[in] rhs_ptr Pointer to the RHS reshaped tensor. Supported data type: same as @p lhs_ptr + * @param[in] rhs_stride_y Stride of the RHS tensor in Y dimension (in bytes) + * @param[in] rhs_stride_z Stride of the RHS tensor in Z dimension (in bytes) + * @param[in] rhs_w The size of the width dimension of the RHS tensor + * @param[in] rhs_h The size of the height dimension of the RHS tensor + * @param[in] rhs_n The size of the depth dimension of the RHS tensor + * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS tensor + * @param[in] bia_ptr (Optional) Pointer to the bias tensor. Supported data type: same as @p lhs_ptr + * @param[in] bia_stride_y (Optional) Stride of the bias tensor in Y dimension (in bytes) + * @param[in] bia_stride_z (Optional) Stride of the bias tensor in Z dimension (in bytes) + * @param[in] bia_w (Optional) The size of the width dimension of the bias tensor + * @param[in] bia_h (Optional) The size of the height dimension of the bias tensor + * @param[in] bia_n (Optional) The size of the depth dimension of the bias tensor + * @param[in] bia_offset_first_element_in_bytes (Optional) The offset of the first element in the bias tensor + * @param[out] dst_ptr Pointer to the destination tensor. Supported data type: same as @p lhs_ptr + * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) + * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) + * @param[in] dst_w The size of the width dimension of the destination tensor + * @param[in] dst_h The size of the height dimension of the destination tensor + * @param[in] dst_n The size of the depth dimension of the destination tensor + * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor + * @param[in] M Number of rows in LHS matrix not reshaped + * @param[in] N Number of columns in RHS matrix not reshaped + * @param[in] K Number of columns in LHS matrix and rows in RHS matrix not reshaped + */ +__kernel void gemm_mm_reshaped_only_rhs_nt_mmul_texture( + TENSOR3D_T(lhs, BUFFER), + TENSOR3D_T(rhs, IMAGE), +#if defined(BETA) + TENSOR3D_T(bia, BUFFER), +#endif // defined(BETA) + TENSOR3D_T(dst, BUFFER), + const int M, + const int N, + const int K) +{ +#define MMUL_BLOCK_SIZE (MMUL_N0 * MMUL_K0) + + uint x0 = get_global_id(0); // (N / N0) * MMUL_K0 + uint y0 = get_global_id(1); // (M / M0) / MMUL_M0 + uint z = get_global_id(2); // Batch + + // Get block ID and thread ID within the block + uint block_id = (x0 / MMUL_BLOCK_SIZE); + uint thread_id = (x0 % MMUL_BLOCK_SIZE); + + // Coordinate within a block + uint block_x = thread_id % MMUL_N0; + uint block_y = (thread_id / MMUL_M0); + + // Starting destination coordinates + uint dst_x = min(block_x * N0 + block_id * MMUL_N0 * N0, (uint)(N - 1)); + uint dst_y = min(block_y * M0 + y0 * M0 * MMUL_M0, (uint)(M - M0)); + + // Note: We need to clamp dst_x and dst_y because we always need to execute a complete MMUL block! Only after the matrix multiplication + // part can we exit the kernel if it is out-of-bound. Remember, we have a cooperative matrix multiplication. Therefore, we need a full block to get the correct results + + // Starting LHS coordinates + uint lhs_x = block_x; + uint lhs_y = dst_y; + + // Starting RHS coordinates + uint rhs_x = block_y * N0 * MMUL_N0 + block_x * N0; + uint rhs_y = block_id + z * rhs_h; + + // Compute LHS/RHS/DST matrix address + lhs_offset_first_element_in_bytes += lhs_x * sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z; + dst_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z; + + // Initialize the accumulators + // MMUL extension accumulate the result in F32 for both F32 and F16 + TILE(float, M0, N0, c_f32); + +#if !defined(HALF_PRECISION) +#define c c_f32 +#endif // !defined(HALF_PRECISION) + + LOOP_UNROLLING(int, i, 0, 1, M0, + { + c_f32[i].v = 0; + }) + + for(int k = 0; k <= K - MMUL_K0; k += MMUL_K0) + { + TILE(DATA_TYPE, M0, 1, a); + TILE(DATA_TYPE, 1, N0, b); + + // Load tile from the lhs/rhs tensors + T_LOAD(DATA_TYPE, M0, 1, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a); + T_LOAD(DATA_TYPE, 1, N0, IMAGE, rhs, rhs_x, rhs_y, 1, rhs_stride_y, b); + + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + LOOP_UNROLLING(int, n0, 0, 1, N0, + { + c_f32[m0].s[n0] = arm_matrix_multiply(a[m0].s[0], b[0].s[n0], c_f32[m0].s[n0]); + }) + }) + + lhs_offset_first_element_in_bytes += MMUL_K0 * sizeof(DATA_TYPE); + rhs_x += MMUL_K0 * MMUL_N0 * N0; + } + + if(block_x * N0 + block_id * MMUL_N0 * N0 >= N) + { + return; + } + + if(block_y * M0 + y0 * M0 * MMUL_M0 >= M) + { + return; + } + +#if defined(HALF_PRECISION) + TILE(DATA_TYPE, M0, N0, c); + + // Conversion required for the half precision + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + LOOP_UNROLLING(int, n0, 0, 1, N0, + { + c[m0].s[n0] = c_f32[m0].s[n0]; + }) + }) +#endif // defined(HALF_PRECISION) + + // Multiply by the weight of matrix-matrix product and store the result +#if defined(ALPHA) + T_SCALE_CONSTANT(DATA_TYPE, M0, N0, c, (DATA_TYPE)ALPHA, c); +#endif // defined(ALPHA) + + // Add beta*bias +#if defined(BETA) +#if defined(BROADCAST_BIAS) + bia_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE); + + TILE(DATA_TYPE, 1, N0, bias0); + + if(dst_x + N0 <= N || N0_LEFTOVER == 0) + { + bias0[0].v = VLOAD(N0)(0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes)); + } + else + { + VLOAD_PARTIAL(N0, N0_LEFTOVER) + (bias0[0].v, 0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes)); + } + +#ifndef UNIT_BETA + T_SCALE_CONSTANT(DATA_TYPE, 1, N0, bias0, (DATA_TYPE)BETA, bias0); +#endif // UNIT_BIAS + + // c = c + bias[broadcasted] + T_ELTWISE_BROADCAST_X(V_ADD, DATA_TYPE, M0, N0, c, bias0, c); +#else // defined(BROADCAST_BIAS) + TILE(DATA_TYPE, M0, N0, bias0); + + bia_offset_first_element_in_bytes += dst_x * sizeof(DATA_TYPE) + dst_y * bia_stride_y + z * bia_stride_z; + + if(dst_x + N0 <= N || N0_LEFTOVER == 0) + { + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + if(dst_y + m0 < M || M0_LEFTOVER == 0) + { + bias0[m0].v = VLOAD(N0)(0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes + m0 * bia_stride_y)); + } + }) + } + else + { + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + if(dst_y + m0 < M || M0_LEFTOVER == 0) + { + VLOAD_PARTIAL(N0, N0_LEFTOVER) + (bias0[m0].v, 0, (DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes + m0 * bia_stride_y)); + } + }) + } + +#ifndef UNIT_BETA + T_SCALE_CONSTANT(DATA_TYPE, M0, N0, bias0, (DATA_TYPE)BETA, bias0); +#endif // UNIT_BIAS + + // c = c + bias + T_ADD(DATA_TYPE, M0, N0, c, bias0, c); + // c = c + bias +#endif // defined(BROADCAST_BIAS) +#endif // defined(BETA) + + T_ACTIVATION(DATA_TYPE, M0, N0, ACTIVATION_TYPE, A_VAL, B_VAL, c, c); + + // Store + if(dst_x + N0 <= N || N0_LEFTOVER == 0) + { + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + if(dst_y + m0 < M || M0_LEFTOVER == 0) + { + VSTORE(N0) + (c[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y)); + } + }) + } + else + { + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + if(dst_y + m0 < M || M0_LEFTOVER == 0) + { + VSTORE_PARTIAL(N0, N0_LEFTOVER) + (c[m0].v, 0, (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y)); + } + }) + } + +#undef RHS_BLOCK_SIZE +#undef RHS_OFFSET_X +#undef RHS_STEP_X +} +#endif // defined(GEMM_MM_RESHAPED_ONLY_RHS_MMUL_TEXTURE) \ No newline at end of file diff --git a/src/core/CL/cl_kernels/common/gemmlowp.cl b/src/core/CL/cl_kernels/common/gemmlowp.cl index f9d18ec976..53ce296948 100644 --- a/src/core/CL/cl_kernels/common/gemmlowp.cl +++ b/src/core/CL/cl_kernels/common/gemmlowp.cl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -703,7 +703,7 @@ __kernel void gemmlowp_mm_reshaped_only_rhs_t a_offset_s32[0].v *= A_OFFSET; - T_ADD_BROADCAST_X(int, M0, N0, offset_s32, a_offset_s32, offset_s32); + T_ELTWISE_BROADCAST_ADD_X(int, M0, N0, offset_s32, a_offset_s32, offset_s32); #endif // defined(A_OFFSET) #if defined(B_OFFSET) @@ -728,7 +728,7 @@ __kernel void gemmlowp_mm_reshaped_only_rhs_t T_LOAD(int, 1, N0, BUFFER, biases, xo, 0, 1, 0, bias); - T_ADD_BROADCAST_X(int, M0, N0, offset_s32, bias, offset_s32); + T_ELTWISE_BROADCAST_ADD_X(int, M0, N0, offset_s32, bias, offset_s32); #endif // defined(ADD_BIAS) LOOP_UNROLLING(int, i, 0, 1, M0, @@ -1096,17 +1096,17 @@ __kernel void gemmlowp_matrix_a_reduction_dot8(TENSOR3D_DECLARATION(src), VEC_DATA_TYPE(DATA_TYPE, 16) a0 = vload16(0, matrix_a + i); - sum_row += arm_dot(a0.s0123, (VEC_DATA_TYPE(DATA_TYPE, 4))(1)); - sum_row += arm_dot(a0.s4567, (VEC_DATA_TYPE(DATA_TYPE, 4))(1)); - sum_row += arm_dot(a0.s89AB, (VEC_DATA_TYPE(DATA_TYPE, 4))(1)); - sum_row += arm_dot(a0.sCDEF, (VEC_DATA_TYPE(DATA_TYPE, 4))(1)); + DOT_PRODUCT4_INTEGER8(DATA_TYPE, DATA_TYPE, DATA_TYPE, a0.s0123, (VEC_DATA_TYPE(DATA_TYPE, 4))(1), sum_row); + DOT_PRODUCT4_INTEGER8(DATA_TYPE, DATA_TYPE, DATA_TYPE, a0.s4567, (VEC_DATA_TYPE(DATA_TYPE, 4))(1), sum_row); + DOT_PRODUCT4_INTEGER8(DATA_TYPE, DATA_TYPE, DATA_TYPE, a0.s89AB, (VEC_DATA_TYPE(DATA_TYPE, 4))(1), sum_row); + DOT_PRODUCT4_INTEGER8(DATA_TYPE, DATA_TYPE, DATA_TYPE, a0.sCDEF, (VEC_DATA_TYPE(DATA_TYPE, 4))(1), sum_row); a0 = vload16(1, matrix_a + i); - sum_row += arm_dot(a0.s0123, (VEC_DATA_TYPE(DATA_TYPE, 4))(1)); - sum_row += arm_dot(a0.s4567, (VEC_DATA_TYPE(DATA_TYPE, 4))(1)); - sum_row += arm_dot(a0.s89AB, (VEC_DATA_TYPE(DATA_TYPE, 4))(1)); - sum_row += arm_dot(a0.sCDEF, (VEC_DATA_TYPE(DATA_TYPE, 4))(1)); + DOT_PRODUCT4_INTEGER8(DATA_TYPE, DATA_TYPE, DATA_TYPE, a0.s0123, (VEC_DATA_TYPE(DATA_TYPE, 4))(1), sum_row); + DOT_PRODUCT4_INTEGER8(DATA_TYPE, DATA_TYPE, DATA_TYPE, a0.s4567, (VEC_DATA_TYPE(DATA_TYPE, 4))(1), sum_row); + DOT_PRODUCT4_INTEGER8(DATA_TYPE, DATA_TYPE, DATA_TYPE, a0.s89AB, (VEC_DATA_TYPE(DATA_TYPE, 4))(1), sum_row); + DOT_PRODUCT4_INTEGER8(DATA_TYPE, DATA_TYPE, DATA_TYPE, a0.sCDEF, (VEC_DATA_TYPE(DATA_TYPE, 4))(1), sum_row); } // This for loop performs the leftover accumulations diff --git a/src/core/CL/cl_kernels/common/gemmlowp_reshaped_only_rhs_mmul.cl b/src/core/CL/cl_kernels/common/gemmlowp_reshaped_only_rhs_mmul.cl new file mode 100644 index 0000000000..72fe3d3b89 --- /dev/null +++ b/src/core/CL/cl_kernels/common/gemmlowp_reshaped_only_rhs_mmul.cl @@ -0,0 +1,309 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "activation_float_helpers.h" +#include "helpers.h" +#include "tile_helpers.h" +#if defined(GEMMLOWP_MM_RESHAPED_ONLY_RHS_MMUL) +/** This OpenCL kernel computes the matrix multiplication between 2 matrices using the MMUL extension: + * + * The LHS matrix is NOT reshaped + * The RHS is reshaped with @ref ClGemmMatrixMultiplyReshapedOnlyRhsKernel and the block K0xN0 is transposed + * + * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (e.g. -DN0=1, -DK0=1). + * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=1) + * @note The number of output columns processed by the the cooperative mmul extension must be passed at compile time using -DMMUL_N0 (e.g., -DMMUL_N0=4) + * @note The number of output rows processed by the the cooperative mmul extension must be passed at compile time using -DMMUL_M0 (e.g., -DMMUL_M0=4) + * @note The number of lhs columns (or rhs rows) processed by the the cooperative mmul extension must be passed at compile time using -DMMUL_K0 (e.g., -DMMUL_K0=16) + * @note Only the following configurations of M0, N0 and K0 are currently supported: + * - M0 = 1, 2, 4 + * - N0 = 1, 4, 8 + * - K0 = 4 + * + * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. + * The activation function is performed after the bias addition + * + * @param[in] lhs_ptr Pointer to the LHS tensor. Supported data types: QASYMM8/QASYMM8_SIGNED + * @param[in] lhs_stride_y Stride of the LHS tensor in Y dimension (in bytes) + * @param[in] lhs_stride_z Stride of the LHS tensor in Z dimension (in bytes) + * @param[in] lhs_w The size of the width dimension of the LHS tensor + * @param[in] lhs_h The size of the height dimension of the LHS tensor + * @param[in] lhs_n The size of the depth dimension of the LHS tensor + * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS tensor + * @param[in] rhs_ptr Pointer to the RHS reshaped tensor. Supported data type: same as @p lhs_ptr + * @param[in] rhs_stride_y Stride of the RHS tensor in Y dimension (in bytes) + * @param[in] rhs_stride_z Stride of the RHS tensor in Z dimension (in bytes) + * @param[in] rhs_w The size of the width dimension of the RHS tensor + * @param[in] rhs_h The size of the height dimension of the RHS tensor + * @param[in] rhs_n The size of the depth dimension of the RHS tensor + * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS tensor + * @param[in] bia_ptr (Optional) Pointer to the bias tensor. Supported data type: S32 + * @param[in] bia_stride_y (Optional) Stride of the bias tensor in Y dimension (in bytes) + * @param[in] bia_stride_z (Optional) Stride of the bias tensor in Z dimension (in bytes) + * @param[in] bia_w (Optional) The size of the width dimension of the bias tensor + * @param[in] bia_h (Optional) The size of the height dimension of the bias tensor + * @param[in] bia_n (Optional) The size of the depth dimension of the bias tensor + * @param[in] bia_offset_first_element_in_bytes (Optional) The offset of the first element in the bias tensor + * @param[out] dst_ptr Pointer to the destination tensor. Supported data type: same as @p lhs_ptr or S32 + * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) + * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) + * @param[in] dst_w The size of the width dimension of the destination tensor + * @param[in] dst_h The size of the height dimension of the destination tensor + * @param[in] dst_n The size of the depth dimension of the destination tensor + * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor + * @param[in] M Number of rows in LHS matrix not reshaped + * @param[in] N Number of columns in RHS matrix not reshaped + * @param[in] K Number of columns in LHS matrix and rows in RHS matrix not reshaped + * @param[in] sum_col_ptr (Optional) Pointer to the source tensor. Supported data type: S32 + * @param[in] sum_col_stride_x (Optional) Stride of the source tensor in X dimension (in bytes) + * @param[in] sum_col_step_x (Optional) sum_col_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] sum_col_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes) + * @param[in] sum_col_step_y (Optional) sum_col_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] sum_col_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor + * @param[in] sum_row_ptr (Optional) Pointer to the source tensor. Supported data type: S32 + * @param[in] sum_row_stride_x (Optional) Stride of the source tensor in X dimension (in bytes) + * @param[in] sum_row_step_x (Optional) sum_row_stride_x * number of elements along X processed per workitem(in bytes) + * @param[in] sum_row_stride_y (Optional) Stride of the source tensor in Y dimension (in bytes) + * @param[in] sum_row_step_y (Optional) sum_row_stride_y * number of elements along Y processed per workitem(in bytes) + * @param[in] sum_row_offset_first_element_in_bytes (Optional) The offset of the first element in the source tensor + */ +__kernel void gemmlowp_mm_reshaped_only_rhs_mmul( + TENSOR3D_T(lhs, BUFFER), + TENSOR3D_T(rhs, BUFFER), +#if defined(ADD_BIAS) + TENSOR3D_T(bia, BUFFER), +#endif // defined(ADD_BIAS) + TENSOR3D_T(dst, BUFFER), + const int M, + const int N, + const int K +#if defined(A_OFFSET) + , + TENSOR3D_T(sum_col, BUFFER) +#endif // defined(A_OFFSET) +#if defined(B_OFFSET) + , + TENSOR3D_T(sum_row, BUFFER) +#endif // defined(B_OFFSET) +) +{ +#define MMUL_BLOCK_SIZE (MMUL_N0 * MMUL_M0) +#define VEC_SIZE 4 // For int8 types input to mmul instruction is a length 4 vector + + uint x0 = get_global_id(0); + uint y0 = get_global_id(1); + uint z = get_global_id(2); + + // Get block ID and thread ID within the block + uint block_id = (x0 / MMUL_BLOCK_SIZE); + uint thread_id = (x0 % MMUL_BLOCK_SIZE); + + // Coordinate within a block + uint block_x = thread_id % MMUL_N0; + uint block_y = (thread_id / MMUL_M0); + + // Starting destination coordinates + uint dst_x = min(block_x * N0 + block_id * MMUL_N0 * N0, (uint)(N - 1)); + uint dst_y = min(block_y * M0 + y0 * M0 * MMUL_M0, (uint)(M - M0)); + + uint lhs_x = VEC_SIZE * block_x; + uint lhs_y = dst_y; + + uint rhs_x = VEC_SIZE * N0 * block_y; + uint rhs_y = 4 * block_id + block_x; + + // Compute LHS/RHS/DST matrix address + lhs_offset_first_element_in_bytes += lhs_x * sizeof(DATA_TYPE) + lhs_y * lhs_stride_y + z * lhs_stride_z; + rhs_offset_first_element_in_bytes += rhs_x * sizeof(DATA_TYPE) + rhs_y * rhs_stride_y + z * rhs_stride_z; + dst_offset_first_element_in_bytes += dst_x * sizeof(OUT_DATA_TYPE) + dst_y * dst_stride_y + z * dst_stride_z; + + TILE(ACC_DATA_TYPE, M0, N0, c); + LOOP_UNROLLING(int, i, 0, 1, M0, + { + c[i].v = 0; + }) + + for(int k = 0; k <= K - MMUL_K0; k += MMUL_K0) + { + TILE(DATA_TYPE, M0, VEC_SIZE, a); + T_LOAD(DATA_TYPE, M0, VEC_SIZE, BUFFER, lhs, 0, 0, 1, lhs_stride_y, a); + + TILE(DATA_TYPE, N0, VEC_SIZE, b); + T_LOAD(DATA_TYPE, N0, VEC_SIZE, BUFFER, rhs, 0, 0, 1, VEC_SIZE, b); + + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + LOOP_UNROLLING(int, n0, 0, 1, N0, + { + VEC_TYPE vec_a = (VEC_TYPE)(a[m0].s[0], a[m0].s[1], a[m0].s[2], a[m0].s[3]); + VEC_TYPE vec_b = (VEC_TYPE)(b[n0].s[0], b[n0].s[1], b[n0].s[2], b[n0].s[3]); + c[m0].s[n0] = arm_matrix_multiply(vec_a, vec_b, c[m0].s[n0]); + }) + }) + + lhs_offset_first_element_in_bytes += MMUL_K0 * sizeof(DATA_TYPE); + rhs_offset_first_element_in_bytes += MMUL_K0 * N0 * sizeof(DATA_TYPE); + } + + if(block_x * N0 + block_id * MMUL_N0 * N0 >= N) + { + return; + } + + if(block_y * M0 + y0 * M0 * MMUL_M0 >= M) + { + return; + } + +#if defined(FUSED_OUTPUT_STAGE_FIXED_POINT) + + TILE(int, M0, N0, offset_s32); + LOOP_UNROLLING(int, i, 0, 1, M0, + { + offset_s32[i].v = (VEC_DATA_TYPE(int, N0))K_OFFSET; + }) + +#if defined(A_OFFSET) + + TILE(int, 1, N0, a_offset_s32); + + T_LOAD(int, 1, N0, BUFFER, sum_col, dst_x, z, 1, sum_col_stride_z, a_offset_s32); + + a_offset_s32[0].v *= A_OFFSET; + + T_ELTWISE_BROADCAST_ADD_X(int, M0, N0, offset_s32, a_offset_s32, offset_s32); +#endif // defined(A_OFFSET) + +#if defined(B_OFFSET) + + TILE(int, M0, 1, b_offset_s32); + + T_LOAD(int, M0, 1, BUFFER, sum_row, dst_y, z * M, 1, 4, b_offset_s32); + + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + offset_s32[m0].v += b_offset_s32[m0].v *B_OFFSET; + }) + +#endif // defined(B_OFFSET) + +#if defined(ADD_BIAS) +#if defined(BROADCAST_BIAS) + bia_offset_first_element_in_bytes += dst_x * sizeof(ACC_DATA_TYPE) + z * bia_stride_y; + + TILE(int, M0, N0, bias); + + T_LOAD(int, M0, N0, BUFFER, bia, dst_x, dst_y, 1, 1, bias); + + T_ADD(ACC_DATA_TYPE, M0, N0, offset_s32, bias, offset_s32); + +#else // defined(BROADCAST_BIAS) + bia_offset_first_element_in_bytes += dst_x * sizeof(ACC_DATA_TYPE); + + TILE(int, 1, N0, bias); + + if(dst_x + N0 <= N || N0_LEFTOVER == 0) + { + bias[0].v = VLOAD(N0)(0, (ACC_DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes)); + } + else + { + VLOAD_PARTIAL(N0, N0_LEFTOVER) + (bias[0].v, 0, (ACC_DATA_TYPE *)(bia_ptr + bia_offset_first_element_in_bytes)); + } + + T_ELTWISE_BROADCAST_ADD_X(int, M0, N0, offset_s32, bias, offset_s32); + +#endif // defined(BROADCAST_BIAS) +#endif // defined(ADD_BIAS) + + T_ADD(ACC_DATA_TYPE, M0, N0, c, offset_s32, c); + TILE(OUT_DATA_TYPE, M0, N0, c_lp); + T_QUANTIZE8(ACC_DATA_TYPE, OUT_DATA_TYPE, PER_TENSOR, M0, N0, RESULT_OFFSET, RESULT_SHIFT, RESULT_MULTIPLIER, c, 0, 0, c_lp); + +#if defined(MIN_BOUND) + LOOP_UNROLLING(int, i, 0, 1, M0, + { + c_lp[i].v = max(c_lp[i].v, (VEC_DATA_TYPE(OUT_DATA_TYPE, N0))MIN_BOUND); + }) +#endif // defined(MIN_BOUND) +#if defined(MAX_BOUND) + LOOP_UNROLLING(int, i, 0, 1, M0, + { + c_lp[i].v = min(c_lp[i].v, (VEC_DATA_TYPE(OUT_DATA_TYPE, N0))MAX_BOUND); + }) +#endif // defined(MAX_BOUND) + + T_ACTIVATION(DATA_TYPE, M0, N0, ACTIVATION_TYPE, A_VAL, B_VAL, c, c); + + if(dst_x + N0 <= N || N0_LEFTOVER == 0) + { + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + if(dst_y + m0 < M || M0_LEFTOVER == 0) + { + VSTORE(N0) + (c_lp[m0].v, 0, (__global OUT_DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y)); + } + }) + } + else + { + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + if(dst_y + m0 < M || M0_LEFTOVER == 0) + { + VSTORE_PARTIAL(N0, N0_LEFTOVER) + (c_lp[m0].v, 0, (__global OUT_DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y)); + } + }) + } + +#else // FUSED_OUTPUT_STAGE_FIXED_POINT + // Store + if(dst_x + N0 <= N || N0_LEFTOVER == 0) + { + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + if(dst_y + m0 < M || M0_LEFTOVER == 0) + { + VSTORE(N0) + (c[m0].v, 0, (__global OUT_DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y)); + } + }) + } + else + { + LOOP_UNROLLING(int, m0, 0, 1, M0, + { + if(dst_y + m0 < M || M0_LEFTOVER == 0) + { + VSTORE_PARTIAL(N0, N0_LEFTOVER) + (c[m0].v, 0, (__global OUT_DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + m0 * dst_stride_y)); + } + }) + } +#endif // FUSED_OUTPUT_STAGE_FIXED_POINT +} + +#endif // defined(GEMMLOWP_MM_RESHAPED_ONLY_RHS_MMUL) diff --git a/src/core/CL/cl_kernels/helpers.h b/src/core/CL/cl_kernels/helpers.h index 4018c40b16..298edc244f 100644 --- a/src/core/CL/cl_kernels/helpers.h +++ b/src/core/CL/cl_kernels/helpers.h @@ -44,6 +44,7 @@ #define GPU_ARCH_MIDGARD 0x100 #define GPU_ARCH_BIFROST 0x200 +#define GPU_ARCH_VALHALL 0x300 /** Concatenate two inputs. * diff --git a/src/core/CL/cl_kernels/nhwc/direct_convolution.cl b/src/core/CL/cl_kernels/nhwc/direct_convolution.cl index f1b422a68f..e602fbb525 100644 --- a/src/core/CL/cl_kernels/nhwc/direct_convolution.cl +++ b/src/core/CL/cl_kernels/nhwc/direct_convolution.cl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -169,11 +169,17 @@ __kernel void direct_convolution_nhwc( TILE(SRC_DATA_TYPE, M0, K0, a); TILE(WEI_DATA_TYPE, N0, K0, b); + // Initialize tiles LOOP_UNROLLING(int, i, 0, 1, M0, { a[i].v = ZERO_VALUE; }) + LOOP_UNROLLING(int, i, 0, 1, N0, + { + b[i].v = ZERO_VALUE; + }) + // Load tile from the src tensor T_LOAD_NHWC_INDIRECT(SRC_DATA_TYPE, M0, K0, SRC_TENSOR_TYPE, src, bout, yk, xk, ck, _ISRC_WIDTH, _ISRC_HEIGHT, src_stride_y, xi, yi, a); @@ -199,11 +205,17 @@ __kernel void direct_convolution_nhwc( TILE(SRC_DATA_TYPE, M0, 1, a); TILE(WEI_DATA_TYPE, N0, 1, b); + // Initialize tiles LOOP_UNROLLING(int, i, 0, 1, M0, { a[i].v = ZERO_VALUE; }) + LOOP_UNROLLING(int, i, 0, 1, N0, + { + b[i].v = ZERO_VALUE; + }) + // Load tile from the src tensor T_LOAD_NHWC_INDIRECT(SRC_DATA_TYPE, M0, 1, SRC_TENSOR_TYPE, src, bout, yk, xk, ck, _ISRC_WIDTH, _ISRC_HEIGHT, src_stride_y, xi, yi, a); @@ -233,7 +245,7 @@ __kernel void direct_convolution_nhwc( T_LOAD(BIA_DATA_TYPE, 1, N0, BUFFER, bia, cout, 0, 1, 0, bias0); // c = c + bias[broadcasted] - T_ADD_BROADCAST_X(ACC_DATA_TYPE, M0, N0, c, bias0, c); + T_ELTWISE_BROADCAST_ADD_X(ACC_DATA_TYPE, M0, N0, c, bias0, c); #endif // HAS_BIAS diff --git a/src/core/CL/cl_kernels/nhwc/direct_convolution3d.cl b/src/core/CL/cl_kernels/nhwc/direct_convolution3d.cl index 587f3984ab..807b990e82 100644 --- a/src/core/CL/cl_kernels/nhwc/direct_convolution3d.cl +++ b/src/core/CL/cl_kernels/nhwc/direct_convolution3d.cl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -254,7 +254,7 @@ __kernel void direct_convolution3d_ndhwc( } // c = c + bias[broadcasted] - T_ADD_BROADCAST_X(ACC_DATA_TYPE, M0, N0, c, bias0, c); + T_ELTWISE_BROADCAST_ADD_X(ACC_DATA_TYPE, M0, N0, c, bias0, c); #endif // HAS_BIAS diff --git a/src/core/CL/cl_kernels/nhwc/dwc_native_fp_nhwc.cl b/src/core/CL/cl_kernels/nhwc/dwc_native_fp_nhwc.cl index 4f57a81e7b..b24a6ae85f 100644 --- a/src/core/CL/cl_kernels/nhwc/dwc_native_fp_nhwc.cl +++ b/src/core/CL/cl_kernels/nhwc/dwc_native_fp_nhwc.cl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -175,7 +175,7 @@ __kernel void dwc_native_fp_nhwc( T_LOAD(BIA_DATA_TYPE, 1, N0, BUFFER, bia, (cout * DEPTH_MULTIPLIER) + d, 0, 0, 0, bias0); // c = c + bias[broadcasted] - T_ADD_BROADCAST_X(ACC_DATA_TYPE, M0, N0, c, bias0, c); + T_ELTWISE_BROADCAST_ADD_X(ACC_DATA_TYPE, M0, N0, c, bias0, c); #endif // HAS_BIAS T_ACTIVATION(ACC_DATA_TYPE, M0, N0, ACTIVATION_TYPE, A_VAL, B_VAL, c, c); diff --git a/src/core/CL/cl_kernels/nhwc/dwc_native_quantized_nhwc.cl b/src/core/CL/cl_kernels/nhwc/dwc_native_quantized_nhwc.cl index ec2593af71..263a23ef28 100644 --- a/src/core/CL/cl_kernels/nhwc/dwc_native_quantized_nhwc.cl +++ b/src/core/CL/cl_kernels/nhwc/dwc_native_quantized_nhwc.cl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -236,7 +236,7 @@ __kernel void dwc_native_quantized_nhwc( T_LOAD(BIA_DATA_TYPE, 1, N0, BUFFER, bia, cout * DEPTH_MULTIPLIER + d, 0, 0, 0, bias0); // c = c + bias[broadcasted] - T_ADD_BROADCAST_X(ACC_DATA_TYPE, M0, N0, c, bias0, c); + T_ELTWISE_BROADCAST_ADD_X(ACC_DATA_TYPE, M0, N0, c, bias0, c); #endif // HAS_BIAS T_LOAD_MULTIPLIERS_SHIFT(QUANTIZATION_TYPE); diff --git a/src/core/CL/cl_kernels/nhwc/winograd_output_transform.cl b/src/core/CL/cl_kernels/nhwc/winograd_output_transform.cl index ed6da9fd12..0883cd99c8 100644 --- a/src/core/CL/cl_kernels/nhwc/winograd_output_transform.cl +++ b/src/core/CL/cl_kernels/nhwc/winograd_output_transform.cl @@ -61,7 +61,6 @@ * @param[in] _ISRC_HEIGHT The source tensor's height * @param[in] _IDST_WIDTH The destination tensor's width * @param[in] _IDST_HEIGHT The destination tensor's height - * @param[in] _INUM_TILES_X The number of tiles along the X direction */ __kernel void winograd_output_transform_2x2_7x7_nhwc( TENSOR4D(src, BUFFER), @@ -72,15 +71,14 @@ __kernel void winograd_output_transform_2x2_7x7_nhwc( int dst_size, const int _ISRC_HEIGHT, const int _IDST_WIDTH, - const int _IDST_HEIGHT, - const int _INUM_TILES_X) + const int _IDST_HEIGHT) { const int cout = GET_SPATIAL_IDX(0, N0, 0); // OFM const int mout = GET_SPATIAL_IDX(1, 1, 0); // WINOGRAD OUTPUT TILES const int bout = GET_SPATIAL_IDX(2, 1, 0); // BATCH SIZE IDX - int x_out = (mout % _INUM_TILES_X) * OUTPUT_TILE_W; - int y_out = (mout / _INUM_TILES_X) * OUTPUT_TILE_H; + int x_out = (mout % NUM_TILES_X) * OUTPUT_TILE_W; + int y_out = (mout / NUM_TILES_X) * OUTPUT_TILE_H; #if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) TILE(DATA_TYPE, 8, N0, in); @@ -113,7 +111,7 @@ __kernel void winograd_output_transform_2x2_7x7_nhwc( T_LOAD(DATA_TYPE, 1, N0, BUFFER, bias, cout, 0, 1, 0, b); - T_ADD_BROADCAST_X(DATA_TYPE, 2, N0, out, b, out); + T_ELTWISE_BROADCAST_ADD_X(DATA_TYPE, 2, N0, out, b, out); #endif // defined(HAS_BIAS) T_ACTIVATION(DATA_TYPE, 2, N0, ACTIVATION_TYPE, A_VAL, B_VAL, out, out); @@ -179,7 +177,7 @@ __kernel void winograd_output_transform_2x2_7x7_nhwc( T_LOAD(DATA_TYPE, 1, N0, BUFFER, bias, cout, 0, 1, 0, b); - T_ADD_BROADCAST_X(DATA_TYPE, 4, N0, out, b, out); + T_ELTWISE_BROADCAST_ADD_X(DATA_TYPE, 4, N0, out, b, out); #endif // defined(HAS_BIAS) T_ACTIVATION(DATA_TYPE, 4, N0, ACTIVATION_TYPE, A_VAL, B_VAL, out, out); @@ -240,7 +238,6 @@ __kernel void winograd_output_transform_2x2_7x7_nhwc( * @param[in] SRC_HEIGHT The source tensor's height * @param[in] DST_WIDTH The destination tensor's width * @param[in] DST_HEIGHT The destination tensor's height - * @param[in] NUM_TILES_X The number of tiles along the X direction */ __kernel void winograd_output_transform_4x4_3x3_nhwc( TENSOR4D(src, BUFFER), @@ -251,8 +248,7 @@ __kernel void winograd_output_transform_4x4_3x3_nhwc( int dst_size, const int SRC_HEIGHT, const int DST_WIDTH, - const int DST_HEIGHT, - const int NUM_TILES_X) + const int DST_HEIGHT) { const int cout = GET_SPATIAL_IDX(0, N0, 0); // OFM const int mout = GET_SPATIAL_IDX(1, 1, 0); // WINOGRAD OUTPUT TILES @@ -291,7 +287,7 @@ __kernel void winograd_output_transform_4x4_3x3_nhwc( T_LOAD(DATA_TYPE, 1, N0, BUFFER, bias, cout, 0, 1, 0, b); // c = c + bias[broadcasted] - T_ADD_BROADCAST_X(DATA_TYPE, 4, N0, out, b, out); + T_ELTWISE_BROADCAST_ADD_X(DATA_TYPE, 4, N0, out, b, out); #endif // HAS_BIAS int x_out = (mout % NUM_TILES_X) * OUTPUT_TILE_W; @@ -378,7 +374,7 @@ __kernel void winograd_output_transform_4x4_3x3_nhwc( T_LOAD(DATA_TYPE, 1, N0, BUFFER, bias, cout, 0, 1, 0, b); // c = c + bias[broadcasted] - T_ADD_BROADCAST_X(DATA_TYPE, 16, N0, out, b, out); + T_ELTWISE_BROADCAST_ADD_X(DATA_TYPE, 16, N0, out, b, out); #endif // HAS_BIAS int x_out = (mout % NUM_TILES_X) * OUTPUT_TILE_W; @@ -439,7 +435,6 @@ __kernel void winograd_output_transform_4x4_3x3_nhwc( * @param[in] SRC_HEIGHT The source tensor's height * @param[in] DST_WIDTH The destination tensor's width * @param[in] DST_HEIGHT The destination tensor's height - * @param[in] NUM_TILES_X The number of tiles along the X direction */ __kernel void winograd_output_transform_4x4_5x5_nhwc( TENSOR4D(src, BUFFER), @@ -450,8 +445,7 @@ __kernel void winograd_output_transform_4x4_5x5_nhwc( int dst_size, const int SRC_HEIGHT, const int DST_WIDTH, - const int DST_HEIGHT, - const int NUM_TILES_X) + const int DST_HEIGHT) { const int cout = GET_SPATIAL_IDX(0, N0, 0); // OFM const int mout = GET_SPATIAL_IDX(1, 1, 0); // WINOGRAD OUTPUT TILES @@ -494,7 +488,7 @@ __kernel void winograd_output_transform_4x4_5x5_nhwc( T_LOAD(DATA_TYPE, 1, N0, BUFFER, bias, cout, 0, 1, 0, b); // c = c + bias[broadcasted] - T_ADD_BROADCAST_X(DATA_TYPE, 4, N0, out, b, out); + T_ELTWISE_BROADCAST_ADD_X(DATA_TYPE, 4, N0, out, b, out); #endif // HAS_BIAS int x_out = (mout % NUM_TILES_X) * OUTPUT_TILE_W; @@ -592,7 +586,7 @@ __kernel void winograd_output_transform_4x4_5x5_nhwc( T_LOAD(DATA_TYPE, 1, N0, BUFFER, bias, cout, 0, 1, 0, b); // c = c + bias[broadcasted] - T_ADD_BROADCAST_X(DATA_TYPE, 16, N0, out, b, out); + T_ELTWISE_BROADCAST_ADD_X(DATA_TYPE, 16, N0, out, b, out); #endif // HAS_BIAS int x_out = (mout % NUM_TILES_X) * OUTPUT_TILE_W; @@ -656,7 +650,6 @@ __kernel void winograd_output_transform_4x4_5x5_nhwc( * @param[in] SRC_HEIGHT The source tensor's height * @param[in] DST_WIDTH The destination tensor's width * @param[in] DST_HEIGHT The destination tensor's height - * @param[in] NUM_TILES_X The number of tiles along the X direction */ __kernel void winograd_output_transform_2x1_7x1_nhwc( TENSOR4D_DECLARATION(src), @@ -667,8 +660,7 @@ __kernel void winograd_output_transform_2x1_7x1_nhwc( int dst_size, const int SRC_HEIGHT, const int DST_WIDTH, - const int DST_HEIGHT, - const int NUM_TILES_X) + const int DST_HEIGHT) { winograd_output_transform_2x2_7x7_nhwc(src_ptr, src_stride_x, @@ -699,8 +691,7 @@ __kernel void winograd_output_transform_2x1_7x1_nhwc( dst_size, SRC_HEIGHT, DST_WIDTH, - DST_HEIGHT, - NUM_TILES_X); + DST_HEIGHT); } #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_2X1_7X1_NHWC) #endif // defined(VEC_SIZE) && VEC_SIZE == 2 @@ -739,7 +730,6 @@ __kernel void winograd_output_transform_2x1_7x1_nhwc( * @param[in] SRC_HEIGHT The source tensor's height * @param[in] DST_WIDTH The destination tensor's width * @param[in] DST_HEIGHT The destination tensor's height - * @param[in] NUM_TILES_X The number of tiles along the X direction */ __kernel void winograd_output_transform_4x1_3x1_nhwc( TENSOR4D_DECLARATION(src), @@ -750,8 +740,7 @@ __kernel void winograd_output_transform_4x1_3x1_nhwc( int dst_size, const int SRC_HEIGHT, const int DST_WIDTH, - const int DST_HEIGHT, - const int NUM_TILES_X) + const int DST_HEIGHT) { winograd_output_transform_4x4_3x3_nhwc(src_ptr, src_stride_x, @@ -782,8 +771,7 @@ __kernel void winograd_output_transform_4x1_3x1_nhwc( dst_size, SRC_HEIGHT, DST_WIDTH, - DST_HEIGHT, - NUM_TILES_X); + DST_HEIGHT); } #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_4X1_3X1_NHWC) @@ -820,7 +808,6 @@ __kernel void winograd_output_transform_4x1_3x1_nhwc( * @param[in] SRC_HEIGHT The source tensor's height * @param[in] DST_WIDTH The destination tensor's width * @param[in] DST_HEIGHT The destination tensor's height - * @param[in] NUM_TILES_X The number of tiles along the X direction */ __kernel void winograd_output_transform_4x1_5x1_nhwc( TENSOR4D_DECLARATION(src), @@ -831,8 +818,7 @@ __kernel void winograd_output_transform_4x1_5x1_nhwc( int dst_size, const int SRC_HEIGHT, const int DST_WIDTH, - const int DST_HEIGHT, - const int NUM_TILES_X) + const int DST_HEIGHT) { winograd_output_transform_4x4_5x5_nhwc(src_ptr, src_stride_x, @@ -863,8 +849,7 @@ __kernel void winograd_output_transform_4x1_5x1_nhwc( dst_size, SRC_HEIGHT, DST_WIDTH, - DST_HEIGHT, - NUM_TILES_X); + DST_HEIGHT); } #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_4X1_5X1_NHWC) #endif // defined(VEC_SIZE) && VEC_SIZE == 4 @@ -905,7 +890,6 @@ __kernel void winograd_output_transform_4x1_5x1_nhwc( * @param[in] SRC_HEIGHT The source tensor's height * @param[in] DST_WIDTH The destination tensor's width * @param[in] DST_HEIGHT The destination tensor's height - * @param[in] NUM_TILES_X The number of tiles along the X direction */ __kernel void winograd_output_transform_1x2_1x7_nhwc( TENSOR4D_DECLARATION(src), @@ -916,8 +900,7 @@ __kernel void winograd_output_transform_1x2_1x7_nhwc( int dst_size, const int SRC_HEIGHT, const int DST_WIDTH, - const int DST_HEIGHT, - const int NUM_TILES_X) + const int DST_HEIGHT) { winograd_output_transform_2x2_7x7_nhwc(src_ptr, src_stride_x, @@ -948,8 +931,7 @@ __kernel void winograd_output_transform_1x2_1x7_nhwc( dst_size, SRC_HEIGHT, DST_WIDTH, - DST_HEIGHT, - NUM_TILES_X); + DST_HEIGHT); } #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_1X2_1X7_NHWC) #endif // defined(VEC_SIZE) && VEC_SIZE == 2 @@ -988,7 +970,6 @@ __kernel void winograd_output_transform_1x2_1x7_nhwc( * @param[in] SRC_HEIGHT The source tensor's height * @param[in] DST_WIDTH The destination tensor's width * @param[in] DST_HEIGHT The destination tensor's height - * @param[in] NUM_TILES_X The number of tiles along the X direction */ __kernel void winograd_output_transform_1x4_1x3_nhwc( TENSOR4D_DECLARATION(src), @@ -999,8 +980,7 @@ __kernel void winograd_output_transform_1x4_1x3_nhwc( int dst_size, const int SRC_HEIGHT, const int DST_WIDTH, - const int DST_HEIGHT, - const int NUM_TILES_X) + const int DST_HEIGHT) { winograd_output_transform_4x4_3x3_nhwc(src_ptr, src_stride_x, @@ -1031,8 +1011,7 @@ __kernel void winograd_output_transform_1x4_1x3_nhwc( dst_size, SRC_HEIGHT, DST_WIDTH, - DST_HEIGHT, - NUM_TILES_X); + DST_HEIGHT); } #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_1X4_1X3_NHWC) @@ -1069,7 +1048,6 @@ __kernel void winograd_output_transform_1x4_1x3_nhwc( * @param[in] SRC_HEIGHT The source tensor's height * @param[in] DST_WIDTH The destination tensor's width * @param[in] DST_HEIGHT The destination tensor's height - * @param[in] NUM_TILES_X The number of tiles along the X direction */ __kernel void winograd_output_transform_1x4_1x5_nhwc( TENSOR4D_DECLARATION(src), @@ -1080,8 +1058,7 @@ __kernel void winograd_output_transform_1x4_1x5_nhwc( int dst_size, const int SRC_HEIGHT, const int DST_WIDTH, - const int DST_HEIGHT, - const int NUM_TILES_X) + const int DST_HEIGHT) { winograd_output_transform_4x4_5x5_nhwc(src_ptr, src_stride_x, @@ -1112,8 +1089,7 @@ __kernel void winograd_output_transform_1x4_1x5_nhwc( dst_size, SRC_HEIGHT, DST_WIDTH, - DST_HEIGHT, - NUM_TILES_X); + DST_HEIGHT); } #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_1X4_1X5_NHWC) #endif // defined(VEC_SIZE) && VEC_SIZE == 4 diff --git a/src/core/CL/cl_kernels/tile_helpers.h b/src/core/CL/cl_kernels/tile_helpers.h index ec57022f63..4b6144a22d 100644 --- a/src/core/CL/cl_kernels/tile_helpers.h +++ b/src/core/CL/cl_kernels/tile_helpers.h @@ -333,7 +333,11 @@ ({ \ c += (C_DATA_TYPE)(a) * (C_DATA_TYPE)(b); \ }) -#if defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8) +#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_khr_integer_dot_product) +#define DOT_PRODUCT2_INTEGER8(A_DATA_TYPE, B_DATA_TYPE, C_DATA_TYPE, a, b, c) c += dot((A_DATA_TYPE##4)((a).s01, (A_DATA_TYPE##2)(0)), (B_DATA_TYPE##4)(((b).s01), (B_DATA_TYPE##2)(0))); +#define DOT_PRODUCT3_INTEGER8(A_DATA_TYPE, B_DATA_TYPE, C_DATA_TYPE, a, b, c) c += dot((A_DATA_TYPE##4)((a).s012, (A_DATA_TYPE)0), (B_DATA_TYPE##4)(((b).s012), (B_DATA_TYPE)0)); +#define DOT_PRODUCT4_INTEGER8(A_DATA_TYPE, B_DATA_TYPE, C_DATA_TYPE, a, b, c) c += dot((a), (b)); +#elif defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8) // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_khr_integer_dot_product) #define DOT_PRODUCT2_INTEGER8(A_DATA_TYPE, B_DATA_TYPE, C_DATA_TYPE, a, b, c) c = arm_dot_acc((A_DATA_TYPE##4)((a).s01, (A_DATA_TYPE##2)(0)), (B_DATA_TYPE##4)(((b).s01), (B_DATA_TYPE##2)(0)), (c)); #define DOT_PRODUCT3_INTEGER8(A_DATA_TYPE, B_DATA_TYPE, C_DATA_TYPE, a, b, c) c = arm_dot_acc((A_DATA_TYPE##4)((a).s012, (A_DATA_TYPE)0), (B_DATA_TYPE##4)(((b).s012), (B_DATA_TYPE)0), (c)); #define DOT_PRODUCT4_INTEGER8(A_DATA_TYPE, B_DATA_TYPE, C_DATA_TYPE, a, b, c) c = arm_dot_acc((a), (b), (c)); @@ -966,6 +970,9 @@ #define ACT_OP_QUANTIZED(op, DATA_TYPE, VEC_SIZE, ZERO_VALUE, A_VAL, B_VAL, x) op##_op_quantized(DATA_TYPE, VEC_SIZE, ZERO_VALUE, A_VAL, B_VAL, x) #define ACTIVATION_QUANTIZED(op, DATA_TYPE, VEC_SIZE, ZERO_VALUE, A_VAL, B_VAL, x) ACT_OP_QUANTIZED(op, DATA_TYPE, VEC_SIZE, ZERO_VALUE, A_VAL, B_VAL, x) +#define V_ADD(A_VAL, B_VAL) ((A_VAL) + (B_VAL)) +#define V_DIV(A_VAL, B_VAL) ((A_VAL) / (B_VAL)) + /** Element-wise activation for quantized types * * @note Performs: activation(LHS) = DST @@ -988,6 +995,25 @@ }) \ }) +/** Element-wise addition between two tiles + * + * @note Performs: LHS + RHS = DST + * + * @param[in] DATA_TYPE LHS/RHS/DST data type + * @param[in] M0 Number of LHS rows + * @param[in] N0 Number of LHS columns + * @param[in] lhs LHS tile + * @param[in] rhs Constant RHS tile + * @param[out] dst DST tile + */ +#define T_ADD(DATA_TYPE, M0, N0, lhs, rhs, dst) \ + ({ \ + LOOP_UNROLLING(int, _m0, 0, 1, M0, \ + { \ + dst[_m0].v = lhs[_m0].v + rhs[_m0].v; \ + }) \ + }) + /** Element-wise addition with a constant value * * @note Performs: LHS + constant = DST @@ -1003,18 +1029,38 @@ ({ \ LOOP_UNROLLING(int, _m0, 0, 1, M0, \ { \ - LOOP_UNROLLING(int, _n0, 0, 1, N0, \ - { \ - dst[_m0].s[_n0] = lhs[_m0].s[_n0] + rhs_constant; \ - }) \ + dst[_m0].v = lhs[_m0].v + (DATA_TYPE)rhs_constant; \ }) \ }) -/** Element-wise addition with RHS broadcasted (RHS has the X dimension only) +#define T_ELTWISE_BROADCAST_ADD_X(DST_DATA_TYPE, M0, N0, lhs, rhs, dst) T_ELTWISE_BROADCAST_X(V_ADD, DST_DATA_TYPE, M0, N0, lhs, rhs, dst) +#define T_ELTWISE_BROADCAST_DIV_X(DST_DATA_TYPE, M0, N0, lhs, rhs, dst) T_ELTWISE_BROADCAST_X(V_DIV, DST_DATA_TYPE, M0, N0, lhs, rhs, dst) + +/** Element-wise scale with a constant value + * + * @note Performs: LHS * constant = DST * - * @note Performs: LHS + RHS[broadcasted] = DST + * @param[in] DATA_TYPE LHS/RHS/DST data type + * @param[in] M0 Number of LHS rows + * @param[in] N0 Number of LHS columns + * @param[in] lhs LHS tile + * @param[in] rhs_constant Constant value + * @param[out] dst DST tile + */ +#define T_SCALE_CONSTANT(DATA_TYPE, M0, N0, lhs, rhs_constant, dst) \ + ({ \ + LOOP_UNROLLING(int, _m0, 0, 1, M0, \ + { \ + dst[_m0].v = lhs[_m0].v * (DATA_TYPE)rhs_constant; \ + }) \ + }) + +/** Element-wise operation with RHS broadcasted (RHS has the X dimension only) + * + * @note Performs: LHS OP RHS[broadcasted] = DST * @note Both tiles must have same data type * + * @param[in] T_ELWISE_OP Elementwise operator to perform * @param[in] DST_DATA_TYPE DST data type * @param[in] M0 Number of LHS rows * @param[in] N0 Number of LHS columns @@ -1022,19 +1068,23 @@ * @param[in] rhs RHS tile * @param[out] dst DST tile */ -#define T_ADD_BROADCAST_X(DST_DATA_TYPE, M0, N0, lhs, rhs, dst) \ +#define T_ELTWISE_BROADCAST_X(T_ELWISE_OP, DST_DATA_TYPE, M0, N0, lhs, rhs, dst) \ ({ \ LOOP_UNROLLING(int, _m0, 0, 1, M0, \ { \ - dst[_m0].v = CONVERT(lhs[_m0].v, VEC_DATA_TYPE(DST_DATA_TYPE, N0)) + CONVERT(rhs[0].v, VEC_DATA_TYPE(DST_DATA_TYPE, N0)); \ + dst[_m0].v = T_ELWISE_OP(CONVERT(lhs[_m0].v, VEC_DATA_TYPE(DST_DATA_TYPE, N0)), CONVERT(rhs[0].v, VEC_DATA_TYPE(DST_DATA_TYPE, N0))); \ }) \ }) -/** Element-wise addition between two tiles (LHS and RHS) +#define T_ELTWISE_ADD(DST_DATA_TYPE, M0, N0, lhs, rhs, dst) T_ELTWISE(V_ADD, DST_DATA_TYPE, M0, N0, lhs, rhs, dst) +#define T_ELTWISE_DIV(DST_DATA_TYPE, M0, N0, lhs, rhs, dst) T_ELTWISE(V_DIV, DST_DATA_TYPE, M0, N0, lhs, rhs, dst) + +/** Element-wise operation between two tiles (LHS and RHS) * - * @note Performs: LHS + RHS = DST + * @note Performs: LHS OP RHS = DST * @note Both tiles must have same data type * + * @param[in] T_ELWISE_OP Elementwise operator to perform * @param[in] DST_DATA_TYPE DST data type * @param[in] M0 Number of LHS rows * @param[in] N0 Number of LHS columns @@ -1042,11 +1092,30 @@ * @param[in] rhs RHS tile * @param[out] dst DST tile */ -#define T_ADD(DST_DATA_TYPE, M0, N0, lhs, rhs, dst) \ +#define T_ELTWISE(T_ELWISE_OP, DST_DATA_TYPE, M0, N0, lhs, rhs, dst) \ + ({ \ + LOOP_UNROLLING(int, _m0, 0, 1, M0, \ + { \ + dst[_m0].v = T_ELWISE_OP(CONVERT(lhs[_m0].v, VEC_DATA_TYPE(DST_DATA_TYPE, N0)), CONVERT(rhs[_m0].v, VEC_DATA_TYPE(DST_DATA_TYPE, N0))); \ + }) \ + }) + +/** Floor operation on a tile + * + * @note Performs: floor(SRC) = DST + * @note Both tiles must have same data type + * + * @param[in] DST_DATA_TYPE DST data type + * @param[in] M0 Number of SRC rows + * @param[in] N0 Number of SRC columns + * @param[in] src LHS tile + * @param[out] dst DST tile + */ +#define T_FLOOR(DST_DATA_TYPE, M0, N0, src, dst) \ ({ \ LOOP_UNROLLING(int, _m0, 0, 1, M0, \ { \ - dst[_m0].v = CONVERT(lhs[_m0].v, VEC_DATA_TYPE(DST_DATA_TYPE, N0)) + CONVERT(rhs[_m0].v, VEC_DATA_TYPE(DST_DATA_TYPE, N0)); \ + dst[_m0].v = floor(CONVERT(src[_m0].v, VEC_DATA_TYPE(DST_DATA_TYPE, N0))); \ }) \ }) diff --git a/src/core/CL/kernels/CLDepthwiseConvolutionLayerNativeKernel.cpp b/src/core/CL/kernels/CLDepthwiseConvolutionLayerNativeKernel.cpp index 61c8d90f78..d1f0338739 100644 --- a/src/core/CL/kernels/CLDepthwiseConvolutionLayerNativeKernel.cpp +++ b/src/core/CL/kernels/CLDepthwiseConvolutionLayerNativeKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021 Arm Limited. + * Copyright (c) 2019-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -211,8 +211,25 @@ void CLDepthwiseConvolutionLayerNativeKernel::configure(const CLCompileContext & arm_compute::opencl::kernels::gemm::update_padding_for_cl_image(weights->info()); } - build_opts.add_option("-cl-fast-relaxed-math"); - build_opts.add_option("-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(conv_info.act_info.activation()))); + // Conditions of -cl-fast-relaxed-math causing accuracy issues can be traced from COMPMID-5324 + const GPUTarget gpu_target = get_target(); + const auto act_function = conv_info.act_info.activation(); + const auto dst_data_type = _output->info()->data_type(); + + if((gpu_target != GPUTarget::G71 && (gpu_target & GPUTarget::GPU_ARCH_MASK) == GPUTarget::BIFROST) + && (act_function == ActivationLayerInfo::ActivationFunction::BOUNDED_RELU || act_function == ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU) + && (dst_data_type == DataType::F32 || dst_data_type == DataType::F16)) + { + // -cl-fast-relaxed-math also sets -cl-finite-math-only and -cl-unsafe-math-optimizations + // to disable -cl-finite-math-only, we only include -cl-unsafe-math-optimizations + build_opts.add_option("-cl-unsafe-math-optimizations"); + } + else + { + build_opts.add_option("-cl-fast-relaxed-math"); + } + + build_opts.add_option("-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(act_function))); build_opts.add_option("-DDEPTH_MULTIPLIER=" + support::cpp11::to_string(conv_info.depth_multiplier)); build_opts.add_option("-DSRC_TENSOR_TYPE=BUFFER"); // Note: SRC_DATA_TYPE must have the same data type of WEI_DATA_TYPE. In quantized, we could @@ -220,7 +237,7 @@ void CLDepthwiseConvolutionLayerNativeKernel::configure(const CLCompileContext & // only works when both have same data type, we have to change the offset to take into account this aspect build_opts.add_option("-DSRC_DATA_TYPE=" + get_cl_type_from_data_type(_input->info()->data_type())); build_opts.add_option("-DDST_TENSOR_TYPE=BUFFER"); - build_opts.add_option("-DDST_DATA_TYPE=" + get_cl_type_from_data_type(_output->info()->data_type())); + build_opts.add_option("-DDST_DATA_TYPE=" + get_cl_type_from_data_type(dst_data_type)); build_opts.add_option_if_else(_export_to_cl_image, "-DWEI_TENSOR_TYPE=IMAGE", "-DWEI_TENSOR_TYPE=BUFFER"); build_opts.add_option("-DWEI_WIDTH=" + support::cpp11::to_string(_weights->info()->dimension(1))); build_opts.add_option("-DWEI_HEIGHT=" + support::cpp11::to_string(_weights->info()->dimension(2))); diff --git a/src/core/CPP/CPPTypes.cpp b/src/core/CPP/CPPTypes.cpp index c197932a13..bd5236fcf8 100644 --- a/src/core/CPP/CPPTypes.cpp +++ b/src/core/CPP/CPPTypes.cpp @@ -101,6 +101,16 @@ bool CPUInfo::has_sve2() const return _impl->info.has_sve2(); } +bool CPUInfo::has_sme() const +{ + return false; +} + +bool CPUInfo::has_sme2() const +{ + return false; +} + CPUModel CPUInfo::get_cpu_model() const { return _impl->info.cpu_model(); diff --git a/src/core/GPUTarget.cpp b/src/core/GPUTarget.cpp index 625c0145df..292acf8633 100644 --- a/src/core/GPUTarget.cpp +++ b/src/core/GPUTarget.cpp @@ -35,6 +35,18 @@ arm_compute::GPUTarget get_valhall_target(const std::string &version) { return arm_compute::GPUTarget::G77; } + else if(version.find("G57") != std::string::npos) + { + return arm_compute::GPUTarget::G57; + } + if(version.find("G68") != std::string::npos) + { + return arm_compute::GPUTarget::G68; + } + if(version.find("G78AE") != std::string::npos) + { + return arm_compute::GPUTarget::G78AE; + } if(version.find("G78") != std::string::npos) { return arm_compute::GPUTarget::G78; @@ -43,6 +55,26 @@ arm_compute::GPUTarget get_valhall_target(const std::string &version) { return arm_compute::GPUTarget::G710; } + else if(version.find("G610") != std::string::npos) + { + return arm_compute::GPUTarget::G610; + } + else if(version.find("G510") != std::string::npos) + { + return arm_compute::GPUTarget::G510; + } + else if(version.find("G310") != std::string::npos) + { + return arm_compute::GPUTarget::G310; + } + else if(version.find("G715") != std::string::npos) + { + return arm_compute::GPUTarget::G715; + } + else if(version.find("G615") != std::string::npos) + { + return arm_compute::GPUTarget::G615; + } else { return arm_compute::GPUTarget::UNKNOWN; @@ -131,12 +163,21 @@ const std::string &string_from_target(GPUTarget target) { GPUTarget::G51, "g51" }, { GPUTarget::G51BIG, "g51big" }, { GPUTarget::G51LIT, "g51lit" }, + { GPUTarget::G31, "g31" }, + { GPUTarget::G76, "g76" }, { GPUTarget::G52, "g52" }, { GPUTarget::G52LIT, "g52lit" }, - { GPUTarget::G76, "g76" }, { GPUTarget::G77, "g77" }, + { GPUTarget::G57, "g57" }, { GPUTarget::G78, "g78" }, - { GPUTarget::G710, "g710" } + { GPUTarget::G68, "g68" }, + { GPUTarget::G78AE, "g78ae" }, + { GPUTarget::G710, "g710" }, + { GPUTarget::G610, "g610" }, + { GPUTarget::G510, "g510" }, + { GPUTarget::G310, "g310" }, + { GPUTarget::G715, "g715" }, + { GPUTarget::G615, "g615" }, }; return gpu_target_map[target]; diff --git a/src/core/NEON/kernels/NEGatherKernel.cpp b/src/core/NEON/kernels/NEGatherKernel.cpp index 7090da8015..085ab7cb18 100644 --- a/src/core/NEON/kernels/NEGatherKernel.cpp +++ b/src/core/NEON/kernels/NEGatherKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021 Arm Limited. + * Copyright (c) 2019-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -44,19 +44,23 @@ namespace * * @param[in] indices Indices tensor info. */ + template void validate_indices(const ITensor *indices) { - for(size_t i = 0; i < indices->info()->tensor_shape()[0]; ++i) + Window window; + window.use_tensor_dimensions(indices->info()->tensor_shape()); + execute_window_loop(window, [&](const Coordinates & id) { - ARM_COMPUTE_ERROR_ON(*(reinterpret_cast(indices->ptr_to_element(Coordinates(i)))) < 0); - } + const auto i = *(reinterpret_cast(indices->ptr_to_element(id))); + ARM_COMPUTE_UNUSED(i); + ARM_COMPUTE_ERROR_ON(i < 0); + }); } Status validate_arguments(const ITensorInfo *input, const ITensorInfo *indices, const ITensorInfo *output, int axis) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, indices, output); - ARM_COMPUTE_RETURN_ERROR_ON(indices->num_dimensions() > 1); ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 4); if(axis < 0) @@ -65,6 +69,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *indices, } ARM_COMPUTE_RETURN_ERROR_ON(0 > axis || axis >= static_cast(input->num_dimensions())); + ARM_COMPUTE_RETURN_ERROR_ON(axis != 1 && indices->num_dimensions() > 1); ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::UNKNOWN); if(output->total_size() != 0) @@ -86,6 +91,37 @@ NEGatherKernel::NEGatherKernel() { } +template +inline void NEGatherKernel::gather_multiindices_1_axis(const Window &window, const ThreadInfo &info) +{ + ARM_COMPUTE_UNUSED(info); + ARM_COMPUTE_ERROR_ON(_indices->info()->num_dimensions() < 2 || _indices->info()->num_dimensions() > 3); + validate_indices(_indices); + Window win = window; + win.set(Window::DimX, Window::Dimension(0, 1, 1)); + execute_window_loop(win, [&](const Coordinates & id) + { + auto *dst_ptr = _output->ptr_to_element(id); + Coordinates index_offset; + for(uint32_t k = 0; k < _indices->info()->num_dimensions(); ++k) + { + index_offset.set(k, id[k + 1]); + } + const uint32_t row = *(reinterpret_cast(_indices->ptr_to_element(index_offset))); + Coordinates src_offset; + // Set up input coords to read the row specified by the current index + src_offset.set(0, 0); + src_offset.set(1, row); + for(uint32_t j = 2; j < _input->info()->num_dimensions(); ++j) + { + src_offset.set(j, id[1 + _indices->info()->num_dimensions() + (j - 2)]); + } + const auto in_ptr_row = _input->ptr_to_element(src_offset); + // Copy a row from input to output + memcpy(dst_ptr, in_ptr_row, _input->info()->tensor_shape()[0] * _input->info()->element_size()); + }); +} + template inline void NEGatherKernel::gather_0_axis(const Window &window, const ThreadInfo &info) { @@ -147,38 +183,64 @@ void NEGatherKernel::configure(const ITensor *input, const ITensor *indices, ITe } ARM_COMPUTE_ERROR_ON(0 > _axis || _axis >= static_cast(input->info()->num_dimensions())); - if(0 == _axis) + if(indices->info()->num_dimensions() == 1u) { - switch(_indices->info()->data_type()) + if(_axis == 0) { - case DataType::U32: - _func = &NEGatherKernel::gather_0_axis; - break; - case DataType::S32: - _func = &NEGatherKernel::gather_0_axis; - break; - default: - ARM_COMPUTE_ERROR("Not supported"); - break; + switch(_indices->info()->data_type()) + { + case DataType::U32: + _func = &NEGatherKernel::gather_0_axis; + break; + case DataType::S32: + _func = &NEGatherKernel::gather_0_axis; + break; + default: + ARM_COMPUTE_ERROR("Not supported"); + break; + } + } + else + { + switch(_indices->info()->data_type()) + { + case DataType::U32: + _func = &NEGatherKernel::gather_n_axis; + break; + case DataType::S32: + _func = &NEGatherKernel::gather_n_axis; + break; + default: + ARM_COMPUTE_ERROR("Not supported"); + break; + } } } else { - switch(_indices->info()->data_type()) + if(_axis == 1) + { + switch(_indices->info()->data_type()) + { + case DataType::U32: + _func = &NEGatherKernel::gather_multiindices_1_axis; + break; + case DataType::S32: + _func = &NEGatherKernel::gather_multiindices_1_axis; + break; + default: + ARM_COMPUTE_ERROR("Not supported"); + break; + } + } + else { - case DataType::U32: - _func = &NEGatherKernel::gather_n_axis; - break; - case DataType::S32: - _func = &NEGatherKernel::gather_n_axis; - break; - default: - ARM_COMPUTE_ERROR("Not supported"); - break; + ARM_COMPUTE_ERROR("Not supported"); } } + // Output auto initialization if not yet initialized - TensorShape output_shape = arm_compute::misc::shape_calculator::compute_gather_shape(input->info()->tensor_shape(), indices->info()->tensor_shape(), _axis); + const TensorShape output_shape = arm_compute::misc::shape_calculator::compute_gather_shape(input->info()->tensor_shape(), indices->info()->tensor_shape(), _axis); auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape)); // Create window diff --git a/src/core/NEON/kernels/NEGatherKernel.h b/src/core/NEON/kernels/NEGatherKernel.h index 0711f8190b..3dc0cad7be 100644 --- a/src/core/NEON/kernels/NEGatherKernel.h +++ b/src/core/NEON/kernels/NEGatherKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021 Arm Limited. + * Copyright (c) 2019-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -61,17 +61,17 @@ class NEGatherKernel : public INEKernel /** Initialise the kernel's inputs and outputs * * @param[in] input Source tensor. Supported tensor rank: up to 4. Data type supported: All - * @param[in] indices Indices tensor. Supported tensor rank: up to 1. Must be one of the following type: U32/S32. Each value Must be in range [0, input.shape[@p axis]) + * @param[in] indices Indices tensor. Supported tensor rank: up to 3. Must be one of the following type: U32/S32. Each value Must be in range [0, input.shape[@p axis]) + * @note 2D or 3D indices are only supported for the axis 1. * @param[out] output Destination tensor. Data type supported: Same as @p input - * @param[in] axis (Optional) The axis in @p input to gather @p indices from. Negative values wrap around. Defaults to 0 + * @param[in] axis (Optional) The axis in @p input to gather @p indices from. Negative values wrap around. Defaults to 0. + * */ void configure(const ITensor *input, const ITensor *indices, ITensor *output, int axis = 0); - /** Static function to check if given info will lead to a valid configuration of @ref NEGatherKernel + + /** Static function to check if given info will lead to a valid configuration * - * @param[in] input Source tensor info. Supported tensor rank: up to 4. Data type supported: All - * @param[in] indices Indices tensor info. Supported tensor rank: up to 1. Must be one of the following type: U32/S32. Each value Must be in range [0, input.shape[@p axis]) - * @param[in] output Destination tensor info. Data type supported: Same as @p input - * @param[in] axis (Optional) The axis in @p input to gather @p indices from. Negative values wrap around. Defaults to 0 + * Similar to @ref NEGatherKernel::configure() * * @return a status */ @@ -85,18 +85,20 @@ class NEGatherKernel : public INEKernel * * For gather on the 0 axis an element by element copy is performed. * - * @param[in] window Region on which to execute the kernel. (Must be a region of the window returned by window()) - * @param[in] info Info about executing thread and CPU. + * @param[in] window Region on which to run the kernel. (Must be a region of the window returned by window()) + * @param[in] info Info about running thread and CPU. */ template void gather_0_axis(const Window &window, const ThreadInfo &info); + template + void gather_multiindices_1_axis(const Window &window, const ThreadInfo &info); /** Implementation of the gather operation. * * For 1<=axis a row-wise copy is taking place. * - * @param[in] window Region on which to execute the kernel. (Must be a region of the window returned by window()) - * @param[in] info Info about executing thread and CPU. + * @param[in] window Region on which to run the kernel. (Must be a region of the window returned by window()) + * @param[in] info Info about running thread and CPU. */ template void gather_n_axis(const Window &window, const ThreadInfo &info); diff --git a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_implementation.hpp b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_implementation.hpp index 0665fa3a29..1ee19e5075 100644 --- a/src/core/NEON/kernels/arm_conv/depthwise/depthwise_implementation.hpp +++ b/src/core/NEON/kernels/arm_conv/depthwise/depthwise_implementation.hpp @@ -136,7 +136,14 @@ UniqueDepthwiseCommon depthwise(const DepthwiseArgs &a { const DepthwiseImplementation *impl = nullptr; const bool success = find_implementation(args, os, impl); - return UniqueDepthwiseCommon(success ? impl->get_instance(args, os) : nullptr); + + if(success) + { + auto i = impl->get_instance(args, os); + i->set_name(impl->name); + return UniqueDepthwiseCommon(i); + } + return nullptr; } } // namespace depthwise diff --git a/src/core/NEON/kernels/arm_conv/pooling/pooling_depthfirst.hpp b/src/core/NEON/kernels/arm_conv/pooling/pooling_depthfirst.hpp index 556ae2a67a..63333c8fb4 100644 --- a/src/core/NEON/kernels/arm_conv/pooling/pooling_depthfirst.hpp +++ b/src/core/NEON/kernels/arm_conv/pooling/pooling_depthfirst.hpp @@ -27,8 +27,9 @@ #include "depthfirst_driver.hpp" #include "src/core/NEON/kernels/arm_conv/addressing.hpp" #include "utils.hpp" - +#if !defined(_WIN64) && !defined(__OpenBSD__) #include +#endif /* !defined(_WIN64) && !defined(__OpenBSD__) */ #include namespace arm_conv { diff --git a/src/core/NEON/kernels/arm_conv/pooling/pooling_depthfirst_generic.hpp b/src/core/NEON/kernels/arm_conv/pooling/pooling_depthfirst_generic.hpp index 227d808e82..65d9a91977 100644 --- a/src/core/NEON/kernels/arm_conv/pooling/pooling_depthfirst_generic.hpp +++ b/src/core/NEON/kernels/arm_conv/pooling/pooling_depthfirst_generic.hpp @@ -27,7 +27,9 @@ #include "arm_compute/core/Error.h" #include "depthfirst_driver.hpp" #include "utils.hpp" +#if !defined(_WIN64) && !defined(__OpenBSD__) #include +#endif /* !defined(_WIN64) && !defined(__OpenBSD__) */ namespace arm_conv { namespace pooling { diff --git a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp index dd72fb5901..58e4861bc0 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp @@ -33,12 +33,21 @@ #include "kernels/a32_sgemm_8x6.hpp" +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS +#include "kernels/a64_ffhybrid_bf16fp32_mmla_6x16.hpp" +#include "kernels/a64_ffinterleaved_bf16fp32_dot_8x12.hpp" +#include "kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp" +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/a64_hybrid_bf16fp32_dot_6x16.hpp" #include "kernels/a64_hybrid_bf16fp32_mmla_6x16.hpp" #include "kernels/a64_interleaved_bf16fp32_dot_8x12.hpp" #include "kernels/a64_interleaved_bf16fp32_mmla_8x12.hpp" #include "kernels/a64_sgemm_8x12.hpp" +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS +#include "kernels/sve_ffhybrid_bf16fp32_mmla_6x4VL.hpp" +#include "kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp" +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/sve_hybrid_bf16fp32_dot_6x4VL.hpp" #include "kernels/sve_hybrid_bf16fp32_mmla_6x4VL.hpp" #include "kernels/sve_interleaved_bf16fp32_dot_8x3VL.hpp" @@ -80,6 +89,24 @@ GemmImplementation::with_estimate( [](const GemmArgs &args) { return GemmInterleaved::estimate_cycles(args); }, [](const GemmArgs &args) { return new GemmInterleaved(args); } ), +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS +GemmImplementation::with_estimate( + GemmMethod::GEMM_INTERLEAVED, + "sve_ffinterleaved_bf16fp32_mmla_8x3VL", + KernelWeightFormat::VL2VL_BL64, + [](const GemmArgs &args) { return args._ci->has_svebf16(); }, + [](const GemmArgs &args) { return GemmInterleavedFixedFormat::estimate_cycles(args); }, + [](const GemmArgs &args) { return new GemmInterleavedFixedFormat(args); } +), +GemmImplementation::with_estimate( + GemmMethod::GEMM_INTERLEAVED, + "sve_ffhybrid_bf16fp32_mmla_6x4VL", + KernelWeightFormat::VL2VL_BL64, + [](const GemmArgs &args) { return args._ci->has_svebf16(); }, + [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat::estimate_cycles(args); }, + [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat(args); } +), +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #endif // ARM_COMPUTE_ENABLE_SVE GemmImplementation::with_estimate( GemmMethod::GEMM_HYBRID, @@ -109,6 +136,32 @@ GemmImplementation::with_estimate( [](const GemmArgs &args) { return GemmInterleaved::estimate_cycles(args); }, [](const GemmArgs &args) { return new GemmInterleaved(args); } ), +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS +GemmImplementation::with_estimate( + GemmMethod::GEMM_INTERLEAVED, + "a64_ffinterleaved_bf16fp32_mmla_8x12", + KernelWeightFormat::VL256_BL64, + [](const GemmArgs &args) { return args._ci->has_bf16(); }, + [](const GemmArgs &args) { return GemmInterleavedFixedFormat::estimate_cycles(args); }, + [](const GemmArgs &args) { return new GemmInterleavedFixedFormat(args); } +), +GemmImplementation::with_estimate( + GemmMethod::GEMM_INTERLEAVED, + "a64_ffhybrid_bf16fp32_mmla_6x16", + KernelWeightFormat::VL256_BL64, + [](const GemmArgs &args) { return args._ci->has_bf16(); }, + [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat::estimate_cycles(args); }, + [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat(args); } +), +GemmImplementation::with_estimate( + GemmMethod::GEMM_INTERLEAVED, + "a64_ffinterleaved_bf16fp32_dot_8x12", + KernelWeightFormat::VL128_BL32, + [](const GemmArgs &args) { return args._ci->has_bf16(); }, + [](const GemmArgs &args) { return GemmInterleavedFixedFormat::estimate_cycles(args); }, + [](const GemmArgs &args) { return new GemmInterleavedFixedFormat(args); } +), +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS GemmImplementation::with_estimate( GemmMethod::GEMM_INTERLEAVED, "a64_sgemm_8x12", @@ -144,7 +197,8 @@ const GemmImplementation *gemm_implementation_list gemm(const GemmArgs &args, const Nothing &); -template bool has_opt_gemm(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const Nothing &); +template KernelDescription get_gemm_method(const GemmArgs &args, const Nothing &); template std::vector get_compatible_kernels(const GemmArgs &args, const Nothing &); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp index 42f4528066..d749dce98d 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp @@ -34,9 +34,17 @@ #include "gemm_interleaved.hpp" #include "kernels/a32_sgemm_8x6.hpp" +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS +#include "kernels/a64_ffhybrid_fp16_mla_6x32.hpp" +#include "kernels/a64_ffinterleaved_fp16_mla_8x24.hpp" +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/a64_hgemm_8x24.hpp" #include "kernels/a64_hybrid_fp16_mla_6x32.hpp" #include "kernels/a64_sgemm_8x12.hpp" +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS +#include "kernels/sve_ffhybrid_fp16_mla_6x4VL.hpp" +#include "kernels/sve_ffinterleaved_fp16_mla_8x3VL.hpp" +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/sve_hybrid_fp16_mla_6x4VL.hpp" #include "kernels/sve_interleaved_fp16_mla_8x3VL.hpp" @@ -58,6 +66,24 @@ GemmImplementation<__fp16, __fp16>::with_estimate( [](const GemmArgs &args) { return GemmInterleaved::estimate_cycles<__fp16>(args); }, [](const GemmArgs &args) { return new GemmInterleaved(args); } ), +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS +GemmImplementation<__fp16, __fp16>::with_estimate( + GemmMethod::GEMM_INTERLEAVED, + "sve_ffinterleaved_fp16_mla_8x3VL", + KernelWeightFormat::VL1VL_BL16, + [](const GemmArgs &args) { return args._ci->has_sve(); }, + [](const GemmArgs &args) { return GemmInterleavedFixedFormat::estimate_cycles<__fp16>(args); }, + [](const GemmArgs &args) { return new GemmInterleavedFixedFormat(args); } +), +GemmImplementation<__fp16, __fp16>::with_estimate( + GemmMethod::GEMM_HYBRID, + "sve_ffhybrid_fp16_mla_6x4VL", + KernelWeightFormat::VL1VL_BL16, + [](const GemmArgs &args) { return args._ci->has_sve(); }, + [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat::estimate_cycles<__fp16>(args); }, + [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat(args); } +), +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #endif // ARM_COMPUTE_ENABLE_SVE #if defined(__aarch64__) GemmImplementation<__fp16, __fp16>::with_estimate( @@ -74,6 +100,24 @@ GemmImplementation<__fp16, __fp16>::with_estimate( [](const GemmArgs &args) { return GemmInterleaved::estimate_cycles<__fp16>(args); }, [](const GemmArgs &args) { return new GemmInterleaved(args); } ), +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS +GemmImplementation<__fp16, __fp16>::with_estimate( + GemmMethod::GEMM_INTERLEAVED, + "a64_ffinterleaved_fp16_mla_8x24", + KernelWeightFormat::VL128_BL16, + [](const GemmArgs &args) { return args._ci->has_fp16(); }, + [](const GemmArgs &args) { return GemmInterleavedFixedFormat::estimate_cycles<__fp16>(args); }, + [](const GemmArgs &args) { return new GemmInterleavedFixedFormat(args); } +), +GemmImplementation<__fp16, __fp16>::with_estimate( + GemmMethod::GEMM_HYBRID, + "a64_ffhybrid_fp16_mla_6x32", + KernelWeightFormat::VL128_BL16, + [](const GemmArgs &args) { return args._ci->has_fp16(); }, + [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat::estimate_cycles<__fp16>(args); }, + [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat(args); } +), +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS { GemmMethod::GEMM_INTERLEAVED, "a64_sgemm_8x12", @@ -108,7 +152,8 @@ const GemmImplementation<__fp16, __fp16> *gemm_implementation_list<__fp16, __fp1 /* Explicitly instantiate the external functions for these types. */ template UniqueGemmCommon<__fp16, __fp16> gemm<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &); -template bool has_opt_gemm<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm<__fp16, __fp16, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &); +template KernelDescription get_gemm_method<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &); template std::vector get_compatible_kernels<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp index 69a2803903..0fc9e8b912 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp @@ -31,6 +31,12 @@ #include "gemv_pretransposed.hpp" #include "kernels/a32_sgemm_8x6.hpp" +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS +#include "kernels/a64_ffhybrid_fp32_mla_6x16.hpp" +#include "kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp" +#include "kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp" +#include "kernels/a64_ffinterleaved_fp32_mla_8x12.hpp" +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/a64_hybrid_fp32bf16fp32_mmla_4x24.hpp" #include "kernels/a64_hybrid_fp32bf16fp32_mmla_6x16.hpp" #include "kernels/a64_hybrid_fp32_mla_4x24.hpp" @@ -42,6 +48,12 @@ #include "kernels/a64_smallK_hybrid_fp32_mla_6x4.hpp" #include "kernels/a64_smallK_hybrid_fp32_mla_8x4.hpp" +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS +#include "kernels/sve_ffhybrid_fp32_mla_6x4VL.hpp" +#include "kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp" +#include "kernels/sve_ffinterleaved_fp32_mla_8x3VL.hpp" +#include "kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp" +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #include "kernels/sve_hybrid_fp32bf16fp32_mmla_4x6VL.hpp" #include "kernels/sve_hybrid_fp32bf16fp32_mmla_6x4VL.hpp" #include "kernels/sve_hybrid_fp32_mla_6x4VL.hpp" @@ -73,6 +85,7 @@ GemmImplementation::with_estimate( [](const GemmArgs &args) { return GemmInterleaved::estimate_cycles(args); }, [](const GemmArgs &args) { return new GemmInterleaved(args); } ), + GemmImplementation::with_estimate( GemmMethod::GEMM_HYBRID, "a64_hybrid_fp32bf16fp32_mmla_6x16", @@ -152,6 +165,42 @@ GemmImplementation::with_estimate( [](const GemmArgs &args) { return GemmInterleaved::estimate_cycles(args); }, [](const GemmArgs &args) { return new GemmInterleaved(args); } ), + #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS +#ifdef ARM_COMPUTE_ENABLE_BF16 +GemmImplementation::with_estimate( + GemmMethod::GEMM_INTERLEAVED, + "sve_ffinterleaved_bf16fp32_mmla_8x3VL", + KernelWeightFormat::VL2VL_BL64_BF16, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_svebf16(); }, + [](const GemmArgs &args) { return GemmInterleavedFixedFormat::estimate_cycles(args); }, + [](const GemmArgs &args) { return new GemmInterleavedFixedFormat(args); } +), +GemmImplementation::with_estimate( + GemmMethod::GEMM_HYBRID, + "sve_ffhybrid_fp32bf16fp32_mmla_4x6VL", + KernelWeightFormat::VL2VL_BL64_BF16, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_svebf16(); }, + [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat::estimate_cycles(args); }, + [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat(args); } +), +#endif +GemmImplementation::with_estimate( + GemmMethod::GEMM_INTERLEAVED, + "sve_ffinterleaved_fp32_mla_8x3VL", + KernelWeightFormat::VL1VL_BL32, + [](const GemmArgs &args) { return args._ci->has_sve(); }, + [](const GemmArgs &args) { return GemmInterleavedFixedFormat::estimate_cycles(args); }, + [](const GemmArgs &args) { return new GemmInterleavedFixedFormat(args); } +), +GemmImplementation::with_estimate( + GemmMethod::GEMM_HYBRID, + "sve_ffhybrid_fp32_mla_6x4VL", + KernelWeightFormat::VL1VL_BL32, + [](const GemmArgs &args) { return args._ci->has_sve(); }, + [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat::estimate_cycles(args); }, + [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat(args); } +), +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #endif // ARM_COMPUTE_ENABLE_SVE // Cortex-A35 specific kernel - use for any problem on A35, and never in any other cases. { @@ -204,6 +253,43 @@ GemmImplementation::with_estimate( [](const GemmArgs &args) { return GemmInterleaved::estimate_cycles(args); }, [](const GemmArgs &args) { return new GemmInterleaved(args); } ), +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS +#ifdef ARM_COMPUTE_ENABLE_BF16 +// "fast mode" (BF16) kernels +GemmImplementation::with_estimate( + GemmMethod::GEMM_INTERLEAVED, + "a64_ffinterleaved_bf16fp32_mmla_8x12", + KernelWeightFormat::VL256_BL64_BF16, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); }, + [](const GemmArgs &args) { return GemmInterleavedFixedFormat::estimate_cycles(args); }, + [](const GemmArgs &args) { return new GemmInterleavedFixedFormat(args); } +), +GemmImplementation::with_estimate( + GemmMethod::GEMM_HYBRID, + "a64_ffhybrid_fp32bf16fp32_mmla_4x24", + KernelWeightFormat::VL256_BL64_BF16, + [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); }, + [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat::estimate_cycles(args); }, + [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat(args); } +), +#endif // BF16 +GemmImplementation::with_estimate( + GemmMethod::GEMM_INTERLEAVED, + "a64_ffinterleaved_fp32_mla_8x12", + KernelWeightFormat::VL128_BL32, + nullptr, + [](const GemmArgs &args) { return GemmInterleavedFixedFormat::estimate_cycles(args); }, + [](const GemmArgs &args) { return new GemmInterleavedFixedFormat(args); } +), +GemmImplementation::with_estimate( + GemmMethod::GEMM_HYBRID, + "a64_ffhybrid_fp32_mla_6x16", + KernelWeightFormat::VL128_BL32, + nullptr, + [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat::estimate_cycles(args); }, + [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat(args); } +), +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS #endif // __aarch64__ #ifdef __arm__ @@ -232,7 +318,8 @@ const GemmImplementation *gemm_implementation_list() /* Explicitly instantiate the external functions for these types. */ template UniqueGemmCommon gemm(const GemmArgs &args, const Nothing &); -template bool has_opt_gemm(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const Nothing &); +template KernelDescription get_gemm_method(const GemmArgs &args, const Nothing &); template std::vector get_compatible_kernels (const GemmArgs &args, const Nothing &); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp index 5b3ef4203d..90e2f07607 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp @@ -33,6 +33,7 @@ #include "arm_gemm.hpp" #include "bias_adder.hpp" #include "convolver.hpp" +#include "kernel_weight_format.hpp" #include "ndrange.hpp" #include "performance_parameters.hpp" #include "transform.hpp" @@ -54,7 +55,7 @@ namespace { // We need to invoke the kernel differently for quantizing and non-quantizing cases, so here is a shim class to do // that. -template +template class run_hybrid_kernel { public: template @@ -63,18 +64,18 @@ class run_hybrid_kernel { profiler &prof, #endif const strategy &strat, unsigned int num_strings, const unsigned int *string_ptr, IndirectInputArg A_arg, unsigned int M, unsigned int N, - unsigned int kern_k, const Tro *b_ptr, IndirectOutputArg output_arg, const Tr *bias_ptr, Activation act, bool accumulate, + unsigned int kern_k, const Tro *b_ptr, size_t b_stride, IndirectOutputArg output_arg, const Tr *bias_ptr, Activation act, bool accumulate, const OutputStage &os, const int32_t *col_bias, unsigned int n_0 ); }; template<> template -inline void run_hybrid_kernel::run( +inline void run_hybrid_kernel::run( #ifdef CYCLE_PROFILING profiler &prof, #endif const strategy &strat, unsigned int num_strings, const unsigned int *string_ptr, IndirectInputArg A_arg, unsigned int M, unsigned int N, - unsigned int kern_k, const Tro *b_ptr, IndirectOutputArg output_arg, const Tr *bias_ptr, Activation act, bool accumulate, + unsigned int kern_k, const Tro *b_ptr, size_t, IndirectOutputArg output_arg, const Tr *bias_ptr, Activation act, bool accumulate, const Nothing &, const int32_t *, unsigned int) { #ifdef CYCLE_PROFILING auto p = prof.ScopedProfiler(PROFILE_KERNEL, (unsigned long)M * kern_k * roundup(N, strategy::out_width())); @@ -115,12 +116,60 @@ inline void run_hybrid_kernel::run( template<> template -inline void run_hybrid_kernel::run( +inline void run_hybrid_kernel::run( #ifdef CYCLE_PROFILING profiler &prof, #endif const strategy &strat, unsigned int num_strings, const unsigned int *string_ptr, IndirectInputArg A_arg, unsigned int M, unsigned int N, - unsigned int kern_k, const Tro *b_ptr, IndirectOutputArg output_arg, const Tr *, Activation, bool, + unsigned int kern_k, const Tro *b_ptr, size_t b_stride, IndirectOutputArg output_arg, const Tr *bias_ptr, Activation act, bool accumulate, + const Nothing &, const int32_t *, unsigned int) { +#ifdef CYCLE_PROFILING + auto p = prof.ScopedProfiler(PROFILE_KERNEL, (unsigned long)M * kern_k * roundup(N, strategy::out_width())); +#endif + UNUSED(kern_k); + + /* Indirect hybrid kernels read the full width of the bias. So we need to detect the case where we are writing + * a partial block and pad the bias for that block. */ + if (bias_ptr && !accumulate && (N % strategy::out_width() != 0)) { + /* Break N into "N_bulk" (a multiple of output width) and "N_remainder" */ + unsigned int N_remainder = N % strategy::out_width(); + unsigned int N_bulk = N - N_remainder; + + /* Output argument to be used for the tail */ + IndirectOutputArg offset_output = output_arg; + + /* If there is a "bulk" to be processed, handle that and update "offset_output" appropriately. */ + if (N_bulk > 0) { + strat.kernel(num_strings, string_ptr, A_arg, M, N_bulk, b_ptr, b_stride, output_arg, bias_ptr, act, accumulate); + + if (output_arg.is_indirect) { + offset_output = IndirectOutputArg(output_arg.indirect.ptr, output_arg.indirect.offset + N_bulk); + } else { + offset_output = IndirectOutputArg(output_arg.direct.base + N_bulk, output_arg.direct.stride); + } + } + + /* Pad the bias buffer for the remainder */ + Tr *bias_pad_buffer = reinterpret_cast(alloca(strategy::out_width() * sizeof(Tr))); + memcpy(bias_pad_buffer, bias_ptr + N_bulk, N_remainder * sizeof(Tr)); + + /* Process the remainder, offsetting the B pointer as needed. */ + strat.kernel(num_strings, string_ptr, A_arg, M, N_remainder, + b_ptr + (N_bulk / strategy::stripe_width()) * b_stride, b_stride, offset_output, + bias_pad_buffer, act, accumulate); + } else { + strat.kernel(num_strings, string_ptr, A_arg, M, N, b_ptr, b_stride, output_arg, bias_ptr, act, accumulate); + } +} + +template<> +template +inline void run_hybrid_kernel::run( +#ifdef CYCLE_PROFILING + profiler &prof, +#endif + const strategy &strat, unsigned int num_strings, const unsigned int *string_ptr, IndirectInputArg A_arg, unsigned int M, unsigned int N, + unsigned int kern_k, const Tro *b_ptr, size_t, IndirectOutputArg output_arg, const Tr *, Activation, bool, const Requantize32 &os, const int32_t *col_bias, unsigned int n_0 ) { #ifdef CYCLE_PROFILING auto p = prof.ScopedProfiler(PROFILE_KERNEL, (unsigned long)M * kern_k * roundup(N, strategy::out_width())); @@ -132,12 +181,12 @@ inline void run_hybrid_kernel::run( template<> template -inline void run_hybrid_kernel::run( +inline void run_hybrid_kernel::run( #ifdef CYCLE_PROFILING profiler &prof, #endif const strategy &strat, unsigned int num_strings, const unsigned int *string_ptr, IndirectInputArg A_arg, unsigned int M, unsigned int N, - unsigned int kern_k, const Tro *b_ptr, IndirectOutputArg output_arg, const Tr *, Activation, bool, + unsigned int kern_k, const Tro *b_ptr, size_t, IndirectOutputArg output_arg, const Tr *, Activation, bool, const Requantize32 &os, const int32_t *col_bias, unsigned int n_0 ) { UNUSED(kern_k); // On this route we will only process one kernel height at a time and will make sure this happens in the driver loop. @@ -180,10 +229,38 @@ inline void run_hybrid_kernel::run( } } +template +struct stripe_width { + static unsigned int get() { + return strategy::stripe_width(); + } +}; + +template +struct stripe_width { + static unsigned int get() { + return 0; + } +}; + +template +struct kernel_weight_format { + static KernelWeightFormat get() { + return strategy::kernel_weight_format(); + } +}; + +template +struct kernel_weight_format { + static KernelWeightFormat get() { + return KernelWeightFormat::NON_FIXED; + } +}; + } // anonymous namespace // Implementation of the GemmCommon abstract class. -template +template class GemmHybridIndirect : public GemmCommon { typedef typename strategy::lhs_operand_type Tloi; typedef typename strategy::rhs_operand_type Troi; @@ -373,7 +450,7 @@ class GemmHybridIndirect : public GemmCommon { } /* Make sure we've been set up correctly. */ - assert(_B_transposed); + assert(FixedFormat || _B_transposed); static_assert(std::is_same::value, "gemm_native: Operand types must be the same."); // static_assert(std::is_same::value, "gemm_native: Result types must be the same."); @@ -425,24 +502,32 @@ class GemmHybridIndirect : public GemmCommon { const unsigned int nmax = std::min(n0 + _n_block, _args._Nsize); const unsigned int multi = p.dim(3); - const Troi *b_panel = _B_transposed + - (multi * roundup(_args._Nsize, strategy::out_width()) * _Ktotal) + - (k0 * roundup(_args._Nsize, strategy::out_width())) + - (n0 * kern_k); + const Troi *b_panel; + if (FixedFormat) { + b_panel = reinterpret_cast(this->_Bptr) + + (multi * this->_B_multi_stride) + + ((n0 / stripe_width::get()) * this->_ldb) + + (k0 * stripe_width::get()); + } else { + b_panel = _B_transposed + + (multi * roundup(_args._Nsize, strategy::out_width()) * _Ktotal) + + (k0 * roundup(_args._Nsize, strategy::out_width())) + + (n0 * kern_k); + } - IndirectOutputArg out_arg(this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (m_start * this->_ldc) + n0, this->_ldc); + IndirectOutputArg out_arg(this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (m_start * this->_ldc) + n0, this->_ldc); #ifdef CYCLE_PROFILING auto p = prof.ScopedProfiler(PROFILE_KERNEL, (unsigned long)(m_end - m_start) * kern_k * roundup(nmax-n0, strategy::out_width())); #endif if (_indirect_buf) { - run_hybrid_kernel::run( + run_hybrid_kernel::run( #ifdef CYCLE_PROFILING prof, #endif strat, sections, string_lengths.data(), IndirectInputArg(_indirect_buf + (multi * _args._nbatches * _args._Ksections) + (batch * _args._Ksections) + first_section, m_start, first_offset), - (m_end - m_start), (nmax - n0), kern_k, b_panel, out_arg, + (m_end - m_start), (nmax - n0), kern_k, b_panel, this->_ldb, out_arg, (this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr, last_pass ? _args._act : Activation(), !first_pass, @@ -469,13 +554,13 @@ class GemmHybridIndirect : public GemmCommon { } assert(pos == sections); - run_hybrid_kernel::run( + run_hybrid_kernel::run( #ifdef CYCLE_PROFILING prof, #endif strat, sections, string_lengths.data(), IndirectInputArg(in_row_strings.data(), 0, first_offset), - (m_end - m_start), (nmax - n0), kern_k, b_panel, out_arg, + (m_end - m_start), (nmax - n0), kern_k, b_panel, this->_ldb, out_arg, (this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr, last_pass ? _args._act : Activation(), !first_pass, @@ -485,13 +570,13 @@ class GemmHybridIndirect : public GemmCommon { // Length to process. This needs to exclude padding, but 'kmax' potentially includes it. const unsigned int len = (std::min(_args._Ksize, kmax) - k0); - run_hybrid_kernel::run( + run_hybrid_kernel::run( #ifdef CYCLE_PROFILING prof, #endif strat, 1, &len, IndirectInputArg(this->_Aptr + (multi * this->_A_multi_stride) + (batch * this->_A_batch_stride) + m_start * this->_lda + k0, this->_lda), - (m_end - m_start), (nmax - n0), kern_k, b_panel, out_arg, + (m_end - m_start), (nmax - n0), kern_k, b_panel, this->_ldb, out_arg, (this->_bias && first_pass) ? this->_bias + (multi * this->_bias_multi_stride) + n0 : nullptr, last_pass ? _args._act : Activation(), !first_pass, @@ -504,14 +589,18 @@ class GemmHybridIndirect : public GemmCommon { // Interface implementation - pretransposed bool B_is_pretransposed() const override { - return true; + return (FixedFormat == false); } bool B_pretranspose_required() const override { - return (_B_transposed==nullptr); + return (FixedFormat == false) && (_B_transposed==nullptr); } size_t get_B_pretransposed_array_size() const override { + if (FixedFormat) { + return 0; + } + // Start with actual pretransposed buffer... size_t size = roundup(_args._Nsize, strategy::out_width()) * _Ktotal * _args._nmulti * sizeof(Troi); @@ -599,8 +688,7 @@ class GemmHybridIndirect : public GemmCommon { } } } else { - // In the single K section case, can process the whole lot in one go. - // Caution: 'blockwalker::kmax()' rounds up, so clamp to valid _Ksize. + // In the single K section case, can process the whole lot in one go. strat.transforms.PrepareB(buffer, B + (multi * B_multi_stride), ldb, 0, _args._Nsize, k0, std::min(kmax, _args._Ksize)); buffer += roundup(_args._Nsize, strategy::out_width()) * roundup(kmax-k0, strategy::k_unroll()); @@ -694,11 +782,15 @@ class GemmHybridIndirect : public GemmCommon { c.inner_block_size = _k_block; c.outer_block_size = _n_block; c.filter = get_type_name(); + c.weight_format = get_weight_format(kernel_weight_format::get(), sizeof(To)); return c; } }; +template +using GemmHybridIndirectFixedFormat = GemmHybridIndirect; + } // namespace arm_gemm #ifdef __I_DEFINED_UNUSED diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp index cb3ff7aa29..19c8fcadd3 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp @@ -24,6 +24,8 @@ #include "arm_gemm.hpp" +#include "kernel_weight_format.hpp" + #include #include @@ -37,15 +39,36 @@ template struct GemmImplementation { const GemmMethod method; const char * name; + const KernelWeightFormat kernel_weight_format = KernelWeightFormat::NON_FIXED; std::function is_supported = {}; std::function cycle_estimate = {}; std::function *(const GemmArgs &, const OutputStage &)> instantiate = {}; bool do_is_supported(const GemmArgs &args, const OutputStage &os) const { - if (is_supported != nullptr) { - return is_supported(args, os); + // Check supplied is_supported() function first. + if (is_supported != nullptr && !is_supported(args, os)) { + return false; + } + + // Check weight format is appropriate. + if (args._fixed_format == false) { + // Can't return a fixed format kernel if we weren't asked for one. + return (kernel_weight_format == KernelWeightFormat::NON_FIXED); } else { - return true; + // Fixed format kernel requested: if this is a non-fixed format kernel we can't use it. + if (kernel_weight_format == KernelWeightFormat::NON_FIXED) { + return false; + } + + // If there's no config, or the config says ANY then this one is OK. + if (!args._cfg || args._cfg->weight_format == WeightFormat::ANY) { + return true; + } + + // If we get here it means there is a config and it specifies a format. Check it matches this kernel. + // NOTE: this will execute SVE instructions if it's an SVE kernel, so it's important that is_supported() + // was called above first. + return (args._cfg->weight_format == get_weight_format(kernel_weight_format, sizeof(Top))); } } @@ -84,6 +107,13 @@ struct GemmImplementation { method(m), name(n), is_supported(is_supported), cycle_estimate( [is_recommended](const GemmArgs &args, const OutputStage &os) { return (is_recommended == nullptr) ? 0 : (is_recommended(args, os) ? 0 : UINT64_MAX); } ), instantiate(instantiate) { } + + GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat kwf, + std::function is_supported, std::function is_recommended, + std::function *(const GemmArgs &, const OutputStage &)> instantiate) : + method(m), name(n), kernel_weight_format(kwf), is_supported(is_supported), + cycle_estimate( [is_recommended](const GemmArgs &args, const OutputStage &os) { return (is_recommended == nullptr) ? 0 : (is_recommended(args, os) ? 0 : UINT64_MAX); } ), + instantiate(instantiate) { } }; /* Slightly different version of above for straightforward GEMMs with no @@ -93,15 +123,36 @@ template struct GemmImplementation { const GemmMethod method; const char * name; + const KernelWeightFormat kernel_weight_format = KernelWeightFormat::NON_FIXED; std::function is_supported = {}; std::function cycle_estimate = {}; std::function *(const GemmArgs &)> instantiate = {}; bool do_is_supported(const GemmArgs &args, const Nothing &) const { - if (is_supported != nullptr) { - return is_supported(args); + // Check supplied is_supported() function first. + if (is_supported != nullptr && !is_supported(args)) { + return false; + } + + // Check weight format is appropriate. + if (args._fixed_format == false) { + // Can't return a fixed format kernel if we weren't asked for one. + return (kernel_weight_format == KernelWeightFormat::NON_FIXED); } else { - return true; + // Fixed format kernel requested: if this is a non-fixed format kernel we can't use it. + if (kernel_weight_format == KernelWeightFormat::NON_FIXED) { + return false; + } + + // If there's no config, or the config says ANY then this one is OK. + if (!args._cfg || args._cfg->weight_format == WeightFormat::ANY) { + return true; + } + + // If we get here it means there is a config and it specifies a format. Check it matches this kernel. + // NOTE: this will execute SVE instructions if it's an SVE kernel, so it's important that is_supported() + // was called above first. + return (args._cfg->weight_format == get_weight_format(kernel_weight_format, sizeof(Top))); } } @@ -129,10 +180,22 @@ struct GemmImplementation { return impl; } + static GemmImplementation with_estimate(GemmMethod m, const char *n, KernelWeightFormat f, + std::function is_supported, std::function cycle_estimate, + std::function *(const GemmArgs &)> instantiate) { + GemmImplementation impl(m,n,f); + + impl.is_supported=is_supported; + impl.cycle_estimate=cycle_estimate; + impl.instantiate=instantiate; + + return impl; + } + GemmImplementation(const GemmImplementation &) = default; GemmImplementation & operator= (const GemmImplementation &) = default; - GemmImplementation(GemmMethod m, const char * n) : method(m), name(n) {} + GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat f=KernelWeightFormat::NON_FIXED) : method(m), name(n), kernel_weight_format(f) {} GemmImplementation(GemmMethod m, const char *n, std::function is_supported, std::function is_recommended, @@ -140,6 +203,13 @@ struct GemmImplementation { method(m), name(n), is_supported(is_supported), cycle_estimate( [is_recommended](const GemmArgs &args) -> uint64_t { return (is_recommended == nullptr) ? 0 : (is_recommended(args) ? 0 : UINT64_MAX); } ), instantiate(instantiate) { } + + GemmImplementation(GemmMethod m, const char *n, KernelWeightFormat kwf, + std::function is_supported, std::function is_recommended, + std::function *(const GemmArgs &)> instantiate) : + method(m), name(n), kernel_weight_format(kwf), is_supported(is_supported), + cycle_estimate( [is_recommended](const GemmArgs &args) -> uint64_t { return (is_recommended == nullptr) ? 0 : (is_recommended(args) ? 0 : UINT64_MAX); } ), + instantiate(instantiate) { } }; /* "Main" function implemented for each valid combination of types. @@ -236,9 +306,12 @@ std::vector get_compatible_kernels(const GemmArgs &args, cons } template -bool has_opt_gemm(const GemmArgs &args, const OutputStage &os) { +bool has_opt_gemm(WeightFormat &wf, const GemmArgs &args, const OutputStage &os) { const GemmImplementation *impl; - return find_implementation(args, os, impl); + const bool success = find_implementation(args, os, impl); + if (success) + wf = UniqueGemmCommon(impl->do_instantiate(args, os))->get_config().weight_format; + return success; } template @@ -252,4 +325,17 @@ UniqueGemmCommon gemm(const GemmArgs &args, const OutputStage &os) { return UniqueGemmCommon(nullptr); } +template +KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage &os) { + const GemmImplementation *impl; + + if (find_implementation(args, os, impl)) { + return KernelDescription(impl->method, impl->name); + } + + /* This shouldn't happen - there should always be at least one valid implementation. */ + return KernelDescription(); +} + + } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp index 3915861112..18d8fc9312 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp @@ -56,7 +56,7 @@ const GemmImplementation *gemm_implementation_list gemm(const GemmArgs &args, const Nothing &); -template bool has_opt_gemm(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const Nothing &); template std::vector get_compatible_kernels (const GemmArgs &args, const Nothing &); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp index 0c68e4dd99..24507486ac 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp @@ -159,7 +159,7 @@ const GemmImplementation *gemm_implementation_list gemm(const GemmArgs &args, const Nothing &); -template bool has_opt_gemm(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const Nothing &); template std::vector get_compatible_kernels (const GemmArgs &args, const Nothing &); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp index c75c320a6b..470cee1557 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -27,7 +27,9 @@ #include #include "arm_gemm.hpp" +#include "bfloat.hpp" #include "convolver.hpp" +#include "kernel_weight_format.hpp" #include "mergeresults.hpp" #include "performance_parameters.hpp" #include "quantized.hpp" @@ -56,7 +58,7 @@ namespace { // Others output directly to the matrix result. This helper class calls the // appropriate functions, using templating to avoid calling non-existent // functions. -template +template class kernel_and_merge { public: template @@ -64,7 +66,7 @@ class kernel_and_merge { #ifdef CYCLE_PROFILING profiler &prof, #endif - strategy &strat, const To *a_ptr, const To *b_panel, Tri *c_panel, + strategy &strat, const To *a_ptr, const To *b_panel, size_t b_stride, Tri *c_panel, Tr *c_ptr, int ldc, int kern_k, unsigned int m_0, unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *biasptr, const Activation &act, bool accumulate, const OutputStage &os, const int32_t *col_bias, @@ -74,11 +76,11 @@ class kernel_and_merge { // Run a kernel and call the separate merge step template<> template -void kernel_and_merge::run( +void kernel_and_merge::run( #ifdef CYCLE_PROFILING profiler &prof, #endif - strategy &strat, const To *a_ptr, const To *b_panel, Tri *c_panel, + strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *c_panel, Tr *c_ptr, int ldc, int kern_k, unsigned int m_0, unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *biasptr, const Activation &act, bool accumulate, const Nothing &, const int32_t *, Tab *) @@ -101,14 +103,44 @@ void kernel_and_merge::run( } } +// Run a fixed-format kernel and call the separate merge step +template<> +template +void kernel_and_merge::run( +#ifdef CYCLE_PROFILING + profiler &prof, +#endif + strategy &strat, const To *a_ptr, const To *b_panel, size_t b_stride, Tri *c_panel, + Tr *c_ptr, int ldc, int kern_k, unsigned int m_0, + unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *biasptr, + const Activation &act, bool accumulate, const Nothing &, const int32_t *, Tab *) +{ + { +#ifdef CYCLE_PROFILING + const int bblocks = iceildiv(n_max - n_0, strategy::out_width()); + auto p=prof.ScopedProfiler(PROFILE_KERNEL, (strategy::out_height() * bblocks * strategy::out_width() * kern_k)); +#endif + + strat.kernel(a_ptr, b_panel, b_stride, c_panel, 1, (n_max - n_0), kern_k); + } + + { +#ifdef CYCLE_PROFILING + const int bblocks = iceildiv(n_max - n_0, strategy::out_width()); + auto p=prof.ScopedProfiler(PROFILE_MERGE, (strategy::out_height() * bblocks * strategy::out_width() * sizeof(Tr))); +#endif + strat.transforms.Merge(c_ptr, c_panel, ldc, m_0, m_max, n_0, n_max, biasptr, act, accumulate); + } +} + // Run a kernel with integrated merge template<> template -void kernel_and_merge::run( +void kernel_and_merge::run( #ifdef CYCLE_PROFILING profiler &prof, #endif - strategy &strat, const To *a_ptr, const To *b_panel, Tri *, + strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *, Tr *c_ptr, int ldc, int kern_k, unsigned int m_0, unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *biasptr, const Activation &act, bool accumulate, const Nothing &, const int32_t *, @@ -143,11 +175,11 @@ void kernel_and_merge::run( // Run a kernel with integrated merge, quantizing template<> template -void kernel_and_merge::run( +void kernel_and_merge::run( #ifdef CYCLE_PROFILING profiler &prof, #endif - strategy &strat, const To *a_ptr, const To *b_panel, Tri *, + strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *, Tr *c_ptr, int ldc, int kern_k, unsigned int m_0, unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *, const Activation &, bool accumulate, const Requantize32 &qp, const int32_t *col_bias, @@ -170,11 +202,11 @@ void kernel_and_merge::run( // Run a kernel and call the separate quantize step template<> template -void kernel_and_merge::run( +void kernel_and_merge::run( #ifdef CYCLE_PROFILING profiler &prof, #endif - strategy &strat, const To *a_ptr, const To *b_panel, Tri *c_panel, + strategy &strat, const To *a_ptr, const To *b_panel, size_t, Tri *c_panel, Tr *c_ptr, int ldc, int kern_k, unsigned int m_0, unsigned int m_max, unsigned int n_0, unsigned int n_max, const Tr *, const Activation &, bool, const Requantize32 &qp, const int32_t *col_bias, @@ -246,9 +278,49 @@ class accumulate_buffer_type { typedef int32_t type; }; +// Stripe width is a concept only needed for FixedFormat kernels. Use an accessor to avoid issues in other scenarios. +template +struct get_stripe_width { + static unsigned int get() { + return 0; + } +}; + +template +struct get_stripe_width { + static unsigned int get() { + return strategy::stripe_width(); + } +}; + +// KernelWeightFormat is a similar story. +template +struct get_kernel_weight_format { + static KernelWeightFormat get() { + return KernelWeightFormat::NON_FIXED; + } +}; + +template +struct get_kernel_weight_format { + static KernelWeightFormat get() { + KernelWeightFormat kwf = strategy::kernel_weight_format(); + + // If we are using a BF16 kernel to do an FP32 problem (fast mode) then we need to set the BF16 flag on the + // weight format. + if (std::is_same::value && std::is_same::value) { + uint32_t kwf_i = static_cast(kwf); + kwf_i |= 0x10; + kwf = static_cast(kwf_i); + } + + return kwf; + } +}; + } // anonymous namespace -template +template class GemmInterleaved : public GemmCommon { typedef typename strategy::operand_type Toi; typedef typename strategy::result_type Tri; @@ -310,7 +382,7 @@ class GemmInterleaved : public GemmCommon { class blockwalker { private: /* Size loops, etc. based on our parent's configuration */ - const GemmInterleaved &_parent; + const GemmInterleaved &_parent; /* K, X and multi parameters for current iteration. */ unsigned int _k0=0, _x0=0, _multi=0; @@ -325,9 +397,9 @@ class GemmInterleaved : public GemmCommon { bool _newmulti=true; public: - blockwalker(const GemmInterleaved &parent) : _parent(parent) { } + blockwalker(const GemmInterleaved &parent) : _parent(parent) { } - blockwalker(const GemmInterleaved &parent, + blockwalker(const GemmInterleaved &parent, unsigned int x_start, unsigned int x_end) : _parent(parent), _x0 (_x_start), _x_start(x_start), _x_end(x_end) { } unsigned int xmax() { @@ -623,7 +695,7 @@ class GemmInterleaved : public GemmCommon { #endif /* Make sure we've been set up correctly. */ - assert(_B_transposed); + assert(FixedFormat || _B_transposed); assert(_working_space); int8_t *working_space_bytes = reinterpret_cast(_working_space); @@ -666,7 +738,11 @@ class GemmInterleaved : public GemmCommon { // Figure out how many "K" the kernel will actually process. unsigned int kern_k = roundup(kmax - k0, strategy::k_unroll()); - const Toi *b_ptr = _B_transposed + (rounded_width * _Ktotal * multi) + (k0 * rounded_width) + (start_x * kern_k); + const Toi *b_ptr = FixedFormat ? + reinterpret_cast(this->_Bptr) + (multi * this->_B_multi_stride) + + ((start_x / get_stripe_width::get()) * this->_ldb) + + (k0 * get_stripe_width::get()) : + _B_transposed + (rounded_width * _Ktotal * multi) + (k0 * rounded_width) + (start_x * kern_k); unsigned int batch = batch_0; unsigned int start_row = (start - (batch_0 * window_per_batch)) * strategy::out_height(); @@ -699,12 +775,12 @@ class GemmInterleaved : public GemmCommon { } // Perform the kernel and merge step, either separately or together as required. - kernel_and_merge::run( + kernel_and_merge::run( #ifdef CYCLE_PROFILING prof, #endif // Strategy and panel pointers - strat, a_panel, b_ptr, c_panel, + strat, a_panel, b_ptr, this->_ldb, c_panel, // Result buffer pointers this->_Cptr + (batch * this->_C_batch_stride) + (multi * this->_C_multi_stride), this->_ldc, // K size, and M/N ranges @@ -802,6 +878,13 @@ class GemmInterleaved : public GemmCommon { } } + // For FixedFormat cases, figure out the B pointer. The loop below moves through batches and vertically through the output so this will be the same throughout. + if (FixedFormat) { + b_panel = reinterpret_cast(this->_Bptr) + (current.multi() * this->_B_multi_stride) + + ((current.x0() / get_stripe_width::get()) * this->_ldb) + + (current.k0() * get_stripe_width::get()); + } + /* Do the actual work. */ for (unsigned int batch = batch_0; batch <= batch_end; batch++) { unsigned int first_m = (batch == batch_0) ? m_0 : 0; @@ -840,12 +923,12 @@ class GemmInterleaved : public GemmCommon { } // Perform the kernel and merge step, either separately or together as required. - kernel_and_merge::run( + kernel_and_merge::run( #ifdef CYCLE_PROFILING prof, #endif // Strategy and panel pointers - strat, a_ptr, b_panel, c_panel, + strat, a_ptr, b_panel, this->_ldb, c_panel, // Result buffer pointers result_ptr, this->_ldc, // K size, and M/N ranges @@ -863,7 +946,9 @@ class GemmInterleaved : public GemmCommon { } } - b_panel += (roundup(current.xmax() - current.x0(), strategy::out_width()) * kern_k); + if (FixedFormat == false) { + b_panel += (roundup(current.xmax() - current.x0(), strategy::out_width()) * kern_k); + } } } } @@ -910,14 +995,18 @@ class GemmInterleaved : public GemmCommon { // Interface implementation - pretransposed bool B_is_pretransposed() const override { - return true; + return (FixedFormat == false); } bool B_pretranspose_required() const override { - return (_B_transposed==nullptr); + return (FixedFormat == false) && (_B_transposed==nullptr); } size_t get_B_pretransposed_array_size() const override { + if (FixedFormat) { + return 0; + } + unsigned int x_size = roundup(_Nsize, strategy::out_width()); return (x_size * _Ktotal * _nmulti * sizeof(Toi)) + get_col_sum_size(); @@ -939,7 +1028,7 @@ class GemmInterleaved : public GemmCommon { void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override { requantize_bias(in_buffer, B, ldb, B_multi_stride); - // Put the transposed data after the column sums - in non-transposing cases get_col_sum_size() == 0 + // Put the transposed data after the column sums - in non-quantized cases get_col_sum_size() == 0 uintptr_t buffer_int = reinterpret_cast(in_buffer); Toi *buffer = reinterpret_cast(buffer_int + get_col_sum_size()); _B_transposed = buffer; @@ -1005,7 +1094,7 @@ class GemmInterleaved : public GemmCommon { } void set_pretransposed_B_data(void *in_buffer) override { - // Put the transposed data after the column sums - in non-transposing cases get_col_sum_size() == 0 + // Put the transposed data after the column sums - in non-quantized cases get_col_sum_size() == 0 uintptr_t buffer_int = reinterpret_cast(in_buffer); _B_transposed = reinterpret_cast(buffer_int + get_col_sum_size()); col_bias = reinterpret_cast(in_buffer); @@ -1039,7 +1128,7 @@ class GemmInterleaved : public GemmCommon { uint64_t total_macs = static_cast(args._nbatches) * args._nmulti * roundup(args._Msize, strategy::out_height()) * roundup(args._Nsize, strategy::out_width()) * get_ktotal(args); uint64_t prepare_bytes = static_cast(args._nbatches) * args._nmulti * roundup(args._Msize, strategy::out_height()) * get_ktotal(args) * sizeof(Toi); - uint64_t merge_bytes = static_cast(args._nbatches) * args._nmulti * k_blocks * args._Msize * roundup(args._Nsize, strategy::out_width()) * sizeof(Tr); + uint64_t merge_bytes = static_cast(args._nbatches) * args._nmulti * k_blocks * args._Msize * roundup(args._Nsize, strategy::out_width()) * sizeof(Tr); float mac_cycles = static_cast(total_macs) / params.kernel_macs_cycle; float prepare_cycles = static_cast(prepare_bytes) / params.prepare_bytes_cycle; @@ -1065,6 +1154,7 @@ class GemmInterleaved : public GemmCommon { c.inner_block_size = _k_block; c.outer_block_size = _x_block; c.filter = get_type_name(); + c.weight_format = get_weight_format(get_kernel_weight_format::get(), sizeof(To)); return c; } @@ -1074,6 +1164,9 @@ class GemmInterleaved : public GemmCommon { template using GemmInterleavedNoMerge = GemmInterleaved; +template +using GemmInterleavedFixedFormat = GemmInterleaved; + template using GemmInterleavedPretransposedNoMergeQuantizedInline = GemmInterleaved; diff --git a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp index 6b813c7974..1d7b9c5b73 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_qint8.cpp @@ -230,7 +230,7 @@ const GemmImplementation *gemm_implementation_list } template UniqueGemmCommon gemm(const GemmArgs &args, const Requantize32 &os); -template bool has_opt_gemm(const GemmArgs &args, const Requantize32 &os); +template bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const Requantize32 &os); template std::vector get_compatible_kernels(const GemmArgs &args, const Requantize32 &os); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp index 95139c2bf6..be7a4ee570 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp @@ -197,7 +197,7 @@ const GemmImplementation *gemm_implementation_li } template UniqueGemmCommon gemm(const GemmArgs &args, const Requantize32 &os); -template bool has_opt_gemm(const GemmArgs &args, const Requantize32 &os); +template bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const Requantize32 &os); template std::vector get_compatible_kernels(const GemmArgs &args, const Requantize32 &os); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp index 20cee556f0..fc836f9790 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp @@ -56,7 +56,7 @@ const GemmImplementation *gemm_implementation_list gemm(const GemmArgs &args, const Nothing &); -template bool has_opt_gemm(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const Nothing &); template std::vector get_compatible_kernels(const GemmArgs &args, const Nothing &); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp index a2d2cc86f0..03e9cd6c1f 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp @@ -157,7 +157,7 @@ const GemmImplementation *gemm_implementation_list gemm(const GemmArgs &args, const Nothing &); -template bool has_opt_gemm(const GemmArgs &args, const Nothing &); +template bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const Nothing &); template std::vector get_compatible_kernels (const GemmArgs &args, const Nothing &); } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/kernel_weight_format.hpp b/src/core/NEON/kernels/arm_gemm/kernel_weight_format.hpp new file mode 100644 index 0000000000..6b89dd0d73 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernel_weight_format.hpp @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#include "arm_gemm.hpp" + +namespace arm_gemm { + +/* Internal enum to define the weight format a kernel is expecting. + * + * This is distinct from the "external" WeightFormat defined in arm_gemm.hpp primarily to allow for SVE, where + * internally kernels are defined in terms of multiples of the SVE vector length, but externally they are converted + * to a fixed format (based on the VL of the machine we are running on). + * + * Encoded as a bitfield: + * bit 0 : SVE flag + * bit 4 : BF16 convert flag (fast mode) + * bits 11-8 : block length (bytes) + * bits 15-12: vector count + */ +enum class KernelWeightFormat { + NON_FIXED = 0, + VL128_BL16 = 0x1200, + VL128_BL32 = 0x1400, + VL128_BL32_BF16 = 0x1410, + VL128_BL64 = 0x1800, + VL256_BL64 = 0x2800, + VL256_BL64_BF16 = 0x2810, + VL1VL_BL16 = 0x1201, + VL1VL_BL32 = 0x1401, + VL1VL_BL32_BF16 = 0x1411, + VL1VL_BL64 = 0x1801, + VL2VL_BL64 = 0x2801, + VL2VL_BL64_BF16 = 0x2811 +}; + +WeightFormat get_weight_format(const KernelWeightFormat, size_t); + +} // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_bf16fp32_mmla_6x16.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_bf16fp32_mmla_6x16.hpp new file mode 100644 index 0000000000..9a871d4b88 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_bf16fp32_mmla_6x16.hpp @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#pragma once +#ifdef __aarch64__ + +#include "../std_transforms_fixed.hpp" +#include "../bfloat.hpp" +#include "../kernel_weight_format.hpp" +#include "../performance_parameters.hpp" + +#define ARGLIST \ + unsigned int, const unsigned int *, \ + IndirectInputArg, \ + size_t, size_t, \ + const bfloat16 *, \ + size_t, \ + IndirectOutputArg, \ + const float *, Activation, bool + +namespace arm_gemm +{ +// Actual kernel implementations +void a64_ffhybrid_bf16fp32_mmla_6x16( ARGLIST ); + +class cls_a64_ffhybrid_bf16fp32_mmla_6x16 +{ +public: + typedef bfloat16 lhs_operand_type; + typedef bfloat16 rhs_operand_type; + typedef float result_type; + + typedef void (*kern_type)( ARGLIST ); + + /* Kernel blocking parameters */ + static constexpr unsigned int out_height() + { + return 6; + } + static unsigned int stripe_width() + { + return 4; + } + + static KernelWeightFormat kernel_weight_format() + { + return KernelWeightFormat::VL256_BL64; + } + + static unsigned int out_width() + { + return 16; + } + + static constexpr unsigned int k_unroll() + { + return 4; + } + + static constexpr bool supports_accumulate() + { + return true; + } + + StdTransformsFixed transforms = {}; + template + static inline PerformanceParameters get_performance_parameters(const CPUInfo *ci) + { + if (std::is_same::value) { + switch (ci->get_cpu_model()) { + default: + return { 37.09 }; + } + } + + return { 1.0 }; + } + + // Default to the generic kernel + kern_type kernel=a64_ffhybrid_bf16fp32_mmla_6x16; + cls_a64_ffhybrid_bf16fp32_mmla_6x16(const CPUInfo *) + { + } +}; + +} // namespace arm_gemm + +#undef ARGLIST +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_bf16fp32_mmla_6x16/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_bf16fp32_mmla_6x16/generic.cpp new file mode 100644 index 0000000000..ec93586f57 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_bf16fp32_mmla_6x16/generic.cpp @@ -0,0 +1,3807 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#ifdef __aarch64__ + +#include "arm_gemm.hpp" +#include "../../utils.hpp" +#include "../../bfloat.hpp" + +#include +#include + +namespace arm_gemm { + +void a64_ffhybrid_bf16fp32_mmla_6x16 ( + unsigned int num_strings, const unsigned int *string_lengths, IndirectInputArg A_arg, + size_t M, size_t N, const bfloat16 *B_ptr, size_t B_stride, IndirectOutputArg output_arg, + const float *bias, Activation act, bool accumulate +) +{ + struct KernelArgs { + float maxval = static_cast(std::numeric_limits::infinity()); + float minval = - static_cast(std::numeric_limits::infinity()); + unsigned int num_strings = {}; + const unsigned int *string_lengths = {}; + size_t N = {}; + const bfloat16 *B_ptr = {}; + const bfloat16 *cur_B_ptr = {}; + size_t B_stride = {}; + size_t output_offset = {}; + size_t input_initial_col = {}; + size_t input_offset = {}; + } ka; + + unsigned long flags=0; + void *output_ptr; + void *input_ptr; + + if (output_arg.is_indirect) { + output_ptr=(void *)(output_arg.indirect.ptr); + ka.output_offset=output_arg.indirect.offset; + flags |= 0x4; + } else { + output_ptr=(void *)(output_arg.direct.base); + ka.output_offset=output_arg.direct.stride; + } + + if (A_arg.is_indirect) { + input_ptr=(void *)(A_arg.indirect.ptr); + ka.input_offset=A_arg.indirect.start_row; + ka.input_initial_col=A_arg.indirect.start_col; + flags |= 0x8; + } else { + assert(num_strings==1); + input_ptr=(void *)(A_arg.direct.base); + ka.input_offset=A_arg.direct.stride; + } + if (accumulate) { + flags |= 0x1; + } + ka.num_strings = num_strings; + ka.string_lengths = string_lengths; + ka.N = N; + ka.B_ptr = B_ptr; + ka.B_stride = B_stride; + switch(act.type) { + default: + case Activation::Type::None: + break; + case Activation::Type::BoundedReLU: + ka.maxval = static_cast(act.param1); + /* fall through */ + case Activation::Type::ReLU: + ka.minval = 0; + flags |= 0x2; + break; + } + __asm__ __volatile__( + "1:" // Row loop + "cmp %x[M], #0x6\n" + "bge 191f\n" + "cmp %x[M], #0x4\n" + "bgt 153f\n" + "beq 115f\n" + "cmp %x[M], #0x2\n" + "bgt 77f\n" + "beq 39f\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "2:" // Height 1: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "cmp x13, #0xc\n" + "bgt 3f\n" + "cmp x13, #0x8\n" + "mov x28, x11\n" + "bgt 3f\n" + "cmp x13, #0x4\n" + "mov x9, x11\n" + "bgt 3f\n" + "mov x10, x11\n" + "3:" // Height 1: B setup done + "cbz x14, 4f\n" + "ldr q8, [x14, #0x0]\n" + "ldr q9, [x14, #0x10]\n" + "zip2 v12.2d, v8.2d, v8.2d\n" + "zip1 v8.2d, v8.2d, v8.2d\n" + "ldr q10, [x14, #0x20]\n" + "ldr q11, [x14, #0x30]\n" + "zip2 v13.2d, v9.2d, v9.2d\n" + "zip1 v9.2d, v9.2d, v9.2d\n" + "zip2 v14.2d, v10.2d, v10.2d\n" + "zip1 v10.2d, v10.2d, v10.2d\n" + "add x14, x14, #0x40\n" + "zip2 v15.2d, v11.2d, v11.2d\n" + "zip1 v11.2d, v11.2d, v11.2d\n" + "b 16f\n" + "4:" // Height 1: no bias + "tbz %x[flags], #0, 15f\n" + "cmp x13, #0x10\n" + "bge 13f\n" + "tbz x13, #3, 8f\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "ld1 { v10.4s }, [x12], #0x10\n" + "tbz x13, #2, 6f\n" + "ld1 { v11.4s }, [x12], #0x10\n" + "tbz x13, #1, 5f\n" + "ldr d16, [x12], #0x8\n" + "mov x19, #0x38\n" + "tbz x13, #0, 12f\n" + "ld1 { v16.s }[2], [x12]\n" + "b 12f\n" + "5:" // Height 1: Partial accumulate: partial_1_12 + "mov x19, #0x30\n" + "tbz x13, #0, 12f\n" + "ldr s16, [x12, #0x0]\n" + "b 12f\n" + "6:" // Height 1: Partial accumulate: partial_2_8 + "tbz x13, #1, 7f\n" + "ldr d11, [x12], #0x8\n" + "mov x19, #0x28\n" + "tbz x13, #0, 12f\n" + "ld1 { v11.s }[2], [x12]\n" + "b 12f\n" + "7:" // Height 1: Partial accumulate: partial_1_8 + "mov x19, #0x20\n" + "tbz x13, #0, 12f\n" + "ldr s11, [x12, #0x0]\n" + "b 12f\n" + "8:" // Height 1: Partial accumulate: partial_4_0 + "tbz x13, #2, 10f\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "tbz x13, #1, 9f\n" + "ldr d10, [x12], #0x8\n" + "mov x19, #0x18\n" + "tbz x13, #0, 12f\n" + "ld1 { v10.s }[2], [x12]\n" + "b 12f\n" + "9:" // Height 1: Partial accumulate: partial_1_4 + "mov x19, #0x10\n" + "tbz x13, #0, 12f\n" + "ldr s10, [x12, #0x0]\n" + "b 12f\n" + "10:" // Height 1: Partial accumulate: partial_2_0 + "tbz x13, #1, 11f\n" + "ldr d9, [x12], #0x8\n" + "mov x19, #0x8\n" + "tbz x13, #0, 12f\n" + "ld1 { v9.s }[2], [x12]\n" + "b 12f\n" + "11:" // Height 1: Partial accumulate: partial_1_0 + "ldr s9, [x12, #0x0]\n" + "mov x19, #0x0\n" + "12:" // Height 1: Partial accumulate: Done + "sub x12, x12, x19\n" + "b 14f\n" + "13:" // Height 1: full accumulate + "ldr q9, [x12, #0x0]\n" + "ldr q10, [x12, #0x10]\n" + "ldr q11, [x12, #0x20]\n" + "ldr q16, [x12, #0x30]\n" + "14:" // Height 1: MMLA fixup + "zip1 v8.2d, v9.2d, v12.2d\n" + "zip2 v12.2d, v9.2d, v12.2d\n" + "zip1 v9.2d, v10.2d, v13.2d\n" + "zip2 v13.2d, v10.2d, v13.2d\n" + "zip1 v10.2d, v11.2d, v14.2d\n" + "zip2 v14.2d, v11.2d, v14.2d\n" + "zip1 v11.2d, v16.2d, v15.2d\n" + "zip2 v15.2d, v16.2d, v15.2d\n" + "b 16f\n" + "15:" // Height 1: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "16:" // Height 1: setup done + "mov x27, #0x0\n" + "17:" // Height 1: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 18f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "cbnz x27, 19f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "b 19f\n" + "18:" // Height 1: setup direct input + "mov x25, %x[input_ptr]\n" + "19:" // Height 1: input setup done + "cmp x26, #0x8\n" + "blt 22f\n" + "ldr q1, [x25, #0x0]\n" + "ldr q7, [x11, #0x0]\n" + "cmp x26, #0x10\n" + "ldr q6, [x11, #0x10]\n" + "blt 21f\n" + "20:" // Height 1: Multiply loop: Main loop head + "trn1 v0.2d, v1.2d, v2.2d\n" + ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" + "ldr q7, [x10, #0x0]\n" + ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" + "ldr q6, [x10, #0x10]\n" + ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n" + "ldr q7, [x9, #0x0]\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + "ldr q6, [x9, #0x10]\n" + ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n" + "ldr q7, [x28, #0x0]\n" + ".inst 0x6e46ec0e // bfmmla v14.4s, v0.8h, v6.8h\n" + "ldr q6, [x28, #0x10]\n" + "trn2 v1.2d, v1.2d, v2.2d\n" + ".inst 0x6e47ec0b // bfmmla v11.4s, v0.8h, v7.8h\n" + "ldr q7, [x11, #0x20]\n" + ".inst 0x6e46ec0f // bfmmla v15.4s, v0.8h, v6.8h\n" + "ldr q6, [x11, #0x30]\n" + ".inst 0x6e47ec28 // bfmmla v8.4s, v1.8h, v7.8h\n" + "ldr q7, [x10, #0x20]\n" + ".inst 0x6e46ec2c // bfmmla v12.4s, v1.8h, v6.8h\n" + "ldr q6, [x10, #0x30]\n" + ".inst 0x6e47ec29 // bfmmla v9.4s, v1.8h, v7.8h\n" + "ldr q7, [x9, #0x20]\n" + ".inst 0x6e46ec2d // bfmmla v13.4s, v1.8h, v6.8h\n" + "ldr q6, [x9, #0x30]\n" + ".inst 0x6e47ec2a // bfmmla v10.4s, v1.8h, v7.8h\n" + "ldr q7, [x28, #0x20]\n" + ".inst 0x6e46ec2e // bfmmla v14.4s, v1.8h, v6.8h\n" + "ldr q6, [x28, #0x30]\n" + "sub x26, x26, #0x8\n" + "cmp x26, #0x10\n" + "add x25, x25, #0x10\n" + ".inst 0x6e47ec2b // bfmmla v11.4s, v1.8h, v7.8h\n" + ".inst 0x6e46ec2f // bfmmla v15.4s, v1.8h, v6.8h\n" + "ldr q1, [x25, #0x0]\n" + "add x11, x11, #0x40\n" + "ldr q7, [x11, #0x0]\n" + "add x10, x10, #0x40\n" + "ldr q6, [x11, #0x10]\n" + "add x9, x9, #0x40\n" + "add x28, x28, #0x40\n" + "bge 20b\n" + "21:" // Height 1: Multiply loop: Single iteration only + "trn1 v0.2d, v1.2d, v2.2d\n" + ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" + "ldr q7, [x10, #0x0]\n" + ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" + "ldr q6, [x10, #0x10]\n" + ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n" + "ldr q7, [x9, #0x0]\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + "ldr q6, [x9, #0x10]\n" + ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n" + "ldr q7, [x28, #0x0]\n" + ".inst 0x6e46ec0e // bfmmla v14.4s, v0.8h, v6.8h\n" + "ldr q6, [x28, #0x10]\n" + "trn2 v1.2d, v1.2d, v2.2d\n" + ".inst 0x6e47ec0b // bfmmla v11.4s, v0.8h, v7.8h\n" + "ldr q7, [x11, #0x20]\n" + ".inst 0x6e46ec0f // bfmmla v15.4s, v0.8h, v6.8h\n" + "ldr q6, [x11, #0x30]\n" + ".inst 0x6e47ec28 // bfmmla v8.4s, v1.8h, v7.8h\n" + "ldr q7, [x10, #0x20]\n" + ".inst 0x6e46ec2c // bfmmla v12.4s, v1.8h, v6.8h\n" + "ldr q6, [x10, #0x30]\n" + ".inst 0x6e47ec29 // bfmmla v9.4s, v1.8h, v7.8h\n" + "ldr q7, [x9, #0x20]\n" + ".inst 0x6e46ec2d // bfmmla v13.4s, v1.8h, v6.8h\n" + "ldr q6, [x9, #0x30]\n" + ".inst 0x6e47ec2a // bfmmla v10.4s, v1.8h, v7.8h\n" + "ldr q7, [x28, #0x20]\n" + ".inst 0x6e46ec2e // bfmmla v14.4s, v1.8h, v6.8h\n" + "ldr q6, [x28, #0x30]\n" + "sub x26, x26, #0x8\n" + ".inst 0x6e47ec2b // bfmmla v11.4s, v1.8h, v7.8h\n" + ".inst 0x6e46ec2f // bfmmla v15.4s, v1.8h, v6.8h\n" + "add x25, x25, #0x10\n" + "add x11, x11, #0x40\n" + "add x10, x10, #0x40\n" + "add x9, x9, #0x40\n" + "add x28, x28, #0x40\n" + "22:" // Height 1: Multiply loop: Main loop skip + "cbz x26, 27f\n" + "cmp x26, #0x4\n" + "blt 24f\n" + "23:" // Height 1: Multiply loop: Odd block loop + "ldr d1, [x25], #0x8\n" + "ldr q6, [x11, #0x0]\n" + "trn1 v0.2d, v1.2d, v2.2d\n" + "ldr q7, [x11, #0x10]\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + "ldr q6, [x10, #0x0]\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + "ldr q7, [x10, #0x10]\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + "ldr q6, [x9, #0x0]\n" + ".inst 0x6e47ec0d // bfmmla v13.4s, v0.8h, v7.8h\n" + "ldr q7, [x9, #0x10]\n" + ".inst 0x6e46ec0a // bfmmla v10.4s, v0.8h, v6.8h\n" + "ldr q6, [x28, #0x0]\n" + ".inst 0x6e47ec0e // bfmmla v14.4s, v0.8h, v7.8h\n" + "ldr q7, [x28, #0x10]\n" + "sub x26, x26, #0x4\n" + "cmp x26, #0x4\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + "add x11, x11, #0x20\n" + "add x10, x10, #0x20\n" + "add x9, x9, #0x20\n" + "add x28, x28, #0x20\n" + "bge 23b\n" + "24:" // Height 1: Multiply loop: Skip odd blocks + "cbz x26, 27f\n" + "tbz x26, #1, 25f\n" + "ldr s1, [x25], #0x4\n" + "tbz x26, #0, 26f\n" + "ld1 { v1.h }[2], [x25]\n" + "b 26f\n" + "25:" // Height 1: Multiply loop: Ragged operand read: partial_1_0 + "ldr h1, [x25, #0x0]\n" + "26:" // Height 1: Multiply loop: Ragged operand read: Done + "ldr q7, [x11, #0x0]\n" + "ldr q6, [x11, #0x10]\n" + "trn1 v0.2d, v1.2d, v2.2d\n" + ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" + "ldr q7, [x10, #0x0]\n" + ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" + "ldr q6, [x10, #0x10]\n" + ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n" + "ldr q7, [x9, #0x0]\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + "ldr q6, [x9, #0x10]\n" + ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n" + "ldr q7, [x28, #0x0]\n" + ".inst 0x6e46ec0e // bfmmla v14.4s, v0.8h, v6.8h\n" + "ldr q6, [x28, #0x10]\n" + ".inst 0x6e47ec0b // bfmmla v11.4s, v0.8h, v7.8h\n" + ".inst 0x6e46ec0f // bfmmla v15.4s, v0.8h, v6.8h\n" + "add x11, x11, #0x20\n" + "add x10, x10, #0x20\n" + "add x9, x9, #0x20\n" + "add x28, x28, #0x20\n" + "27:" // Height 1: Multiply loop: No odd multiplies + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 17b\n" + "uzp1 v8.2d, v8.2d, v12.2d\n" + "uzp1 v9.2d, v9.2d, v13.2d\n" + "uzp1 v10.2d, v10.2d, v14.2d\n" + "uzp1 v11.2d, v11.2d, v15.2d\n" + "tbz %x[flags], #1, 28f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1r { v1.4s }, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1r { v0.4s }, [x19]\n" + "fmin v8.4s, v8.4s, v1.4s\n" + "fmin v9.4s, v9.4s, v1.4s\n" + "fmin v10.4s, v10.4s, v1.4s\n" + "fmin v11.4s, v11.4s, v1.4s\n" + "fmax v8.4s, v8.4s, v0.4s\n" + "fmax v9.4s, v9.4s, v0.4s\n" + "fmax v10.4s, v10.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + "28:" // Height 1: No activation + "cmp x13, #0x10\n" + "bge 37f\n" + "tbz x13, #3, 32f\n" + "st1 { v8.4s }, [x12], #0x10\n" + "st1 { v9.4s }, [x12], #0x10\n" + "tbz x13, #2, 30f\n" + "st1 { v10.4s }, [x12], #0x10\n" + "tbz x13, #1, 29f\n" + "str d11, [x12], #0x8\n" + "tbz x13, #0, 36f\n" + "st1 { v11.s }[2], [x12]\n" + "b 36f\n" + "29:" // Height 1: Partial direct writeback: partial_1_12 + "tbz x13, #0, 36f\n" + "str s11, [x12, #0x0]\n" + "b 36f\n" + "30:" // Height 1: Partial direct writeback: partial_2_8 + "tbz x13, #1, 31f\n" + "str d10, [x12], #0x8\n" + "tbz x13, #0, 36f\n" + "st1 { v10.s }[2], [x12]\n" + "b 36f\n" + "31:" // Height 1: Partial direct writeback: partial_1_8 + "tbz x13, #0, 36f\n" + "str s10, [x12, #0x0]\n" + "b 36f\n" + "32:" // Height 1: Partial direct writeback: partial_4_0 + "tbz x13, #2, 34f\n" + "st1 { v8.4s }, [x12], #0x10\n" + "tbz x13, #1, 33f\n" + "str d9, [x12], #0x8\n" + "tbz x13, #0, 36f\n" + "st1 { v9.s }[2], [x12]\n" + "b 36f\n" + "33:" // Height 1: Partial direct writeback: partial_1_4 + "tbz x13, #0, 36f\n" + "str s9, [x12, #0x0]\n" + "b 36f\n" + "34:" // Height 1: Partial direct writeback: partial_2_0 + "tbz x13, #1, 35f\n" + "str d8, [x12], #0x8\n" + "tbz x13, #0, 36f\n" + "st1 { v8.s }[2], [x12]\n" + "b 36f\n" + "35:" // Height 1: Partial direct writeback: partial_1_0 + "str s8, [x12, #0x0]\n" + "36:" // Height 1: Partial direct writeback: Done + "b 38f\n" + "37:" // Height 1: Full writeback + "str q8, [x12, #0x0]\n" + "str q9, [x12, #0x10]\n" + "str q10, [x12, #0x20]\n" + "str q11, [x12, #0x30]\n" + "add x12, x12, #0x40\n" + "38:" // Height 1: Writeback done + "subs x13, x13, #0x10\n" + "bgt 2b\n" + "b 230f\n" + "39:" // Height 2 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "40:" // Height 2: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "cmp x13, #0xc\n" + "bgt 41f\n" + "cmp x13, #0x8\n" + "mov x28, x11\n" + "bgt 41f\n" + "cmp x13, #0x4\n" + "mov x9, x11\n" + "bgt 41f\n" + "mov x10, x11\n" + "41:" // Height 2: B setup done + "cbz x14, 42f\n" + "ldr q8, [x14, #0x0]\n" + "ldr q9, [x14, #0x10]\n" + "zip2 v12.2d, v8.2d, v8.2d\n" + "zip1 v8.2d, v8.2d, v8.2d\n" + "ldr q10, [x14, #0x20]\n" + "ldr q11, [x14, #0x30]\n" + "zip2 v13.2d, v9.2d, v9.2d\n" + "zip1 v9.2d, v9.2d, v9.2d\n" + "zip2 v14.2d, v10.2d, v10.2d\n" + "zip1 v10.2d, v10.2d, v10.2d\n" + "add x14, x14, #0x40\n" + "zip2 v15.2d, v11.2d, v11.2d\n" + "zip1 v11.2d, v11.2d, v11.2d\n" + "b 54f\n" + "42:" // Height 2: no bias + "tbz %x[flags], #0, 53f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "cmp x13, #0x10\n" + "add x24, x12, x19, LSL #2\n" + "bge 51f\n" + "tbz x13, #3, 46f\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "ld1 { v12.4s }, [x24], #0x10\n" + "ld1 { v10.4s }, [x12], #0x10\n" + "ld1 { v13.4s }, [x24], #0x10\n" + "tbz x13, #2, 44f\n" + "ld1 { v11.4s }, [x12], #0x10\n" + "ld1 { v14.4s }, [x24], #0x10\n" + "tbz x13, #1, 43f\n" + "ldr d16, [x12], #0x8\n" + "ldr d15, [x24], #0x8\n" + "mov x19, #0x38\n" + "tbz x13, #0, 50f\n" + "ld1 { v16.s }[2], [x12]\n" + "ld1 { v15.s }[2], [x24]\n" + "b 50f\n" + "43:" // Height 2: Partial accumulate: partial_1_12 + "mov x19, #0x30\n" + "tbz x13, #0, 50f\n" + "ldr s16, [x12, #0x0]\n" + "ldr s15, [x24, #0x0]\n" + "b 50f\n" + "44:" // Height 2: Partial accumulate: partial_2_8 + "tbz x13, #1, 45f\n" + "ldr d11, [x12], #0x8\n" + "ldr d14, [x24], #0x8\n" + "mov x19, #0x28\n" + "tbz x13, #0, 50f\n" + "ld1 { v11.s }[2], [x12]\n" + "ld1 { v14.s }[2], [x24]\n" + "b 50f\n" + "45:" // Height 2: Partial accumulate: partial_1_8 + "mov x19, #0x20\n" + "tbz x13, #0, 50f\n" + "ldr s11, [x12, #0x0]\n" + "ldr s14, [x24, #0x0]\n" + "b 50f\n" + "46:" // Height 2: Partial accumulate: partial_4_0 + "tbz x13, #2, 48f\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "ld1 { v12.4s }, [x24], #0x10\n" + "tbz x13, #1, 47f\n" + "ldr d10, [x12], #0x8\n" + "ldr d13, [x24], #0x8\n" + "mov x19, #0x18\n" + "tbz x13, #0, 50f\n" + "ld1 { v10.s }[2], [x12]\n" + "ld1 { v13.s }[2], [x24]\n" + "b 50f\n" + "47:" // Height 2: Partial accumulate: partial_1_4 + "mov x19, #0x10\n" + "tbz x13, #0, 50f\n" + "ldr s10, [x12, #0x0]\n" + "ldr s13, [x24, #0x0]\n" + "b 50f\n" + "48:" // Height 2: Partial accumulate: partial_2_0 + "tbz x13, #1, 49f\n" + "ldr d9, [x12], #0x8\n" + "ldr d12, [x24], #0x8\n" + "mov x19, #0x8\n" + "tbz x13, #0, 50f\n" + "ld1 { v9.s }[2], [x12]\n" + "ld1 { v12.s }[2], [x24]\n" + "b 50f\n" + "49:" // Height 2: Partial accumulate: partial_1_0 + "ldr s9, [x12, #0x0]\n" + "ldr s12, [x24, #0x0]\n" + "mov x19, #0x0\n" + "50:" // Height 2: Partial accumulate: Done + "sub x12, x12, x19\n" + "b 52f\n" + "51:" // Height 2: full accumulate + "ldr q9, [x12, #0x0]\n" + "ldr q10, [x12, #0x10]\n" + "ldr q11, [x12, #0x20]\n" + "ldr q16, [x12, #0x30]\n" + "ldr q12, [x24, #0x0]\n" + "ldr q13, [x24, #0x10]\n" + "ldr q14, [x24, #0x20]\n" + "ldr q15, [x24, #0x30]\n" + "52:" // Height 2: MMLA fixup + "zip1 v8.2d, v9.2d, v12.2d\n" + "zip2 v12.2d, v9.2d, v12.2d\n" + "zip1 v9.2d, v10.2d, v13.2d\n" + "zip2 v13.2d, v10.2d, v13.2d\n" + "zip1 v10.2d, v11.2d, v14.2d\n" + "zip2 v14.2d, v11.2d, v14.2d\n" + "zip1 v11.2d, v16.2d, v15.2d\n" + "zip2 v15.2d, v16.2d, v15.2d\n" + "b 54f\n" + "53:" // Height 2: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "54:" // Height 2: setup done + "mov x27, #0x0\n" + "55:" // Height 2: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 56f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "cbnz x27, 57f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "add x24, x24, x19, LSL #1\n" + "b 57f\n" + "56:" // Height 2: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #1\n" + "57:" // Height 2: input setup done + "cmp x26, #0x8\n" + "blt 60f\n" + "ldr q1, [x25, #0x0]\n" + "ldr q2, [x24, #0x0]\n" + "cmp x26, #0x10\n" + "ldr q7, [x11, #0x0]\n" + "ldr q6, [x11, #0x10]\n" + "blt 59f\n" + "58:" // Height 2: Multiply loop: Main loop head + "trn1 v0.2d, v1.2d, v2.2d\n" + ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" + "ldr q7, [x10, #0x0]\n" + ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" + "ldr q6, [x10, #0x10]\n" + ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n" + "ldr q7, [x9, #0x0]\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + "ldr q6, [x9, #0x10]\n" + ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n" + "ldr q7, [x28, #0x0]\n" + ".inst 0x6e46ec0e // bfmmla v14.4s, v0.8h, v6.8h\n" + "ldr q6, [x28, #0x10]\n" + "trn2 v1.2d, v1.2d, v2.2d\n" + ".inst 0x6e47ec0b // bfmmla v11.4s, v0.8h, v7.8h\n" + "ldr q7, [x11, #0x20]\n" + ".inst 0x6e46ec0f // bfmmla v15.4s, v0.8h, v6.8h\n" + "ldr q6, [x11, #0x30]\n" + ".inst 0x6e47ec28 // bfmmla v8.4s, v1.8h, v7.8h\n" + "ldr q7, [x10, #0x20]\n" + ".inst 0x6e46ec2c // bfmmla v12.4s, v1.8h, v6.8h\n" + "ldr q6, [x10, #0x30]\n" + ".inst 0x6e47ec29 // bfmmla v9.4s, v1.8h, v7.8h\n" + "ldr q7, [x9, #0x20]\n" + ".inst 0x6e46ec2d // bfmmla v13.4s, v1.8h, v6.8h\n" + "ldr q6, [x9, #0x30]\n" + ".inst 0x6e47ec2a // bfmmla v10.4s, v1.8h, v7.8h\n" + "ldr q7, [x28, #0x20]\n" + ".inst 0x6e46ec2e // bfmmla v14.4s, v1.8h, v6.8h\n" + "ldr q6, [x28, #0x30]\n" + "sub x26, x26, #0x8\n" + "cmp x26, #0x10\n" + "add x25, x25, #0x10\n" + "add x24, x24, #0x10\n" + "ldr q2, [x24, #0x0]\n" + ".inst 0x6e47ec2b // bfmmla v11.4s, v1.8h, v7.8h\n" + "add x11, x11, #0x40\n" + "ldr q7, [x11, #0x0]\n" + ".inst 0x6e46ec2f // bfmmla v15.4s, v1.8h, v6.8h\n" + "ldr q1, [x25, #0x0]\n" + "ldr q6, [x11, #0x10]\n" + "add x10, x10, #0x40\n" + "add x9, x9, #0x40\n" + "add x28, x28, #0x40\n" + "bge 58b\n" + "59:" // Height 2: Multiply loop: Single iteration only + "trn1 v0.2d, v1.2d, v2.2d\n" + ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" + "ldr q7, [x10, #0x0]\n" + ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" + "ldr q6, [x10, #0x10]\n" + ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n" + "ldr q7, [x9, #0x0]\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + "ldr q6, [x9, #0x10]\n" + ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n" + "ldr q7, [x28, #0x0]\n" + ".inst 0x6e46ec0e // bfmmla v14.4s, v0.8h, v6.8h\n" + "ldr q6, [x28, #0x10]\n" + "trn2 v1.2d, v1.2d, v2.2d\n" + ".inst 0x6e47ec0b // bfmmla v11.4s, v0.8h, v7.8h\n" + "ldr q7, [x11, #0x20]\n" + ".inst 0x6e46ec0f // bfmmla v15.4s, v0.8h, v6.8h\n" + "ldr q6, [x11, #0x30]\n" + ".inst 0x6e47ec28 // bfmmla v8.4s, v1.8h, v7.8h\n" + "ldr q7, [x10, #0x20]\n" + ".inst 0x6e46ec2c // bfmmla v12.4s, v1.8h, v6.8h\n" + "ldr q6, [x10, #0x30]\n" + ".inst 0x6e47ec29 // bfmmla v9.4s, v1.8h, v7.8h\n" + "ldr q7, [x9, #0x20]\n" + ".inst 0x6e46ec2d // bfmmla v13.4s, v1.8h, v6.8h\n" + "ldr q6, [x9, #0x30]\n" + ".inst 0x6e47ec2a // bfmmla v10.4s, v1.8h, v7.8h\n" + "ldr q7, [x28, #0x20]\n" + ".inst 0x6e46ec2e // bfmmla v14.4s, v1.8h, v6.8h\n" + "ldr q6, [x28, #0x30]\n" + "sub x26, x26, #0x8\n" + ".inst 0x6e47ec2b // bfmmla v11.4s, v1.8h, v7.8h\n" + ".inst 0x6e46ec2f // bfmmla v15.4s, v1.8h, v6.8h\n" + "add x25, x25, #0x10\n" + "add x24, x24, #0x10\n" + "add x11, x11, #0x40\n" + "add x10, x10, #0x40\n" + "add x9, x9, #0x40\n" + "add x28, x28, #0x40\n" + "60:" // Height 2: Multiply loop: Main loop skip + "cbz x26, 65f\n" + "cmp x26, #0x4\n" + "blt 62f\n" + "61:" // Height 2: Multiply loop: Odd block loop + "ldr d1, [x25], #0x8\n" + "ldr d2, [x24], #0x8\n" + "trn1 v0.2d, v1.2d, v2.2d\n" + "sub x26, x26, #0x4\n" + "ldr q6, [x11, #0x0]\n" + "ldr q7, [x11, #0x10]\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + "ldr q6, [x10, #0x0]\n" + "ldr q7, [x10, #0x10]\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec0d // bfmmla v13.4s, v0.8h, v7.8h\n" + "ldr q6, [x9, #0x0]\n" + "ldr q7, [x9, #0x10]\n" + ".inst 0x6e46ec0a // bfmmla v10.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec0e // bfmmla v14.4s, v0.8h, v7.8h\n" + "ldr q6, [x28, #0x0]\n" + "ldr q7, [x28, #0x10]\n" + "cmp x26, #0x4\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + "add x11, x11, #0x20\n" + "add x10, x10, #0x20\n" + "add x9, x9, #0x20\n" + "add x28, x28, #0x20\n" + "bge 61b\n" + "62:" // Height 2: Multiply loop: Skip odd blocks + "cbz x26, 65f\n" + "tbz x26, #1, 63f\n" + "ldr s1, [x25], #0x4\n" + "ldr s2, [x24], #0x4\n" + "tbz x26, #0, 64f\n" + "ld1 { v1.h }[2], [x25]\n" + "ld1 { v2.h }[2], [x24]\n" + "b 64f\n" + "63:" // Height 2: Multiply loop: Ragged operand read: partial_1_0 + "ldr h1, [x25, #0x0]\n" + "ldr h2, [x24, #0x0]\n" + "64:" // Height 2: Multiply loop: Ragged operand read: Done + "ldr q7, [x11, #0x0]\n" + "ldr q6, [x11, #0x10]\n" + "trn1 v0.2d, v1.2d, v2.2d\n" + ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" + "ldr q7, [x10, #0x0]\n" + ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" + "ldr q6, [x10, #0x10]\n" + ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n" + "ldr q7, [x9, #0x0]\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + "ldr q6, [x9, #0x10]\n" + ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n" + "ldr q7, [x28, #0x0]\n" + ".inst 0x6e46ec0e // bfmmla v14.4s, v0.8h, v6.8h\n" + "ldr q6, [x28, #0x10]\n" + ".inst 0x6e47ec0b // bfmmla v11.4s, v0.8h, v7.8h\n" + ".inst 0x6e46ec0f // bfmmla v15.4s, v0.8h, v6.8h\n" + "add x11, x11, #0x20\n" + "add x10, x10, #0x20\n" + "add x9, x9, #0x20\n" + "add x28, x28, #0x20\n" + "65:" // Height 2: Multiply loop: No odd multiplies + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 55b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "uzp1 v7.2d, v8.2d, v12.2d\n" + "uzp2 v8.2d, v8.2d, v12.2d\n" + "add x24, x12, x19, LSL #2\n" + "uzp1 v12.2d, v9.2d, v13.2d\n" + "uzp2 v9.2d, v9.2d, v13.2d\n" + "uzp1 v13.2d, v10.2d, v14.2d\n" + "uzp2 v10.2d, v10.2d, v14.2d\n" + "uzp1 v14.2d, v11.2d, v15.2d\n" + "uzp2 v11.2d, v11.2d, v15.2d\n" + "tbz %x[flags], #1, 66f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1r { v1.4s }, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1r { v0.4s }, [x19]\n" + "fmin v7.4s, v7.4s, v1.4s\n" + "fmin v12.4s, v12.4s, v1.4s\n" + "fmin v13.4s, v13.4s, v1.4s\n" + "fmin v14.4s, v14.4s, v1.4s\n" + "fmin v8.4s, v8.4s, v1.4s\n" + "fmin v9.4s, v9.4s, v1.4s\n" + "fmin v10.4s, v10.4s, v1.4s\n" + "fmin v11.4s, v11.4s, v1.4s\n" + "fmax v7.4s, v7.4s, v0.4s\n" + "fmax v12.4s, v12.4s, v0.4s\n" + "fmax v13.4s, v13.4s, v0.4s\n" + "fmax v14.4s, v14.4s, v0.4s\n" + "fmax v8.4s, v8.4s, v0.4s\n" + "fmax v9.4s, v9.4s, v0.4s\n" + "fmax v10.4s, v10.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + "66:" // Height 2: No activation + "cmp x13, #0x10\n" + "bge 75f\n" + "tbz x13, #3, 70f\n" + "st1 { v7.4s }, [x12], #0x10\n" + "st1 { v12.4s }, [x12], #0x10\n" + "st1 { v8.4s }, [x24], #0x10\n" + "st1 { v9.4s }, [x24], #0x10\n" + "tbz x13, #2, 68f\n" + "st1 { v13.4s }, [x12], #0x10\n" + "st1 { v10.4s }, [x24], #0x10\n" + "tbz x13, #1, 67f\n" + "str d14, [x12], #0x8\n" + "str d11, [x24], #0x8\n" + "tbz x13, #0, 74f\n" + "st1 { v14.s }[2], [x12]\n" + "st1 { v11.s }[2], [x24]\n" + "b 74f\n" + "67:" // Height 2: Partial direct writeback: partial_1_12 + "tbz x13, #0, 74f\n" + "str s14, [x12, #0x0]\n" + "str s11, [x24, #0x0]\n" + "b 74f\n" + "68:" // Height 2: Partial direct writeback: partial_2_8 + "tbz x13, #1, 69f\n" + "str d13, [x12], #0x8\n" + "str d10, [x24], #0x8\n" + "tbz x13, #0, 74f\n" + "st1 { v13.s }[2], [x12]\n" + "st1 { v10.s }[2], [x24]\n" + "b 74f\n" + "69:" // Height 2: Partial direct writeback: partial_1_8 + "tbz x13, #0, 74f\n" + "str s13, [x12, #0x0]\n" + "str s10, [x24, #0x0]\n" + "b 74f\n" + "70:" // Height 2: Partial direct writeback: partial_4_0 + "tbz x13, #2, 72f\n" + "st1 { v7.4s }, [x12], #0x10\n" + "st1 { v8.4s }, [x24], #0x10\n" + "tbz x13, #1, 71f\n" + "str d12, [x12], #0x8\n" + "str d9, [x24], #0x8\n" + "tbz x13, #0, 74f\n" + "st1 { v12.s }[2], [x12]\n" + "st1 { v9.s }[2], [x24]\n" + "b 74f\n" + "71:" // Height 2: Partial direct writeback: partial_1_4 + "tbz x13, #0, 74f\n" + "str s12, [x12, #0x0]\n" + "str s9, [x24, #0x0]\n" + "b 74f\n" + "72:" // Height 2: Partial direct writeback: partial_2_0 + "tbz x13, #1, 73f\n" + "str d7, [x12], #0x8\n" + "str d8, [x24], #0x8\n" + "tbz x13, #0, 74f\n" + "st1 { v7.s }[2], [x12]\n" + "st1 { v8.s }[2], [x24]\n" + "b 74f\n" + "73:" // Height 2: Partial direct writeback: partial_1_0 + "str s7, [x12, #0x0]\n" + "str s8, [x24, #0x0]\n" + "74:" // Height 2: Partial direct writeback: Done + "b 76f\n" + "75:" // Height 2: Full writeback + "str q7, [x12, #0x0]\n" + "str q12, [x12, #0x10]\n" + "str q13, [x12, #0x20]\n" + "str q14, [x12, #0x30]\n" + "add x12, x12, #0x40\n" + "str q8, [x24, #0x0]\n" + "str q9, [x24, #0x10]\n" + "str q10, [x24, #0x20]\n" + "str q11, [x24, #0x30]\n" + "76:" // Height 2: Writeback done + "subs x13, x13, #0x10\n" + "bgt 40b\n" + "b 230f\n" + "77:" // Height 3 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "78:" // Height 3: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "cmp x13, #0xc\n" + "bgt 79f\n" + "cmp x13, #0x8\n" + "mov x28, x11\n" + "bgt 79f\n" + "cmp x13, #0x4\n" + "mov x9, x11\n" + "bgt 79f\n" + "mov x10, x11\n" + "79:" // Height 3: B setup done + "cbz x14, 80f\n" + "ldr q8, [x14, #0x0]\n" + "ldr q9, [x14, #0x10]\n" + "zip2 v12.2d, v8.2d, v8.2d\n" + "zip1 v8.2d, v8.2d, v8.2d\n" + "ldr q10, [x14, #0x20]\n" + "ldr q11, [x14, #0x30]\n" + "zip2 v13.2d, v9.2d, v9.2d\n" + "zip1 v9.2d, v9.2d, v9.2d\n" + "zip2 v14.2d, v10.2d, v10.2d\n" + "zip1 v10.2d, v10.2d, v10.2d\n" + "add x14, x14, #0x40\n" + "zip2 v15.2d, v11.2d, v11.2d\n" + "zip1 v11.2d, v11.2d, v11.2d\n" + "mov v16.16b, v8.16b\n" + "mov v20.16b, v12.16b\n" + "mov v17.16b, v9.16b\n" + "mov v21.16b, v13.16b\n" + "mov v18.16b, v10.16b\n" + "mov v22.16b, v14.16b\n" + "mov v19.16b, v11.16b\n" + "mov v23.16b, v15.16b\n" + "b 92f\n" + "80:" // Height 3: no bias + "tbz %x[flags], #0, 91f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "cmp x13, #0x10\n" + "add x23, x24, x19, LSL #2\n" + "bge 89f\n" + "tbz x13, #3, 84f\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "ld1 { v12.4s }, [x24], #0x10\n" + "ld1 { v17.4s }, [x23], #0x10\n" + "ld1 { v10.4s }, [x12], #0x10\n" + "ld1 { v13.4s }, [x24], #0x10\n" + "ld1 { v18.4s }, [x23], #0x10\n" + "tbz x13, #2, 82f\n" + "ld1 { v11.4s }, [x12], #0x10\n" + "ld1 { v14.4s }, [x24], #0x10\n" + "ld1 { v19.4s }, [x23], #0x10\n" + "tbz x13, #1, 81f\n" + "ldr d16, [x12], #0x8\n" + "ldr d15, [x24], #0x8\n" + "mov x19, #0x38\n" + "ldr d24, [x23], #0x8\n" + "tbz x13, #0, 88f\n" + "ld1 { v16.s }[2], [x12]\n" + "ld1 { v15.s }[2], [x24]\n" + "ld1 { v24.s }[2], [x23]\n" + "b 88f\n" + "81:" // Height 3: Partial accumulate: partial_1_12 + "mov x19, #0x30\n" + "tbz x13, #0, 88f\n" + "ldr s16, [x12, #0x0]\n" + "ldr s15, [x24, #0x0]\n" + "ldr s24, [x23, #0x0]\n" + "b 88f\n" + "82:" // Height 3: Partial accumulate: partial_2_8 + "tbz x13, #1, 83f\n" + "ldr d11, [x12], #0x8\n" + "ldr d14, [x24], #0x8\n" + "mov x19, #0x28\n" + "ldr d19, [x23], #0x8\n" + "tbz x13, #0, 88f\n" + "ld1 { v11.s }[2], [x12]\n" + "ld1 { v14.s }[2], [x24]\n" + "ld1 { v19.s }[2], [x23]\n" + "b 88f\n" + "83:" // Height 3: Partial accumulate: partial_1_8 + "mov x19, #0x20\n" + "tbz x13, #0, 88f\n" + "ldr s11, [x12, #0x0]\n" + "ldr s14, [x24, #0x0]\n" + "ldr s19, [x23, #0x0]\n" + "b 88f\n" + "84:" // Height 3: Partial accumulate: partial_4_0 + "tbz x13, #2, 86f\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "ld1 { v12.4s }, [x24], #0x10\n" + "ld1 { v17.4s }, [x23], #0x10\n" + "tbz x13, #1, 85f\n" + "ldr d10, [x12], #0x8\n" + "ldr d13, [x24], #0x8\n" + "mov x19, #0x18\n" + "ldr d18, [x23], #0x8\n" + "tbz x13, #0, 88f\n" + "ld1 { v10.s }[2], [x12]\n" + "ld1 { v13.s }[2], [x24]\n" + "ld1 { v18.s }[2], [x23]\n" + "b 88f\n" + "85:" // Height 3: Partial accumulate: partial_1_4 + "mov x19, #0x10\n" + "tbz x13, #0, 88f\n" + "ldr s10, [x12, #0x0]\n" + "ldr s13, [x24, #0x0]\n" + "ldr s18, [x23, #0x0]\n" + "b 88f\n" + "86:" // Height 3: Partial accumulate: partial_2_0 + "tbz x13, #1, 87f\n" + "ldr d9, [x12], #0x8\n" + "ldr d12, [x24], #0x8\n" + "mov x19, #0x8\n" + "ldr d17, [x23], #0x8\n" + "tbz x13, #0, 88f\n" + "ld1 { v9.s }[2], [x12]\n" + "ld1 { v12.s }[2], [x24]\n" + "ld1 { v17.s }[2], [x23]\n" + "b 88f\n" + "87:" // Height 3: Partial accumulate: partial_1_0 + "ldr s9, [x12, #0x0]\n" + "ldr s12, [x24, #0x0]\n" + "mov x19, #0x0\n" + "ldr s17, [x23, #0x0]\n" + "88:" // Height 3: Partial accumulate: Done + "sub x12, x12, x19\n" + "b 90f\n" + "89:" // Height 3: full accumulate + "ldr q9, [x12, #0x0]\n" + "ldr q10, [x12, #0x10]\n" + "ldr q11, [x12, #0x20]\n" + "ldr q16, [x12, #0x30]\n" + "ldr q12, [x24, #0x0]\n" + "ldr q13, [x24, #0x10]\n" + "ldr q14, [x24, #0x20]\n" + "ldr q15, [x24, #0x30]\n" + "ldr q17, [x23, #0x0]\n" + "ldr q18, [x23, #0x10]\n" + "ldr q19, [x23, #0x20]\n" + "ldr q24, [x23, #0x30]\n" + "90:" // Height 3: MMLA fixup + "zip1 v8.2d, v9.2d, v12.2d\n" + "zip2 v12.2d, v9.2d, v12.2d\n" + "zip1 v9.2d, v10.2d, v13.2d\n" + "zip2 v13.2d, v10.2d, v13.2d\n" + "zip1 v10.2d, v11.2d, v14.2d\n" + "zip2 v14.2d, v11.2d, v14.2d\n" + "zip1 v11.2d, v16.2d, v15.2d\n" + "zip2 v15.2d, v16.2d, v15.2d\n" + "zip1 v16.2d, v17.2d, v20.2d\n" + "zip2 v20.2d, v17.2d, v20.2d\n" + "zip1 v17.2d, v18.2d, v21.2d\n" + "zip2 v21.2d, v18.2d, v21.2d\n" + "zip1 v18.2d, v19.2d, v22.2d\n" + "zip2 v22.2d, v19.2d, v22.2d\n" + "zip1 v19.2d, v24.2d, v23.2d\n" + "zip2 v23.2d, v24.2d, v23.2d\n" + "b 92f\n" + "91:" // Height 3: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "92:" // Height 3: setup done + "mov x27, #0x0\n" + "93:" // Height 3: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 94f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "cbnz x27, 95f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "add x24, x24, x19, LSL #1\n" + "add x23, x23, x19, LSL #1\n" + "b 95f\n" + "94:" // Height 3: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "95:" // Height 3: input setup done + "cmp x26, #0x8\n" + "blt 98f\n" + "ldr q1, [x25, #0x0]\n" + "ldr q2, [x24, #0x0]\n" + "cmp x26, #0x10\n" + "ldr q3, [x23, #0x0]\n" + "ldr q7, [x11, #0x0]\n" + "ldr q6, [x11, #0x10]\n" + "blt 97f\n" + "96:" // Height 3: Multiply loop: Main loop head + "trn1 v0.2d, v1.2d, v2.2d\n" + "trn2 v1.2d, v1.2d, v2.2d\n" + ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" + "trn1 v2.2d, v3.2d, v4.2d\n" + ".inst 0x6e47ec50 // bfmmla v16.4s, v2.8h, v7.8h\n" + "ldr q7, [x10, #0x0]\n" + ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec54 // bfmmla v20.4s, v2.8h, v6.8h\n" + "ldr q6, [x10, #0x10]\n" + ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n" + "trn2 v3.2d, v3.2d, v4.2d\n" + ".inst 0x6e47ec51 // bfmmla v17.4s, v2.8h, v7.8h\n" + "ldr q7, [x9, #0x0]\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + "sub x26, x26, #0x8\n" + ".inst 0x6e46ec55 // bfmmla v21.4s, v2.8h, v6.8h\n" + "ldr q6, [x9, #0x10]\n" + ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n" + "cmp x26, #0x10\n" + ".inst 0x6e47ec52 // bfmmla v18.4s, v2.8h, v7.8h\n" + "ldr q7, [x28, #0x0]\n" + ".inst 0x6e46ec0e // bfmmla v14.4s, v0.8h, v6.8h\n" + "add x25, x25, #0x10\n" + ".inst 0x6e46ec56 // bfmmla v22.4s, v2.8h, v6.8h\n" + "ldr q6, [x28, #0x10]\n" + ".inst 0x6e47ec0b // bfmmla v11.4s, v0.8h, v7.8h\n" + "add x24, x24, #0x10\n" + ".inst 0x6e47ec53 // bfmmla v19.4s, v2.8h, v7.8h\n" + "ldr q7, [x11, #0x20]\n" + ".inst 0x6e46ec0f // bfmmla v15.4s, v0.8h, v6.8h\n" + "add x23, x23, #0x10\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + "ldr q6, [x11, #0x30]\n" + "ldr q2, [x24, #0x0]\n" + ".inst 0x6e47ec28 // bfmmla v8.4s, v1.8h, v7.8h\n" + ".inst 0x6e47ec70 // bfmmla v16.4s, v3.8h, v7.8h\n" + "ldr q7, [x10, #0x20]\n" + ".inst 0x6e46ec2c // bfmmla v12.4s, v1.8h, v6.8h\n" + "add x11, x11, #0x40\n" + ".inst 0x6e46ec74 // bfmmla v20.4s, v3.8h, v6.8h\n" + "ldr q6, [x10, #0x30]\n" + ".inst 0x6e47ec29 // bfmmla v9.4s, v1.8h, v7.8h\n" + "add x10, x10, #0x40\n" + ".inst 0x6e47ec71 // bfmmla v17.4s, v3.8h, v7.8h\n" + "ldr q7, [x9, #0x20]\n" + ".inst 0x6e46ec2d // bfmmla v13.4s, v1.8h, v6.8h\n" + ".inst 0x6e46ec75 // bfmmla v21.4s, v3.8h, v6.8h\n" + "ldr q6, [x9, #0x30]\n" + ".inst 0x6e47ec2a // bfmmla v10.4s, v1.8h, v7.8h\n" + "add x9, x9, #0x40\n" + ".inst 0x6e47ec72 // bfmmla v18.4s, v3.8h, v7.8h\n" + "ldr q7, [x28, #0x20]\n" + ".inst 0x6e46ec2e // bfmmla v14.4s, v1.8h, v6.8h\n" + ".inst 0x6e46ec76 // bfmmla v22.4s, v3.8h, v6.8h\n" + "ldr q6, [x28, #0x30]\n" + "add x28, x28, #0x40\n" + ".inst 0x6e47ec2b // bfmmla v11.4s, v1.8h, v7.8h\n" + ".inst 0x6e47ec73 // bfmmla v19.4s, v3.8h, v7.8h\n" + "ldr q7, [x11, #0x0]\n" + ".inst 0x6e46ec2f // bfmmla v15.4s, v1.8h, v6.8h\n" + "ldr q1, [x25, #0x0]\n" + ".inst 0x6e46ec77 // bfmmla v23.4s, v3.8h, v6.8h\n" + "ldr q3, [x23, #0x0]\n" + "ldr q6, [x11, #0x10]\n" + "bge 96b\n" + "97:" // Height 3: Multiply loop: Single iteration only + "trn1 v0.2d, v1.2d, v2.2d\n" + "trn2 v1.2d, v1.2d, v2.2d\n" + ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" + "trn1 v2.2d, v3.2d, v4.2d\n" + ".inst 0x6e47ec50 // bfmmla v16.4s, v2.8h, v7.8h\n" + "ldr q7, [x10, #0x0]\n" + ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec54 // bfmmla v20.4s, v2.8h, v6.8h\n" + "ldr q6, [x10, #0x10]\n" + ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n" + "trn2 v3.2d, v3.2d, v4.2d\n" + ".inst 0x6e47ec51 // bfmmla v17.4s, v2.8h, v7.8h\n" + "ldr q7, [x9, #0x0]\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + "sub x26, x26, #0x8\n" + ".inst 0x6e46ec55 // bfmmla v21.4s, v2.8h, v6.8h\n" + "ldr q6, [x9, #0x10]\n" + ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n" + "add x25, x25, #0x10\n" + ".inst 0x6e47ec52 // bfmmla v18.4s, v2.8h, v7.8h\n" + "ldr q7, [x28, #0x0]\n" + ".inst 0x6e46ec0e // bfmmla v14.4s, v0.8h, v6.8h\n" + "add x24, x24, #0x10\n" + ".inst 0x6e46ec56 // bfmmla v22.4s, v2.8h, v6.8h\n" + "ldr q6, [x28, #0x10]\n" + ".inst 0x6e47ec0b // bfmmla v11.4s, v0.8h, v7.8h\n" + "add x23, x23, #0x10\n" + ".inst 0x6e47ec53 // bfmmla v19.4s, v2.8h, v7.8h\n" + "ldr q7, [x11, #0x20]\n" + ".inst 0x6e46ec0f // bfmmla v15.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + "ldr q6, [x11, #0x30]\n" + ".inst 0x6e47ec28 // bfmmla v8.4s, v1.8h, v7.8h\n" + "add x11, x11, #0x40\n" + ".inst 0x6e47ec70 // bfmmla v16.4s, v3.8h, v7.8h\n" + "ldr q7, [x10, #0x20]\n" + ".inst 0x6e46ec2c // bfmmla v12.4s, v1.8h, v6.8h\n" + ".inst 0x6e46ec74 // bfmmla v20.4s, v3.8h, v6.8h\n" + "ldr q6, [x10, #0x30]\n" + ".inst 0x6e47ec29 // bfmmla v9.4s, v1.8h, v7.8h\n" + "add x10, x10, #0x40\n" + ".inst 0x6e47ec71 // bfmmla v17.4s, v3.8h, v7.8h\n" + "ldr q7, [x9, #0x20]\n" + ".inst 0x6e46ec2d // bfmmla v13.4s, v1.8h, v6.8h\n" + ".inst 0x6e46ec75 // bfmmla v21.4s, v3.8h, v6.8h\n" + "ldr q6, [x9, #0x30]\n" + ".inst 0x6e47ec2a // bfmmla v10.4s, v1.8h, v7.8h\n" + "add x9, x9, #0x40\n" + ".inst 0x6e47ec72 // bfmmla v18.4s, v3.8h, v7.8h\n" + "ldr q7, [x28, #0x20]\n" + ".inst 0x6e46ec2e // bfmmla v14.4s, v1.8h, v6.8h\n" + ".inst 0x6e46ec76 // bfmmla v22.4s, v3.8h, v6.8h\n" + "ldr q6, [x28, #0x30]\n" + "add x28, x28, #0x40\n" + ".inst 0x6e47ec2b // bfmmla v11.4s, v1.8h, v7.8h\n" + ".inst 0x6e47ec73 // bfmmla v19.4s, v3.8h, v7.8h\n" + ".inst 0x6e46ec2f // bfmmla v15.4s, v1.8h, v6.8h\n" + ".inst 0x6e46ec77 // bfmmla v23.4s, v3.8h, v6.8h\n" + "98:" // Height 3: Multiply loop: Main loop skip + "cbz x26, 103f\n" + "cmp x26, #0x4\n" + "blt 100f\n" + "99:" // Height 3: Multiply loop: Odd block loop + "ldr d1, [x25], #0x8\n" + "ldr d2, [x24], #0x8\n" + "trn1 v0.2d, v1.2d, v2.2d\n" + "ldr d3, [x23], #0x8\n" + "ldr q6, [x11, #0x0]\n" + "trn1 v2.2d, v3.2d, v4.2d\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + "ldr q7, [x11, #0x10]\n" + ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n" + "ldr q6, [x10, #0x0]\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" + "ldr q7, [x10, #0x10]\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + "sub x26, x26, #0x4\n" + ".inst 0x6e46ec51 // bfmmla v17.4s, v2.8h, v6.8h\n" + "ldr q6, [x9, #0x0]\n" + ".inst 0x6e47ec0d // bfmmla v13.4s, v0.8h, v7.8h\n" + "cmp x26, #0x4\n" + ".inst 0x6e47ec55 // bfmmla v21.4s, v2.8h, v7.8h\n" + "ldr q7, [x9, #0x10]\n" + ".inst 0x6e46ec0a // bfmmla v10.4s, v0.8h, v6.8h\n" + "add x11, x11, #0x20\n" + ".inst 0x6e46ec52 // bfmmla v18.4s, v2.8h, v6.8h\n" + "ldr q6, [x28, #0x0]\n" + ".inst 0x6e47ec0e // bfmmla v14.4s, v0.8h, v7.8h\n" + "add x10, x10, #0x20\n" + ".inst 0x6e47ec56 // bfmmla v22.4s, v2.8h, v7.8h\n" + "ldr q7, [x28, #0x10]\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + "add x9, x9, #0x20\n" + ".inst 0x6e46ec53 // bfmmla v19.4s, v2.8h, v6.8h\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + "add x28, x28, #0x20\n" + ".inst 0x6e47ec57 // bfmmla v23.4s, v2.8h, v7.8h\n" + "bge 99b\n" + "100:" // Height 3: Multiply loop: Skip odd blocks + "cbz x26, 103f\n" + "tbz x26, #1, 101f\n" + "ldr s1, [x25], #0x4\n" + "ldr s2, [x24], #0x4\n" + "ldr s3, [x23], #0x4\n" + "tbz x26, #0, 102f\n" + "ld1 { v1.h }[2], [x25]\n" + "ld1 { v2.h }[2], [x24]\n" + "ld1 { v3.h }[2], [x23]\n" + "b 102f\n" + "101:" // Height 3: Multiply loop: Ragged operand read: partial_1_0 + "ldr h1, [x25, #0x0]\n" + "ldr h2, [x24, #0x0]\n" + "ldr h3, [x23, #0x0]\n" + "102:" // Height 3: Multiply loop: Ragged operand read: Done + "ldr q7, [x11, #0x0]\n" + "ldr q6, [x11, #0x10]\n" + "trn1 v0.2d, v1.2d, v2.2d\n" + "trn1 v2.2d, v3.2d, v4.2d\n" + ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec50 // bfmmla v16.4s, v2.8h, v7.8h\n" + "ldr q7, [x10, #0x0]\n" + ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec54 // bfmmla v20.4s, v2.8h, v6.8h\n" + "ldr q6, [x10, #0x10]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec51 // bfmmla v17.4s, v2.8h, v7.8h\n" + "ldr q7, [x9, #0x0]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec55 // bfmmla v21.4s, v2.8h, v6.8h\n" + "ldr q6, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec52 // bfmmla v18.4s, v2.8h, v7.8h\n" + "ldr q7, [x28, #0x0]\n" + ".inst 0x6e46ec0e // bfmmla v14.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec56 // bfmmla v22.4s, v2.8h, v6.8h\n" + "ldr q6, [x28, #0x10]\n" + "add x28, x28, #0x20\n" + ".inst 0x6e47ec0b // bfmmla v11.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec53 // bfmmla v19.4s, v2.8h, v7.8h\n" + ".inst 0x6e46ec0f // bfmmla v15.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + "103:" // Height 3: Multiply loop: No odd multiplies + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 93b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "uzp1 v7.2d, v8.2d, v12.2d\n" + "uzp2 v8.2d, v8.2d, v12.2d\n" + "uzp1 v12.2d, v9.2d, v13.2d\n" + "uzp2 v9.2d, v9.2d, v13.2d\n" + "add x23, x24, x19, LSL #2\n" + "uzp1 v13.2d, v10.2d, v14.2d\n" + "uzp2 v10.2d, v10.2d, v14.2d\n" + "uzp1 v14.2d, v11.2d, v15.2d\n" + "uzp2 v11.2d, v11.2d, v15.2d\n" + "uzp1 v16.2d, v16.2d, v20.2d\n" + "uzp1 v17.2d, v17.2d, v21.2d\n" + "uzp1 v18.2d, v18.2d, v22.2d\n" + "uzp1 v19.2d, v19.2d, v23.2d\n" + "tbz %x[flags], #1, 104f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1r { v1.4s }, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1r { v0.4s }, [x19]\n" + "fmin v7.4s, v7.4s, v1.4s\n" + "fmin v12.4s, v12.4s, v1.4s\n" + "fmin v13.4s, v13.4s, v1.4s\n" + "fmin v14.4s, v14.4s, v1.4s\n" + "fmin v8.4s, v8.4s, v1.4s\n" + "fmin v9.4s, v9.4s, v1.4s\n" + "fmin v10.4s, v10.4s, v1.4s\n" + "fmin v11.4s, v11.4s, v1.4s\n" + "fmin v16.4s, v16.4s, v1.4s\n" + "fmin v17.4s, v17.4s, v1.4s\n" + "fmin v18.4s, v18.4s, v1.4s\n" + "fmin v19.4s, v19.4s, v1.4s\n" + "fmax v7.4s, v7.4s, v0.4s\n" + "fmax v12.4s, v12.4s, v0.4s\n" + "fmax v13.4s, v13.4s, v0.4s\n" + "fmax v14.4s, v14.4s, v0.4s\n" + "fmax v8.4s, v8.4s, v0.4s\n" + "fmax v9.4s, v9.4s, v0.4s\n" + "fmax v10.4s, v10.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + "fmax v16.4s, v16.4s, v0.4s\n" + "fmax v17.4s, v17.4s, v0.4s\n" + "fmax v18.4s, v18.4s, v0.4s\n" + "fmax v19.4s, v19.4s, v0.4s\n" + "104:" // Height 3: No activation + "cmp x13, #0x10\n" + "bge 113f\n" + "tbz x13, #3, 108f\n" + "st1 { v7.4s }, [x12], #0x10\n" + "st1 { v12.4s }, [x12], #0x10\n" + "st1 { v8.4s }, [x24], #0x10\n" + "st1 { v9.4s }, [x24], #0x10\n" + "st1 { v16.4s }, [x23], #0x10\n" + "st1 { v17.4s }, [x23], #0x10\n" + "tbz x13, #2, 106f\n" + "st1 { v13.4s }, [x12], #0x10\n" + "st1 { v10.4s }, [x24], #0x10\n" + "st1 { v18.4s }, [x23], #0x10\n" + "tbz x13, #1, 105f\n" + "str d14, [x12], #0x8\n" + "str d11, [x24], #0x8\n" + "str d19, [x23], #0x8\n" + "tbz x13, #0, 112f\n" + "st1 { v14.s }[2], [x12]\n" + "st1 { v11.s }[2], [x24]\n" + "st1 { v19.s }[2], [x23]\n" + "b 112f\n" + "105:" // Height 3: Partial direct writeback: partial_1_12 + "tbz x13, #0, 112f\n" + "str s14, [x12, #0x0]\n" + "str s11, [x24, #0x0]\n" + "str s19, [x23, #0x0]\n" + "b 112f\n" + "106:" // Height 3: Partial direct writeback: partial_2_8 + "tbz x13, #1, 107f\n" + "str d13, [x12], #0x8\n" + "str d10, [x24], #0x8\n" + "str d18, [x23], #0x8\n" + "tbz x13, #0, 112f\n" + "st1 { v13.s }[2], [x12]\n" + "st1 { v10.s }[2], [x24]\n" + "st1 { v18.s }[2], [x23]\n" + "b 112f\n" + "107:" // Height 3: Partial direct writeback: partial_1_8 + "tbz x13, #0, 112f\n" + "str s13, [x12, #0x0]\n" + "str s10, [x24, #0x0]\n" + "str s18, [x23, #0x0]\n" + "b 112f\n" + "108:" // Height 3: Partial direct writeback: partial_4_0 + "tbz x13, #2, 110f\n" + "st1 { v7.4s }, [x12], #0x10\n" + "st1 { v8.4s }, [x24], #0x10\n" + "st1 { v16.4s }, [x23], #0x10\n" + "tbz x13, #1, 109f\n" + "str d12, [x12], #0x8\n" + "str d9, [x24], #0x8\n" + "str d17, [x23], #0x8\n" + "tbz x13, #0, 112f\n" + "st1 { v12.s }[2], [x12]\n" + "st1 { v9.s }[2], [x24]\n" + "st1 { v17.s }[2], [x23]\n" + "b 112f\n" + "109:" // Height 3: Partial direct writeback: partial_1_4 + "tbz x13, #0, 112f\n" + "str s12, [x12, #0x0]\n" + "str s9, [x24, #0x0]\n" + "str s17, [x23, #0x0]\n" + "b 112f\n" + "110:" // Height 3: Partial direct writeback: partial_2_0 + "tbz x13, #1, 111f\n" + "str d7, [x12], #0x8\n" + "str d8, [x24], #0x8\n" + "str d16, [x23], #0x8\n" + "tbz x13, #0, 112f\n" + "st1 { v7.s }[2], [x12]\n" + "st1 { v8.s }[2], [x24]\n" + "st1 { v16.s }[2], [x23]\n" + "b 112f\n" + "111:" // Height 3: Partial direct writeback: partial_1_0 + "str s7, [x12, #0x0]\n" + "str s8, [x24, #0x0]\n" + "str s16, [x23, #0x0]\n" + "112:" // Height 3: Partial direct writeback: Done + "b 114f\n" + "113:" // Height 3: Full writeback + "str q7, [x12, #0x0]\n" + "str q12, [x12, #0x10]\n" + "str q13, [x12, #0x20]\n" + "str q14, [x12, #0x30]\n" + "add x12, x12, #0x40\n" + "str q8, [x24, #0x0]\n" + "str q9, [x24, #0x10]\n" + "str q10, [x24, #0x20]\n" + "str q11, [x24, #0x30]\n" + "str q16, [x23, #0x0]\n" + "str q17, [x23, #0x10]\n" + "str q18, [x23, #0x20]\n" + "str q19, [x23, #0x30]\n" + "114:" // Height 3: Writeback done + "subs x13, x13, #0x10\n" + "bgt 78b\n" + "b 230f\n" + "115:" // Height 4 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "116:" // Height 4: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "cmp x13, #0xc\n" + "bgt 117f\n" + "cmp x13, #0x8\n" + "mov x28, x11\n" + "bgt 117f\n" + "cmp x13, #0x4\n" + "mov x9, x11\n" + "bgt 117f\n" + "mov x10, x11\n" + "117:" // Height 4: B setup done + "cbz x14, 118f\n" + "ldr q8, [x14, #0x0]\n" + "ldr q9, [x14, #0x10]\n" + "zip2 v12.2d, v8.2d, v8.2d\n" + "zip1 v8.2d, v8.2d, v8.2d\n" + "ldr q10, [x14, #0x20]\n" + "ldr q11, [x14, #0x30]\n" + "zip2 v13.2d, v9.2d, v9.2d\n" + "zip1 v9.2d, v9.2d, v9.2d\n" + "zip2 v14.2d, v10.2d, v10.2d\n" + "zip1 v10.2d, v10.2d, v10.2d\n" + "add x14, x14, #0x40\n" + "zip2 v15.2d, v11.2d, v11.2d\n" + "zip1 v11.2d, v11.2d, v11.2d\n" + "mov v16.16b, v8.16b\n" + "mov v20.16b, v12.16b\n" + "mov v17.16b, v9.16b\n" + "mov v21.16b, v13.16b\n" + "mov v18.16b, v10.16b\n" + "mov v22.16b, v14.16b\n" + "mov v19.16b, v11.16b\n" + "mov v23.16b, v15.16b\n" + "b 130f\n" + "118:" // Height 4: no bias + "tbz %x[flags], #0, 129f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "cmp x13, #0x10\n" + "add x22, x23, x19, LSL #2\n" + "bge 127f\n" + "tbz x13, #3, 122f\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "ld1 { v12.4s }, [x24], #0x10\n" + "ld1 { v17.4s }, [x23], #0x10\n" + "ld1 { v20.4s }, [x22], #0x10\n" + "ld1 { v10.4s }, [x12], #0x10\n" + "ld1 { v13.4s }, [x24], #0x10\n" + "ld1 { v18.4s }, [x23], #0x10\n" + "ld1 { v21.4s }, [x22], #0x10\n" + "tbz x13, #2, 120f\n" + "ld1 { v11.4s }, [x12], #0x10\n" + "ld1 { v14.4s }, [x24], #0x10\n" + "ld1 { v19.4s }, [x23], #0x10\n" + "ld1 { v22.4s }, [x22], #0x10\n" + "tbz x13, #1, 119f\n" + "ldr d16, [x12], #0x8\n" + "ldr d15, [x24], #0x8\n" + "mov x19, #0x38\n" + "ldr d24, [x23], #0x8\n" + "ldr d23, [x22], #0x8\n" + "tbz x13, #0, 126f\n" + "ld1 { v16.s }[2], [x12]\n" + "ld1 { v15.s }[2], [x24]\n" + "ld1 { v24.s }[2], [x23]\n" + "ld1 { v23.s }[2], [x22]\n" + "b 126f\n" + "119:" // Height 4: Partial accumulate: partial_1_12 + "mov x19, #0x30\n" + "tbz x13, #0, 126f\n" + "ldr s16, [x12, #0x0]\n" + "ldr s15, [x24, #0x0]\n" + "ldr s24, [x23, #0x0]\n" + "ldr s23, [x22, #0x0]\n" + "b 126f\n" + "120:" // Height 4: Partial accumulate: partial_2_8 + "tbz x13, #1, 121f\n" + "ldr d11, [x12], #0x8\n" + "ldr d14, [x24], #0x8\n" + "mov x19, #0x28\n" + "ldr d19, [x23], #0x8\n" + "ldr d22, [x22], #0x8\n" + "tbz x13, #0, 126f\n" + "ld1 { v11.s }[2], [x12]\n" + "ld1 { v14.s }[2], [x24]\n" + "ld1 { v19.s }[2], [x23]\n" + "ld1 { v22.s }[2], [x22]\n" + "b 126f\n" + "121:" // Height 4: Partial accumulate: partial_1_8 + "mov x19, #0x20\n" + "tbz x13, #0, 126f\n" + "ldr s11, [x12, #0x0]\n" + "ldr s14, [x24, #0x0]\n" + "ldr s19, [x23, #0x0]\n" + "ldr s22, [x22, #0x0]\n" + "b 126f\n" + "122:" // Height 4: Partial accumulate: partial_4_0 + "tbz x13, #2, 124f\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "ld1 { v12.4s }, [x24], #0x10\n" + "ld1 { v17.4s }, [x23], #0x10\n" + "ld1 { v20.4s }, [x22], #0x10\n" + "tbz x13, #1, 123f\n" + "ldr d10, [x12], #0x8\n" + "ldr d13, [x24], #0x8\n" + "mov x19, #0x18\n" + "ldr d18, [x23], #0x8\n" + "ldr d21, [x22], #0x8\n" + "tbz x13, #0, 126f\n" + "ld1 { v10.s }[2], [x12]\n" + "ld1 { v13.s }[2], [x24]\n" + "ld1 { v18.s }[2], [x23]\n" + "ld1 { v21.s }[2], [x22]\n" + "b 126f\n" + "123:" // Height 4: Partial accumulate: partial_1_4 + "mov x19, #0x10\n" + "tbz x13, #0, 126f\n" + "ldr s10, [x12, #0x0]\n" + "ldr s13, [x24, #0x0]\n" + "ldr s18, [x23, #0x0]\n" + "ldr s21, [x22, #0x0]\n" + "b 126f\n" + "124:" // Height 4: Partial accumulate: partial_2_0 + "tbz x13, #1, 125f\n" + "ldr d9, [x12], #0x8\n" + "ldr d12, [x24], #0x8\n" + "mov x19, #0x8\n" + "ldr d17, [x23], #0x8\n" + "ldr d20, [x22], #0x8\n" + "tbz x13, #0, 126f\n" + "ld1 { v9.s }[2], [x12]\n" + "ld1 { v12.s }[2], [x24]\n" + "ld1 { v17.s }[2], [x23]\n" + "ld1 { v20.s }[2], [x22]\n" + "b 126f\n" + "125:" // Height 4: Partial accumulate: partial_1_0 + "ldr s9, [x12, #0x0]\n" + "ldr s12, [x24, #0x0]\n" + "mov x19, #0x0\n" + "ldr s17, [x23, #0x0]\n" + "ldr s20, [x22, #0x0]\n" + "126:" // Height 4: Partial accumulate: Done + "sub x12, x12, x19\n" + "b 128f\n" + "127:" // Height 4: full accumulate + "ldr q9, [x12, #0x0]\n" + "ldr q10, [x12, #0x10]\n" + "ldr q11, [x12, #0x20]\n" + "ldr q16, [x12, #0x30]\n" + "ldr q12, [x24, #0x0]\n" + "ldr q13, [x24, #0x10]\n" + "ldr q14, [x24, #0x20]\n" + "ldr q15, [x24, #0x30]\n" + "ldr q17, [x23, #0x0]\n" + "ldr q18, [x23, #0x10]\n" + "ldr q19, [x23, #0x20]\n" + "ldr q24, [x23, #0x30]\n" + "ldr q20, [x22, #0x0]\n" + "ldr q21, [x22, #0x10]\n" + "ldr q22, [x22, #0x20]\n" + "ldr q23, [x22, #0x30]\n" + "128:" // Height 4: MMLA fixup + "zip1 v8.2d, v9.2d, v12.2d\n" + "zip2 v12.2d, v9.2d, v12.2d\n" + "zip1 v9.2d, v10.2d, v13.2d\n" + "zip2 v13.2d, v10.2d, v13.2d\n" + "zip1 v10.2d, v11.2d, v14.2d\n" + "zip2 v14.2d, v11.2d, v14.2d\n" + "zip1 v11.2d, v16.2d, v15.2d\n" + "zip2 v15.2d, v16.2d, v15.2d\n" + "zip1 v16.2d, v17.2d, v20.2d\n" + "zip2 v20.2d, v17.2d, v20.2d\n" + "zip1 v17.2d, v18.2d, v21.2d\n" + "zip2 v21.2d, v18.2d, v21.2d\n" + "zip1 v18.2d, v19.2d, v22.2d\n" + "zip2 v22.2d, v19.2d, v22.2d\n" + "zip1 v19.2d, v24.2d, v23.2d\n" + "zip2 v23.2d, v24.2d, v23.2d\n" + "b 130f\n" + "129:" // Height 4: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "130:" // Height 4: setup done + "mov x27, #0x0\n" + "131:" // Height 4: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 132f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "ldr x22, [x20, #0x18]\n" + "cbnz x27, 133f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "add x24, x24, x19, LSL #1\n" + "add x23, x23, x19, LSL #1\n" + "add x22, x22, x19, LSL #1\n" + "b 133f\n" + "132:" // Height 4: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "add x22, x23, x19, LSL #1\n" + "133:" // Height 4: input setup done + "cmp x26, #0x8\n" + "blt 136f\n" + "ldr q1, [x25, #0x0]\n" + "ldr q2, [x24, #0x0]\n" + "cmp x26, #0x10\n" + "ldr q3, [x23, #0x0]\n" + "ldr q4, [x22, #0x0]\n" + "ldr q7, [x11, #0x0]\n" + "ldr q6, [x11, #0x10]\n" + "blt 135f\n" + "134:" // Height 4: Multiply loop: Main loop head + "trn1 v0.2d, v1.2d, v2.2d\n" + "trn2 v1.2d, v1.2d, v2.2d\n" + ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" + "sub x26, x26, #0x8\n" + "trn1 v2.2d, v3.2d, v4.2d\n" + ".inst 0x6e47ec50 // bfmmla v16.4s, v2.8h, v7.8h\n" + "ldr q7, [x10, #0x0]\n" + ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec54 // bfmmla v20.4s, v2.8h, v6.8h\n" + "ldr q6, [x10, #0x10]\n" + ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n" + "trn2 v3.2d, v3.2d, v4.2d\n" + ".inst 0x6e47ec51 // bfmmla v17.4s, v2.8h, v7.8h\n" + "ldr q7, [x9, #0x0]\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + "cmp x26, #0x10\n" + ".inst 0x6e46ec55 // bfmmla v21.4s, v2.8h, v6.8h\n" + "ldr q6, [x9, #0x10]\n" + ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n" + "add x25, x25, #0x10\n" + ".inst 0x6e47ec52 // bfmmla v18.4s, v2.8h, v7.8h\n" + "ldr q7, [x28, #0x0]\n" + ".inst 0x6e46ec0e // bfmmla v14.4s, v0.8h, v6.8h\n" + "add x24, x24, #0x10\n" + ".inst 0x6e46ec56 // bfmmla v22.4s, v2.8h, v6.8h\n" + "ldr q6, [x28, #0x10]\n" + ".inst 0x6e47ec0b // bfmmla v11.4s, v0.8h, v7.8h\n" + "add x23, x23, #0x10\n" + ".inst 0x6e47ec53 // bfmmla v19.4s, v2.8h, v7.8h\n" + "ldr q7, [x11, #0x20]\n" + ".inst 0x6e46ec0f // bfmmla v15.4s, v0.8h, v6.8h\n" + "add x22, x22, #0x10\n" + "ldr q4, [x22, #0x0]\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + "ldr q6, [x11, #0x30]\n" + ".inst 0x6e47ec28 // bfmmla v8.4s, v1.8h, v7.8h\n" + "ldr q2, [x24, #0x0]\n" + ".inst 0x6e47ec70 // bfmmla v16.4s, v3.8h, v7.8h\n" + "ldr q7, [x10, #0x20]\n" + ".inst 0x6e46ec2c // bfmmla v12.4s, v1.8h, v6.8h\n" + ".inst 0x6e46ec74 // bfmmla v20.4s, v3.8h, v6.8h\n" + "ldr q6, [x10, #0x30]\n" + ".inst 0x6e47ec29 // bfmmla v9.4s, v1.8h, v7.8h\n" + "add x11, x11, #0x40\n" + ".inst 0x6e47ec71 // bfmmla v17.4s, v3.8h, v7.8h\n" + "ldr q7, [x9, #0x20]\n" + ".inst 0x6e46ec2d // bfmmla v13.4s, v1.8h, v6.8h\n" + "add x10, x10, #0x40\n" + ".inst 0x6e46ec75 // bfmmla v21.4s, v3.8h, v6.8h\n" + "ldr q6, [x9, #0x30]\n" + ".inst 0x6e47ec2a // bfmmla v10.4s, v1.8h, v7.8h\n" + "add x9, x9, #0x40\n" + ".inst 0x6e47ec72 // bfmmla v18.4s, v3.8h, v7.8h\n" + "ldr q7, [x28, #0x20]\n" + ".inst 0x6e46ec2e // bfmmla v14.4s, v1.8h, v6.8h\n" + ".inst 0x6e46ec76 // bfmmla v22.4s, v3.8h, v6.8h\n" + "ldr q6, [x28, #0x30]\n" + "add x28, x28, #0x40\n" + ".inst 0x6e47ec2b // bfmmla v11.4s, v1.8h, v7.8h\n" + ".inst 0x6e47ec73 // bfmmla v19.4s, v3.8h, v7.8h\n" + "ldr q7, [x11, #0x0]\n" + ".inst 0x6e46ec2f // bfmmla v15.4s, v1.8h, v6.8h\n" + "ldr q1, [x25, #0x0]\n" + ".inst 0x6e46ec77 // bfmmla v23.4s, v3.8h, v6.8h\n" + "ldr q3, [x23, #0x0]\n" + "ldr q6, [x11, #0x10]\n" + "bge 134b\n" + "135:" // Height 4: Multiply loop: Single iteration only + "trn1 v0.2d, v1.2d, v2.2d\n" + "trn2 v1.2d, v1.2d, v2.2d\n" + ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" + "sub x26, x26, #0x8\n" + "trn1 v2.2d, v3.2d, v4.2d\n" + ".inst 0x6e47ec50 // bfmmla v16.4s, v2.8h, v7.8h\n" + "ldr q7, [x10, #0x0]\n" + ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec54 // bfmmla v20.4s, v2.8h, v6.8h\n" + "ldr q6, [x10, #0x10]\n" + ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n" + "trn2 v3.2d, v3.2d, v4.2d\n" + ".inst 0x6e47ec51 // bfmmla v17.4s, v2.8h, v7.8h\n" + "ldr q7, [x9, #0x0]\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + "add x25, x25, #0x10\n" + ".inst 0x6e46ec55 // bfmmla v21.4s, v2.8h, v6.8h\n" + "ldr q6, [x9, #0x10]\n" + ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n" + "add x24, x24, #0x10\n" + ".inst 0x6e47ec52 // bfmmla v18.4s, v2.8h, v7.8h\n" + "ldr q7, [x28, #0x0]\n" + ".inst 0x6e46ec0e // bfmmla v14.4s, v0.8h, v6.8h\n" + "add x23, x23, #0x10\n" + ".inst 0x6e46ec56 // bfmmla v22.4s, v2.8h, v6.8h\n" + "ldr q6, [x28, #0x10]\n" + ".inst 0x6e47ec0b // bfmmla v11.4s, v0.8h, v7.8h\n" + "add x22, x22, #0x10\n" + ".inst 0x6e47ec53 // bfmmla v19.4s, v2.8h, v7.8h\n" + "ldr q7, [x11, #0x20]\n" + ".inst 0x6e46ec0f // bfmmla v15.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + "ldr q6, [x11, #0x30]\n" + ".inst 0x6e47ec28 // bfmmla v8.4s, v1.8h, v7.8h\n" + "add x11, x11, #0x40\n" + ".inst 0x6e47ec70 // bfmmla v16.4s, v3.8h, v7.8h\n" + "ldr q7, [x10, #0x20]\n" + ".inst 0x6e46ec2c // bfmmla v12.4s, v1.8h, v6.8h\n" + ".inst 0x6e46ec74 // bfmmla v20.4s, v3.8h, v6.8h\n" + "ldr q6, [x10, #0x30]\n" + ".inst 0x6e47ec29 // bfmmla v9.4s, v1.8h, v7.8h\n" + "add x10, x10, #0x40\n" + ".inst 0x6e47ec71 // bfmmla v17.4s, v3.8h, v7.8h\n" + "ldr q7, [x9, #0x20]\n" + ".inst 0x6e46ec2d // bfmmla v13.4s, v1.8h, v6.8h\n" + ".inst 0x6e46ec75 // bfmmla v21.4s, v3.8h, v6.8h\n" + "ldr q6, [x9, #0x30]\n" + ".inst 0x6e47ec2a // bfmmla v10.4s, v1.8h, v7.8h\n" + "add x9, x9, #0x40\n" + ".inst 0x6e47ec72 // bfmmla v18.4s, v3.8h, v7.8h\n" + "ldr q7, [x28, #0x20]\n" + ".inst 0x6e46ec2e // bfmmla v14.4s, v1.8h, v6.8h\n" + ".inst 0x6e46ec76 // bfmmla v22.4s, v3.8h, v6.8h\n" + "ldr q6, [x28, #0x30]\n" + "add x28, x28, #0x40\n" + ".inst 0x6e47ec2b // bfmmla v11.4s, v1.8h, v7.8h\n" + ".inst 0x6e47ec73 // bfmmla v19.4s, v3.8h, v7.8h\n" + ".inst 0x6e46ec2f // bfmmla v15.4s, v1.8h, v6.8h\n" + ".inst 0x6e46ec77 // bfmmla v23.4s, v3.8h, v6.8h\n" + "136:" // Height 4: Multiply loop: Main loop skip + "cbz x26, 141f\n" + "cmp x26, #0x4\n" + "blt 138f\n" + "137:" // Height 4: Multiply loop: Odd block loop + "ldr d1, [x25], #0x8\n" + "ldr d2, [x24], #0x8\n" + "trn1 v0.2d, v1.2d, v2.2d\n" + "sub x26, x26, #0x4\n" + "ldr d3, [x23], #0x8\n" + "ldr d4, [x22], #0x8\n" + "trn1 v2.2d, v3.2d, v4.2d\n" + "cmp x26, #0x4\n" + "ldr q6, [x11, #0x0]\n" + "ldr q7, [x11, #0x10]\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n" + "ldr q6, [x10, #0x0]\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" + "ldr q7, [x10, #0x10]\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec51 // bfmmla v17.4s, v2.8h, v6.8h\n" + "ldr q6, [x9, #0x0]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e47ec0d // bfmmla v13.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec55 // bfmmla v21.4s, v2.8h, v7.8h\n" + "ldr q7, [x9, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e46ec0a // bfmmla v10.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec52 // bfmmla v18.4s, v2.8h, v6.8h\n" + "ldr q6, [x28, #0x0]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e47ec0e // bfmmla v14.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec56 // bfmmla v22.4s, v2.8h, v7.8h\n" + "ldr q7, [x28, #0x10]\n" + "add x28, x28, #0x20\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec53 // bfmmla v19.4s, v2.8h, v6.8h\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec57 // bfmmla v23.4s, v2.8h, v7.8h\n" + "bge 137b\n" + "138:" // Height 4: Multiply loop: Skip odd blocks + "cbz x26, 141f\n" + "tbz x26, #1, 139f\n" + "ldr s1, [x25], #0x4\n" + "ldr s2, [x24], #0x4\n" + "ldr s3, [x23], #0x4\n" + "ldr s4, [x22], #0x4\n" + "tbz x26, #0, 140f\n" + "ld1 { v1.h }[2], [x25]\n" + "ld1 { v2.h }[2], [x24]\n" + "ld1 { v3.h }[2], [x23]\n" + "ld1 { v4.h }[2], [x22]\n" + "b 140f\n" + "139:" // Height 4: Multiply loop: Ragged operand read: partial_1_0 + "ldr h1, [x25, #0x0]\n" + "ldr h2, [x24, #0x0]\n" + "ldr h3, [x23, #0x0]\n" + "ldr h4, [x22, #0x0]\n" + "140:" // Height 4: Multiply loop: Ragged operand read: Done + "ldr q7, [x11, #0x0]\n" + "ldr q6, [x11, #0x10]\n" + "trn1 v0.2d, v1.2d, v2.2d\n" + "trn1 v2.2d, v3.2d, v4.2d\n" + ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec50 // bfmmla v16.4s, v2.8h, v7.8h\n" + "ldr q7, [x10, #0x0]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec54 // bfmmla v20.4s, v2.8h, v6.8h\n" + "ldr q6, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec51 // bfmmla v17.4s, v2.8h, v7.8h\n" + "ldr q7, [x9, #0x0]\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec55 // bfmmla v21.4s, v2.8h, v6.8h\n" + "ldr q6, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec52 // bfmmla v18.4s, v2.8h, v7.8h\n" + "ldr q7, [x28, #0x0]\n" + ".inst 0x6e46ec0e // bfmmla v14.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec56 // bfmmla v22.4s, v2.8h, v6.8h\n" + "ldr q6, [x28, #0x10]\n" + "add x28, x28, #0x20\n" + ".inst 0x6e47ec0b // bfmmla v11.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec53 // bfmmla v19.4s, v2.8h, v7.8h\n" + ".inst 0x6e46ec0f // bfmmla v15.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + "141:" // Height 4: Multiply loop: No odd multiplies + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 131b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "uzp1 v7.2d, v8.2d, v12.2d\n" + "uzp2 v8.2d, v8.2d, v12.2d\n" + "uzp1 v12.2d, v9.2d, v13.2d\n" + "add x22, x23, x19, LSL #2\n" + "uzp2 v9.2d, v9.2d, v13.2d\n" + "uzp1 v13.2d, v10.2d, v14.2d\n" + "uzp2 v10.2d, v10.2d, v14.2d\n" + "uzp1 v14.2d, v11.2d, v15.2d\n" + "uzp2 v11.2d, v11.2d, v15.2d\n" + "uzp1 v15.2d, v16.2d, v20.2d\n" + "uzp2 v16.2d, v16.2d, v20.2d\n" + "uzp1 v20.2d, v17.2d, v21.2d\n" + "uzp2 v17.2d, v17.2d, v21.2d\n" + "uzp1 v21.2d, v18.2d, v22.2d\n" + "uzp2 v18.2d, v18.2d, v22.2d\n" + "uzp1 v22.2d, v19.2d, v23.2d\n" + "uzp2 v19.2d, v19.2d, v23.2d\n" + "tbz %x[flags], #1, 142f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1r { v1.4s }, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1r { v0.4s }, [x19]\n" + "fmin v7.4s, v7.4s, v1.4s\n" + "fmin v12.4s, v12.4s, v1.4s\n" + "fmin v13.4s, v13.4s, v1.4s\n" + "fmin v14.4s, v14.4s, v1.4s\n" + "fmin v8.4s, v8.4s, v1.4s\n" + "fmin v9.4s, v9.4s, v1.4s\n" + "fmin v10.4s, v10.4s, v1.4s\n" + "fmin v11.4s, v11.4s, v1.4s\n" + "fmin v15.4s, v15.4s, v1.4s\n" + "fmin v20.4s, v20.4s, v1.4s\n" + "fmin v21.4s, v21.4s, v1.4s\n" + "fmin v22.4s, v22.4s, v1.4s\n" + "fmin v16.4s, v16.4s, v1.4s\n" + "fmin v17.4s, v17.4s, v1.4s\n" + "fmin v18.4s, v18.4s, v1.4s\n" + "fmin v19.4s, v19.4s, v1.4s\n" + "fmax v7.4s, v7.4s, v0.4s\n" + "fmax v12.4s, v12.4s, v0.4s\n" + "fmax v13.4s, v13.4s, v0.4s\n" + "fmax v14.4s, v14.4s, v0.4s\n" + "fmax v8.4s, v8.4s, v0.4s\n" + "fmax v9.4s, v9.4s, v0.4s\n" + "fmax v10.4s, v10.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + "fmax v15.4s, v15.4s, v0.4s\n" + "fmax v20.4s, v20.4s, v0.4s\n" + "fmax v21.4s, v21.4s, v0.4s\n" + "fmax v22.4s, v22.4s, v0.4s\n" + "fmax v16.4s, v16.4s, v0.4s\n" + "fmax v17.4s, v17.4s, v0.4s\n" + "fmax v18.4s, v18.4s, v0.4s\n" + "fmax v19.4s, v19.4s, v0.4s\n" + "142:" // Height 4: No activation + "cmp x13, #0x10\n" + "bge 151f\n" + "tbz x13, #3, 146f\n" + "st1 { v7.4s }, [x12], #0x10\n" + "st1 { v12.4s }, [x12], #0x10\n" + "st1 { v8.4s }, [x24], #0x10\n" + "st1 { v9.4s }, [x24], #0x10\n" + "st1 { v15.4s }, [x23], #0x10\n" + "st1 { v20.4s }, [x23], #0x10\n" + "st1 { v16.4s }, [x22], #0x10\n" + "st1 { v17.4s }, [x22], #0x10\n" + "tbz x13, #2, 144f\n" + "st1 { v13.4s }, [x12], #0x10\n" + "st1 { v10.4s }, [x24], #0x10\n" + "st1 { v21.4s }, [x23], #0x10\n" + "st1 { v18.4s }, [x22], #0x10\n" + "tbz x13, #1, 143f\n" + "str d14, [x12], #0x8\n" + "str d11, [x24], #0x8\n" + "str d22, [x23], #0x8\n" + "str d19, [x22], #0x8\n" + "tbz x13, #0, 150f\n" + "st1 { v14.s }[2], [x12]\n" + "st1 { v11.s }[2], [x24]\n" + "st1 { v22.s }[2], [x23]\n" + "st1 { v19.s }[2], [x22]\n" + "b 150f\n" + "143:" // Height 4: Partial direct writeback: partial_1_12 + "tbz x13, #0, 150f\n" + "str s14, [x12, #0x0]\n" + "str s11, [x24, #0x0]\n" + "str s22, [x23, #0x0]\n" + "str s19, [x22, #0x0]\n" + "b 150f\n" + "144:" // Height 4: Partial direct writeback: partial_2_8 + "tbz x13, #1, 145f\n" + "str d13, [x12], #0x8\n" + "str d10, [x24], #0x8\n" + "str d21, [x23], #0x8\n" + "str d18, [x22], #0x8\n" + "tbz x13, #0, 150f\n" + "st1 { v13.s }[2], [x12]\n" + "st1 { v10.s }[2], [x24]\n" + "st1 { v21.s }[2], [x23]\n" + "st1 { v18.s }[2], [x22]\n" + "b 150f\n" + "145:" // Height 4: Partial direct writeback: partial_1_8 + "tbz x13, #0, 150f\n" + "str s13, [x12, #0x0]\n" + "str s10, [x24, #0x0]\n" + "str s21, [x23, #0x0]\n" + "str s18, [x22, #0x0]\n" + "b 150f\n" + "146:" // Height 4: Partial direct writeback: partial_4_0 + "tbz x13, #2, 148f\n" + "st1 { v7.4s }, [x12], #0x10\n" + "st1 { v8.4s }, [x24], #0x10\n" + "st1 { v15.4s }, [x23], #0x10\n" + "st1 { v16.4s }, [x22], #0x10\n" + "tbz x13, #1, 147f\n" + "str d12, [x12], #0x8\n" + "str d9, [x24], #0x8\n" + "str d20, [x23], #0x8\n" + "str d17, [x22], #0x8\n" + "tbz x13, #0, 150f\n" + "st1 { v12.s }[2], [x12]\n" + "st1 { v9.s }[2], [x24]\n" + "st1 { v20.s }[2], [x23]\n" + "st1 { v17.s }[2], [x22]\n" + "b 150f\n" + "147:" // Height 4: Partial direct writeback: partial_1_4 + "tbz x13, #0, 150f\n" + "str s12, [x12, #0x0]\n" + "str s9, [x24, #0x0]\n" + "str s20, [x23, #0x0]\n" + "str s17, [x22, #0x0]\n" + "b 150f\n" + "148:" // Height 4: Partial direct writeback: partial_2_0 + "tbz x13, #1, 149f\n" + "str d7, [x12], #0x8\n" + "str d8, [x24], #0x8\n" + "str d15, [x23], #0x8\n" + "str d16, [x22], #0x8\n" + "tbz x13, #0, 150f\n" + "st1 { v7.s }[2], [x12]\n" + "st1 { v8.s }[2], [x24]\n" + "st1 { v15.s }[2], [x23]\n" + "st1 { v16.s }[2], [x22]\n" + "b 150f\n" + "149:" // Height 4: Partial direct writeback: partial_1_0 + "str s7, [x12, #0x0]\n" + "str s8, [x24, #0x0]\n" + "str s15, [x23, #0x0]\n" + "str s16, [x22, #0x0]\n" + "150:" // Height 4: Partial direct writeback: Done + "b 152f\n" + "151:" // Height 4: Full writeback + "str q7, [x12, #0x0]\n" + "str q12, [x12, #0x10]\n" + "str q13, [x12, #0x20]\n" + "str q14, [x12, #0x30]\n" + "add x12, x12, #0x40\n" + "str q8, [x24, #0x0]\n" + "str q9, [x24, #0x10]\n" + "str q10, [x24, #0x20]\n" + "str q11, [x24, #0x30]\n" + "str q15, [x23, #0x0]\n" + "str q20, [x23, #0x10]\n" + "str q21, [x23, #0x20]\n" + "str q22, [x23, #0x30]\n" + "str q16, [x22, #0x0]\n" + "str q17, [x22, #0x10]\n" + "str q18, [x22, #0x20]\n" + "str q19, [x22, #0x30]\n" + "152:" // Height 4: Writeback done + "subs x13, x13, #0x10\n" + "bgt 116b\n" + "b 230f\n" + "153:" // Height 5 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "154:" // Height 5: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "cmp x13, #0xc\n" + "bgt 155f\n" + "cmp x13, #0x8\n" + "mov x28, x11\n" + "bgt 155f\n" + "cmp x13, #0x4\n" + "mov x9, x11\n" + "bgt 155f\n" + "mov x10, x11\n" + "155:" // Height 5: B setup done + "cbz x14, 156f\n" + "ldr q8, [x14, #0x0]\n" + "ldr q9, [x14, #0x10]\n" + "zip2 v12.2d, v8.2d, v8.2d\n" + "zip1 v8.2d, v8.2d, v8.2d\n" + "ldr q10, [x14, #0x20]\n" + "ldr q11, [x14, #0x30]\n" + "zip2 v13.2d, v9.2d, v9.2d\n" + "zip1 v9.2d, v9.2d, v9.2d\n" + "zip2 v14.2d, v10.2d, v10.2d\n" + "zip1 v10.2d, v10.2d, v10.2d\n" + "add x14, x14, #0x40\n" + "zip2 v15.2d, v11.2d, v11.2d\n" + "zip1 v11.2d, v11.2d, v11.2d\n" + "mov v16.16b, v8.16b\n" + "mov v20.16b, v12.16b\n" + "mov v17.16b, v9.16b\n" + "mov v21.16b, v13.16b\n" + "mov v18.16b, v10.16b\n" + "mov v22.16b, v14.16b\n" + "mov v19.16b, v11.16b\n" + "mov v23.16b, v15.16b\n" + "mov v24.16b, v8.16b\n" + "mov v28.16b, v12.16b\n" + "mov v25.16b, v9.16b\n" + "mov v29.16b, v13.16b\n" + "mov v26.16b, v10.16b\n" + "mov v30.16b, v14.16b\n" + "mov v27.16b, v11.16b\n" + "mov v31.16b, v15.16b\n" + "b 168f\n" + "156:" // Height 5: no bias + "tbz %x[flags], #0, 167f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "add x22, x23, x19, LSL #2\n" + "cmp x13, #0x10\n" + "add x21, x22, x19, LSL #2\n" + "bge 165f\n" + "tbz x13, #3, 160f\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "ld1 { v12.4s }, [x24], #0x10\n" + "ld1 { v17.4s }, [x23], #0x10\n" + "ld1 { v20.4s }, [x22], #0x10\n" + "ld1 { v25.4s }, [x21], #0x10\n" + "ld1 { v10.4s }, [x12], #0x10\n" + "ld1 { v13.4s }, [x24], #0x10\n" + "ld1 { v18.4s }, [x23], #0x10\n" + "ld1 { v21.4s }, [x22], #0x10\n" + "ld1 { v26.4s }, [x21], #0x10\n" + "tbz x13, #2, 158f\n" + "ld1 { v11.4s }, [x12], #0x10\n" + "ld1 { v14.4s }, [x24], #0x10\n" + "ld1 { v19.4s }, [x23], #0x10\n" + "ld1 { v22.4s }, [x22], #0x10\n" + "ld1 { v27.4s }, [x21], #0x10\n" + "tbz x13, #1, 157f\n" + "ldr d16, [x12], #0x8\n" + "ldr d15, [x24], #0x8\n" + "mov x19, #0x38\n" + "ldr d24, [x23], #0x8\n" + "ldr d23, [x22], #0x8\n" + "ldr d6, [x21], #0x8\n" + "tbz x13, #0, 164f\n" + "ld1 { v16.s }[2], [x12]\n" + "ld1 { v15.s }[2], [x24]\n" + "ld1 { v24.s }[2], [x23]\n" + "ld1 { v23.s }[2], [x22]\n" + "ld1 { v6.s }[2], [x21]\n" + "b 164f\n" + "157:" // Height 5: Partial accumulate: partial_1_12 + "mov x19, #0x30\n" + "tbz x13, #0, 164f\n" + "ldr s16, [x12, #0x0]\n" + "ldr s15, [x24, #0x0]\n" + "ldr s24, [x23, #0x0]\n" + "ldr s23, [x22, #0x0]\n" + "ldr s6, [x21, #0x0]\n" + "b 164f\n" + "158:" // Height 5: Partial accumulate: partial_2_8 + "tbz x13, #1, 159f\n" + "ldr d11, [x12], #0x8\n" + "ldr d14, [x24], #0x8\n" + "mov x19, #0x28\n" + "ldr d19, [x23], #0x8\n" + "ldr d22, [x22], #0x8\n" + "ldr d27, [x21], #0x8\n" + "tbz x13, #0, 164f\n" + "ld1 { v11.s }[2], [x12]\n" + "ld1 { v14.s }[2], [x24]\n" + "ld1 { v19.s }[2], [x23]\n" + "ld1 { v22.s }[2], [x22]\n" + "ld1 { v27.s }[2], [x21]\n" + "b 164f\n" + "159:" // Height 5: Partial accumulate: partial_1_8 + "mov x19, #0x20\n" + "tbz x13, #0, 164f\n" + "ldr s11, [x12, #0x0]\n" + "ldr s14, [x24, #0x0]\n" + "ldr s19, [x23, #0x0]\n" + "ldr s22, [x22, #0x0]\n" + "ldr s27, [x21, #0x0]\n" + "b 164f\n" + "160:" // Height 5: Partial accumulate: partial_4_0 + "tbz x13, #2, 162f\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "ld1 { v12.4s }, [x24], #0x10\n" + "ld1 { v17.4s }, [x23], #0x10\n" + "ld1 { v20.4s }, [x22], #0x10\n" + "ld1 { v25.4s }, [x21], #0x10\n" + "tbz x13, #1, 161f\n" + "ldr d10, [x12], #0x8\n" + "ldr d13, [x24], #0x8\n" + "mov x19, #0x18\n" + "ldr d18, [x23], #0x8\n" + "ldr d21, [x22], #0x8\n" + "ldr d26, [x21], #0x8\n" + "tbz x13, #0, 164f\n" + "ld1 { v10.s }[2], [x12]\n" + "ld1 { v13.s }[2], [x24]\n" + "ld1 { v18.s }[2], [x23]\n" + "ld1 { v21.s }[2], [x22]\n" + "ld1 { v26.s }[2], [x21]\n" + "b 164f\n" + "161:" // Height 5: Partial accumulate: partial_1_4 + "mov x19, #0x10\n" + "tbz x13, #0, 164f\n" + "ldr s10, [x12, #0x0]\n" + "ldr s13, [x24, #0x0]\n" + "ldr s18, [x23, #0x0]\n" + "ldr s21, [x22, #0x0]\n" + "ldr s26, [x21, #0x0]\n" + "b 164f\n" + "162:" // Height 5: Partial accumulate: partial_2_0 + "tbz x13, #1, 163f\n" + "ldr d9, [x12], #0x8\n" + "ldr d12, [x24], #0x8\n" + "mov x19, #0x8\n" + "ldr d17, [x23], #0x8\n" + "ldr d20, [x22], #0x8\n" + "ldr d25, [x21], #0x8\n" + "tbz x13, #0, 164f\n" + "ld1 { v9.s }[2], [x12]\n" + "ld1 { v12.s }[2], [x24]\n" + "ld1 { v17.s }[2], [x23]\n" + "ld1 { v20.s }[2], [x22]\n" + "ld1 { v25.s }[2], [x21]\n" + "b 164f\n" + "163:" // Height 5: Partial accumulate: partial_1_0 + "ldr s9, [x12, #0x0]\n" + "ldr s12, [x24, #0x0]\n" + "mov x19, #0x0\n" + "ldr s17, [x23, #0x0]\n" + "ldr s20, [x22, #0x0]\n" + "ldr s25, [x21, #0x0]\n" + "164:" // Height 5: Partial accumulate: Done + "sub x12, x12, x19\n" + "b 166f\n" + "165:" // Height 5: full accumulate + "ldr q9, [x12, #0x0]\n" + "ldr q10, [x12, #0x10]\n" + "ldr q11, [x12, #0x20]\n" + "ldr q16, [x12, #0x30]\n" + "ldr q12, [x24, #0x0]\n" + "ldr q13, [x24, #0x10]\n" + "ldr q14, [x24, #0x20]\n" + "ldr q15, [x24, #0x30]\n" + "ldr q17, [x23, #0x0]\n" + "ldr q18, [x23, #0x10]\n" + "ldr q19, [x23, #0x20]\n" + "ldr q24, [x23, #0x30]\n" + "ldr q20, [x22, #0x0]\n" + "ldr q21, [x22, #0x10]\n" + "ldr q22, [x22, #0x20]\n" + "ldr q23, [x22, #0x30]\n" + "ldr q25, [x21, #0x0]\n" + "ldr q26, [x21, #0x10]\n" + "ldr q27, [x21, #0x20]\n" + "ldr q6, [x21, #0x30]\n" + "166:" // Height 5: MMLA fixup + "zip1 v8.2d, v9.2d, v12.2d\n" + "zip2 v12.2d, v9.2d, v12.2d\n" + "zip1 v9.2d, v10.2d, v13.2d\n" + "zip2 v13.2d, v10.2d, v13.2d\n" + "zip1 v10.2d, v11.2d, v14.2d\n" + "zip2 v14.2d, v11.2d, v14.2d\n" + "zip1 v11.2d, v16.2d, v15.2d\n" + "zip2 v15.2d, v16.2d, v15.2d\n" + "zip1 v16.2d, v17.2d, v20.2d\n" + "zip2 v20.2d, v17.2d, v20.2d\n" + "zip1 v17.2d, v18.2d, v21.2d\n" + "zip2 v21.2d, v18.2d, v21.2d\n" + "zip1 v18.2d, v19.2d, v22.2d\n" + "zip2 v22.2d, v19.2d, v22.2d\n" + "zip1 v19.2d, v24.2d, v23.2d\n" + "zip2 v23.2d, v24.2d, v23.2d\n" + "zip1 v24.2d, v25.2d, v28.2d\n" + "zip2 v28.2d, v25.2d, v28.2d\n" + "zip1 v25.2d, v26.2d, v29.2d\n" + "zip2 v29.2d, v26.2d, v29.2d\n" + "zip1 v26.2d, v27.2d, v30.2d\n" + "zip2 v30.2d, v27.2d, v30.2d\n" + "zip1 v27.2d, v6.2d, v31.2d\n" + "zip2 v31.2d, v6.2d, v31.2d\n" + "b 168f\n" + "167:" // Height 5: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v26.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "movi v29.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "168:" // Height 5: setup done + "mov x27, #0x0\n" + "169:" // Height 5: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 170f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "ldr x22, [x20, #0x18]\n" + "ldr x21, [x20, #0x20]\n" + "cbnz x27, 171f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "add x24, x24, x19, LSL #1\n" + "add x23, x23, x19, LSL #1\n" + "add x22, x22, x19, LSL #1\n" + "add x21, x21, x19, LSL #1\n" + "b 171f\n" + "170:" // Height 5: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "add x22, x23, x19, LSL #1\n" + "add x21, x22, x19, LSL #1\n" + "171:" // Height 5: input setup done + "cmp x26, #0x8\n" + "blt 174f\n" + "ldr q1, [x25, #0x0]\n" + "ldr q2, [x24, #0x0]\n" + "cmp x26, #0x10\n" + "ldr q3, [x23, #0x0]\n" + "ldr q4, [x22, #0x0]\n" + "ldr q5, [x21, #0x0]\n" + "ldr q7, [x11, #0x0]\n" + "blt 173f\n" + "172:" // Height 5: Multiply loop: Main loop head + "trn1 v0.2d, v1.2d, v2.2d\n" + "trn2 v1.2d, v1.2d, v2.2d\n" + ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" + "trn1 v2.2d, v3.2d, v4.2d\n" + "trn2 v3.2d, v3.2d, v4.2d\n" + ".inst 0x6e47ec50 // bfmmla v16.4s, v2.8h, v7.8h\n" + "sub x26, x26, #0x8\n" + "trn1 v4.2d, v5.2d, v6.2d\n" + "trn2 v5.2d, v5.2d, v6.2d\n" + "ldr q6, [x11, #0x10]\n" + ".inst 0x6e47ec98 // bfmmla v24.4s, v4.8h, v7.8h\n" + "ldr q7, [x10, #0x0]\n" + ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec54 // bfmmla v20.4s, v2.8h, v6.8h\n" + "cmp x26, #0x10\n" + ".inst 0x6e46ec9c // bfmmla v28.4s, v4.8h, v6.8h\n" + "ldr q6, [x10, #0x10]\n" + ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n" + "add x25, x25, #0x10\n" + ".inst 0x6e47ec51 // bfmmla v17.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec99 // bfmmla v25.4s, v4.8h, v7.8h\n" + "ldr q7, [x9, #0x0]\n" + "add x24, x24, #0x10\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec55 // bfmmla v21.4s, v2.8h, v6.8h\n" + "add x23, x23, #0x10\n" + "add x22, x22, #0x10\n" + ".inst 0x6e46ec9d // bfmmla v29.4s, v4.8h, v6.8h\n" + "ldr q6, [x9, #0x10]\n" + ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n" + "add x21, x21, #0x10\n" + ".inst 0x6e47ec52 // bfmmla v18.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec9a // bfmmla v26.4s, v4.8h, v7.8h\n" + "ldr q7, [x28, #0x0]\n" + ".inst 0x6e46ec0e // bfmmla v14.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec56 // bfmmla v22.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec9e // bfmmla v30.4s, v4.8h, v6.8h\n" + "ldr q6, [x28, #0x10]\n" + ".inst 0x6e47ec0b // bfmmla v11.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec53 // bfmmla v19.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec9b // bfmmla v27.4s, v4.8h, v7.8h\n" + "ldr q7, [x11, #0x20]\n" + ".inst 0x6e46ec0f // bfmmla v15.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + "ldr q2, [x24, #0x0]\n" + ".inst 0x6e46ec9f // bfmmla v31.4s, v4.8h, v6.8h\n" + "ldr q6, [x11, #0x30]\n" + "ldr q4, [x22, #0x0]\n" + ".inst 0x6e47ec28 // bfmmla v8.4s, v1.8h, v7.8h\n" + ".inst 0x6e47ec70 // bfmmla v16.4s, v3.8h, v7.8h\n" + ".inst 0x6e47ecb8 // bfmmla v24.4s, v5.8h, v7.8h\n" + "ldr q7, [x10, #0x20]\n" + "add x11, x11, #0x40\n" + ".inst 0x6e46ec2c // bfmmla v12.4s, v1.8h, v6.8h\n" + ".inst 0x6e46ec74 // bfmmla v20.4s, v3.8h, v6.8h\n" + ".inst 0x6e46ecbc // bfmmla v28.4s, v5.8h, v6.8h\n" + "ldr q6, [x10, #0x30]\n" + ".inst 0x6e47ec29 // bfmmla v9.4s, v1.8h, v7.8h\n" + "add x10, x10, #0x40\n" + ".inst 0x6e47ec71 // bfmmla v17.4s, v3.8h, v7.8h\n" + ".inst 0x6e47ecb9 // bfmmla v25.4s, v5.8h, v7.8h\n" + "ldr q7, [x9, #0x20]\n" + ".inst 0x6e46ec2d // bfmmla v13.4s, v1.8h, v6.8h\n" + ".inst 0x6e46ec75 // bfmmla v21.4s, v3.8h, v6.8h\n" + ".inst 0x6e46ecbd // bfmmla v29.4s, v5.8h, v6.8h\n" + "ldr q6, [x9, #0x30]\n" + ".inst 0x6e47ec2a // bfmmla v10.4s, v1.8h, v7.8h\n" + "add x9, x9, #0x40\n" + ".inst 0x6e47ec72 // bfmmla v18.4s, v3.8h, v7.8h\n" + ".inst 0x6e47ecba // bfmmla v26.4s, v5.8h, v7.8h\n" + "ldr q7, [x28, #0x20]\n" + ".inst 0x6e46ec2e // bfmmla v14.4s, v1.8h, v6.8h\n" + ".inst 0x6e46ec76 // bfmmla v22.4s, v3.8h, v6.8h\n" + ".inst 0x6e46ecbe // bfmmla v30.4s, v5.8h, v6.8h\n" + "ldr q6, [x28, #0x30]\n" + "add x28, x28, #0x40\n" + ".inst 0x6e47ec2b // bfmmla v11.4s, v1.8h, v7.8h\n" + ".inst 0x6e47ec73 // bfmmla v19.4s, v3.8h, v7.8h\n" + ".inst 0x6e47ecbb // bfmmla v27.4s, v5.8h, v7.8h\n" + "ldr q7, [x11, #0x0]\n" + ".inst 0x6e46ec2f // bfmmla v15.4s, v1.8h, v6.8h\n" + "ldr q1, [x25, #0x0]\n" + ".inst 0x6e46ec77 // bfmmla v23.4s, v3.8h, v6.8h\n" + "ldr q3, [x23, #0x0]\n" + ".inst 0x6e46ecbf // bfmmla v31.4s, v5.8h, v6.8h\n" + "ldr q5, [x21, #0x0]\n" + "bge 172b\n" + "173:" // Height 5: Multiply loop: Single iteration only + "trn1 v0.2d, v1.2d, v2.2d\n" + "trn2 v1.2d, v1.2d, v2.2d\n" + ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" + "trn1 v2.2d, v3.2d, v4.2d\n" + "trn2 v3.2d, v3.2d, v4.2d\n" + ".inst 0x6e47ec50 // bfmmla v16.4s, v2.8h, v7.8h\n" + "sub x26, x26, #0x8\n" + "trn1 v4.2d, v5.2d, v6.2d\n" + "trn2 v5.2d, v5.2d, v6.2d\n" + "ldr q6, [x11, #0x10]\n" + ".inst 0x6e47ec98 // bfmmla v24.4s, v4.8h, v7.8h\n" + "ldr q7, [x10, #0x0]\n" + ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec54 // bfmmla v20.4s, v2.8h, v6.8h\n" + "add x25, x25, #0x10\n" + ".inst 0x6e46ec9c // bfmmla v28.4s, v4.8h, v6.8h\n" + "ldr q6, [x10, #0x10]\n" + ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n" + "add x24, x24, #0x10\n" + ".inst 0x6e47ec51 // bfmmla v17.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec99 // bfmmla v25.4s, v4.8h, v7.8h\n" + "ldr q7, [x9, #0x0]\n" + "add x23, x23, #0x10\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec55 // bfmmla v21.4s, v2.8h, v6.8h\n" + "add x22, x22, #0x10\n" + "add x21, x21, #0x10\n" + ".inst 0x6e46ec9d // bfmmla v29.4s, v4.8h, v6.8h\n" + "ldr q6, [x9, #0x10]\n" + ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec52 // bfmmla v18.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec9a // bfmmla v26.4s, v4.8h, v7.8h\n" + "ldr q7, [x28, #0x0]\n" + ".inst 0x6e46ec0e // bfmmla v14.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec56 // bfmmla v22.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec9e // bfmmla v30.4s, v4.8h, v6.8h\n" + "ldr q6, [x28, #0x10]\n" + ".inst 0x6e47ec0b // bfmmla v11.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec53 // bfmmla v19.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec9b // bfmmla v27.4s, v4.8h, v7.8h\n" + "ldr q7, [x11, #0x20]\n" + ".inst 0x6e46ec0f // bfmmla v15.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec9f // bfmmla v31.4s, v4.8h, v6.8h\n" + "ldr q6, [x11, #0x30]\n" + ".inst 0x6e47ec28 // bfmmla v8.4s, v1.8h, v7.8h\n" + "add x11, x11, #0x40\n" + ".inst 0x6e47ec70 // bfmmla v16.4s, v3.8h, v7.8h\n" + ".inst 0x6e47ecb8 // bfmmla v24.4s, v5.8h, v7.8h\n" + "ldr q7, [x10, #0x20]\n" + ".inst 0x6e46ec2c // bfmmla v12.4s, v1.8h, v6.8h\n" + ".inst 0x6e46ec74 // bfmmla v20.4s, v3.8h, v6.8h\n" + ".inst 0x6e46ecbc // bfmmla v28.4s, v5.8h, v6.8h\n" + "ldr q6, [x10, #0x30]\n" + ".inst 0x6e47ec29 // bfmmla v9.4s, v1.8h, v7.8h\n" + "add x10, x10, #0x40\n" + ".inst 0x6e47ec71 // bfmmla v17.4s, v3.8h, v7.8h\n" + ".inst 0x6e47ecb9 // bfmmla v25.4s, v5.8h, v7.8h\n" + "ldr q7, [x9, #0x20]\n" + ".inst 0x6e46ec2d // bfmmla v13.4s, v1.8h, v6.8h\n" + ".inst 0x6e46ec75 // bfmmla v21.4s, v3.8h, v6.8h\n" + ".inst 0x6e46ecbd // bfmmla v29.4s, v5.8h, v6.8h\n" + "ldr q6, [x9, #0x30]\n" + ".inst 0x6e47ec2a // bfmmla v10.4s, v1.8h, v7.8h\n" + "add x9, x9, #0x40\n" + ".inst 0x6e47ec72 // bfmmla v18.4s, v3.8h, v7.8h\n" + ".inst 0x6e47ecba // bfmmla v26.4s, v5.8h, v7.8h\n" + "ldr q7, [x28, #0x20]\n" + ".inst 0x6e46ec2e // bfmmla v14.4s, v1.8h, v6.8h\n" + ".inst 0x6e46ec76 // bfmmla v22.4s, v3.8h, v6.8h\n" + ".inst 0x6e46ecbe // bfmmla v30.4s, v5.8h, v6.8h\n" + "ldr q6, [x28, #0x30]\n" + "add x28, x28, #0x40\n" + ".inst 0x6e47ec2b // bfmmla v11.4s, v1.8h, v7.8h\n" + ".inst 0x6e47ec73 // bfmmla v19.4s, v3.8h, v7.8h\n" + ".inst 0x6e47ecbb // bfmmla v27.4s, v5.8h, v7.8h\n" + ".inst 0x6e46ec2f // bfmmla v15.4s, v1.8h, v6.8h\n" + ".inst 0x6e46ec77 // bfmmla v23.4s, v3.8h, v6.8h\n" + ".inst 0x6e46ecbf // bfmmla v31.4s, v5.8h, v6.8h\n" + "174:" // Height 5: Multiply loop: Main loop skip + "cbz x26, 179f\n" + "cmp x26, #0x4\n" + "blt 176f\n" + "175:" // Height 5: Multiply loop: Odd block loop + "ldr d1, [x25], #0x8\n" + "ldr d2, [x24], #0x8\n" + "trn1 v0.2d, v1.2d, v2.2d\n" + "ldr d3, [x23], #0x8\n" + "ldr d4, [x22], #0x8\n" + "trn1 v2.2d, v3.2d, v4.2d\n" + "sub x26, x26, #0x4\n" + "ldr d5, [x21], #0x8\n" + "ldr q6, [x11, #0x0]\n" + "trn1 v4.2d, v5.2d, v7.2d\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + "ldr q7, [x11, #0x10]\n" + ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec98 // bfmmla v24.4s, v4.8h, v6.8h\n" + "ldr q6, [x10, #0x0]\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" + "cmp x26, #0x4\n" + "add x11, x11, #0x20\n" + ".inst 0x6e47ec9c // bfmmla v28.4s, v4.8h, v7.8h\n" + "ldr q7, [x10, #0x10]\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + "add x10, x10, #0x20\n" + ".inst 0x6e46ec51 // bfmmla v17.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec99 // bfmmla v25.4s, v4.8h, v6.8h\n" + "ldr q6, [x9, #0x0]\n" + ".inst 0x6e47ec0d // bfmmla v13.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec55 // bfmmla v21.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec9d // bfmmla v29.4s, v4.8h, v7.8h\n" + "ldr q7, [x9, #0x10]\n" + ".inst 0x6e46ec0a // bfmmla v10.4s, v0.8h, v6.8h\n" + "add x9, x9, #0x20\n" + ".inst 0x6e46ec52 // bfmmla v18.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec9a // bfmmla v26.4s, v4.8h, v6.8h\n" + "ldr q6, [x28, #0x0]\n" + ".inst 0x6e47ec0e // bfmmla v14.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec56 // bfmmla v22.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec9e // bfmmla v30.4s, v4.8h, v7.8h\n" + "ldr q7, [x28, #0x10]\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + "add x28, x28, #0x20\n" + ".inst 0x6e46ec53 // bfmmla v19.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec9b // bfmmla v27.4s, v4.8h, v6.8h\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec57 // bfmmla v23.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec9f // bfmmla v31.4s, v4.8h, v7.8h\n" + "bge 175b\n" + "176:" // Height 5: Multiply loop: Skip odd blocks + "cbz x26, 179f\n" + "tbz x26, #1, 177f\n" + "ldr s1, [x25], #0x4\n" + "ldr s2, [x24], #0x4\n" + "ldr s3, [x23], #0x4\n" + "ldr s4, [x22], #0x4\n" + "ldr s5, [x21], #0x4\n" + "tbz x26, #0, 178f\n" + "ld1 { v1.h }[2], [x25]\n" + "ld1 { v2.h }[2], [x24]\n" + "ld1 { v3.h }[2], [x23]\n" + "ld1 { v4.h }[2], [x22]\n" + "ld1 { v5.h }[2], [x21]\n" + "b 178f\n" + "177:" // Height 5: Multiply loop: Ragged operand read: partial_1_0 + "ldr h1, [x25, #0x0]\n" + "ldr h2, [x24, #0x0]\n" + "ldr h3, [x23, #0x0]\n" + "ldr h4, [x22, #0x0]\n" + "ldr h5, [x21, #0x0]\n" + "178:" // Height 5: Multiply loop: Ragged operand read: Done + "ldr q7, [x11, #0x0]\n" + "trn1 v0.2d, v1.2d, v2.2d\n" + "trn1 v2.2d, v3.2d, v4.2d\n" + "trn1 v4.2d, v5.2d, v6.2d\n" + "ldr q6, [x11, #0x10]\n" + ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec50 // bfmmla v16.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec98 // bfmmla v24.4s, v4.8h, v7.8h\n" + "ldr q7, [x10, #0x0]\n" + ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" + "add x11, x11, #0x20\n" + ".inst 0x6e46ec54 // bfmmla v20.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec9c // bfmmla v28.4s, v4.8h, v6.8h\n" + "ldr q6, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec51 // bfmmla v17.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec99 // bfmmla v25.4s, v4.8h, v7.8h\n" + "ldr q7, [x9, #0x0]\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec55 // bfmmla v21.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec9d // bfmmla v29.4s, v4.8h, v6.8h\n" + "ldr q6, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec52 // bfmmla v18.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec9a // bfmmla v26.4s, v4.8h, v7.8h\n" + "ldr q7, [x28, #0x0]\n" + ".inst 0x6e46ec0e // bfmmla v14.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec56 // bfmmla v22.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec9e // bfmmla v30.4s, v4.8h, v6.8h\n" + "ldr q6, [x28, #0x10]\n" + "add x28, x28, #0x20\n" + ".inst 0x6e47ec0b // bfmmla v11.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec53 // bfmmla v19.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec9b // bfmmla v27.4s, v4.8h, v7.8h\n" + ".inst 0x6e46ec0f // bfmmla v15.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec9f // bfmmla v31.4s, v4.8h, v6.8h\n" + "179:" // Height 5: Multiply loop: No odd multiplies + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 169b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "uzp1 v7.2d, v8.2d, v12.2d\n" + "add x22, x23, x19, LSL #2\n" + "uzp2 v8.2d, v8.2d, v12.2d\n" + "uzp1 v12.2d, v9.2d, v13.2d\n" + "add x21, x22, x19, LSL #2\n" + "uzp2 v9.2d, v9.2d, v13.2d\n" + "uzp1 v13.2d, v10.2d, v14.2d\n" + "uzp2 v10.2d, v10.2d, v14.2d\n" + "uzp1 v14.2d, v11.2d, v15.2d\n" + "uzp2 v11.2d, v11.2d, v15.2d\n" + "uzp1 v15.2d, v16.2d, v20.2d\n" + "uzp2 v16.2d, v16.2d, v20.2d\n" + "uzp1 v20.2d, v17.2d, v21.2d\n" + "uzp2 v17.2d, v17.2d, v21.2d\n" + "uzp1 v21.2d, v18.2d, v22.2d\n" + "uzp2 v18.2d, v18.2d, v22.2d\n" + "uzp1 v22.2d, v19.2d, v23.2d\n" + "uzp2 v19.2d, v19.2d, v23.2d\n" + "uzp1 v24.2d, v24.2d, v28.2d\n" + "uzp1 v25.2d, v25.2d, v29.2d\n" + "uzp1 v26.2d, v26.2d, v30.2d\n" + "uzp1 v27.2d, v27.2d, v31.2d\n" + "tbz %x[flags], #1, 180f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1r { v1.4s }, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1r { v0.4s }, [x19]\n" + "fmin v7.4s, v7.4s, v1.4s\n" + "fmin v12.4s, v12.4s, v1.4s\n" + "fmin v13.4s, v13.4s, v1.4s\n" + "fmin v14.4s, v14.4s, v1.4s\n" + "fmin v8.4s, v8.4s, v1.4s\n" + "fmin v9.4s, v9.4s, v1.4s\n" + "fmin v10.4s, v10.4s, v1.4s\n" + "fmin v11.4s, v11.4s, v1.4s\n" + "fmin v15.4s, v15.4s, v1.4s\n" + "fmin v20.4s, v20.4s, v1.4s\n" + "fmin v21.4s, v21.4s, v1.4s\n" + "fmin v22.4s, v22.4s, v1.4s\n" + "fmin v16.4s, v16.4s, v1.4s\n" + "fmin v17.4s, v17.4s, v1.4s\n" + "fmin v18.4s, v18.4s, v1.4s\n" + "fmin v19.4s, v19.4s, v1.4s\n" + "fmin v24.4s, v24.4s, v1.4s\n" + "fmin v25.4s, v25.4s, v1.4s\n" + "fmin v26.4s, v26.4s, v1.4s\n" + "fmin v27.4s, v27.4s, v1.4s\n" + "fmax v7.4s, v7.4s, v0.4s\n" + "fmax v12.4s, v12.4s, v0.4s\n" + "fmax v13.4s, v13.4s, v0.4s\n" + "fmax v14.4s, v14.4s, v0.4s\n" + "fmax v8.4s, v8.4s, v0.4s\n" + "fmax v9.4s, v9.4s, v0.4s\n" + "fmax v10.4s, v10.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + "fmax v15.4s, v15.4s, v0.4s\n" + "fmax v20.4s, v20.4s, v0.4s\n" + "fmax v21.4s, v21.4s, v0.4s\n" + "fmax v22.4s, v22.4s, v0.4s\n" + "fmax v16.4s, v16.4s, v0.4s\n" + "fmax v17.4s, v17.4s, v0.4s\n" + "fmax v18.4s, v18.4s, v0.4s\n" + "fmax v19.4s, v19.4s, v0.4s\n" + "fmax v24.4s, v24.4s, v0.4s\n" + "fmax v25.4s, v25.4s, v0.4s\n" + "fmax v26.4s, v26.4s, v0.4s\n" + "fmax v27.4s, v27.4s, v0.4s\n" + "180:" // Height 5: No activation + "cmp x13, #0x10\n" + "bge 189f\n" + "tbz x13, #3, 184f\n" + "st1 { v7.4s }, [x12], #0x10\n" + "st1 { v12.4s }, [x12], #0x10\n" + "st1 { v8.4s }, [x24], #0x10\n" + "st1 { v9.4s }, [x24], #0x10\n" + "st1 { v15.4s }, [x23], #0x10\n" + "st1 { v20.4s }, [x23], #0x10\n" + "st1 { v16.4s }, [x22], #0x10\n" + "st1 { v17.4s }, [x22], #0x10\n" + "st1 { v24.4s }, [x21], #0x10\n" + "st1 { v25.4s }, [x21], #0x10\n" + "tbz x13, #2, 182f\n" + "st1 { v13.4s }, [x12], #0x10\n" + "st1 { v10.4s }, [x24], #0x10\n" + "st1 { v21.4s }, [x23], #0x10\n" + "st1 { v18.4s }, [x22], #0x10\n" + "st1 { v26.4s }, [x21], #0x10\n" + "tbz x13, #1, 181f\n" + "str d14, [x12], #0x8\n" + "str d11, [x24], #0x8\n" + "str d22, [x23], #0x8\n" + "str d19, [x22], #0x8\n" + "str d27, [x21], #0x8\n" + "tbz x13, #0, 188f\n" + "st1 { v14.s }[2], [x12]\n" + "st1 { v11.s }[2], [x24]\n" + "st1 { v22.s }[2], [x23]\n" + "st1 { v19.s }[2], [x22]\n" + "st1 { v27.s }[2], [x21]\n" + "b 188f\n" + "181:" // Height 5: Partial direct writeback: partial_1_12 + "tbz x13, #0, 188f\n" + "str s14, [x12, #0x0]\n" + "str s11, [x24, #0x0]\n" + "str s22, [x23, #0x0]\n" + "str s19, [x22, #0x0]\n" + "str s27, [x21, #0x0]\n" + "b 188f\n" + "182:" // Height 5: Partial direct writeback: partial_2_8 + "tbz x13, #1, 183f\n" + "str d13, [x12], #0x8\n" + "str d10, [x24], #0x8\n" + "str d21, [x23], #0x8\n" + "str d18, [x22], #0x8\n" + "str d26, [x21], #0x8\n" + "tbz x13, #0, 188f\n" + "st1 { v13.s }[2], [x12]\n" + "st1 { v10.s }[2], [x24]\n" + "st1 { v21.s }[2], [x23]\n" + "st1 { v18.s }[2], [x22]\n" + "st1 { v26.s }[2], [x21]\n" + "b 188f\n" + "183:" // Height 5: Partial direct writeback: partial_1_8 + "tbz x13, #0, 188f\n" + "str s13, [x12, #0x0]\n" + "str s10, [x24, #0x0]\n" + "str s21, [x23, #0x0]\n" + "str s18, [x22, #0x0]\n" + "str s26, [x21, #0x0]\n" + "b 188f\n" + "184:" // Height 5: Partial direct writeback: partial_4_0 + "tbz x13, #2, 186f\n" + "st1 { v7.4s }, [x12], #0x10\n" + "st1 { v8.4s }, [x24], #0x10\n" + "st1 { v15.4s }, [x23], #0x10\n" + "st1 { v16.4s }, [x22], #0x10\n" + "st1 { v24.4s }, [x21], #0x10\n" + "tbz x13, #1, 185f\n" + "str d12, [x12], #0x8\n" + "str d9, [x24], #0x8\n" + "str d20, [x23], #0x8\n" + "str d17, [x22], #0x8\n" + "str d25, [x21], #0x8\n" + "tbz x13, #0, 188f\n" + "st1 { v12.s }[2], [x12]\n" + "st1 { v9.s }[2], [x24]\n" + "st1 { v20.s }[2], [x23]\n" + "st1 { v17.s }[2], [x22]\n" + "st1 { v25.s }[2], [x21]\n" + "b 188f\n" + "185:" // Height 5: Partial direct writeback: partial_1_4 + "tbz x13, #0, 188f\n" + "str s12, [x12, #0x0]\n" + "str s9, [x24, #0x0]\n" + "str s20, [x23, #0x0]\n" + "str s17, [x22, #0x0]\n" + "str s25, [x21, #0x0]\n" + "b 188f\n" + "186:" // Height 5: Partial direct writeback: partial_2_0 + "tbz x13, #1, 187f\n" + "str d7, [x12], #0x8\n" + "str d8, [x24], #0x8\n" + "str d15, [x23], #0x8\n" + "str d16, [x22], #0x8\n" + "str d24, [x21], #0x8\n" + "tbz x13, #0, 188f\n" + "st1 { v7.s }[2], [x12]\n" + "st1 { v8.s }[2], [x24]\n" + "st1 { v15.s }[2], [x23]\n" + "st1 { v16.s }[2], [x22]\n" + "st1 { v24.s }[2], [x21]\n" + "b 188f\n" + "187:" // Height 5: Partial direct writeback: partial_1_0 + "str s7, [x12, #0x0]\n" + "str s8, [x24, #0x0]\n" + "str s15, [x23, #0x0]\n" + "str s16, [x22, #0x0]\n" + "str s24, [x21, #0x0]\n" + "188:" // Height 5: Partial direct writeback: Done + "b 190f\n" + "189:" // Height 5: Full writeback + "str q7, [x12, #0x0]\n" + "str q12, [x12, #0x10]\n" + "str q13, [x12, #0x20]\n" + "str q14, [x12, #0x30]\n" + "add x12, x12, #0x40\n" + "str q8, [x24, #0x0]\n" + "str q9, [x24, #0x10]\n" + "str q10, [x24, #0x20]\n" + "str q11, [x24, #0x30]\n" + "str q15, [x23, #0x0]\n" + "str q20, [x23, #0x10]\n" + "str q21, [x23, #0x20]\n" + "str q22, [x23, #0x30]\n" + "str q16, [x22, #0x0]\n" + "str q17, [x22, #0x10]\n" + "str q18, [x22, #0x20]\n" + "str q19, [x22, #0x30]\n" + "str q24, [x21, #0x0]\n" + "str q25, [x21, #0x10]\n" + "str q26, [x21, #0x20]\n" + "str q27, [x21, #0x30]\n" + "190:" // Height 5: Writeback done + "subs x13, x13, #0x10\n" + "bgt 154b\n" + "b 230f\n" + "191:" // Height 6 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x20, #0x18\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "mov x14, %x[bias]\n" + "mov x12, %x[output_ptr]\n" + "madd %x[output_ptr], x19, x20, %x[output_ptr]\n" + "192:" // Height 6: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "cmp x13, #0xc\n" + "bgt 193f\n" + "cmp x13, #0x8\n" + "mov x28, x11\n" + "bgt 193f\n" + "cmp x13, #0x4\n" + "mov x9, x11\n" + "bgt 193f\n" + "mov x10, x11\n" + "193:" // Height 6: B setup done + "cbz x14, 194f\n" + "ldr q8, [x14, #0x0]\n" + "ldr q9, [x14, #0x10]\n" + "zip2 v12.2d, v8.2d, v8.2d\n" + "zip1 v8.2d, v8.2d, v8.2d\n" + "ldr q10, [x14, #0x20]\n" + "ldr q11, [x14, #0x30]\n" + "zip2 v13.2d, v9.2d, v9.2d\n" + "zip1 v9.2d, v9.2d, v9.2d\n" + "zip2 v14.2d, v10.2d, v10.2d\n" + "zip1 v10.2d, v10.2d, v10.2d\n" + "add x14, x14, #0x40\n" + "zip2 v15.2d, v11.2d, v11.2d\n" + "zip1 v11.2d, v11.2d, v11.2d\n" + "mov v16.16b, v8.16b\n" + "mov v20.16b, v12.16b\n" + "mov v17.16b, v9.16b\n" + "mov v21.16b, v13.16b\n" + "mov v18.16b, v10.16b\n" + "mov v22.16b, v14.16b\n" + "mov v19.16b, v11.16b\n" + "mov v23.16b, v15.16b\n" + "mov v24.16b, v8.16b\n" + "mov v28.16b, v12.16b\n" + "mov v25.16b, v9.16b\n" + "mov v29.16b, v13.16b\n" + "mov v26.16b, v10.16b\n" + "mov v30.16b, v14.16b\n" + "mov v27.16b, v11.16b\n" + "mov v31.16b, v15.16b\n" + "b 206f\n" + "194:" // Height 6: no bias + "tbz %x[flags], #0, 205f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "add x22, x23, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "cmp x13, #0x10\n" + "add x20, x21, x19, LSL #2\n" + "bge 203f\n" + "tbz x13, #3, 198f\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "ld1 { v12.4s }, [x24], #0x10\n" + "ld1 { v17.4s }, [x23], #0x10\n" + "ld1 { v20.4s }, [x22], #0x10\n" + "ld1 { v25.4s }, [x21], #0x10\n" + "ld1 { v28.4s }, [x20], #0x10\n" + "ld1 { v10.4s }, [x12], #0x10\n" + "ld1 { v13.4s }, [x24], #0x10\n" + "ld1 { v18.4s }, [x23], #0x10\n" + "ld1 { v21.4s }, [x22], #0x10\n" + "ld1 { v26.4s }, [x21], #0x10\n" + "ld1 { v29.4s }, [x20], #0x10\n" + "tbz x13, #2, 196f\n" + "ld1 { v11.4s }, [x12], #0x10\n" + "ld1 { v14.4s }, [x24], #0x10\n" + "ld1 { v19.4s }, [x23], #0x10\n" + "ld1 { v22.4s }, [x22], #0x10\n" + "ld1 { v27.4s }, [x21], #0x10\n" + "ld1 { v30.4s }, [x20], #0x10\n" + "tbz x13, #1, 195f\n" + "ldr d16, [x12], #0x8\n" + "ldr d15, [x24], #0x8\n" + "mov x19, #0x38\n" + "ldr d24, [x23], #0x8\n" + "ldr d23, [x22], #0x8\n" + "ldr d6, [x21], #0x8\n" + "ldr d31, [x20], #0x8\n" + "tbz x13, #0, 202f\n" + "ld1 { v16.s }[2], [x12]\n" + "ld1 { v15.s }[2], [x24]\n" + "ld1 { v24.s }[2], [x23]\n" + "ld1 { v23.s }[2], [x22]\n" + "ld1 { v6.s }[2], [x21]\n" + "ld1 { v31.s }[2], [x20]\n" + "b 202f\n" + "195:" // Height 6: Partial accumulate: partial_1_12 + "mov x19, #0x30\n" + "tbz x13, #0, 202f\n" + "ldr s16, [x12, #0x0]\n" + "ldr s15, [x24, #0x0]\n" + "ldr s24, [x23, #0x0]\n" + "ldr s23, [x22, #0x0]\n" + "ldr s6, [x21, #0x0]\n" + "ldr s31, [x20, #0x0]\n" + "b 202f\n" + "196:" // Height 6: Partial accumulate: partial_2_8 + "tbz x13, #1, 197f\n" + "ldr d11, [x12], #0x8\n" + "ldr d14, [x24], #0x8\n" + "mov x19, #0x28\n" + "ldr d19, [x23], #0x8\n" + "ldr d22, [x22], #0x8\n" + "ldr d27, [x21], #0x8\n" + "ldr d30, [x20], #0x8\n" + "tbz x13, #0, 202f\n" + "ld1 { v11.s }[2], [x12]\n" + "ld1 { v14.s }[2], [x24]\n" + "ld1 { v19.s }[2], [x23]\n" + "ld1 { v22.s }[2], [x22]\n" + "ld1 { v27.s }[2], [x21]\n" + "ld1 { v30.s }[2], [x20]\n" + "b 202f\n" + "197:" // Height 6: Partial accumulate: partial_1_8 + "mov x19, #0x20\n" + "tbz x13, #0, 202f\n" + "ldr s11, [x12, #0x0]\n" + "ldr s14, [x24, #0x0]\n" + "ldr s19, [x23, #0x0]\n" + "ldr s22, [x22, #0x0]\n" + "ldr s27, [x21, #0x0]\n" + "ldr s30, [x20, #0x0]\n" + "b 202f\n" + "198:" // Height 6: Partial accumulate: partial_4_0 + "tbz x13, #2, 200f\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "ld1 { v12.4s }, [x24], #0x10\n" + "ld1 { v17.4s }, [x23], #0x10\n" + "ld1 { v20.4s }, [x22], #0x10\n" + "ld1 { v25.4s }, [x21], #0x10\n" + "ld1 { v28.4s }, [x20], #0x10\n" + "tbz x13, #1, 199f\n" + "ldr d10, [x12], #0x8\n" + "ldr d13, [x24], #0x8\n" + "mov x19, #0x18\n" + "ldr d18, [x23], #0x8\n" + "ldr d21, [x22], #0x8\n" + "ldr d26, [x21], #0x8\n" + "ldr d29, [x20], #0x8\n" + "tbz x13, #0, 202f\n" + "ld1 { v10.s }[2], [x12]\n" + "ld1 { v13.s }[2], [x24]\n" + "ld1 { v18.s }[2], [x23]\n" + "ld1 { v21.s }[2], [x22]\n" + "ld1 { v26.s }[2], [x21]\n" + "ld1 { v29.s }[2], [x20]\n" + "b 202f\n" + "199:" // Height 6: Partial accumulate: partial_1_4 + "mov x19, #0x10\n" + "tbz x13, #0, 202f\n" + "ldr s10, [x12, #0x0]\n" + "ldr s13, [x24, #0x0]\n" + "ldr s18, [x23, #0x0]\n" + "ldr s21, [x22, #0x0]\n" + "ldr s26, [x21, #0x0]\n" + "ldr s29, [x20, #0x0]\n" + "b 202f\n" + "200:" // Height 6: Partial accumulate: partial_2_0 + "tbz x13, #1, 201f\n" + "ldr d9, [x12], #0x8\n" + "ldr d12, [x24], #0x8\n" + "mov x19, #0x8\n" + "ldr d17, [x23], #0x8\n" + "ldr d20, [x22], #0x8\n" + "ldr d25, [x21], #0x8\n" + "ldr d28, [x20], #0x8\n" + "tbz x13, #0, 202f\n" + "ld1 { v9.s }[2], [x12]\n" + "ld1 { v12.s }[2], [x24]\n" + "ld1 { v17.s }[2], [x23]\n" + "ld1 { v20.s }[2], [x22]\n" + "ld1 { v25.s }[2], [x21]\n" + "ld1 { v28.s }[2], [x20]\n" + "b 202f\n" + "201:" // Height 6: Partial accumulate: partial_1_0 + "ldr s9, [x12, #0x0]\n" + "ldr s12, [x24, #0x0]\n" + "mov x19, #0x0\n" + "ldr s17, [x23, #0x0]\n" + "ldr s20, [x22, #0x0]\n" + "ldr s25, [x21, #0x0]\n" + "ldr s28, [x20, #0x0]\n" + "202:" // Height 6: Partial accumulate: Done + "sub x12, x12, x19\n" + "b 204f\n" + "203:" // Height 6: full accumulate + "ldr q9, [x12, #0x0]\n" + "ldr q10, [x12, #0x10]\n" + "ldr q11, [x12, #0x20]\n" + "ldr q16, [x12, #0x30]\n" + "ldr q12, [x24, #0x0]\n" + "ldr q13, [x24, #0x10]\n" + "ldr q14, [x24, #0x20]\n" + "ldr q15, [x24, #0x30]\n" + "ldr q17, [x23, #0x0]\n" + "ldr q18, [x23, #0x10]\n" + "ldr q19, [x23, #0x20]\n" + "ldr q24, [x23, #0x30]\n" + "ldr q20, [x22, #0x0]\n" + "ldr q21, [x22, #0x10]\n" + "ldr q22, [x22, #0x20]\n" + "ldr q23, [x22, #0x30]\n" + "ldr q25, [x21, #0x0]\n" + "ldr q26, [x21, #0x10]\n" + "ldr q27, [x21, #0x20]\n" + "ldr q6, [x21, #0x30]\n" + "ldr q28, [x20, #0x0]\n" + "ldr q29, [x20, #0x10]\n" + "ldr q30, [x20, #0x20]\n" + "ldr q31, [x20, #0x30]\n" + "204:" // Height 6: MMLA fixup + "zip1 v8.2d, v9.2d, v12.2d\n" + "zip2 v12.2d, v9.2d, v12.2d\n" + "zip1 v9.2d, v10.2d, v13.2d\n" + "zip2 v13.2d, v10.2d, v13.2d\n" + "zip1 v10.2d, v11.2d, v14.2d\n" + "zip2 v14.2d, v11.2d, v14.2d\n" + "zip1 v11.2d, v16.2d, v15.2d\n" + "zip2 v15.2d, v16.2d, v15.2d\n" + "zip1 v16.2d, v17.2d, v20.2d\n" + "zip2 v20.2d, v17.2d, v20.2d\n" + "zip1 v17.2d, v18.2d, v21.2d\n" + "zip2 v21.2d, v18.2d, v21.2d\n" + "zip1 v18.2d, v19.2d, v22.2d\n" + "zip2 v22.2d, v19.2d, v22.2d\n" + "zip1 v19.2d, v24.2d, v23.2d\n" + "zip2 v23.2d, v24.2d, v23.2d\n" + "zip1 v24.2d, v25.2d, v28.2d\n" + "zip2 v28.2d, v25.2d, v28.2d\n" + "zip1 v25.2d, v26.2d, v29.2d\n" + "zip2 v29.2d, v26.2d, v29.2d\n" + "zip1 v26.2d, v27.2d, v30.2d\n" + "zip2 v30.2d, v27.2d, v30.2d\n" + "zip1 v27.2d, v6.2d, v31.2d\n" + "zip2 v31.2d, v6.2d, v31.2d\n" + "b 206f\n" + "205:" // Height 6: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v26.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "movi v29.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "206:" // Height 6: setup done + "mov x27, #0x0\n" + "207:" // Height 6: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 208f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "ldr x22, [x20, #0x18]\n" + "ldr x21, [x20, #0x20]\n" + "ldr x20, [x20, #0x28]\n" + "cbnz x27, 209f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "add x24, x24, x19, LSL #1\n" + "add x23, x23, x19, LSL #1\n" + "add x22, x22, x19, LSL #1\n" + "add x21, x21, x19, LSL #1\n" + "add x20, x20, x19, LSL #1\n" + "b 209f\n" + "208:" // Height 6: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "add x22, x23, x19, LSL #1\n" + "add x21, x22, x19, LSL #1\n" + "add x20, x21, x19, LSL #1\n" + "209:" // Height 6: input setup done + "cmp x26, #0x8\n" + "blt 212f\n" + "ldr q1, [x25, #0x0]\n" + "ldr q2, [x24, #0x0]\n" + "cmp x26, #0x10\n" + "ldr q3, [x23, #0x0]\n" + "ldr q4, [x22, #0x0]\n" + "ldr q5, [x21, #0x0]\n" + "ldr q6, [x20, #0x0]\n" + "ldr q7, [x11, #0x0]\n" + "blt 211f\n" + "210:" // Height 6: Multiply loop: Main loop head + "trn1 v0.2d, v1.2d, v2.2d\n" + "trn2 v1.2d, v1.2d, v2.2d\n" + ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" + "sub x26, x26, #0x8\n" + "trn1 v2.2d, v3.2d, v4.2d\n" + "trn2 v3.2d, v3.2d, v4.2d\n" + ".inst 0x6e47ec50 // bfmmla v16.4s, v2.8h, v7.8h\n" + "cmp x26, #0x10\n" + "trn1 v4.2d, v5.2d, v6.2d\n" + "trn2 v5.2d, v5.2d, v6.2d\n" + "ldr q6, [x11, #0x10]\n" + ".inst 0x6e47ec98 // bfmmla v24.4s, v4.8h, v7.8h\n" + "ldr q7, [x10, #0x0]\n" + ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec54 // bfmmla v20.4s, v2.8h, v6.8h\n" + "add x25, x25, #0x10\n" + ".inst 0x6e46ec9c // bfmmla v28.4s, v4.8h, v6.8h\n" + "ldr q6, [x10, #0x10]\n" + ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n" + "add x24, x24, #0x10\n" + ".inst 0x6e47ec51 // bfmmla v17.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec99 // bfmmla v25.4s, v4.8h, v7.8h\n" + "ldr q7, [x9, #0x0]\n" + "add x23, x23, #0x10\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec55 // bfmmla v21.4s, v2.8h, v6.8h\n" + "add x22, x22, #0x10\n" + "add x21, x21, #0x10\n" + ".inst 0x6e46ec9d // bfmmla v29.4s, v4.8h, v6.8h\n" + "ldr q6, [x9, #0x10]\n" + ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n" + "add x20, x20, #0x10\n" + ".inst 0x6e47ec52 // bfmmla v18.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec9a // bfmmla v26.4s, v4.8h, v7.8h\n" + "ldr q7, [x28, #0x0]\n" + ".inst 0x6e46ec0e // bfmmla v14.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec56 // bfmmla v22.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec9e // bfmmla v30.4s, v4.8h, v6.8h\n" + "ldr q6, [x28, #0x10]\n" + ".inst 0x6e47ec0b // bfmmla v11.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec53 // bfmmla v19.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec9b // bfmmla v27.4s, v4.8h, v7.8h\n" + "ldr q7, [x11, #0x20]\n" + ".inst 0x6e46ec0f // bfmmla v15.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + "ldr q2, [x24, #0x0]\n" + ".inst 0x6e46ec9f // bfmmla v31.4s, v4.8h, v6.8h\n" + "ldr q6, [x11, #0x30]\n" + "ldr q4, [x22, #0x0]\n" + ".inst 0x6e47ec28 // bfmmla v8.4s, v1.8h, v7.8h\n" + ".inst 0x6e47ec70 // bfmmla v16.4s, v3.8h, v7.8h\n" + ".inst 0x6e47ecb8 // bfmmla v24.4s, v5.8h, v7.8h\n" + "ldr q7, [x10, #0x20]\n" + "add x11, x11, #0x40\n" + ".inst 0x6e46ec2c // bfmmla v12.4s, v1.8h, v6.8h\n" + ".inst 0x6e46ec74 // bfmmla v20.4s, v3.8h, v6.8h\n" + ".inst 0x6e46ecbc // bfmmla v28.4s, v5.8h, v6.8h\n" + "ldr q6, [x10, #0x30]\n" + ".inst 0x6e47ec29 // bfmmla v9.4s, v1.8h, v7.8h\n" + "add x10, x10, #0x40\n" + ".inst 0x6e47ec71 // bfmmla v17.4s, v3.8h, v7.8h\n" + ".inst 0x6e47ecb9 // bfmmla v25.4s, v5.8h, v7.8h\n" + "ldr q7, [x9, #0x20]\n" + ".inst 0x6e46ec2d // bfmmla v13.4s, v1.8h, v6.8h\n" + ".inst 0x6e46ec75 // bfmmla v21.4s, v3.8h, v6.8h\n" + ".inst 0x6e46ecbd // bfmmla v29.4s, v5.8h, v6.8h\n" + "ldr q6, [x9, #0x30]\n" + ".inst 0x6e47ec2a // bfmmla v10.4s, v1.8h, v7.8h\n" + "add x9, x9, #0x40\n" + ".inst 0x6e47ec72 // bfmmla v18.4s, v3.8h, v7.8h\n" + ".inst 0x6e47ecba // bfmmla v26.4s, v5.8h, v7.8h\n" + "ldr q7, [x28, #0x20]\n" + ".inst 0x6e46ec2e // bfmmla v14.4s, v1.8h, v6.8h\n" + ".inst 0x6e46ec76 // bfmmla v22.4s, v3.8h, v6.8h\n" + ".inst 0x6e46ecbe // bfmmla v30.4s, v5.8h, v6.8h\n" + "ldr q6, [x28, #0x30]\n" + "add x28, x28, #0x40\n" + ".inst 0x6e47ec2b // bfmmla v11.4s, v1.8h, v7.8h\n" + ".inst 0x6e47ec73 // bfmmla v19.4s, v3.8h, v7.8h\n" + ".inst 0x6e47ecbb // bfmmla v27.4s, v5.8h, v7.8h\n" + "ldr q7, [x11, #0x0]\n" + ".inst 0x6e46ec2f // bfmmla v15.4s, v1.8h, v6.8h\n" + "ldr q1, [x25, #0x0]\n" + ".inst 0x6e46ec77 // bfmmla v23.4s, v3.8h, v6.8h\n" + "ldr q3, [x23, #0x0]\n" + ".inst 0x6e46ecbf // bfmmla v31.4s, v5.8h, v6.8h\n" + "ldr q5, [x21, #0x0]\n" + "ldr q6, [x20, #0x0]\n" + "bge 210b\n" + "211:" // Height 6: Multiply loop: Single iteration only + "trn1 v0.2d, v1.2d, v2.2d\n" + "trn2 v1.2d, v1.2d, v2.2d\n" + ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" + "sub x26, x26, #0x8\n" + "trn1 v2.2d, v3.2d, v4.2d\n" + "trn2 v3.2d, v3.2d, v4.2d\n" + ".inst 0x6e47ec50 // bfmmla v16.4s, v2.8h, v7.8h\n" + "add x25, x25, #0x10\n" + "trn1 v4.2d, v5.2d, v6.2d\n" + "trn2 v5.2d, v5.2d, v6.2d\n" + "ldr q6, [x11, #0x10]\n" + ".inst 0x6e47ec98 // bfmmla v24.4s, v4.8h, v7.8h\n" + "ldr q7, [x10, #0x0]\n" + ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec54 // bfmmla v20.4s, v2.8h, v6.8h\n" + "add x24, x24, #0x10\n" + ".inst 0x6e46ec9c // bfmmla v28.4s, v4.8h, v6.8h\n" + "ldr q6, [x10, #0x10]\n" + ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n" + "add x23, x23, #0x10\n" + ".inst 0x6e47ec51 // bfmmla v17.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec99 // bfmmla v25.4s, v4.8h, v7.8h\n" + "ldr q7, [x9, #0x0]\n" + "add x22, x22, #0x10\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec55 // bfmmla v21.4s, v2.8h, v6.8h\n" + "add x21, x21, #0x10\n" + "add x20, x20, #0x10\n" + ".inst 0x6e46ec9d // bfmmla v29.4s, v4.8h, v6.8h\n" + "ldr q6, [x9, #0x10]\n" + ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec52 // bfmmla v18.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec9a // bfmmla v26.4s, v4.8h, v7.8h\n" + "ldr q7, [x28, #0x0]\n" + ".inst 0x6e46ec0e // bfmmla v14.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec56 // bfmmla v22.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec9e // bfmmla v30.4s, v4.8h, v6.8h\n" + "ldr q6, [x28, #0x10]\n" + ".inst 0x6e47ec0b // bfmmla v11.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec53 // bfmmla v19.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec9b // bfmmla v27.4s, v4.8h, v7.8h\n" + "ldr q7, [x11, #0x20]\n" + ".inst 0x6e46ec0f // bfmmla v15.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec9f // bfmmla v31.4s, v4.8h, v6.8h\n" + "ldr q6, [x11, #0x30]\n" + ".inst 0x6e47ec28 // bfmmla v8.4s, v1.8h, v7.8h\n" + "add x11, x11, #0x40\n" + ".inst 0x6e47ec70 // bfmmla v16.4s, v3.8h, v7.8h\n" + ".inst 0x6e47ecb8 // bfmmla v24.4s, v5.8h, v7.8h\n" + "ldr q7, [x10, #0x20]\n" + ".inst 0x6e46ec2c // bfmmla v12.4s, v1.8h, v6.8h\n" + ".inst 0x6e46ec74 // bfmmla v20.4s, v3.8h, v6.8h\n" + ".inst 0x6e46ecbc // bfmmla v28.4s, v5.8h, v6.8h\n" + "ldr q6, [x10, #0x30]\n" + ".inst 0x6e47ec29 // bfmmla v9.4s, v1.8h, v7.8h\n" + "add x10, x10, #0x40\n" + ".inst 0x6e47ec71 // bfmmla v17.4s, v3.8h, v7.8h\n" + ".inst 0x6e47ecb9 // bfmmla v25.4s, v5.8h, v7.8h\n" + "ldr q7, [x9, #0x20]\n" + ".inst 0x6e46ec2d // bfmmla v13.4s, v1.8h, v6.8h\n" + ".inst 0x6e46ec75 // bfmmla v21.4s, v3.8h, v6.8h\n" + ".inst 0x6e46ecbd // bfmmla v29.4s, v5.8h, v6.8h\n" + "ldr q6, [x9, #0x30]\n" + ".inst 0x6e47ec2a // bfmmla v10.4s, v1.8h, v7.8h\n" + "add x9, x9, #0x40\n" + ".inst 0x6e47ec72 // bfmmla v18.4s, v3.8h, v7.8h\n" + ".inst 0x6e47ecba // bfmmla v26.4s, v5.8h, v7.8h\n" + "ldr q7, [x28, #0x20]\n" + ".inst 0x6e46ec2e // bfmmla v14.4s, v1.8h, v6.8h\n" + ".inst 0x6e46ec76 // bfmmla v22.4s, v3.8h, v6.8h\n" + ".inst 0x6e46ecbe // bfmmla v30.4s, v5.8h, v6.8h\n" + "ldr q6, [x28, #0x30]\n" + "add x28, x28, #0x40\n" + ".inst 0x6e47ec2b // bfmmla v11.4s, v1.8h, v7.8h\n" + ".inst 0x6e47ec73 // bfmmla v19.4s, v3.8h, v7.8h\n" + ".inst 0x6e47ecbb // bfmmla v27.4s, v5.8h, v7.8h\n" + ".inst 0x6e46ec2f // bfmmla v15.4s, v1.8h, v6.8h\n" + ".inst 0x6e46ec77 // bfmmla v23.4s, v3.8h, v6.8h\n" + ".inst 0x6e46ecbf // bfmmla v31.4s, v5.8h, v6.8h\n" + "212:" // Height 6: Multiply loop: Main loop skip + "cbz x26, 217f\n" + "cmp x26, #0x4\n" + "blt 214f\n" + "213:" // Height 6: Multiply loop: Odd block loop + "ldr d1, [x25], #0x8\n" + "ldr d2, [x24], #0x8\n" + "trn1 v0.2d, v1.2d, v2.2d\n" + "sub x26, x26, #0x4\n" + "ldr d3, [x23], #0x8\n" + "ldr d4, [x22], #0x8\n" + "trn1 v2.2d, v3.2d, v4.2d\n" + "cmp x26, #0x4\n" + "ldr d5, [x21], #0x8\n" + "ldr d7, [x20], #0x8\n" + "trn1 v4.2d, v5.2d, v7.2d\n" + "ldr q6, [x11, #0x0]\n" + "ldr q7, [x11, #0x10]\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec50 // bfmmla v16.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec98 // bfmmla v24.4s, v4.8h, v6.8h\n" + "ldr q6, [x10, #0x0]\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + "add x11, x11, #0x20\n" + ".inst 0x6e47ec54 // bfmmla v20.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec9c // bfmmla v28.4s, v4.8h, v7.8h\n" + "ldr q7, [x10, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec51 // bfmmla v17.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec99 // bfmmla v25.4s, v4.8h, v6.8h\n" + "ldr q6, [x9, #0x0]\n" + ".inst 0x6e47ec0d // bfmmla v13.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec55 // bfmmla v21.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec9d // bfmmla v29.4s, v4.8h, v7.8h\n" + "ldr q7, [x9, #0x10]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e46ec0a // bfmmla v10.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec52 // bfmmla v18.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec9a // bfmmla v26.4s, v4.8h, v6.8h\n" + "ldr q6, [x28, #0x0]\n" + ".inst 0x6e47ec0e // bfmmla v14.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec56 // bfmmla v22.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec9e // bfmmla v30.4s, v4.8h, v7.8h\n" + "ldr q7, [x28, #0x10]\n" + "add x28, x28, #0x20\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec53 // bfmmla v19.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec9b // bfmmla v27.4s, v4.8h, v6.8h\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec57 // bfmmla v23.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec9f // bfmmla v31.4s, v4.8h, v7.8h\n" + "bge 213b\n" + "214:" // Height 6: Multiply loop: Skip odd blocks + "cbz x26, 217f\n" + "tbz x26, #1, 215f\n" + "ldr s1, [x25], #0x4\n" + "ldr s2, [x24], #0x4\n" + "ldr s3, [x23], #0x4\n" + "ldr s4, [x22], #0x4\n" + "ldr s5, [x21], #0x4\n" + "ldr s6, [x20], #0x4\n" + "tbz x26, #0, 216f\n" + "ld1 { v1.h }[2], [x25]\n" + "ld1 { v2.h }[2], [x24]\n" + "ld1 { v3.h }[2], [x23]\n" + "ld1 { v4.h }[2], [x22]\n" + "ld1 { v5.h }[2], [x21]\n" + "ld1 { v6.h }[2], [x20]\n" + "b 216f\n" + "215:" // Height 6: Multiply loop: Ragged operand read: partial_1_0 + "ldr h1, [x25, #0x0]\n" + "ldr h2, [x24, #0x0]\n" + "ldr h3, [x23, #0x0]\n" + "ldr h4, [x22, #0x0]\n" + "ldr h5, [x21, #0x0]\n" + "ldr h6, [x20, #0x0]\n" + "216:" // Height 6: Multiply loop: Ragged operand read: Done + "ldr q7, [x11, #0x0]\n" + "trn1 v0.2d, v1.2d, v2.2d\n" + "trn1 v2.2d, v3.2d, v4.2d\n" + ".inst 0x6e47ec08 // bfmmla v8.4s, v0.8h, v7.8h\n" + "trn1 v4.2d, v5.2d, v6.2d\n" + "ldr q6, [x11, #0x10]\n" + ".inst 0x6e47ec50 // bfmmla v16.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec98 // bfmmla v24.4s, v4.8h, v7.8h\n" + "ldr q7, [x10, #0x0]\n" + ".inst 0x6e46ec0c // bfmmla v12.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec54 // bfmmla v20.4s, v2.8h, v6.8h\n" + "add x11, x11, #0x20\n" + ".inst 0x6e46ec9c // bfmmla v28.4s, v4.8h, v6.8h\n" + "ldr q6, [x10, #0x10]\n" + ".inst 0x6e47ec09 // bfmmla v9.4s, v0.8h, v7.8h\n" + "add x10, x10, #0x20\n" + ".inst 0x6e47ec51 // bfmmla v17.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec99 // bfmmla v25.4s, v4.8h, v7.8h\n" + "ldr q7, [x9, #0x0]\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec55 // bfmmla v21.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec9d // bfmmla v29.4s, v4.8h, v6.8h\n" + "ldr q6, [x9, #0x10]\n" + ".inst 0x6e47ec0a // bfmmla v10.4s, v0.8h, v7.8h\n" + "add x9, x9, #0x20\n" + ".inst 0x6e47ec52 // bfmmla v18.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec9a // bfmmla v26.4s, v4.8h, v7.8h\n" + "ldr q7, [x28, #0x0]\n" + ".inst 0x6e46ec0e // bfmmla v14.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec56 // bfmmla v22.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec9e // bfmmla v30.4s, v4.8h, v6.8h\n" + "ldr q6, [x28, #0x10]\n" + ".inst 0x6e47ec0b // bfmmla v11.4s, v0.8h, v7.8h\n" + "add x28, x28, #0x20\n" + ".inst 0x6e47ec53 // bfmmla v19.4s, v2.8h, v7.8h\n" + ".inst 0x6e47ec9b // bfmmla v27.4s, v4.8h, v7.8h\n" + ".inst 0x6e46ec0f // bfmmla v15.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + ".inst 0x6e46ec9f // bfmmla v31.4s, v4.8h, v6.8h\n" + "217:" // Height 6: Multiply loop: No odd multiplies + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 207b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "uzp1 v7.2d, v8.2d, v12.2d\n" + "add x22, x23, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "uzp2 v8.2d, v8.2d, v12.2d\n" + "uzp1 v12.2d, v9.2d, v13.2d\n" + "uzp2 v9.2d, v9.2d, v13.2d\n" + "uzp1 v13.2d, v10.2d, v14.2d\n" + "add x20, x21, x19, LSL #2\n" + "uzp2 v10.2d, v10.2d, v14.2d\n" + "uzp1 v14.2d, v11.2d, v15.2d\n" + "uzp2 v11.2d, v11.2d, v15.2d\n" + "uzp1 v15.2d, v16.2d, v20.2d\n" + "uzp2 v16.2d, v16.2d, v20.2d\n" + "uzp1 v20.2d, v17.2d, v21.2d\n" + "uzp2 v17.2d, v17.2d, v21.2d\n" + "uzp1 v21.2d, v18.2d, v22.2d\n" + "uzp2 v18.2d, v18.2d, v22.2d\n" + "uzp1 v22.2d, v19.2d, v23.2d\n" + "uzp2 v19.2d, v19.2d, v23.2d\n" + "uzp1 v23.2d, v24.2d, v28.2d\n" + "uzp2 v24.2d, v24.2d, v28.2d\n" + "uzp1 v28.2d, v25.2d, v29.2d\n" + "uzp2 v25.2d, v25.2d, v29.2d\n" + "uzp1 v29.2d, v26.2d, v30.2d\n" + "uzp2 v26.2d, v26.2d, v30.2d\n" + "uzp1 v30.2d, v27.2d, v31.2d\n" + "uzp2 v27.2d, v27.2d, v31.2d\n" + "tbz %x[flags], #1, 218f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1r { v1.4s }, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1r { v0.4s }, [x19]\n" + "fmin v7.4s, v7.4s, v1.4s\n" + "fmin v12.4s, v12.4s, v1.4s\n" + "fmin v13.4s, v13.4s, v1.4s\n" + "fmin v14.4s, v14.4s, v1.4s\n" + "fmin v8.4s, v8.4s, v1.4s\n" + "fmin v9.4s, v9.4s, v1.4s\n" + "fmin v10.4s, v10.4s, v1.4s\n" + "fmin v11.4s, v11.4s, v1.4s\n" + "fmin v15.4s, v15.4s, v1.4s\n" + "fmin v20.4s, v20.4s, v1.4s\n" + "fmin v21.4s, v21.4s, v1.4s\n" + "fmin v22.4s, v22.4s, v1.4s\n" + "fmin v16.4s, v16.4s, v1.4s\n" + "fmin v17.4s, v17.4s, v1.4s\n" + "fmin v18.4s, v18.4s, v1.4s\n" + "fmin v19.4s, v19.4s, v1.4s\n" + "fmin v23.4s, v23.4s, v1.4s\n" + "fmin v28.4s, v28.4s, v1.4s\n" + "fmin v29.4s, v29.4s, v1.4s\n" + "fmin v30.4s, v30.4s, v1.4s\n" + "fmin v24.4s, v24.4s, v1.4s\n" + "fmin v25.4s, v25.4s, v1.4s\n" + "fmin v26.4s, v26.4s, v1.4s\n" + "fmin v27.4s, v27.4s, v1.4s\n" + "fmax v7.4s, v7.4s, v0.4s\n" + "fmax v12.4s, v12.4s, v0.4s\n" + "fmax v13.4s, v13.4s, v0.4s\n" + "fmax v14.4s, v14.4s, v0.4s\n" + "fmax v8.4s, v8.4s, v0.4s\n" + "fmax v9.4s, v9.4s, v0.4s\n" + "fmax v10.4s, v10.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + "fmax v15.4s, v15.4s, v0.4s\n" + "fmax v20.4s, v20.4s, v0.4s\n" + "fmax v21.4s, v21.4s, v0.4s\n" + "fmax v22.4s, v22.4s, v0.4s\n" + "fmax v16.4s, v16.4s, v0.4s\n" + "fmax v17.4s, v17.4s, v0.4s\n" + "fmax v18.4s, v18.4s, v0.4s\n" + "fmax v19.4s, v19.4s, v0.4s\n" + "fmax v23.4s, v23.4s, v0.4s\n" + "fmax v28.4s, v28.4s, v0.4s\n" + "fmax v29.4s, v29.4s, v0.4s\n" + "fmax v30.4s, v30.4s, v0.4s\n" + "fmax v24.4s, v24.4s, v0.4s\n" + "fmax v25.4s, v25.4s, v0.4s\n" + "fmax v26.4s, v26.4s, v0.4s\n" + "fmax v27.4s, v27.4s, v0.4s\n" + "218:" // Height 6: No activation + "cmp x13, #0x10\n" + "bge 227f\n" + "tbz x13, #3, 222f\n" + "st1 { v7.4s }, [x12], #0x10\n" + "st1 { v12.4s }, [x12], #0x10\n" + "st1 { v8.4s }, [x24], #0x10\n" + "st1 { v9.4s }, [x24], #0x10\n" + "st1 { v15.4s }, [x23], #0x10\n" + "st1 { v20.4s }, [x23], #0x10\n" + "st1 { v16.4s }, [x22], #0x10\n" + "st1 { v17.4s }, [x22], #0x10\n" + "st1 { v23.4s }, [x21], #0x10\n" + "st1 { v28.4s }, [x21], #0x10\n" + "st1 { v24.4s }, [x20], #0x10\n" + "st1 { v25.4s }, [x20], #0x10\n" + "tbz x13, #2, 220f\n" + "st1 { v13.4s }, [x12], #0x10\n" + "st1 { v10.4s }, [x24], #0x10\n" + "st1 { v21.4s }, [x23], #0x10\n" + "st1 { v18.4s }, [x22], #0x10\n" + "st1 { v29.4s }, [x21], #0x10\n" + "st1 { v26.4s }, [x20], #0x10\n" + "tbz x13, #1, 219f\n" + "str d14, [x12], #0x8\n" + "str d11, [x24], #0x8\n" + "str d22, [x23], #0x8\n" + "str d19, [x22], #0x8\n" + "str d30, [x21], #0x8\n" + "str d27, [x20], #0x8\n" + "tbz x13, #0, 226f\n" + "st1 { v14.s }[2], [x12]\n" + "st1 { v11.s }[2], [x24]\n" + "st1 { v22.s }[2], [x23]\n" + "st1 { v19.s }[2], [x22]\n" + "st1 { v30.s }[2], [x21]\n" + "st1 { v27.s }[2], [x20]\n" + "b 226f\n" + "219:" // Height 6: Partial direct writeback: partial_1_12 + "tbz x13, #0, 226f\n" + "str s14, [x12, #0x0]\n" + "str s11, [x24, #0x0]\n" + "str s22, [x23, #0x0]\n" + "str s19, [x22, #0x0]\n" + "str s30, [x21, #0x0]\n" + "str s27, [x20, #0x0]\n" + "b 226f\n" + "220:" // Height 6: Partial direct writeback: partial_2_8 + "tbz x13, #1, 221f\n" + "str d13, [x12], #0x8\n" + "str d10, [x24], #0x8\n" + "str d21, [x23], #0x8\n" + "str d18, [x22], #0x8\n" + "str d29, [x21], #0x8\n" + "str d26, [x20], #0x8\n" + "tbz x13, #0, 226f\n" + "st1 { v13.s }[2], [x12]\n" + "st1 { v10.s }[2], [x24]\n" + "st1 { v21.s }[2], [x23]\n" + "st1 { v18.s }[2], [x22]\n" + "st1 { v29.s }[2], [x21]\n" + "st1 { v26.s }[2], [x20]\n" + "b 226f\n" + "221:" // Height 6: Partial direct writeback: partial_1_8 + "tbz x13, #0, 226f\n" + "str s13, [x12, #0x0]\n" + "str s10, [x24, #0x0]\n" + "str s21, [x23, #0x0]\n" + "str s18, [x22, #0x0]\n" + "str s29, [x21, #0x0]\n" + "str s26, [x20, #0x0]\n" + "b 226f\n" + "222:" // Height 6: Partial direct writeback: partial_4_0 + "tbz x13, #2, 224f\n" + "st1 { v7.4s }, [x12], #0x10\n" + "st1 { v8.4s }, [x24], #0x10\n" + "st1 { v15.4s }, [x23], #0x10\n" + "st1 { v16.4s }, [x22], #0x10\n" + "st1 { v23.4s }, [x21], #0x10\n" + "st1 { v24.4s }, [x20], #0x10\n" + "tbz x13, #1, 223f\n" + "str d12, [x12], #0x8\n" + "str d9, [x24], #0x8\n" + "str d20, [x23], #0x8\n" + "str d17, [x22], #0x8\n" + "str d28, [x21], #0x8\n" + "str d25, [x20], #0x8\n" + "tbz x13, #0, 226f\n" + "st1 { v12.s }[2], [x12]\n" + "st1 { v9.s }[2], [x24]\n" + "st1 { v20.s }[2], [x23]\n" + "st1 { v17.s }[2], [x22]\n" + "st1 { v28.s }[2], [x21]\n" + "st1 { v25.s }[2], [x20]\n" + "b 226f\n" + "223:" // Height 6: Partial direct writeback: partial_1_4 + "tbz x13, #0, 226f\n" + "str s12, [x12, #0x0]\n" + "str s9, [x24, #0x0]\n" + "str s20, [x23, #0x0]\n" + "str s17, [x22, #0x0]\n" + "str s28, [x21, #0x0]\n" + "str s25, [x20, #0x0]\n" + "b 226f\n" + "224:" // Height 6: Partial direct writeback: partial_2_0 + "tbz x13, #1, 225f\n" + "str d7, [x12], #0x8\n" + "str d8, [x24], #0x8\n" + "str d15, [x23], #0x8\n" + "str d16, [x22], #0x8\n" + "str d23, [x21], #0x8\n" + "str d24, [x20], #0x8\n" + "tbz x13, #0, 226f\n" + "st1 { v7.s }[2], [x12]\n" + "st1 { v8.s }[2], [x24]\n" + "st1 { v15.s }[2], [x23]\n" + "st1 { v16.s }[2], [x22]\n" + "st1 { v23.s }[2], [x21]\n" + "st1 { v24.s }[2], [x20]\n" + "b 226f\n" + "225:" // Height 6: Partial direct writeback: partial_1_0 + "str s7, [x12, #0x0]\n" + "str s8, [x24, #0x0]\n" + "str s15, [x23, #0x0]\n" + "str s16, [x22, #0x0]\n" + "str s23, [x21, #0x0]\n" + "str s24, [x20, #0x0]\n" + "226:" // Height 6: Partial direct writeback: Done + "b 228f\n" + "227:" // Height 6: Full writeback + "str q7, [x12, #0x0]\n" + "str q12, [x12, #0x10]\n" + "str q13, [x12, #0x20]\n" + "str q14, [x12, #0x30]\n" + "add x12, x12, #0x40\n" + "str q8, [x24, #0x0]\n" + "str q9, [x24, #0x10]\n" + "str q10, [x24, #0x20]\n" + "str q11, [x24, #0x30]\n" + "str q15, [x23, #0x0]\n" + "str q20, [x23, #0x10]\n" + "str q21, [x23, #0x20]\n" + "str q22, [x23, #0x30]\n" + "str q16, [x22, #0x0]\n" + "str q17, [x22, #0x10]\n" + "str q18, [x22, #0x20]\n" + "str q19, [x22, #0x30]\n" + "str q23, [x21, #0x0]\n" + "str q28, [x21, #0x10]\n" + "str q29, [x21, #0x20]\n" + "str q30, [x21, #0x30]\n" + "str q24, [x20, #0x0]\n" + "str q25, [x20, #0x10]\n" + "str q26, [x20, #0x20]\n" + "str q27, [x20, #0x30]\n" + "228:" // Height 6: Writeback done + "subs x13, x13, #0x10\n" + "bgt 192b\n" + "subs %x[M], %x[M], #0x6\n" + "beq 230f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 229f\n" + "add x20, x20, #0x6\n" + "str x20, [%x[args_ptr], %[offsetof_input_offset]]\n" + "b 1b\n" + "229:" // Update direct input + "mov x19, #0xc\n" + "madd %x[input_ptr], x19, x20, %x[input_ptr]\n" + "b 1b\n" + "230:" // Exit + : [M] "+&r" (M), [input_ptr] "+&r" (input_ptr), [output_ptr] "+&r" (output_ptr) + : [args_ptr] "r" (&ka), [bias] "r" (bias), [flags] "r" (flags), [offset_max] "I" (offsetof(KernelArgs, maxval)), [offset_min] "I" (offsetof(KernelArgs, minval)), [offsetof_B_ptr] "I" (offsetof(KernelArgs, B_ptr)), [offsetof_B_stride] "I" (offsetof(KernelArgs, B_stride)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_cur_B_ptr] "I" (offsetof(KernelArgs, cur_B_ptr)), [offsetof_input_initial_col] "I" (offsetof(KernelArgs, input_initial_col)), [offsetof_input_offset] "I" (offsetof(KernelArgs, input_offset)), [offsetof_num_strings] "I" (offsetof(KernelArgs, num_strings)), [offsetof_output_offset] "I" (offsetof(KernelArgs, output_offset)), [offsetof_string_lengths] "I" (offsetof(KernelArgs, string_lengths)) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", "x12", "x13", "x14", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +} + +} // namespace arm_gemm +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp16_mla_6x32.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp16_mla_6x32.hpp new file mode 100644 index 0000000000..f7506e5123 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp16_mla_6x32.hpp @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#pragma once +#ifdef __aarch64__ + +#include "../std_transforms_fixed.hpp" +#include "../kernel_weight_format.hpp" +#include "../performance_parameters.hpp" + +#define ARGLIST \ + unsigned int, const unsigned int *, \ + IndirectInputArg<__fp16>, \ + size_t, size_t, \ + const __fp16 *, \ + size_t, \ + IndirectOutputArg<__fp16>, \ + const __fp16 *, Activation, bool + +namespace arm_gemm +{ +// Actual kernel implementations +void a64_ffhybrid_fp16_mla_6x32( ARGLIST ); + +class cls_a64_ffhybrid_fp16_mla_6x32 +{ +public: + typedef __fp16 lhs_operand_type; + typedef __fp16 rhs_operand_type; + typedef __fp16 result_type; + + typedef void (*kern_type)( ARGLIST ); + + /* Kernel blocking parameters */ + static constexpr unsigned int out_height() + { + return 6; + } + static unsigned int stripe_width() + { + return 8; + } + + static KernelWeightFormat kernel_weight_format() + { + return KernelWeightFormat::VL128_BL16; + } + + static unsigned int out_width() + { + return 32; + } + + static constexpr unsigned int k_unroll() + { + return 1; + } + + static constexpr bool supports_accumulate() + { + return true; + } + + StdTransformsFixed transforms = {}; + template + static inline PerformanceParameters get_performance_parameters(const CPUInfo *ci) + { + if (std::is_same::value) { + switch (ci->get_cpu_model()) { + default: + return { 29.14 }; + } + } + + return { 1.0 }; + } + + // Default to the generic kernel + kern_type kernel=a64_ffhybrid_fp16_mla_6x32; + cls_a64_ffhybrid_fp16_mla_6x32(const CPUInfo *) + { + } +}; + +} // namespace arm_gemm + +#undef ARGLIST +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp16_mla_6x32/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp16_mla_6x32/generic.cpp new file mode 100644 index 0000000000..e1458b39ab --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp16_mla_6x32/generic.cpp @@ -0,0 +1,5429 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)) + +#include "arm_gemm.hpp" +#include "../../utils.hpp" + +#include +#include + +namespace arm_gemm { + +void a64_ffhybrid_fp16_mla_6x32 ( + unsigned int num_strings, const unsigned int *string_lengths, IndirectInputArg<__fp16> A_arg, + size_t M, size_t N, const __fp16 *B_ptr, size_t B_stride, IndirectOutputArg<__fp16> output_arg, + const __fp16 *bias, Activation act, bool accumulate +) +{ + struct KernelArgs { + __fp16 maxval = static_cast<__fp16>(std::numeric_limits::infinity()); + __fp16 minval = - static_cast<__fp16>(std::numeric_limits::infinity()); + unsigned int num_strings = {}; + const unsigned int *string_lengths = {}; + size_t N = {}; + const __fp16 *B_ptr = {}; + const __fp16 *cur_B_ptr = {}; + size_t B_stride = {}; + size_t output_offset = {}; + size_t input_initial_col = {}; + size_t input_offset = {}; + } ka; + + unsigned long flags=0; + void *output_ptr; + void *input_ptr; + + if (output_arg.is_indirect) { + output_ptr=(void *)(output_arg.indirect.ptr); + ka.output_offset=output_arg.indirect.offset; + flags |= 0x4; + } else { + output_ptr=(void *)(output_arg.direct.base); + ka.output_offset=output_arg.direct.stride; + } + + if (A_arg.is_indirect) { + input_ptr=(void *)(A_arg.indirect.ptr); + ka.input_offset=A_arg.indirect.start_row; + ka.input_initial_col=A_arg.indirect.start_col; + flags |= 0x8; + } else { + assert(num_strings==1); + input_ptr=(void *)(A_arg.direct.base); + ka.input_offset=A_arg.direct.stride; + } + if (accumulate) { + flags |= 0x1; + } + ka.num_strings = num_strings; + ka.string_lengths = string_lengths; + ka.N = N; + ka.B_ptr = B_ptr; + ka.B_stride = B_stride; + switch(act.type) { + default: + case Activation::Type::None: + break; + case Activation::Type::BoundedReLU: + ka.maxval = static_cast<__fp16>(act.param1); + /* fall through */ + case Activation::Type::ReLU: + ka.minval = 0; + flags |= 0x2; + break; + } + __asm__ __volatile__( + "1:" // Row loop + "cmp %x[M], #0x6\n" + "bge 251f\n" + "cmp %x[M], #0x4\n" + "bgt 201f\n" + "beq 151f\n" + "cmp %x[M], #0x2\n" + "bgt 101f\n" + "beq 51f\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "2:" // Height 1: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "cmp x13, #0x18\n" + "bgt 3f\n" + "cmp x13, #0x10\n" + "mov x28, x11\n" + "bgt 3f\n" + "cmp x13, #0x8\n" + "mov x9, x11\n" + "bgt 3f\n" + "mov x10, x11\n" + "3:" // Height 1: B setup done + "cbz x14, 4f\n" + "ldr q8, [x14, #0x0]\n" + "ldr q9, [x14, #0x10]\n" + "ldr q10, [x14, #0x20]\n" + "ldr q11, [x14, #0x30]\n" + "add x14, x14, #0x40\n" + "b 23f\n" + "4:" // Height 1: no bias + "tbz %x[flags], #0, 22f\n" + "cmp x13, #0x20\n" + "bge 21f\n" + "tbz x13, #4, 12f\n" + "ld1 { v8.8h }, [x12], #0x10\n" + "ld1 { v9.8h }, [x12], #0x10\n" + "tbz x13, #3, 8f\n" + "ld1 { v10.8h }, [x12], #0x10\n" + "tbz x13, #2, 6f\n" + "ldr d11, [x12], #0x8\n" + "tbz x13, #1, 5f\n" + "ld1 { v11.s }[2], [x12], #0x4\n" + "mov x19, #0x3c\n" + "tbz x13, #0, 20f\n" + "ld1 { v11.h }[6], [x12]\n" + "b 20f\n" + "5:" // Height 1: Partial accumulate: partial_1_28 + "mov x19, #0x38\n" + "tbz x13, #0, 20f\n" + "ld1 { v11.h }[4], [x12]\n" + "b 20f\n" + "6:" // Height 1: Partial accumulate: partial_2_24 + "tbz x13, #1, 7f\n" + "ldr s11, [x12], #0x4\n" + "mov x19, #0x34\n" + "tbz x13, #0, 20f\n" + "ld1 { v11.h }[2], [x12]\n" + "b 20f\n" + "7:" // Height 1: Partial accumulate: partial_1_24 + "mov x19, #0x30\n" + "tbz x13, #0, 20f\n" + "ldr h11, [x12, #0x0]\n" + "b 20f\n" + "8:" // Height 1: Partial accumulate: partial_4_16 + "tbz x13, #2, 10f\n" + "ldr d10, [x12], #0x8\n" + "tbz x13, #1, 9f\n" + "ld1 { v10.s }[2], [x12], #0x4\n" + "mov x19, #0x2c\n" + "tbz x13, #0, 20f\n" + "ld1 { v10.h }[6], [x12]\n" + "b 20f\n" + "9:" // Height 1: Partial accumulate: partial_1_20 + "mov x19, #0x28\n" + "tbz x13, #0, 20f\n" + "ld1 { v10.h }[4], [x12]\n" + "b 20f\n" + "10:" // Height 1: Partial accumulate: partial_2_16 + "tbz x13, #1, 11f\n" + "ldr s10, [x12], #0x4\n" + "mov x19, #0x24\n" + "tbz x13, #0, 20f\n" + "ld1 { v10.h }[2], [x12]\n" + "b 20f\n" + "11:" // Height 1: Partial accumulate: partial_1_16 + "mov x19, #0x20\n" + "tbz x13, #0, 20f\n" + "ldr h10, [x12, #0x0]\n" + "b 20f\n" + "12:" // Height 1: Partial accumulate: partial_8_0 + "tbz x13, #3, 16f\n" + "ld1 { v8.8h }, [x12], #0x10\n" + "tbz x13, #2, 14f\n" + "ldr d9, [x12], #0x8\n" + "tbz x13, #1, 13f\n" + "ld1 { v9.s }[2], [x12], #0x4\n" + "mov x19, #0x1c\n" + "tbz x13, #0, 20f\n" + "ld1 { v9.h }[6], [x12]\n" + "b 20f\n" + "13:" // Height 1: Partial accumulate: partial_1_12 + "mov x19, #0x18\n" + "tbz x13, #0, 20f\n" + "ld1 { v9.h }[4], [x12]\n" + "b 20f\n" + "14:" // Height 1: Partial accumulate: partial_2_8 + "tbz x13, #1, 15f\n" + "ldr s9, [x12], #0x4\n" + "mov x19, #0x14\n" + "tbz x13, #0, 20f\n" + "ld1 { v9.h }[2], [x12]\n" + "b 20f\n" + "15:" // Height 1: Partial accumulate: partial_1_8 + "mov x19, #0x10\n" + "tbz x13, #0, 20f\n" + "ldr h9, [x12, #0x0]\n" + "b 20f\n" + "16:" // Height 1: Partial accumulate: partial_4_0 + "tbz x13, #2, 18f\n" + "ldr d8, [x12], #0x8\n" + "tbz x13, #1, 17f\n" + "ld1 { v8.s }[2], [x12], #0x4\n" + "mov x19, #0xc\n" + "tbz x13, #0, 20f\n" + "ld1 { v8.h }[6], [x12]\n" + "b 20f\n" + "17:" // Height 1: Partial accumulate: partial_1_4 + "mov x19, #0x8\n" + "tbz x13, #0, 20f\n" + "ld1 { v8.h }[4], [x12]\n" + "b 20f\n" + "18:" // Height 1: Partial accumulate: partial_2_0 + "tbz x13, #1, 19f\n" + "ldr s8, [x12], #0x4\n" + "mov x19, #0x4\n" + "tbz x13, #0, 20f\n" + "ld1 { v8.h }[2], [x12]\n" + "b 20f\n" + "19:" // Height 1: Partial accumulate: partial_1_0 + "ldr h8, [x12, #0x0]\n" + "mov x19, #0x0\n" + "20:" // Height 1: Partial accumulate: Done + "sub x12, x12, x19\n" + "b 23f\n" + "21:" // Height 1: full accumulate + "ldr q8, [x12, #0x0]\n" + "ldr q9, [x12, #0x10]\n" + "ldr q10, [x12, #0x20]\n" + "ldr q11, [x12, #0x30]\n" + "b 23f\n" + "22:" // Height 1: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "23:" // Height 1: setup done + "mov x27, #0x0\n" + "24:" // Height 1: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 25f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "cbnz x27, 26f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "b 26f\n" + "25:" // Height 1: setup direct input + "mov x25, %x[input_ptr]\n" + "26:" // Height 1: input setup done + "cmp x26, #0x8\n" + "blt 29f\n" + "ldr q0, [x25, #0x0]\n" + "ldr q6, [x11, #0x0]\n" + "cmp x26, #0x10\n" + "ldr q7, [x10, #0x0]\n" + "blt 28f\n" + "27:" // Height 1: Multiply loop: Main loop head + "fmla v8.8h, v6.8h, v0.h[0]\n" + "ldr q6, [x9, #0x0]\n" + "fmla v9.8h, v7.8h, v0.h[0]\n" + "ldr q7, [x28, #0x0]\n" + "fmla v10.8h, v6.8h, v0.h[0]\n" + "ldr q6, [x11, #0x10]\n" + "fmla v11.8h, v7.8h, v0.h[0]\n" + "ldr q7, [x10, #0x10]\n" + "fmla v8.8h, v6.8h, v0.h[1]\n" + "ldr q6, [x9, #0x10]\n" + "fmla v9.8h, v7.8h, v0.h[1]\n" + "ldr q7, [x28, #0x10]\n" + "fmla v10.8h, v6.8h, v0.h[1]\n" + "ldr q6, [x11, #0x20]\n" + "fmla v11.8h, v7.8h, v0.h[1]\n" + "ldr q7, [x10, #0x20]\n" + "fmla v8.8h, v6.8h, v0.h[2]\n" + "ldr q6, [x9, #0x20]\n" + "fmla v9.8h, v7.8h, v0.h[2]\n" + "ldr q7, [x28, #0x20]\n" + "fmla v10.8h, v6.8h, v0.h[2]\n" + "ldr q6, [x11, #0x30]\n" + "fmla v11.8h, v7.8h, v0.h[2]\n" + "ldr q7, [x10, #0x30]\n" + "fmla v8.8h, v6.8h, v0.h[3]\n" + "ldr q6, [x9, #0x30]\n" + "fmla v9.8h, v7.8h, v0.h[3]\n" + "ldr q7, [x28, #0x30]\n" + "fmla v10.8h, v6.8h, v0.h[3]\n" + "ldr q6, [x11, #0x40]\n" + "fmla v11.8h, v7.8h, v0.h[3]\n" + "ldr q7, [x10, #0x40]\n" + "fmla v8.8h, v6.8h, v0.h[4]\n" + "ldr q6, [x9, #0x40]\n" + "fmla v9.8h, v7.8h, v0.h[4]\n" + "ldr q7, [x28, #0x40]\n" + "fmla v10.8h, v6.8h, v0.h[4]\n" + "ldr q6, [x11, #0x50]\n" + "fmla v11.8h, v7.8h, v0.h[4]\n" + "ldr q7, [x10, #0x50]\n" + "fmla v8.8h, v6.8h, v0.h[5]\n" + "ldr q6, [x9, #0x50]\n" + "fmla v9.8h, v7.8h, v0.h[5]\n" + "ldr q7, [x28, #0x50]\n" + "fmla v10.8h, v6.8h, v0.h[5]\n" + "ldr q6, [x11, #0x60]\n" + "fmla v11.8h, v7.8h, v0.h[5]\n" + "ldr q7, [x10, #0x60]\n" + "fmla v8.8h, v6.8h, v0.h[6]\n" + "ldr q6, [x9, #0x60]\n" + "fmla v9.8h, v7.8h, v0.h[6]\n" + "ldr q7, [x28, #0x60]\n" + "fmla v10.8h, v6.8h, v0.h[6]\n" + "ldr q6, [x11, #0x70]\n" + "fmla v11.8h, v7.8h, v0.h[6]\n" + "ldr q7, [x10, #0x70]\n" + "fmla v8.8h, v6.8h, v0.h[7]\n" + "ldr q6, [x9, #0x70]\n" + "fmla v9.8h, v7.8h, v0.h[7]\n" + "ldr q7, [x28, #0x70]\n" + "sub x26, x26, #0x8\n" + "cmp x26, #0x10\n" + "fmla v10.8h, v6.8h, v0.h[7]\n" + "fmla v11.8h, v7.8h, v0.h[7]\n" + "add x25, x25, #0x10\n" + "ldr q0, [x25, #0x0]\n" + "add x11, x11, #0x80\n" + "ldr q6, [x11, #0x0]\n" + "add x10, x10, #0x80\n" + "ldr q7, [x10, #0x0]\n" + "add x9, x9, #0x80\n" + "add x28, x28, #0x80\n" + "bge 27b\n" + "28:" // Height 1: Multiply loop: Single iteration only + "fmla v8.8h, v6.8h, v0.h[0]\n" + "ldr q6, [x9, #0x0]\n" + "fmla v9.8h, v7.8h, v0.h[0]\n" + "ldr q7, [x28, #0x0]\n" + "fmla v10.8h, v6.8h, v0.h[0]\n" + "ldr q6, [x11, #0x10]\n" + "fmla v11.8h, v7.8h, v0.h[0]\n" + "ldr q7, [x10, #0x10]\n" + "fmla v8.8h, v6.8h, v0.h[1]\n" + "ldr q6, [x9, #0x10]\n" + "fmla v9.8h, v7.8h, v0.h[1]\n" + "ldr q7, [x28, #0x10]\n" + "fmla v10.8h, v6.8h, v0.h[1]\n" + "ldr q6, [x11, #0x20]\n" + "fmla v11.8h, v7.8h, v0.h[1]\n" + "ldr q7, [x10, #0x20]\n" + "fmla v8.8h, v6.8h, v0.h[2]\n" + "ldr q6, [x9, #0x20]\n" + "fmla v9.8h, v7.8h, v0.h[2]\n" + "ldr q7, [x28, #0x20]\n" + "fmla v10.8h, v6.8h, v0.h[2]\n" + "ldr q6, [x11, #0x30]\n" + "fmla v11.8h, v7.8h, v0.h[2]\n" + "ldr q7, [x10, #0x30]\n" + "fmla v8.8h, v6.8h, v0.h[3]\n" + "ldr q6, [x9, #0x30]\n" + "fmla v9.8h, v7.8h, v0.h[3]\n" + "ldr q7, [x28, #0x30]\n" + "fmla v10.8h, v6.8h, v0.h[3]\n" + "ldr q6, [x11, #0x40]\n" + "fmla v11.8h, v7.8h, v0.h[3]\n" + "ldr q7, [x10, #0x40]\n" + "fmla v8.8h, v6.8h, v0.h[4]\n" + "ldr q6, [x9, #0x40]\n" + "fmla v9.8h, v7.8h, v0.h[4]\n" + "ldr q7, [x28, #0x40]\n" + "fmla v10.8h, v6.8h, v0.h[4]\n" + "ldr q6, [x11, #0x50]\n" + "fmla v11.8h, v7.8h, v0.h[4]\n" + "ldr q7, [x10, #0x50]\n" + "fmla v8.8h, v6.8h, v0.h[5]\n" + "ldr q6, [x9, #0x50]\n" + "fmla v9.8h, v7.8h, v0.h[5]\n" + "ldr q7, [x28, #0x50]\n" + "fmla v10.8h, v6.8h, v0.h[5]\n" + "ldr q6, [x11, #0x60]\n" + "fmla v11.8h, v7.8h, v0.h[5]\n" + "ldr q7, [x10, #0x60]\n" + "fmla v8.8h, v6.8h, v0.h[6]\n" + "ldr q6, [x9, #0x60]\n" + "fmla v9.8h, v7.8h, v0.h[6]\n" + "ldr q7, [x28, #0x60]\n" + "fmla v10.8h, v6.8h, v0.h[6]\n" + "ldr q6, [x11, #0x70]\n" + "fmla v11.8h, v7.8h, v0.h[6]\n" + "ldr q7, [x10, #0x70]\n" + "fmla v8.8h, v6.8h, v0.h[7]\n" + "ldr q6, [x9, #0x70]\n" + "fmla v9.8h, v7.8h, v0.h[7]\n" + "ldr q7, [x28, #0x70]\n" + "sub x26, x26, #0x8\n" + "fmla v10.8h, v6.8h, v0.h[7]\n" + "fmla v11.8h, v7.8h, v0.h[7]\n" + "add x25, x25, #0x10\n" + "add x11, x11, #0x80\n" + "add x10, x10, #0x80\n" + "add x9, x9, #0x80\n" + "add x28, x28, #0x80\n" + "29:" // Height 1: Multiply loop: Main loop skip + "cbz x26, 31f\n" + "30:" // Height 1: Multiply loop: Odd block loop + "ldr h0, [x25], #0x2\n" + "ldr q6, [x11, #0x0]\n" + "fmla v8.8h, v6.8h, v0.h[0]\n" + "sub x26, x26, #0x1\n" + "ldr q7, [x10, #0x0]\n" + "ldr q6, [x9, #0x0]\n" + "fmla v9.8h, v7.8h, v0.h[0]\n" + "fmla v10.8h, v6.8h, v0.h[0]\n" + "ldr q7, [x28, #0x0]\n" + "fmla v11.8h, v7.8h, v0.h[0]\n" + "add x11, x11, #0x10\n" + "add x10, x10, #0x10\n" + "add x9, x9, #0x10\n" + "add x28, x28, #0x10\n" + "cbnz x26, 30b\n" + "31:" // Height 1: Multiply loop: No odd multiplies + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 24b\n" + "tbz %x[flags], #1, 32f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1r { v1.8h }, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1r { v0.8h }, [x19]\n" + "fmin v8.8h, v8.8h, v1.8h\n" + "fmin v9.8h, v9.8h, v1.8h\n" + "fmin v10.8h, v10.8h, v1.8h\n" + "fmin v11.8h, v11.8h, v1.8h\n" + "fmax v8.8h, v8.8h, v0.8h\n" + "fmax v9.8h, v9.8h, v0.8h\n" + "fmax v10.8h, v10.8h, v0.8h\n" + "fmax v11.8h, v11.8h, v0.8h\n" + "32:" // Height 1: No activation + "cmp x13, #0x20\n" + "bge 49f\n" + "tbz x13, #4, 40f\n" + "st1 { v8.8h }, [x12], #0x10\n" + "st1 { v9.8h }, [x12], #0x10\n" + "tbz x13, #3, 36f\n" + "st1 { v10.8h }, [x12], #0x10\n" + "tbz x13, #2, 34f\n" + "str d11, [x12], #0x8\n" + "tbz x13, #1, 33f\n" + "st1 { v11.s }[2], [x12], #0x4\n" + "tbz x13, #0, 48f\n" + "st1 { v11.h }[6], [x12]\n" + "b 48f\n" + "33:" // Height 1: Partial direct writeback: partial_1_28 + "tbz x13, #0, 48f\n" + "st1 { v11.h }[4], [x12]\n" + "b 48f\n" + "34:" // Height 1: Partial direct writeback: partial_2_24 + "tbz x13, #1, 35f\n" + "str s11, [x12], #0x4\n" + "tbz x13, #0, 48f\n" + "st1 { v11.h }[2], [x12]\n" + "b 48f\n" + "35:" // Height 1: Partial direct writeback: partial_1_24 + "tbz x13, #0, 48f\n" + "str h11, [x12, #0x0]\n" + "b 48f\n" + "36:" // Height 1: Partial direct writeback: partial_4_16 + "tbz x13, #2, 38f\n" + "str d10, [x12], #0x8\n" + "tbz x13, #1, 37f\n" + "st1 { v10.s }[2], [x12], #0x4\n" + "tbz x13, #0, 48f\n" + "st1 { v10.h }[6], [x12]\n" + "b 48f\n" + "37:" // Height 1: Partial direct writeback: partial_1_20 + "tbz x13, #0, 48f\n" + "st1 { v10.h }[4], [x12]\n" + "b 48f\n" + "38:" // Height 1: Partial direct writeback: partial_2_16 + "tbz x13, #1, 39f\n" + "str s10, [x12], #0x4\n" + "tbz x13, #0, 48f\n" + "st1 { v10.h }[2], [x12]\n" + "b 48f\n" + "39:" // Height 1: Partial direct writeback: partial_1_16 + "tbz x13, #0, 48f\n" + "str h10, [x12, #0x0]\n" + "b 48f\n" + "40:" // Height 1: Partial direct writeback: partial_8_0 + "tbz x13, #3, 44f\n" + "st1 { v8.8h }, [x12], #0x10\n" + "tbz x13, #2, 42f\n" + "str d9, [x12], #0x8\n" + "tbz x13, #1, 41f\n" + "st1 { v9.s }[2], [x12], #0x4\n" + "tbz x13, #0, 48f\n" + "st1 { v9.h }[6], [x12]\n" + "b 48f\n" + "41:" // Height 1: Partial direct writeback: partial_1_12 + "tbz x13, #0, 48f\n" + "st1 { v9.h }[4], [x12]\n" + "b 48f\n" + "42:" // Height 1: Partial direct writeback: partial_2_8 + "tbz x13, #1, 43f\n" + "str s9, [x12], #0x4\n" + "tbz x13, #0, 48f\n" + "st1 { v9.h }[2], [x12]\n" + "b 48f\n" + "43:" // Height 1: Partial direct writeback: partial_1_8 + "tbz x13, #0, 48f\n" + "str h9, [x12, #0x0]\n" + "b 48f\n" + "44:" // Height 1: Partial direct writeback: partial_4_0 + "tbz x13, #2, 46f\n" + "str d8, [x12], #0x8\n" + "tbz x13, #1, 45f\n" + "st1 { v8.s }[2], [x12], #0x4\n" + "tbz x13, #0, 48f\n" + "st1 { v8.h }[6], [x12]\n" + "b 48f\n" + "45:" // Height 1: Partial direct writeback: partial_1_4 + "tbz x13, #0, 48f\n" + "st1 { v8.h }[4], [x12]\n" + "b 48f\n" + "46:" // Height 1: Partial direct writeback: partial_2_0 + "tbz x13, #1, 47f\n" + "str s8, [x12], #0x4\n" + "tbz x13, #0, 48f\n" + "st1 { v8.h }[2], [x12]\n" + "b 48f\n" + "47:" // Height 1: Partial direct writeback: partial_1_0 + "str h8, [x12, #0x0]\n" + "48:" // Height 1: Partial direct writeback: Done + "b 50f\n" + "49:" // Height 1: Full writeback + "str q8, [x12, #0x0]\n" + "str q9, [x12, #0x10]\n" + "str q10, [x12, #0x20]\n" + "str q11, [x12, #0x30]\n" + "add x12, x12, #0x40\n" + "50:" // Height 1: Writeback done + "subs x13, x13, #0x20\n" + "bgt 2b\n" + "b 302f\n" + "51:" // Height 2 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "52:" // Height 2: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "cmp x13, #0x18\n" + "bgt 53f\n" + "cmp x13, #0x10\n" + "mov x28, x11\n" + "bgt 53f\n" + "cmp x13, #0x8\n" + "mov x9, x11\n" + "bgt 53f\n" + "mov x10, x11\n" + "53:" // Height 2: B setup done + "cbz x14, 54f\n" + "ldr q8, [x14, #0x0]\n" + "ldr q9, [x14, #0x10]\n" + "mov v12.16b, v8.16b\n" + "mov v13.16b, v9.16b\n" + "ldr q10, [x14, #0x20]\n" + "ldr q11, [x14, #0x30]\n" + "mov v14.16b, v10.16b\n" + "mov v15.16b, v11.16b\n" + "add x14, x14, #0x40\n" + "b 73f\n" + "54:" // Height 2: no bias + "tbz %x[flags], #0, 72f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "cmp x13, #0x20\n" + "add x24, x12, x19, LSL #1\n" + "bge 71f\n" + "tbz x13, #4, 62f\n" + "ld1 { v8.8h }, [x12], #0x10\n" + "ld1 { v12.8h }, [x24], #0x10\n" + "ld1 { v9.8h }, [x12], #0x10\n" + "ld1 { v13.8h }, [x24], #0x10\n" + "tbz x13, #3, 58f\n" + "ld1 { v10.8h }, [x12], #0x10\n" + "ld1 { v14.8h }, [x24], #0x10\n" + "tbz x13, #2, 56f\n" + "ldr d11, [x12], #0x8\n" + "ldr d15, [x24], #0x8\n" + "tbz x13, #1, 55f\n" + "ld1 { v11.s }[2], [x12], #0x4\n" + "ld1 { v15.s }[2], [x24], #0x4\n" + "mov x19, #0x3c\n" + "tbz x13, #0, 70f\n" + "ld1 { v11.h }[6], [x12]\n" + "ld1 { v15.h }[6], [x24]\n" + "b 70f\n" + "55:" // Height 2: Partial accumulate: partial_1_28 + "mov x19, #0x38\n" + "tbz x13, #0, 70f\n" + "ld1 { v11.h }[4], [x12]\n" + "ld1 { v15.h }[4], [x24]\n" + "b 70f\n" + "56:" // Height 2: Partial accumulate: partial_2_24 + "tbz x13, #1, 57f\n" + "ldr s11, [x12], #0x4\n" + "ldr s15, [x24], #0x4\n" + "mov x19, #0x34\n" + "tbz x13, #0, 70f\n" + "ld1 { v11.h }[2], [x12]\n" + "ld1 { v15.h }[2], [x24]\n" + "b 70f\n" + "57:" // Height 2: Partial accumulate: partial_1_24 + "mov x19, #0x30\n" + "tbz x13, #0, 70f\n" + "ldr h11, [x12, #0x0]\n" + "ldr h15, [x24, #0x0]\n" + "b 70f\n" + "58:" // Height 2: Partial accumulate: partial_4_16 + "tbz x13, #2, 60f\n" + "ldr d10, [x12], #0x8\n" + "ldr d14, [x24], #0x8\n" + "tbz x13, #1, 59f\n" + "ld1 { v10.s }[2], [x12], #0x4\n" + "ld1 { v14.s }[2], [x24], #0x4\n" + "mov x19, #0x2c\n" + "tbz x13, #0, 70f\n" + "ld1 { v10.h }[6], [x12]\n" + "ld1 { v14.h }[6], [x24]\n" + "b 70f\n" + "59:" // Height 2: Partial accumulate: partial_1_20 + "mov x19, #0x28\n" + "tbz x13, #0, 70f\n" + "ld1 { v10.h }[4], [x12]\n" + "ld1 { v14.h }[4], [x24]\n" + "b 70f\n" + "60:" // Height 2: Partial accumulate: partial_2_16 + "tbz x13, #1, 61f\n" + "ldr s10, [x12], #0x4\n" + "ldr s14, [x24], #0x4\n" + "mov x19, #0x24\n" + "tbz x13, #0, 70f\n" + "ld1 { v10.h }[2], [x12]\n" + "ld1 { v14.h }[2], [x24]\n" + "b 70f\n" + "61:" // Height 2: Partial accumulate: partial_1_16 + "mov x19, #0x20\n" + "tbz x13, #0, 70f\n" + "ldr h10, [x12, #0x0]\n" + "ldr h14, [x24, #0x0]\n" + "b 70f\n" + "62:" // Height 2: Partial accumulate: partial_8_0 + "tbz x13, #3, 66f\n" + "ld1 { v8.8h }, [x12], #0x10\n" + "ld1 { v12.8h }, [x24], #0x10\n" + "tbz x13, #2, 64f\n" + "ldr d9, [x12], #0x8\n" + "ldr d13, [x24], #0x8\n" + "tbz x13, #1, 63f\n" + "ld1 { v9.s }[2], [x12], #0x4\n" + "ld1 { v13.s }[2], [x24], #0x4\n" + "mov x19, #0x1c\n" + "tbz x13, #0, 70f\n" + "ld1 { v9.h }[6], [x12]\n" + "ld1 { v13.h }[6], [x24]\n" + "b 70f\n" + "63:" // Height 2: Partial accumulate: partial_1_12 + "mov x19, #0x18\n" + "tbz x13, #0, 70f\n" + "ld1 { v9.h }[4], [x12]\n" + "ld1 { v13.h }[4], [x24]\n" + "b 70f\n" + "64:" // Height 2: Partial accumulate: partial_2_8 + "tbz x13, #1, 65f\n" + "ldr s9, [x12], #0x4\n" + "ldr s13, [x24], #0x4\n" + "mov x19, #0x14\n" + "tbz x13, #0, 70f\n" + "ld1 { v9.h }[2], [x12]\n" + "ld1 { v13.h }[2], [x24]\n" + "b 70f\n" + "65:" // Height 2: Partial accumulate: partial_1_8 + "mov x19, #0x10\n" + "tbz x13, #0, 70f\n" + "ldr h9, [x12, #0x0]\n" + "ldr h13, [x24, #0x0]\n" + "b 70f\n" + "66:" // Height 2: Partial accumulate: partial_4_0 + "tbz x13, #2, 68f\n" + "ldr d8, [x12], #0x8\n" + "ldr d12, [x24], #0x8\n" + "tbz x13, #1, 67f\n" + "ld1 { v8.s }[2], [x12], #0x4\n" + "ld1 { v12.s }[2], [x24], #0x4\n" + "mov x19, #0xc\n" + "tbz x13, #0, 70f\n" + "ld1 { v8.h }[6], [x12]\n" + "ld1 { v12.h }[6], [x24]\n" + "b 70f\n" + "67:" // Height 2: Partial accumulate: partial_1_4 + "mov x19, #0x8\n" + "tbz x13, #0, 70f\n" + "ld1 { v8.h }[4], [x12]\n" + "ld1 { v12.h }[4], [x24]\n" + "b 70f\n" + "68:" // Height 2: Partial accumulate: partial_2_0 + "tbz x13, #1, 69f\n" + "ldr s8, [x12], #0x4\n" + "ldr s12, [x24], #0x4\n" + "mov x19, #0x4\n" + "tbz x13, #0, 70f\n" + "ld1 { v8.h }[2], [x12]\n" + "ld1 { v12.h }[2], [x24]\n" + "b 70f\n" + "69:" // Height 2: Partial accumulate: partial_1_0 + "ldr h8, [x12, #0x0]\n" + "ldr h12, [x24, #0x0]\n" + "mov x19, #0x0\n" + "70:" // Height 2: Partial accumulate: Done + "sub x12, x12, x19\n" + "b 73f\n" + "71:" // Height 2: full accumulate + "ldr q8, [x12, #0x0]\n" + "ldr q9, [x12, #0x10]\n" + "ldr q10, [x12, #0x20]\n" + "ldr q11, [x12, #0x30]\n" + "ldr q12, [x24, #0x0]\n" + "ldr q13, [x24, #0x10]\n" + "ldr q14, [x24, #0x20]\n" + "ldr q15, [x24, #0x30]\n" + "b 73f\n" + "72:" // Height 2: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "73:" // Height 2: setup done + "mov x27, #0x0\n" + "74:" // Height 2: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 75f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "cbnz x27, 76f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "add x24, x24, x19, LSL #1\n" + "b 76f\n" + "75:" // Height 2: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #1\n" + "76:" // Height 2: input setup done + "cmp x26, #0x8\n" + "blt 79f\n" + "ldr q0, [x25, #0x0]\n" + "ldr q1, [x24, #0x0]\n" + "cmp x26, #0x10\n" + "ldr q6, [x11, #0x0]\n" + "ldr q7, [x10, #0x0]\n" + "blt 78f\n" + "77:" // Height 2: Multiply loop: Main loop head + "fmla v8.8h, v6.8h, v0.h[0]\n" + "fmla v12.8h, v6.8h, v1.h[0]\n" + "ldr q6, [x9, #0x0]\n" + "sub x26, x26, #0x8\n" + "fmla v9.8h, v7.8h, v0.h[0]\n" + "fmla v13.8h, v7.8h, v1.h[0]\n" + "ldr q7, [x28, #0x0]\n" + "cmp x26, #0x10\n" + "fmla v10.8h, v6.8h, v0.h[0]\n" + "fmla v14.8h, v6.8h, v1.h[0]\n" + "ldr q6, [x11, #0x10]\n" + "add x25, x25, #0x10\n" + "fmla v11.8h, v7.8h, v0.h[0]\n" + "fmla v15.8h, v7.8h, v1.h[0]\n" + "ldr q7, [x10, #0x10]\n" + "add x24, x24, #0x10\n" + "fmla v8.8h, v6.8h, v0.h[1]\n" + "fmla v12.8h, v6.8h, v1.h[1]\n" + "ldr q6, [x9, #0x10]\n" + "fmla v9.8h, v7.8h, v0.h[1]\n" + "fmla v13.8h, v7.8h, v1.h[1]\n" + "ldr q7, [x28, #0x10]\n" + "fmla v10.8h, v6.8h, v0.h[1]\n" + "fmla v14.8h, v6.8h, v1.h[1]\n" + "ldr q6, [x11, #0x20]\n" + "fmla v11.8h, v7.8h, v0.h[1]\n" + "fmla v15.8h, v7.8h, v1.h[1]\n" + "ldr q7, [x10, #0x20]\n" + "fmla v8.8h, v6.8h, v0.h[2]\n" + "fmla v12.8h, v6.8h, v1.h[2]\n" + "ldr q6, [x9, #0x20]\n" + "fmla v9.8h, v7.8h, v0.h[2]\n" + "fmla v13.8h, v7.8h, v1.h[2]\n" + "ldr q7, [x28, #0x20]\n" + "fmla v10.8h, v6.8h, v0.h[2]\n" + "fmla v14.8h, v6.8h, v1.h[2]\n" + "ldr q6, [x11, #0x30]\n" + "fmla v11.8h, v7.8h, v0.h[2]\n" + "fmla v15.8h, v7.8h, v1.h[2]\n" + "ldr q7, [x10, #0x30]\n" + "fmla v8.8h, v6.8h, v0.h[3]\n" + "fmla v12.8h, v6.8h, v1.h[3]\n" + "ldr q6, [x9, #0x30]\n" + "fmla v9.8h, v7.8h, v0.h[3]\n" + "fmla v13.8h, v7.8h, v1.h[3]\n" + "ldr q7, [x28, #0x30]\n" + "fmla v10.8h, v6.8h, v0.h[3]\n" + "fmla v14.8h, v6.8h, v1.h[3]\n" + "ldr q6, [x11, #0x40]\n" + "fmla v11.8h, v7.8h, v0.h[3]\n" + "fmla v15.8h, v7.8h, v1.h[3]\n" + "ldr q7, [x10, #0x40]\n" + "fmla v8.8h, v6.8h, v0.h[4]\n" + "fmla v12.8h, v6.8h, v1.h[4]\n" + "ldr q6, [x9, #0x40]\n" + "fmla v9.8h, v7.8h, v0.h[4]\n" + "fmla v13.8h, v7.8h, v1.h[4]\n" + "ldr q7, [x28, #0x40]\n" + "fmla v10.8h, v6.8h, v0.h[4]\n" + "fmla v14.8h, v6.8h, v1.h[4]\n" + "ldr q6, [x11, #0x50]\n" + "fmla v11.8h, v7.8h, v0.h[4]\n" + "fmla v15.8h, v7.8h, v1.h[4]\n" + "ldr q7, [x10, #0x50]\n" + "fmla v8.8h, v6.8h, v0.h[5]\n" + "fmla v12.8h, v6.8h, v1.h[5]\n" + "ldr q6, [x9, #0x50]\n" + "fmla v9.8h, v7.8h, v0.h[5]\n" + "fmla v13.8h, v7.8h, v1.h[5]\n" + "ldr q7, [x28, #0x50]\n" + "fmla v10.8h, v6.8h, v0.h[5]\n" + "fmla v14.8h, v6.8h, v1.h[5]\n" + "ldr q6, [x11, #0x60]\n" + "fmla v11.8h, v7.8h, v0.h[5]\n" + "fmla v15.8h, v7.8h, v1.h[5]\n" + "ldr q7, [x10, #0x60]\n" + "fmla v8.8h, v6.8h, v0.h[6]\n" + "fmla v12.8h, v6.8h, v1.h[6]\n" + "ldr q6, [x9, #0x60]\n" + "fmla v9.8h, v7.8h, v0.h[6]\n" + "fmla v13.8h, v7.8h, v1.h[6]\n" + "ldr q7, [x28, #0x60]\n" + "fmla v10.8h, v6.8h, v0.h[6]\n" + "fmla v14.8h, v6.8h, v1.h[6]\n" + "ldr q6, [x11, #0x70]\n" + "add x11, x11, #0x80\n" + "fmla v11.8h, v7.8h, v0.h[6]\n" + "fmla v15.8h, v7.8h, v1.h[6]\n" + "ldr q7, [x10, #0x70]\n" + "add x10, x10, #0x80\n" + "fmla v8.8h, v6.8h, v0.h[7]\n" + "fmla v12.8h, v6.8h, v1.h[7]\n" + "ldr q6, [x9, #0x70]\n" + "add x9, x9, #0x80\n" + "fmla v9.8h, v7.8h, v0.h[7]\n" + "fmla v13.8h, v7.8h, v1.h[7]\n" + "ldr q7, [x28, #0x70]\n" + "add x28, x28, #0x80\n" + "fmla v10.8h, v6.8h, v0.h[7]\n" + "fmla v14.8h, v6.8h, v1.h[7]\n" + "ldr q6, [x11, #0x0]\n" + "fmla v11.8h, v7.8h, v0.h[7]\n" + "ldr q0, [x25, #0x0]\n" + "fmla v15.8h, v7.8h, v1.h[7]\n" + "ldr q1, [x24, #0x0]\n" + "ldr q7, [x10, #0x0]\n" + "bge 77b\n" + "78:" // Height 2: Multiply loop: Single iteration only + "fmla v8.8h, v6.8h, v0.h[0]\n" + "fmla v12.8h, v6.8h, v1.h[0]\n" + "ldr q6, [x9, #0x0]\n" + "sub x26, x26, #0x8\n" + "fmla v9.8h, v7.8h, v0.h[0]\n" + "fmla v13.8h, v7.8h, v1.h[0]\n" + "ldr q7, [x28, #0x0]\n" + "add x25, x25, #0x10\n" + "fmla v10.8h, v6.8h, v0.h[0]\n" + "fmla v14.8h, v6.8h, v1.h[0]\n" + "ldr q6, [x11, #0x10]\n" + "add x24, x24, #0x10\n" + "fmla v11.8h, v7.8h, v0.h[0]\n" + "fmla v15.8h, v7.8h, v1.h[0]\n" + "ldr q7, [x10, #0x10]\n" + "fmla v8.8h, v6.8h, v0.h[1]\n" + "fmla v12.8h, v6.8h, v1.h[1]\n" + "ldr q6, [x9, #0x10]\n" + "fmla v9.8h, v7.8h, v0.h[1]\n" + "fmla v13.8h, v7.8h, v1.h[1]\n" + "ldr q7, [x28, #0x10]\n" + "fmla v10.8h, v6.8h, v0.h[1]\n" + "fmla v14.8h, v6.8h, v1.h[1]\n" + "ldr q6, [x11, #0x20]\n" + "fmla v11.8h, v7.8h, v0.h[1]\n" + "fmla v15.8h, v7.8h, v1.h[1]\n" + "ldr q7, [x10, #0x20]\n" + "fmla v8.8h, v6.8h, v0.h[2]\n" + "fmla v12.8h, v6.8h, v1.h[2]\n" + "ldr q6, [x9, #0x20]\n" + "fmla v9.8h, v7.8h, v0.h[2]\n" + "fmla v13.8h, v7.8h, v1.h[2]\n" + "ldr q7, [x28, #0x20]\n" + "fmla v10.8h, v6.8h, v0.h[2]\n" + "fmla v14.8h, v6.8h, v1.h[2]\n" + "ldr q6, [x11, #0x30]\n" + "fmla v11.8h, v7.8h, v0.h[2]\n" + "fmla v15.8h, v7.8h, v1.h[2]\n" + "ldr q7, [x10, #0x30]\n" + "fmla v8.8h, v6.8h, v0.h[3]\n" + "fmla v12.8h, v6.8h, v1.h[3]\n" + "ldr q6, [x9, #0x30]\n" + "fmla v9.8h, v7.8h, v0.h[3]\n" + "fmla v13.8h, v7.8h, v1.h[3]\n" + "ldr q7, [x28, #0x30]\n" + "fmla v10.8h, v6.8h, v0.h[3]\n" + "fmla v14.8h, v6.8h, v1.h[3]\n" + "ldr q6, [x11, #0x40]\n" + "fmla v11.8h, v7.8h, v0.h[3]\n" + "fmla v15.8h, v7.8h, v1.h[3]\n" + "ldr q7, [x10, #0x40]\n" + "fmla v8.8h, v6.8h, v0.h[4]\n" + "fmla v12.8h, v6.8h, v1.h[4]\n" + "ldr q6, [x9, #0x40]\n" + "fmla v9.8h, v7.8h, v0.h[4]\n" + "fmla v13.8h, v7.8h, v1.h[4]\n" + "ldr q7, [x28, #0x40]\n" + "fmla v10.8h, v6.8h, v0.h[4]\n" + "fmla v14.8h, v6.8h, v1.h[4]\n" + "ldr q6, [x11, #0x50]\n" + "fmla v11.8h, v7.8h, v0.h[4]\n" + "fmla v15.8h, v7.8h, v1.h[4]\n" + "ldr q7, [x10, #0x50]\n" + "fmla v8.8h, v6.8h, v0.h[5]\n" + "fmla v12.8h, v6.8h, v1.h[5]\n" + "ldr q6, [x9, #0x50]\n" + "fmla v9.8h, v7.8h, v0.h[5]\n" + "fmla v13.8h, v7.8h, v1.h[5]\n" + "ldr q7, [x28, #0x50]\n" + "fmla v10.8h, v6.8h, v0.h[5]\n" + "fmla v14.8h, v6.8h, v1.h[5]\n" + "ldr q6, [x11, #0x60]\n" + "fmla v11.8h, v7.8h, v0.h[5]\n" + "fmla v15.8h, v7.8h, v1.h[5]\n" + "ldr q7, [x10, #0x60]\n" + "fmla v8.8h, v6.8h, v0.h[6]\n" + "fmla v12.8h, v6.8h, v1.h[6]\n" + "ldr q6, [x9, #0x60]\n" + "fmla v9.8h, v7.8h, v0.h[6]\n" + "fmla v13.8h, v7.8h, v1.h[6]\n" + "ldr q7, [x28, #0x60]\n" + "fmla v10.8h, v6.8h, v0.h[6]\n" + "fmla v14.8h, v6.8h, v1.h[6]\n" + "ldr q6, [x11, #0x70]\n" + "add x11, x11, #0x80\n" + "fmla v11.8h, v7.8h, v0.h[6]\n" + "fmla v15.8h, v7.8h, v1.h[6]\n" + "ldr q7, [x10, #0x70]\n" + "add x10, x10, #0x80\n" + "fmla v8.8h, v6.8h, v0.h[7]\n" + "fmla v12.8h, v6.8h, v1.h[7]\n" + "ldr q6, [x9, #0x70]\n" + "add x9, x9, #0x80\n" + "fmla v9.8h, v7.8h, v0.h[7]\n" + "fmla v13.8h, v7.8h, v1.h[7]\n" + "ldr q7, [x28, #0x70]\n" + "add x28, x28, #0x80\n" + "fmla v10.8h, v6.8h, v0.h[7]\n" + "fmla v14.8h, v6.8h, v1.h[7]\n" + "fmla v11.8h, v7.8h, v0.h[7]\n" + "fmla v15.8h, v7.8h, v1.h[7]\n" + "79:" // Height 2: Multiply loop: Main loop skip + "cbz x26, 81f\n" + "80:" // Height 2: Multiply loop: Odd block loop + "ldr h0, [x25], #0x2\n" + "ldr h1, [x24], #0x2\n" + "sub x26, x26, #0x1\n" + "ldr q6, [x11, #0x0]\n" + "ldr q7, [x10, #0x0]\n" + "fmla v8.8h, v6.8h, v0.h[0]\n" + "fmla v12.8h, v6.8h, v1.h[0]\n" + "ldr q6, [x9, #0x0]\n" + "fmla v9.8h, v7.8h, v0.h[0]\n" + "fmla v13.8h, v7.8h, v1.h[0]\n" + "ldr q7, [x28, #0x0]\n" + "fmla v10.8h, v6.8h, v0.h[0]\n" + "fmla v14.8h, v6.8h, v1.h[0]\n" + "add x11, x11, #0x10\n" + "add x10, x10, #0x10\n" + "fmla v11.8h, v7.8h, v0.h[0]\n" + "fmla v15.8h, v7.8h, v1.h[0]\n" + "add x9, x9, #0x10\n" + "add x28, x28, #0x10\n" + "cbnz x26, 80b\n" + "81:" // Height 2: Multiply loop: No odd multiplies + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 74b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "tbz %x[flags], #1, 82f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1r { v1.8h }, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1r { v0.8h }, [x19]\n" + "fmin v8.8h, v8.8h, v1.8h\n" + "fmin v9.8h, v9.8h, v1.8h\n" + "fmin v10.8h, v10.8h, v1.8h\n" + "fmin v11.8h, v11.8h, v1.8h\n" + "fmin v12.8h, v12.8h, v1.8h\n" + "fmin v13.8h, v13.8h, v1.8h\n" + "fmin v14.8h, v14.8h, v1.8h\n" + "fmin v15.8h, v15.8h, v1.8h\n" + "fmax v8.8h, v8.8h, v0.8h\n" + "fmax v9.8h, v9.8h, v0.8h\n" + "fmax v10.8h, v10.8h, v0.8h\n" + "fmax v11.8h, v11.8h, v0.8h\n" + "fmax v12.8h, v12.8h, v0.8h\n" + "fmax v13.8h, v13.8h, v0.8h\n" + "fmax v14.8h, v14.8h, v0.8h\n" + "fmax v15.8h, v15.8h, v0.8h\n" + "82:" // Height 2: No activation + "cmp x13, #0x20\n" + "bge 99f\n" + "tbz x13, #4, 90f\n" + "st1 { v8.8h }, [x12], #0x10\n" + "st1 { v9.8h }, [x12], #0x10\n" + "st1 { v12.8h }, [x24], #0x10\n" + "st1 { v13.8h }, [x24], #0x10\n" + "tbz x13, #3, 86f\n" + "st1 { v10.8h }, [x12], #0x10\n" + "st1 { v14.8h }, [x24], #0x10\n" + "tbz x13, #2, 84f\n" + "str d11, [x12], #0x8\n" + "str d15, [x24], #0x8\n" + "tbz x13, #1, 83f\n" + "st1 { v11.s }[2], [x12], #0x4\n" + "st1 { v15.s }[2], [x24], #0x4\n" + "tbz x13, #0, 98f\n" + "st1 { v11.h }[6], [x12]\n" + "st1 { v15.h }[6], [x24]\n" + "b 98f\n" + "83:" // Height 2: Partial direct writeback: partial_1_28 + "tbz x13, #0, 98f\n" + "st1 { v11.h }[4], [x12]\n" + "st1 { v15.h }[4], [x24]\n" + "b 98f\n" + "84:" // Height 2: Partial direct writeback: partial_2_24 + "tbz x13, #1, 85f\n" + "str s11, [x12], #0x4\n" + "str s15, [x24], #0x4\n" + "tbz x13, #0, 98f\n" + "st1 { v11.h }[2], [x12]\n" + "st1 { v15.h }[2], [x24]\n" + "b 98f\n" + "85:" // Height 2: Partial direct writeback: partial_1_24 + "tbz x13, #0, 98f\n" + "str h11, [x12, #0x0]\n" + "str h15, [x24, #0x0]\n" + "b 98f\n" + "86:" // Height 2: Partial direct writeback: partial_4_16 + "tbz x13, #2, 88f\n" + "str d10, [x12], #0x8\n" + "str d14, [x24], #0x8\n" + "tbz x13, #1, 87f\n" + "st1 { v10.s }[2], [x12], #0x4\n" + "st1 { v14.s }[2], [x24], #0x4\n" + "tbz x13, #0, 98f\n" + "st1 { v10.h }[6], [x12]\n" + "st1 { v14.h }[6], [x24]\n" + "b 98f\n" + "87:" // Height 2: Partial direct writeback: partial_1_20 + "tbz x13, #0, 98f\n" + "st1 { v10.h }[4], [x12]\n" + "st1 { v14.h }[4], [x24]\n" + "b 98f\n" + "88:" // Height 2: Partial direct writeback: partial_2_16 + "tbz x13, #1, 89f\n" + "str s10, [x12], #0x4\n" + "str s14, [x24], #0x4\n" + "tbz x13, #0, 98f\n" + "st1 { v10.h }[2], [x12]\n" + "st1 { v14.h }[2], [x24]\n" + "b 98f\n" + "89:" // Height 2: Partial direct writeback: partial_1_16 + "tbz x13, #0, 98f\n" + "str h10, [x12, #0x0]\n" + "str h14, [x24, #0x0]\n" + "b 98f\n" + "90:" // Height 2: Partial direct writeback: partial_8_0 + "tbz x13, #3, 94f\n" + "st1 { v8.8h }, [x12], #0x10\n" + "st1 { v12.8h }, [x24], #0x10\n" + "tbz x13, #2, 92f\n" + "str d9, [x12], #0x8\n" + "str d13, [x24], #0x8\n" + "tbz x13, #1, 91f\n" + "st1 { v9.s }[2], [x12], #0x4\n" + "st1 { v13.s }[2], [x24], #0x4\n" + "tbz x13, #0, 98f\n" + "st1 { v9.h }[6], [x12]\n" + "st1 { v13.h }[6], [x24]\n" + "b 98f\n" + "91:" // Height 2: Partial direct writeback: partial_1_12 + "tbz x13, #0, 98f\n" + "st1 { v9.h }[4], [x12]\n" + "st1 { v13.h }[4], [x24]\n" + "b 98f\n" + "92:" // Height 2: Partial direct writeback: partial_2_8 + "tbz x13, #1, 93f\n" + "str s9, [x12], #0x4\n" + "str s13, [x24], #0x4\n" + "tbz x13, #0, 98f\n" + "st1 { v9.h }[2], [x12]\n" + "st1 { v13.h }[2], [x24]\n" + "b 98f\n" + "93:" // Height 2: Partial direct writeback: partial_1_8 + "tbz x13, #0, 98f\n" + "str h9, [x12, #0x0]\n" + "str h13, [x24, #0x0]\n" + "b 98f\n" + "94:" // Height 2: Partial direct writeback: partial_4_0 + "tbz x13, #2, 96f\n" + "str d8, [x12], #0x8\n" + "str d12, [x24], #0x8\n" + "tbz x13, #1, 95f\n" + "st1 { v8.s }[2], [x12], #0x4\n" + "st1 { v12.s }[2], [x24], #0x4\n" + "tbz x13, #0, 98f\n" + "st1 { v8.h }[6], [x12]\n" + "st1 { v12.h }[6], [x24]\n" + "b 98f\n" + "95:" // Height 2: Partial direct writeback: partial_1_4 + "tbz x13, #0, 98f\n" + "st1 { v8.h }[4], [x12]\n" + "st1 { v12.h }[4], [x24]\n" + "b 98f\n" + "96:" // Height 2: Partial direct writeback: partial_2_0 + "tbz x13, #1, 97f\n" + "str s8, [x12], #0x4\n" + "str s12, [x24], #0x4\n" + "tbz x13, #0, 98f\n" + "st1 { v8.h }[2], [x12]\n" + "st1 { v12.h }[2], [x24]\n" + "b 98f\n" + "97:" // Height 2: Partial direct writeback: partial_1_0 + "str h8, [x12, #0x0]\n" + "str h12, [x24, #0x0]\n" + "98:" // Height 2: Partial direct writeback: Done + "b 100f\n" + "99:" // Height 2: Full writeback + "str q8, [x12, #0x0]\n" + "str q9, [x12, #0x10]\n" + "str q10, [x12, #0x20]\n" + "str q11, [x12, #0x30]\n" + "add x12, x12, #0x40\n" + "str q12, [x24, #0x0]\n" + "str q13, [x24, #0x10]\n" + "str q14, [x24, #0x20]\n" + "str q15, [x24, #0x30]\n" + "100:" // Height 2: Writeback done + "subs x13, x13, #0x20\n" + "bgt 52b\n" + "b 302f\n" + "101:" // Height 3 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "102:" // Height 3: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "cmp x13, #0x18\n" + "bgt 103f\n" + "cmp x13, #0x10\n" + "mov x28, x11\n" + "bgt 103f\n" + "cmp x13, #0x8\n" + "mov x9, x11\n" + "bgt 103f\n" + "mov x10, x11\n" + "103:" // Height 3: B setup done + "cbz x14, 104f\n" + "ldr q8, [x14, #0x0]\n" + "ldr q9, [x14, #0x10]\n" + "mov v12.16b, v8.16b\n" + "mov v13.16b, v9.16b\n" + "ldr q10, [x14, #0x20]\n" + "ldr q11, [x14, #0x30]\n" + "mov v14.16b, v10.16b\n" + "mov v15.16b, v11.16b\n" + "mov v16.16b, v8.16b\n" + "mov v17.16b, v9.16b\n" + "add x14, x14, #0x40\n" + "mov v18.16b, v10.16b\n" + "mov v19.16b, v11.16b\n" + "b 123f\n" + "104:" // Height 3: no bias + "tbz %x[flags], #0, 122f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "cmp x13, #0x20\n" + "add x23, x24, x19, LSL #1\n" + "bge 121f\n" + "tbz x13, #4, 112f\n" + "ld1 { v8.8h }, [x12], #0x10\n" + "ld1 { v12.8h }, [x24], #0x10\n" + "ld1 { v16.8h }, [x23], #0x10\n" + "ld1 { v9.8h }, [x12], #0x10\n" + "ld1 { v13.8h }, [x24], #0x10\n" + "ld1 { v17.8h }, [x23], #0x10\n" + "tbz x13, #3, 108f\n" + "ld1 { v10.8h }, [x12], #0x10\n" + "ld1 { v14.8h }, [x24], #0x10\n" + "ld1 { v18.8h }, [x23], #0x10\n" + "tbz x13, #2, 106f\n" + "ldr d11, [x12], #0x8\n" + "ldr d15, [x24], #0x8\n" + "ldr d19, [x23], #0x8\n" + "tbz x13, #1, 105f\n" + "ld1 { v11.s }[2], [x12], #0x4\n" + "ld1 { v15.s }[2], [x24], #0x4\n" + "mov x19, #0x3c\n" + "ld1 { v19.s }[2], [x23], #0x4\n" + "tbz x13, #0, 120f\n" + "ld1 { v11.h }[6], [x12]\n" + "ld1 { v15.h }[6], [x24]\n" + "ld1 { v19.h }[6], [x23]\n" + "b 120f\n" + "105:" // Height 3: Partial accumulate: partial_1_28 + "mov x19, #0x38\n" + "tbz x13, #0, 120f\n" + "ld1 { v11.h }[4], [x12]\n" + "ld1 { v15.h }[4], [x24]\n" + "ld1 { v19.h }[4], [x23]\n" + "b 120f\n" + "106:" // Height 3: Partial accumulate: partial_2_24 + "tbz x13, #1, 107f\n" + "ldr s11, [x12], #0x4\n" + "ldr s15, [x24], #0x4\n" + "mov x19, #0x34\n" + "ldr s19, [x23], #0x4\n" + "tbz x13, #0, 120f\n" + "ld1 { v11.h }[2], [x12]\n" + "ld1 { v15.h }[2], [x24]\n" + "ld1 { v19.h }[2], [x23]\n" + "b 120f\n" + "107:" // Height 3: Partial accumulate: partial_1_24 + "mov x19, #0x30\n" + "tbz x13, #0, 120f\n" + "ldr h11, [x12, #0x0]\n" + "ldr h15, [x24, #0x0]\n" + "ldr h19, [x23, #0x0]\n" + "b 120f\n" + "108:" // Height 3: Partial accumulate: partial_4_16 + "tbz x13, #2, 110f\n" + "ldr d10, [x12], #0x8\n" + "ldr d14, [x24], #0x8\n" + "ldr d18, [x23], #0x8\n" + "tbz x13, #1, 109f\n" + "ld1 { v10.s }[2], [x12], #0x4\n" + "ld1 { v14.s }[2], [x24], #0x4\n" + "mov x19, #0x2c\n" + "ld1 { v18.s }[2], [x23], #0x4\n" + "tbz x13, #0, 120f\n" + "ld1 { v10.h }[6], [x12]\n" + "ld1 { v14.h }[6], [x24]\n" + "ld1 { v18.h }[6], [x23]\n" + "b 120f\n" + "109:" // Height 3: Partial accumulate: partial_1_20 + "mov x19, #0x28\n" + "tbz x13, #0, 120f\n" + "ld1 { v10.h }[4], [x12]\n" + "ld1 { v14.h }[4], [x24]\n" + "ld1 { v18.h }[4], [x23]\n" + "b 120f\n" + "110:" // Height 3: Partial accumulate: partial_2_16 + "tbz x13, #1, 111f\n" + "ldr s10, [x12], #0x4\n" + "ldr s14, [x24], #0x4\n" + "mov x19, #0x24\n" + "ldr s18, [x23], #0x4\n" + "tbz x13, #0, 120f\n" + "ld1 { v10.h }[2], [x12]\n" + "ld1 { v14.h }[2], [x24]\n" + "ld1 { v18.h }[2], [x23]\n" + "b 120f\n" + "111:" // Height 3: Partial accumulate: partial_1_16 + "mov x19, #0x20\n" + "tbz x13, #0, 120f\n" + "ldr h10, [x12, #0x0]\n" + "ldr h14, [x24, #0x0]\n" + "ldr h18, [x23, #0x0]\n" + "b 120f\n" + "112:" // Height 3: Partial accumulate: partial_8_0 + "tbz x13, #3, 116f\n" + "ld1 { v8.8h }, [x12], #0x10\n" + "ld1 { v12.8h }, [x24], #0x10\n" + "ld1 { v16.8h }, [x23], #0x10\n" + "tbz x13, #2, 114f\n" + "ldr d9, [x12], #0x8\n" + "ldr d13, [x24], #0x8\n" + "ldr d17, [x23], #0x8\n" + "tbz x13, #1, 113f\n" + "ld1 { v9.s }[2], [x12], #0x4\n" + "ld1 { v13.s }[2], [x24], #0x4\n" + "mov x19, #0x1c\n" + "ld1 { v17.s }[2], [x23], #0x4\n" + "tbz x13, #0, 120f\n" + "ld1 { v9.h }[6], [x12]\n" + "ld1 { v13.h }[6], [x24]\n" + "ld1 { v17.h }[6], [x23]\n" + "b 120f\n" + "113:" // Height 3: Partial accumulate: partial_1_12 + "mov x19, #0x18\n" + "tbz x13, #0, 120f\n" + "ld1 { v9.h }[4], [x12]\n" + "ld1 { v13.h }[4], [x24]\n" + "ld1 { v17.h }[4], [x23]\n" + "b 120f\n" + "114:" // Height 3: Partial accumulate: partial_2_8 + "tbz x13, #1, 115f\n" + "ldr s9, [x12], #0x4\n" + "ldr s13, [x24], #0x4\n" + "mov x19, #0x14\n" + "ldr s17, [x23], #0x4\n" + "tbz x13, #0, 120f\n" + "ld1 { v9.h }[2], [x12]\n" + "ld1 { v13.h }[2], [x24]\n" + "ld1 { v17.h }[2], [x23]\n" + "b 120f\n" + "115:" // Height 3: Partial accumulate: partial_1_8 + "mov x19, #0x10\n" + "tbz x13, #0, 120f\n" + "ldr h9, [x12, #0x0]\n" + "ldr h13, [x24, #0x0]\n" + "ldr h17, [x23, #0x0]\n" + "b 120f\n" + "116:" // Height 3: Partial accumulate: partial_4_0 + "tbz x13, #2, 118f\n" + "ldr d8, [x12], #0x8\n" + "ldr d12, [x24], #0x8\n" + "ldr d16, [x23], #0x8\n" + "tbz x13, #1, 117f\n" + "ld1 { v8.s }[2], [x12], #0x4\n" + "ld1 { v12.s }[2], [x24], #0x4\n" + "mov x19, #0xc\n" + "ld1 { v16.s }[2], [x23], #0x4\n" + "tbz x13, #0, 120f\n" + "ld1 { v8.h }[6], [x12]\n" + "ld1 { v12.h }[6], [x24]\n" + "ld1 { v16.h }[6], [x23]\n" + "b 120f\n" + "117:" // Height 3: Partial accumulate: partial_1_4 + "mov x19, #0x8\n" + "tbz x13, #0, 120f\n" + "ld1 { v8.h }[4], [x12]\n" + "ld1 { v12.h }[4], [x24]\n" + "ld1 { v16.h }[4], [x23]\n" + "b 120f\n" + "118:" // Height 3: Partial accumulate: partial_2_0 + "tbz x13, #1, 119f\n" + "ldr s8, [x12], #0x4\n" + "ldr s12, [x24], #0x4\n" + "mov x19, #0x4\n" + "ldr s16, [x23], #0x4\n" + "tbz x13, #0, 120f\n" + "ld1 { v8.h }[2], [x12]\n" + "ld1 { v12.h }[2], [x24]\n" + "ld1 { v16.h }[2], [x23]\n" + "b 120f\n" + "119:" // Height 3: Partial accumulate: partial_1_0 + "ldr h8, [x12, #0x0]\n" + "ldr h12, [x24, #0x0]\n" + "mov x19, #0x0\n" + "ldr h16, [x23, #0x0]\n" + "120:" // Height 3: Partial accumulate: Done + "sub x12, x12, x19\n" + "b 123f\n" + "121:" // Height 3: full accumulate + "ldr q8, [x12, #0x0]\n" + "ldr q9, [x12, #0x10]\n" + "ldr q10, [x12, #0x20]\n" + "ldr q11, [x12, #0x30]\n" + "ldr q12, [x24, #0x0]\n" + "ldr q13, [x24, #0x10]\n" + "ldr q14, [x24, #0x20]\n" + "ldr q15, [x24, #0x30]\n" + "ldr q16, [x23, #0x0]\n" + "ldr q17, [x23, #0x10]\n" + "ldr q18, [x23, #0x20]\n" + "ldr q19, [x23, #0x30]\n" + "b 123f\n" + "122:" // Height 3: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "123:" // Height 3: setup done + "mov x27, #0x0\n" + "124:" // Height 3: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 125f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "cbnz x27, 126f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "add x24, x24, x19, LSL #1\n" + "add x23, x23, x19, LSL #1\n" + "b 126f\n" + "125:" // Height 3: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "126:" // Height 3: input setup done + "cmp x26, #0x8\n" + "blt 129f\n" + "ldr q0, [x25, #0x0]\n" + "ldr q1, [x24, #0x0]\n" + "cmp x26, #0x10\n" + "ldr q2, [x23, #0x0]\n" + "ldr q6, [x11, #0x0]\n" + "ldr q7, [x10, #0x0]\n" + "blt 128f\n" + "127:" // Height 3: Multiply loop: Main loop head + "fmla v8.8h, v6.8h, v0.h[0]\n" + "fmla v12.8h, v6.8h, v1.h[0]\n" + "sub x26, x26, #0x8\n" + "cmp x26, #0x10\n" + "fmla v16.8h, v6.8h, v2.h[0]\n" + "ldr q6, [x9, #0x0]\n" + "fmla v9.8h, v7.8h, v0.h[0]\n" + "add x25, x25, #0x10\n" + "fmla v13.8h, v7.8h, v1.h[0]\n" + "fmla v17.8h, v7.8h, v2.h[0]\n" + "ldr q7, [x28, #0x0]\n" + "add x24, x24, #0x10\n" + "fmla v10.8h, v6.8h, v0.h[0]\n" + "fmla v14.8h, v6.8h, v1.h[0]\n" + "add x23, x23, #0x10\n" + "fmla v18.8h, v6.8h, v2.h[0]\n" + "ldr q6, [x11, #0x10]\n" + "fmla v11.8h, v7.8h, v0.h[0]\n" + "fmla v15.8h, v7.8h, v1.h[0]\n" + "fmla v19.8h, v7.8h, v2.h[0]\n" + "ldr q7, [x10, #0x10]\n" + "fmla v8.8h, v6.8h, v0.h[1]\n" + "fmla v12.8h, v6.8h, v1.h[1]\n" + "fmla v16.8h, v6.8h, v2.h[1]\n" + "ldr q6, [x9, #0x10]\n" + "fmla v9.8h, v7.8h, v0.h[1]\n" + "fmla v13.8h, v7.8h, v1.h[1]\n" + "fmla v17.8h, v7.8h, v2.h[1]\n" + "ldr q7, [x28, #0x10]\n" + "fmla v10.8h, v6.8h, v0.h[1]\n" + "fmla v14.8h, v6.8h, v1.h[1]\n" + "fmla v18.8h, v6.8h, v2.h[1]\n" + "ldr q6, [x11, #0x20]\n" + "fmla v11.8h, v7.8h, v0.h[1]\n" + "fmla v15.8h, v7.8h, v1.h[1]\n" + "fmla v19.8h, v7.8h, v2.h[1]\n" + "ldr q7, [x10, #0x20]\n" + "fmla v8.8h, v6.8h, v0.h[2]\n" + "fmla v12.8h, v6.8h, v1.h[2]\n" + "fmla v16.8h, v6.8h, v2.h[2]\n" + "ldr q6, [x9, #0x20]\n" + "fmla v9.8h, v7.8h, v0.h[2]\n" + "fmla v13.8h, v7.8h, v1.h[2]\n" + "fmla v17.8h, v7.8h, v2.h[2]\n" + "ldr q7, [x28, #0x20]\n" + "fmla v10.8h, v6.8h, v0.h[2]\n" + "fmla v14.8h, v6.8h, v1.h[2]\n" + "fmla v18.8h, v6.8h, v2.h[2]\n" + "ldr q6, [x11, #0x30]\n" + "fmla v11.8h, v7.8h, v0.h[2]\n" + "fmla v15.8h, v7.8h, v1.h[2]\n" + "fmla v19.8h, v7.8h, v2.h[2]\n" + "ldr q7, [x10, #0x30]\n" + "fmla v8.8h, v6.8h, v0.h[3]\n" + "fmla v12.8h, v6.8h, v1.h[3]\n" + "fmla v16.8h, v6.8h, v2.h[3]\n" + "ldr q6, [x9, #0x30]\n" + "fmla v9.8h, v7.8h, v0.h[3]\n" + "fmla v13.8h, v7.8h, v1.h[3]\n" + "fmla v17.8h, v7.8h, v2.h[3]\n" + "ldr q7, [x28, #0x30]\n" + "fmla v10.8h, v6.8h, v0.h[3]\n" + "fmla v14.8h, v6.8h, v1.h[3]\n" + "fmla v18.8h, v6.8h, v2.h[3]\n" + "ldr q6, [x11, #0x40]\n" + "fmla v11.8h, v7.8h, v0.h[3]\n" + "fmla v15.8h, v7.8h, v1.h[3]\n" + "fmla v19.8h, v7.8h, v2.h[3]\n" + "ldr q7, [x10, #0x40]\n" + "fmla v8.8h, v6.8h, v0.h[4]\n" + "fmla v12.8h, v6.8h, v1.h[4]\n" + "fmla v16.8h, v6.8h, v2.h[4]\n" + "ldr q6, [x9, #0x40]\n" + "fmla v9.8h, v7.8h, v0.h[4]\n" + "fmla v13.8h, v7.8h, v1.h[4]\n" + "fmla v17.8h, v7.8h, v2.h[4]\n" + "ldr q7, [x28, #0x40]\n" + "fmla v10.8h, v6.8h, v0.h[4]\n" + "fmla v14.8h, v6.8h, v1.h[4]\n" + "fmla v18.8h, v6.8h, v2.h[4]\n" + "ldr q6, [x11, #0x50]\n" + "fmla v11.8h, v7.8h, v0.h[4]\n" + "fmla v15.8h, v7.8h, v1.h[4]\n" + "fmla v19.8h, v7.8h, v2.h[4]\n" + "ldr q7, [x10, #0x50]\n" + "fmla v8.8h, v6.8h, v0.h[5]\n" + "fmla v12.8h, v6.8h, v1.h[5]\n" + "fmla v16.8h, v6.8h, v2.h[5]\n" + "ldr q6, [x9, #0x50]\n" + "fmla v9.8h, v7.8h, v0.h[5]\n" + "fmla v13.8h, v7.8h, v1.h[5]\n" + "fmla v17.8h, v7.8h, v2.h[5]\n" + "ldr q7, [x28, #0x50]\n" + "fmla v10.8h, v6.8h, v0.h[5]\n" + "fmla v14.8h, v6.8h, v1.h[5]\n" + "fmla v18.8h, v6.8h, v2.h[5]\n" + "ldr q6, [x11, #0x60]\n" + "fmla v11.8h, v7.8h, v0.h[5]\n" + "fmla v15.8h, v7.8h, v1.h[5]\n" + "fmla v19.8h, v7.8h, v2.h[5]\n" + "ldr q7, [x10, #0x60]\n" + "fmla v8.8h, v6.8h, v0.h[6]\n" + "fmla v12.8h, v6.8h, v1.h[6]\n" + "fmla v16.8h, v6.8h, v2.h[6]\n" + "ldr q6, [x9, #0x60]\n" + "fmla v9.8h, v7.8h, v0.h[6]\n" + "fmla v13.8h, v7.8h, v1.h[6]\n" + "fmla v17.8h, v7.8h, v2.h[6]\n" + "ldr q7, [x28, #0x60]\n" + "fmla v10.8h, v6.8h, v0.h[6]\n" + "fmla v14.8h, v6.8h, v1.h[6]\n" + "fmla v18.8h, v6.8h, v2.h[6]\n" + "ldr q6, [x11, #0x70]\n" + "fmla v11.8h, v7.8h, v0.h[6]\n" + "add x11, x11, #0x80\n" + "fmla v15.8h, v7.8h, v1.h[6]\n" + "fmla v19.8h, v7.8h, v2.h[6]\n" + "ldr q7, [x10, #0x70]\n" + "add x10, x10, #0x80\n" + "fmla v8.8h, v6.8h, v0.h[7]\n" + "fmla v12.8h, v6.8h, v1.h[7]\n" + "fmla v16.8h, v6.8h, v2.h[7]\n" + "ldr q6, [x9, #0x70]\n" + "fmla v9.8h, v7.8h, v0.h[7]\n" + "add x9, x9, #0x80\n" + "fmla v13.8h, v7.8h, v1.h[7]\n" + "fmla v17.8h, v7.8h, v2.h[7]\n" + "ldr q7, [x28, #0x70]\n" + "add x28, x28, #0x80\n" + "fmla v10.8h, v6.8h, v0.h[7]\n" + "fmla v14.8h, v6.8h, v1.h[7]\n" + "fmla v18.8h, v6.8h, v2.h[7]\n" + "ldr q6, [x11, #0x0]\n" + "fmla v11.8h, v7.8h, v0.h[7]\n" + "ldr q0, [x25, #0x0]\n" + "fmla v15.8h, v7.8h, v1.h[7]\n" + "ldr q1, [x24, #0x0]\n" + "fmla v19.8h, v7.8h, v2.h[7]\n" + "ldr q2, [x23, #0x0]\n" + "ldr q7, [x10, #0x0]\n" + "bge 127b\n" + "128:" // Height 3: Multiply loop: Single iteration only + "fmla v8.8h, v6.8h, v0.h[0]\n" + "fmla v12.8h, v6.8h, v1.h[0]\n" + "sub x26, x26, #0x8\n" + "add x25, x25, #0x10\n" + "fmla v16.8h, v6.8h, v2.h[0]\n" + "ldr q6, [x9, #0x0]\n" + "fmla v9.8h, v7.8h, v0.h[0]\n" + "add x24, x24, #0x10\n" + "fmla v13.8h, v7.8h, v1.h[0]\n" + "fmla v17.8h, v7.8h, v2.h[0]\n" + "ldr q7, [x28, #0x0]\n" + "add x23, x23, #0x10\n" + "fmla v10.8h, v6.8h, v0.h[0]\n" + "fmla v14.8h, v6.8h, v1.h[0]\n" + "fmla v18.8h, v6.8h, v2.h[0]\n" + "ldr q6, [x11, #0x10]\n" + "fmla v11.8h, v7.8h, v0.h[0]\n" + "fmla v15.8h, v7.8h, v1.h[0]\n" + "fmla v19.8h, v7.8h, v2.h[0]\n" + "ldr q7, [x10, #0x10]\n" + "fmla v8.8h, v6.8h, v0.h[1]\n" + "fmla v12.8h, v6.8h, v1.h[1]\n" + "fmla v16.8h, v6.8h, v2.h[1]\n" + "ldr q6, [x9, #0x10]\n" + "fmla v9.8h, v7.8h, v0.h[1]\n" + "fmla v13.8h, v7.8h, v1.h[1]\n" + "fmla v17.8h, v7.8h, v2.h[1]\n" + "ldr q7, [x28, #0x10]\n" + "fmla v10.8h, v6.8h, v0.h[1]\n" + "fmla v14.8h, v6.8h, v1.h[1]\n" + "fmla v18.8h, v6.8h, v2.h[1]\n" + "ldr q6, [x11, #0x20]\n" + "fmla v11.8h, v7.8h, v0.h[1]\n" + "fmla v15.8h, v7.8h, v1.h[1]\n" + "fmla v19.8h, v7.8h, v2.h[1]\n" + "ldr q7, [x10, #0x20]\n" + "fmla v8.8h, v6.8h, v0.h[2]\n" + "fmla v12.8h, v6.8h, v1.h[2]\n" + "fmla v16.8h, v6.8h, v2.h[2]\n" + "ldr q6, [x9, #0x20]\n" + "fmla v9.8h, v7.8h, v0.h[2]\n" + "fmla v13.8h, v7.8h, v1.h[2]\n" + "fmla v17.8h, v7.8h, v2.h[2]\n" + "ldr q7, [x28, #0x20]\n" + "fmla v10.8h, v6.8h, v0.h[2]\n" + "fmla v14.8h, v6.8h, v1.h[2]\n" + "fmla v18.8h, v6.8h, v2.h[2]\n" + "ldr q6, [x11, #0x30]\n" + "fmla v11.8h, v7.8h, v0.h[2]\n" + "fmla v15.8h, v7.8h, v1.h[2]\n" + "fmla v19.8h, v7.8h, v2.h[2]\n" + "ldr q7, [x10, #0x30]\n" + "fmla v8.8h, v6.8h, v0.h[3]\n" + "fmla v12.8h, v6.8h, v1.h[3]\n" + "fmla v16.8h, v6.8h, v2.h[3]\n" + "ldr q6, [x9, #0x30]\n" + "fmla v9.8h, v7.8h, v0.h[3]\n" + "fmla v13.8h, v7.8h, v1.h[3]\n" + "fmla v17.8h, v7.8h, v2.h[3]\n" + "ldr q7, [x28, #0x30]\n" + "fmla v10.8h, v6.8h, v0.h[3]\n" + "fmla v14.8h, v6.8h, v1.h[3]\n" + "fmla v18.8h, v6.8h, v2.h[3]\n" + "ldr q6, [x11, #0x40]\n" + "fmla v11.8h, v7.8h, v0.h[3]\n" + "fmla v15.8h, v7.8h, v1.h[3]\n" + "fmla v19.8h, v7.8h, v2.h[3]\n" + "ldr q7, [x10, #0x40]\n" + "fmla v8.8h, v6.8h, v0.h[4]\n" + "fmla v12.8h, v6.8h, v1.h[4]\n" + "fmla v16.8h, v6.8h, v2.h[4]\n" + "ldr q6, [x9, #0x40]\n" + "fmla v9.8h, v7.8h, v0.h[4]\n" + "fmla v13.8h, v7.8h, v1.h[4]\n" + "fmla v17.8h, v7.8h, v2.h[4]\n" + "ldr q7, [x28, #0x40]\n" + "fmla v10.8h, v6.8h, v0.h[4]\n" + "fmla v14.8h, v6.8h, v1.h[4]\n" + "fmla v18.8h, v6.8h, v2.h[4]\n" + "ldr q6, [x11, #0x50]\n" + "fmla v11.8h, v7.8h, v0.h[4]\n" + "fmla v15.8h, v7.8h, v1.h[4]\n" + "fmla v19.8h, v7.8h, v2.h[4]\n" + "ldr q7, [x10, #0x50]\n" + "fmla v8.8h, v6.8h, v0.h[5]\n" + "fmla v12.8h, v6.8h, v1.h[5]\n" + "fmla v16.8h, v6.8h, v2.h[5]\n" + "ldr q6, [x9, #0x50]\n" + "fmla v9.8h, v7.8h, v0.h[5]\n" + "fmla v13.8h, v7.8h, v1.h[5]\n" + "fmla v17.8h, v7.8h, v2.h[5]\n" + "ldr q7, [x28, #0x50]\n" + "fmla v10.8h, v6.8h, v0.h[5]\n" + "fmla v14.8h, v6.8h, v1.h[5]\n" + "fmla v18.8h, v6.8h, v2.h[5]\n" + "ldr q6, [x11, #0x60]\n" + "fmla v11.8h, v7.8h, v0.h[5]\n" + "fmla v15.8h, v7.8h, v1.h[5]\n" + "fmla v19.8h, v7.8h, v2.h[5]\n" + "ldr q7, [x10, #0x60]\n" + "fmla v8.8h, v6.8h, v0.h[6]\n" + "fmla v12.8h, v6.8h, v1.h[6]\n" + "fmla v16.8h, v6.8h, v2.h[6]\n" + "ldr q6, [x9, #0x60]\n" + "fmla v9.8h, v7.8h, v0.h[6]\n" + "fmla v13.8h, v7.8h, v1.h[6]\n" + "fmla v17.8h, v7.8h, v2.h[6]\n" + "ldr q7, [x28, #0x60]\n" + "fmla v10.8h, v6.8h, v0.h[6]\n" + "fmla v14.8h, v6.8h, v1.h[6]\n" + "fmla v18.8h, v6.8h, v2.h[6]\n" + "ldr q6, [x11, #0x70]\n" + "fmla v11.8h, v7.8h, v0.h[6]\n" + "add x11, x11, #0x80\n" + "fmla v15.8h, v7.8h, v1.h[6]\n" + "fmla v19.8h, v7.8h, v2.h[6]\n" + "ldr q7, [x10, #0x70]\n" + "add x10, x10, #0x80\n" + "fmla v8.8h, v6.8h, v0.h[7]\n" + "fmla v12.8h, v6.8h, v1.h[7]\n" + "fmla v16.8h, v6.8h, v2.h[7]\n" + "ldr q6, [x9, #0x70]\n" + "fmla v9.8h, v7.8h, v0.h[7]\n" + "add x9, x9, #0x80\n" + "fmla v13.8h, v7.8h, v1.h[7]\n" + "fmla v17.8h, v7.8h, v2.h[7]\n" + "ldr q7, [x28, #0x70]\n" + "add x28, x28, #0x80\n" + "fmla v10.8h, v6.8h, v0.h[7]\n" + "fmla v14.8h, v6.8h, v1.h[7]\n" + "fmla v18.8h, v6.8h, v2.h[7]\n" + "fmla v11.8h, v7.8h, v0.h[7]\n" + "fmla v15.8h, v7.8h, v1.h[7]\n" + "fmla v19.8h, v7.8h, v2.h[7]\n" + "129:" // Height 3: Multiply loop: Main loop skip + "cbz x26, 131f\n" + "130:" // Height 3: Multiply loop: Odd block loop + "ldr h0, [x25], #0x2\n" + "ldr h1, [x24], #0x2\n" + "sub x26, x26, #0x1\n" + "ldr h2, [x23], #0x2\n" + "ldr q6, [x11, #0x0]\n" + "fmla v8.8h, v6.8h, v0.h[0]\n" + "fmla v12.8h, v6.8h, v1.h[0]\n" + "ldr q7, [x10, #0x0]\n" + "fmla v16.8h, v6.8h, v2.h[0]\n" + "ldr q6, [x9, #0x0]\n" + "fmla v9.8h, v7.8h, v0.h[0]\n" + "fmla v13.8h, v7.8h, v1.h[0]\n" + "fmla v17.8h, v7.8h, v2.h[0]\n" + "ldr q7, [x28, #0x0]\n" + "add x11, x11, #0x10\n" + "fmla v10.8h, v6.8h, v0.h[0]\n" + "fmla v14.8h, v6.8h, v1.h[0]\n" + "add x10, x10, #0x10\n" + "add x9, x9, #0x10\n" + "fmla v18.8h, v6.8h, v2.h[0]\n" + "fmla v11.8h, v7.8h, v0.h[0]\n" + "add x28, x28, #0x10\n" + "fmla v15.8h, v7.8h, v1.h[0]\n" + "fmla v19.8h, v7.8h, v2.h[0]\n" + "cbnz x26, 130b\n" + "131:" // Height 3: Multiply loop: No odd multiplies + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 124b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "tbz %x[flags], #1, 132f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1r { v1.8h }, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1r { v0.8h }, [x19]\n" + "fmin v8.8h, v8.8h, v1.8h\n" + "fmin v9.8h, v9.8h, v1.8h\n" + "fmin v10.8h, v10.8h, v1.8h\n" + "fmin v11.8h, v11.8h, v1.8h\n" + "fmin v12.8h, v12.8h, v1.8h\n" + "fmin v13.8h, v13.8h, v1.8h\n" + "fmin v14.8h, v14.8h, v1.8h\n" + "fmin v15.8h, v15.8h, v1.8h\n" + "fmin v16.8h, v16.8h, v1.8h\n" + "fmin v17.8h, v17.8h, v1.8h\n" + "fmin v18.8h, v18.8h, v1.8h\n" + "fmin v19.8h, v19.8h, v1.8h\n" + "fmax v8.8h, v8.8h, v0.8h\n" + "fmax v9.8h, v9.8h, v0.8h\n" + "fmax v10.8h, v10.8h, v0.8h\n" + "fmax v11.8h, v11.8h, v0.8h\n" + "fmax v12.8h, v12.8h, v0.8h\n" + "fmax v13.8h, v13.8h, v0.8h\n" + "fmax v14.8h, v14.8h, v0.8h\n" + "fmax v15.8h, v15.8h, v0.8h\n" + "fmax v16.8h, v16.8h, v0.8h\n" + "fmax v17.8h, v17.8h, v0.8h\n" + "fmax v18.8h, v18.8h, v0.8h\n" + "fmax v19.8h, v19.8h, v0.8h\n" + "132:" // Height 3: No activation + "cmp x13, #0x20\n" + "bge 149f\n" + "tbz x13, #4, 140f\n" + "st1 { v8.8h }, [x12], #0x10\n" + "st1 { v9.8h }, [x12], #0x10\n" + "st1 { v12.8h }, [x24], #0x10\n" + "st1 { v13.8h }, [x24], #0x10\n" + "st1 { v16.8h }, [x23], #0x10\n" + "st1 { v17.8h }, [x23], #0x10\n" + "tbz x13, #3, 136f\n" + "st1 { v10.8h }, [x12], #0x10\n" + "st1 { v14.8h }, [x24], #0x10\n" + "st1 { v18.8h }, [x23], #0x10\n" + "tbz x13, #2, 134f\n" + "str d11, [x12], #0x8\n" + "str d15, [x24], #0x8\n" + "str d19, [x23], #0x8\n" + "tbz x13, #1, 133f\n" + "st1 { v11.s }[2], [x12], #0x4\n" + "st1 { v15.s }[2], [x24], #0x4\n" + "st1 { v19.s }[2], [x23], #0x4\n" + "tbz x13, #0, 148f\n" + "st1 { v11.h }[6], [x12]\n" + "st1 { v15.h }[6], [x24]\n" + "st1 { v19.h }[6], [x23]\n" + "b 148f\n" + "133:" // Height 3: Partial direct writeback: partial_1_28 + "tbz x13, #0, 148f\n" + "st1 { v11.h }[4], [x12]\n" + "st1 { v15.h }[4], [x24]\n" + "st1 { v19.h }[4], [x23]\n" + "b 148f\n" + "134:" // Height 3: Partial direct writeback: partial_2_24 + "tbz x13, #1, 135f\n" + "str s11, [x12], #0x4\n" + "str s15, [x24], #0x4\n" + "str s19, [x23], #0x4\n" + "tbz x13, #0, 148f\n" + "st1 { v11.h }[2], [x12]\n" + "st1 { v15.h }[2], [x24]\n" + "st1 { v19.h }[2], [x23]\n" + "b 148f\n" + "135:" // Height 3: Partial direct writeback: partial_1_24 + "tbz x13, #0, 148f\n" + "str h11, [x12, #0x0]\n" + "str h15, [x24, #0x0]\n" + "str h19, [x23, #0x0]\n" + "b 148f\n" + "136:" // Height 3: Partial direct writeback: partial_4_16 + "tbz x13, #2, 138f\n" + "str d10, [x12], #0x8\n" + "str d14, [x24], #0x8\n" + "str d18, [x23], #0x8\n" + "tbz x13, #1, 137f\n" + "st1 { v10.s }[2], [x12], #0x4\n" + "st1 { v14.s }[2], [x24], #0x4\n" + "st1 { v18.s }[2], [x23], #0x4\n" + "tbz x13, #0, 148f\n" + "st1 { v10.h }[6], [x12]\n" + "st1 { v14.h }[6], [x24]\n" + "st1 { v18.h }[6], [x23]\n" + "b 148f\n" + "137:" // Height 3: Partial direct writeback: partial_1_20 + "tbz x13, #0, 148f\n" + "st1 { v10.h }[4], [x12]\n" + "st1 { v14.h }[4], [x24]\n" + "st1 { v18.h }[4], [x23]\n" + "b 148f\n" + "138:" // Height 3: Partial direct writeback: partial_2_16 + "tbz x13, #1, 139f\n" + "str s10, [x12], #0x4\n" + "str s14, [x24], #0x4\n" + "str s18, [x23], #0x4\n" + "tbz x13, #0, 148f\n" + "st1 { v10.h }[2], [x12]\n" + "st1 { v14.h }[2], [x24]\n" + "st1 { v18.h }[2], [x23]\n" + "b 148f\n" + "139:" // Height 3: Partial direct writeback: partial_1_16 + "tbz x13, #0, 148f\n" + "str h10, [x12, #0x0]\n" + "str h14, [x24, #0x0]\n" + "str h18, [x23, #0x0]\n" + "b 148f\n" + "140:" // Height 3: Partial direct writeback: partial_8_0 + "tbz x13, #3, 144f\n" + "st1 { v8.8h }, [x12], #0x10\n" + "st1 { v12.8h }, [x24], #0x10\n" + "st1 { v16.8h }, [x23], #0x10\n" + "tbz x13, #2, 142f\n" + "str d9, [x12], #0x8\n" + "str d13, [x24], #0x8\n" + "str d17, [x23], #0x8\n" + "tbz x13, #1, 141f\n" + "st1 { v9.s }[2], [x12], #0x4\n" + "st1 { v13.s }[2], [x24], #0x4\n" + "st1 { v17.s }[2], [x23], #0x4\n" + "tbz x13, #0, 148f\n" + "st1 { v9.h }[6], [x12]\n" + "st1 { v13.h }[6], [x24]\n" + "st1 { v17.h }[6], [x23]\n" + "b 148f\n" + "141:" // Height 3: Partial direct writeback: partial_1_12 + "tbz x13, #0, 148f\n" + "st1 { v9.h }[4], [x12]\n" + "st1 { v13.h }[4], [x24]\n" + "st1 { v17.h }[4], [x23]\n" + "b 148f\n" + "142:" // Height 3: Partial direct writeback: partial_2_8 + "tbz x13, #1, 143f\n" + "str s9, [x12], #0x4\n" + "str s13, [x24], #0x4\n" + "str s17, [x23], #0x4\n" + "tbz x13, #0, 148f\n" + "st1 { v9.h }[2], [x12]\n" + "st1 { v13.h }[2], [x24]\n" + "st1 { v17.h }[2], [x23]\n" + "b 148f\n" + "143:" // Height 3: Partial direct writeback: partial_1_8 + "tbz x13, #0, 148f\n" + "str h9, [x12, #0x0]\n" + "str h13, [x24, #0x0]\n" + "str h17, [x23, #0x0]\n" + "b 148f\n" + "144:" // Height 3: Partial direct writeback: partial_4_0 + "tbz x13, #2, 146f\n" + "str d8, [x12], #0x8\n" + "str d12, [x24], #0x8\n" + "str d16, [x23], #0x8\n" + "tbz x13, #1, 145f\n" + "st1 { v8.s }[2], [x12], #0x4\n" + "st1 { v12.s }[2], [x24], #0x4\n" + "st1 { v16.s }[2], [x23], #0x4\n" + "tbz x13, #0, 148f\n" + "st1 { v8.h }[6], [x12]\n" + "st1 { v12.h }[6], [x24]\n" + "st1 { v16.h }[6], [x23]\n" + "b 148f\n" + "145:" // Height 3: Partial direct writeback: partial_1_4 + "tbz x13, #0, 148f\n" + "st1 { v8.h }[4], [x12]\n" + "st1 { v12.h }[4], [x24]\n" + "st1 { v16.h }[4], [x23]\n" + "b 148f\n" + "146:" // Height 3: Partial direct writeback: partial_2_0 + "tbz x13, #1, 147f\n" + "str s8, [x12], #0x4\n" + "str s12, [x24], #0x4\n" + "str s16, [x23], #0x4\n" + "tbz x13, #0, 148f\n" + "st1 { v8.h }[2], [x12]\n" + "st1 { v12.h }[2], [x24]\n" + "st1 { v16.h }[2], [x23]\n" + "b 148f\n" + "147:" // Height 3: Partial direct writeback: partial_1_0 + "str h8, [x12, #0x0]\n" + "str h12, [x24, #0x0]\n" + "str h16, [x23, #0x0]\n" + "148:" // Height 3: Partial direct writeback: Done + "b 150f\n" + "149:" // Height 3: Full writeback + "str q8, [x12, #0x0]\n" + "str q9, [x12, #0x10]\n" + "str q10, [x12, #0x20]\n" + "str q11, [x12, #0x30]\n" + "add x12, x12, #0x40\n" + "str q12, [x24, #0x0]\n" + "str q13, [x24, #0x10]\n" + "str q14, [x24, #0x20]\n" + "str q15, [x24, #0x30]\n" + "str q16, [x23, #0x0]\n" + "str q17, [x23, #0x10]\n" + "str q18, [x23, #0x20]\n" + "str q19, [x23, #0x30]\n" + "150:" // Height 3: Writeback done + "subs x13, x13, #0x20\n" + "bgt 102b\n" + "b 302f\n" + "151:" // Height 4 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "152:" // Height 4: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "cmp x13, #0x18\n" + "bgt 153f\n" + "cmp x13, #0x10\n" + "mov x28, x11\n" + "bgt 153f\n" + "cmp x13, #0x8\n" + "mov x9, x11\n" + "bgt 153f\n" + "mov x10, x11\n" + "153:" // Height 4: B setup done + "cbz x14, 154f\n" + "ldr q8, [x14, #0x0]\n" + "ldr q9, [x14, #0x10]\n" + "mov v12.16b, v8.16b\n" + "mov v13.16b, v9.16b\n" + "ldr q10, [x14, #0x20]\n" + "ldr q11, [x14, #0x30]\n" + "mov v14.16b, v10.16b\n" + "mov v15.16b, v11.16b\n" + "mov v16.16b, v8.16b\n" + "mov v17.16b, v9.16b\n" + "add x14, x14, #0x40\n" + "mov v18.16b, v10.16b\n" + "mov v19.16b, v11.16b\n" + "mov v20.16b, v8.16b\n" + "mov v21.16b, v9.16b\n" + "mov v22.16b, v10.16b\n" + "mov v23.16b, v11.16b\n" + "b 173f\n" + "154:" // Height 4: no bias + "tbz %x[flags], #0, 172f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "cmp x13, #0x20\n" + "add x22, x23, x19, LSL #1\n" + "bge 171f\n" + "tbz x13, #4, 162f\n" + "ld1 { v8.8h }, [x12], #0x10\n" + "ld1 { v12.8h }, [x24], #0x10\n" + "ld1 { v16.8h }, [x23], #0x10\n" + "ld1 { v20.8h }, [x22], #0x10\n" + "ld1 { v9.8h }, [x12], #0x10\n" + "ld1 { v13.8h }, [x24], #0x10\n" + "ld1 { v17.8h }, [x23], #0x10\n" + "ld1 { v21.8h }, [x22], #0x10\n" + "tbz x13, #3, 158f\n" + "ld1 { v10.8h }, [x12], #0x10\n" + "ld1 { v14.8h }, [x24], #0x10\n" + "ld1 { v18.8h }, [x23], #0x10\n" + "ld1 { v22.8h }, [x22], #0x10\n" + "tbz x13, #2, 156f\n" + "ldr d11, [x12], #0x8\n" + "ldr d15, [x24], #0x8\n" + "ldr d19, [x23], #0x8\n" + "ldr d23, [x22], #0x8\n" + "tbz x13, #1, 155f\n" + "ld1 { v11.s }[2], [x12], #0x4\n" + "ld1 { v15.s }[2], [x24], #0x4\n" + "mov x19, #0x3c\n" + "ld1 { v19.s }[2], [x23], #0x4\n" + "ld1 { v23.s }[2], [x22], #0x4\n" + "tbz x13, #0, 170f\n" + "ld1 { v11.h }[6], [x12]\n" + "ld1 { v15.h }[6], [x24]\n" + "ld1 { v19.h }[6], [x23]\n" + "ld1 { v23.h }[6], [x22]\n" + "b 170f\n" + "155:" // Height 4: Partial accumulate: partial_1_28 + "mov x19, #0x38\n" + "tbz x13, #0, 170f\n" + "ld1 { v11.h }[4], [x12]\n" + "ld1 { v15.h }[4], [x24]\n" + "ld1 { v19.h }[4], [x23]\n" + "ld1 { v23.h }[4], [x22]\n" + "b 170f\n" + "156:" // Height 4: Partial accumulate: partial_2_24 + "tbz x13, #1, 157f\n" + "ldr s11, [x12], #0x4\n" + "ldr s15, [x24], #0x4\n" + "mov x19, #0x34\n" + "ldr s19, [x23], #0x4\n" + "ldr s23, [x22], #0x4\n" + "tbz x13, #0, 170f\n" + "ld1 { v11.h }[2], [x12]\n" + "ld1 { v15.h }[2], [x24]\n" + "ld1 { v19.h }[2], [x23]\n" + "ld1 { v23.h }[2], [x22]\n" + "b 170f\n" + "157:" // Height 4: Partial accumulate: partial_1_24 + "mov x19, #0x30\n" + "tbz x13, #0, 170f\n" + "ldr h11, [x12, #0x0]\n" + "ldr h15, [x24, #0x0]\n" + "ldr h19, [x23, #0x0]\n" + "ldr h23, [x22, #0x0]\n" + "b 170f\n" + "158:" // Height 4: Partial accumulate: partial_4_16 + "tbz x13, #2, 160f\n" + "ldr d10, [x12], #0x8\n" + "ldr d14, [x24], #0x8\n" + "ldr d18, [x23], #0x8\n" + "ldr d22, [x22], #0x8\n" + "tbz x13, #1, 159f\n" + "ld1 { v10.s }[2], [x12], #0x4\n" + "ld1 { v14.s }[2], [x24], #0x4\n" + "mov x19, #0x2c\n" + "ld1 { v18.s }[2], [x23], #0x4\n" + "ld1 { v22.s }[2], [x22], #0x4\n" + "tbz x13, #0, 170f\n" + "ld1 { v10.h }[6], [x12]\n" + "ld1 { v14.h }[6], [x24]\n" + "ld1 { v18.h }[6], [x23]\n" + "ld1 { v22.h }[6], [x22]\n" + "b 170f\n" + "159:" // Height 4: Partial accumulate: partial_1_20 + "mov x19, #0x28\n" + "tbz x13, #0, 170f\n" + "ld1 { v10.h }[4], [x12]\n" + "ld1 { v14.h }[4], [x24]\n" + "ld1 { v18.h }[4], [x23]\n" + "ld1 { v22.h }[4], [x22]\n" + "b 170f\n" + "160:" // Height 4: Partial accumulate: partial_2_16 + "tbz x13, #1, 161f\n" + "ldr s10, [x12], #0x4\n" + "ldr s14, [x24], #0x4\n" + "mov x19, #0x24\n" + "ldr s18, [x23], #0x4\n" + "ldr s22, [x22], #0x4\n" + "tbz x13, #0, 170f\n" + "ld1 { v10.h }[2], [x12]\n" + "ld1 { v14.h }[2], [x24]\n" + "ld1 { v18.h }[2], [x23]\n" + "ld1 { v22.h }[2], [x22]\n" + "b 170f\n" + "161:" // Height 4: Partial accumulate: partial_1_16 + "mov x19, #0x20\n" + "tbz x13, #0, 170f\n" + "ldr h10, [x12, #0x0]\n" + "ldr h14, [x24, #0x0]\n" + "ldr h18, [x23, #0x0]\n" + "ldr h22, [x22, #0x0]\n" + "b 170f\n" + "162:" // Height 4: Partial accumulate: partial_8_0 + "tbz x13, #3, 166f\n" + "ld1 { v8.8h }, [x12], #0x10\n" + "ld1 { v12.8h }, [x24], #0x10\n" + "ld1 { v16.8h }, [x23], #0x10\n" + "ld1 { v20.8h }, [x22], #0x10\n" + "tbz x13, #2, 164f\n" + "ldr d9, [x12], #0x8\n" + "ldr d13, [x24], #0x8\n" + "ldr d17, [x23], #0x8\n" + "ldr d21, [x22], #0x8\n" + "tbz x13, #1, 163f\n" + "ld1 { v9.s }[2], [x12], #0x4\n" + "ld1 { v13.s }[2], [x24], #0x4\n" + "mov x19, #0x1c\n" + "ld1 { v17.s }[2], [x23], #0x4\n" + "ld1 { v21.s }[2], [x22], #0x4\n" + "tbz x13, #0, 170f\n" + "ld1 { v9.h }[6], [x12]\n" + "ld1 { v13.h }[6], [x24]\n" + "ld1 { v17.h }[6], [x23]\n" + "ld1 { v21.h }[6], [x22]\n" + "b 170f\n" + "163:" // Height 4: Partial accumulate: partial_1_12 + "mov x19, #0x18\n" + "tbz x13, #0, 170f\n" + "ld1 { v9.h }[4], [x12]\n" + "ld1 { v13.h }[4], [x24]\n" + "ld1 { v17.h }[4], [x23]\n" + "ld1 { v21.h }[4], [x22]\n" + "b 170f\n" + "164:" // Height 4: Partial accumulate: partial_2_8 + "tbz x13, #1, 165f\n" + "ldr s9, [x12], #0x4\n" + "ldr s13, [x24], #0x4\n" + "mov x19, #0x14\n" + "ldr s17, [x23], #0x4\n" + "ldr s21, [x22], #0x4\n" + "tbz x13, #0, 170f\n" + "ld1 { v9.h }[2], [x12]\n" + "ld1 { v13.h }[2], [x24]\n" + "ld1 { v17.h }[2], [x23]\n" + "ld1 { v21.h }[2], [x22]\n" + "b 170f\n" + "165:" // Height 4: Partial accumulate: partial_1_8 + "mov x19, #0x10\n" + "tbz x13, #0, 170f\n" + "ldr h9, [x12, #0x0]\n" + "ldr h13, [x24, #0x0]\n" + "ldr h17, [x23, #0x0]\n" + "ldr h21, [x22, #0x0]\n" + "b 170f\n" + "166:" // Height 4: Partial accumulate: partial_4_0 + "tbz x13, #2, 168f\n" + "ldr d8, [x12], #0x8\n" + "ldr d12, [x24], #0x8\n" + "ldr d16, [x23], #0x8\n" + "ldr d20, [x22], #0x8\n" + "tbz x13, #1, 167f\n" + "ld1 { v8.s }[2], [x12], #0x4\n" + "ld1 { v12.s }[2], [x24], #0x4\n" + "mov x19, #0xc\n" + "ld1 { v16.s }[2], [x23], #0x4\n" + "ld1 { v20.s }[2], [x22], #0x4\n" + "tbz x13, #0, 170f\n" + "ld1 { v8.h }[6], [x12]\n" + "ld1 { v12.h }[6], [x24]\n" + "ld1 { v16.h }[6], [x23]\n" + "ld1 { v20.h }[6], [x22]\n" + "b 170f\n" + "167:" // Height 4: Partial accumulate: partial_1_4 + "mov x19, #0x8\n" + "tbz x13, #0, 170f\n" + "ld1 { v8.h }[4], [x12]\n" + "ld1 { v12.h }[4], [x24]\n" + "ld1 { v16.h }[4], [x23]\n" + "ld1 { v20.h }[4], [x22]\n" + "b 170f\n" + "168:" // Height 4: Partial accumulate: partial_2_0 + "tbz x13, #1, 169f\n" + "ldr s8, [x12], #0x4\n" + "ldr s12, [x24], #0x4\n" + "mov x19, #0x4\n" + "ldr s16, [x23], #0x4\n" + "ldr s20, [x22], #0x4\n" + "tbz x13, #0, 170f\n" + "ld1 { v8.h }[2], [x12]\n" + "ld1 { v12.h }[2], [x24]\n" + "ld1 { v16.h }[2], [x23]\n" + "ld1 { v20.h }[2], [x22]\n" + "b 170f\n" + "169:" // Height 4: Partial accumulate: partial_1_0 + "ldr h8, [x12, #0x0]\n" + "ldr h12, [x24, #0x0]\n" + "mov x19, #0x0\n" + "ldr h16, [x23, #0x0]\n" + "ldr h20, [x22, #0x0]\n" + "170:" // Height 4: Partial accumulate: Done + "sub x12, x12, x19\n" + "b 173f\n" + "171:" // Height 4: full accumulate + "ldr q8, [x12, #0x0]\n" + "ldr q9, [x12, #0x10]\n" + "ldr q10, [x12, #0x20]\n" + "ldr q11, [x12, #0x30]\n" + "ldr q12, [x24, #0x0]\n" + "ldr q13, [x24, #0x10]\n" + "ldr q14, [x24, #0x20]\n" + "ldr q15, [x24, #0x30]\n" + "ldr q16, [x23, #0x0]\n" + "ldr q17, [x23, #0x10]\n" + "ldr q18, [x23, #0x20]\n" + "ldr q19, [x23, #0x30]\n" + "ldr q20, [x22, #0x0]\n" + "ldr q21, [x22, #0x10]\n" + "ldr q22, [x22, #0x20]\n" + "ldr q23, [x22, #0x30]\n" + "b 173f\n" + "172:" // Height 4: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "173:" // Height 4: setup done + "mov x27, #0x0\n" + "174:" // Height 4: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 175f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "ldr x22, [x20, #0x18]\n" + "cbnz x27, 176f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "add x24, x24, x19, LSL #1\n" + "add x23, x23, x19, LSL #1\n" + "add x22, x22, x19, LSL #1\n" + "b 176f\n" + "175:" // Height 4: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "add x22, x23, x19, LSL #1\n" + "176:" // Height 4: input setup done + "cmp x26, #0x8\n" + "blt 179f\n" + "ldr q0, [x25, #0x0]\n" + "ldr q1, [x24, #0x0]\n" + "cmp x26, #0x10\n" + "ldr q2, [x23, #0x0]\n" + "ldr q3, [x22, #0x0]\n" + "ldr q6, [x11, #0x0]\n" + "ldr q7, [x10, #0x0]\n" + "blt 178f\n" + "177:" // Height 4: Multiply loop: Main loop head + "fmla v8.8h, v6.8h, v0.h[0]\n" + "fmla v12.8h, v6.8h, v1.h[0]\n" + "sub x26, x26, #0x8\n" + "cmp x26, #0x10\n" + "fmla v16.8h, v6.8h, v2.h[0]\n" + "fmla v20.8h, v6.8h, v3.h[0]\n" + "ldr q6, [x9, #0x0]\n" + "add x25, x25, #0x10\n" + "fmla v9.8h, v7.8h, v0.h[0]\n" + "fmla v13.8h, v7.8h, v1.h[0]\n" + "add x24, x24, #0x10\n" + "add x23, x23, #0x10\n" + "fmla v17.8h, v7.8h, v2.h[0]\n" + "fmla v21.8h, v7.8h, v3.h[0]\n" + "ldr q7, [x28, #0x0]\n" + "add x22, x22, #0x10\n" + "fmla v10.8h, v6.8h, v0.h[0]\n" + "fmla v14.8h, v6.8h, v1.h[0]\n" + "fmla v18.8h, v6.8h, v2.h[0]\n" + "fmla v22.8h, v6.8h, v3.h[0]\n" + "ldr q6, [x11, #0x10]\n" + "fmla v11.8h, v7.8h, v0.h[0]\n" + "fmla v15.8h, v7.8h, v1.h[0]\n" + "fmla v19.8h, v7.8h, v2.h[0]\n" + "fmla v23.8h, v7.8h, v3.h[0]\n" + "ldr q7, [x10, #0x10]\n" + "fmla v8.8h, v6.8h, v0.h[1]\n" + "fmla v12.8h, v6.8h, v1.h[1]\n" + "fmla v16.8h, v6.8h, v2.h[1]\n" + "fmla v20.8h, v6.8h, v3.h[1]\n" + "ldr q6, [x9, #0x10]\n" + "fmla v9.8h, v7.8h, v0.h[1]\n" + "fmla v13.8h, v7.8h, v1.h[1]\n" + "fmla v17.8h, v7.8h, v2.h[1]\n" + "fmla v21.8h, v7.8h, v3.h[1]\n" + "ldr q7, [x28, #0x10]\n" + "fmla v10.8h, v6.8h, v0.h[1]\n" + "fmla v14.8h, v6.8h, v1.h[1]\n" + "fmla v18.8h, v6.8h, v2.h[1]\n" + "fmla v22.8h, v6.8h, v3.h[1]\n" + "ldr q6, [x11, #0x20]\n" + "fmla v11.8h, v7.8h, v0.h[1]\n" + "fmla v15.8h, v7.8h, v1.h[1]\n" + "fmla v19.8h, v7.8h, v2.h[1]\n" + "fmla v23.8h, v7.8h, v3.h[1]\n" + "ldr q7, [x10, #0x20]\n" + "fmla v8.8h, v6.8h, v0.h[2]\n" + "fmla v12.8h, v6.8h, v1.h[2]\n" + "fmla v16.8h, v6.8h, v2.h[2]\n" + "fmla v20.8h, v6.8h, v3.h[2]\n" + "ldr q6, [x9, #0x20]\n" + "fmla v9.8h, v7.8h, v0.h[2]\n" + "fmla v13.8h, v7.8h, v1.h[2]\n" + "fmla v17.8h, v7.8h, v2.h[2]\n" + "fmla v21.8h, v7.8h, v3.h[2]\n" + "ldr q7, [x28, #0x20]\n" + "fmla v10.8h, v6.8h, v0.h[2]\n" + "fmla v14.8h, v6.8h, v1.h[2]\n" + "fmla v18.8h, v6.8h, v2.h[2]\n" + "fmla v22.8h, v6.8h, v3.h[2]\n" + "ldr q6, [x11, #0x30]\n" + "fmla v11.8h, v7.8h, v0.h[2]\n" + "fmla v15.8h, v7.8h, v1.h[2]\n" + "fmla v19.8h, v7.8h, v2.h[2]\n" + "fmla v23.8h, v7.8h, v3.h[2]\n" + "ldr q7, [x10, #0x30]\n" + "fmla v8.8h, v6.8h, v0.h[3]\n" + "fmla v12.8h, v6.8h, v1.h[3]\n" + "fmla v16.8h, v6.8h, v2.h[3]\n" + "fmla v20.8h, v6.8h, v3.h[3]\n" + "ldr q6, [x9, #0x30]\n" + "fmla v9.8h, v7.8h, v0.h[3]\n" + "fmla v13.8h, v7.8h, v1.h[3]\n" + "fmla v17.8h, v7.8h, v2.h[3]\n" + "fmla v21.8h, v7.8h, v3.h[3]\n" + "ldr q7, [x28, #0x30]\n" + "fmla v10.8h, v6.8h, v0.h[3]\n" + "fmla v14.8h, v6.8h, v1.h[3]\n" + "fmla v18.8h, v6.8h, v2.h[3]\n" + "fmla v22.8h, v6.8h, v3.h[3]\n" + "ldr q6, [x11, #0x40]\n" + "fmla v11.8h, v7.8h, v0.h[3]\n" + "fmla v15.8h, v7.8h, v1.h[3]\n" + "fmla v19.8h, v7.8h, v2.h[3]\n" + "fmla v23.8h, v7.8h, v3.h[3]\n" + "ldr q7, [x10, #0x40]\n" + "fmla v8.8h, v6.8h, v0.h[4]\n" + "fmla v12.8h, v6.8h, v1.h[4]\n" + "fmla v16.8h, v6.8h, v2.h[4]\n" + "fmla v20.8h, v6.8h, v3.h[4]\n" + "ldr q6, [x9, #0x40]\n" + "fmla v9.8h, v7.8h, v0.h[4]\n" + "fmla v13.8h, v7.8h, v1.h[4]\n" + "fmla v17.8h, v7.8h, v2.h[4]\n" + "fmla v21.8h, v7.8h, v3.h[4]\n" + "ldr q7, [x28, #0x40]\n" + "fmla v10.8h, v6.8h, v0.h[4]\n" + "fmla v14.8h, v6.8h, v1.h[4]\n" + "fmla v18.8h, v6.8h, v2.h[4]\n" + "fmla v22.8h, v6.8h, v3.h[4]\n" + "ldr q6, [x11, #0x50]\n" + "fmla v11.8h, v7.8h, v0.h[4]\n" + "fmla v15.8h, v7.8h, v1.h[4]\n" + "fmla v19.8h, v7.8h, v2.h[4]\n" + "fmla v23.8h, v7.8h, v3.h[4]\n" + "ldr q7, [x10, #0x50]\n" + "fmla v8.8h, v6.8h, v0.h[5]\n" + "fmla v12.8h, v6.8h, v1.h[5]\n" + "fmla v16.8h, v6.8h, v2.h[5]\n" + "fmla v20.8h, v6.8h, v3.h[5]\n" + "ldr q6, [x9, #0x50]\n" + "fmla v9.8h, v7.8h, v0.h[5]\n" + "fmla v13.8h, v7.8h, v1.h[5]\n" + "fmla v17.8h, v7.8h, v2.h[5]\n" + "fmla v21.8h, v7.8h, v3.h[5]\n" + "ldr q7, [x28, #0x50]\n" + "fmla v10.8h, v6.8h, v0.h[5]\n" + "fmla v14.8h, v6.8h, v1.h[5]\n" + "fmla v18.8h, v6.8h, v2.h[5]\n" + "fmla v22.8h, v6.8h, v3.h[5]\n" + "ldr q6, [x11, #0x60]\n" + "fmla v11.8h, v7.8h, v0.h[5]\n" + "fmla v15.8h, v7.8h, v1.h[5]\n" + "fmla v19.8h, v7.8h, v2.h[5]\n" + "fmla v23.8h, v7.8h, v3.h[5]\n" + "ldr q7, [x10, #0x60]\n" + "fmla v8.8h, v6.8h, v0.h[6]\n" + "fmla v12.8h, v6.8h, v1.h[6]\n" + "fmla v16.8h, v6.8h, v2.h[6]\n" + "fmla v20.8h, v6.8h, v3.h[6]\n" + "ldr q6, [x9, #0x60]\n" + "fmla v9.8h, v7.8h, v0.h[6]\n" + "fmla v13.8h, v7.8h, v1.h[6]\n" + "fmla v17.8h, v7.8h, v2.h[6]\n" + "fmla v21.8h, v7.8h, v3.h[6]\n" + "ldr q7, [x28, #0x60]\n" + "fmla v10.8h, v6.8h, v0.h[6]\n" + "fmla v14.8h, v6.8h, v1.h[6]\n" + "fmla v18.8h, v6.8h, v2.h[6]\n" + "fmla v22.8h, v6.8h, v3.h[6]\n" + "ldr q6, [x11, #0x70]\n" + "add x11, x11, #0x80\n" + "fmla v11.8h, v7.8h, v0.h[6]\n" + "fmla v15.8h, v7.8h, v1.h[6]\n" + "fmla v19.8h, v7.8h, v2.h[6]\n" + "fmla v23.8h, v7.8h, v3.h[6]\n" + "ldr q7, [x10, #0x70]\n" + "add x10, x10, #0x80\n" + "fmla v8.8h, v6.8h, v0.h[7]\n" + "fmla v12.8h, v6.8h, v1.h[7]\n" + "fmla v16.8h, v6.8h, v2.h[7]\n" + "fmla v20.8h, v6.8h, v3.h[7]\n" + "ldr q6, [x9, #0x70]\n" + "add x9, x9, #0x80\n" + "fmla v9.8h, v7.8h, v0.h[7]\n" + "fmla v13.8h, v7.8h, v1.h[7]\n" + "fmla v17.8h, v7.8h, v2.h[7]\n" + "fmla v21.8h, v7.8h, v3.h[7]\n" + "ldr q7, [x28, #0x70]\n" + "add x28, x28, #0x80\n" + "fmla v10.8h, v6.8h, v0.h[7]\n" + "fmla v14.8h, v6.8h, v1.h[7]\n" + "fmla v18.8h, v6.8h, v2.h[7]\n" + "fmla v22.8h, v6.8h, v3.h[7]\n" + "ldr q6, [x11, #0x0]\n" + "fmla v11.8h, v7.8h, v0.h[7]\n" + "ldr q0, [x25, #0x0]\n" + "fmla v15.8h, v7.8h, v1.h[7]\n" + "ldr q1, [x24, #0x0]\n" + "fmla v19.8h, v7.8h, v2.h[7]\n" + "ldr q2, [x23, #0x0]\n" + "fmla v23.8h, v7.8h, v3.h[7]\n" + "ldr q3, [x22, #0x0]\n" + "ldr q7, [x10, #0x0]\n" + "bge 177b\n" + "178:" // Height 4: Multiply loop: Single iteration only + "fmla v8.8h, v6.8h, v0.h[0]\n" + "fmla v12.8h, v6.8h, v1.h[0]\n" + "sub x26, x26, #0x8\n" + "add x25, x25, #0x10\n" + "fmla v16.8h, v6.8h, v2.h[0]\n" + "fmla v20.8h, v6.8h, v3.h[0]\n" + "ldr q6, [x9, #0x0]\n" + "add x24, x24, #0x10\n" + "fmla v9.8h, v7.8h, v0.h[0]\n" + "fmla v13.8h, v7.8h, v1.h[0]\n" + "add x23, x23, #0x10\n" + "add x22, x22, #0x10\n" + "fmla v17.8h, v7.8h, v2.h[0]\n" + "fmla v21.8h, v7.8h, v3.h[0]\n" + "ldr q7, [x28, #0x0]\n" + "fmla v10.8h, v6.8h, v0.h[0]\n" + "fmla v14.8h, v6.8h, v1.h[0]\n" + "fmla v18.8h, v6.8h, v2.h[0]\n" + "fmla v22.8h, v6.8h, v3.h[0]\n" + "ldr q6, [x11, #0x10]\n" + "fmla v11.8h, v7.8h, v0.h[0]\n" + "fmla v15.8h, v7.8h, v1.h[0]\n" + "fmla v19.8h, v7.8h, v2.h[0]\n" + "fmla v23.8h, v7.8h, v3.h[0]\n" + "ldr q7, [x10, #0x10]\n" + "fmla v8.8h, v6.8h, v0.h[1]\n" + "fmla v12.8h, v6.8h, v1.h[1]\n" + "fmla v16.8h, v6.8h, v2.h[1]\n" + "fmla v20.8h, v6.8h, v3.h[1]\n" + "ldr q6, [x9, #0x10]\n" + "fmla v9.8h, v7.8h, v0.h[1]\n" + "fmla v13.8h, v7.8h, v1.h[1]\n" + "fmla v17.8h, v7.8h, v2.h[1]\n" + "fmla v21.8h, v7.8h, v3.h[1]\n" + "ldr q7, [x28, #0x10]\n" + "fmla v10.8h, v6.8h, v0.h[1]\n" + "fmla v14.8h, v6.8h, v1.h[1]\n" + "fmla v18.8h, v6.8h, v2.h[1]\n" + "fmla v22.8h, v6.8h, v3.h[1]\n" + "ldr q6, [x11, #0x20]\n" + "fmla v11.8h, v7.8h, v0.h[1]\n" + "fmla v15.8h, v7.8h, v1.h[1]\n" + "fmla v19.8h, v7.8h, v2.h[1]\n" + "fmla v23.8h, v7.8h, v3.h[1]\n" + "ldr q7, [x10, #0x20]\n" + "fmla v8.8h, v6.8h, v0.h[2]\n" + "fmla v12.8h, v6.8h, v1.h[2]\n" + "fmla v16.8h, v6.8h, v2.h[2]\n" + "fmla v20.8h, v6.8h, v3.h[2]\n" + "ldr q6, [x9, #0x20]\n" + "fmla v9.8h, v7.8h, v0.h[2]\n" + "fmla v13.8h, v7.8h, v1.h[2]\n" + "fmla v17.8h, v7.8h, v2.h[2]\n" + "fmla v21.8h, v7.8h, v3.h[2]\n" + "ldr q7, [x28, #0x20]\n" + "fmla v10.8h, v6.8h, v0.h[2]\n" + "fmla v14.8h, v6.8h, v1.h[2]\n" + "fmla v18.8h, v6.8h, v2.h[2]\n" + "fmla v22.8h, v6.8h, v3.h[2]\n" + "ldr q6, [x11, #0x30]\n" + "fmla v11.8h, v7.8h, v0.h[2]\n" + "fmla v15.8h, v7.8h, v1.h[2]\n" + "fmla v19.8h, v7.8h, v2.h[2]\n" + "fmla v23.8h, v7.8h, v3.h[2]\n" + "ldr q7, [x10, #0x30]\n" + "fmla v8.8h, v6.8h, v0.h[3]\n" + "fmla v12.8h, v6.8h, v1.h[3]\n" + "fmla v16.8h, v6.8h, v2.h[3]\n" + "fmla v20.8h, v6.8h, v3.h[3]\n" + "ldr q6, [x9, #0x30]\n" + "fmla v9.8h, v7.8h, v0.h[3]\n" + "fmla v13.8h, v7.8h, v1.h[3]\n" + "fmla v17.8h, v7.8h, v2.h[3]\n" + "fmla v21.8h, v7.8h, v3.h[3]\n" + "ldr q7, [x28, #0x30]\n" + "fmla v10.8h, v6.8h, v0.h[3]\n" + "fmla v14.8h, v6.8h, v1.h[3]\n" + "fmla v18.8h, v6.8h, v2.h[3]\n" + "fmla v22.8h, v6.8h, v3.h[3]\n" + "ldr q6, [x11, #0x40]\n" + "fmla v11.8h, v7.8h, v0.h[3]\n" + "fmla v15.8h, v7.8h, v1.h[3]\n" + "fmla v19.8h, v7.8h, v2.h[3]\n" + "fmla v23.8h, v7.8h, v3.h[3]\n" + "ldr q7, [x10, #0x40]\n" + "fmla v8.8h, v6.8h, v0.h[4]\n" + "fmla v12.8h, v6.8h, v1.h[4]\n" + "fmla v16.8h, v6.8h, v2.h[4]\n" + "fmla v20.8h, v6.8h, v3.h[4]\n" + "ldr q6, [x9, #0x40]\n" + "fmla v9.8h, v7.8h, v0.h[4]\n" + "fmla v13.8h, v7.8h, v1.h[4]\n" + "fmla v17.8h, v7.8h, v2.h[4]\n" + "fmla v21.8h, v7.8h, v3.h[4]\n" + "ldr q7, [x28, #0x40]\n" + "fmla v10.8h, v6.8h, v0.h[4]\n" + "fmla v14.8h, v6.8h, v1.h[4]\n" + "fmla v18.8h, v6.8h, v2.h[4]\n" + "fmla v22.8h, v6.8h, v3.h[4]\n" + "ldr q6, [x11, #0x50]\n" + "fmla v11.8h, v7.8h, v0.h[4]\n" + "fmla v15.8h, v7.8h, v1.h[4]\n" + "fmla v19.8h, v7.8h, v2.h[4]\n" + "fmla v23.8h, v7.8h, v3.h[4]\n" + "ldr q7, [x10, #0x50]\n" + "fmla v8.8h, v6.8h, v0.h[5]\n" + "fmla v12.8h, v6.8h, v1.h[5]\n" + "fmla v16.8h, v6.8h, v2.h[5]\n" + "fmla v20.8h, v6.8h, v3.h[5]\n" + "ldr q6, [x9, #0x50]\n" + "fmla v9.8h, v7.8h, v0.h[5]\n" + "fmla v13.8h, v7.8h, v1.h[5]\n" + "fmla v17.8h, v7.8h, v2.h[5]\n" + "fmla v21.8h, v7.8h, v3.h[5]\n" + "ldr q7, [x28, #0x50]\n" + "fmla v10.8h, v6.8h, v0.h[5]\n" + "fmla v14.8h, v6.8h, v1.h[5]\n" + "fmla v18.8h, v6.8h, v2.h[5]\n" + "fmla v22.8h, v6.8h, v3.h[5]\n" + "ldr q6, [x11, #0x60]\n" + "fmla v11.8h, v7.8h, v0.h[5]\n" + "fmla v15.8h, v7.8h, v1.h[5]\n" + "fmla v19.8h, v7.8h, v2.h[5]\n" + "fmla v23.8h, v7.8h, v3.h[5]\n" + "ldr q7, [x10, #0x60]\n" + "fmla v8.8h, v6.8h, v0.h[6]\n" + "fmla v12.8h, v6.8h, v1.h[6]\n" + "fmla v16.8h, v6.8h, v2.h[6]\n" + "fmla v20.8h, v6.8h, v3.h[6]\n" + "ldr q6, [x9, #0x60]\n" + "fmla v9.8h, v7.8h, v0.h[6]\n" + "fmla v13.8h, v7.8h, v1.h[6]\n" + "fmla v17.8h, v7.8h, v2.h[6]\n" + "fmla v21.8h, v7.8h, v3.h[6]\n" + "ldr q7, [x28, #0x60]\n" + "fmla v10.8h, v6.8h, v0.h[6]\n" + "fmla v14.8h, v6.8h, v1.h[6]\n" + "fmla v18.8h, v6.8h, v2.h[6]\n" + "fmla v22.8h, v6.8h, v3.h[6]\n" + "ldr q6, [x11, #0x70]\n" + "add x11, x11, #0x80\n" + "fmla v11.8h, v7.8h, v0.h[6]\n" + "fmla v15.8h, v7.8h, v1.h[6]\n" + "fmla v19.8h, v7.8h, v2.h[6]\n" + "fmla v23.8h, v7.8h, v3.h[6]\n" + "ldr q7, [x10, #0x70]\n" + "add x10, x10, #0x80\n" + "fmla v8.8h, v6.8h, v0.h[7]\n" + "fmla v12.8h, v6.8h, v1.h[7]\n" + "fmla v16.8h, v6.8h, v2.h[7]\n" + "fmla v20.8h, v6.8h, v3.h[7]\n" + "ldr q6, [x9, #0x70]\n" + "add x9, x9, #0x80\n" + "fmla v9.8h, v7.8h, v0.h[7]\n" + "fmla v13.8h, v7.8h, v1.h[7]\n" + "fmla v17.8h, v7.8h, v2.h[7]\n" + "fmla v21.8h, v7.8h, v3.h[7]\n" + "ldr q7, [x28, #0x70]\n" + "add x28, x28, #0x80\n" + "fmla v10.8h, v6.8h, v0.h[7]\n" + "fmla v14.8h, v6.8h, v1.h[7]\n" + "fmla v18.8h, v6.8h, v2.h[7]\n" + "fmla v22.8h, v6.8h, v3.h[7]\n" + "fmla v11.8h, v7.8h, v0.h[7]\n" + "fmla v15.8h, v7.8h, v1.h[7]\n" + "fmla v19.8h, v7.8h, v2.h[7]\n" + "fmla v23.8h, v7.8h, v3.h[7]\n" + "179:" // Height 4: Multiply loop: Main loop skip + "cbz x26, 181f\n" + "180:" // Height 4: Multiply loop: Odd block loop + "ldr h0, [x25], #0x2\n" + "ldr h1, [x24], #0x2\n" + "sub x26, x26, #0x1\n" + "ldr h2, [x23], #0x2\n" + "ldr h3, [x22], #0x2\n" + "ldr q6, [x11, #0x0]\n" + "ldr q7, [x10, #0x0]\n" + "fmla v8.8h, v6.8h, v0.h[0]\n" + "fmla v12.8h, v6.8h, v1.h[0]\n" + "fmla v16.8h, v6.8h, v2.h[0]\n" + "fmla v20.8h, v6.8h, v3.h[0]\n" + "ldr q6, [x9, #0x0]\n" + "add x11, x11, #0x10\n" + "fmla v9.8h, v7.8h, v0.h[0]\n" + "fmla v13.8h, v7.8h, v1.h[0]\n" + "add x10, x10, #0x10\n" + "add x9, x9, #0x10\n" + "fmla v17.8h, v7.8h, v2.h[0]\n" + "fmla v21.8h, v7.8h, v3.h[0]\n" + "ldr q7, [x28, #0x0]\n" + "add x28, x28, #0x10\n" + "fmla v10.8h, v6.8h, v0.h[0]\n" + "fmla v14.8h, v6.8h, v1.h[0]\n" + "fmla v18.8h, v6.8h, v2.h[0]\n" + "fmla v22.8h, v6.8h, v3.h[0]\n" + "fmla v11.8h, v7.8h, v0.h[0]\n" + "fmla v15.8h, v7.8h, v1.h[0]\n" + "fmla v19.8h, v7.8h, v2.h[0]\n" + "fmla v23.8h, v7.8h, v3.h[0]\n" + "cbnz x26, 180b\n" + "181:" // Height 4: Multiply loop: No odd multiplies + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 174b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "add x22, x23, x19, LSL #1\n" + "tbz %x[flags], #1, 182f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1r { v1.8h }, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1r { v0.8h }, [x19]\n" + "fmin v8.8h, v8.8h, v1.8h\n" + "fmin v9.8h, v9.8h, v1.8h\n" + "fmin v10.8h, v10.8h, v1.8h\n" + "fmin v11.8h, v11.8h, v1.8h\n" + "fmin v12.8h, v12.8h, v1.8h\n" + "fmin v13.8h, v13.8h, v1.8h\n" + "fmin v14.8h, v14.8h, v1.8h\n" + "fmin v15.8h, v15.8h, v1.8h\n" + "fmin v16.8h, v16.8h, v1.8h\n" + "fmin v17.8h, v17.8h, v1.8h\n" + "fmin v18.8h, v18.8h, v1.8h\n" + "fmin v19.8h, v19.8h, v1.8h\n" + "fmin v20.8h, v20.8h, v1.8h\n" + "fmin v21.8h, v21.8h, v1.8h\n" + "fmin v22.8h, v22.8h, v1.8h\n" + "fmin v23.8h, v23.8h, v1.8h\n" + "fmax v8.8h, v8.8h, v0.8h\n" + "fmax v9.8h, v9.8h, v0.8h\n" + "fmax v10.8h, v10.8h, v0.8h\n" + "fmax v11.8h, v11.8h, v0.8h\n" + "fmax v12.8h, v12.8h, v0.8h\n" + "fmax v13.8h, v13.8h, v0.8h\n" + "fmax v14.8h, v14.8h, v0.8h\n" + "fmax v15.8h, v15.8h, v0.8h\n" + "fmax v16.8h, v16.8h, v0.8h\n" + "fmax v17.8h, v17.8h, v0.8h\n" + "fmax v18.8h, v18.8h, v0.8h\n" + "fmax v19.8h, v19.8h, v0.8h\n" + "fmax v20.8h, v20.8h, v0.8h\n" + "fmax v21.8h, v21.8h, v0.8h\n" + "fmax v22.8h, v22.8h, v0.8h\n" + "fmax v23.8h, v23.8h, v0.8h\n" + "182:" // Height 4: No activation + "cmp x13, #0x20\n" + "bge 199f\n" + "tbz x13, #4, 190f\n" + "st1 { v8.8h }, [x12], #0x10\n" + "st1 { v9.8h }, [x12], #0x10\n" + "st1 { v12.8h }, [x24], #0x10\n" + "st1 { v13.8h }, [x24], #0x10\n" + "st1 { v16.8h }, [x23], #0x10\n" + "st1 { v17.8h }, [x23], #0x10\n" + "st1 { v20.8h }, [x22], #0x10\n" + "st1 { v21.8h }, [x22], #0x10\n" + "tbz x13, #3, 186f\n" + "st1 { v10.8h }, [x12], #0x10\n" + "st1 { v14.8h }, [x24], #0x10\n" + "st1 { v18.8h }, [x23], #0x10\n" + "st1 { v22.8h }, [x22], #0x10\n" + "tbz x13, #2, 184f\n" + "str d11, [x12], #0x8\n" + "str d15, [x24], #0x8\n" + "str d19, [x23], #0x8\n" + "str d23, [x22], #0x8\n" + "tbz x13, #1, 183f\n" + "st1 { v11.s }[2], [x12], #0x4\n" + "st1 { v15.s }[2], [x24], #0x4\n" + "st1 { v19.s }[2], [x23], #0x4\n" + "st1 { v23.s }[2], [x22], #0x4\n" + "tbz x13, #0, 198f\n" + "st1 { v11.h }[6], [x12]\n" + "st1 { v15.h }[6], [x24]\n" + "st1 { v19.h }[6], [x23]\n" + "st1 { v23.h }[6], [x22]\n" + "b 198f\n" + "183:" // Height 4: Partial direct writeback: partial_1_28 + "tbz x13, #0, 198f\n" + "st1 { v11.h }[4], [x12]\n" + "st1 { v15.h }[4], [x24]\n" + "st1 { v19.h }[4], [x23]\n" + "st1 { v23.h }[4], [x22]\n" + "b 198f\n" + "184:" // Height 4: Partial direct writeback: partial_2_24 + "tbz x13, #1, 185f\n" + "str s11, [x12], #0x4\n" + "str s15, [x24], #0x4\n" + "str s19, [x23], #0x4\n" + "str s23, [x22], #0x4\n" + "tbz x13, #0, 198f\n" + "st1 { v11.h }[2], [x12]\n" + "st1 { v15.h }[2], [x24]\n" + "st1 { v19.h }[2], [x23]\n" + "st1 { v23.h }[2], [x22]\n" + "b 198f\n" + "185:" // Height 4: Partial direct writeback: partial_1_24 + "tbz x13, #0, 198f\n" + "str h11, [x12, #0x0]\n" + "str h15, [x24, #0x0]\n" + "str h19, [x23, #0x0]\n" + "str h23, [x22, #0x0]\n" + "b 198f\n" + "186:" // Height 4: Partial direct writeback: partial_4_16 + "tbz x13, #2, 188f\n" + "str d10, [x12], #0x8\n" + "str d14, [x24], #0x8\n" + "str d18, [x23], #0x8\n" + "str d22, [x22], #0x8\n" + "tbz x13, #1, 187f\n" + "st1 { v10.s }[2], [x12], #0x4\n" + "st1 { v14.s }[2], [x24], #0x4\n" + "st1 { v18.s }[2], [x23], #0x4\n" + "st1 { v22.s }[2], [x22], #0x4\n" + "tbz x13, #0, 198f\n" + "st1 { v10.h }[6], [x12]\n" + "st1 { v14.h }[6], [x24]\n" + "st1 { v18.h }[6], [x23]\n" + "st1 { v22.h }[6], [x22]\n" + "b 198f\n" + "187:" // Height 4: Partial direct writeback: partial_1_20 + "tbz x13, #0, 198f\n" + "st1 { v10.h }[4], [x12]\n" + "st1 { v14.h }[4], [x24]\n" + "st1 { v18.h }[4], [x23]\n" + "st1 { v22.h }[4], [x22]\n" + "b 198f\n" + "188:" // Height 4: Partial direct writeback: partial_2_16 + "tbz x13, #1, 189f\n" + "str s10, [x12], #0x4\n" + "str s14, [x24], #0x4\n" + "str s18, [x23], #0x4\n" + "str s22, [x22], #0x4\n" + "tbz x13, #0, 198f\n" + "st1 { v10.h }[2], [x12]\n" + "st1 { v14.h }[2], [x24]\n" + "st1 { v18.h }[2], [x23]\n" + "st1 { v22.h }[2], [x22]\n" + "b 198f\n" + "189:" // Height 4: Partial direct writeback: partial_1_16 + "tbz x13, #0, 198f\n" + "str h10, [x12, #0x0]\n" + "str h14, [x24, #0x0]\n" + "str h18, [x23, #0x0]\n" + "str h22, [x22, #0x0]\n" + "b 198f\n" + "190:" // Height 4: Partial direct writeback: partial_8_0 + "tbz x13, #3, 194f\n" + "st1 { v8.8h }, [x12], #0x10\n" + "st1 { v12.8h }, [x24], #0x10\n" + "st1 { v16.8h }, [x23], #0x10\n" + "st1 { v20.8h }, [x22], #0x10\n" + "tbz x13, #2, 192f\n" + "str d9, [x12], #0x8\n" + "str d13, [x24], #0x8\n" + "str d17, [x23], #0x8\n" + "str d21, [x22], #0x8\n" + "tbz x13, #1, 191f\n" + "st1 { v9.s }[2], [x12], #0x4\n" + "st1 { v13.s }[2], [x24], #0x4\n" + "st1 { v17.s }[2], [x23], #0x4\n" + "st1 { v21.s }[2], [x22], #0x4\n" + "tbz x13, #0, 198f\n" + "st1 { v9.h }[6], [x12]\n" + "st1 { v13.h }[6], [x24]\n" + "st1 { v17.h }[6], [x23]\n" + "st1 { v21.h }[6], [x22]\n" + "b 198f\n" + "191:" // Height 4: Partial direct writeback: partial_1_12 + "tbz x13, #0, 198f\n" + "st1 { v9.h }[4], [x12]\n" + "st1 { v13.h }[4], [x24]\n" + "st1 { v17.h }[4], [x23]\n" + "st1 { v21.h }[4], [x22]\n" + "b 198f\n" + "192:" // Height 4: Partial direct writeback: partial_2_8 + "tbz x13, #1, 193f\n" + "str s9, [x12], #0x4\n" + "str s13, [x24], #0x4\n" + "str s17, [x23], #0x4\n" + "str s21, [x22], #0x4\n" + "tbz x13, #0, 198f\n" + "st1 { v9.h }[2], [x12]\n" + "st1 { v13.h }[2], [x24]\n" + "st1 { v17.h }[2], [x23]\n" + "st1 { v21.h }[2], [x22]\n" + "b 198f\n" + "193:" // Height 4: Partial direct writeback: partial_1_8 + "tbz x13, #0, 198f\n" + "str h9, [x12, #0x0]\n" + "str h13, [x24, #0x0]\n" + "str h17, [x23, #0x0]\n" + "str h21, [x22, #0x0]\n" + "b 198f\n" + "194:" // Height 4: Partial direct writeback: partial_4_0 + "tbz x13, #2, 196f\n" + "str d8, [x12], #0x8\n" + "str d12, [x24], #0x8\n" + "str d16, [x23], #0x8\n" + "str d20, [x22], #0x8\n" + "tbz x13, #1, 195f\n" + "st1 { v8.s }[2], [x12], #0x4\n" + "st1 { v12.s }[2], [x24], #0x4\n" + "st1 { v16.s }[2], [x23], #0x4\n" + "st1 { v20.s }[2], [x22], #0x4\n" + "tbz x13, #0, 198f\n" + "st1 { v8.h }[6], [x12]\n" + "st1 { v12.h }[6], [x24]\n" + "st1 { v16.h }[6], [x23]\n" + "st1 { v20.h }[6], [x22]\n" + "b 198f\n" + "195:" // Height 4: Partial direct writeback: partial_1_4 + "tbz x13, #0, 198f\n" + "st1 { v8.h }[4], [x12]\n" + "st1 { v12.h }[4], [x24]\n" + "st1 { v16.h }[4], [x23]\n" + "st1 { v20.h }[4], [x22]\n" + "b 198f\n" + "196:" // Height 4: Partial direct writeback: partial_2_0 + "tbz x13, #1, 197f\n" + "str s8, [x12], #0x4\n" + "str s12, [x24], #0x4\n" + "str s16, [x23], #0x4\n" + "str s20, [x22], #0x4\n" + "tbz x13, #0, 198f\n" + "st1 { v8.h }[2], [x12]\n" + "st1 { v12.h }[2], [x24]\n" + "st1 { v16.h }[2], [x23]\n" + "st1 { v20.h }[2], [x22]\n" + "b 198f\n" + "197:" // Height 4: Partial direct writeback: partial_1_0 + "str h8, [x12, #0x0]\n" + "str h12, [x24, #0x0]\n" + "str h16, [x23, #0x0]\n" + "str h20, [x22, #0x0]\n" + "198:" // Height 4: Partial direct writeback: Done + "b 200f\n" + "199:" // Height 4: Full writeback + "str q8, [x12, #0x0]\n" + "str q9, [x12, #0x10]\n" + "str q10, [x12, #0x20]\n" + "str q11, [x12, #0x30]\n" + "add x12, x12, #0x40\n" + "str q12, [x24, #0x0]\n" + "str q13, [x24, #0x10]\n" + "str q14, [x24, #0x20]\n" + "str q15, [x24, #0x30]\n" + "str q16, [x23, #0x0]\n" + "str q17, [x23, #0x10]\n" + "str q18, [x23, #0x20]\n" + "str q19, [x23, #0x30]\n" + "str q20, [x22, #0x0]\n" + "str q21, [x22, #0x10]\n" + "str q22, [x22, #0x20]\n" + "str q23, [x22, #0x30]\n" + "200:" // Height 4: Writeback done + "subs x13, x13, #0x20\n" + "bgt 152b\n" + "b 302f\n" + "201:" // Height 5 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "202:" // Height 5: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "cmp x13, #0x18\n" + "bgt 203f\n" + "cmp x13, #0x10\n" + "mov x28, x11\n" + "bgt 203f\n" + "cmp x13, #0x8\n" + "mov x9, x11\n" + "bgt 203f\n" + "mov x10, x11\n" + "203:" // Height 5: B setup done + "cbz x14, 204f\n" + "ldr q8, [x14, #0x0]\n" + "ldr q9, [x14, #0x10]\n" + "mov v12.16b, v8.16b\n" + "mov v13.16b, v9.16b\n" + "ldr q10, [x14, #0x20]\n" + "ldr q11, [x14, #0x30]\n" + "mov v14.16b, v10.16b\n" + "mov v15.16b, v11.16b\n" + "mov v16.16b, v8.16b\n" + "mov v17.16b, v9.16b\n" + "add x14, x14, #0x40\n" + "mov v18.16b, v10.16b\n" + "mov v19.16b, v11.16b\n" + "mov v20.16b, v8.16b\n" + "mov v21.16b, v9.16b\n" + "mov v22.16b, v10.16b\n" + "mov v23.16b, v11.16b\n" + "mov v24.16b, v8.16b\n" + "mov v25.16b, v9.16b\n" + "mov v26.16b, v10.16b\n" + "mov v27.16b, v11.16b\n" + "b 223f\n" + "204:" // Height 5: no bias + "tbz %x[flags], #0, 222f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "add x22, x23, x19, LSL #1\n" + "cmp x13, #0x20\n" + "add x21, x22, x19, LSL #1\n" + "bge 221f\n" + "tbz x13, #4, 212f\n" + "ld1 { v8.8h }, [x12], #0x10\n" + "ld1 { v12.8h }, [x24], #0x10\n" + "ld1 { v16.8h }, [x23], #0x10\n" + "ld1 { v20.8h }, [x22], #0x10\n" + "ld1 { v24.8h }, [x21], #0x10\n" + "ld1 { v9.8h }, [x12], #0x10\n" + "ld1 { v13.8h }, [x24], #0x10\n" + "ld1 { v17.8h }, [x23], #0x10\n" + "ld1 { v21.8h }, [x22], #0x10\n" + "ld1 { v25.8h }, [x21], #0x10\n" + "tbz x13, #3, 208f\n" + "ld1 { v10.8h }, [x12], #0x10\n" + "ld1 { v14.8h }, [x24], #0x10\n" + "ld1 { v18.8h }, [x23], #0x10\n" + "ld1 { v22.8h }, [x22], #0x10\n" + "ld1 { v26.8h }, [x21], #0x10\n" + "tbz x13, #2, 206f\n" + "ldr d11, [x12], #0x8\n" + "ldr d15, [x24], #0x8\n" + "ldr d19, [x23], #0x8\n" + "ldr d23, [x22], #0x8\n" + "ldr d27, [x21], #0x8\n" + "tbz x13, #1, 205f\n" + "ld1 { v11.s }[2], [x12], #0x4\n" + "ld1 { v15.s }[2], [x24], #0x4\n" + "mov x19, #0x3c\n" + "ld1 { v19.s }[2], [x23], #0x4\n" + "ld1 { v23.s }[2], [x22], #0x4\n" + "ld1 { v27.s }[2], [x21], #0x4\n" + "tbz x13, #0, 220f\n" + "ld1 { v11.h }[6], [x12]\n" + "ld1 { v15.h }[6], [x24]\n" + "ld1 { v19.h }[6], [x23]\n" + "ld1 { v23.h }[6], [x22]\n" + "ld1 { v27.h }[6], [x21]\n" + "b 220f\n" + "205:" // Height 5: Partial accumulate: partial_1_28 + "mov x19, #0x38\n" + "tbz x13, #0, 220f\n" + "ld1 { v11.h }[4], [x12]\n" + "ld1 { v15.h }[4], [x24]\n" + "ld1 { v19.h }[4], [x23]\n" + "ld1 { v23.h }[4], [x22]\n" + "ld1 { v27.h }[4], [x21]\n" + "b 220f\n" + "206:" // Height 5: Partial accumulate: partial_2_24 + "tbz x13, #1, 207f\n" + "ldr s11, [x12], #0x4\n" + "ldr s15, [x24], #0x4\n" + "mov x19, #0x34\n" + "ldr s19, [x23], #0x4\n" + "ldr s23, [x22], #0x4\n" + "ldr s27, [x21], #0x4\n" + "tbz x13, #0, 220f\n" + "ld1 { v11.h }[2], [x12]\n" + "ld1 { v15.h }[2], [x24]\n" + "ld1 { v19.h }[2], [x23]\n" + "ld1 { v23.h }[2], [x22]\n" + "ld1 { v27.h }[2], [x21]\n" + "b 220f\n" + "207:" // Height 5: Partial accumulate: partial_1_24 + "mov x19, #0x30\n" + "tbz x13, #0, 220f\n" + "ldr h11, [x12, #0x0]\n" + "ldr h15, [x24, #0x0]\n" + "ldr h19, [x23, #0x0]\n" + "ldr h23, [x22, #0x0]\n" + "ldr h27, [x21, #0x0]\n" + "b 220f\n" + "208:" // Height 5: Partial accumulate: partial_4_16 + "tbz x13, #2, 210f\n" + "ldr d10, [x12], #0x8\n" + "ldr d14, [x24], #0x8\n" + "ldr d18, [x23], #0x8\n" + "ldr d22, [x22], #0x8\n" + "ldr d26, [x21], #0x8\n" + "tbz x13, #1, 209f\n" + "ld1 { v10.s }[2], [x12], #0x4\n" + "ld1 { v14.s }[2], [x24], #0x4\n" + "mov x19, #0x2c\n" + "ld1 { v18.s }[2], [x23], #0x4\n" + "ld1 { v22.s }[2], [x22], #0x4\n" + "ld1 { v26.s }[2], [x21], #0x4\n" + "tbz x13, #0, 220f\n" + "ld1 { v10.h }[6], [x12]\n" + "ld1 { v14.h }[6], [x24]\n" + "ld1 { v18.h }[6], [x23]\n" + "ld1 { v22.h }[6], [x22]\n" + "ld1 { v26.h }[6], [x21]\n" + "b 220f\n" + "209:" // Height 5: Partial accumulate: partial_1_20 + "mov x19, #0x28\n" + "tbz x13, #0, 220f\n" + "ld1 { v10.h }[4], [x12]\n" + "ld1 { v14.h }[4], [x24]\n" + "ld1 { v18.h }[4], [x23]\n" + "ld1 { v22.h }[4], [x22]\n" + "ld1 { v26.h }[4], [x21]\n" + "b 220f\n" + "210:" // Height 5: Partial accumulate: partial_2_16 + "tbz x13, #1, 211f\n" + "ldr s10, [x12], #0x4\n" + "ldr s14, [x24], #0x4\n" + "mov x19, #0x24\n" + "ldr s18, [x23], #0x4\n" + "ldr s22, [x22], #0x4\n" + "ldr s26, [x21], #0x4\n" + "tbz x13, #0, 220f\n" + "ld1 { v10.h }[2], [x12]\n" + "ld1 { v14.h }[2], [x24]\n" + "ld1 { v18.h }[2], [x23]\n" + "ld1 { v22.h }[2], [x22]\n" + "ld1 { v26.h }[2], [x21]\n" + "b 220f\n" + "211:" // Height 5: Partial accumulate: partial_1_16 + "mov x19, #0x20\n" + "tbz x13, #0, 220f\n" + "ldr h10, [x12, #0x0]\n" + "ldr h14, [x24, #0x0]\n" + "ldr h18, [x23, #0x0]\n" + "ldr h22, [x22, #0x0]\n" + "ldr h26, [x21, #0x0]\n" + "b 220f\n" + "212:" // Height 5: Partial accumulate: partial_8_0 + "tbz x13, #3, 216f\n" + "ld1 { v8.8h }, [x12], #0x10\n" + "ld1 { v12.8h }, [x24], #0x10\n" + "ld1 { v16.8h }, [x23], #0x10\n" + "ld1 { v20.8h }, [x22], #0x10\n" + "ld1 { v24.8h }, [x21], #0x10\n" + "tbz x13, #2, 214f\n" + "ldr d9, [x12], #0x8\n" + "ldr d13, [x24], #0x8\n" + "ldr d17, [x23], #0x8\n" + "ldr d21, [x22], #0x8\n" + "ldr d25, [x21], #0x8\n" + "tbz x13, #1, 213f\n" + "ld1 { v9.s }[2], [x12], #0x4\n" + "ld1 { v13.s }[2], [x24], #0x4\n" + "mov x19, #0x1c\n" + "ld1 { v17.s }[2], [x23], #0x4\n" + "ld1 { v21.s }[2], [x22], #0x4\n" + "ld1 { v25.s }[2], [x21], #0x4\n" + "tbz x13, #0, 220f\n" + "ld1 { v9.h }[6], [x12]\n" + "ld1 { v13.h }[6], [x24]\n" + "ld1 { v17.h }[6], [x23]\n" + "ld1 { v21.h }[6], [x22]\n" + "ld1 { v25.h }[6], [x21]\n" + "b 220f\n" + "213:" // Height 5: Partial accumulate: partial_1_12 + "mov x19, #0x18\n" + "tbz x13, #0, 220f\n" + "ld1 { v9.h }[4], [x12]\n" + "ld1 { v13.h }[4], [x24]\n" + "ld1 { v17.h }[4], [x23]\n" + "ld1 { v21.h }[4], [x22]\n" + "ld1 { v25.h }[4], [x21]\n" + "b 220f\n" + "214:" // Height 5: Partial accumulate: partial_2_8 + "tbz x13, #1, 215f\n" + "ldr s9, [x12], #0x4\n" + "ldr s13, [x24], #0x4\n" + "mov x19, #0x14\n" + "ldr s17, [x23], #0x4\n" + "ldr s21, [x22], #0x4\n" + "ldr s25, [x21], #0x4\n" + "tbz x13, #0, 220f\n" + "ld1 { v9.h }[2], [x12]\n" + "ld1 { v13.h }[2], [x24]\n" + "ld1 { v17.h }[2], [x23]\n" + "ld1 { v21.h }[2], [x22]\n" + "ld1 { v25.h }[2], [x21]\n" + "b 220f\n" + "215:" // Height 5: Partial accumulate: partial_1_8 + "mov x19, #0x10\n" + "tbz x13, #0, 220f\n" + "ldr h9, [x12, #0x0]\n" + "ldr h13, [x24, #0x0]\n" + "ldr h17, [x23, #0x0]\n" + "ldr h21, [x22, #0x0]\n" + "ldr h25, [x21, #0x0]\n" + "b 220f\n" + "216:" // Height 5: Partial accumulate: partial_4_0 + "tbz x13, #2, 218f\n" + "ldr d8, [x12], #0x8\n" + "ldr d12, [x24], #0x8\n" + "ldr d16, [x23], #0x8\n" + "ldr d20, [x22], #0x8\n" + "ldr d24, [x21], #0x8\n" + "tbz x13, #1, 217f\n" + "ld1 { v8.s }[2], [x12], #0x4\n" + "ld1 { v12.s }[2], [x24], #0x4\n" + "mov x19, #0xc\n" + "ld1 { v16.s }[2], [x23], #0x4\n" + "ld1 { v20.s }[2], [x22], #0x4\n" + "ld1 { v24.s }[2], [x21], #0x4\n" + "tbz x13, #0, 220f\n" + "ld1 { v8.h }[6], [x12]\n" + "ld1 { v12.h }[6], [x24]\n" + "ld1 { v16.h }[6], [x23]\n" + "ld1 { v20.h }[6], [x22]\n" + "ld1 { v24.h }[6], [x21]\n" + "b 220f\n" + "217:" // Height 5: Partial accumulate: partial_1_4 + "mov x19, #0x8\n" + "tbz x13, #0, 220f\n" + "ld1 { v8.h }[4], [x12]\n" + "ld1 { v12.h }[4], [x24]\n" + "ld1 { v16.h }[4], [x23]\n" + "ld1 { v20.h }[4], [x22]\n" + "ld1 { v24.h }[4], [x21]\n" + "b 220f\n" + "218:" // Height 5: Partial accumulate: partial_2_0 + "tbz x13, #1, 219f\n" + "ldr s8, [x12], #0x4\n" + "ldr s12, [x24], #0x4\n" + "mov x19, #0x4\n" + "ldr s16, [x23], #0x4\n" + "ldr s20, [x22], #0x4\n" + "ldr s24, [x21], #0x4\n" + "tbz x13, #0, 220f\n" + "ld1 { v8.h }[2], [x12]\n" + "ld1 { v12.h }[2], [x24]\n" + "ld1 { v16.h }[2], [x23]\n" + "ld1 { v20.h }[2], [x22]\n" + "ld1 { v24.h }[2], [x21]\n" + "b 220f\n" + "219:" // Height 5: Partial accumulate: partial_1_0 + "ldr h8, [x12, #0x0]\n" + "ldr h12, [x24, #0x0]\n" + "mov x19, #0x0\n" + "ldr h16, [x23, #0x0]\n" + "ldr h20, [x22, #0x0]\n" + "ldr h24, [x21, #0x0]\n" + "220:" // Height 5: Partial accumulate: Done + "sub x12, x12, x19\n" + "b 223f\n" + "221:" // Height 5: full accumulate + "ldr q8, [x12, #0x0]\n" + "ldr q9, [x12, #0x10]\n" + "ldr q10, [x12, #0x20]\n" + "ldr q11, [x12, #0x30]\n" + "ldr q12, [x24, #0x0]\n" + "ldr q13, [x24, #0x10]\n" + "ldr q14, [x24, #0x20]\n" + "ldr q15, [x24, #0x30]\n" + "ldr q16, [x23, #0x0]\n" + "ldr q17, [x23, #0x10]\n" + "ldr q18, [x23, #0x20]\n" + "ldr q19, [x23, #0x30]\n" + "ldr q20, [x22, #0x0]\n" + "ldr q21, [x22, #0x10]\n" + "ldr q22, [x22, #0x20]\n" + "ldr q23, [x22, #0x30]\n" + "ldr q24, [x21, #0x0]\n" + "ldr q25, [x21, #0x10]\n" + "ldr q26, [x21, #0x20]\n" + "ldr q27, [x21, #0x30]\n" + "b 223f\n" + "222:" // Height 5: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v26.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "223:" // Height 5: setup done + "mov x27, #0x0\n" + "224:" // Height 5: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 225f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "ldr x22, [x20, #0x18]\n" + "ldr x21, [x20, #0x20]\n" + "cbnz x27, 226f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "add x24, x24, x19, LSL #1\n" + "add x23, x23, x19, LSL #1\n" + "add x22, x22, x19, LSL #1\n" + "add x21, x21, x19, LSL #1\n" + "b 226f\n" + "225:" // Height 5: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "add x22, x23, x19, LSL #1\n" + "add x21, x22, x19, LSL #1\n" + "226:" // Height 5: input setup done + "cmp x26, #0x8\n" + "blt 229f\n" + "ldr q0, [x25, #0x0]\n" + "ldr q1, [x24, #0x0]\n" + "cmp x26, #0x10\n" + "ldr q2, [x23, #0x0]\n" + "ldr q3, [x22, #0x0]\n" + "ldr q4, [x21, #0x0]\n" + "ldr q6, [x11, #0x0]\n" + "ldr q7, [x10, #0x0]\n" + "blt 228f\n" + "227:" // Height 5: Multiply loop: Main loop head + "fmla v8.8h, v6.8h, v0.h[0]\n" + "fmla v12.8h, v6.8h, v1.h[0]\n" + "sub x26, x26, #0x8\n" + "cmp x26, #0x10\n" + "fmla v16.8h, v6.8h, v2.h[0]\n" + "fmla v20.8h, v6.8h, v3.h[0]\n" + "add x25, x25, #0x10\n" + "add x24, x24, #0x10\n" + "fmla v24.8h, v6.8h, v4.h[0]\n" + "ldr q6, [x9, #0x0]\n" + "fmla v9.8h, v7.8h, v0.h[0]\n" + "add x23, x23, #0x10\n" + "fmla v13.8h, v7.8h, v1.h[0]\n" + "fmla v17.8h, v7.8h, v2.h[0]\n" + "add x22, x22, #0x10\n" + "add x21, x21, #0x10\n" + "fmla v21.8h, v7.8h, v3.h[0]\n" + "fmla v25.8h, v7.8h, v4.h[0]\n" + "ldr q7, [x28, #0x0]\n" + "fmla v10.8h, v6.8h, v0.h[0]\n" + "fmla v14.8h, v6.8h, v1.h[0]\n" + "fmla v18.8h, v6.8h, v2.h[0]\n" + "fmla v22.8h, v6.8h, v3.h[0]\n" + "fmla v26.8h, v6.8h, v4.h[0]\n" + "ldr q6, [x11, #0x10]\n" + "fmla v11.8h, v7.8h, v0.h[0]\n" + "fmla v15.8h, v7.8h, v1.h[0]\n" + "fmla v19.8h, v7.8h, v2.h[0]\n" + "fmla v23.8h, v7.8h, v3.h[0]\n" + "fmla v27.8h, v7.8h, v4.h[0]\n" + "ldr q7, [x10, #0x10]\n" + "fmla v8.8h, v6.8h, v0.h[1]\n" + "fmla v12.8h, v6.8h, v1.h[1]\n" + "fmla v16.8h, v6.8h, v2.h[1]\n" + "fmla v20.8h, v6.8h, v3.h[1]\n" + "fmla v24.8h, v6.8h, v4.h[1]\n" + "ldr q6, [x9, #0x10]\n" + "fmla v9.8h, v7.8h, v0.h[1]\n" + "fmla v13.8h, v7.8h, v1.h[1]\n" + "fmla v17.8h, v7.8h, v2.h[1]\n" + "fmla v21.8h, v7.8h, v3.h[1]\n" + "fmla v25.8h, v7.8h, v4.h[1]\n" + "ldr q7, [x28, #0x10]\n" + "fmla v10.8h, v6.8h, v0.h[1]\n" + "fmla v14.8h, v6.8h, v1.h[1]\n" + "fmla v18.8h, v6.8h, v2.h[1]\n" + "fmla v22.8h, v6.8h, v3.h[1]\n" + "fmla v26.8h, v6.8h, v4.h[1]\n" + "ldr q6, [x11, #0x20]\n" + "fmla v11.8h, v7.8h, v0.h[1]\n" + "fmla v15.8h, v7.8h, v1.h[1]\n" + "fmla v19.8h, v7.8h, v2.h[1]\n" + "fmla v23.8h, v7.8h, v3.h[1]\n" + "fmla v27.8h, v7.8h, v4.h[1]\n" + "ldr q7, [x10, #0x20]\n" + "fmla v8.8h, v6.8h, v0.h[2]\n" + "fmla v12.8h, v6.8h, v1.h[2]\n" + "fmla v16.8h, v6.8h, v2.h[2]\n" + "fmla v20.8h, v6.8h, v3.h[2]\n" + "fmla v24.8h, v6.8h, v4.h[2]\n" + "ldr q6, [x9, #0x20]\n" + "fmla v9.8h, v7.8h, v0.h[2]\n" + "fmla v13.8h, v7.8h, v1.h[2]\n" + "fmla v17.8h, v7.8h, v2.h[2]\n" + "fmla v21.8h, v7.8h, v3.h[2]\n" + "fmla v25.8h, v7.8h, v4.h[2]\n" + "ldr q7, [x28, #0x20]\n" + "fmla v10.8h, v6.8h, v0.h[2]\n" + "fmla v14.8h, v6.8h, v1.h[2]\n" + "fmla v18.8h, v6.8h, v2.h[2]\n" + "fmla v22.8h, v6.8h, v3.h[2]\n" + "fmla v26.8h, v6.8h, v4.h[2]\n" + "ldr q6, [x11, #0x30]\n" + "fmla v11.8h, v7.8h, v0.h[2]\n" + "fmla v15.8h, v7.8h, v1.h[2]\n" + "fmla v19.8h, v7.8h, v2.h[2]\n" + "fmla v23.8h, v7.8h, v3.h[2]\n" + "fmla v27.8h, v7.8h, v4.h[2]\n" + "ldr q7, [x10, #0x30]\n" + "fmla v8.8h, v6.8h, v0.h[3]\n" + "fmla v12.8h, v6.8h, v1.h[3]\n" + "fmla v16.8h, v6.8h, v2.h[3]\n" + "fmla v20.8h, v6.8h, v3.h[3]\n" + "fmla v24.8h, v6.8h, v4.h[3]\n" + "ldr q6, [x9, #0x30]\n" + "fmla v9.8h, v7.8h, v0.h[3]\n" + "fmla v13.8h, v7.8h, v1.h[3]\n" + "fmla v17.8h, v7.8h, v2.h[3]\n" + "fmla v21.8h, v7.8h, v3.h[3]\n" + "fmla v25.8h, v7.8h, v4.h[3]\n" + "ldr q7, [x28, #0x30]\n" + "fmla v10.8h, v6.8h, v0.h[3]\n" + "fmla v14.8h, v6.8h, v1.h[3]\n" + "fmla v18.8h, v6.8h, v2.h[3]\n" + "fmla v22.8h, v6.8h, v3.h[3]\n" + "fmla v26.8h, v6.8h, v4.h[3]\n" + "ldr q6, [x11, #0x40]\n" + "fmla v11.8h, v7.8h, v0.h[3]\n" + "fmla v15.8h, v7.8h, v1.h[3]\n" + "fmla v19.8h, v7.8h, v2.h[3]\n" + "fmla v23.8h, v7.8h, v3.h[3]\n" + "fmla v27.8h, v7.8h, v4.h[3]\n" + "ldr q7, [x10, #0x40]\n" + "fmla v8.8h, v6.8h, v0.h[4]\n" + "fmla v12.8h, v6.8h, v1.h[4]\n" + "fmla v16.8h, v6.8h, v2.h[4]\n" + "fmla v20.8h, v6.8h, v3.h[4]\n" + "fmla v24.8h, v6.8h, v4.h[4]\n" + "ldr q6, [x9, #0x40]\n" + "fmla v9.8h, v7.8h, v0.h[4]\n" + "fmla v13.8h, v7.8h, v1.h[4]\n" + "fmla v17.8h, v7.8h, v2.h[4]\n" + "fmla v21.8h, v7.8h, v3.h[4]\n" + "fmla v25.8h, v7.8h, v4.h[4]\n" + "ldr q7, [x28, #0x40]\n" + "fmla v10.8h, v6.8h, v0.h[4]\n" + "fmla v14.8h, v6.8h, v1.h[4]\n" + "fmla v18.8h, v6.8h, v2.h[4]\n" + "fmla v22.8h, v6.8h, v3.h[4]\n" + "fmla v26.8h, v6.8h, v4.h[4]\n" + "ldr q6, [x11, #0x50]\n" + "fmla v11.8h, v7.8h, v0.h[4]\n" + "fmla v15.8h, v7.8h, v1.h[4]\n" + "fmla v19.8h, v7.8h, v2.h[4]\n" + "fmla v23.8h, v7.8h, v3.h[4]\n" + "fmla v27.8h, v7.8h, v4.h[4]\n" + "ldr q7, [x10, #0x50]\n" + "fmla v8.8h, v6.8h, v0.h[5]\n" + "fmla v12.8h, v6.8h, v1.h[5]\n" + "fmla v16.8h, v6.8h, v2.h[5]\n" + "fmla v20.8h, v6.8h, v3.h[5]\n" + "fmla v24.8h, v6.8h, v4.h[5]\n" + "ldr q6, [x9, #0x50]\n" + "fmla v9.8h, v7.8h, v0.h[5]\n" + "fmla v13.8h, v7.8h, v1.h[5]\n" + "fmla v17.8h, v7.8h, v2.h[5]\n" + "fmla v21.8h, v7.8h, v3.h[5]\n" + "fmla v25.8h, v7.8h, v4.h[5]\n" + "ldr q7, [x28, #0x50]\n" + "fmla v10.8h, v6.8h, v0.h[5]\n" + "fmla v14.8h, v6.8h, v1.h[5]\n" + "fmla v18.8h, v6.8h, v2.h[5]\n" + "fmla v22.8h, v6.8h, v3.h[5]\n" + "fmla v26.8h, v6.8h, v4.h[5]\n" + "ldr q6, [x11, #0x60]\n" + "fmla v11.8h, v7.8h, v0.h[5]\n" + "fmla v15.8h, v7.8h, v1.h[5]\n" + "fmla v19.8h, v7.8h, v2.h[5]\n" + "fmla v23.8h, v7.8h, v3.h[5]\n" + "fmla v27.8h, v7.8h, v4.h[5]\n" + "ldr q7, [x10, #0x60]\n" + "fmla v8.8h, v6.8h, v0.h[6]\n" + "fmla v12.8h, v6.8h, v1.h[6]\n" + "fmla v16.8h, v6.8h, v2.h[6]\n" + "fmla v20.8h, v6.8h, v3.h[6]\n" + "fmla v24.8h, v6.8h, v4.h[6]\n" + "ldr q6, [x9, #0x60]\n" + "fmla v9.8h, v7.8h, v0.h[6]\n" + "fmla v13.8h, v7.8h, v1.h[6]\n" + "fmla v17.8h, v7.8h, v2.h[6]\n" + "fmla v21.8h, v7.8h, v3.h[6]\n" + "fmla v25.8h, v7.8h, v4.h[6]\n" + "ldr q7, [x28, #0x60]\n" + "fmla v10.8h, v6.8h, v0.h[6]\n" + "fmla v14.8h, v6.8h, v1.h[6]\n" + "fmla v18.8h, v6.8h, v2.h[6]\n" + "fmla v22.8h, v6.8h, v3.h[6]\n" + "fmla v26.8h, v6.8h, v4.h[6]\n" + "ldr q6, [x11, #0x70]\n" + "fmla v11.8h, v7.8h, v0.h[6]\n" + "add x11, x11, #0x80\n" + "fmla v15.8h, v7.8h, v1.h[6]\n" + "fmla v19.8h, v7.8h, v2.h[6]\n" + "fmla v23.8h, v7.8h, v3.h[6]\n" + "fmla v27.8h, v7.8h, v4.h[6]\n" + "ldr q7, [x10, #0x70]\n" + "add x10, x10, #0x80\n" + "fmla v8.8h, v6.8h, v0.h[7]\n" + "fmla v12.8h, v6.8h, v1.h[7]\n" + "fmla v16.8h, v6.8h, v2.h[7]\n" + "fmla v20.8h, v6.8h, v3.h[7]\n" + "fmla v24.8h, v6.8h, v4.h[7]\n" + "ldr q6, [x9, #0x70]\n" + "fmla v9.8h, v7.8h, v0.h[7]\n" + "add x9, x9, #0x80\n" + "fmla v13.8h, v7.8h, v1.h[7]\n" + "fmla v17.8h, v7.8h, v2.h[7]\n" + "fmla v21.8h, v7.8h, v3.h[7]\n" + "fmla v25.8h, v7.8h, v4.h[7]\n" + "ldr q7, [x28, #0x70]\n" + "add x28, x28, #0x80\n" + "fmla v10.8h, v6.8h, v0.h[7]\n" + "fmla v14.8h, v6.8h, v1.h[7]\n" + "fmla v18.8h, v6.8h, v2.h[7]\n" + "fmla v22.8h, v6.8h, v3.h[7]\n" + "fmla v26.8h, v6.8h, v4.h[7]\n" + "ldr q6, [x11, #0x0]\n" + "fmla v11.8h, v7.8h, v0.h[7]\n" + "ldr q0, [x25, #0x0]\n" + "fmla v15.8h, v7.8h, v1.h[7]\n" + "ldr q1, [x24, #0x0]\n" + "fmla v19.8h, v7.8h, v2.h[7]\n" + "ldr q2, [x23, #0x0]\n" + "fmla v23.8h, v7.8h, v3.h[7]\n" + "ldr q3, [x22, #0x0]\n" + "fmla v27.8h, v7.8h, v4.h[7]\n" + "ldr q4, [x21, #0x0]\n" + "ldr q7, [x10, #0x0]\n" + "bge 227b\n" + "228:" // Height 5: Multiply loop: Single iteration only + "fmla v8.8h, v6.8h, v0.h[0]\n" + "fmla v12.8h, v6.8h, v1.h[0]\n" + "sub x26, x26, #0x8\n" + "add x25, x25, #0x10\n" + "fmla v16.8h, v6.8h, v2.h[0]\n" + "fmla v20.8h, v6.8h, v3.h[0]\n" + "add x24, x24, #0x10\n" + "add x23, x23, #0x10\n" + "fmla v24.8h, v6.8h, v4.h[0]\n" + "ldr q6, [x9, #0x0]\n" + "fmla v9.8h, v7.8h, v0.h[0]\n" + "add x22, x22, #0x10\n" + "fmla v13.8h, v7.8h, v1.h[0]\n" + "fmla v17.8h, v7.8h, v2.h[0]\n" + "add x21, x21, #0x10\n" + "fmla v21.8h, v7.8h, v3.h[0]\n" + "fmla v25.8h, v7.8h, v4.h[0]\n" + "ldr q7, [x28, #0x0]\n" + "fmla v10.8h, v6.8h, v0.h[0]\n" + "fmla v14.8h, v6.8h, v1.h[0]\n" + "fmla v18.8h, v6.8h, v2.h[0]\n" + "fmla v22.8h, v6.8h, v3.h[0]\n" + "fmla v26.8h, v6.8h, v4.h[0]\n" + "ldr q6, [x11, #0x10]\n" + "fmla v11.8h, v7.8h, v0.h[0]\n" + "fmla v15.8h, v7.8h, v1.h[0]\n" + "fmla v19.8h, v7.8h, v2.h[0]\n" + "fmla v23.8h, v7.8h, v3.h[0]\n" + "fmla v27.8h, v7.8h, v4.h[0]\n" + "ldr q7, [x10, #0x10]\n" + "fmla v8.8h, v6.8h, v0.h[1]\n" + "fmla v12.8h, v6.8h, v1.h[1]\n" + "fmla v16.8h, v6.8h, v2.h[1]\n" + "fmla v20.8h, v6.8h, v3.h[1]\n" + "fmla v24.8h, v6.8h, v4.h[1]\n" + "ldr q6, [x9, #0x10]\n" + "fmla v9.8h, v7.8h, v0.h[1]\n" + "fmla v13.8h, v7.8h, v1.h[1]\n" + "fmla v17.8h, v7.8h, v2.h[1]\n" + "fmla v21.8h, v7.8h, v3.h[1]\n" + "fmla v25.8h, v7.8h, v4.h[1]\n" + "ldr q7, [x28, #0x10]\n" + "fmla v10.8h, v6.8h, v0.h[1]\n" + "fmla v14.8h, v6.8h, v1.h[1]\n" + "fmla v18.8h, v6.8h, v2.h[1]\n" + "fmla v22.8h, v6.8h, v3.h[1]\n" + "fmla v26.8h, v6.8h, v4.h[1]\n" + "ldr q6, [x11, #0x20]\n" + "fmla v11.8h, v7.8h, v0.h[1]\n" + "fmla v15.8h, v7.8h, v1.h[1]\n" + "fmla v19.8h, v7.8h, v2.h[1]\n" + "fmla v23.8h, v7.8h, v3.h[1]\n" + "fmla v27.8h, v7.8h, v4.h[1]\n" + "ldr q7, [x10, #0x20]\n" + "fmla v8.8h, v6.8h, v0.h[2]\n" + "fmla v12.8h, v6.8h, v1.h[2]\n" + "fmla v16.8h, v6.8h, v2.h[2]\n" + "fmla v20.8h, v6.8h, v3.h[2]\n" + "fmla v24.8h, v6.8h, v4.h[2]\n" + "ldr q6, [x9, #0x20]\n" + "fmla v9.8h, v7.8h, v0.h[2]\n" + "fmla v13.8h, v7.8h, v1.h[2]\n" + "fmla v17.8h, v7.8h, v2.h[2]\n" + "fmla v21.8h, v7.8h, v3.h[2]\n" + "fmla v25.8h, v7.8h, v4.h[2]\n" + "ldr q7, [x28, #0x20]\n" + "fmla v10.8h, v6.8h, v0.h[2]\n" + "fmla v14.8h, v6.8h, v1.h[2]\n" + "fmla v18.8h, v6.8h, v2.h[2]\n" + "fmla v22.8h, v6.8h, v3.h[2]\n" + "fmla v26.8h, v6.8h, v4.h[2]\n" + "ldr q6, [x11, #0x30]\n" + "fmla v11.8h, v7.8h, v0.h[2]\n" + "fmla v15.8h, v7.8h, v1.h[2]\n" + "fmla v19.8h, v7.8h, v2.h[2]\n" + "fmla v23.8h, v7.8h, v3.h[2]\n" + "fmla v27.8h, v7.8h, v4.h[2]\n" + "ldr q7, [x10, #0x30]\n" + "fmla v8.8h, v6.8h, v0.h[3]\n" + "fmla v12.8h, v6.8h, v1.h[3]\n" + "fmla v16.8h, v6.8h, v2.h[3]\n" + "fmla v20.8h, v6.8h, v3.h[3]\n" + "fmla v24.8h, v6.8h, v4.h[3]\n" + "ldr q6, [x9, #0x30]\n" + "fmla v9.8h, v7.8h, v0.h[3]\n" + "fmla v13.8h, v7.8h, v1.h[3]\n" + "fmla v17.8h, v7.8h, v2.h[3]\n" + "fmla v21.8h, v7.8h, v3.h[3]\n" + "fmla v25.8h, v7.8h, v4.h[3]\n" + "ldr q7, [x28, #0x30]\n" + "fmla v10.8h, v6.8h, v0.h[3]\n" + "fmla v14.8h, v6.8h, v1.h[3]\n" + "fmla v18.8h, v6.8h, v2.h[3]\n" + "fmla v22.8h, v6.8h, v3.h[3]\n" + "fmla v26.8h, v6.8h, v4.h[3]\n" + "ldr q6, [x11, #0x40]\n" + "fmla v11.8h, v7.8h, v0.h[3]\n" + "fmla v15.8h, v7.8h, v1.h[3]\n" + "fmla v19.8h, v7.8h, v2.h[3]\n" + "fmla v23.8h, v7.8h, v3.h[3]\n" + "fmla v27.8h, v7.8h, v4.h[3]\n" + "ldr q7, [x10, #0x40]\n" + "fmla v8.8h, v6.8h, v0.h[4]\n" + "fmla v12.8h, v6.8h, v1.h[4]\n" + "fmla v16.8h, v6.8h, v2.h[4]\n" + "fmla v20.8h, v6.8h, v3.h[4]\n" + "fmla v24.8h, v6.8h, v4.h[4]\n" + "ldr q6, [x9, #0x40]\n" + "fmla v9.8h, v7.8h, v0.h[4]\n" + "fmla v13.8h, v7.8h, v1.h[4]\n" + "fmla v17.8h, v7.8h, v2.h[4]\n" + "fmla v21.8h, v7.8h, v3.h[4]\n" + "fmla v25.8h, v7.8h, v4.h[4]\n" + "ldr q7, [x28, #0x40]\n" + "fmla v10.8h, v6.8h, v0.h[4]\n" + "fmla v14.8h, v6.8h, v1.h[4]\n" + "fmla v18.8h, v6.8h, v2.h[4]\n" + "fmla v22.8h, v6.8h, v3.h[4]\n" + "fmla v26.8h, v6.8h, v4.h[4]\n" + "ldr q6, [x11, #0x50]\n" + "fmla v11.8h, v7.8h, v0.h[4]\n" + "fmla v15.8h, v7.8h, v1.h[4]\n" + "fmla v19.8h, v7.8h, v2.h[4]\n" + "fmla v23.8h, v7.8h, v3.h[4]\n" + "fmla v27.8h, v7.8h, v4.h[4]\n" + "ldr q7, [x10, #0x50]\n" + "fmla v8.8h, v6.8h, v0.h[5]\n" + "fmla v12.8h, v6.8h, v1.h[5]\n" + "fmla v16.8h, v6.8h, v2.h[5]\n" + "fmla v20.8h, v6.8h, v3.h[5]\n" + "fmla v24.8h, v6.8h, v4.h[5]\n" + "ldr q6, [x9, #0x50]\n" + "fmla v9.8h, v7.8h, v0.h[5]\n" + "fmla v13.8h, v7.8h, v1.h[5]\n" + "fmla v17.8h, v7.8h, v2.h[5]\n" + "fmla v21.8h, v7.8h, v3.h[5]\n" + "fmla v25.8h, v7.8h, v4.h[5]\n" + "ldr q7, [x28, #0x50]\n" + "fmla v10.8h, v6.8h, v0.h[5]\n" + "fmla v14.8h, v6.8h, v1.h[5]\n" + "fmla v18.8h, v6.8h, v2.h[5]\n" + "fmla v22.8h, v6.8h, v3.h[5]\n" + "fmla v26.8h, v6.8h, v4.h[5]\n" + "ldr q6, [x11, #0x60]\n" + "fmla v11.8h, v7.8h, v0.h[5]\n" + "fmla v15.8h, v7.8h, v1.h[5]\n" + "fmla v19.8h, v7.8h, v2.h[5]\n" + "fmla v23.8h, v7.8h, v3.h[5]\n" + "fmla v27.8h, v7.8h, v4.h[5]\n" + "ldr q7, [x10, #0x60]\n" + "fmla v8.8h, v6.8h, v0.h[6]\n" + "fmla v12.8h, v6.8h, v1.h[6]\n" + "fmla v16.8h, v6.8h, v2.h[6]\n" + "fmla v20.8h, v6.8h, v3.h[6]\n" + "fmla v24.8h, v6.8h, v4.h[6]\n" + "ldr q6, [x9, #0x60]\n" + "fmla v9.8h, v7.8h, v0.h[6]\n" + "fmla v13.8h, v7.8h, v1.h[6]\n" + "fmla v17.8h, v7.8h, v2.h[6]\n" + "fmla v21.8h, v7.8h, v3.h[6]\n" + "fmla v25.8h, v7.8h, v4.h[6]\n" + "ldr q7, [x28, #0x60]\n" + "fmla v10.8h, v6.8h, v0.h[6]\n" + "fmla v14.8h, v6.8h, v1.h[6]\n" + "fmla v18.8h, v6.8h, v2.h[6]\n" + "fmla v22.8h, v6.8h, v3.h[6]\n" + "fmla v26.8h, v6.8h, v4.h[6]\n" + "ldr q6, [x11, #0x70]\n" + "fmla v11.8h, v7.8h, v0.h[6]\n" + "add x11, x11, #0x80\n" + "fmla v15.8h, v7.8h, v1.h[6]\n" + "fmla v19.8h, v7.8h, v2.h[6]\n" + "fmla v23.8h, v7.8h, v3.h[6]\n" + "fmla v27.8h, v7.8h, v4.h[6]\n" + "ldr q7, [x10, #0x70]\n" + "add x10, x10, #0x80\n" + "fmla v8.8h, v6.8h, v0.h[7]\n" + "fmla v12.8h, v6.8h, v1.h[7]\n" + "fmla v16.8h, v6.8h, v2.h[7]\n" + "fmla v20.8h, v6.8h, v3.h[7]\n" + "fmla v24.8h, v6.8h, v4.h[7]\n" + "ldr q6, [x9, #0x70]\n" + "fmla v9.8h, v7.8h, v0.h[7]\n" + "add x9, x9, #0x80\n" + "fmla v13.8h, v7.8h, v1.h[7]\n" + "fmla v17.8h, v7.8h, v2.h[7]\n" + "fmla v21.8h, v7.8h, v3.h[7]\n" + "fmla v25.8h, v7.8h, v4.h[7]\n" + "ldr q7, [x28, #0x70]\n" + "add x28, x28, #0x80\n" + "fmla v10.8h, v6.8h, v0.h[7]\n" + "fmla v14.8h, v6.8h, v1.h[7]\n" + "fmla v18.8h, v6.8h, v2.h[7]\n" + "fmla v22.8h, v6.8h, v3.h[7]\n" + "fmla v26.8h, v6.8h, v4.h[7]\n" + "fmla v11.8h, v7.8h, v0.h[7]\n" + "fmla v15.8h, v7.8h, v1.h[7]\n" + "fmla v19.8h, v7.8h, v2.h[7]\n" + "fmla v23.8h, v7.8h, v3.h[7]\n" + "fmla v27.8h, v7.8h, v4.h[7]\n" + "229:" // Height 5: Multiply loop: Main loop skip + "cbz x26, 231f\n" + "230:" // Height 5: Multiply loop: Odd block loop + "ldr h0, [x25], #0x2\n" + "ldr h1, [x24], #0x2\n" + "sub x26, x26, #0x1\n" + "ldr h2, [x23], #0x2\n" + "ldr h3, [x22], #0x2\n" + "ldr h4, [x21], #0x2\n" + "ldr q6, [x11, #0x0]\n" + "fmla v8.8h, v6.8h, v0.h[0]\n" + "fmla v12.8h, v6.8h, v1.h[0]\n" + "ldr q7, [x10, #0x0]\n" + "fmla v16.8h, v6.8h, v2.h[0]\n" + "fmla v20.8h, v6.8h, v3.h[0]\n" + "add x11, x11, #0x10\n" + "fmla v24.8h, v6.8h, v4.h[0]\n" + "ldr q6, [x9, #0x0]\n" + "fmla v9.8h, v7.8h, v0.h[0]\n" + "add x10, x10, #0x10\n" + "fmla v13.8h, v7.8h, v1.h[0]\n" + "fmla v17.8h, v7.8h, v2.h[0]\n" + "add x9, x9, #0x10\n" + "fmla v21.8h, v7.8h, v3.h[0]\n" + "fmla v25.8h, v7.8h, v4.h[0]\n" + "ldr q7, [x28, #0x0]\n" + "add x28, x28, #0x10\n" + "fmla v10.8h, v6.8h, v0.h[0]\n" + "fmla v14.8h, v6.8h, v1.h[0]\n" + "fmla v18.8h, v6.8h, v2.h[0]\n" + "fmla v22.8h, v6.8h, v3.h[0]\n" + "fmla v26.8h, v6.8h, v4.h[0]\n" + "fmla v11.8h, v7.8h, v0.h[0]\n" + "fmla v15.8h, v7.8h, v1.h[0]\n" + "fmla v19.8h, v7.8h, v2.h[0]\n" + "fmla v23.8h, v7.8h, v3.h[0]\n" + "fmla v27.8h, v7.8h, v4.h[0]\n" + "cbnz x26, 230b\n" + "231:" // Height 5: Multiply loop: No odd multiplies + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 224b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "add x22, x23, x19, LSL #1\n" + "add x21, x22, x19, LSL #1\n" + "tbz %x[flags], #1, 232f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1r { v1.8h }, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1r { v0.8h }, [x19]\n" + "fmin v8.8h, v8.8h, v1.8h\n" + "fmin v9.8h, v9.8h, v1.8h\n" + "fmin v10.8h, v10.8h, v1.8h\n" + "fmin v11.8h, v11.8h, v1.8h\n" + "fmin v12.8h, v12.8h, v1.8h\n" + "fmin v13.8h, v13.8h, v1.8h\n" + "fmin v14.8h, v14.8h, v1.8h\n" + "fmin v15.8h, v15.8h, v1.8h\n" + "fmin v16.8h, v16.8h, v1.8h\n" + "fmin v17.8h, v17.8h, v1.8h\n" + "fmin v18.8h, v18.8h, v1.8h\n" + "fmin v19.8h, v19.8h, v1.8h\n" + "fmin v20.8h, v20.8h, v1.8h\n" + "fmin v21.8h, v21.8h, v1.8h\n" + "fmin v22.8h, v22.8h, v1.8h\n" + "fmin v23.8h, v23.8h, v1.8h\n" + "fmin v24.8h, v24.8h, v1.8h\n" + "fmin v25.8h, v25.8h, v1.8h\n" + "fmin v26.8h, v26.8h, v1.8h\n" + "fmin v27.8h, v27.8h, v1.8h\n" + "fmax v8.8h, v8.8h, v0.8h\n" + "fmax v9.8h, v9.8h, v0.8h\n" + "fmax v10.8h, v10.8h, v0.8h\n" + "fmax v11.8h, v11.8h, v0.8h\n" + "fmax v12.8h, v12.8h, v0.8h\n" + "fmax v13.8h, v13.8h, v0.8h\n" + "fmax v14.8h, v14.8h, v0.8h\n" + "fmax v15.8h, v15.8h, v0.8h\n" + "fmax v16.8h, v16.8h, v0.8h\n" + "fmax v17.8h, v17.8h, v0.8h\n" + "fmax v18.8h, v18.8h, v0.8h\n" + "fmax v19.8h, v19.8h, v0.8h\n" + "fmax v20.8h, v20.8h, v0.8h\n" + "fmax v21.8h, v21.8h, v0.8h\n" + "fmax v22.8h, v22.8h, v0.8h\n" + "fmax v23.8h, v23.8h, v0.8h\n" + "fmax v24.8h, v24.8h, v0.8h\n" + "fmax v25.8h, v25.8h, v0.8h\n" + "fmax v26.8h, v26.8h, v0.8h\n" + "fmax v27.8h, v27.8h, v0.8h\n" + "232:" // Height 5: No activation + "cmp x13, #0x20\n" + "bge 249f\n" + "tbz x13, #4, 240f\n" + "st1 { v8.8h }, [x12], #0x10\n" + "st1 { v9.8h }, [x12], #0x10\n" + "st1 { v12.8h }, [x24], #0x10\n" + "st1 { v13.8h }, [x24], #0x10\n" + "st1 { v16.8h }, [x23], #0x10\n" + "st1 { v17.8h }, [x23], #0x10\n" + "st1 { v20.8h }, [x22], #0x10\n" + "st1 { v21.8h }, [x22], #0x10\n" + "st1 { v24.8h }, [x21], #0x10\n" + "st1 { v25.8h }, [x21], #0x10\n" + "tbz x13, #3, 236f\n" + "st1 { v10.8h }, [x12], #0x10\n" + "st1 { v14.8h }, [x24], #0x10\n" + "st1 { v18.8h }, [x23], #0x10\n" + "st1 { v22.8h }, [x22], #0x10\n" + "st1 { v26.8h }, [x21], #0x10\n" + "tbz x13, #2, 234f\n" + "str d11, [x12], #0x8\n" + "str d15, [x24], #0x8\n" + "str d19, [x23], #0x8\n" + "str d23, [x22], #0x8\n" + "str d27, [x21], #0x8\n" + "tbz x13, #1, 233f\n" + "st1 { v11.s }[2], [x12], #0x4\n" + "st1 { v15.s }[2], [x24], #0x4\n" + "st1 { v19.s }[2], [x23], #0x4\n" + "st1 { v23.s }[2], [x22], #0x4\n" + "st1 { v27.s }[2], [x21], #0x4\n" + "tbz x13, #0, 248f\n" + "st1 { v11.h }[6], [x12]\n" + "st1 { v15.h }[6], [x24]\n" + "st1 { v19.h }[6], [x23]\n" + "st1 { v23.h }[6], [x22]\n" + "st1 { v27.h }[6], [x21]\n" + "b 248f\n" + "233:" // Height 5: Partial direct writeback: partial_1_28 + "tbz x13, #0, 248f\n" + "st1 { v11.h }[4], [x12]\n" + "st1 { v15.h }[4], [x24]\n" + "st1 { v19.h }[4], [x23]\n" + "st1 { v23.h }[4], [x22]\n" + "st1 { v27.h }[4], [x21]\n" + "b 248f\n" + "234:" // Height 5: Partial direct writeback: partial_2_24 + "tbz x13, #1, 235f\n" + "str s11, [x12], #0x4\n" + "str s15, [x24], #0x4\n" + "str s19, [x23], #0x4\n" + "str s23, [x22], #0x4\n" + "str s27, [x21], #0x4\n" + "tbz x13, #0, 248f\n" + "st1 { v11.h }[2], [x12]\n" + "st1 { v15.h }[2], [x24]\n" + "st1 { v19.h }[2], [x23]\n" + "st1 { v23.h }[2], [x22]\n" + "st1 { v27.h }[2], [x21]\n" + "b 248f\n" + "235:" // Height 5: Partial direct writeback: partial_1_24 + "tbz x13, #0, 248f\n" + "str h11, [x12, #0x0]\n" + "str h15, [x24, #0x0]\n" + "str h19, [x23, #0x0]\n" + "str h23, [x22, #0x0]\n" + "str h27, [x21, #0x0]\n" + "b 248f\n" + "236:" // Height 5: Partial direct writeback: partial_4_16 + "tbz x13, #2, 238f\n" + "str d10, [x12], #0x8\n" + "str d14, [x24], #0x8\n" + "str d18, [x23], #0x8\n" + "str d22, [x22], #0x8\n" + "str d26, [x21], #0x8\n" + "tbz x13, #1, 237f\n" + "st1 { v10.s }[2], [x12], #0x4\n" + "st1 { v14.s }[2], [x24], #0x4\n" + "st1 { v18.s }[2], [x23], #0x4\n" + "st1 { v22.s }[2], [x22], #0x4\n" + "st1 { v26.s }[2], [x21], #0x4\n" + "tbz x13, #0, 248f\n" + "st1 { v10.h }[6], [x12]\n" + "st1 { v14.h }[6], [x24]\n" + "st1 { v18.h }[6], [x23]\n" + "st1 { v22.h }[6], [x22]\n" + "st1 { v26.h }[6], [x21]\n" + "b 248f\n" + "237:" // Height 5: Partial direct writeback: partial_1_20 + "tbz x13, #0, 248f\n" + "st1 { v10.h }[4], [x12]\n" + "st1 { v14.h }[4], [x24]\n" + "st1 { v18.h }[4], [x23]\n" + "st1 { v22.h }[4], [x22]\n" + "st1 { v26.h }[4], [x21]\n" + "b 248f\n" + "238:" // Height 5: Partial direct writeback: partial_2_16 + "tbz x13, #1, 239f\n" + "str s10, [x12], #0x4\n" + "str s14, [x24], #0x4\n" + "str s18, [x23], #0x4\n" + "str s22, [x22], #0x4\n" + "str s26, [x21], #0x4\n" + "tbz x13, #0, 248f\n" + "st1 { v10.h }[2], [x12]\n" + "st1 { v14.h }[2], [x24]\n" + "st1 { v18.h }[2], [x23]\n" + "st1 { v22.h }[2], [x22]\n" + "st1 { v26.h }[2], [x21]\n" + "b 248f\n" + "239:" // Height 5: Partial direct writeback: partial_1_16 + "tbz x13, #0, 248f\n" + "str h10, [x12, #0x0]\n" + "str h14, [x24, #0x0]\n" + "str h18, [x23, #0x0]\n" + "str h22, [x22, #0x0]\n" + "str h26, [x21, #0x0]\n" + "b 248f\n" + "240:" // Height 5: Partial direct writeback: partial_8_0 + "tbz x13, #3, 244f\n" + "st1 { v8.8h }, [x12], #0x10\n" + "st1 { v12.8h }, [x24], #0x10\n" + "st1 { v16.8h }, [x23], #0x10\n" + "st1 { v20.8h }, [x22], #0x10\n" + "st1 { v24.8h }, [x21], #0x10\n" + "tbz x13, #2, 242f\n" + "str d9, [x12], #0x8\n" + "str d13, [x24], #0x8\n" + "str d17, [x23], #0x8\n" + "str d21, [x22], #0x8\n" + "str d25, [x21], #0x8\n" + "tbz x13, #1, 241f\n" + "st1 { v9.s }[2], [x12], #0x4\n" + "st1 { v13.s }[2], [x24], #0x4\n" + "st1 { v17.s }[2], [x23], #0x4\n" + "st1 { v21.s }[2], [x22], #0x4\n" + "st1 { v25.s }[2], [x21], #0x4\n" + "tbz x13, #0, 248f\n" + "st1 { v9.h }[6], [x12]\n" + "st1 { v13.h }[6], [x24]\n" + "st1 { v17.h }[6], [x23]\n" + "st1 { v21.h }[6], [x22]\n" + "st1 { v25.h }[6], [x21]\n" + "b 248f\n" + "241:" // Height 5: Partial direct writeback: partial_1_12 + "tbz x13, #0, 248f\n" + "st1 { v9.h }[4], [x12]\n" + "st1 { v13.h }[4], [x24]\n" + "st1 { v17.h }[4], [x23]\n" + "st1 { v21.h }[4], [x22]\n" + "st1 { v25.h }[4], [x21]\n" + "b 248f\n" + "242:" // Height 5: Partial direct writeback: partial_2_8 + "tbz x13, #1, 243f\n" + "str s9, [x12], #0x4\n" + "str s13, [x24], #0x4\n" + "str s17, [x23], #0x4\n" + "str s21, [x22], #0x4\n" + "str s25, [x21], #0x4\n" + "tbz x13, #0, 248f\n" + "st1 { v9.h }[2], [x12]\n" + "st1 { v13.h }[2], [x24]\n" + "st1 { v17.h }[2], [x23]\n" + "st1 { v21.h }[2], [x22]\n" + "st1 { v25.h }[2], [x21]\n" + "b 248f\n" + "243:" // Height 5: Partial direct writeback: partial_1_8 + "tbz x13, #0, 248f\n" + "str h9, [x12, #0x0]\n" + "str h13, [x24, #0x0]\n" + "str h17, [x23, #0x0]\n" + "str h21, [x22, #0x0]\n" + "str h25, [x21, #0x0]\n" + "b 248f\n" + "244:" // Height 5: Partial direct writeback: partial_4_0 + "tbz x13, #2, 246f\n" + "str d8, [x12], #0x8\n" + "str d12, [x24], #0x8\n" + "str d16, [x23], #0x8\n" + "str d20, [x22], #0x8\n" + "str d24, [x21], #0x8\n" + "tbz x13, #1, 245f\n" + "st1 { v8.s }[2], [x12], #0x4\n" + "st1 { v12.s }[2], [x24], #0x4\n" + "st1 { v16.s }[2], [x23], #0x4\n" + "st1 { v20.s }[2], [x22], #0x4\n" + "st1 { v24.s }[2], [x21], #0x4\n" + "tbz x13, #0, 248f\n" + "st1 { v8.h }[6], [x12]\n" + "st1 { v12.h }[6], [x24]\n" + "st1 { v16.h }[6], [x23]\n" + "st1 { v20.h }[6], [x22]\n" + "st1 { v24.h }[6], [x21]\n" + "b 248f\n" + "245:" // Height 5: Partial direct writeback: partial_1_4 + "tbz x13, #0, 248f\n" + "st1 { v8.h }[4], [x12]\n" + "st1 { v12.h }[4], [x24]\n" + "st1 { v16.h }[4], [x23]\n" + "st1 { v20.h }[4], [x22]\n" + "st1 { v24.h }[4], [x21]\n" + "b 248f\n" + "246:" // Height 5: Partial direct writeback: partial_2_0 + "tbz x13, #1, 247f\n" + "str s8, [x12], #0x4\n" + "str s12, [x24], #0x4\n" + "str s16, [x23], #0x4\n" + "str s20, [x22], #0x4\n" + "str s24, [x21], #0x4\n" + "tbz x13, #0, 248f\n" + "st1 { v8.h }[2], [x12]\n" + "st1 { v12.h }[2], [x24]\n" + "st1 { v16.h }[2], [x23]\n" + "st1 { v20.h }[2], [x22]\n" + "st1 { v24.h }[2], [x21]\n" + "b 248f\n" + "247:" // Height 5: Partial direct writeback: partial_1_0 + "str h8, [x12, #0x0]\n" + "str h12, [x24, #0x0]\n" + "str h16, [x23, #0x0]\n" + "str h20, [x22, #0x0]\n" + "str h24, [x21, #0x0]\n" + "248:" // Height 5: Partial direct writeback: Done + "b 250f\n" + "249:" // Height 5: Full writeback + "str q8, [x12, #0x0]\n" + "str q9, [x12, #0x10]\n" + "str q10, [x12, #0x20]\n" + "str q11, [x12, #0x30]\n" + "add x12, x12, #0x40\n" + "str q12, [x24, #0x0]\n" + "str q13, [x24, #0x10]\n" + "str q14, [x24, #0x20]\n" + "str q15, [x24, #0x30]\n" + "str q16, [x23, #0x0]\n" + "str q17, [x23, #0x10]\n" + "str q18, [x23, #0x20]\n" + "str q19, [x23, #0x30]\n" + "str q20, [x22, #0x0]\n" + "str q21, [x22, #0x10]\n" + "str q22, [x22, #0x20]\n" + "str q23, [x22, #0x30]\n" + "str q24, [x21, #0x0]\n" + "str q25, [x21, #0x10]\n" + "str q26, [x21, #0x20]\n" + "str q27, [x21, #0x30]\n" + "250:" // Height 5: Writeback done + "subs x13, x13, #0x20\n" + "bgt 202b\n" + "b 302f\n" + "251:" // Height 6 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x20, #0xc\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "mov x14, %x[bias]\n" + "mov x12, %x[output_ptr]\n" + "madd %x[output_ptr], x19, x20, %x[output_ptr]\n" + "252:" // Height 6: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "cmp x13, #0x18\n" + "bgt 253f\n" + "cmp x13, #0x10\n" + "mov x28, x11\n" + "bgt 253f\n" + "cmp x13, #0x8\n" + "mov x9, x11\n" + "bgt 253f\n" + "mov x10, x11\n" + "253:" // Height 6: B setup done + "cbz x14, 254f\n" + "ldr q8, [x14, #0x0]\n" + "ldr q9, [x14, #0x10]\n" + "mov v12.16b, v8.16b\n" + "mov v13.16b, v9.16b\n" + "ldr q10, [x14, #0x20]\n" + "ldr q11, [x14, #0x30]\n" + "mov v14.16b, v10.16b\n" + "mov v15.16b, v11.16b\n" + "mov v16.16b, v8.16b\n" + "mov v17.16b, v9.16b\n" + "add x14, x14, #0x40\n" + "mov v18.16b, v10.16b\n" + "mov v19.16b, v11.16b\n" + "mov v20.16b, v8.16b\n" + "mov v21.16b, v9.16b\n" + "mov v22.16b, v10.16b\n" + "mov v23.16b, v11.16b\n" + "mov v24.16b, v8.16b\n" + "mov v25.16b, v9.16b\n" + "mov v26.16b, v10.16b\n" + "mov v27.16b, v11.16b\n" + "mov v28.16b, v8.16b\n" + "mov v29.16b, v9.16b\n" + "mov v30.16b, v10.16b\n" + "mov v31.16b, v11.16b\n" + "b 273f\n" + "254:" // Height 6: no bias + "tbz %x[flags], #0, 272f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "add x22, x23, x19, LSL #1\n" + "add x21, x22, x19, LSL #1\n" + "cmp x13, #0x20\n" + "add x20, x21, x19, LSL #1\n" + "bge 271f\n" + "tbz x13, #4, 262f\n" + "ld1 { v8.8h }, [x12], #0x10\n" + "ld1 { v12.8h }, [x24], #0x10\n" + "ld1 { v16.8h }, [x23], #0x10\n" + "ld1 { v20.8h }, [x22], #0x10\n" + "ld1 { v24.8h }, [x21], #0x10\n" + "ld1 { v28.8h }, [x20], #0x10\n" + "ld1 { v9.8h }, [x12], #0x10\n" + "ld1 { v13.8h }, [x24], #0x10\n" + "ld1 { v17.8h }, [x23], #0x10\n" + "ld1 { v21.8h }, [x22], #0x10\n" + "ld1 { v25.8h }, [x21], #0x10\n" + "ld1 { v29.8h }, [x20], #0x10\n" + "tbz x13, #3, 258f\n" + "ld1 { v10.8h }, [x12], #0x10\n" + "ld1 { v14.8h }, [x24], #0x10\n" + "ld1 { v18.8h }, [x23], #0x10\n" + "ld1 { v22.8h }, [x22], #0x10\n" + "ld1 { v26.8h }, [x21], #0x10\n" + "ld1 { v30.8h }, [x20], #0x10\n" + "tbz x13, #2, 256f\n" + "ldr d11, [x12], #0x8\n" + "ldr d15, [x24], #0x8\n" + "ldr d19, [x23], #0x8\n" + "ldr d23, [x22], #0x8\n" + "ldr d27, [x21], #0x8\n" + "ldr d31, [x20], #0x8\n" + "tbz x13, #1, 255f\n" + "ld1 { v11.s }[2], [x12], #0x4\n" + "ld1 { v15.s }[2], [x24], #0x4\n" + "mov x19, #0x3c\n" + "ld1 { v19.s }[2], [x23], #0x4\n" + "ld1 { v23.s }[2], [x22], #0x4\n" + "ld1 { v27.s }[2], [x21], #0x4\n" + "ld1 { v31.s }[2], [x20], #0x4\n" + "tbz x13, #0, 270f\n" + "ld1 { v11.h }[6], [x12]\n" + "ld1 { v15.h }[6], [x24]\n" + "ld1 { v19.h }[6], [x23]\n" + "ld1 { v23.h }[6], [x22]\n" + "ld1 { v27.h }[6], [x21]\n" + "ld1 { v31.h }[6], [x20]\n" + "b 270f\n" + "255:" // Height 6: Partial accumulate: partial_1_28 + "mov x19, #0x38\n" + "tbz x13, #0, 270f\n" + "ld1 { v11.h }[4], [x12]\n" + "ld1 { v15.h }[4], [x24]\n" + "ld1 { v19.h }[4], [x23]\n" + "ld1 { v23.h }[4], [x22]\n" + "ld1 { v27.h }[4], [x21]\n" + "ld1 { v31.h }[4], [x20]\n" + "b 270f\n" + "256:" // Height 6: Partial accumulate: partial_2_24 + "tbz x13, #1, 257f\n" + "ldr s11, [x12], #0x4\n" + "ldr s15, [x24], #0x4\n" + "mov x19, #0x34\n" + "ldr s19, [x23], #0x4\n" + "ldr s23, [x22], #0x4\n" + "ldr s27, [x21], #0x4\n" + "ldr s31, [x20], #0x4\n" + "tbz x13, #0, 270f\n" + "ld1 { v11.h }[2], [x12]\n" + "ld1 { v15.h }[2], [x24]\n" + "ld1 { v19.h }[2], [x23]\n" + "ld1 { v23.h }[2], [x22]\n" + "ld1 { v27.h }[2], [x21]\n" + "ld1 { v31.h }[2], [x20]\n" + "b 270f\n" + "257:" // Height 6: Partial accumulate: partial_1_24 + "mov x19, #0x30\n" + "tbz x13, #0, 270f\n" + "ldr h11, [x12, #0x0]\n" + "ldr h15, [x24, #0x0]\n" + "ldr h19, [x23, #0x0]\n" + "ldr h23, [x22, #0x0]\n" + "ldr h27, [x21, #0x0]\n" + "ldr h31, [x20, #0x0]\n" + "b 270f\n" + "258:" // Height 6: Partial accumulate: partial_4_16 + "tbz x13, #2, 260f\n" + "ldr d10, [x12], #0x8\n" + "ldr d14, [x24], #0x8\n" + "ldr d18, [x23], #0x8\n" + "ldr d22, [x22], #0x8\n" + "ldr d26, [x21], #0x8\n" + "ldr d30, [x20], #0x8\n" + "tbz x13, #1, 259f\n" + "ld1 { v10.s }[2], [x12], #0x4\n" + "ld1 { v14.s }[2], [x24], #0x4\n" + "mov x19, #0x2c\n" + "ld1 { v18.s }[2], [x23], #0x4\n" + "ld1 { v22.s }[2], [x22], #0x4\n" + "ld1 { v26.s }[2], [x21], #0x4\n" + "ld1 { v30.s }[2], [x20], #0x4\n" + "tbz x13, #0, 270f\n" + "ld1 { v10.h }[6], [x12]\n" + "ld1 { v14.h }[6], [x24]\n" + "ld1 { v18.h }[6], [x23]\n" + "ld1 { v22.h }[6], [x22]\n" + "ld1 { v26.h }[6], [x21]\n" + "ld1 { v30.h }[6], [x20]\n" + "b 270f\n" + "259:" // Height 6: Partial accumulate: partial_1_20 + "mov x19, #0x28\n" + "tbz x13, #0, 270f\n" + "ld1 { v10.h }[4], [x12]\n" + "ld1 { v14.h }[4], [x24]\n" + "ld1 { v18.h }[4], [x23]\n" + "ld1 { v22.h }[4], [x22]\n" + "ld1 { v26.h }[4], [x21]\n" + "ld1 { v30.h }[4], [x20]\n" + "b 270f\n" + "260:" // Height 6: Partial accumulate: partial_2_16 + "tbz x13, #1, 261f\n" + "ldr s10, [x12], #0x4\n" + "ldr s14, [x24], #0x4\n" + "mov x19, #0x24\n" + "ldr s18, [x23], #0x4\n" + "ldr s22, [x22], #0x4\n" + "ldr s26, [x21], #0x4\n" + "ldr s30, [x20], #0x4\n" + "tbz x13, #0, 270f\n" + "ld1 { v10.h }[2], [x12]\n" + "ld1 { v14.h }[2], [x24]\n" + "ld1 { v18.h }[2], [x23]\n" + "ld1 { v22.h }[2], [x22]\n" + "ld1 { v26.h }[2], [x21]\n" + "ld1 { v30.h }[2], [x20]\n" + "b 270f\n" + "261:" // Height 6: Partial accumulate: partial_1_16 + "mov x19, #0x20\n" + "tbz x13, #0, 270f\n" + "ldr h10, [x12, #0x0]\n" + "ldr h14, [x24, #0x0]\n" + "ldr h18, [x23, #0x0]\n" + "ldr h22, [x22, #0x0]\n" + "ldr h26, [x21, #0x0]\n" + "ldr h30, [x20, #0x0]\n" + "b 270f\n" + "262:" // Height 6: Partial accumulate: partial_8_0 + "tbz x13, #3, 266f\n" + "ld1 { v8.8h }, [x12], #0x10\n" + "ld1 { v12.8h }, [x24], #0x10\n" + "ld1 { v16.8h }, [x23], #0x10\n" + "ld1 { v20.8h }, [x22], #0x10\n" + "ld1 { v24.8h }, [x21], #0x10\n" + "ld1 { v28.8h }, [x20], #0x10\n" + "tbz x13, #2, 264f\n" + "ldr d9, [x12], #0x8\n" + "ldr d13, [x24], #0x8\n" + "ldr d17, [x23], #0x8\n" + "ldr d21, [x22], #0x8\n" + "ldr d25, [x21], #0x8\n" + "ldr d29, [x20], #0x8\n" + "tbz x13, #1, 263f\n" + "ld1 { v9.s }[2], [x12], #0x4\n" + "ld1 { v13.s }[2], [x24], #0x4\n" + "mov x19, #0x1c\n" + "ld1 { v17.s }[2], [x23], #0x4\n" + "ld1 { v21.s }[2], [x22], #0x4\n" + "ld1 { v25.s }[2], [x21], #0x4\n" + "ld1 { v29.s }[2], [x20], #0x4\n" + "tbz x13, #0, 270f\n" + "ld1 { v9.h }[6], [x12]\n" + "ld1 { v13.h }[6], [x24]\n" + "ld1 { v17.h }[6], [x23]\n" + "ld1 { v21.h }[6], [x22]\n" + "ld1 { v25.h }[6], [x21]\n" + "ld1 { v29.h }[6], [x20]\n" + "b 270f\n" + "263:" // Height 6: Partial accumulate: partial_1_12 + "mov x19, #0x18\n" + "tbz x13, #0, 270f\n" + "ld1 { v9.h }[4], [x12]\n" + "ld1 { v13.h }[4], [x24]\n" + "ld1 { v17.h }[4], [x23]\n" + "ld1 { v21.h }[4], [x22]\n" + "ld1 { v25.h }[4], [x21]\n" + "ld1 { v29.h }[4], [x20]\n" + "b 270f\n" + "264:" // Height 6: Partial accumulate: partial_2_8 + "tbz x13, #1, 265f\n" + "ldr s9, [x12], #0x4\n" + "ldr s13, [x24], #0x4\n" + "mov x19, #0x14\n" + "ldr s17, [x23], #0x4\n" + "ldr s21, [x22], #0x4\n" + "ldr s25, [x21], #0x4\n" + "ldr s29, [x20], #0x4\n" + "tbz x13, #0, 270f\n" + "ld1 { v9.h }[2], [x12]\n" + "ld1 { v13.h }[2], [x24]\n" + "ld1 { v17.h }[2], [x23]\n" + "ld1 { v21.h }[2], [x22]\n" + "ld1 { v25.h }[2], [x21]\n" + "ld1 { v29.h }[2], [x20]\n" + "b 270f\n" + "265:" // Height 6: Partial accumulate: partial_1_8 + "mov x19, #0x10\n" + "tbz x13, #0, 270f\n" + "ldr h9, [x12, #0x0]\n" + "ldr h13, [x24, #0x0]\n" + "ldr h17, [x23, #0x0]\n" + "ldr h21, [x22, #0x0]\n" + "ldr h25, [x21, #0x0]\n" + "ldr h29, [x20, #0x0]\n" + "b 270f\n" + "266:" // Height 6: Partial accumulate: partial_4_0 + "tbz x13, #2, 268f\n" + "ldr d8, [x12], #0x8\n" + "ldr d12, [x24], #0x8\n" + "ldr d16, [x23], #0x8\n" + "ldr d20, [x22], #0x8\n" + "ldr d24, [x21], #0x8\n" + "ldr d28, [x20], #0x8\n" + "tbz x13, #1, 267f\n" + "ld1 { v8.s }[2], [x12], #0x4\n" + "ld1 { v12.s }[2], [x24], #0x4\n" + "mov x19, #0xc\n" + "ld1 { v16.s }[2], [x23], #0x4\n" + "ld1 { v20.s }[2], [x22], #0x4\n" + "ld1 { v24.s }[2], [x21], #0x4\n" + "ld1 { v28.s }[2], [x20], #0x4\n" + "tbz x13, #0, 270f\n" + "ld1 { v8.h }[6], [x12]\n" + "ld1 { v12.h }[6], [x24]\n" + "ld1 { v16.h }[6], [x23]\n" + "ld1 { v20.h }[6], [x22]\n" + "ld1 { v24.h }[6], [x21]\n" + "ld1 { v28.h }[6], [x20]\n" + "b 270f\n" + "267:" // Height 6: Partial accumulate: partial_1_4 + "mov x19, #0x8\n" + "tbz x13, #0, 270f\n" + "ld1 { v8.h }[4], [x12]\n" + "ld1 { v12.h }[4], [x24]\n" + "ld1 { v16.h }[4], [x23]\n" + "ld1 { v20.h }[4], [x22]\n" + "ld1 { v24.h }[4], [x21]\n" + "ld1 { v28.h }[4], [x20]\n" + "b 270f\n" + "268:" // Height 6: Partial accumulate: partial_2_0 + "tbz x13, #1, 269f\n" + "ldr s8, [x12], #0x4\n" + "ldr s12, [x24], #0x4\n" + "mov x19, #0x4\n" + "ldr s16, [x23], #0x4\n" + "ldr s20, [x22], #0x4\n" + "ldr s24, [x21], #0x4\n" + "ldr s28, [x20], #0x4\n" + "tbz x13, #0, 270f\n" + "ld1 { v8.h }[2], [x12]\n" + "ld1 { v12.h }[2], [x24]\n" + "ld1 { v16.h }[2], [x23]\n" + "ld1 { v20.h }[2], [x22]\n" + "ld1 { v24.h }[2], [x21]\n" + "ld1 { v28.h }[2], [x20]\n" + "b 270f\n" + "269:" // Height 6: Partial accumulate: partial_1_0 + "ldr h8, [x12, #0x0]\n" + "ldr h12, [x24, #0x0]\n" + "mov x19, #0x0\n" + "ldr h16, [x23, #0x0]\n" + "ldr h20, [x22, #0x0]\n" + "ldr h24, [x21, #0x0]\n" + "ldr h28, [x20, #0x0]\n" + "270:" // Height 6: Partial accumulate: Done + "sub x12, x12, x19\n" + "b 273f\n" + "271:" // Height 6: full accumulate + "ldr q8, [x12, #0x0]\n" + "ldr q9, [x12, #0x10]\n" + "ldr q10, [x12, #0x20]\n" + "ldr q11, [x12, #0x30]\n" + "ldr q12, [x24, #0x0]\n" + "ldr q13, [x24, #0x10]\n" + "ldr q14, [x24, #0x20]\n" + "ldr q15, [x24, #0x30]\n" + "ldr q16, [x23, #0x0]\n" + "ldr q17, [x23, #0x10]\n" + "ldr q18, [x23, #0x20]\n" + "ldr q19, [x23, #0x30]\n" + "ldr q20, [x22, #0x0]\n" + "ldr q21, [x22, #0x10]\n" + "ldr q22, [x22, #0x20]\n" + "ldr q23, [x22, #0x30]\n" + "ldr q24, [x21, #0x0]\n" + "ldr q25, [x21, #0x10]\n" + "ldr q26, [x21, #0x20]\n" + "ldr q27, [x21, #0x30]\n" + "ldr q28, [x20, #0x0]\n" + "ldr q29, [x20, #0x10]\n" + "ldr q30, [x20, #0x20]\n" + "ldr q31, [x20, #0x30]\n" + "b 273f\n" + "272:" // Height 6: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v26.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "movi v29.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "273:" // Height 6: setup done + "mov x27, #0x0\n" + "274:" // Height 6: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 275f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "ldr x22, [x20, #0x18]\n" + "ldr x21, [x20, #0x20]\n" + "ldr x20, [x20, #0x28]\n" + "cbnz x27, 276f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "add x24, x24, x19, LSL #1\n" + "add x23, x23, x19, LSL #1\n" + "add x22, x22, x19, LSL #1\n" + "add x21, x21, x19, LSL #1\n" + "add x20, x20, x19, LSL #1\n" + "b 276f\n" + "275:" // Height 6: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "add x22, x23, x19, LSL #1\n" + "add x21, x22, x19, LSL #1\n" + "add x20, x21, x19, LSL #1\n" + "276:" // Height 6: input setup done + "cmp x26, #0x8\n" + "blt 279f\n" + "ldr q0, [x25, #0x0]\n" + "ldr q1, [x24, #0x0]\n" + "cmp x26, #0x10\n" + "ldr q2, [x23, #0x0]\n" + "ldr q3, [x22, #0x0]\n" + "ldr q4, [x21, #0x0]\n" + "ldr q5, [x20, #0x0]\n" + "ldr q6, [x11, #0x0]\n" + "ldr q7, [x10, #0x0]\n" + "blt 278f\n" + "277:" // Height 6: Multiply loop: Main loop head + "fmla v8.8h, v6.8h, v0.h[0]\n" + "fmla v12.8h, v6.8h, v1.h[0]\n" + "sub x26, x26, #0x8\n" + "cmp x26, #0x10\n" + "fmla v16.8h, v6.8h, v2.h[0]\n" + "fmla v20.8h, v6.8h, v3.h[0]\n" + "add x25, x25, #0x10\n" + "add x24, x24, #0x10\n" + "fmla v24.8h, v6.8h, v4.h[0]\n" + "fmla v28.8h, v6.8h, v5.h[0]\n" + "ldr q6, [x9, #0x0]\n" + "add x23, x23, #0x10\n" + "fmla v9.8h, v7.8h, v0.h[0]\n" + "fmla v13.8h, v7.8h, v1.h[0]\n" + "add x22, x22, #0x10\n" + "add x21, x21, #0x10\n" + "fmla v17.8h, v7.8h, v2.h[0]\n" + "fmla v21.8h, v7.8h, v3.h[0]\n" + "add x20, x20, #0x10\n" + "fmla v25.8h, v7.8h, v4.h[0]\n" + "fmla v29.8h, v7.8h, v5.h[0]\n" + "ldr q7, [x28, #0x0]\n" + "fmla v10.8h, v6.8h, v0.h[0]\n" + "fmla v14.8h, v6.8h, v1.h[0]\n" + "fmla v18.8h, v6.8h, v2.h[0]\n" + "fmla v22.8h, v6.8h, v3.h[0]\n" + "fmla v26.8h, v6.8h, v4.h[0]\n" + "fmla v30.8h, v6.8h, v5.h[0]\n" + "ldr q6, [x11, #0x10]\n" + "fmla v11.8h, v7.8h, v0.h[0]\n" + "fmla v15.8h, v7.8h, v1.h[0]\n" + "fmla v19.8h, v7.8h, v2.h[0]\n" + "fmla v23.8h, v7.8h, v3.h[0]\n" + "fmla v27.8h, v7.8h, v4.h[0]\n" + "fmla v31.8h, v7.8h, v5.h[0]\n" + "ldr q7, [x10, #0x10]\n" + "fmla v8.8h, v6.8h, v0.h[1]\n" + "fmla v12.8h, v6.8h, v1.h[1]\n" + "fmla v16.8h, v6.8h, v2.h[1]\n" + "fmla v20.8h, v6.8h, v3.h[1]\n" + "fmla v24.8h, v6.8h, v4.h[1]\n" + "fmla v28.8h, v6.8h, v5.h[1]\n" + "ldr q6, [x9, #0x10]\n" + "fmla v9.8h, v7.8h, v0.h[1]\n" + "fmla v13.8h, v7.8h, v1.h[1]\n" + "fmla v17.8h, v7.8h, v2.h[1]\n" + "fmla v21.8h, v7.8h, v3.h[1]\n" + "fmla v25.8h, v7.8h, v4.h[1]\n" + "fmla v29.8h, v7.8h, v5.h[1]\n" + "ldr q7, [x28, #0x10]\n" + "fmla v10.8h, v6.8h, v0.h[1]\n" + "fmla v14.8h, v6.8h, v1.h[1]\n" + "fmla v18.8h, v6.8h, v2.h[1]\n" + "fmla v22.8h, v6.8h, v3.h[1]\n" + "fmla v26.8h, v6.8h, v4.h[1]\n" + "fmla v30.8h, v6.8h, v5.h[1]\n" + "ldr q6, [x11, #0x20]\n" + "fmla v11.8h, v7.8h, v0.h[1]\n" + "fmla v15.8h, v7.8h, v1.h[1]\n" + "fmla v19.8h, v7.8h, v2.h[1]\n" + "fmla v23.8h, v7.8h, v3.h[1]\n" + "fmla v27.8h, v7.8h, v4.h[1]\n" + "fmla v31.8h, v7.8h, v5.h[1]\n" + "ldr q7, [x10, #0x20]\n" + "fmla v8.8h, v6.8h, v0.h[2]\n" + "fmla v12.8h, v6.8h, v1.h[2]\n" + "fmla v16.8h, v6.8h, v2.h[2]\n" + "fmla v20.8h, v6.8h, v3.h[2]\n" + "fmla v24.8h, v6.8h, v4.h[2]\n" + "fmla v28.8h, v6.8h, v5.h[2]\n" + "ldr q6, [x9, #0x20]\n" + "fmla v9.8h, v7.8h, v0.h[2]\n" + "fmla v13.8h, v7.8h, v1.h[2]\n" + "fmla v17.8h, v7.8h, v2.h[2]\n" + "fmla v21.8h, v7.8h, v3.h[2]\n" + "fmla v25.8h, v7.8h, v4.h[2]\n" + "fmla v29.8h, v7.8h, v5.h[2]\n" + "ldr q7, [x28, #0x20]\n" + "fmla v10.8h, v6.8h, v0.h[2]\n" + "fmla v14.8h, v6.8h, v1.h[2]\n" + "fmla v18.8h, v6.8h, v2.h[2]\n" + "fmla v22.8h, v6.8h, v3.h[2]\n" + "fmla v26.8h, v6.8h, v4.h[2]\n" + "fmla v30.8h, v6.8h, v5.h[2]\n" + "ldr q6, [x11, #0x30]\n" + "fmla v11.8h, v7.8h, v0.h[2]\n" + "fmla v15.8h, v7.8h, v1.h[2]\n" + "fmla v19.8h, v7.8h, v2.h[2]\n" + "fmla v23.8h, v7.8h, v3.h[2]\n" + "fmla v27.8h, v7.8h, v4.h[2]\n" + "fmla v31.8h, v7.8h, v5.h[2]\n" + "ldr q7, [x10, #0x30]\n" + "fmla v8.8h, v6.8h, v0.h[3]\n" + "fmla v12.8h, v6.8h, v1.h[3]\n" + "fmla v16.8h, v6.8h, v2.h[3]\n" + "fmla v20.8h, v6.8h, v3.h[3]\n" + "fmla v24.8h, v6.8h, v4.h[3]\n" + "fmla v28.8h, v6.8h, v5.h[3]\n" + "ldr q6, [x9, #0x30]\n" + "fmla v9.8h, v7.8h, v0.h[3]\n" + "fmla v13.8h, v7.8h, v1.h[3]\n" + "fmla v17.8h, v7.8h, v2.h[3]\n" + "fmla v21.8h, v7.8h, v3.h[3]\n" + "fmla v25.8h, v7.8h, v4.h[3]\n" + "fmla v29.8h, v7.8h, v5.h[3]\n" + "ldr q7, [x28, #0x30]\n" + "fmla v10.8h, v6.8h, v0.h[3]\n" + "fmla v14.8h, v6.8h, v1.h[3]\n" + "fmla v18.8h, v6.8h, v2.h[3]\n" + "fmla v22.8h, v6.8h, v3.h[3]\n" + "fmla v26.8h, v6.8h, v4.h[3]\n" + "fmla v30.8h, v6.8h, v5.h[3]\n" + "ldr q6, [x11, #0x40]\n" + "fmla v11.8h, v7.8h, v0.h[3]\n" + "fmla v15.8h, v7.8h, v1.h[3]\n" + "fmla v19.8h, v7.8h, v2.h[3]\n" + "fmla v23.8h, v7.8h, v3.h[3]\n" + "fmla v27.8h, v7.8h, v4.h[3]\n" + "fmla v31.8h, v7.8h, v5.h[3]\n" + "ldr q7, [x10, #0x40]\n" + "fmla v8.8h, v6.8h, v0.h[4]\n" + "fmla v12.8h, v6.8h, v1.h[4]\n" + "fmla v16.8h, v6.8h, v2.h[4]\n" + "fmla v20.8h, v6.8h, v3.h[4]\n" + "fmla v24.8h, v6.8h, v4.h[4]\n" + "fmla v28.8h, v6.8h, v5.h[4]\n" + "ldr q6, [x9, #0x40]\n" + "fmla v9.8h, v7.8h, v0.h[4]\n" + "fmla v13.8h, v7.8h, v1.h[4]\n" + "fmla v17.8h, v7.8h, v2.h[4]\n" + "fmla v21.8h, v7.8h, v3.h[4]\n" + "fmla v25.8h, v7.8h, v4.h[4]\n" + "fmla v29.8h, v7.8h, v5.h[4]\n" + "ldr q7, [x28, #0x40]\n" + "fmla v10.8h, v6.8h, v0.h[4]\n" + "fmla v14.8h, v6.8h, v1.h[4]\n" + "fmla v18.8h, v6.8h, v2.h[4]\n" + "fmla v22.8h, v6.8h, v3.h[4]\n" + "fmla v26.8h, v6.8h, v4.h[4]\n" + "fmla v30.8h, v6.8h, v5.h[4]\n" + "ldr q6, [x11, #0x50]\n" + "fmla v11.8h, v7.8h, v0.h[4]\n" + "fmla v15.8h, v7.8h, v1.h[4]\n" + "fmla v19.8h, v7.8h, v2.h[4]\n" + "fmla v23.8h, v7.8h, v3.h[4]\n" + "fmla v27.8h, v7.8h, v4.h[4]\n" + "fmla v31.8h, v7.8h, v5.h[4]\n" + "ldr q7, [x10, #0x50]\n" + "fmla v8.8h, v6.8h, v0.h[5]\n" + "fmla v12.8h, v6.8h, v1.h[5]\n" + "fmla v16.8h, v6.8h, v2.h[5]\n" + "fmla v20.8h, v6.8h, v3.h[5]\n" + "fmla v24.8h, v6.8h, v4.h[5]\n" + "fmla v28.8h, v6.8h, v5.h[5]\n" + "ldr q6, [x9, #0x50]\n" + "fmla v9.8h, v7.8h, v0.h[5]\n" + "fmla v13.8h, v7.8h, v1.h[5]\n" + "fmla v17.8h, v7.8h, v2.h[5]\n" + "fmla v21.8h, v7.8h, v3.h[5]\n" + "fmla v25.8h, v7.8h, v4.h[5]\n" + "fmla v29.8h, v7.8h, v5.h[5]\n" + "ldr q7, [x28, #0x50]\n" + "fmla v10.8h, v6.8h, v0.h[5]\n" + "fmla v14.8h, v6.8h, v1.h[5]\n" + "fmla v18.8h, v6.8h, v2.h[5]\n" + "fmla v22.8h, v6.8h, v3.h[5]\n" + "fmla v26.8h, v6.8h, v4.h[5]\n" + "fmla v30.8h, v6.8h, v5.h[5]\n" + "ldr q6, [x11, #0x60]\n" + "fmla v11.8h, v7.8h, v0.h[5]\n" + "fmla v15.8h, v7.8h, v1.h[5]\n" + "fmla v19.8h, v7.8h, v2.h[5]\n" + "fmla v23.8h, v7.8h, v3.h[5]\n" + "fmla v27.8h, v7.8h, v4.h[5]\n" + "fmla v31.8h, v7.8h, v5.h[5]\n" + "ldr q7, [x10, #0x60]\n" + "fmla v8.8h, v6.8h, v0.h[6]\n" + "fmla v12.8h, v6.8h, v1.h[6]\n" + "fmla v16.8h, v6.8h, v2.h[6]\n" + "fmla v20.8h, v6.8h, v3.h[6]\n" + "fmla v24.8h, v6.8h, v4.h[6]\n" + "fmla v28.8h, v6.8h, v5.h[6]\n" + "ldr q6, [x9, #0x60]\n" + "fmla v9.8h, v7.8h, v0.h[6]\n" + "fmla v13.8h, v7.8h, v1.h[6]\n" + "fmla v17.8h, v7.8h, v2.h[6]\n" + "fmla v21.8h, v7.8h, v3.h[6]\n" + "fmla v25.8h, v7.8h, v4.h[6]\n" + "fmla v29.8h, v7.8h, v5.h[6]\n" + "ldr q7, [x28, #0x60]\n" + "fmla v10.8h, v6.8h, v0.h[6]\n" + "fmla v14.8h, v6.8h, v1.h[6]\n" + "fmla v18.8h, v6.8h, v2.h[6]\n" + "fmla v22.8h, v6.8h, v3.h[6]\n" + "fmla v26.8h, v6.8h, v4.h[6]\n" + "fmla v30.8h, v6.8h, v5.h[6]\n" + "ldr q6, [x11, #0x70]\n" + "add x11, x11, #0x80\n" + "fmla v11.8h, v7.8h, v0.h[6]\n" + "fmla v15.8h, v7.8h, v1.h[6]\n" + "fmla v19.8h, v7.8h, v2.h[6]\n" + "fmla v23.8h, v7.8h, v3.h[6]\n" + "fmla v27.8h, v7.8h, v4.h[6]\n" + "fmla v31.8h, v7.8h, v5.h[6]\n" + "ldr q7, [x10, #0x70]\n" + "add x10, x10, #0x80\n" + "fmla v8.8h, v6.8h, v0.h[7]\n" + "fmla v12.8h, v6.8h, v1.h[7]\n" + "fmla v16.8h, v6.8h, v2.h[7]\n" + "fmla v20.8h, v6.8h, v3.h[7]\n" + "fmla v24.8h, v6.8h, v4.h[7]\n" + "fmla v28.8h, v6.8h, v5.h[7]\n" + "ldr q6, [x9, #0x70]\n" + "add x9, x9, #0x80\n" + "fmla v9.8h, v7.8h, v0.h[7]\n" + "fmla v13.8h, v7.8h, v1.h[7]\n" + "fmla v17.8h, v7.8h, v2.h[7]\n" + "fmla v21.8h, v7.8h, v3.h[7]\n" + "fmla v25.8h, v7.8h, v4.h[7]\n" + "fmla v29.8h, v7.8h, v5.h[7]\n" + "ldr q7, [x28, #0x70]\n" + "add x28, x28, #0x80\n" + "fmla v10.8h, v6.8h, v0.h[7]\n" + "fmla v14.8h, v6.8h, v1.h[7]\n" + "fmla v18.8h, v6.8h, v2.h[7]\n" + "fmla v22.8h, v6.8h, v3.h[7]\n" + "fmla v26.8h, v6.8h, v4.h[7]\n" + "fmla v30.8h, v6.8h, v5.h[7]\n" + "ldr q6, [x11, #0x0]\n" + "fmla v11.8h, v7.8h, v0.h[7]\n" + "ldr q0, [x25, #0x0]\n" + "fmla v15.8h, v7.8h, v1.h[7]\n" + "ldr q1, [x24, #0x0]\n" + "fmla v19.8h, v7.8h, v2.h[7]\n" + "ldr q2, [x23, #0x0]\n" + "fmla v23.8h, v7.8h, v3.h[7]\n" + "ldr q3, [x22, #0x0]\n" + "fmla v27.8h, v7.8h, v4.h[7]\n" + "ldr q4, [x21, #0x0]\n" + "fmla v31.8h, v7.8h, v5.h[7]\n" + "ldr q5, [x20, #0x0]\n" + "ldr q7, [x10, #0x0]\n" + "bge 277b\n" + "278:" // Height 6: Multiply loop: Single iteration only + "fmla v8.8h, v6.8h, v0.h[0]\n" + "fmla v12.8h, v6.8h, v1.h[0]\n" + "sub x26, x26, #0x8\n" + "add x25, x25, #0x10\n" + "fmla v16.8h, v6.8h, v2.h[0]\n" + "fmla v20.8h, v6.8h, v3.h[0]\n" + "add x24, x24, #0x10\n" + "add x23, x23, #0x10\n" + "fmla v24.8h, v6.8h, v4.h[0]\n" + "fmla v28.8h, v6.8h, v5.h[0]\n" + "ldr q6, [x9, #0x0]\n" + "add x22, x22, #0x10\n" + "fmla v9.8h, v7.8h, v0.h[0]\n" + "fmla v13.8h, v7.8h, v1.h[0]\n" + "add x21, x21, #0x10\n" + "add x20, x20, #0x10\n" + "fmla v17.8h, v7.8h, v2.h[0]\n" + "fmla v21.8h, v7.8h, v3.h[0]\n" + "fmla v25.8h, v7.8h, v4.h[0]\n" + "fmla v29.8h, v7.8h, v5.h[0]\n" + "ldr q7, [x28, #0x0]\n" + "fmla v10.8h, v6.8h, v0.h[0]\n" + "fmla v14.8h, v6.8h, v1.h[0]\n" + "fmla v18.8h, v6.8h, v2.h[0]\n" + "fmla v22.8h, v6.8h, v3.h[0]\n" + "fmla v26.8h, v6.8h, v4.h[0]\n" + "fmla v30.8h, v6.8h, v5.h[0]\n" + "ldr q6, [x11, #0x10]\n" + "fmla v11.8h, v7.8h, v0.h[0]\n" + "fmla v15.8h, v7.8h, v1.h[0]\n" + "fmla v19.8h, v7.8h, v2.h[0]\n" + "fmla v23.8h, v7.8h, v3.h[0]\n" + "fmla v27.8h, v7.8h, v4.h[0]\n" + "fmla v31.8h, v7.8h, v5.h[0]\n" + "ldr q7, [x10, #0x10]\n" + "fmla v8.8h, v6.8h, v0.h[1]\n" + "fmla v12.8h, v6.8h, v1.h[1]\n" + "fmla v16.8h, v6.8h, v2.h[1]\n" + "fmla v20.8h, v6.8h, v3.h[1]\n" + "fmla v24.8h, v6.8h, v4.h[1]\n" + "fmla v28.8h, v6.8h, v5.h[1]\n" + "ldr q6, [x9, #0x10]\n" + "fmla v9.8h, v7.8h, v0.h[1]\n" + "fmla v13.8h, v7.8h, v1.h[1]\n" + "fmla v17.8h, v7.8h, v2.h[1]\n" + "fmla v21.8h, v7.8h, v3.h[1]\n" + "fmla v25.8h, v7.8h, v4.h[1]\n" + "fmla v29.8h, v7.8h, v5.h[1]\n" + "ldr q7, [x28, #0x10]\n" + "fmla v10.8h, v6.8h, v0.h[1]\n" + "fmla v14.8h, v6.8h, v1.h[1]\n" + "fmla v18.8h, v6.8h, v2.h[1]\n" + "fmla v22.8h, v6.8h, v3.h[1]\n" + "fmla v26.8h, v6.8h, v4.h[1]\n" + "fmla v30.8h, v6.8h, v5.h[1]\n" + "ldr q6, [x11, #0x20]\n" + "fmla v11.8h, v7.8h, v0.h[1]\n" + "fmla v15.8h, v7.8h, v1.h[1]\n" + "fmla v19.8h, v7.8h, v2.h[1]\n" + "fmla v23.8h, v7.8h, v3.h[1]\n" + "fmla v27.8h, v7.8h, v4.h[1]\n" + "fmla v31.8h, v7.8h, v5.h[1]\n" + "ldr q7, [x10, #0x20]\n" + "fmla v8.8h, v6.8h, v0.h[2]\n" + "fmla v12.8h, v6.8h, v1.h[2]\n" + "fmla v16.8h, v6.8h, v2.h[2]\n" + "fmla v20.8h, v6.8h, v3.h[2]\n" + "fmla v24.8h, v6.8h, v4.h[2]\n" + "fmla v28.8h, v6.8h, v5.h[2]\n" + "ldr q6, [x9, #0x20]\n" + "fmla v9.8h, v7.8h, v0.h[2]\n" + "fmla v13.8h, v7.8h, v1.h[2]\n" + "fmla v17.8h, v7.8h, v2.h[2]\n" + "fmla v21.8h, v7.8h, v3.h[2]\n" + "fmla v25.8h, v7.8h, v4.h[2]\n" + "fmla v29.8h, v7.8h, v5.h[2]\n" + "ldr q7, [x28, #0x20]\n" + "fmla v10.8h, v6.8h, v0.h[2]\n" + "fmla v14.8h, v6.8h, v1.h[2]\n" + "fmla v18.8h, v6.8h, v2.h[2]\n" + "fmla v22.8h, v6.8h, v3.h[2]\n" + "fmla v26.8h, v6.8h, v4.h[2]\n" + "fmla v30.8h, v6.8h, v5.h[2]\n" + "ldr q6, [x11, #0x30]\n" + "fmla v11.8h, v7.8h, v0.h[2]\n" + "fmla v15.8h, v7.8h, v1.h[2]\n" + "fmla v19.8h, v7.8h, v2.h[2]\n" + "fmla v23.8h, v7.8h, v3.h[2]\n" + "fmla v27.8h, v7.8h, v4.h[2]\n" + "fmla v31.8h, v7.8h, v5.h[2]\n" + "ldr q7, [x10, #0x30]\n" + "fmla v8.8h, v6.8h, v0.h[3]\n" + "fmla v12.8h, v6.8h, v1.h[3]\n" + "fmla v16.8h, v6.8h, v2.h[3]\n" + "fmla v20.8h, v6.8h, v3.h[3]\n" + "fmla v24.8h, v6.8h, v4.h[3]\n" + "fmla v28.8h, v6.8h, v5.h[3]\n" + "ldr q6, [x9, #0x30]\n" + "fmla v9.8h, v7.8h, v0.h[3]\n" + "fmla v13.8h, v7.8h, v1.h[3]\n" + "fmla v17.8h, v7.8h, v2.h[3]\n" + "fmla v21.8h, v7.8h, v3.h[3]\n" + "fmla v25.8h, v7.8h, v4.h[3]\n" + "fmla v29.8h, v7.8h, v5.h[3]\n" + "ldr q7, [x28, #0x30]\n" + "fmla v10.8h, v6.8h, v0.h[3]\n" + "fmla v14.8h, v6.8h, v1.h[3]\n" + "fmla v18.8h, v6.8h, v2.h[3]\n" + "fmla v22.8h, v6.8h, v3.h[3]\n" + "fmla v26.8h, v6.8h, v4.h[3]\n" + "fmla v30.8h, v6.8h, v5.h[3]\n" + "ldr q6, [x11, #0x40]\n" + "fmla v11.8h, v7.8h, v0.h[3]\n" + "fmla v15.8h, v7.8h, v1.h[3]\n" + "fmla v19.8h, v7.8h, v2.h[3]\n" + "fmla v23.8h, v7.8h, v3.h[3]\n" + "fmla v27.8h, v7.8h, v4.h[3]\n" + "fmla v31.8h, v7.8h, v5.h[3]\n" + "ldr q7, [x10, #0x40]\n" + "fmla v8.8h, v6.8h, v0.h[4]\n" + "fmla v12.8h, v6.8h, v1.h[4]\n" + "fmla v16.8h, v6.8h, v2.h[4]\n" + "fmla v20.8h, v6.8h, v3.h[4]\n" + "fmla v24.8h, v6.8h, v4.h[4]\n" + "fmla v28.8h, v6.8h, v5.h[4]\n" + "ldr q6, [x9, #0x40]\n" + "fmla v9.8h, v7.8h, v0.h[4]\n" + "fmla v13.8h, v7.8h, v1.h[4]\n" + "fmla v17.8h, v7.8h, v2.h[4]\n" + "fmla v21.8h, v7.8h, v3.h[4]\n" + "fmla v25.8h, v7.8h, v4.h[4]\n" + "fmla v29.8h, v7.8h, v5.h[4]\n" + "ldr q7, [x28, #0x40]\n" + "fmla v10.8h, v6.8h, v0.h[4]\n" + "fmla v14.8h, v6.8h, v1.h[4]\n" + "fmla v18.8h, v6.8h, v2.h[4]\n" + "fmla v22.8h, v6.8h, v3.h[4]\n" + "fmla v26.8h, v6.8h, v4.h[4]\n" + "fmla v30.8h, v6.8h, v5.h[4]\n" + "ldr q6, [x11, #0x50]\n" + "fmla v11.8h, v7.8h, v0.h[4]\n" + "fmla v15.8h, v7.8h, v1.h[4]\n" + "fmla v19.8h, v7.8h, v2.h[4]\n" + "fmla v23.8h, v7.8h, v3.h[4]\n" + "fmla v27.8h, v7.8h, v4.h[4]\n" + "fmla v31.8h, v7.8h, v5.h[4]\n" + "ldr q7, [x10, #0x50]\n" + "fmla v8.8h, v6.8h, v0.h[5]\n" + "fmla v12.8h, v6.8h, v1.h[5]\n" + "fmla v16.8h, v6.8h, v2.h[5]\n" + "fmla v20.8h, v6.8h, v3.h[5]\n" + "fmla v24.8h, v6.8h, v4.h[5]\n" + "fmla v28.8h, v6.8h, v5.h[5]\n" + "ldr q6, [x9, #0x50]\n" + "fmla v9.8h, v7.8h, v0.h[5]\n" + "fmla v13.8h, v7.8h, v1.h[5]\n" + "fmla v17.8h, v7.8h, v2.h[5]\n" + "fmla v21.8h, v7.8h, v3.h[5]\n" + "fmla v25.8h, v7.8h, v4.h[5]\n" + "fmla v29.8h, v7.8h, v5.h[5]\n" + "ldr q7, [x28, #0x50]\n" + "fmla v10.8h, v6.8h, v0.h[5]\n" + "fmla v14.8h, v6.8h, v1.h[5]\n" + "fmla v18.8h, v6.8h, v2.h[5]\n" + "fmla v22.8h, v6.8h, v3.h[5]\n" + "fmla v26.8h, v6.8h, v4.h[5]\n" + "fmla v30.8h, v6.8h, v5.h[5]\n" + "ldr q6, [x11, #0x60]\n" + "fmla v11.8h, v7.8h, v0.h[5]\n" + "fmla v15.8h, v7.8h, v1.h[5]\n" + "fmla v19.8h, v7.8h, v2.h[5]\n" + "fmla v23.8h, v7.8h, v3.h[5]\n" + "fmla v27.8h, v7.8h, v4.h[5]\n" + "fmla v31.8h, v7.8h, v5.h[5]\n" + "ldr q7, [x10, #0x60]\n" + "fmla v8.8h, v6.8h, v0.h[6]\n" + "fmla v12.8h, v6.8h, v1.h[6]\n" + "fmla v16.8h, v6.8h, v2.h[6]\n" + "fmla v20.8h, v6.8h, v3.h[6]\n" + "fmla v24.8h, v6.8h, v4.h[6]\n" + "fmla v28.8h, v6.8h, v5.h[6]\n" + "ldr q6, [x9, #0x60]\n" + "fmla v9.8h, v7.8h, v0.h[6]\n" + "fmla v13.8h, v7.8h, v1.h[6]\n" + "fmla v17.8h, v7.8h, v2.h[6]\n" + "fmla v21.8h, v7.8h, v3.h[6]\n" + "fmla v25.8h, v7.8h, v4.h[6]\n" + "fmla v29.8h, v7.8h, v5.h[6]\n" + "ldr q7, [x28, #0x60]\n" + "fmla v10.8h, v6.8h, v0.h[6]\n" + "fmla v14.8h, v6.8h, v1.h[6]\n" + "fmla v18.8h, v6.8h, v2.h[6]\n" + "fmla v22.8h, v6.8h, v3.h[6]\n" + "fmla v26.8h, v6.8h, v4.h[6]\n" + "fmla v30.8h, v6.8h, v5.h[6]\n" + "ldr q6, [x11, #0x70]\n" + "add x11, x11, #0x80\n" + "fmla v11.8h, v7.8h, v0.h[6]\n" + "fmla v15.8h, v7.8h, v1.h[6]\n" + "fmla v19.8h, v7.8h, v2.h[6]\n" + "fmla v23.8h, v7.8h, v3.h[6]\n" + "fmla v27.8h, v7.8h, v4.h[6]\n" + "fmla v31.8h, v7.8h, v5.h[6]\n" + "ldr q7, [x10, #0x70]\n" + "add x10, x10, #0x80\n" + "fmla v8.8h, v6.8h, v0.h[7]\n" + "fmla v12.8h, v6.8h, v1.h[7]\n" + "fmla v16.8h, v6.8h, v2.h[7]\n" + "fmla v20.8h, v6.8h, v3.h[7]\n" + "fmla v24.8h, v6.8h, v4.h[7]\n" + "fmla v28.8h, v6.8h, v5.h[7]\n" + "ldr q6, [x9, #0x70]\n" + "add x9, x9, #0x80\n" + "fmla v9.8h, v7.8h, v0.h[7]\n" + "fmla v13.8h, v7.8h, v1.h[7]\n" + "fmla v17.8h, v7.8h, v2.h[7]\n" + "fmla v21.8h, v7.8h, v3.h[7]\n" + "fmla v25.8h, v7.8h, v4.h[7]\n" + "fmla v29.8h, v7.8h, v5.h[7]\n" + "ldr q7, [x28, #0x70]\n" + "add x28, x28, #0x80\n" + "fmla v10.8h, v6.8h, v0.h[7]\n" + "fmla v14.8h, v6.8h, v1.h[7]\n" + "fmla v18.8h, v6.8h, v2.h[7]\n" + "fmla v22.8h, v6.8h, v3.h[7]\n" + "fmla v26.8h, v6.8h, v4.h[7]\n" + "fmla v30.8h, v6.8h, v5.h[7]\n" + "fmla v11.8h, v7.8h, v0.h[7]\n" + "fmla v15.8h, v7.8h, v1.h[7]\n" + "fmla v19.8h, v7.8h, v2.h[7]\n" + "fmla v23.8h, v7.8h, v3.h[7]\n" + "fmla v27.8h, v7.8h, v4.h[7]\n" + "fmla v31.8h, v7.8h, v5.h[7]\n" + "279:" // Height 6: Multiply loop: Main loop skip + "cbz x26, 281f\n" + "280:" // Height 6: Multiply loop: Odd block loop + "ldr h0, [x25], #0x2\n" + "ldr h1, [x24], #0x2\n" + "sub x26, x26, #0x1\n" + "ldr h2, [x23], #0x2\n" + "ldr h3, [x22], #0x2\n" + "ldr h4, [x21], #0x2\n" + "ldr h5, [x20], #0x2\n" + "ldr q6, [x11, #0x0]\n" + "ldr q7, [x10, #0x0]\n" + "fmla v8.8h, v6.8h, v0.h[0]\n" + "fmla v12.8h, v6.8h, v1.h[0]\n" + "fmla v16.8h, v6.8h, v2.h[0]\n" + "fmla v20.8h, v6.8h, v3.h[0]\n" + "add x11, x11, #0x10\n" + "add x10, x10, #0x10\n" + "fmla v24.8h, v6.8h, v4.h[0]\n" + "fmla v28.8h, v6.8h, v5.h[0]\n" + "ldr q6, [x9, #0x0]\n" + "add x9, x9, #0x10\n" + "fmla v9.8h, v7.8h, v0.h[0]\n" + "fmla v13.8h, v7.8h, v1.h[0]\n" + "fmla v17.8h, v7.8h, v2.h[0]\n" + "fmla v21.8h, v7.8h, v3.h[0]\n" + "fmla v25.8h, v7.8h, v4.h[0]\n" + "fmla v29.8h, v7.8h, v5.h[0]\n" + "ldr q7, [x28, #0x0]\n" + "add x28, x28, #0x10\n" + "fmla v10.8h, v6.8h, v0.h[0]\n" + "fmla v14.8h, v6.8h, v1.h[0]\n" + "fmla v18.8h, v6.8h, v2.h[0]\n" + "fmla v22.8h, v6.8h, v3.h[0]\n" + "fmla v26.8h, v6.8h, v4.h[0]\n" + "fmla v30.8h, v6.8h, v5.h[0]\n" + "fmla v11.8h, v7.8h, v0.h[0]\n" + "fmla v15.8h, v7.8h, v1.h[0]\n" + "fmla v19.8h, v7.8h, v2.h[0]\n" + "fmla v23.8h, v7.8h, v3.h[0]\n" + "fmla v27.8h, v7.8h, v4.h[0]\n" + "fmla v31.8h, v7.8h, v5.h[0]\n" + "cbnz x26, 280b\n" + "281:" // Height 6: Multiply loop: No odd multiplies + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 274b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "add x22, x23, x19, LSL #1\n" + "add x21, x22, x19, LSL #1\n" + "add x20, x21, x19, LSL #1\n" + "tbz %x[flags], #1, 282f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1r { v1.8h }, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1r { v0.8h }, [x19]\n" + "fmin v8.8h, v8.8h, v1.8h\n" + "fmin v9.8h, v9.8h, v1.8h\n" + "fmin v10.8h, v10.8h, v1.8h\n" + "fmin v11.8h, v11.8h, v1.8h\n" + "fmin v12.8h, v12.8h, v1.8h\n" + "fmin v13.8h, v13.8h, v1.8h\n" + "fmin v14.8h, v14.8h, v1.8h\n" + "fmin v15.8h, v15.8h, v1.8h\n" + "fmin v16.8h, v16.8h, v1.8h\n" + "fmin v17.8h, v17.8h, v1.8h\n" + "fmin v18.8h, v18.8h, v1.8h\n" + "fmin v19.8h, v19.8h, v1.8h\n" + "fmin v20.8h, v20.8h, v1.8h\n" + "fmin v21.8h, v21.8h, v1.8h\n" + "fmin v22.8h, v22.8h, v1.8h\n" + "fmin v23.8h, v23.8h, v1.8h\n" + "fmin v24.8h, v24.8h, v1.8h\n" + "fmin v25.8h, v25.8h, v1.8h\n" + "fmin v26.8h, v26.8h, v1.8h\n" + "fmin v27.8h, v27.8h, v1.8h\n" + "fmin v28.8h, v28.8h, v1.8h\n" + "fmin v29.8h, v29.8h, v1.8h\n" + "fmin v30.8h, v30.8h, v1.8h\n" + "fmin v31.8h, v31.8h, v1.8h\n" + "fmax v8.8h, v8.8h, v0.8h\n" + "fmax v9.8h, v9.8h, v0.8h\n" + "fmax v10.8h, v10.8h, v0.8h\n" + "fmax v11.8h, v11.8h, v0.8h\n" + "fmax v12.8h, v12.8h, v0.8h\n" + "fmax v13.8h, v13.8h, v0.8h\n" + "fmax v14.8h, v14.8h, v0.8h\n" + "fmax v15.8h, v15.8h, v0.8h\n" + "fmax v16.8h, v16.8h, v0.8h\n" + "fmax v17.8h, v17.8h, v0.8h\n" + "fmax v18.8h, v18.8h, v0.8h\n" + "fmax v19.8h, v19.8h, v0.8h\n" + "fmax v20.8h, v20.8h, v0.8h\n" + "fmax v21.8h, v21.8h, v0.8h\n" + "fmax v22.8h, v22.8h, v0.8h\n" + "fmax v23.8h, v23.8h, v0.8h\n" + "fmax v24.8h, v24.8h, v0.8h\n" + "fmax v25.8h, v25.8h, v0.8h\n" + "fmax v26.8h, v26.8h, v0.8h\n" + "fmax v27.8h, v27.8h, v0.8h\n" + "fmax v28.8h, v28.8h, v0.8h\n" + "fmax v29.8h, v29.8h, v0.8h\n" + "fmax v30.8h, v30.8h, v0.8h\n" + "fmax v31.8h, v31.8h, v0.8h\n" + "282:" // Height 6: No activation + "cmp x13, #0x20\n" + "bge 299f\n" + "tbz x13, #4, 290f\n" + "st1 { v8.8h }, [x12], #0x10\n" + "st1 { v9.8h }, [x12], #0x10\n" + "st1 { v12.8h }, [x24], #0x10\n" + "st1 { v13.8h }, [x24], #0x10\n" + "st1 { v16.8h }, [x23], #0x10\n" + "st1 { v17.8h }, [x23], #0x10\n" + "st1 { v20.8h }, [x22], #0x10\n" + "st1 { v21.8h }, [x22], #0x10\n" + "st1 { v24.8h }, [x21], #0x10\n" + "st1 { v25.8h }, [x21], #0x10\n" + "st1 { v28.8h }, [x20], #0x10\n" + "st1 { v29.8h }, [x20], #0x10\n" + "tbz x13, #3, 286f\n" + "st1 { v10.8h }, [x12], #0x10\n" + "st1 { v14.8h }, [x24], #0x10\n" + "st1 { v18.8h }, [x23], #0x10\n" + "st1 { v22.8h }, [x22], #0x10\n" + "st1 { v26.8h }, [x21], #0x10\n" + "st1 { v30.8h }, [x20], #0x10\n" + "tbz x13, #2, 284f\n" + "str d11, [x12], #0x8\n" + "str d15, [x24], #0x8\n" + "str d19, [x23], #0x8\n" + "str d23, [x22], #0x8\n" + "str d27, [x21], #0x8\n" + "str d31, [x20], #0x8\n" + "tbz x13, #1, 283f\n" + "st1 { v11.s }[2], [x12], #0x4\n" + "st1 { v15.s }[2], [x24], #0x4\n" + "st1 { v19.s }[2], [x23], #0x4\n" + "st1 { v23.s }[2], [x22], #0x4\n" + "st1 { v27.s }[2], [x21], #0x4\n" + "st1 { v31.s }[2], [x20], #0x4\n" + "tbz x13, #0, 298f\n" + "st1 { v11.h }[6], [x12]\n" + "st1 { v15.h }[6], [x24]\n" + "st1 { v19.h }[6], [x23]\n" + "st1 { v23.h }[6], [x22]\n" + "st1 { v27.h }[6], [x21]\n" + "st1 { v31.h }[6], [x20]\n" + "b 298f\n" + "283:" // Height 6: Partial direct writeback: partial_1_28 + "tbz x13, #0, 298f\n" + "st1 { v11.h }[4], [x12]\n" + "st1 { v15.h }[4], [x24]\n" + "st1 { v19.h }[4], [x23]\n" + "st1 { v23.h }[4], [x22]\n" + "st1 { v27.h }[4], [x21]\n" + "st1 { v31.h }[4], [x20]\n" + "b 298f\n" + "284:" // Height 6: Partial direct writeback: partial_2_24 + "tbz x13, #1, 285f\n" + "str s11, [x12], #0x4\n" + "str s15, [x24], #0x4\n" + "str s19, [x23], #0x4\n" + "str s23, [x22], #0x4\n" + "str s27, [x21], #0x4\n" + "str s31, [x20], #0x4\n" + "tbz x13, #0, 298f\n" + "st1 { v11.h }[2], [x12]\n" + "st1 { v15.h }[2], [x24]\n" + "st1 { v19.h }[2], [x23]\n" + "st1 { v23.h }[2], [x22]\n" + "st1 { v27.h }[2], [x21]\n" + "st1 { v31.h }[2], [x20]\n" + "b 298f\n" + "285:" // Height 6: Partial direct writeback: partial_1_24 + "tbz x13, #0, 298f\n" + "str h11, [x12, #0x0]\n" + "str h15, [x24, #0x0]\n" + "str h19, [x23, #0x0]\n" + "str h23, [x22, #0x0]\n" + "str h27, [x21, #0x0]\n" + "str h31, [x20, #0x0]\n" + "b 298f\n" + "286:" // Height 6: Partial direct writeback: partial_4_16 + "tbz x13, #2, 288f\n" + "str d10, [x12], #0x8\n" + "str d14, [x24], #0x8\n" + "str d18, [x23], #0x8\n" + "str d22, [x22], #0x8\n" + "str d26, [x21], #0x8\n" + "str d30, [x20], #0x8\n" + "tbz x13, #1, 287f\n" + "st1 { v10.s }[2], [x12], #0x4\n" + "st1 { v14.s }[2], [x24], #0x4\n" + "st1 { v18.s }[2], [x23], #0x4\n" + "st1 { v22.s }[2], [x22], #0x4\n" + "st1 { v26.s }[2], [x21], #0x4\n" + "st1 { v30.s }[2], [x20], #0x4\n" + "tbz x13, #0, 298f\n" + "st1 { v10.h }[6], [x12]\n" + "st1 { v14.h }[6], [x24]\n" + "st1 { v18.h }[6], [x23]\n" + "st1 { v22.h }[6], [x22]\n" + "st1 { v26.h }[6], [x21]\n" + "st1 { v30.h }[6], [x20]\n" + "b 298f\n" + "287:" // Height 6: Partial direct writeback: partial_1_20 + "tbz x13, #0, 298f\n" + "st1 { v10.h }[4], [x12]\n" + "st1 { v14.h }[4], [x24]\n" + "st1 { v18.h }[4], [x23]\n" + "st1 { v22.h }[4], [x22]\n" + "st1 { v26.h }[4], [x21]\n" + "st1 { v30.h }[4], [x20]\n" + "b 298f\n" + "288:" // Height 6: Partial direct writeback: partial_2_16 + "tbz x13, #1, 289f\n" + "str s10, [x12], #0x4\n" + "str s14, [x24], #0x4\n" + "str s18, [x23], #0x4\n" + "str s22, [x22], #0x4\n" + "str s26, [x21], #0x4\n" + "str s30, [x20], #0x4\n" + "tbz x13, #0, 298f\n" + "st1 { v10.h }[2], [x12]\n" + "st1 { v14.h }[2], [x24]\n" + "st1 { v18.h }[2], [x23]\n" + "st1 { v22.h }[2], [x22]\n" + "st1 { v26.h }[2], [x21]\n" + "st1 { v30.h }[2], [x20]\n" + "b 298f\n" + "289:" // Height 6: Partial direct writeback: partial_1_16 + "tbz x13, #0, 298f\n" + "str h10, [x12, #0x0]\n" + "str h14, [x24, #0x0]\n" + "str h18, [x23, #0x0]\n" + "str h22, [x22, #0x0]\n" + "str h26, [x21, #0x0]\n" + "str h30, [x20, #0x0]\n" + "b 298f\n" + "290:" // Height 6: Partial direct writeback: partial_8_0 + "tbz x13, #3, 294f\n" + "st1 { v8.8h }, [x12], #0x10\n" + "st1 { v12.8h }, [x24], #0x10\n" + "st1 { v16.8h }, [x23], #0x10\n" + "st1 { v20.8h }, [x22], #0x10\n" + "st1 { v24.8h }, [x21], #0x10\n" + "st1 { v28.8h }, [x20], #0x10\n" + "tbz x13, #2, 292f\n" + "str d9, [x12], #0x8\n" + "str d13, [x24], #0x8\n" + "str d17, [x23], #0x8\n" + "str d21, [x22], #0x8\n" + "str d25, [x21], #0x8\n" + "str d29, [x20], #0x8\n" + "tbz x13, #1, 291f\n" + "st1 { v9.s }[2], [x12], #0x4\n" + "st1 { v13.s }[2], [x24], #0x4\n" + "st1 { v17.s }[2], [x23], #0x4\n" + "st1 { v21.s }[2], [x22], #0x4\n" + "st1 { v25.s }[2], [x21], #0x4\n" + "st1 { v29.s }[2], [x20], #0x4\n" + "tbz x13, #0, 298f\n" + "st1 { v9.h }[6], [x12]\n" + "st1 { v13.h }[6], [x24]\n" + "st1 { v17.h }[6], [x23]\n" + "st1 { v21.h }[6], [x22]\n" + "st1 { v25.h }[6], [x21]\n" + "st1 { v29.h }[6], [x20]\n" + "b 298f\n" + "291:" // Height 6: Partial direct writeback: partial_1_12 + "tbz x13, #0, 298f\n" + "st1 { v9.h }[4], [x12]\n" + "st1 { v13.h }[4], [x24]\n" + "st1 { v17.h }[4], [x23]\n" + "st1 { v21.h }[4], [x22]\n" + "st1 { v25.h }[4], [x21]\n" + "st1 { v29.h }[4], [x20]\n" + "b 298f\n" + "292:" // Height 6: Partial direct writeback: partial_2_8 + "tbz x13, #1, 293f\n" + "str s9, [x12], #0x4\n" + "str s13, [x24], #0x4\n" + "str s17, [x23], #0x4\n" + "str s21, [x22], #0x4\n" + "str s25, [x21], #0x4\n" + "str s29, [x20], #0x4\n" + "tbz x13, #0, 298f\n" + "st1 { v9.h }[2], [x12]\n" + "st1 { v13.h }[2], [x24]\n" + "st1 { v17.h }[2], [x23]\n" + "st1 { v21.h }[2], [x22]\n" + "st1 { v25.h }[2], [x21]\n" + "st1 { v29.h }[2], [x20]\n" + "b 298f\n" + "293:" // Height 6: Partial direct writeback: partial_1_8 + "tbz x13, #0, 298f\n" + "str h9, [x12, #0x0]\n" + "str h13, [x24, #0x0]\n" + "str h17, [x23, #0x0]\n" + "str h21, [x22, #0x0]\n" + "str h25, [x21, #0x0]\n" + "str h29, [x20, #0x0]\n" + "b 298f\n" + "294:" // Height 6: Partial direct writeback: partial_4_0 + "tbz x13, #2, 296f\n" + "str d8, [x12], #0x8\n" + "str d12, [x24], #0x8\n" + "str d16, [x23], #0x8\n" + "str d20, [x22], #0x8\n" + "str d24, [x21], #0x8\n" + "str d28, [x20], #0x8\n" + "tbz x13, #1, 295f\n" + "st1 { v8.s }[2], [x12], #0x4\n" + "st1 { v12.s }[2], [x24], #0x4\n" + "st1 { v16.s }[2], [x23], #0x4\n" + "st1 { v20.s }[2], [x22], #0x4\n" + "st1 { v24.s }[2], [x21], #0x4\n" + "st1 { v28.s }[2], [x20], #0x4\n" + "tbz x13, #0, 298f\n" + "st1 { v8.h }[6], [x12]\n" + "st1 { v12.h }[6], [x24]\n" + "st1 { v16.h }[6], [x23]\n" + "st1 { v20.h }[6], [x22]\n" + "st1 { v24.h }[6], [x21]\n" + "st1 { v28.h }[6], [x20]\n" + "b 298f\n" + "295:" // Height 6: Partial direct writeback: partial_1_4 + "tbz x13, #0, 298f\n" + "st1 { v8.h }[4], [x12]\n" + "st1 { v12.h }[4], [x24]\n" + "st1 { v16.h }[4], [x23]\n" + "st1 { v20.h }[4], [x22]\n" + "st1 { v24.h }[4], [x21]\n" + "st1 { v28.h }[4], [x20]\n" + "b 298f\n" + "296:" // Height 6: Partial direct writeback: partial_2_0 + "tbz x13, #1, 297f\n" + "str s8, [x12], #0x4\n" + "str s12, [x24], #0x4\n" + "str s16, [x23], #0x4\n" + "str s20, [x22], #0x4\n" + "str s24, [x21], #0x4\n" + "str s28, [x20], #0x4\n" + "tbz x13, #0, 298f\n" + "st1 { v8.h }[2], [x12]\n" + "st1 { v12.h }[2], [x24]\n" + "st1 { v16.h }[2], [x23]\n" + "st1 { v20.h }[2], [x22]\n" + "st1 { v24.h }[2], [x21]\n" + "st1 { v28.h }[2], [x20]\n" + "b 298f\n" + "297:" // Height 6: Partial direct writeback: partial_1_0 + "str h8, [x12, #0x0]\n" + "str h12, [x24, #0x0]\n" + "str h16, [x23, #0x0]\n" + "str h20, [x22, #0x0]\n" + "str h24, [x21, #0x0]\n" + "str h28, [x20, #0x0]\n" + "298:" // Height 6: Partial direct writeback: Done + "b 300f\n" + "299:" // Height 6: Full writeback + "str q8, [x12, #0x0]\n" + "str q9, [x12, #0x10]\n" + "str q10, [x12, #0x20]\n" + "str q11, [x12, #0x30]\n" + "add x12, x12, #0x40\n" + "str q12, [x24, #0x0]\n" + "str q13, [x24, #0x10]\n" + "str q14, [x24, #0x20]\n" + "str q15, [x24, #0x30]\n" + "str q16, [x23, #0x0]\n" + "str q17, [x23, #0x10]\n" + "str q18, [x23, #0x20]\n" + "str q19, [x23, #0x30]\n" + "str q20, [x22, #0x0]\n" + "str q21, [x22, #0x10]\n" + "str q22, [x22, #0x20]\n" + "str q23, [x22, #0x30]\n" + "str q24, [x21, #0x0]\n" + "str q25, [x21, #0x10]\n" + "str q26, [x21, #0x20]\n" + "str q27, [x21, #0x30]\n" + "str q28, [x20, #0x0]\n" + "str q29, [x20, #0x10]\n" + "str q30, [x20, #0x20]\n" + "str q31, [x20, #0x30]\n" + "300:" // Height 6: Writeback done + "subs x13, x13, #0x20\n" + "bgt 252b\n" + "subs %x[M], %x[M], #0x6\n" + "beq 302f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 301f\n" + "add x20, x20, #0x6\n" + "str x20, [%x[args_ptr], %[offsetof_input_offset]]\n" + "b 1b\n" + "301:" // Update direct input + "mov x19, #0xc\n" + "madd %x[input_ptr], x19, x20, %x[input_ptr]\n" + "b 1b\n" + "302:" // Exit + : [M] "+&r" (M), [input_ptr] "+&r" (input_ptr), [output_ptr] "+&r" (output_ptr) + : [args_ptr] "r" (&ka), [bias] "r" (bias), [flags] "r" (flags), [offset_max] "I" (offsetof(KernelArgs, maxval)), [offset_min] "I" (offsetof(KernelArgs, minval)), [offsetof_B_ptr] "I" (offsetof(KernelArgs, B_ptr)), [offsetof_B_stride] "I" (offsetof(KernelArgs, B_stride)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_cur_B_ptr] "I" (offsetof(KernelArgs, cur_B_ptr)), [offsetof_input_initial_col] "I" (offsetof(KernelArgs, input_initial_col)), [offsetof_input_offset] "I" (offsetof(KernelArgs, input_offset)), [offsetof_num_strings] "I" (offsetof(KernelArgs, num_strings)), [offsetof_output_offset] "I" (offsetof(KernelArgs, output_offset)), [offsetof_string_lengths] "I" (offsetof(KernelArgs, string_lengths)) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", "x12", "x13", "x14", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +} + +} // namespace arm_gemm +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32_mla_6x16.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32_mla_6x16.hpp new file mode 100644 index 0000000000..08f5aeb2d8 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32_mla_6x16.hpp @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#pragma once +#ifdef __aarch64__ + +#include "../std_transforms_fixed.hpp" +#include "../kernel_weight_format.hpp" +#include "../performance_parameters.hpp" + +#define ARGLIST \ + unsigned int, const unsigned int *, \ + IndirectInputArg, \ + size_t, size_t, \ + const float *, \ + size_t, \ + IndirectOutputArg, \ + const float *, Activation, bool + +namespace arm_gemm +{ +// Actual kernel implementations +void a64_ffhybrid_fp32_mla_6x16( ARGLIST ); + +class cls_a64_ffhybrid_fp32_mla_6x16 +{ +public: + typedef float lhs_operand_type; + typedef float rhs_operand_type; + typedef float result_type; + + typedef void (*kern_type)( ARGLIST ); + + /* Kernel blocking parameters */ + static constexpr unsigned int out_height() + { + return 6; + } + static unsigned int stripe_width() + { + return 4; + } + + static KernelWeightFormat kernel_weight_format() + { + return KernelWeightFormat::VL128_BL32; + } + + static unsigned int out_width() + { + return 16; + } + + static constexpr unsigned int k_unroll() + { + return 1; + } + + static constexpr bool supports_accumulate() + { + return true; + } + + StdTransformsFixed transforms = {}; + template + static inline PerformanceParameters get_performance_parameters(const CPUInfo *ci) + { + if (std::is_same::value) { + switch (ci->get_cpu_model()) { + default: + return { 13.16 }; + } + } + + return { 1.0 }; + } + + // Default to the generic kernel + kern_type kernel=a64_ffhybrid_fp32_mla_6x16; + cls_a64_ffhybrid_fp32_mla_6x16(const CPUInfo *) + { + } +}; + +} // namespace arm_gemm + +#undef ARGLIST +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32_mla_6x16/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32_mla_6x16/generic.cpp new file mode 100644 index 0000000000..f811116a06 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32_mla_6x16/generic.cpp @@ -0,0 +1,3461 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#ifdef __aarch64__ + +#include "arm_gemm.hpp" +#include "../../utils.hpp" + +#include +#include + +namespace arm_gemm { + +void a64_ffhybrid_fp32_mla_6x16 ( + unsigned int num_strings, const unsigned int *string_lengths, IndirectInputArg A_arg, + size_t M, size_t N, const float *B_ptr, size_t B_stride, IndirectOutputArg output_arg, + const float *bias, Activation act, bool accumulate +) +{ + struct KernelArgs { + float maxval = static_cast(std::numeric_limits::infinity()); + float minval = - static_cast(std::numeric_limits::infinity()); + unsigned int num_strings = {}; + const unsigned int *string_lengths = {}; + size_t N = {}; + const float *B_ptr = {}; + const float *cur_B_ptr = {}; + size_t B_stride = {}; + size_t output_offset = {}; + size_t input_initial_col = {}; + size_t input_offset = {}; + } ka; + + unsigned long flags=0; + void *output_ptr; + void *input_ptr; + + if (output_arg.is_indirect) { + output_ptr=(void *)(output_arg.indirect.ptr); + ka.output_offset=output_arg.indirect.offset; + flags |= 0x4; + } else { + output_ptr=(void *)(output_arg.direct.base); + ka.output_offset=output_arg.direct.stride; + } + + if (A_arg.is_indirect) { + input_ptr=(void *)(A_arg.indirect.ptr); + ka.input_offset=A_arg.indirect.start_row; + ka.input_initial_col=A_arg.indirect.start_col; + flags |= 0x8; + } else { + assert(num_strings==1); + input_ptr=(void *)(A_arg.direct.base); + ka.input_offset=A_arg.direct.stride; + } + if (accumulate) { + flags |= 0x1; + } + ka.num_strings = num_strings; + ka.string_lengths = string_lengths; + ka.N = N; + ka.B_ptr = B_ptr; + ka.B_stride = B_stride; + switch(act.type) { + default: + case Activation::Type::None: + break; + case Activation::Type::BoundedReLU: + ka.maxval = static_cast(act.param1); + /* fall through */ + case Activation::Type::ReLU: + ka.minval = 0; + flags |= 0x2; + break; + } + __asm__ __volatile__( + "1:" // Row loop + "cmp %x[M], #0x6\n" + "bge 171f\n" + "cmp %x[M], #0x4\n" + "bgt 137f\n" + "beq 103f\n" + "cmp %x[M], #0x2\n" + "bgt 69f\n" + "beq 35f\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "2:" // Height 1: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #2\n" + "add x9, x10, x19, LSL #2\n" + "add x28, x9, x19, LSL #2\n" + "add x19, x28, x19, LSL #2\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "cmp x13, #0xc\n" + "bgt 3f\n" + "cmp x13, #0x8\n" + "mov x28, x11\n" + "bgt 3f\n" + "cmp x13, #0x4\n" + "mov x9, x11\n" + "bgt 3f\n" + "mov x10, x11\n" + "3:" // Height 1: B setup done + "cbz x14, 4f\n" + "ldr q8, [x14, #0x0]\n" + "ldr q9, [x14, #0x10]\n" + "ldr q10, [x14, #0x20]\n" + "ldr q11, [x14, #0x30]\n" + "add x14, x14, #0x40\n" + "b 15f\n" + "4:" // Height 1: no bias + "tbz %x[flags], #0, 14f\n" + "cmp x13, #0x10\n" + "bge 13f\n" + "tbz x13, #3, 8f\n" + "ld1 { v8.4s }, [x12], #0x10\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "tbz x13, #2, 6f\n" + "ld1 { v10.4s }, [x12], #0x10\n" + "tbz x13, #1, 5f\n" + "ldr d11, [x12], #0x8\n" + "mov x19, #0x38\n" + "tbz x13, #0, 12f\n" + "ld1 { v11.s }[2], [x12]\n" + "b 12f\n" + "5:" // Height 1: Partial accumulate: partial_1_12 + "mov x19, #0x30\n" + "tbz x13, #0, 12f\n" + "ldr s11, [x12, #0x0]\n" + "b 12f\n" + "6:" // Height 1: Partial accumulate: partial_2_8 + "tbz x13, #1, 7f\n" + "ldr d10, [x12], #0x8\n" + "mov x19, #0x28\n" + "tbz x13, #0, 12f\n" + "ld1 { v10.s }[2], [x12]\n" + "b 12f\n" + "7:" // Height 1: Partial accumulate: partial_1_8 + "mov x19, #0x20\n" + "tbz x13, #0, 12f\n" + "ldr s10, [x12, #0x0]\n" + "b 12f\n" + "8:" // Height 1: Partial accumulate: partial_4_0 + "tbz x13, #2, 10f\n" + "ld1 { v8.4s }, [x12], #0x10\n" + "tbz x13, #1, 9f\n" + "ldr d9, [x12], #0x8\n" + "mov x19, #0x18\n" + "tbz x13, #0, 12f\n" + "ld1 { v9.s }[2], [x12]\n" + "b 12f\n" + "9:" // Height 1: Partial accumulate: partial_1_4 + "mov x19, #0x10\n" + "tbz x13, #0, 12f\n" + "ldr s9, [x12, #0x0]\n" + "b 12f\n" + "10:" // Height 1: Partial accumulate: partial_2_0 + "tbz x13, #1, 11f\n" + "ldr d8, [x12], #0x8\n" + "mov x19, #0x8\n" + "tbz x13, #0, 12f\n" + "ld1 { v8.s }[2], [x12]\n" + "b 12f\n" + "11:" // Height 1: Partial accumulate: partial_1_0 + "ldr s8, [x12, #0x0]\n" + "mov x19, #0x0\n" + "12:" // Height 1: Partial accumulate: Done + "sub x12, x12, x19\n" + "b 15f\n" + "13:" // Height 1: full accumulate + "ldr q8, [x12, #0x0]\n" + "ldr q9, [x12, #0x10]\n" + "ldr q10, [x12, #0x20]\n" + "ldr q11, [x12, #0x30]\n" + "b 15f\n" + "14:" // Height 1: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "15:" // Height 1: setup done + "mov x27, #0x0\n" + "16:" // Height 1: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 17f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "cbnz x27, 18f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #2\n" + "b 18f\n" + "17:" // Height 1: setup direct input + "mov x25, %x[input_ptr]\n" + "18:" // Height 1: input setup done + "cmp x26, #0x4\n" + "blt 21f\n" + "ldr q0, [x25, #0x0]\n" + "ldr q6, [x11, #0x0]\n" + "cmp x26, #0x8\n" + "ldr q7, [x10, #0x0]\n" + "blt 20f\n" + "19:" // Height 1: Multiply loop: Main loop head + "fmla v8.4s, v6.4s, v0.s[0]\n" + "ldr q6, [x9, #0x0]\n" + "fmla v9.4s, v7.4s, v0.s[0]\n" + "ldr q7, [x28, #0x0]\n" + "fmla v10.4s, v6.4s, v0.s[0]\n" + "ldr q6, [x11, #0x10]\n" + "fmla v11.4s, v7.4s, v0.s[0]\n" + "ldr q7, [x10, #0x10]\n" + "fmla v8.4s, v6.4s, v0.s[1]\n" + "ldr q6, [x9, #0x10]\n" + "fmla v9.4s, v7.4s, v0.s[1]\n" + "ldr q7, [x28, #0x10]\n" + "fmla v10.4s, v6.4s, v0.s[1]\n" + "ldr q6, [x11, #0x20]\n" + "fmla v11.4s, v7.4s, v0.s[1]\n" + "ldr q7, [x10, #0x20]\n" + "fmla v8.4s, v6.4s, v0.s[2]\n" + "ldr q6, [x9, #0x20]\n" + "fmla v9.4s, v7.4s, v0.s[2]\n" + "ldr q7, [x28, #0x20]\n" + "fmla v10.4s, v6.4s, v0.s[2]\n" + "ldr q6, [x11, #0x30]\n" + "fmla v11.4s, v7.4s, v0.s[2]\n" + "ldr q7, [x10, #0x30]\n" + "fmla v8.4s, v6.4s, v0.s[3]\n" + "ldr q6, [x9, #0x30]\n" + "fmla v9.4s, v7.4s, v0.s[3]\n" + "ldr q7, [x28, #0x30]\n" + "sub x26, x26, #0x4\n" + "cmp x26, #0x8\n" + "fmla v10.4s, v6.4s, v0.s[3]\n" + "fmla v11.4s, v7.4s, v0.s[3]\n" + "add x25, x25, #0x10\n" + "ldr q0, [x25, #0x0]\n" + "add x11, x11, #0x40\n" + "ldr q6, [x11, #0x0]\n" + "add x10, x10, #0x40\n" + "ldr q7, [x10, #0x0]\n" + "add x9, x9, #0x40\n" + "add x28, x28, #0x40\n" + "bge 19b\n" + "20:" // Height 1: Multiply loop: Single iteration only + "fmla v8.4s, v6.4s, v0.s[0]\n" + "ldr q6, [x9, #0x0]\n" + "fmla v9.4s, v7.4s, v0.s[0]\n" + "ldr q7, [x28, #0x0]\n" + "fmla v10.4s, v6.4s, v0.s[0]\n" + "ldr q6, [x11, #0x10]\n" + "fmla v11.4s, v7.4s, v0.s[0]\n" + "ldr q7, [x10, #0x10]\n" + "fmla v8.4s, v6.4s, v0.s[1]\n" + "ldr q6, [x9, #0x10]\n" + "fmla v9.4s, v7.4s, v0.s[1]\n" + "ldr q7, [x28, #0x10]\n" + "fmla v10.4s, v6.4s, v0.s[1]\n" + "ldr q6, [x11, #0x20]\n" + "fmla v11.4s, v7.4s, v0.s[1]\n" + "ldr q7, [x10, #0x20]\n" + "fmla v8.4s, v6.4s, v0.s[2]\n" + "ldr q6, [x9, #0x20]\n" + "fmla v9.4s, v7.4s, v0.s[2]\n" + "ldr q7, [x28, #0x20]\n" + "fmla v10.4s, v6.4s, v0.s[2]\n" + "ldr q6, [x11, #0x30]\n" + "fmla v11.4s, v7.4s, v0.s[2]\n" + "ldr q7, [x10, #0x30]\n" + "fmla v8.4s, v6.4s, v0.s[3]\n" + "ldr q6, [x9, #0x30]\n" + "fmla v9.4s, v7.4s, v0.s[3]\n" + "ldr q7, [x28, #0x30]\n" + "sub x26, x26, #0x4\n" + "fmla v10.4s, v6.4s, v0.s[3]\n" + "fmla v11.4s, v7.4s, v0.s[3]\n" + "add x25, x25, #0x10\n" + "add x11, x11, #0x40\n" + "add x10, x10, #0x40\n" + "add x9, x9, #0x40\n" + "add x28, x28, #0x40\n" + "21:" // Height 1: Multiply loop: Main loop skip + "cbz x26, 23f\n" + "22:" // Height 1: Multiply loop: Odd block loop + "ldr s0, [x25], #0x4\n" + "ldr q6, [x11, #0x0]\n" + "fmla v8.4s, v6.4s, v0.s[0]\n" + "sub x26, x26, #0x1\n" + "ldr q7, [x10, #0x0]\n" + "ldr q6, [x9, #0x0]\n" + "fmla v9.4s, v7.4s, v0.s[0]\n" + "fmla v10.4s, v6.4s, v0.s[0]\n" + "ldr q7, [x28, #0x0]\n" + "fmla v11.4s, v7.4s, v0.s[0]\n" + "add x11, x11, #0x10\n" + "add x10, x10, #0x10\n" + "add x9, x9, #0x10\n" + "add x28, x28, #0x10\n" + "cbnz x26, 22b\n" + "23:" // Height 1: Multiply loop: No odd multiplies + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 16b\n" + "tbz %x[flags], #1, 24f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1r { v1.4s }, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1r { v0.4s }, [x19]\n" + "fmin v8.4s, v8.4s, v1.4s\n" + "fmin v9.4s, v9.4s, v1.4s\n" + "fmin v10.4s, v10.4s, v1.4s\n" + "fmin v11.4s, v11.4s, v1.4s\n" + "fmax v8.4s, v8.4s, v0.4s\n" + "fmax v9.4s, v9.4s, v0.4s\n" + "fmax v10.4s, v10.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + "24:" // Height 1: No activation + "cmp x13, #0x10\n" + "bge 33f\n" + "tbz x13, #3, 28f\n" + "st1 { v8.4s }, [x12], #0x10\n" + "st1 { v9.4s }, [x12], #0x10\n" + "tbz x13, #2, 26f\n" + "st1 { v10.4s }, [x12], #0x10\n" + "tbz x13, #1, 25f\n" + "str d11, [x12], #0x8\n" + "tbz x13, #0, 32f\n" + "st1 { v11.s }[2], [x12]\n" + "b 32f\n" + "25:" // Height 1: Partial direct writeback: partial_1_12 + "tbz x13, #0, 32f\n" + "str s11, [x12, #0x0]\n" + "b 32f\n" + "26:" // Height 1: Partial direct writeback: partial_2_8 + "tbz x13, #1, 27f\n" + "str d10, [x12], #0x8\n" + "tbz x13, #0, 32f\n" + "st1 { v10.s }[2], [x12]\n" + "b 32f\n" + "27:" // Height 1: Partial direct writeback: partial_1_8 + "tbz x13, #0, 32f\n" + "str s10, [x12, #0x0]\n" + "b 32f\n" + "28:" // Height 1: Partial direct writeback: partial_4_0 + "tbz x13, #2, 30f\n" + "st1 { v8.4s }, [x12], #0x10\n" + "tbz x13, #1, 29f\n" + "str d9, [x12], #0x8\n" + "tbz x13, #0, 32f\n" + "st1 { v9.s }[2], [x12]\n" + "b 32f\n" + "29:" // Height 1: Partial direct writeback: partial_1_4 + "tbz x13, #0, 32f\n" + "str s9, [x12, #0x0]\n" + "b 32f\n" + "30:" // Height 1: Partial direct writeback: partial_2_0 + "tbz x13, #1, 31f\n" + "str d8, [x12], #0x8\n" + "tbz x13, #0, 32f\n" + "st1 { v8.s }[2], [x12]\n" + "b 32f\n" + "31:" // Height 1: Partial direct writeback: partial_1_0 + "str s8, [x12, #0x0]\n" + "32:" // Height 1: Partial direct writeback: Done + "b 34f\n" + "33:" // Height 1: Full writeback + "str q8, [x12, #0x0]\n" + "str q9, [x12, #0x10]\n" + "str q10, [x12, #0x20]\n" + "str q11, [x12, #0x30]\n" + "add x12, x12, #0x40\n" + "34:" // Height 1: Writeback done + "subs x13, x13, #0x10\n" + "bgt 2b\n" + "b 206f\n" + "35:" // Height 2 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "36:" // Height 2: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #2\n" + "add x9, x10, x19, LSL #2\n" + "add x28, x9, x19, LSL #2\n" + "add x19, x28, x19, LSL #2\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "cmp x13, #0xc\n" + "bgt 37f\n" + "cmp x13, #0x8\n" + "mov x28, x11\n" + "bgt 37f\n" + "cmp x13, #0x4\n" + "mov x9, x11\n" + "bgt 37f\n" + "mov x10, x11\n" + "37:" // Height 2: B setup done + "cbz x14, 38f\n" + "ldr q8, [x14, #0x0]\n" + "ldr q9, [x14, #0x10]\n" + "mov v12.16b, v8.16b\n" + "mov v13.16b, v9.16b\n" + "ldr q10, [x14, #0x20]\n" + "ldr q11, [x14, #0x30]\n" + "mov v14.16b, v10.16b\n" + "mov v15.16b, v11.16b\n" + "add x14, x14, #0x40\n" + "b 49f\n" + "38:" // Height 2: no bias + "tbz %x[flags], #0, 48f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "cmp x13, #0x10\n" + "add x24, x12, x19, LSL #2\n" + "bge 47f\n" + "tbz x13, #3, 42f\n" + "ld1 { v8.4s }, [x12], #0x10\n" + "ld1 { v12.4s }, [x24], #0x10\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "ld1 { v13.4s }, [x24], #0x10\n" + "tbz x13, #2, 40f\n" + "ld1 { v10.4s }, [x12], #0x10\n" + "ld1 { v14.4s }, [x24], #0x10\n" + "tbz x13, #1, 39f\n" + "ldr d11, [x12], #0x8\n" + "ldr d15, [x24], #0x8\n" + "mov x19, #0x38\n" + "tbz x13, #0, 46f\n" + "ld1 { v11.s }[2], [x12]\n" + "ld1 { v15.s }[2], [x24]\n" + "b 46f\n" + "39:" // Height 2: Partial accumulate: partial_1_12 + "mov x19, #0x30\n" + "tbz x13, #0, 46f\n" + "ldr s11, [x12, #0x0]\n" + "ldr s15, [x24, #0x0]\n" + "b 46f\n" + "40:" // Height 2: Partial accumulate: partial_2_8 + "tbz x13, #1, 41f\n" + "ldr d10, [x12], #0x8\n" + "ldr d14, [x24], #0x8\n" + "mov x19, #0x28\n" + "tbz x13, #0, 46f\n" + "ld1 { v10.s }[2], [x12]\n" + "ld1 { v14.s }[2], [x24]\n" + "b 46f\n" + "41:" // Height 2: Partial accumulate: partial_1_8 + "mov x19, #0x20\n" + "tbz x13, #0, 46f\n" + "ldr s10, [x12, #0x0]\n" + "ldr s14, [x24, #0x0]\n" + "b 46f\n" + "42:" // Height 2: Partial accumulate: partial_4_0 + "tbz x13, #2, 44f\n" + "ld1 { v8.4s }, [x12], #0x10\n" + "ld1 { v12.4s }, [x24], #0x10\n" + "tbz x13, #1, 43f\n" + "ldr d9, [x12], #0x8\n" + "ldr d13, [x24], #0x8\n" + "mov x19, #0x18\n" + "tbz x13, #0, 46f\n" + "ld1 { v9.s }[2], [x12]\n" + "ld1 { v13.s }[2], [x24]\n" + "b 46f\n" + "43:" // Height 2: Partial accumulate: partial_1_4 + "mov x19, #0x10\n" + "tbz x13, #0, 46f\n" + "ldr s9, [x12, #0x0]\n" + "ldr s13, [x24, #0x0]\n" + "b 46f\n" + "44:" // Height 2: Partial accumulate: partial_2_0 + "tbz x13, #1, 45f\n" + "ldr d8, [x12], #0x8\n" + "ldr d12, [x24], #0x8\n" + "mov x19, #0x8\n" + "tbz x13, #0, 46f\n" + "ld1 { v8.s }[2], [x12]\n" + "ld1 { v12.s }[2], [x24]\n" + "b 46f\n" + "45:" // Height 2: Partial accumulate: partial_1_0 + "ldr s8, [x12, #0x0]\n" + "ldr s12, [x24, #0x0]\n" + "mov x19, #0x0\n" + "46:" // Height 2: Partial accumulate: Done + "sub x12, x12, x19\n" + "b 49f\n" + "47:" // Height 2: full accumulate + "ldr q8, [x12, #0x0]\n" + "ldr q9, [x12, #0x10]\n" + "ldr q10, [x12, #0x20]\n" + "ldr q11, [x12, #0x30]\n" + "ldr q12, [x24, #0x0]\n" + "ldr q13, [x24, #0x10]\n" + "ldr q14, [x24, #0x20]\n" + "ldr q15, [x24, #0x30]\n" + "b 49f\n" + "48:" // Height 2: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "49:" // Height 2: setup done + "mov x27, #0x0\n" + "50:" // Height 2: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 51f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "cbnz x27, 52f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #2\n" + "add x24, x24, x19, LSL #2\n" + "b 52f\n" + "51:" // Height 2: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #2\n" + "52:" // Height 2: input setup done + "cmp x26, #0x4\n" + "blt 55f\n" + "ldr q0, [x25, #0x0]\n" + "ldr q1, [x24, #0x0]\n" + "cmp x26, #0x8\n" + "ldr q6, [x11, #0x0]\n" + "ldr q7, [x10, #0x0]\n" + "blt 54f\n" + "53:" // Height 2: Multiply loop: Main loop head + "fmla v8.4s, v6.4s, v0.s[0]\n" + "fmla v12.4s, v6.4s, v1.s[0]\n" + "ldr q6, [x9, #0x0]\n" + "sub x26, x26, #0x4\n" + "fmla v9.4s, v7.4s, v0.s[0]\n" + "fmla v13.4s, v7.4s, v1.s[0]\n" + "ldr q7, [x28, #0x0]\n" + "cmp x26, #0x8\n" + "fmla v10.4s, v6.4s, v0.s[0]\n" + "fmla v14.4s, v6.4s, v1.s[0]\n" + "ldr q6, [x11, #0x10]\n" + "add x25, x25, #0x10\n" + "fmla v11.4s, v7.4s, v0.s[0]\n" + "fmla v15.4s, v7.4s, v1.s[0]\n" + "ldr q7, [x10, #0x10]\n" + "add x24, x24, #0x10\n" + "fmla v8.4s, v6.4s, v0.s[1]\n" + "fmla v12.4s, v6.4s, v1.s[1]\n" + "ldr q6, [x9, #0x10]\n" + "fmla v9.4s, v7.4s, v0.s[1]\n" + "fmla v13.4s, v7.4s, v1.s[1]\n" + "ldr q7, [x28, #0x10]\n" + "fmla v10.4s, v6.4s, v0.s[1]\n" + "fmla v14.4s, v6.4s, v1.s[1]\n" + "ldr q6, [x11, #0x20]\n" + "fmla v11.4s, v7.4s, v0.s[1]\n" + "fmla v15.4s, v7.4s, v1.s[1]\n" + "ldr q7, [x10, #0x20]\n" + "fmla v8.4s, v6.4s, v0.s[2]\n" + "fmla v12.4s, v6.4s, v1.s[2]\n" + "ldr q6, [x9, #0x20]\n" + "fmla v9.4s, v7.4s, v0.s[2]\n" + "fmla v13.4s, v7.4s, v1.s[2]\n" + "ldr q7, [x28, #0x20]\n" + "fmla v10.4s, v6.4s, v0.s[2]\n" + "fmla v14.4s, v6.4s, v1.s[2]\n" + "ldr q6, [x11, #0x30]\n" + "add x11, x11, #0x40\n" + "fmla v11.4s, v7.4s, v0.s[2]\n" + "fmla v15.4s, v7.4s, v1.s[2]\n" + "ldr q7, [x10, #0x30]\n" + "add x10, x10, #0x40\n" + "fmla v8.4s, v6.4s, v0.s[3]\n" + "fmla v12.4s, v6.4s, v1.s[3]\n" + "ldr q6, [x9, #0x30]\n" + "add x9, x9, #0x40\n" + "fmla v9.4s, v7.4s, v0.s[3]\n" + "fmla v13.4s, v7.4s, v1.s[3]\n" + "ldr q7, [x28, #0x30]\n" + "add x28, x28, #0x40\n" + "fmla v10.4s, v6.4s, v0.s[3]\n" + "fmla v14.4s, v6.4s, v1.s[3]\n" + "ldr q6, [x11, #0x0]\n" + "fmla v11.4s, v7.4s, v0.s[3]\n" + "ldr q0, [x25, #0x0]\n" + "fmla v15.4s, v7.4s, v1.s[3]\n" + "ldr q1, [x24, #0x0]\n" + "ldr q7, [x10, #0x0]\n" + "bge 53b\n" + "54:" // Height 2: Multiply loop: Single iteration only + "fmla v8.4s, v6.4s, v0.s[0]\n" + "fmla v12.4s, v6.4s, v1.s[0]\n" + "ldr q6, [x9, #0x0]\n" + "sub x26, x26, #0x4\n" + "fmla v9.4s, v7.4s, v0.s[0]\n" + "fmla v13.4s, v7.4s, v1.s[0]\n" + "ldr q7, [x28, #0x0]\n" + "add x25, x25, #0x10\n" + "fmla v10.4s, v6.4s, v0.s[0]\n" + "fmla v14.4s, v6.4s, v1.s[0]\n" + "ldr q6, [x11, #0x10]\n" + "add x24, x24, #0x10\n" + "fmla v11.4s, v7.4s, v0.s[0]\n" + "fmla v15.4s, v7.4s, v1.s[0]\n" + "ldr q7, [x10, #0x10]\n" + "fmla v8.4s, v6.4s, v0.s[1]\n" + "fmla v12.4s, v6.4s, v1.s[1]\n" + "ldr q6, [x9, #0x10]\n" + "fmla v9.4s, v7.4s, v0.s[1]\n" + "fmla v13.4s, v7.4s, v1.s[1]\n" + "ldr q7, [x28, #0x10]\n" + "fmla v10.4s, v6.4s, v0.s[1]\n" + "fmla v14.4s, v6.4s, v1.s[1]\n" + "ldr q6, [x11, #0x20]\n" + "fmla v11.4s, v7.4s, v0.s[1]\n" + "fmla v15.4s, v7.4s, v1.s[1]\n" + "ldr q7, [x10, #0x20]\n" + "fmla v8.4s, v6.4s, v0.s[2]\n" + "fmla v12.4s, v6.4s, v1.s[2]\n" + "ldr q6, [x9, #0x20]\n" + "fmla v9.4s, v7.4s, v0.s[2]\n" + "fmla v13.4s, v7.4s, v1.s[2]\n" + "ldr q7, [x28, #0x20]\n" + "fmla v10.4s, v6.4s, v0.s[2]\n" + "fmla v14.4s, v6.4s, v1.s[2]\n" + "ldr q6, [x11, #0x30]\n" + "add x11, x11, #0x40\n" + "fmla v11.4s, v7.4s, v0.s[2]\n" + "fmla v15.4s, v7.4s, v1.s[2]\n" + "ldr q7, [x10, #0x30]\n" + "add x10, x10, #0x40\n" + "fmla v8.4s, v6.4s, v0.s[3]\n" + "fmla v12.4s, v6.4s, v1.s[3]\n" + "ldr q6, [x9, #0x30]\n" + "add x9, x9, #0x40\n" + "fmla v9.4s, v7.4s, v0.s[3]\n" + "fmla v13.4s, v7.4s, v1.s[3]\n" + "ldr q7, [x28, #0x30]\n" + "add x28, x28, #0x40\n" + "fmla v10.4s, v6.4s, v0.s[3]\n" + "fmla v14.4s, v6.4s, v1.s[3]\n" + "fmla v11.4s, v7.4s, v0.s[3]\n" + "fmla v15.4s, v7.4s, v1.s[3]\n" + "55:" // Height 2: Multiply loop: Main loop skip + "cbz x26, 57f\n" + "56:" // Height 2: Multiply loop: Odd block loop + "ldr s0, [x25], #0x4\n" + "ldr s1, [x24], #0x4\n" + "sub x26, x26, #0x1\n" + "ldr q6, [x11, #0x0]\n" + "ldr q7, [x10, #0x0]\n" + "fmla v8.4s, v6.4s, v0.s[0]\n" + "fmla v12.4s, v6.4s, v1.s[0]\n" + "ldr q6, [x9, #0x0]\n" + "fmla v9.4s, v7.4s, v0.s[0]\n" + "fmla v13.4s, v7.4s, v1.s[0]\n" + "ldr q7, [x28, #0x0]\n" + "fmla v10.4s, v6.4s, v0.s[0]\n" + "fmla v14.4s, v6.4s, v1.s[0]\n" + "add x11, x11, #0x10\n" + "add x10, x10, #0x10\n" + "fmla v11.4s, v7.4s, v0.s[0]\n" + "fmla v15.4s, v7.4s, v1.s[0]\n" + "add x9, x9, #0x10\n" + "add x28, x28, #0x10\n" + "cbnz x26, 56b\n" + "57:" // Height 2: Multiply loop: No odd multiplies + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 50b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "tbz %x[flags], #1, 58f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1r { v1.4s }, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1r { v0.4s }, [x19]\n" + "fmin v8.4s, v8.4s, v1.4s\n" + "fmin v9.4s, v9.4s, v1.4s\n" + "fmin v10.4s, v10.4s, v1.4s\n" + "fmin v11.4s, v11.4s, v1.4s\n" + "fmin v12.4s, v12.4s, v1.4s\n" + "fmin v13.4s, v13.4s, v1.4s\n" + "fmin v14.4s, v14.4s, v1.4s\n" + "fmin v15.4s, v15.4s, v1.4s\n" + "fmax v8.4s, v8.4s, v0.4s\n" + "fmax v9.4s, v9.4s, v0.4s\n" + "fmax v10.4s, v10.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + "fmax v12.4s, v12.4s, v0.4s\n" + "fmax v13.4s, v13.4s, v0.4s\n" + "fmax v14.4s, v14.4s, v0.4s\n" + "fmax v15.4s, v15.4s, v0.4s\n" + "58:" // Height 2: No activation + "cmp x13, #0x10\n" + "bge 67f\n" + "tbz x13, #3, 62f\n" + "st1 { v8.4s }, [x12], #0x10\n" + "st1 { v9.4s }, [x12], #0x10\n" + "st1 { v12.4s }, [x24], #0x10\n" + "st1 { v13.4s }, [x24], #0x10\n" + "tbz x13, #2, 60f\n" + "st1 { v10.4s }, [x12], #0x10\n" + "st1 { v14.4s }, [x24], #0x10\n" + "tbz x13, #1, 59f\n" + "str d11, [x12], #0x8\n" + "str d15, [x24], #0x8\n" + "tbz x13, #0, 66f\n" + "st1 { v11.s }[2], [x12]\n" + "st1 { v15.s }[2], [x24]\n" + "b 66f\n" + "59:" // Height 2: Partial direct writeback: partial_1_12 + "tbz x13, #0, 66f\n" + "str s11, [x12, #0x0]\n" + "str s15, [x24, #0x0]\n" + "b 66f\n" + "60:" // Height 2: Partial direct writeback: partial_2_8 + "tbz x13, #1, 61f\n" + "str d10, [x12], #0x8\n" + "str d14, [x24], #0x8\n" + "tbz x13, #0, 66f\n" + "st1 { v10.s }[2], [x12]\n" + "st1 { v14.s }[2], [x24]\n" + "b 66f\n" + "61:" // Height 2: Partial direct writeback: partial_1_8 + "tbz x13, #0, 66f\n" + "str s10, [x12, #0x0]\n" + "str s14, [x24, #0x0]\n" + "b 66f\n" + "62:" // Height 2: Partial direct writeback: partial_4_0 + "tbz x13, #2, 64f\n" + "st1 { v8.4s }, [x12], #0x10\n" + "st1 { v12.4s }, [x24], #0x10\n" + "tbz x13, #1, 63f\n" + "str d9, [x12], #0x8\n" + "str d13, [x24], #0x8\n" + "tbz x13, #0, 66f\n" + "st1 { v9.s }[2], [x12]\n" + "st1 { v13.s }[2], [x24]\n" + "b 66f\n" + "63:" // Height 2: Partial direct writeback: partial_1_4 + "tbz x13, #0, 66f\n" + "str s9, [x12, #0x0]\n" + "str s13, [x24, #0x0]\n" + "b 66f\n" + "64:" // Height 2: Partial direct writeback: partial_2_0 + "tbz x13, #1, 65f\n" + "str d8, [x12], #0x8\n" + "str d12, [x24], #0x8\n" + "tbz x13, #0, 66f\n" + "st1 { v8.s }[2], [x12]\n" + "st1 { v12.s }[2], [x24]\n" + "b 66f\n" + "65:" // Height 2: Partial direct writeback: partial_1_0 + "str s8, [x12, #0x0]\n" + "str s12, [x24, #0x0]\n" + "66:" // Height 2: Partial direct writeback: Done + "b 68f\n" + "67:" // Height 2: Full writeback + "str q8, [x12, #0x0]\n" + "str q9, [x12, #0x10]\n" + "str q10, [x12, #0x20]\n" + "str q11, [x12, #0x30]\n" + "add x12, x12, #0x40\n" + "str q12, [x24, #0x0]\n" + "str q13, [x24, #0x10]\n" + "str q14, [x24, #0x20]\n" + "str q15, [x24, #0x30]\n" + "68:" // Height 2: Writeback done + "subs x13, x13, #0x10\n" + "bgt 36b\n" + "b 206f\n" + "69:" // Height 3 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "70:" // Height 3: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #2\n" + "add x9, x10, x19, LSL #2\n" + "add x28, x9, x19, LSL #2\n" + "add x19, x28, x19, LSL #2\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "cmp x13, #0xc\n" + "bgt 71f\n" + "cmp x13, #0x8\n" + "mov x28, x11\n" + "bgt 71f\n" + "cmp x13, #0x4\n" + "mov x9, x11\n" + "bgt 71f\n" + "mov x10, x11\n" + "71:" // Height 3: B setup done + "cbz x14, 72f\n" + "ldr q8, [x14, #0x0]\n" + "ldr q9, [x14, #0x10]\n" + "mov v12.16b, v8.16b\n" + "mov v13.16b, v9.16b\n" + "ldr q10, [x14, #0x20]\n" + "ldr q11, [x14, #0x30]\n" + "mov v14.16b, v10.16b\n" + "mov v15.16b, v11.16b\n" + "mov v16.16b, v8.16b\n" + "mov v17.16b, v9.16b\n" + "add x14, x14, #0x40\n" + "mov v18.16b, v10.16b\n" + "mov v19.16b, v11.16b\n" + "b 83f\n" + "72:" // Height 3: no bias + "tbz %x[flags], #0, 82f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "cmp x13, #0x10\n" + "add x23, x24, x19, LSL #2\n" + "bge 81f\n" + "tbz x13, #3, 76f\n" + "ld1 { v8.4s }, [x12], #0x10\n" + "ld1 { v12.4s }, [x24], #0x10\n" + "ld1 { v16.4s }, [x23], #0x10\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "ld1 { v13.4s }, [x24], #0x10\n" + "ld1 { v17.4s }, [x23], #0x10\n" + "tbz x13, #2, 74f\n" + "ld1 { v10.4s }, [x12], #0x10\n" + "ld1 { v14.4s }, [x24], #0x10\n" + "ld1 { v18.4s }, [x23], #0x10\n" + "tbz x13, #1, 73f\n" + "ldr d11, [x12], #0x8\n" + "ldr d15, [x24], #0x8\n" + "mov x19, #0x38\n" + "ldr d19, [x23], #0x8\n" + "tbz x13, #0, 80f\n" + "ld1 { v11.s }[2], [x12]\n" + "ld1 { v15.s }[2], [x24]\n" + "ld1 { v19.s }[2], [x23]\n" + "b 80f\n" + "73:" // Height 3: Partial accumulate: partial_1_12 + "mov x19, #0x30\n" + "tbz x13, #0, 80f\n" + "ldr s11, [x12, #0x0]\n" + "ldr s15, [x24, #0x0]\n" + "ldr s19, [x23, #0x0]\n" + "b 80f\n" + "74:" // Height 3: Partial accumulate: partial_2_8 + "tbz x13, #1, 75f\n" + "ldr d10, [x12], #0x8\n" + "ldr d14, [x24], #0x8\n" + "mov x19, #0x28\n" + "ldr d18, [x23], #0x8\n" + "tbz x13, #0, 80f\n" + "ld1 { v10.s }[2], [x12]\n" + "ld1 { v14.s }[2], [x24]\n" + "ld1 { v18.s }[2], [x23]\n" + "b 80f\n" + "75:" // Height 3: Partial accumulate: partial_1_8 + "mov x19, #0x20\n" + "tbz x13, #0, 80f\n" + "ldr s10, [x12, #0x0]\n" + "ldr s14, [x24, #0x0]\n" + "ldr s18, [x23, #0x0]\n" + "b 80f\n" + "76:" // Height 3: Partial accumulate: partial_4_0 + "tbz x13, #2, 78f\n" + "ld1 { v8.4s }, [x12], #0x10\n" + "ld1 { v12.4s }, [x24], #0x10\n" + "ld1 { v16.4s }, [x23], #0x10\n" + "tbz x13, #1, 77f\n" + "ldr d9, [x12], #0x8\n" + "ldr d13, [x24], #0x8\n" + "mov x19, #0x18\n" + "ldr d17, [x23], #0x8\n" + "tbz x13, #0, 80f\n" + "ld1 { v9.s }[2], [x12]\n" + "ld1 { v13.s }[2], [x24]\n" + "ld1 { v17.s }[2], [x23]\n" + "b 80f\n" + "77:" // Height 3: Partial accumulate: partial_1_4 + "mov x19, #0x10\n" + "tbz x13, #0, 80f\n" + "ldr s9, [x12, #0x0]\n" + "ldr s13, [x24, #0x0]\n" + "ldr s17, [x23, #0x0]\n" + "b 80f\n" + "78:" // Height 3: Partial accumulate: partial_2_0 + "tbz x13, #1, 79f\n" + "ldr d8, [x12], #0x8\n" + "ldr d12, [x24], #0x8\n" + "mov x19, #0x8\n" + "ldr d16, [x23], #0x8\n" + "tbz x13, #0, 80f\n" + "ld1 { v8.s }[2], [x12]\n" + "ld1 { v12.s }[2], [x24]\n" + "ld1 { v16.s }[2], [x23]\n" + "b 80f\n" + "79:" // Height 3: Partial accumulate: partial_1_0 + "ldr s8, [x12, #0x0]\n" + "ldr s12, [x24, #0x0]\n" + "mov x19, #0x0\n" + "ldr s16, [x23, #0x0]\n" + "80:" // Height 3: Partial accumulate: Done + "sub x12, x12, x19\n" + "b 83f\n" + "81:" // Height 3: full accumulate + "ldr q8, [x12, #0x0]\n" + "ldr q9, [x12, #0x10]\n" + "ldr q10, [x12, #0x20]\n" + "ldr q11, [x12, #0x30]\n" + "ldr q12, [x24, #0x0]\n" + "ldr q13, [x24, #0x10]\n" + "ldr q14, [x24, #0x20]\n" + "ldr q15, [x24, #0x30]\n" + "ldr q16, [x23, #0x0]\n" + "ldr q17, [x23, #0x10]\n" + "ldr q18, [x23, #0x20]\n" + "ldr q19, [x23, #0x30]\n" + "b 83f\n" + "82:" // Height 3: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "83:" // Height 3: setup done + "mov x27, #0x0\n" + "84:" // Height 3: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 85f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "cbnz x27, 86f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #2\n" + "add x24, x24, x19, LSL #2\n" + "add x23, x23, x19, LSL #2\n" + "b 86f\n" + "85:" // Height 3: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "86:" // Height 3: input setup done + "cmp x26, #0x4\n" + "blt 89f\n" + "ldr q0, [x25, #0x0]\n" + "ldr q1, [x24, #0x0]\n" + "cmp x26, #0x8\n" + "ldr q2, [x23, #0x0]\n" + "ldr q6, [x11, #0x0]\n" + "ldr q7, [x10, #0x0]\n" + "blt 88f\n" + "87:" // Height 3: Multiply loop: Main loop head + "fmla v8.4s, v6.4s, v0.s[0]\n" + "fmla v12.4s, v6.4s, v1.s[0]\n" + "sub x26, x26, #0x4\n" + "cmp x26, #0x8\n" + "fmla v16.4s, v6.4s, v2.s[0]\n" + "ldr q6, [x9, #0x0]\n" + "fmla v9.4s, v7.4s, v0.s[0]\n" + "add x25, x25, #0x10\n" + "fmla v13.4s, v7.4s, v1.s[0]\n" + "fmla v17.4s, v7.4s, v2.s[0]\n" + "ldr q7, [x28, #0x0]\n" + "add x24, x24, #0x10\n" + "fmla v10.4s, v6.4s, v0.s[0]\n" + "fmla v14.4s, v6.4s, v1.s[0]\n" + "add x23, x23, #0x10\n" + "fmla v18.4s, v6.4s, v2.s[0]\n" + "ldr q6, [x11, #0x10]\n" + "fmla v11.4s, v7.4s, v0.s[0]\n" + "fmla v15.4s, v7.4s, v1.s[0]\n" + "fmla v19.4s, v7.4s, v2.s[0]\n" + "ldr q7, [x10, #0x10]\n" + "fmla v8.4s, v6.4s, v0.s[1]\n" + "fmla v12.4s, v6.4s, v1.s[1]\n" + "fmla v16.4s, v6.4s, v2.s[1]\n" + "ldr q6, [x9, #0x10]\n" + "fmla v9.4s, v7.4s, v0.s[1]\n" + "fmla v13.4s, v7.4s, v1.s[1]\n" + "fmla v17.4s, v7.4s, v2.s[1]\n" + "ldr q7, [x28, #0x10]\n" + "fmla v10.4s, v6.4s, v0.s[1]\n" + "fmla v14.4s, v6.4s, v1.s[1]\n" + "fmla v18.4s, v6.4s, v2.s[1]\n" + "ldr q6, [x11, #0x20]\n" + "fmla v11.4s, v7.4s, v0.s[1]\n" + "fmla v15.4s, v7.4s, v1.s[1]\n" + "fmla v19.4s, v7.4s, v2.s[1]\n" + "ldr q7, [x10, #0x20]\n" + "fmla v8.4s, v6.4s, v0.s[2]\n" + "fmla v12.4s, v6.4s, v1.s[2]\n" + "fmla v16.4s, v6.4s, v2.s[2]\n" + "ldr q6, [x9, #0x20]\n" + "fmla v9.4s, v7.4s, v0.s[2]\n" + "fmla v13.4s, v7.4s, v1.s[2]\n" + "fmla v17.4s, v7.4s, v2.s[2]\n" + "ldr q7, [x28, #0x20]\n" + "fmla v10.4s, v6.4s, v0.s[2]\n" + "fmla v14.4s, v6.4s, v1.s[2]\n" + "fmla v18.4s, v6.4s, v2.s[2]\n" + "ldr q6, [x11, #0x30]\n" + "fmla v11.4s, v7.4s, v0.s[2]\n" + "add x11, x11, #0x40\n" + "fmla v15.4s, v7.4s, v1.s[2]\n" + "fmla v19.4s, v7.4s, v2.s[2]\n" + "ldr q7, [x10, #0x30]\n" + "add x10, x10, #0x40\n" + "fmla v8.4s, v6.4s, v0.s[3]\n" + "fmla v12.4s, v6.4s, v1.s[3]\n" + "fmla v16.4s, v6.4s, v2.s[3]\n" + "ldr q6, [x9, #0x30]\n" + "fmla v9.4s, v7.4s, v0.s[3]\n" + "add x9, x9, #0x40\n" + "fmla v13.4s, v7.4s, v1.s[3]\n" + "fmla v17.4s, v7.4s, v2.s[3]\n" + "ldr q7, [x28, #0x30]\n" + "add x28, x28, #0x40\n" + "fmla v10.4s, v6.4s, v0.s[3]\n" + "fmla v14.4s, v6.4s, v1.s[3]\n" + "fmla v18.4s, v6.4s, v2.s[3]\n" + "ldr q6, [x11, #0x0]\n" + "fmla v11.4s, v7.4s, v0.s[3]\n" + "ldr q0, [x25, #0x0]\n" + "fmla v15.4s, v7.4s, v1.s[3]\n" + "ldr q1, [x24, #0x0]\n" + "fmla v19.4s, v7.4s, v2.s[3]\n" + "ldr q2, [x23, #0x0]\n" + "ldr q7, [x10, #0x0]\n" + "bge 87b\n" + "88:" // Height 3: Multiply loop: Single iteration only + "fmla v8.4s, v6.4s, v0.s[0]\n" + "fmla v12.4s, v6.4s, v1.s[0]\n" + "sub x26, x26, #0x4\n" + "add x25, x25, #0x10\n" + "fmla v16.4s, v6.4s, v2.s[0]\n" + "ldr q6, [x9, #0x0]\n" + "fmla v9.4s, v7.4s, v0.s[0]\n" + "add x24, x24, #0x10\n" + "fmla v13.4s, v7.4s, v1.s[0]\n" + "fmla v17.4s, v7.4s, v2.s[0]\n" + "ldr q7, [x28, #0x0]\n" + "add x23, x23, #0x10\n" + "fmla v10.4s, v6.4s, v0.s[0]\n" + "fmla v14.4s, v6.4s, v1.s[0]\n" + "fmla v18.4s, v6.4s, v2.s[0]\n" + "ldr q6, [x11, #0x10]\n" + "fmla v11.4s, v7.4s, v0.s[0]\n" + "fmla v15.4s, v7.4s, v1.s[0]\n" + "fmla v19.4s, v7.4s, v2.s[0]\n" + "ldr q7, [x10, #0x10]\n" + "fmla v8.4s, v6.4s, v0.s[1]\n" + "fmla v12.4s, v6.4s, v1.s[1]\n" + "fmla v16.4s, v6.4s, v2.s[1]\n" + "ldr q6, [x9, #0x10]\n" + "fmla v9.4s, v7.4s, v0.s[1]\n" + "fmla v13.4s, v7.4s, v1.s[1]\n" + "fmla v17.4s, v7.4s, v2.s[1]\n" + "ldr q7, [x28, #0x10]\n" + "fmla v10.4s, v6.4s, v0.s[1]\n" + "fmla v14.4s, v6.4s, v1.s[1]\n" + "fmla v18.4s, v6.4s, v2.s[1]\n" + "ldr q6, [x11, #0x20]\n" + "fmla v11.4s, v7.4s, v0.s[1]\n" + "fmla v15.4s, v7.4s, v1.s[1]\n" + "fmla v19.4s, v7.4s, v2.s[1]\n" + "ldr q7, [x10, #0x20]\n" + "fmla v8.4s, v6.4s, v0.s[2]\n" + "fmla v12.4s, v6.4s, v1.s[2]\n" + "fmla v16.4s, v6.4s, v2.s[2]\n" + "ldr q6, [x9, #0x20]\n" + "fmla v9.4s, v7.4s, v0.s[2]\n" + "fmla v13.4s, v7.4s, v1.s[2]\n" + "fmla v17.4s, v7.4s, v2.s[2]\n" + "ldr q7, [x28, #0x20]\n" + "fmla v10.4s, v6.4s, v0.s[2]\n" + "fmla v14.4s, v6.4s, v1.s[2]\n" + "fmla v18.4s, v6.4s, v2.s[2]\n" + "ldr q6, [x11, #0x30]\n" + "fmla v11.4s, v7.4s, v0.s[2]\n" + "add x11, x11, #0x40\n" + "fmla v15.4s, v7.4s, v1.s[2]\n" + "fmla v19.4s, v7.4s, v2.s[2]\n" + "ldr q7, [x10, #0x30]\n" + "add x10, x10, #0x40\n" + "fmla v8.4s, v6.4s, v0.s[3]\n" + "fmla v12.4s, v6.4s, v1.s[3]\n" + "fmla v16.4s, v6.4s, v2.s[3]\n" + "ldr q6, [x9, #0x30]\n" + "fmla v9.4s, v7.4s, v0.s[3]\n" + "add x9, x9, #0x40\n" + "fmla v13.4s, v7.4s, v1.s[3]\n" + "fmla v17.4s, v7.4s, v2.s[3]\n" + "ldr q7, [x28, #0x30]\n" + "add x28, x28, #0x40\n" + "fmla v10.4s, v6.4s, v0.s[3]\n" + "fmla v14.4s, v6.4s, v1.s[3]\n" + "fmla v18.4s, v6.4s, v2.s[3]\n" + "fmla v11.4s, v7.4s, v0.s[3]\n" + "fmla v15.4s, v7.4s, v1.s[3]\n" + "fmla v19.4s, v7.4s, v2.s[3]\n" + "89:" // Height 3: Multiply loop: Main loop skip + "cbz x26, 91f\n" + "90:" // Height 3: Multiply loop: Odd block loop + "ldr s0, [x25], #0x4\n" + "ldr s1, [x24], #0x4\n" + "sub x26, x26, #0x1\n" + "ldr s2, [x23], #0x4\n" + "ldr q6, [x11, #0x0]\n" + "fmla v8.4s, v6.4s, v0.s[0]\n" + "fmla v12.4s, v6.4s, v1.s[0]\n" + "ldr q7, [x10, #0x0]\n" + "fmla v16.4s, v6.4s, v2.s[0]\n" + "ldr q6, [x9, #0x0]\n" + "fmla v9.4s, v7.4s, v0.s[0]\n" + "fmla v13.4s, v7.4s, v1.s[0]\n" + "fmla v17.4s, v7.4s, v2.s[0]\n" + "ldr q7, [x28, #0x0]\n" + "add x11, x11, #0x10\n" + "fmla v10.4s, v6.4s, v0.s[0]\n" + "fmla v14.4s, v6.4s, v1.s[0]\n" + "add x10, x10, #0x10\n" + "add x9, x9, #0x10\n" + "fmla v18.4s, v6.4s, v2.s[0]\n" + "fmla v11.4s, v7.4s, v0.s[0]\n" + "add x28, x28, #0x10\n" + "fmla v15.4s, v7.4s, v1.s[0]\n" + "fmla v19.4s, v7.4s, v2.s[0]\n" + "cbnz x26, 90b\n" + "91:" // Height 3: Multiply loop: No odd multiplies + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 84b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "tbz %x[flags], #1, 92f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1r { v1.4s }, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1r { v0.4s }, [x19]\n" + "fmin v8.4s, v8.4s, v1.4s\n" + "fmin v9.4s, v9.4s, v1.4s\n" + "fmin v10.4s, v10.4s, v1.4s\n" + "fmin v11.4s, v11.4s, v1.4s\n" + "fmin v12.4s, v12.4s, v1.4s\n" + "fmin v13.4s, v13.4s, v1.4s\n" + "fmin v14.4s, v14.4s, v1.4s\n" + "fmin v15.4s, v15.4s, v1.4s\n" + "fmin v16.4s, v16.4s, v1.4s\n" + "fmin v17.4s, v17.4s, v1.4s\n" + "fmin v18.4s, v18.4s, v1.4s\n" + "fmin v19.4s, v19.4s, v1.4s\n" + "fmax v8.4s, v8.4s, v0.4s\n" + "fmax v9.4s, v9.4s, v0.4s\n" + "fmax v10.4s, v10.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + "fmax v12.4s, v12.4s, v0.4s\n" + "fmax v13.4s, v13.4s, v0.4s\n" + "fmax v14.4s, v14.4s, v0.4s\n" + "fmax v15.4s, v15.4s, v0.4s\n" + "fmax v16.4s, v16.4s, v0.4s\n" + "fmax v17.4s, v17.4s, v0.4s\n" + "fmax v18.4s, v18.4s, v0.4s\n" + "fmax v19.4s, v19.4s, v0.4s\n" + "92:" // Height 3: No activation + "cmp x13, #0x10\n" + "bge 101f\n" + "tbz x13, #3, 96f\n" + "st1 { v8.4s }, [x12], #0x10\n" + "st1 { v9.4s }, [x12], #0x10\n" + "st1 { v12.4s }, [x24], #0x10\n" + "st1 { v13.4s }, [x24], #0x10\n" + "st1 { v16.4s }, [x23], #0x10\n" + "st1 { v17.4s }, [x23], #0x10\n" + "tbz x13, #2, 94f\n" + "st1 { v10.4s }, [x12], #0x10\n" + "st1 { v14.4s }, [x24], #0x10\n" + "st1 { v18.4s }, [x23], #0x10\n" + "tbz x13, #1, 93f\n" + "str d11, [x12], #0x8\n" + "str d15, [x24], #0x8\n" + "str d19, [x23], #0x8\n" + "tbz x13, #0, 100f\n" + "st1 { v11.s }[2], [x12]\n" + "st1 { v15.s }[2], [x24]\n" + "st1 { v19.s }[2], [x23]\n" + "b 100f\n" + "93:" // Height 3: Partial direct writeback: partial_1_12 + "tbz x13, #0, 100f\n" + "str s11, [x12, #0x0]\n" + "str s15, [x24, #0x0]\n" + "str s19, [x23, #0x0]\n" + "b 100f\n" + "94:" // Height 3: Partial direct writeback: partial_2_8 + "tbz x13, #1, 95f\n" + "str d10, [x12], #0x8\n" + "str d14, [x24], #0x8\n" + "str d18, [x23], #0x8\n" + "tbz x13, #0, 100f\n" + "st1 { v10.s }[2], [x12]\n" + "st1 { v14.s }[2], [x24]\n" + "st1 { v18.s }[2], [x23]\n" + "b 100f\n" + "95:" // Height 3: Partial direct writeback: partial_1_8 + "tbz x13, #0, 100f\n" + "str s10, [x12, #0x0]\n" + "str s14, [x24, #0x0]\n" + "str s18, [x23, #0x0]\n" + "b 100f\n" + "96:" // Height 3: Partial direct writeback: partial_4_0 + "tbz x13, #2, 98f\n" + "st1 { v8.4s }, [x12], #0x10\n" + "st1 { v12.4s }, [x24], #0x10\n" + "st1 { v16.4s }, [x23], #0x10\n" + "tbz x13, #1, 97f\n" + "str d9, [x12], #0x8\n" + "str d13, [x24], #0x8\n" + "str d17, [x23], #0x8\n" + "tbz x13, #0, 100f\n" + "st1 { v9.s }[2], [x12]\n" + "st1 { v13.s }[2], [x24]\n" + "st1 { v17.s }[2], [x23]\n" + "b 100f\n" + "97:" // Height 3: Partial direct writeback: partial_1_4 + "tbz x13, #0, 100f\n" + "str s9, [x12, #0x0]\n" + "str s13, [x24, #0x0]\n" + "str s17, [x23, #0x0]\n" + "b 100f\n" + "98:" // Height 3: Partial direct writeback: partial_2_0 + "tbz x13, #1, 99f\n" + "str d8, [x12], #0x8\n" + "str d12, [x24], #0x8\n" + "str d16, [x23], #0x8\n" + "tbz x13, #0, 100f\n" + "st1 { v8.s }[2], [x12]\n" + "st1 { v12.s }[2], [x24]\n" + "st1 { v16.s }[2], [x23]\n" + "b 100f\n" + "99:" // Height 3: Partial direct writeback: partial_1_0 + "str s8, [x12, #0x0]\n" + "str s12, [x24, #0x0]\n" + "str s16, [x23, #0x0]\n" + "100:" // Height 3: Partial direct writeback: Done + "b 102f\n" + "101:" // Height 3: Full writeback + "str q8, [x12, #0x0]\n" + "str q9, [x12, #0x10]\n" + "str q10, [x12, #0x20]\n" + "str q11, [x12, #0x30]\n" + "add x12, x12, #0x40\n" + "str q12, [x24, #0x0]\n" + "str q13, [x24, #0x10]\n" + "str q14, [x24, #0x20]\n" + "str q15, [x24, #0x30]\n" + "str q16, [x23, #0x0]\n" + "str q17, [x23, #0x10]\n" + "str q18, [x23, #0x20]\n" + "str q19, [x23, #0x30]\n" + "102:" // Height 3: Writeback done + "subs x13, x13, #0x10\n" + "bgt 70b\n" + "b 206f\n" + "103:" // Height 4 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "104:" // Height 4: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #2\n" + "add x9, x10, x19, LSL #2\n" + "add x28, x9, x19, LSL #2\n" + "add x19, x28, x19, LSL #2\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "cmp x13, #0xc\n" + "bgt 105f\n" + "cmp x13, #0x8\n" + "mov x28, x11\n" + "bgt 105f\n" + "cmp x13, #0x4\n" + "mov x9, x11\n" + "bgt 105f\n" + "mov x10, x11\n" + "105:" // Height 4: B setup done + "cbz x14, 106f\n" + "ldr q8, [x14, #0x0]\n" + "ldr q9, [x14, #0x10]\n" + "mov v12.16b, v8.16b\n" + "mov v13.16b, v9.16b\n" + "ldr q10, [x14, #0x20]\n" + "ldr q11, [x14, #0x30]\n" + "mov v14.16b, v10.16b\n" + "mov v15.16b, v11.16b\n" + "mov v16.16b, v8.16b\n" + "mov v17.16b, v9.16b\n" + "add x14, x14, #0x40\n" + "mov v18.16b, v10.16b\n" + "mov v19.16b, v11.16b\n" + "mov v20.16b, v8.16b\n" + "mov v21.16b, v9.16b\n" + "mov v22.16b, v10.16b\n" + "mov v23.16b, v11.16b\n" + "b 117f\n" + "106:" // Height 4: no bias + "tbz %x[flags], #0, 116f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "cmp x13, #0x10\n" + "add x22, x23, x19, LSL #2\n" + "bge 115f\n" + "tbz x13, #3, 110f\n" + "ld1 { v8.4s }, [x12], #0x10\n" + "ld1 { v12.4s }, [x24], #0x10\n" + "ld1 { v16.4s }, [x23], #0x10\n" + "ld1 { v20.4s }, [x22], #0x10\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "ld1 { v13.4s }, [x24], #0x10\n" + "ld1 { v17.4s }, [x23], #0x10\n" + "ld1 { v21.4s }, [x22], #0x10\n" + "tbz x13, #2, 108f\n" + "ld1 { v10.4s }, [x12], #0x10\n" + "ld1 { v14.4s }, [x24], #0x10\n" + "ld1 { v18.4s }, [x23], #0x10\n" + "ld1 { v22.4s }, [x22], #0x10\n" + "tbz x13, #1, 107f\n" + "ldr d11, [x12], #0x8\n" + "ldr d15, [x24], #0x8\n" + "mov x19, #0x38\n" + "ldr d19, [x23], #0x8\n" + "ldr d23, [x22], #0x8\n" + "tbz x13, #0, 114f\n" + "ld1 { v11.s }[2], [x12]\n" + "ld1 { v15.s }[2], [x24]\n" + "ld1 { v19.s }[2], [x23]\n" + "ld1 { v23.s }[2], [x22]\n" + "b 114f\n" + "107:" // Height 4: Partial accumulate: partial_1_12 + "mov x19, #0x30\n" + "tbz x13, #0, 114f\n" + "ldr s11, [x12, #0x0]\n" + "ldr s15, [x24, #0x0]\n" + "ldr s19, [x23, #0x0]\n" + "ldr s23, [x22, #0x0]\n" + "b 114f\n" + "108:" // Height 4: Partial accumulate: partial_2_8 + "tbz x13, #1, 109f\n" + "ldr d10, [x12], #0x8\n" + "ldr d14, [x24], #0x8\n" + "mov x19, #0x28\n" + "ldr d18, [x23], #0x8\n" + "ldr d22, [x22], #0x8\n" + "tbz x13, #0, 114f\n" + "ld1 { v10.s }[2], [x12]\n" + "ld1 { v14.s }[2], [x24]\n" + "ld1 { v18.s }[2], [x23]\n" + "ld1 { v22.s }[2], [x22]\n" + "b 114f\n" + "109:" // Height 4: Partial accumulate: partial_1_8 + "mov x19, #0x20\n" + "tbz x13, #0, 114f\n" + "ldr s10, [x12, #0x0]\n" + "ldr s14, [x24, #0x0]\n" + "ldr s18, [x23, #0x0]\n" + "ldr s22, [x22, #0x0]\n" + "b 114f\n" + "110:" // Height 4: Partial accumulate: partial_4_0 + "tbz x13, #2, 112f\n" + "ld1 { v8.4s }, [x12], #0x10\n" + "ld1 { v12.4s }, [x24], #0x10\n" + "ld1 { v16.4s }, [x23], #0x10\n" + "ld1 { v20.4s }, [x22], #0x10\n" + "tbz x13, #1, 111f\n" + "ldr d9, [x12], #0x8\n" + "ldr d13, [x24], #0x8\n" + "mov x19, #0x18\n" + "ldr d17, [x23], #0x8\n" + "ldr d21, [x22], #0x8\n" + "tbz x13, #0, 114f\n" + "ld1 { v9.s }[2], [x12]\n" + "ld1 { v13.s }[2], [x24]\n" + "ld1 { v17.s }[2], [x23]\n" + "ld1 { v21.s }[2], [x22]\n" + "b 114f\n" + "111:" // Height 4: Partial accumulate: partial_1_4 + "mov x19, #0x10\n" + "tbz x13, #0, 114f\n" + "ldr s9, [x12, #0x0]\n" + "ldr s13, [x24, #0x0]\n" + "ldr s17, [x23, #0x0]\n" + "ldr s21, [x22, #0x0]\n" + "b 114f\n" + "112:" // Height 4: Partial accumulate: partial_2_0 + "tbz x13, #1, 113f\n" + "ldr d8, [x12], #0x8\n" + "ldr d12, [x24], #0x8\n" + "mov x19, #0x8\n" + "ldr d16, [x23], #0x8\n" + "ldr d20, [x22], #0x8\n" + "tbz x13, #0, 114f\n" + "ld1 { v8.s }[2], [x12]\n" + "ld1 { v12.s }[2], [x24]\n" + "ld1 { v16.s }[2], [x23]\n" + "ld1 { v20.s }[2], [x22]\n" + "b 114f\n" + "113:" // Height 4: Partial accumulate: partial_1_0 + "ldr s8, [x12, #0x0]\n" + "ldr s12, [x24, #0x0]\n" + "mov x19, #0x0\n" + "ldr s16, [x23, #0x0]\n" + "ldr s20, [x22, #0x0]\n" + "114:" // Height 4: Partial accumulate: Done + "sub x12, x12, x19\n" + "b 117f\n" + "115:" // Height 4: full accumulate + "ldr q8, [x12, #0x0]\n" + "ldr q9, [x12, #0x10]\n" + "ldr q10, [x12, #0x20]\n" + "ldr q11, [x12, #0x30]\n" + "ldr q12, [x24, #0x0]\n" + "ldr q13, [x24, #0x10]\n" + "ldr q14, [x24, #0x20]\n" + "ldr q15, [x24, #0x30]\n" + "ldr q16, [x23, #0x0]\n" + "ldr q17, [x23, #0x10]\n" + "ldr q18, [x23, #0x20]\n" + "ldr q19, [x23, #0x30]\n" + "ldr q20, [x22, #0x0]\n" + "ldr q21, [x22, #0x10]\n" + "ldr q22, [x22, #0x20]\n" + "ldr q23, [x22, #0x30]\n" + "b 117f\n" + "116:" // Height 4: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "117:" // Height 4: setup done + "mov x27, #0x0\n" + "118:" // Height 4: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 119f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "ldr x22, [x20, #0x18]\n" + "cbnz x27, 120f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #2\n" + "add x24, x24, x19, LSL #2\n" + "add x23, x23, x19, LSL #2\n" + "add x22, x22, x19, LSL #2\n" + "b 120f\n" + "119:" // Height 4: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "add x22, x23, x19, LSL #2\n" + "120:" // Height 4: input setup done + "cmp x26, #0x4\n" + "blt 123f\n" + "ldr q0, [x25, #0x0]\n" + "ldr q1, [x24, #0x0]\n" + "cmp x26, #0x8\n" + "ldr q2, [x23, #0x0]\n" + "ldr q3, [x22, #0x0]\n" + "ldr q6, [x11, #0x0]\n" + "ldr q7, [x10, #0x0]\n" + "blt 122f\n" + "121:" // Height 4: Multiply loop: Main loop head + "fmla v8.4s, v6.4s, v0.s[0]\n" + "fmla v12.4s, v6.4s, v1.s[0]\n" + "sub x26, x26, #0x4\n" + "cmp x26, #0x8\n" + "fmla v16.4s, v6.4s, v2.s[0]\n" + "fmla v20.4s, v6.4s, v3.s[0]\n" + "ldr q6, [x9, #0x0]\n" + "add x25, x25, #0x10\n" + "fmla v9.4s, v7.4s, v0.s[0]\n" + "fmla v13.4s, v7.4s, v1.s[0]\n" + "add x24, x24, #0x10\n" + "add x23, x23, #0x10\n" + "fmla v17.4s, v7.4s, v2.s[0]\n" + "fmla v21.4s, v7.4s, v3.s[0]\n" + "ldr q7, [x28, #0x0]\n" + "add x22, x22, #0x10\n" + "fmla v10.4s, v6.4s, v0.s[0]\n" + "fmla v14.4s, v6.4s, v1.s[0]\n" + "fmla v18.4s, v6.4s, v2.s[0]\n" + "fmla v22.4s, v6.4s, v3.s[0]\n" + "ldr q6, [x11, #0x10]\n" + "fmla v11.4s, v7.4s, v0.s[0]\n" + "fmla v15.4s, v7.4s, v1.s[0]\n" + "fmla v19.4s, v7.4s, v2.s[0]\n" + "fmla v23.4s, v7.4s, v3.s[0]\n" + "ldr q7, [x10, #0x10]\n" + "fmla v8.4s, v6.4s, v0.s[1]\n" + "fmla v12.4s, v6.4s, v1.s[1]\n" + "fmla v16.4s, v6.4s, v2.s[1]\n" + "fmla v20.4s, v6.4s, v3.s[1]\n" + "ldr q6, [x9, #0x10]\n" + "fmla v9.4s, v7.4s, v0.s[1]\n" + "fmla v13.4s, v7.4s, v1.s[1]\n" + "fmla v17.4s, v7.4s, v2.s[1]\n" + "fmla v21.4s, v7.4s, v3.s[1]\n" + "ldr q7, [x28, #0x10]\n" + "fmla v10.4s, v6.4s, v0.s[1]\n" + "fmla v14.4s, v6.4s, v1.s[1]\n" + "fmla v18.4s, v6.4s, v2.s[1]\n" + "fmla v22.4s, v6.4s, v3.s[1]\n" + "ldr q6, [x11, #0x20]\n" + "fmla v11.4s, v7.4s, v0.s[1]\n" + "fmla v15.4s, v7.4s, v1.s[1]\n" + "fmla v19.4s, v7.4s, v2.s[1]\n" + "fmla v23.4s, v7.4s, v3.s[1]\n" + "ldr q7, [x10, #0x20]\n" + "fmla v8.4s, v6.4s, v0.s[2]\n" + "fmla v12.4s, v6.4s, v1.s[2]\n" + "fmla v16.4s, v6.4s, v2.s[2]\n" + "fmla v20.4s, v6.4s, v3.s[2]\n" + "ldr q6, [x9, #0x20]\n" + "fmla v9.4s, v7.4s, v0.s[2]\n" + "fmla v13.4s, v7.4s, v1.s[2]\n" + "fmla v17.4s, v7.4s, v2.s[2]\n" + "fmla v21.4s, v7.4s, v3.s[2]\n" + "ldr q7, [x28, #0x20]\n" + "fmla v10.4s, v6.4s, v0.s[2]\n" + "fmla v14.4s, v6.4s, v1.s[2]\n" + "fmla v18.4s, v6.4s, v2.s[2]\n" + "fmla v22.4s, v6.4s, v3.s[2]\n" + "ldr q6, [x11, #0x30]\n" + "add x11, x11, #0x40\n" + "fmla v11.4s, v7.4s, v0.s[2]\n" + "fmla v15.4s, v7.4s, v1.s[2]\n" + "fmla v19.4s, v7.4s, v2.s[2]\n" + "fmla v23.4s, v7.4s, v3.s[2]\n" + "ldr q7, [x10, #0x30]\n" + "add x10, x10, #0x40\n" + "fmla v8.4s, v6.4s, v0.s[3]\n" + "fmla v12.4s, v6.4s, v1.s[3]\n" + "fmla v16.4s, v6.4s, v2.s[3]\n" + "fmla v20.4s, v6.4s, v3.s[3]\n" + "ldr q6, [x9, #0x30]\n" + "add x9, x9, #0x40\n" + "fmla v9.4s, v7.4s, v0.s[3]\n" + "fmla v13.4s, v7.4s, v1.s[3]\n" + "fmla v17.4s, v7.4s, v2.s[3]\n" + "fmla v21.4s, v7.4s, v3.s[3]\n" + "ldr q7, [x28, #0x30]\n" + "add x28, x28, #0x40\n" + "fmla v10.4s, v6.4s, v0.s[3]\n" + "fmla v14.4s, v6.4s, v1.s[3]\n" + "fmla v18.4s, v6.4s, v2.s[3]\n" + "fmla v22.4s, v6.4s, v3.s[3]\n" + "ldr q6, [x11, #0x0]\n" + "fmla v11.4s, v7.4s, v0.s[3]\n" + "ldr q0, [x25, #0x0]\n" + "fmla v15.4s, v7.4s, v1.s[3]\n" + "ldr q1, [x24, #0x0]\n" + "fmla v19.4s, v7.4s, v2.s[3]\n" + "ldr q2, [x23, #0x0]\n" + "fmla v23.4s, v7.4s, v3.s[3]\n" + "ldr q3, [x22, #0x0]\n" + "ldr q7, [x10, #0x0]\n" + "bge 121b\n" + "122:" // Height 4: Multiply loop: Single iteration only + "fmla v8.4s, v6.4s, v0.s[0]\n" + "fmla v12.4s, v6.4s, v1.s[0]\n" + "sub x26, x26, #0x4\n" + "add x25, x25, #0x10\n" + "fmla v16.4s, v6.4s, v2.s[0]\n" + "fmla v20.4s, v6.4s, v3.s[0]\n" + "ldr q6, [x9, #0x0]\n" + "add x24, x24, #0x10\n" + "fmla v9.4s, v7.4s, v0.s[0]\n" + "fmla v13.4s, v7.4s, v1.s[0]\n" + "add x23, x23, #0x10\n" + "add x22, x22, #0x10\n" + "fmla v17.4s, v7.4s, v2.s[0]\n" + "fmla v21.4s, v7.4s, v3.s[0]\n" + "ldr q7, [x28, #0x0]\n" + "fmla v10.4s, v6.4s, v0.s[0]\n" + "fmla v14.4s, v6.4s, v1.s[0]\n" + "fmla v18.4s, v6.4s, v2.s[0]\n" + "fmla v22.4s, v6.4s, v3.s[0]\n" + "ldr q6, [x11, #0x10]\n" + "fmla v11.4s, v7.4s, v0.s[0]\n" + "fmla v15.4s, v7.4s, v1.s[0]\n" + "fmla v19.4s, v7.4s, v2.s[0]\n" + "fmla v23.4s, v7.4s, v3.s[0]\n" + "ldr q7, [x10, #0x10]\n" + "fmla v8.4s, v6.4s, v0.s[1]\n" + "fmla v12.4s, v6.4s, v1.s[1]\n" + "fmla v16.4s, v6.4s, v2.s[1]\n" + "fmla v20.4s, v6.4s, v3.s[1]\n" + "ldr q6, [x9, #0x10]\n" + "fmla v9.4s, v7.4s, v0.s[1]\n" + "fmla v13.4s, v7.4s, v1.s[1]\n" + "fmla v17.4s, v7.4s, v2.s[1]\n" + "fmla v21.4s, v7.4s, v3.s[1]\n" + "ldr q7, [x28, #0x10]\n" + "fmla v10.4s, v6.4s, v0.s[1]\n" + "fmla v14.4s, v6.4s, v1.s[1]\n" + "fmla v18.4s, v6.4s, v2.s[1]\n" + "fmla v22.4s, v6.4s, v3.s[1]\n" + "ldr q6, [x11, #0x20]\n" + "fmla v11.4s, v7.4s, v0.s[1]\n" + "fmla v15.4s, v7.4s, v1.s[1]\n" + "fmla v19.4s, v7.4s, v2.s[1]\n" + "fmla v23.4s, v7.4s, v3.s[1]\n" + "ldr q7, [x10, #0x20]\n" + "fmla v8.4s, v6.4s, v0.s[2]\n" + "fmla v12.4s, v6.4s, v1.s[2]\n" + "fmla v16.4s, v6.4s, v2.s[2]\n" + "fmla v20.4s, v6.4s, v3.s[2]\n" + "ldr q6, [x9, #0x20]\n" + "fmla v9.4s, v7.4s, v0.s[2]\n" + "fmla v13.4s, v7.4s, v1.s[2]\n" + "fmla v17.4s, v7.4s, v2.s[2]\n" + "fmla v21.4s, v7.4s, v3.s[2]\n" + "ldr q7, [x28, #0x20]\n" + "fmla v10.4s, v6.4s, v0.s[2]\n" + "fmla v14.4s, v6.4s, v1.s[2]\n" + "fmla v18.4s, v6.4s, v2.s[2]\n" + "fmla v22.4s, v6.4s, v3.s[2]\n" + "ldr q6, [x11, #0x30]\n" + "add x11, x11, #0x40\n" + "fmla v11.4s, v7.4s, v0.s[2]\n" + "fmla v15.4s, v7.4s, v1.s[2]\n" + "fmla v19.4s, v7.4s, v2.s[2]\n" + "fmla v23.4s, v7.4s, v3.s[2]\n" + "ldr q7, [x10, #0x30]\n" + "add x10, x10, #0x40\n" + "fmla v8.4s, v6.4s, v0.s[3]\n" + "fmla v12.4s, v6.4s, v1.s[3]\n" + "fmla v16.4s, v6.4s, v2.s[3]\n" + "fmla v20.4s, v6.4s, v3.s[3]\n" + "ldr q6, [x9, #0x30]\n" + "add x9, x9, #0x40\n" + "fmla v9.4s, v7.4s, v0.s[3]\n" + "fmla v13.4s, v7.4s, v1.s[3]\n" + "fmla v17.4s, v7.4s, v2.s[3]\n" + "fmla v21.4s, v7.4s, v3.s[3]\n" + "ldr q7, [x28, #0x30]\n" + "add x28, x28, #0x40\n" + "fmla v10.4s, v6.4s, v0.s[3]\n" + "fmla v14.4s, v6.4s, v1.s[3]\n" + "fmla v18.4s, v6.4s, v2.s[3]\n" + "fmla v22.4s, v6.4s, v3.s[3]\n" + "fmla v11.4s, v7.4s, v0.s[3]\n" + "fmla v15.4s, v7.4s, v1.s[3]\n" + "fmla v19.4s, v7.4s, v2.s[3]\n" + "fmla v23.4s, v7.4s, v3.s[3]\n" + "123:" // Height 4: Multiply loop: Main loop skip + "cbz x26, 125f\n" + "124:" // Height 4: Multiply loop: Odd block loop + "ldr s0, [x25], #0x4\n" + "ldr s1, [x24], #0x4\n" + "sub x26, x26, #0x1\n" + "ldr s2, [x23], #0x4\n" + "ldr s3, [x22], #0x4\n" + "ldr q6, [x11, #0x0]\n" + "ldr q7, [x10, #0x0]\n" + "fmla v8.4s, v6.4s, v0.s[0]\n" + "fmla v12.4s, v6.4s, v1.s[0]\n" + "fmla v16.4s, v6.4s, v2.s[0]\n" + "fmla v20.4s, v6.4s, v3.s[0]\n" + "ldr q6, [x9, #0x0]\n" + "add x11, x11, #0x10\n" + "fmla v9.4s, v7.4s, v0.s[0]\n" + "fmla v13.4s, v7.4s, v1.s[0]\n" + "add x10, x10, #0x10\n" + "add x9, x9, #0x10\n" + "fmla v17.4s, v7.4s, v2.s[0]\n" + "fmla v21.4s, v7.4s, v3.s[0]\n" + "ldr q7, [x28, #0x0]\n" + "add x28, x28, #0x10\n" + "fmla v10.4s, v6.4s, v0.s[0]\n" + "fmla v14.4s, v6.4s, v1.s[0]\n" + "fmla v18.4s, v6.4s, v2.s[0]\n" + "fmla v22.4s, v6.4s, v3.s[0]\n" + "fmla v11.4s, v7.4s, v0.s[0]\n" + "fmla v15.4s, v7.4s, v1.s[0]\n" + "fmla v19.4s, v7.4s, v2.s[0]\n" + "fmla v23.4s, v7.4s, v3.s[0]\n" + "cbnz x26, 124b\n" + "125:" // Height 4: Multiply loop: No odd multiplies + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 118b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "add x22, x23, x19, LSL #2\n" + "tbz %x[flags], #1, 126f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1r { v1.4s }, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1r { v0.4s }, [x19]\n" + "fmin v8.4s, v8.4s, v1.4s\n" + "fmin v9.4s, v9.4s, v1.4s\n" + "fmin v10.4s, v10.4s, v1.4s\n" + "fmin v11.4s, v11.4s, v1.4s\n" + "fmin v12.4s, v12.4s, v1.4s\n" + "fmin v13.4s, v13.4s, v1.4s\n" + "fmin v14.4s, v14.4s, v1.4s\n" + "fmin v15.4s, v15.4s, v1.4s\n" + "fmin v16.4s, v16.4s, v1.4s\n" + "fmin v17.4s, v17.4s, v1.4s\n" + "fmin v18.4s, v18.4s, v1.4s\n" + "fmin v19.4s, v19.4s, v1.4s\n" + "fmin v20.4s, v20.4s, v1.4s\n" + "fmin v21.4s, v21.4s, v1.4s\n" + "fmin v22.4s, v22.4s, v1.4s\n" + "fmin v23.4s, v23.4s, v1.4s\n" + "fmax v8.4s, v8.4s, v0.4s\n" + "fmax v9.4s, v9.4s, v0.4s\n" + "fmax v10.4s, v10.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + "fmax v12.4s, v12.4s, v0.4s\n" + "fmax v13.4s, v13.4s, v0.4s\n" + "fmax v14.4s, v14.4s, v0.4s\n" + "fmax v15.4s, v15.4s, v0.4s\n" + "fmax v16.4s, v16.4s, v0.4s\n" + "fmax v17.4s, v17.4s, v0.4s\n" + "fmax v18.4s, v18.4s, v0.4s\n" + "fmax v19.4s, v19.4s, v0.4s\n" + "fmax v20.4s, v20.4s, v0.4s\n" + "fmax v21.4s, v21.4s, v0.4s\n" + "fmax v22.4s, v22.4s, v0.4s\n" + "fmax v23.4s, v23.4s, v0.4s\n" + "126:" // Height 4: No activation + "cmp x13, #0x10\n" + "bge 135f\n" + "tbz x13, #3, 130f\n" + "st1 { v8.4s }, [x12], #0x10\n" + "st1 { v9.4s }, [x12], #0x10\n" + "st1 { v12.4s }, [x24], #0x10\n" + "st1 { v13.4s }, [x24], #0x10\n" + "st1 { v16.4s }, [x23], #0x10\n" + "st1 { v17.4s }, [x23], #0x10\n" + "st1 { v20.4s }, [x22], #0x10\n" + "st1 { v21.4s }, [x22], #0x10\n" + "tbz x13, #2, 128f\n" + "st1 { v10.4s }, [x12], #0x10\n" + "st1 { v14.4s }, [x24], #0x10\n" + "st1 { v18.4s }, [x23], #0x10\n" + "st1 { v22.4s }, [x22], #0x10\n" + "tbz x13, #1, 127f\n" + "str d11, [x12], #0x8\n" + "str d15, [x24], #0x8\n" + "str d19, [x23], #0x8\n" + "str d23, [x22], #0x8\n" + "tbz x13, #0, 134f\n" + "st1 { v11.s }[2], [x12]\n" + "st1 { v15.s }[2], [x24]\n" + "st1 { v19.s }[2], [x23]\n" + "st1 { v23.s }[2], [x22]\n" + "b 134f\n" + "127:" // Height 4: Partial direct writeback: partial_1_12 + "tbz x13, #0, 134f\n" + "str s11, [x12, #0x0]\n" + "str s15, [x24, #0x0]\n" + "str s19, [x23, #0x0]\n" + "str s23, [x22, #0x0]\n" + "b 134f\n" + "128:" // Height 4: Partial direct writeback: partial_2_8 + "tbz x13, #1, 129f\n" + "str d10, [x12], #0x8\n" + "str d14, [x24], #0x8\n" + "str d18, [x23], #0x8\n" + "str d22, [x22], #0x8\n" + "tbz x13, #0, 134f\n" + "st1 { v10.s }[2], [x12]\n" + "st1 { v14.s }[2], [x24]\n" + "st1 { v18.s }[2], [x23]\n" + "st1 { v22.s }[2], [x22]\n" + "b 134f\n" + "129:" // Height 4: Partial direct writeback: partial_1_8 + "tbz x13, #0, 134f\n" + "str s10, [x12, #0x0]\n" + "str s14, [x24, #0x0]\n" + "str s18, [x23, #0x0]\n" + "str s22, [x22, #0x0]\n" + "b 134f\n" + "130:" // Height 4: Partial direct writeback: partial_4_0 + "tbz x13, #2, 132f\n" + "st1 { v8.4s }, [x12], #0x10\n" + "st1 { v12.4s }, [x24], #0x10\n" + "st1 { v16.4s }, [x23], #0x10\n" + "st1 { v20.4s }, [x22], #0x10\n" + "tbz x13, #1, 131f\n" + "str d9, [x12], #0x8\n" + "str d13, [x24], #0x8\n" + "str d17, [x23], #0x8\n" + "str d21, [x22], #0x8\n" + "tbz x13, #0, 134f\n" + "st1 { v9.s }[2], [x12]\n" + "st1 { v13.s }[2], [x24]\n" + "st1 { v17.s }[2], [x23]\n" + "st1 { v21.s }[2], [x22]\n" + "b 134f\n" + "131:" // Height 4: Partial direct writeback: partial_1_4 + "tbz x13, #0, 134f\n" + "str s9, [x12, #0x0]\n" + "str s13, [x24, #0x0]\n" + "str s17, [x23, #0x0]\n" + "str s21, [x22, #0x0]\n" + "b 134f\n" + "132:" // Height 4: Partial direct writeback: partial_2_0 + "tbz x13, #1, 133f\n" + "str d8, [x12], #0x8\n" + "str d12, [x24], #0x8\n" + "str d16, [x23], #0x8\n" + "str d20, [x22], #0x8\n" + "tbz x13, #0, 134f\n" + "st1 { v8.s }[2], [x12]\n" + "st1 { v12.s }[2], [x24]\n" + "st1 { v16.s }[2], [x23]\n" + "st1 { v20.s }[2], [x22]\n" + "b 134f\n" + "133:" // Height 4: Partial direct writeback: partial_1_0 + "str s8, [x12, #0x0]\n" + "str s12, [x24, #0x0]\n" + "str s16, [x23, #0x0]\n" + "str s20, [x22, #0x0]\n" + "134:" // Height 4: Partial direct writeback: Done + "b 136f\n" + "135:" // Height 4: Full writeback + "str q8, [x12, #0x0]\n" + "str q9, [x12, #0x10]\n" + "str q10, [x12, #0x20]\n" + "str q11, [x12, #0x30]\n" + "add x12, x12, #0x40\n" + "str q12, [x24, #0x0]\n" + "str q13, [x24, #0x10]\n" + "str q14, [x24, #0x20]\n" + "str q15, [x24, #0x30]\n" + "str q16, [x23, #0x0]\n" + "str q17, [x23, #0x10]\n" + "str q18, [x23, #0x20]\n" + "str q19, [x23, #0x30]\n" + "str q20, [x22, #0x0]\n" + "str q21, [x22, #0x10]\n" + "str q22, [x22, #0x20]\n" + "str q23, [x22, #0x30]\n" + "136:" // Height 4: Writeback done + "subs x13, x13, #0x10\n" + "bgt 104b\n" + "b 206f\n" + "137:" // Height 5 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "138:" // Height 5: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #2\n" + "add x9, x10, x19, LSL #2\n" + "add x28, x9, x19, LSL #2\n" + "add x19, x28, x19, LSL #2\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "cmp x13, #0xc\n" + "bgt 139f\n" + "cmp x13, #0x8\n" + "mov x28, x11\n" + "bgt 139f\n" + "cmp x13, #0x4\n" + "mov x9, x11\n" + "bgt 139f\n" + "mov x10, x11\n" + "139:" // Height 5: B setup done + "cbz x14, 140f\n" + "ldr q8, [x14, #0x0]\n" + "ldr q9, [x14, #0x10]\n" + "mov v12.16b, v8.16b\n" + "mov v13.16b, v9.16b\n" + "ldr q10, [x14, #0x20]\n" + "ldr q11, [x14, #0x30]\n" + "mov v14.16b, v10.16b\n" + "mov v15.16b, v11.16b\n" + "mov v16.16b, v8.16b\n" + "mov v17.16b, v9.16b\n" + "add x14, x14, #0x40\n" + "mov v18.16b, v10.16b\n" + "mov v19.16b, v11.16b\n" + "mov v20.16b, v8.16b\n" + "mov v21.16b, v9.16b\n" + "mov v22.16b, v10.16b\n" + "mov v23.16b, v11.16b\n" + "mov v24.16b, v8.16b\n" + "mov v25.16b, v9.16b\n" + "mov v26.16b, v10.16b\n" + "mov v27.16b, v11.16b\n" + "b 151f\n" + "140:" // Height 5: no bias + "tbz %x[flags], #0, 150f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "add x22, x23, x19, LSL #2\n" + "cmp x13, #0x10\n" + "add x21, x22, x19, LSL #2\n" + "bge 149f\n" + "tbz x13, #3, 144f\n" + "ld1 { v8.4s }, [x12], #0x10\n" + "ld1 { v12.4s }, [x24], #0x10\n" + "ld1 { v16.4s }, [x23], #0x10\n" + "ld1 { v20.4s }, [x22], #0x10\n" + "ld1 { v24.4s }, [x21], #0x10\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "ld1 { v13.4s }, [x24], #0x10\n" + "ld1 { v17.4s }, [x23], #0x10\n" + "ld1 { v21.4s }, [x22], #0x10\n" + "ld1 { v25.4s }, [x21], #0x10\n" + "tbz x13, #2, 142f\n" + "ld1 { v10.4s }, [x12], #0x10\n" + "ld1 { v14.4s }, [x24], #0x10\n" + "ld1 { v18.4s }, [x23], #0x10\n" + "ld1 { v22.4s }, [x22], #0x10\n" + "ld1 { v26.4s }, [x21], #0x10\n" + "tbz x13, #1, 141f\n" + "ldr d11, [x12], #0x8\n" + "ldr d15, [x24], #0x8\n" + "mov x19, #0x38\n" + "ldr d19, [x23], #0x8\n" + "ldr d23, [x22], #0x8\n" + "ldr d27, [x21], #0x8\n" + "tbz x13, #0, 148f\n" + "ld1 { v11.s }[2], [x12]\n" + "ld1 { v15.s }[2], [x24]\n" + "ld1 { v19.s }[2], [x23]\n" + "ld1 { v23.s }[2], [x22]\n" + "ld1 { v27.s }[2], [x21]\n" + "b 148f\n" + "141:" // Height 5: Partial accumulate: partial_1_12 + "mov x19, #0x30\n" + "tbz x13, #0, 148f\n" + "ldr s11, [x12, #0x0]\n" + "ldr s15, [x24, #0x0]\n" + "ldr s19, [x23, #0x0]\n" + "ldr s23, [x22, #0x0]\n" + "ldr s27, [x21, #0x0]\n" + "b 148f\n" + "142:" // Height 5: Partial accumulate: partial_2_8 + "tbz x13, #1, 143f\n" + "ldr d10, [x12], #0x8\n" + "ldr d14, [x24], #0x8\n" + "mov x19, #0x28\n" + "ldr d18, [x23], #0x8\n" + "ldr d22, [x22], #0x8\n" + "ldr d26, [x21], #0x8\n" + "tbz x13, #0, 148f\n" + "ld1 { v10.s }[2], [x12]\n" + "ld1 { v14.s }[2], [x24]\n" + "ld1 { v18.s }[2], [x23]\n" + "ld1 { v22.s }[2], [x22]\n" + "ld1 { v26.s }[2], [x21]\n" + "b 148f\n" + "143:" // Height 5: Partial accumulate: partial_1_8 + "mov x19, #0x20\n" + "tbz x13, #0, 148f\n" + "ldr s10, [x12, #0x0]\n" + "ldr s14, [x24, #0x0]\n" + "ldr s18, [x23, #0x0]\n" + "ldr s22, [x22, #0x0]\n" + "ldr s26, [x21, #0x0]\n" + "b 148f\n" + "144:" // Height 5: Partial accumulate: partial_4_0 + "tbz x13, #2, 146f\n" + "ld1 { v8.4s }, [x12], #0x10\n" + "ld1 { v12.4s }, [x24], #0x10\n" + "ld1 { v16.4s }, [x23], #0x10\n" + "ld1 { v20.4s }, [x22], #0x10\n" + "ld1 { v24.4s }, [x21], #0x10\n" + "tbz x13, #1, 145f\n" + "ldr d9, [x12], #0x8\n" + "ldr d13, [x24], #0x8\n" + "mov x19, #0x18\n" + "ldr d17, [x23], #0x8\n" + "ldr d21, [x22], #0x8\n" + "ldr d25, [x21], #0x8\n" + "tbz x13, #0, 148f\n" + "ld1 { v9.s }[2], [x12]\n" + "ld1 { v13.s }[2], [x24]\n" + "ld1 { v17.s }[2], [x23]\n" + "ld1 { v21.s }[2], [x22]\n" + "ld1 { v25.s }[2], [x21]\n" + "b 148f\n" + "145:" // Height 5: Partial accumulate: partial_1_4 + "mov x19, #0x10\n" + "tbz x13, #0, 148f\n" + "ldr s9, [x12, #0x0]\n" + "ldr s13, [x24, #0x0]\n" + "ldr s17, [x23, #0x0]\n" + "ldr s21, [x22, #0x0]\n" + "ldr s25, [x21, #0x0]\n" + "b 148f\n" + "146:" // Height 5: Partial accumulate: partial_2_0 + "tbz x13, #1, 147f\n" + "ldr d8, [x12], #0x8\n" + "ldr d12, [x24], #0x8\n" + "mov x19, #0x8\n" + "ldr d16, [x23], #0x8\n" + "ldr d20, [x22], #0x8\n" + "ldr d24, [x21], #0x8\n" + "tbz x13, #0, 148f\n" + "ld1 { v8.s }[2], [x12]\n" + "ld1 { v12.s }[2], [x24]\n" + "ld1 { v16.s }[2], [x23]\n" + "ld1 { v20.s }[2], [x22]\n" + "ld1 { v24.s }[2], [x21]\n" + "b 148f\n" + "147:" // Height 5: Partial accumulate: partial_1_0 + "ldr s8, [x12, #0x0]\n" + "ldr s12, [x24, #0x0]\n" + "mov x19, #0x0\n" + "ldr s16, [x23, #0x0]\n" + "ldr s20, [x22, #0x0]\n" + "ldr s24, [x21, #0x0]\n" + "148:" // Height 5: Partial accumulate: Done + "sub x12, x12, x19\n" + "b 151f\n" + "149:" // Height 5: full accumulate + "ldr q8, [x12, #0x0]\n" + "ldr q9, [x12, #0x10]\n" + "ldr q10, [x12, #0x20]\n" + "ldr q11, [x12, #0x30]\n" + "ldr q12, [x24, #0x0]\n" + "ldr q13, [x24, #0x10]\n" + "ldr q14, [x24, #0x20]\n" + "ldr q15, [x24, #0x30]\n" + "ldr q16, [x23, #0x0]\n" + "ldr q17, [x23, #0x10]\n" + "ldr q18, [x23, #0x20]\n" + "ldr q19, [x23, #0x30]\n" + "ldr q20, [x22, #0x0]\n" + "ldr q21, [x22, #0x10]\n" + "ldr q22, [x22, #0x20]\n" + "ldr q23, [x22, #0x30]\n" + "ldr q24, [x21, #0x0]\n" + "ldr q25, [x21, #0x10]\n" + "ldr q26, [x21, #0x20]\n" + "ldr q27, [x21, #0x30]\n" + "b 151f\n" + "150:" // Height 5: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v26.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "151:" // Height 5: setup done + "mov x27, #0x0\n" + "152:" // Height 5: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 153f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "ldr x22, [x20, #0x18]\n" + "ldr x21, [x20, #0x20]\n" + "cbnz x27, 154f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #2\n" + "add x24, x24, x19, LSL #2\n" + "add x23, x23, x19, LSL #2\n" + "add x22, x22, x19, LSL #2\n" + "add x21, x21, x19, LSL #2\n" + "b 154f\n" + "153:" // Height 5: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "add x22, x23, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "154:" // Height 5: input setup done + "cmp x26, #0x4\n" + "blt 157f\n" + "ldr q0, [x25, #0x0]\n" + "ldr q1, [x24, #0x0]\n" + "cmp x26, #0x8\n" + "ldr q2, [x23, #0x0]\n" + "ldr q3, [x22, #0x0]\n" + "ldr q4, [x21, #0x0]\n" + "ldr q6, [x11, #0x0]\n" + "ldr q7, [x10, #0x0]\n" + "blt 156f\n" + "155:" // Height 5: Multiply loop: Main loop head + "fmla v8.4s, v6.4s, v0.s[0]\n" + "fmla v12.4s, v6.4s, v1.s[0]\n" + "sub x26, x26, #0x4\n" + "cmp x26, #0x8\n" + "fmla v16.4s, v6.4s, v2.s[0]\n" + "fmla v20.4s, v6.4s, v3.s[0]\n" + "add x25, x25, #0x10\n" + "add x24, x24, #0x10\n" + "fmla v24.4s, v6.4s, v4.s[0]\n" + "ldr q6, [x9, #0x0]\n" + "fmla v9.4s, v7.4s, v0.s[0]\n" + "add x23, x23, #0x10\n" + "fmla v13.4s, v7.4s, v1.s[0]\n" + "fmla v17.4s, v7.4s, v2.s[0]\n" + "add x22, x22, #0x10\n" + "add x21, x21, #0x10\n" + "fmla v21.4s, v7.4s, v3.s[0]\n" + "fmla v25.4s, v7.4s, v4.s[0]\n" + "ldr q7, [x28, #0x0]\n" + "fmla v10.4s, v6.4s, v0.s[0]\n" + "fmla v14.4s, v6.4s, v1.s[0]\n" + "fmla v18.4s, v6.4s, v2.s[0]\n" + "fmla v22.4s, v6.4s, v3.s[0]\n" + "fmla v26.4s, v6.4s, v4.s[0]\n" + "ldr q6, [x11, #0x10]\n" + "fmla v11.4s, v7.4s, v0.s[0]\n" + "fmla v15.4s, v7.4s, v1.s[0]\n" + "fmla v19.4s, v7.4s, v2.s[0]\n" + "fmla v23.4s, v7.4s, v3.s[0]\n" + "fmla v27.4s, v7.4s, v4.s[0]\n" + "ldr q7, [x10, #0x10]\n" + "fmla v8.4s, v6.4s, v0.s[1]\n" + "fmla v12.4s, v6.4s, v1.s[1]\n" + "fmla v16.4s, v6.4s, v2.s[1]\n" + "fmla v20.4s, v6.4s, v3.s[1]\n" + "fmla v24.4s, v6.4s, v4.s[1]\n" + "ldr q6, [x9, #0x10]\n" + "fmla v9.4s, v7.4s, v0.s[1]\n" + "fmla v13.4s, v7.4s, v1.s[1]\n" + "fmla v17.4s, v7.4s, v2.s[1]\n" + "fmla v21.4s, v7.4s, v3.s[1]\n" + "fmla v25.4s, v7.4s, v4.s[1]\n" + "ldr q7, [x28, #0x10]\n" + "fmla v10.4s, v6.4s, v0.s[1]\n" + "fmla v14.4s, v6.4s, v1.s[1]\n" + "fmla v18.4s, v6.4s, v2.s[1]\n" + "fmla v22.4s, v6.4s, v3.s[1]\n" + "fmla v26.4s, v6.4s, v4.s[1]\n" + "ldr q6, [x11, #0x20]\n" + "fmla v11.4s, v7.4s, v0.s[1]\n" + "fmla v15.4s, v7.4s, v1.s[1]\n" + "fmla v19.4s, v7.4s, v2.s[1]\n" + "fmla v23.4s, v7.4s, v3.s[1]\n" + "fmla v27.4s, v7.4s, v4.s[1]\n" + "ldr q7, [x10, #0x20]\n" + "fmla v8.4s, v6.4s, v0.s[2]\n" + "fmla v12.4s, v6.4s, v1.s[2]\n" + "fmla v16.4s, v6.4s, v2.s[2]\n" + "fmla v20.4s, v6.4s, v3.s[2]\n" + "fmla v24.4s, v6.4s, v4.s[2]\n" + "ldr q6, [x9, #0x20]\n" + "fmla v9.4s, v7.4s, v0.s[2]\n" + "fmla v13.4s, v7.4s, v1.s[2]\n" + "fmla v17.4s, v7.4s, v2.s[2]\n" + "fmla v21.4s, v7.4s, v3.s[2]\n" + "fmla v25.4s, v7.4s, v4.s[2]\n" + "ldr q7, [x28, #0x20]\n" + "fmla v10.4s, v6.4s, v0.s[2]\n" + "fmla v14.4s, v6.4s, v1.s[2]\n" + "fmla v18.4s, v6.4s, v2.s[2]\n" + "fmla v22.4s, v6.4s, v3.s[2]\n" + "fmla v26.4s, v6.4s, v4.s[2]\n" + "ldr q6, [x11, #0x30]\n" + "fmla v11.4s, v7.4s, v0.s[2]\n" + "add x11, x11, #0x40\n" + "fmla v15.4s, v7.4s, v1.s[2]\n" + "fmla v19.4s, v7.4s, v2.s[2]\n" + "fmla v23.4s, v7.4s, v3.s[2]\n" + "fmla v27.4s, v7.4s, v4.s[2]\n" + "ldr q7, [x10, #0x30]\n" + "add x10, x10, #0x40\n" + "fmla v8.4s, v6.4s, v0.s[3]\n" + "fmla v12.4s, v6.4s, v1.s[3]\n" + "fmla v16.4s, v6.4s, v2.s[3]\n" + "fmla v20.4s, v6.4s, v3.s[3]\n" + "fmla v24.4s, v6.4s, v4.s[3]\n" + "ldr q6, [x9, #0x30]\n" + "fmla v9.4s, v7.4s, v0.s[3]\n" + "add x9, x9, #0x40\n" + "fmla v13.4s, v7.4s, v1.s[3]\n" + "fmla v17.4s, v7.4s, v2.s[3]\n" + "fmla v21.4s, v7.4s, v3.s[3]\n" + "fmla v25.4s, v7.4s, v4.s[3]\n" + "ldr q7, [x28, #0x30]\n" + "add x28, x28, #0x40\n" + "fmla v10.4s, v6.4s, v0.s[3]\n" + "fmla v14.4s, v6.4s, v1.s[3]\n" + "fmla v18.4s, v6.4s, v2.s[3]\n" + "fmla v22.4s, v6.4s, v3.s[3]\n" + "fmla v26.4s, v6.4s, v4.s[3]\n" + "ldr q6, [x11, #0x0]\n" + "fmla v11.4s, v7.4s, v0.s[3]\n" + "ldr q0, [x25, #0x0]\n" + "fmla v15.4s, v7.4s, v1.s[3]\n" + "ldr q1, [x24, #0x0]\n" + "fmla v19.4s, v7.4s, v2.s[3]\n" + "ldr q2, [x23, #0x0]\n" + "fmla v23.4s, v7.4s, v3.s[3]\n" + "ldr q3, [x22, #0x0]\n" + "fmla v27.4s, v7.4s, v4.s[3]\n" + "ldr q4, [x21, #0x0]\n" + "ldr q7, [x10, #0x0]\n" + "bge 155b\n" + "156:" // Height 5: Multiply loop: Single iteration only + "fmla v8.4s, v6.4s, v0.s[0]\n" + "fmla v12.4s, v6.4s, v1.s[0]\n" + "sub x26, x26, #0x4\n" + "add x25, x25, #0x10\n" + "fmla v16.4s, v6.4s, v2.s[0]\n" + "fmla v20.4s, v6.4s, v3.s[0]\n" + "add x24, x24, #0x10\n" + "add x23, x23, #0x10\n" + "fmla v24.4s, v6.4s, v4.s[0]\n" + "ldr q6, [x9, #0x0]\n" + "fmla v9.4s, v7.4s, v0.s[0]\n" + "add x22, x22, #0x10\n" + "fmla v13.4s, v7.4s, v1.s[0]\n" + "fmla v17.4s, v7.4s, v2.s[0]\n" + "add x21, x21, #0x10\n" + "fmla v21.4s, v7.4s, v3.s[0]\n" + "fmla v25.4s, v7.4s, v4.s[0]\n" + "ldr q7, [x28, #0x0]\n" + "fmla v10.4s, v6.4s, v0.s[0]\n" + "fmla v14.4s, v6.4s, v1.s[0]\n" + "fmla v18.4s, v6.4s, v2.s[0]\n" + "fmla v22.4s, v6.4s, v3.s[0]\n" + "fmla v26.4s, v6.4s, v4.s[0]\n" + "ldr q6, [x11, #0x10]\n" + "fmla v11.4s, v7.4s, v0.s[0]\n" + "fmla v15.4s, v7.4s, v1.s[0]\n" + "fmla v19.4s, v7.4s, v2.s[0]\n" + "fmla v23.4s, v7.4s, v3.s[0]\n" + "fmla v27.4s, v7.4s, v4.s[0]\n" + "ldr q7, [x10, #0x10]\n" + "fmla v8.4s, v6.4s, v0.s[1]\n" + "fmla v12.4s, v6.4s, v1.s[1]\n" + "fmla v16.4s, v6.4s, v2.s[1]\n" + "fmla v20.4s, v6.4s, v3.s[1]\n" + "fmla v24.4s, v6.4s, v4.s[1]\n" + "ldr q6, [x9, #0x10]\n" + "fmla v9.4s, v7.4s, v0.s[1]\n" + "fmla v13.4s, v7.4s, v1.s[1]\n" + "fmla v17.4s, v7.4s, v2.s[1]\n" + "fmla v21.4s, v7.4s, v3.s[1]\n" + "fmla v25.4s, v7.4s, v4.s[1]\n" + "ldr q7, [x28, #0x10]\n" + "fmla v10.4s, v6.4s, v0.s[1]\n" + "fmla v14.4s, v6.4s, v1.s[1]\n" + "fmla v18.4s, v6.4s, v2.s[1]\n" + "fmla v22.4s, v6.4s, v3.s[1]\n" + "fmla v26.4s, v6.4s, v4.s[1]\n" + "ldr q6, [x11, #0x20]\n" + "fmla v11.4s, v7.4s, v0.s[1]\n" + "fmla v15.4s, v7.4s, v1.s[1]\n" + "fmla v19.4s, v7.4s, v2.s[1]\n" + "fmla v23.4s, v7.4s, v3.s[1]\n" + "fmla v27.4s, v7.4s, v4.s[1]\n" + "ldr q7, [x10, #0x20]\n" + "fmla v8.4s, v6.4s, v0.s[2]\n" + "fmla v12.4s, v6.4s, v1.s[2]\n" + "fmla v16.4s, v6.4s, v2.s[2]\n" + "fmla v20.4s, v6.4s, v3.s[2]\n" + "fmla v24.4s, v6.4s, v4.s[2]\n" + "ldr q6, [x9, #0x20]\n" + "fmla v9.4s, v7.4s, v0.s[2]\n" + "fmla v13.4s, v7.4s, v1.s[2]\n" + "fmla v17.4s, v7.4s, v2.s[2]\n" + "fmla v21.4s, v7.4s, v3.s[2]\n" + "fmla v25.4s, v7.4s, v4.s[2]\n" + "ldr q7, [x28, #0x20]\n" + "fmla v10.4s, v6.4s, v0.s[2]\n" + "fmla v14.4s, v6.4s, v1.s[2]\n" + "fmla v18.4s, v6.4s, v2.s[2]\n" + "fmla v22.4s, v6.4s, v3.s[2]\n" + "fmla v26.4s, v6.4s, v4.s[2]\n" + "ldr q6, [x11, #0x30]\n" + "fmla v11.4s, v7.4s, v0.s[2]\n" + "add x11, x11, #0x40\n" + "fmla v15.4s, v7.4s, v1.s[2]\n" + "fmla v19.4s, v7.4s, v2.s[2]\n" + "fmla v23.4s, v7.4s, v3.s[2]\n" + "fmla v27.4s, v7.4s, v4.s[2]\n" + "ldr q7, [x10, #0x30]\n" + "add x10, x10, #0x40\n" + "fmla v8.4s, v6.4s, v0.s[3]\n" + "fmla v12.4s, v6.4s, v1.s[3]\n" + "fmla v16.4s, v6.4s, v2.s[3]\n" + "fmla v20.4s, v6.4s, v3.s[3]\n" + "fmla v24.4s, v6.4s, v4.s[3]\n" + "ldr q6, [x9, #0x30]\n" + "fmla v9.4s, v7.4s, v0.s[3]\n" + "add x9, x9, #0x40\n" + "fmla v13.4s, v7.4s, v1.s[3]\n" + "fmla v17.4s, v7.4s, v2.s[3]\n" + "fmla v21.4s, v7.4s, v3.s[3]\n" + "fmla v25.4s, v7.4s, v4.s[3]\n" + "ldr q7, [x28, #0x30]\n" + "add x28, x28, #0x40\n" + "fmla v10.4s, v6.4s, v0.s[3]\n" + "fmla v14.4s, v6.4s, v1.s[3]\n" + "fmla v18.4s, v6.4s, v2.s[3]\n" + "fmla v22.4s, v6.4s, v3.s[3]\n" + "fmla v26.4s, v6.4s, v4.s[3]\n" + "fmla v11.4s, v7.4s, v0.s[3]\n" + "fmla v15.4s, v7.4s, v1.s[3]\n" + "fmla v19.4s, v7.4s, v2.s[3]\n" + "fmla v23.4s, v7.4s, v3.s[3]\n" + "fmla v27.4s, v7.4s, v4.s[3]\n" + "157:" // Height 5: Multiply loop: Main loop skip + "cbz x26, 159f\n" + "158:" // Height 5: Multiply loop: Odd block loop + "ldr s0, [x25], #0x4\n" + "ldr s1, [x24], #0x4\n" + "sub x26, x26, #0x1\n" + "ldr s2, [x23], #0x4\n" + "ldr s3, [x22], #0x4\n" + "ldr s4, [x21], #0x4\n" + "ldr q6, [x11, #0x0]\n" + "fmla v8.4s, v6.4s, v0.s[0]\n" + "fmla v12.4s, v6.4s, v1.s[0]\n" + "ldr q7, [x10, #0x0]\n" + "fmla v16.4s, v6.4s, v2.s[0]\n" + "fmla v20.4s, v6.4s, v3.s[0]\n" + "add x11, x11, #0x10\n" + "fmla v24.4s, v6.4s, v4.s[0]\n" + "ldr q6, [x9, #0x0]\n" + "fmla v9.4s, v7.4s, v0.s[0]\n" + "add x10, x10, #0x10\n" + "fmla v13.4s, v7.4s, v1.s[0]\n" + "fmla v17.4s, v7.4s, v2.s[0]\n" + "add x9, x9, #0x10\n" + "fmla v21.4s, v7.4s, v3.s[0]\n" + "fmla v25.4s, v7.4s, v4.s[0]\n" + "ldr q7, [x28, #0x0]\n" + "add x28, x28, #0x10\n" + "fmla v10.4s, v6.4s, v0.s[0]\n" + "fmla v14.4s, v6.4s, v1.s[0]\n" + "fmla v18.4s, v6.4s, v2.s[0]\n" + "fmla v22.4s, v6.4s, v3.s[0]\n" + "fmla v26.4s, v6.4s, v4.s[0]\n" + "fmla v11.4s, v7.4s, v0.s[0]\n" + "fmla v15.4s, v7.4s, v1.s[0]\n" + "fmla v19.4s, v7.4s, v2.s[0]\n" + "fmla v23.4s, v7.4s, v3.s[0]\n" + "fmla v27.4s, v7.4s, v4.s[0]\n" + "cbnz x26, 158b\n" + "159:" // Height 5: Multiply loop: No odd multiplies + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 152b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "add x22, x23, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "tbz %x[flags], #1, 160f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1r { v1.4s }, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1r { v0.4s }, [x19]\n" + "fmin v8.4s, v8.4s, v1.4s\n" + "fmin v9.4s, v9.4s, v1.4s\n" + "fmin v10.4s, v10.4s, v1.4s\n" + "fmin v11.4s, v11.4s, v1.4s\n" + "fmin v12.4s, v12.4s, v1.4s\n" + "fmin v13.4s, v13.4s, v1.4s\n" + "fmin v14.4s, v14.4s, v1.4s\n" + "fmin v15.4s, v15.4s, v1.4s\n" + "fmin v16.4s, v16.4s, v1.4s\n" + "fmin v17.4s, v17.4s, v1.4s\n" + "fmin v18.4s, v18.4s, v1.4s\n" + "fmin v19.4s, v19.4s, v1.4s\n" + "fmin v20.4s, v20.4s, v1.4s\n" + "fmin v21.4s, v21.4s, v1.4s\n" + "fmin v22.4s, v22.4s, v1.4s\n" + "fmin v23.4s, v23.4s, v1.4s\n" + "fmin v24.4s, v24.4s, v1.4s\n" + "fmin v25.4s, v25.4s, v1.4s\n" + "fmin v26.4s, v26.4s, v1.4s\n" + "fmin v27.4s, v27.4s, v1.4s\n" + "fmax v8.4s, v8.4s, v0.4s\n" + "fmax v9.4s, v9.4s, v0.4s\n" + "fmax v10.4s, v10.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + "fmax v12.4s, v12.4s, v0.4s\n" + "fmax v13.4s, v13.4s, v0.4s\n" + "fmax v14.4s, v14.4s, v0.4s\n" + "fmax v15.4s, v15.4s, v0.4s\n" + "fmax v16.4s, v16.4s, v0.4s\n" + "fmax v17.4s, v17.4s, v0.4s\n" + "fmax v18.4s, v18.4s, v0.4s\n" + "fmax v19.4s, v19.4s, v0.4s\n" + "fmax v20.4s, v20.4s, v0.4s\n" + "fmax v21.4s, v21.4s, v0.4s\n" + "fmax v22.4s, v22.4s, v0.4s\n" + "fmax v23.4s, v23.4s, v0.4s\n" + "fmax v24.4s, v24.4s, v0.4s\n" + "fmax v25.4s, v25.4s, v0.4s\n" + "fmax v26.4s, v26.4s, v0.4s\n" + "fmax v27.4s, v27.4s, v0.4s\n" + "160:" // Height 5: No activation + "cmp x13, #0x10\n" + "bge 169f\n" + "tbz x13, #3, 164f\n" + "st1 { v8.4s }, [x12], #0x10\n" + "st1 { v9.4s }, [x12], #0x10\n" + "st1 { v12.4s }, [x24], #0x10\n" + "st1 { v13.4s }, [x24], #0x10\n" + "st1 { v16.4s }, [x23], #0x10\n" + "st1 { v17.4s }, [x23], #0x10\n" + "st1 { v20.4s }, [x22], #0x10\n" + "st1 { v21.4s }, [x22], #0x10\n" + "st1 { v24.4s }, [x21], #0x10\n" + "st1 { v25.4s }, [x21], #0x10\n" + "tbz x13, #2, 162f\n" + "st1 { v10.4s }, [x12], #0x10\n" + "st1 { v14.4s }, [x24], #0x10\n" + "st1 { v18.4s }, [x23], #0x10\n" + "st1 { v22.4s }, [x22], #0x10\n" + "st1 { v26.4s }, [x21], #0x10\n" + "tbz x13, #1, 161f\n" + "str d11, [x12], #0x8\n" + "str d15, [x24], #0x8\n" + "str d19, [x23], #0x8\n" + "str d23, [x22], #0x8\n" + "str d27, [x21], #0x8\n" + "tbz x13, #0, 168f\n" + "st1 { v11.s }[2], [x12]\n" + "st1 { v15.s }[2], [x24]\n" + "st1 { v19.s }[2], [x23]\n" + "st1 { v23.s }[2], [x22]\n" + "st1 { v27.s }[2], [x21]\n" + "b 168f\n" + "161:" // Height 5: Partial direct writeback: partial_1_12 + "tbz x13, #0, 168f\n" + "str s11, [x12, #0x0]\n" + "str s15, [x24, #0x0]\n" + "str s19, [x23, #0x0]\n" + "str s23, [x22, #0x0]\n" + "str s27, [x21, #0x0]\n" + "b 168f\n" + "162:" // Height 5: Partial direct writeback: partial_2_8 + "tbz x13, #1, 163f\n" + "str d10, [x12], #0x8\n" + "str d14, [x24], #0x8\n" + "str d18, [x23], #0x8\n" + "str d22, [x22], #0x8\n" + "str d26, [x21], #0x8\n" + "tbz x13, #0, 168f\n" + "st1 { v10.s }[2], [x12]\n" + "st1 { v14.s }[2], [x24]\n" + "st1 { v18.s }[2], [x23]\n" + "st1 { v22.s }[2], [x22]\n" + "st1 { v26.s }[2], [x21]\n" + "b 168f\n" + "163:" // Height 5: Partial direct writeback: partial_1_8 + "tbz x13, #0, 168f\n" + "str s10, [x12, #0x0]\n" + "str s14, [x24, #0x0]\n" + "str s18, [x23, #0x0]\n" + "str s22, [x22, #0x0]\n" + "str s26, [x21, #0x0]\n" + "b 168f\n" + "164:" // Height 5: Partial direct writeback: partial_4_0 + "tbz x13, #2, 166f\n" + "st1 { v8.4s }, [x12], #0x10\n" + "st1 { v12.4s }, [x24], #0x10\n" + "st1 { v16.4s }, [x23], #0x10\n" + "st1 { v20.4s }, [x22], #0x10\n" + "st1 { v24.4s }, [x21], #0x10\n" + "tbz x13, #1, 165f\n" + "str d9, [x12], #0x8\n" + "str d13, [x24], #0x8\n" + "str d17, [x23], #0x8\n" + "str d21, [x22], #0x8\n" + "str d25, [x21], #0x8\n" + "tbz x13, #0, 168f\n" + "st1 { v9.s }[2], [x12]\n" + "st1 { v13.s }[2], [x24]\n" + "st1 { v17.s }[2], [x23]\n" + "st1 { v21.s }[2], [x22]\n" + "st1 { v25.s }[2], [x21]\n" + "b 168f\n" + "165:" // Height 5: Partial direct writeback: partial_1_4 + "tbz x13, #0, 168f\n" + "str s9, [x12, #0x0]\n" + "str s13, [x24, #0x0]\n" + "str s17, [x23, #0x0]\n" + "str s21, [x22, #0x0]\n" + "str s25, [x21, #0x0]\n" + "b 168f\n" + "166:" // Height 5: Partial direct writeback: partial_2_0 + "tbz x13, #1, 167f\n" + "str d8, [x12], #0x8\n" + "str d12, [x24], #0x8\n" + "str d16, [x23], #0x8\n" + "str d20, [x22], #0x8\n" + "str d24, [x21], #0x8\n" + "tbz x13, #0, 168f\n" + "st1 { v8.s }[2], [x12]\n" + "st1 { v12.s }[2], [x24]\n" + "st1 { v16.s }[2], [x23]\n" + "st1 { v20.s }[2], [x22]\n" + "st1 { v24.s }[2], [x21]\n" + "b 168f\n" + "167:" // Height 5: Partial direct writeback: partial_1_0 + "str s8, [x12, #0x0]\n" + "str s12, [x24, #0x0]\n" + "str s16, [x23, #0x0]\n" + "str s20, [x22, #0x0]\n" + "str s24, [x21, #0x0]\n" + "168:" // Height 5: Partial direct writeback: Done + "b 170f\n" + "169:" // Height 5: Full writeback + "str q8, [x12, #0x0]\n" + "str q9, [x12, #0x10]\n" + "str q10, [x12, #0x20]\n" + "str q11, [x12, #0x30]\n" + "add x12, x12, #0x40\n" + "str q12, [x24, #0x0]\n" + "str q13, [x24, #0x10]\n" + "str q14, [x24, #0x20]\n" + "str q15, [x24, #0x30]\n" + "str q16, [x23, #0x0]\n" + "str q17, [x23, #0x10]\n" + "str q18, [x23, #0x20]\n" + "str q19, [x23, #0x30]\n" + "str q20, [x22, #0x0]\n" + "str q21, [x22, #0x10]\n" + "str q22, [x22, #0x20]\n" + "str q23, [x22, #0x30]\n" + "str q24, [x21, #0x0]\n" + "str q25, [x21, #0x10]\n" + "str q26, [x21, #0x20]\n" + "str q27, [x21, #0x30]\n" + "170:" // Height 5: Writeback done + "subs x13, x13, #0x10\n" + "bgt 138b\n" + "b 206f\n" + "171:" // Height 6 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x20, #0x18\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "mov x14, %x[bias]\n" + "mov x12, %x[output_ptr]\n" + "madd %x[output_ptr], x19, x20, %x[output_ptr]\n" + "172:" // Height 6: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #2\n" + "add x9, x10, x19, LSL #2\n" + "add x28, x9, x19, LSL #2\n" + "add x19, x28, x19, LSL #2\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "cmp x13, #0xc\n" + "bgt 173f\n" + "cmp x13, #0x8\n" + "mov x28, x11\n" + "bgt 173f\n" + "cmp x13, #0x4\n" + "mov x9, x11\n" + "bgt 173f\n" + "mov x10, x11\n" + "173:" // Height 6: B setup done + "cbz x14, 174f\n" + "ldr q8, [x14, #0x0]\n" + "ldr q9, [x14, #0x10]\n" + "mov v12.16b, v8.16b\n" + "mov v13.16b, v9.16b\n" + "ldr q10, [x14, #0x20]\n" + "ldr q11, [x14, #0x30]\n" + "mov v14.16b, v10.16b\n" + "mov v15.16b, v11.16b\n" + "mov v16.16b, v8.16b\n" + "mov v17.16b, v9.16b\n" + "add x14, x14, #0x40\n" + "mov v18.16b, v10.16b\n" + "mov v19.16b, v11.16b\n" + "mov v20.16b, v8.16b\n" + "mov v21.16b, v9.16b\n" + "mov v22.16b, v10.16b\n" + "mov v23.16b, v11.16b\n" + "mov v24.16b, v8.16b\n" + "mov v25.16b, v9.16b\n" + "mov v26.16b, v10.16b\n" + "mov v27.16b, v11.16b\n" + "mov v28.16b, v8.16b\n" + "mov v29.16b, v9.16b\n" + "mov v30.16b, v10.16b\n" + "mov v31.16b, v11.16b\n" + "b 185f\n" + "174:" // Height 6: no bias + "tbz %x[flags], #0, 184f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "add x22, x23, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "cmp x13, #0x10\n" + "add x20, x21, x19, LSL #2\n" + "bge 183f\n" + "tbz x13, #3, 178f\n" + "ld1 { v8.4s }, [x12], #0x10\n" + "ld1 { v12.4s }, [x24], #0x10\n" + "ld1 { v16.4s }, [x23], #0x10\n" + "ld1 { v20.4s }, [x22], #0x10\n" + "ld1 { v24.4s }, [x21], #0x10\n" + "ld1 { v28.4s }, [x20], #0x10\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "ld1 { v13.4s }, [x24], #0x10\n" + "ld1 { v17.4s }, [x23], #0x10\n" + "ld1 { v21.4s }, [x22], #0x10\n" + "ld1 { v25.4s }, [x21], #0x10\n" + "ld1 { v29.4s }, [x20], #0x10\n" + "tbz x13, #2, 176f\n" + "ld1 { v10.4s }, [x12], #0x10\n" + "ld1 { v14.4s }, [x24], #0x10\n" + "ld1 { v18.4s }, [x23], #0x10\n" + "ld1 { v22.4s }, [x22], #0x10\n" + "ld1 { v26.4s }, [x21], #0x10\n" + "ld1 { v30.4s }, [x20], #0x10\n" + "tbz x13, #1, 175f\n" + "ldr d11, [x12], #0x8\n" + "ldr d15, [x24], #0x8\n" + "mov x19, #0x38\n" + "ldr d19, [x23], #0x8\n" + "ldr d23, [x22], #0x8\n" + "ldr d27, [x21], #0x8\n" + "ldr d31, [x20], #0x8\n" + "tbz x13, #0, 182f\n" + "ld1 { v11.s }[2], [x12]\n" + "ld1 { v15.s }[2], [x24]\n" + "ld1 { v19.s }[2], [x23]\n" + "ld1 { v23.s }[2], [x22]\n" + "ld1 { v27.s }[2], [x21]\n" + "ld1 { v31.s }[2], [x20]\n" + "b 182f\n" + "175:" // Height 6: Partial accumulate: partial_1_12 + "mov x19, #0x30\n" + "tbz x13, #0, 182f\n" + "ldr s11, [x12, #0x0]\n" + "ldr s15, [x24, #0x0]\n" + "ldr s19, [x23, #0x0]\n" + "ldr s23, [x22, #0x0]\n" + "ldr s27, [x21, #0x0]\n" + "ldr s31, [x20, #0x0]\n" + "b 182f\n" + "176:" // Height 6: Partial accumulate: partial_2_8 + "tbz x13, #1, 177f\n" + "ldr d10, [x12], #0x8\n" + "ldr d14, [x24], #0x8\n" + "mov x19, #0x28\n" + "ldr d18, [x23], #0x8\n" + "ldr d22, [x22], #0x8\n" + "ldr d26, [x21], #0x8\n" + "ldr d30, [x20], #0x8\n" + "tbz x13, #0, 182f\n" + "ld1 { v10.s }[2], [x12]\n" + "ld1 { v14.s }[2], [x24]\n" + "ld1 { v18.s }[2], [x23]\n" + "ld1 { v22.s }[2], [x22]\n" + "ld1 { v26.s }[2], [x21]\n" + "ld1 { v30.s }[2], [x20]\n" + "b 182f\n" + "177:" // Height 6: Partial accumulate: partial_1_8 + "mov x19, #0x20\n" + "tbz x13, #0, 182f\n" + "ldr s10, [x12, #0x0]\n" + "ldr s14, [x24, #0x0]\n" + "ldr s18, [x23, #0x0]\n" + "ldr s22, [x22, #0x0]\n" + "ldr s26, [x21, #0x0]\n" + "ldr s30, [x20, #0x0]\n" + "b 182f\n" + "178:" // Height 6: Partial accumulate: partial_4_0 + "tbz x13, #2, 180f\n" + "ld1 { v8.4s }, [x12], #0x10\n" + "ld1 { v12.4s }, [x24], #0x10\n" + "ld1 { v16.4s }, [x23], #0x10\n" + "ld1 { v20.4s }, [x22], #0x10\n" + "ld1 { v24.4s }, [x21], #0x10\n" + "ld1 { v28.4s }, [x20], #0x10\n" + "tbz x13, #1, 179f\n" + "ldr d9, [x12], #0x8\n" + "ldr d13, [x24], #0x8\n" + "mov x19, #0x18\n" + "ldr d17, [x23], #0x8\n" + "ldr d21, [x22], #0x8\n" + "ldr d25, [x21], #0x8\n" + "ldr d29, [x20], #0x8\n" + "tbz x13, #0, 182f\n" + "ld1 { v9.s }[2], [x12]\n" + "ld1 { v13.s }[2], [x24]\n" + "ld1 { v17.s }[2], [x23]\n" + "ld1 { v21.s }[2], [x22]\n" + "ld1 { v25.s }[2], [x21]\n" + "ld1 { v29.s }[2], [x20]\n" + "b 182f\n" + "179:" // Height 6: Partial accumulate: partial_1_4 + "mov x19, #0x10\n" + "tbz x13, #0, 182f\n" + "ldr s9, [x12, #0x0]\n" + "ldr s13, [x24, #0x0]\n" + "ldr s17, [x23, #0x0]\n" + "ldr s21, [x22, #0x0]\n" + "ldr s25, [x21, #0x0]\n" + "ldr s29, [x20, #0x0]\n" + "b 182f\n" + "180:" // Height 6: Partial accumulate: partial_2_0 + "tbz x13, #1, 181f\n" + "ldr d8, [x12], #0x8\n" + "ldr d12, [x24], #0x8\n" + "mov x19, #0x8\n" + "ldr d16, [x23], #0x8\n" + "ldr d20, [x22], #0x8\n" + "ldr d24, [x21], #0x8\n" + "ldr d28, [x20], #0x8\n" + "tbz x13, #0, 182f\n" + "ld1 { v8.s }[2], [x12]\n" + "ld1 { v12.s }[2], [x24]\n" + "ld1 { v16.s }[2], [x23]\n" + "ld1 { v20.s }[2], [x22]\n" + "ld1 { v24.s }[2], [x21]\n" + "ld1 { v28.s }[2], [x20]\n" + "b 182f\n" + "181:" // Height 6: Partial accumulate: partial_1_0 + "ldr s8, [x12, #0x0]\n" + "ldr s12, [x24, #0x0]\n" + "mov x19, #0x0\n" + "ldr s16, [x23, #0x0]\n" + "ldr s20, [x22, #0x0]\n" + "ldr s24, [x21, #0x0]\n" + "ldr s28, [x20, #0x0]\n" + "182:" // Height 6: Partial accumulate: Done + "sub x12, x12, x19\n" + "b 185f\n" + "183:" // Height 6: full accumulate + "ldr q8, [x12, #0x0]\n" + "ldr q9, [x12, #0x10]\n" + "ldr q10, [x12, #0x20]\n" + "ldr q11, [x12, #0x30]\n" + "ldr q12, [x24, #0x0]\n" + "ldr q13, [x24, #0x10]\n" + "ldr q14, [x24, #0x20]\n" + "ldr q15, [x24, #0x30]\n" + "ldr q16, [x23, #0x0]\n" + "ldr q17, [x23, #0x10]\n" + "ldr q18, [x23, #0x20]\n" + "ldr q19, [x23, #0x30]\n" + "ldr q20, [x22, #0x0]\n" + "ldr q21, [x22, #0x10]\n" + "ldr q22, [x22, #0x20]\n" + "ldr q23, [x22, #0x30]\n" + "ldr q24, [x21, #0x0]\n" + "ldr q25, [x21, #0x10]\n" + "ldr q26, [x21, #0x20]\n" + "ldr q27, [x21, #0x30]\n" + "ldr q28, [x20, #0x0]\n" + "ldr q29, [x20, #0x10]\n" + "ldr q30, [x20, #0x20]\n" + "ldr q31, [x20, #0x30]\n" + "b 185f\n" + "184:" // Height 6: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v26.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "movi v29.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "185:" // Height 6: setup done + "mov x27, #0x0\n" + "186:" // Height 6: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 187f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "ldr x22, [x20, #0x18]\n" + "ldr x21, [x20, #0x20]\n" + "ldr x20, [x20, #0x28]\n" + "cbnz x27, 188f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #2\n" + "add x24, x24, x19, LSL #2\n" + "add x23, x23, x19, LSL #2\n" + "add x22, x22, x19, LSL #2\n" + "add x21, x21, x19, LSL #2\n" + "add x20, x20, x19, LSL #2\n" + "b 188f\n" + "187:" // Height 6: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "add x22, x23, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "add x20, x21, x19, LSL #2\n" + "188:" // Height 6: input setup done + "cmp x26, #0x4\n" + "blt 191f\n" + "ldr q0, [x25, #0x0]\n" + "ldr q1, [x24, #0x0]\n" + "cmp x26, #0x8\n" + "ldr q2, [x23, #0x0]\n" + "ldr q3, [x22, #0x0]\n" + "ldr q4, [x21, #0x0]\n" + "ldr q5, [x20, #0x0]\n" + "ldr q6, [x11, #0x0]\n" + "ldr q7, [x10, #0x0]\n" + "blt 190f\n" + "189:" // Height 6: Multiply loop: Main loop head + "fmla v8.4s, v6.4s, v0.s[0]\n" + "fmla v12.4s, v6.4s, v1.s[0]\n" + "sub x26, x26, #0x4\n" + "cmp x26, #0x8\n" + "fmla v16.4s, v6.4s, v2.s[0]\n" + "fmla v20.4s, v6.4s, v3.s[0]\n" + "add x25, x25, #0x10\n" + "add x24, x24, #0x10\n" + "fmla v24.4s, v6.4s, v4.s[0]\n" + "fmla v28.4s, v6.4s, v5.s[0]\n" + "ldr q6, [x9, #0x0]\n" + "add x23, x23, #0x10\n" + "fmla v9.4s, v7.4s, v0.s[0]\n" + "fmla v13.4s, v7.4s, v1.s[0]\n" + "add x22, x22, #0x10\n" + "add x21, x21, #0x10\n" + "fmla v17.4s, v7.4s, v2.s[0]\n" + "fmla v21.4s, v7.4s, v3.s[0]\n" + "add x20, x20, #0x10\n" + "fmla v25.4s, v7.4s, v4.s[0]\n" + "fmla v29.4s, v7.4s, v5.s[0]\n" + "ldr q7, [x28, #0x0]\n" + "fmla v10.4s, v6.4s, v0.s[0]\n" + "fmla v14.4s, v6.4s, v1.s[0]\n" + "fmla v18.4s, v6.4s, v2.s[0]\n" + "fmla v22.4s, v6.4s, v3.s[0]\n" + "fmla v26.4s, v6.4s, v4.s[0]\n" + "fmla v30.4s, v6.4s, v5.s[0]\n" + "ldr q6, [x11, #0x10]\n" + "fmla v11.4s, v7.4s, v0.s[0]\n" + "fmla v15.4s, v7.4s, v1.s[0]\n" + "fmla v19.4s, v7.4s, v2.s[0]\n" + "fmla v23.4s, v7.4s, v3.s[0]\n" + "fmla v27.4s, v7.4s, v4.s[0]\n" + "fmla v31.4s, v7.4s, v5.s[0]\n" + "ldr q7, [x10, #0x10]\n" + "fmla v8.4s, v6.4s, v0.s[1]\n" + "fmla v12.4s, v6.4s, v1.s[1]\n" + "fmla v16.4s, v6.4s, v2.s[1]\n" + "fmla v20.4s, v6.4s, v3.s[1]\n" + "fmla v24.4s, v6.4s, v4.s[1]\n" + "fmla v28.4s, v6.4s, v5.s[1]\n" + "ldr q6, [x9, #0x10]\n" + "fmla v9.4s, v7.4s, v0.s[1]\n" + "fmla v13.4s, v7.4s, v1.s[1]\n" + "fmla v17.4s, v7.4s, v2.s[1]\n" + "fmla v21.4s, v7.4s, v3.s[1]\n" + "fmla v25.4s, v7.4s, v4.s[1]\n" + "fmla v29.4s, v7.4s, v5.s[1]\n" + "ldr q7, [x28, #0x10]\n" + "fmla v10.4s, v6.4s, v0.s[1]\n" + "fmla v14.4s, v6.4s, v1.s[1]\n" + "fmla v18.4s, v6.4s, v2.s[1]\n" + "fmla v22.4s, v6.4s, v3.s[1]\n" + "fmla v26.4s, v6.4s, v4.s[1]\n" + "fmla v30.4s, v6.4s, v5.s[1]\n" + "ldr q6, [x11, #0x20]\n" + "fmla v11.4s, v7.4s, v0.s[1]\n" + "fmla v15.4s, v7.4s, v1.s[1]\n" + "fmla v19.4s, v7.4s, v2.s[1]\n" + "fmla v23.4s, v7.4s, v3.s[1]\n" + "fmla v27.4s, v7.4s, v4.s[1]\n" + "fmla v31.4s, v7.4s, v5.s[1]\n" + "ldr q7, [x10, #0x20]\n" + "fmla v8.4s, v6.4s, v0.s[2]\n" + "fmla v12.4s, v6.4s, v1.s[2]\n" + "fmla v16.4s, v6.4s, v2.s[2]\n" + "fmla v20.4s, v6.4s, v3.s[2]\n" + "fmla v24.4s, v6.4s, v4.s[2]\n" + "fmla v28.4s, v6.4s, v5.s[2]\n" + "ldr q6, [x9, #0x20]\n" + "fmla v9.4s, v7.4s, v0.s[2]\n" + "fmla v13.4s, v7.4s, v1.s[2]\n" + "fmla v17.4s, v7.4s, v2.s[2]\n" + "fmla v21.4s, v7.4s, v3.s[2]\n" + "fmla v25.4s, v7.4s, v4.s[2]\n" + "fmla v29.4s, v7.4s, v5.s[2]\n" + "ldr q7, [x28, #0x20]\n" + "fmla v10.4s, v6.4s, v0.s[2]\n" + "fmla v14.4s, v6.4s, v1.s[2]\n" + "fmla v18.4s, v6.4s, v2.s[2]\n" + "fmla v22.4s, v6.4s, v3.s[2]\n" + "fmla v26.4s, v6.4s, v4.s[2]\n" + "fmla v30.4s, v6.4s, v5.s[2]\n" + "ldr q6, [x11, #0x30]\n" + "add x11, x11, #0x40\n" + "fmla v11.4s, v7.4s, v0.s[2]\n" + "fmla v15.4s, v7.4s, v1.s[2]\n" + "fmla v19.4s, v7.4s, v2.s[2]\n" + "fmla v23.4s, v7.4s, v3.s[2]\n" + "fmla v27.4s, v7.4s, v4.s[2]\n" + "fmla v31.4s, v7.4s, v5.s[2]\n" + "ldr q7, [x10, #0x30]\n" + "add x10, x10, #0x40\n" + "fmla v8.4s, v6.4s, v0.s[3]\n" + "fmla v12.4s, v6.4s, v1.s[3]\n" + "fmla v16.4s, v6.4s, v2.s[3]\n" + "fmla v20.4s, v6.4s, v3.s[3]\n" + "fmla v24.4s, v6.4s, v4.s[3]\n" + "fmla v28.4s, v6.4s, v5.s[3]\n" + "ldr q6, [x9, #0x30]\n" + "add x9, x9, #0x40\n" + "fmla v9.4s, v7.4s, v0.s[3]\n" + "fmla v13.4s, v7.4s, v1.s[3]\n" + "fmla v17.4s, v7.4s, v2.s[3]\n" + "fmla v21.4s, v7.4s, v3.s[3]\n" + "fmla v25.4s, v7.4s, v4.s[3]\n" + "fmla v29.4s, v7.4s, v5.s[3]\n" + "ldr q7, [x28, #0x30]\n" + "add x28, x28, #0x40\n" + "fmla v10.4s, v6.4s, v0.s[3]\n" + "fmla v14.4s, v6.4s, v1.s[3]\n" + "fmla v18.4s, v6.4s, v2.s[3]\n" + "fmla v22.4s, v6.4s, v3.s[3]\n" + "fmla v26.4s, v6.4s, v4.s[3]\n" + "fmla v30.4s, v6.4s, v5.s[3]\n" + "ldr q6, [x11, #0x0]\n" + "fmla v11.4s, v7.4s, v0.s[3]\n" + "ldr q0, [x25, #0x0]\n" + "fmla v15.4s, v7.4s, v1.s[3]\n" + "ldr q1, [x24, #0x0]\n" + "fmla v19.4s, v7.4s, v2.s[3]\n" + "ldr q2, [x23, #0x0]\n" + "fmla v23.4s, v7.4s, v3.s[3]\n" + "ldr q3, [x22, #0x0]\n" + "fmla v27.4s, v7.4s, v4.s[3]\n" + "ldr q4, [x21, #0x0]\n" + "fmla v31.4s, v7.4s, v5.s[3]\n" + "ldr q5, [x20, #0x0]\n" + "ldr q7, [x10, #0x0]\n" + "bge 189b\n" + "190:" // Height 6: Multiply loop: Single iteration only + "fmla v8.4s, v6.4s, v0.s[0]\n" + "fmla v12.4s, v6.4s, v1.s[0]\n" + "sub x26, x26, #0x4\n" + "add x25, x25, #0x10\n" + "fmla v16.4s, v6.4s, v2.s[0]\n" + "fmla v20.4s, v6.4s, v3.s[0]\n" + "add x24, x24, #0x10\n" + "add x23, x23, #0x10\n" + "fmla v24.4s, v6.4s, v4.s[0]\n" + "fmla v28.4s, v6.4s, v5.s[0]\n" + "ldr q6, [x9, #0x0]\n" + "add x22, x22, #0x10\n" + "fmla v9.4s, v7.4s, v0.s[0]\n" + "fmla v13.4s, v7.4s, v1.s[0]\n" + "add x21, x21, #0x10\n" + "add x20, x20, #0x10\n" + "fmla v17.4s, v7.4s, v2.s[0]\n" + "fmla v21.4s, v7.4s, v3.s[0]\n" + "fmla v25.4s, v7.4s, v4.s[0]\n" + "fmla v29.4s, v7.4s, v5.s[0]\n" + "ldr q7, [x28, #0x0]\n" + "fmla v10.4s, v6.4s, v0.s[0]\n" + "fmla v14.4s, v6.4s, v1.s[0]\n" + "fmla v18.4s, v6.4s, v2.s[0]\n" + "fmla v22.4s, v6.4s, v3.s[0]\n" + "fmla v26.4s, v6.4s, v4.s[0]\n" + "fmla v30.4s, v6.4s, v5.s[0]\n" + "ldr q6, [x11, #0x10]\n" + "fmla v11.4s, v7.4s, v0.s[0]\n" + "fmla v15.4s, v7.4s, v1.s[0]\n" + "fmla v19.4s, v7.4s, v2.s[0]\n" + "fmla v23.4s, v7.4s, v3.s[0]\n" + "fmla v27.4s, v7.4s, v4.s[0]\n" + "fmla v31.4s, v7.4s, v5.s[0]\n" + "ldr q7, [x10, #0x10]\n" + "fmla v8.4s, v6.4s, v0.s[1]\n" + "fmla v12.4s, v6.4s, v1.s[1]\n" + "fmla v16.4s, v6.4s, v2.s[1]\n" + "fmla v20.4s, v6.4s, v3.s[1]\n" + "fmla v24.4s, v6.4s, v4.s[1]\n" + "fmla v28.4s, v6.4s, v5.s[1]\n" + "ldr q6, [x9, #0x10]\n" + "fmla v9.4s, v7.4s, v0.s[1]\n" + "fmla v13.4s, v7.4s, v1.s[1]\n" + "fmla v17.4s, v7.4s, v2.s[1]\n" + "fmla v21.4s, v7.4s, v3.s[1]\n" + "fmla v25.4s, v7.4s, v4.s[1]\n" + "fmla v29.4s, v7.4s, v5.s[1]\n" + "ldr q7, [x28, #0x10]\n" + "fmla v10.4s, v6.4s, v0.s[1]\n" + "fmla v14.4s, v6.4s, v1.s[1]\n" + "fmla v18.4s, v6.4s, v2.s[1]\n" + "fmla v22.4s, v6.4s, v3.s[1]\n" + "fmla v26.4s, v6.4s, v4.s[1]\n" + "fmla v30.4s, v6.4s, v5.s[1]\n" + "ldr q6, [x11, #0x20]\n" + "fmla v11.4s, v7.4s, v0.s[1]\n" + "fmla v15.4s, v7.4s, v1.s[1]\n" + "fmla v19.4s, v7.4s, v2.s[1]\n" + "fmla v23.4s, v7.4s, v3.s[1]\n" + "fmla v27.4s, v7.4s, v4.s[1]\n" + "fmla v31.4s, v7.4s, v5.s[1]\n" + "ldr q7, [x10, #0x20]\n" + "fmla v8.4s, v6.4s, v0.s[2]\n" + "fmla v12.4s, v6.4s, v1.s[2]\n" + "fmla v16.4s, v6.4s, v2.s[2]\n" + "fmla v20.4s, v6.4s, v3.s[2]\n" + "fmla v24.4s, v6.4s, v4.s[2]\n" + "fmla v28.4s, v6.4s, v5.s[2]\n" + "ldr q6, [x9, #0x20]\n" + "fmla v9.4s, v7.4s, v0.s[2]\n" + "fmla v13.4s, v7.4s, v1.s[2]\n" + "fmla v17.4s, v7.4s, v2.s[2]\n" + "fmla v21.4s, v7.4s, v3.s[2]\n" + "fmla v25.4s, v7.4s, v4.s[2]\n" + "fmla v29.4s, v7.4s, v5.s[2]\n" + "ldr q7, [x28, #0x20]\n" + "fmla v10.4s, v6.4s, v0.s[2]\n" + "fmla v14.4s, v6.4s, v1.s[2]\n" + "fmla v18.4s, v6.4s, v2.s[2]\n" + "fmla v22.4s, v6.4s, v3.s[2]\n" + "fmla v26.4s, v6.4s, v4.s[2]\n" + "fmla v30.4s, v6.4s, v5.s[2]\n" + "ldr q6, [x11, #0x30]\n" + "add x11, x11, #0x40\n" + "fmla v11.4s, v7.4s, v0.s[2]\n" + "fmla v15.4s, v7.4s, v1.s[2]\n" + "fmla v19.4s, v7.4s, v2.s[2]\n" + "fmla v23.4s, v7.4s, v3.s[2]\n" + "fmla v27.4s, v7.4s, v4.s[2]\n" + "fmla v31.4s, v7.4s, v5.s[2]\n" + "ldr q7, [x10, #0x30]\n" + "add x10, x10, #0x40\n" + "fmla v8.4s, v6.4s, v0.s[3]\n" + "fmla v12.4s, v6.4s, v1.s[3]\n" + "fmla v16.4s, v6.4s, v2.s[3]\n" + "fmla v20.4s, v6.4s, v3.s[3]\n" + "fmla v24.4s, v6.4s, v4.s[3]\n" + "fmla v28.4s, v6.4s, v5.s[3]\n" + "ldr q6, [x9, #0x30]\n" + "add x9, x9, #0x40\n" + "fmla v9.4s, v7.4s, v0.s[3]\n" + "fmla v13.4s, v7.4s, v1.s[3]\n" + "fmla v17.4s, v7.4s, v2.s[3]\n" + "fmla v21.4s, v7.4s, v3.s[3]\n" + "fmla v25.4s, v7.4s, v4.s[3]\n" + "fmla v29.4s, v7.4s, v5.s[3]\n" + "ldr q7, [x28, #0x30]\n" + "add x28, x28, #0x40\n" + "fmla v10.4s, v6.4s, v0.s[3]\n" + "fmla v14.4s, v6.4s, v1.s[3]\n" + "fmla v18.4s, v6.4s, v2.s[3]\n" + "fmla v22.4s, v6.4s, v3.s[3]\n" + "fmla v26.4s, v6.4s, v4.s[3]\n" + "fmla v30.4s, v6.4s, v5.s[3]\n" + "fmla v11.4s, v7.4s, v0.s[3]\n" + "fmla v15.4s, v7.4s, v1.s[3]\n" + "fmla v19.4s, v7.4s, v2.s[3]\n" + "fmla v23.4s, v7.4s, v3.s[3]\n" + "fmla v27.4s, v7.4s, v4.s[3]\n" + "fmla v31.4s, v7.4s, v5.s[3]\n" + "191:" // Height 6: Multiply loop: Main loop skip + "cbz x26, 193f\n" + "192:" // Height 6: Multiply loop: Odd block loop + "ldr s0, [x25], #0x4\n" + "ldr s1, [x24], #0x4\n" + "sub x26, x26, #0x1\n" + "ldr s2, [x23], #0x4\n" + "ldr s3, [x22], #0x4\n" + "ldr s4, [x21], #0x4\n" + "ldr s5, [x20], #0x4\n" + "ldr q6, [x11, #0x0]\n" + "ldr q7, [x10, #0x0]\n" + "fmla v8.4s, v6.4s, v0.s[0]\n" + "fmla v12.4s, v6.4s, v1.s[0]\n" + "fmla v16.4s, v6.4s, v2.s[0]\n" + "fmla v20.4s, v6.4s, v3.s[0]\n" + "add x11, x11, #0x10\n" + "add x10, x10, #0x10\n" + "fmla v24.4s, v6.4s, v4.s[0]\n" + "fmla v28.4s, v6.4s, v5.s[0]\n" + "ldr q6, [x9, #0x0]\n" + "add x9, x9, #0x10\n" + "fmla v9.4s, v7.4s, v0.s[0]\n" + "fmla v13.4s, v7.4s, v1.s[0]\n" + "fmla v17.4s, v7.4s, v2.s[0]\n" + "fmla v21.4s, v7.4s, v3.s[0]\n" + "fmla v25.4s, v7.4s, v4.s[0]\n" + "fmla v29.4s, v7.4s, v5.s[0]\n" + "ldr q7, [x28, #0x0]\n" + "add x28, x28, #0x10\n" + "fmla v10.4s, v6.4s, v0.s[0]\n" + "fmla v14.4s, v6.4s, v1.s[0]\n" + "fmla v18.4s, v6.4s, v2.s[0]\n" + "fmla v22.4s, v6.4s, v3.s[0]\n" + "fmla v26.4s, v6.4s, v4.s[0]\n" + "fmla v30.4s, v6.4s, v5.s[0]\n" + "fmla v11.4s, v7.4s, v0.s[0]\n" + "fmla v15.4s, v7.4s, v1.s[0]\n" + "fmla v19.4s, v7.4s, v2.s[0]\n" + "fmla v23.4s, v7.4s, v3.s[0]\n" + "fmla v27.4s, v7.4s, v4.s[0]\n" + "fmla v31.4s, v7.4s, v5.s[0]\n" + "cbnz x26, 192b\n" + "193:" // Height 6: Multiply loop: No odd multiplies + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 186b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "add x22, x23, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "add x20, x21, x19, LSL #2\n" + "tbz %x[flags], #1, 194f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1r { v1.4s }, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1r { v0.4s }, [x19]\n" + "fmin v8.4s, v8.4s, v1.4s\n" + "fmin v9.4s, v9.4s, v1.4s\n" + "fmin v10.4s, v10.4s, v1.4s\n" + "fmin v11.4s, v11.4s, v1.4s\n" + "fmin v12.4s, v12.4s, v1.4s\n" + "fmin v13.4s, v13.4s, v1.4s\n" + "fmin v14.4s, v14.4s, v1.4s\n" + "fmin v15.4s, v15.4s, v1.4s\n" + "fmin v16.4s, v16.4s, v1.4s\n" + "fmin v17.4s, v17.4s, v1.4s\n" + "fmin v18.4s, v18.4s, v1.4s\n" + "fmin v19.4s, v19.4s, v1.4s\n" + "fmin v20.4s, v20.4s, v1.4s\n" + "fmin v21.4s, v21.4s, v1.4s\n" + "fmin v22.4s, v22.4s, v1.4s\n" + "fmin v23.4s, v23.4s, v1.4s\n" + "fmin v24.4s, v24.4s, v1.4s\n" + "fmin v25.4s, v25.4s, v1.4s\n" + "fmin v26.4s, v26.4s, v1.4s\n" + "fmin v27.4s, v27.4s, v1.4s\n" + "fmin v28.4s, v28.4s, v1.4s\n" + "fmin v29.4s, v29.4s, v1.4s\n" + "fmin v30.4s, v30.4s, v1.4s\n" + "fmin v31.4s, v31.4s, v1.4s\n" + "fmax v8.4s, v8.4s, v0.4s\n" + "fmax v9.4s, v9.4s, v0.4s\n" + "fmax v10.4s, v10.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + "fmax v12.4s, v12.4s, v0.4s\n" + "fmax v13.4s, v13.4s, v0.4s\n" + "fmax v14.4s, v14.4s, v0.4s\n" + "fmax v15.4s, v15.4s, v0.4s\n" + "fmax v16.4s, v16.4s, v0.4s\n" + "fmax v17.4s, v17.4s, v0.4s\n" + "fmax v18.4s, v18.4s, v0.4s\n" + "fmax v19.4s, v19.4s, v0.4s\n" + "fmax v20.4s, v20.4s, v0.4s\n" + "fmax v21.4s, v21.4s, v0.4s\n" + "fmax v22.4s, v22.4s, v0.4s\n" + "fmax v23.4s, v23.4s, v0.4s\n" + "fmax v24.4s, v24.4s, v0.4s\n" + "fmax v25.4s, v25.4s, v0.4s\n" + "fmax v26.4s, v26.4s, v0.4s\n" + "fmax v27.4s, v27.4s, v0.4s\n" + "fmax v28.4s, v28.4s, v0.4s\n" + "fmax v29.4s, v29.4s, v0.4s\n" + "fmax v30.4s, v30.4s, v0.4s\n" + "fmax v31.4s, v31.4s, v0.4s\n" + "194:" // Height 6: No activation + "cmp x13, #0x10\n" + "bge 203f\n" + "tbz x13, #3, 198f\n" + "st1 { v8.4s }, [x12], #0x10\n" + "st1 { v9.4s }, [x12], #0x10\n" + "st1 { v12.4s }, [x24], #0x10\n" + "st1 { v13.4s }, [x24], #0x10\n" + "st1 { v16.4s }, [x23], #0x10\n" + "st1 { v17.4s }, [x23], #0x10\n" + "st1 { v20.4s }, [x22], #0x10\n" + "st1 { v21.4s }, [x22], #0x10\n" + "st1 { v24.4s }, [x21], #0x10\n" + "st1 { v25.4s }, [x21], #0x10\n" + "st1 { v28.4s }, [x20], #0x10\n" + "st1 { v29.4s }, [x20], #0x10\n" + "tbz x13, #2, 196f\n" + "st1 { v10.4s }, [x12], #0x10\n" + "st1 { v14.4s }, [x24], #0x10\n" + "st1 { v18.4s }, [x23], #0x10\n" + "st1 { v22.4s }, [x22], #0x10\n" + "st1 { v26.4s }, [x21], #0x10\n" + "st1 { v30.4s }, [x20], #0x10\n" + "tbz x13, #1, 195f\n" + "str d11, [x12], #0x8\n" + "str d15, [x24], #0x8\n" + "str d19, [x23], #0x8\n" + "str d23, [x22], #0x8\n" + "str d27, [x21], #0x8\n" + "str d31, [x20], #0x8\n" + "tbz x13, #0, 202f\n" + "st1 { v11.s }[2], [x12]\n" + "st1 { v15.s }[2], [x24]\n" + "st1 { v19.s }[2], [x23]\n" + "st1 { v23.s }[2], [x22]\n" + "st1 { v27.s }[2], [x21]\n" + "st1 { v31.s }[2], [x20]\n" + "b 202f\n" + "195:" // Height 6: Partial direct writeback: partial_1_12 + "tbz x13, #0, 202f\n" + "str s11, [x12, #0x0]\n" + "str s15, [x24, #0x0]\n" + "str s19, [x23, #0x0]\n" + "str s23, [x22, #0x0]\n" + "str s27, [x21, #0x0]\n" + "str s31, [x20, #0x0]\n" + "b 202f\n" + "196:" // Height 6: Partial direct writeback: partial_2_8 + "tbz x13, #1, 197f\n" + "str d10, [x12], #0x8\n" + "str d14, [x24], #0x8\n" + "str d18, [x23], #0x8\n" + "str d22, [x22], #0x8\n" + "str d26, [x21], #0x8\n" + "str d30, [x20], #0x8\n" + "tbz x13, #0, 202f\n" + "st1 { v10.s }[2], [x12]\n" + "st1 { v14.s }[2], [x24]\n" + "st1 { v18.s }[2], [x23]\n" + "st1 { v22.s }[2], [x22]\n" + "st1 { v26.s }[2], [x21]\n" + "st1 { v30.s }[2], [x20]\n" + "b 202f\n" + "197:" // Height 6: Partial direct writeback: partial_1_8 + "tbz x13, #0, 202f\n" + "str s10, [x12, #0x0]\n" + "str s14, [x24, #0x0]\n" + "str s18, [x23, #0x0]\n" + "str s22, [x22, #0x0]\n" + "str s26, [x21, #0x0]\n" + "str s30, [x20, #0x0]\n" + "b 202f\n" + "198:" // Height 6: Partial direct writeback: partial_4_0 + "tbz x13, #2, 200f\n" + "st1 { v8.4s }, [x12], #0x10\n" + "st1 { v12.4s }, [x24], #0x10\n" + "st1 { v16.4s }, [x23], #0x10\n" + "st1 { v20.4s }, [x22], #0x10\n" + "st1 { v24.4s }, [x21], #0x10\n" + "st1 { v28.4s }, [x20], #0x10\n" + "tbz x13, #1, 199f\n" + "str d9, [x12], #0x8\n" + "str d13, [x24], #0x8\n" + "str d17, [x23], #0x8\n" + "str d21, [x22], #0x8\n" + "str d25, [x21], #0x8\n" + "str d29, [x20], #0x8\n" + "tbz x13, #0, 202f\n" + "st1 { v9.s }[2], [x12]\n" + "st1 { v13.s }[2], [x24]\n" + "st1 { v17.s }[2], [x23]\n" + "st1 { v21.s }[2], [x22]\n" + "st1 { v25.s }[2], [x21]\n" + "st1 { v29.s }[2], [x20]\n" + "b 202f\n" + "199:" // Height 6: Partial direct writeback: partial_1_4 + "tbz x13, #0, 202f\n" + "str s9, [x12, #0x0]\n" + "str s13, [x24, #0x0]\n" + "str s17, [x23, #0x0]\n" + "str s21, [x22, #0x0]\n" + "str s25, [x21, #0x0]\n" + "str s29, [x20, #0x0]\n" + "b 202f\n" + "200:" // Height 6: Partial direct writeback: partial_2_0 + "tbz x13, #1, 201f\n" + "str d8, [x12], #0x8\n" + "str d12, [x24], #0x8\n" + "str d16, [x23], #0x8\n" + "str d20, [x22], #0x8\n" + "str d24, [x21], #0x8\n" + "str d28, [x20], #0x8\n" + "tbz x13, #0, 202f\n" + "st1 { v8.s }[2], [x12]\n" + "st1 { v12.s }[2], [x24]\n" + "st1 { v16.s }[2], [x23]\n" + "st1 { v20.s }[2], [x22]\n" + "st1 { v24.s }[2], [x21]\n" + "st1 { v28.s }[2], [x20]\n" + "b 202f\n" + "201:" // Height 6: Partial direct writeback: partial_1_0 + "str s8, [x12, #0x0]\n" + "str s12, [x24, #0x0]\n" + "str s16, [x23, #0x0]\n" + "str s20, [x22, #0x0]\n" + "str s24, [x21, #0x0]\n" + "str s28, [x20, #0x0]\n" + "202:" // Height 6: Partial direct writeback: Done + "b 204f\n" + "203:" // Height 6: Full writeback + "str q8, [x12, #0x0]\n" + "str q9, [x12, #0x10]\n" + "str q10, [x12, #0x20]\n" + "str q11, [x12, #0x30]\n" + "add x12, x12, #0x40\n" + "str q12, [x24, #0x0]\n" + "str q13, [x24, #0x10]\n" + "str q14, [x24, #0x20]\n" + "str q15, [x24, #0x30]\n" + "str q16, [x23, #0x0]\n" + "str q17, [x23, #0x10]\n" + "str q18, [x23, #0x20]\n" + "str q19, [x23, #0x30]\n" + "str q20, [x22, #0x0]\n" + "str q21, [x22, #0x10]\n" + "str q22, [x22, #0x20]\n" + "str q23, [x22, #0x30]\n" + "str q24, [x21, #0x0]\n" + "str q25, [x21, #0x10]\n" + "str q26, [x21, #0x20]\n" + "str q27, [x21, #0x30]\n" + "str q28, [x20, #0x0]\n" + "str q29, [x20, #0x10]\n" + "str q30, [x20, #0x20]\n" + "str q31, [x20, #0x30]\n" + "204:" // Height 6: Writeback done + "subs x13, x13, #0x10\n" + "bgt 172b\n" + "subs %x[M], %x[M], #0x6\n" + "beq 206f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 205f\n" + "add x20, x20, #0x6\n" + "str x20, [%x[args_ptr], %[offsetof_input_offset]]\n" + "b 1b\n" + "205:" // Update direct input + "mov x19, #0x18\n" + "madd %x[input_ptr], x19, x20, %x[input_ptr]\n" + "b 1b\n" + "206:" // Exit + : [M] "+&r" (M), [input_ptr] "+&r" (input_ptr), [output_ptr] "+&r" (output_ptr) + : [args_ptr] "r" (&ka), [bias] "r" (bias), [flags] "r" (flags), [offset_max] "I" (offsetof(KernelArgs, maxval)), [offset_min] "I" (offsetof(KernelArgs, minval)), [offsetof_B_ptr] "I" (offsetof(KernelArgs, B_ptr)), [offsetof_B_stride] "I" (offsetof(KernelArgs, B_stride)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_cur_B_ptr] "I" (offsetof(KernelArgs, cur_B_ptr)), [offsetof_input_initial_col] "I" (offsetof(KernelArgs, input_initial_col)), [offsetof_input_offset] "I" (offsetof(KernelArgs, input_offset)), [offsetof_num_strings] "I" (offsetof(KernelArgs, num_strings)), [offsetof_output_offset] "I" (offsetof(KernelArgs, output_offset)), [offsetof_string_lengths] "I" (offsetof(KernelArgs, string_lengths)) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", "x12", "x13", "x14", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +} + +} // namespace arm_gemm +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp new file mode 100644 index 0000000000..af2c1e5ae0 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#pragma once +#ifdef __aarch64__ + +#include "../std_transforms_fixed.hpp" +#include "../bfloat.hpp" +#include "../kernel_weight_format.hpp" +#include "../performance_parameters.hpp" + +#define ARGLIST \ + unsigned int, const unsigned int *, \ + IndirectInputArg, \ + size_t, size_t, \ + const bfloat16 *, \ + size_t, \ + IndirectOutputArg, \ + const float *, Activation, bool + +namespace arm_gemm +{ +// Actual kernel implementations +void a64_ffhybrid_fp32bf16fp32_mmla_4x24( ARGLIST ); + +class cls_a64_ffhybrid_fp32bf16fp32_mmla_4x24 +{ +public: + typedef float lhs_operand_type; + typedef bfloat16 rhs_operand_type; + typedef float result_type; + + typedef void (*kern_type)( ARGLIST ); + + /* Kernel blocking parameters */ + static constexpr unsigned int out_height() + { + return 4; + } + static unsigned int stripe_width() + { + return 4; + } + + static KernelWeightFormat kernel_weight_format() + { + return KernelWeightFormat::VL256_BL64_BF16; + } + + static unsigned int out_width() + { + return 24; + } + + static constexpr unsigned int k_unroll() + { + return 4; + } + + static constexpr bool supports_accumulate() + { + return true; + } + + StdTransformsFixed transforms = {}; + template + static inline PerformanceParameters get_performance_parameters(const CPUInfo *ci) + { + if (std::is_same::value) { + switch (ci->get_cpu_model()) { + default: + return { 28.48 }; + } + } + + return { 1.0 }; + } + + // Default to the generic kernel + kern_type kernel=a64_ffhybrid_fp32bf16fp32_mmla_4x24; + cls_a64_ffhybrid_fp32bf16fp32_mmla_4x24(const CPUInfo *) + { + } +}; + +} // namespace arm_gemm + +#undef ARGLIST +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24/generic.cpp new file mode 100644 index 0000000000..245e653a43 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24/generic.cpp @@ -0,0 +1,2561 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#ifdef __aarch64__ + +#include "arm_gemm.hpp" +#include "../../utils.hpp" +#include "../../bfloat.hpp" + +#include +#include + +namespace arm_gemm { + +void a64_ffhybrid_fp32bf16fp32_mmla_4x24 ( + unsigned int num_strings, const unsigned int *string_lengths, IndirectInputArg A_arg, + size_t M, size_t N, const bfloat16 *B_ptr, size_t B_stride, IndirectOutputArg output_arg, + const float *bias, Activation act, bool accumulate +) +{ + struct KernelArgs { + float maxval = static_cast(std::numeric_limits::infinity()); + float minval = - static_cast(std::numeric_limits::infinity()); + unsigned int num_strings = {}; + const unsigned int *string_lengths = {}; + size_t N = {}; + const bfloat16 *B_ptr = {}; + const bfloat16 *cur_B_ptr = {}; + size_t B_stride = {}; + size_t output_offset = {}; + size_t input_initial_col = {}; + size_t input_offset = {}; + } ka; + + unsigned long flags=0; + void *output_ptr; + void *input_ptr; + + if (output_arg.is_indirect) { + output_ptr=(void *)(output_arg.indirect.ptr); + ka.output_offset=output_arg.indirect.offset; + flags |= 0x4; + } else { + output_ptr=(void *)(output_arg.direct.base); + ka.output_offset=output_arg.direct.stride; + } + + if (A_arg.is_indirect) { + input_ptr=(void *)(A_arg.indirect.ptr); + ka.input_offset=A_arg.indirect.start_row; + ka.input_initial_col=A_arg.indirect.start_col; + flags |= 0x8; + } else { + assert(num_strings==1); + input_ptr=(void *)(A_arg.direct.base); + ka.input_offset=A_arg.direct.stride; + } + if (accumulate) { + flags |= 0x1; + } + ka.num_strings = num_strings; + ka.string_lengths = string_lengths; + ka.N = N; + ka.B_ptr = B_ptr; + ka.B_stride = B_stride; + switch(act.type) { + default: + case Activation::Type::None: + break; + case Activation::Type::BoundedReLU: + ka.maxval = static_cast(act.param1); + /* fall through */ + case Activation::Type::ReLU: + ka.minval = 0; + flags |= 0x2; + break; + } + __asm__ __volatile__( + "1:" // Row loop + "cmp %x[M], #0x4\n" + "bge 133f\n" + "cmp %x[M], #0x2\n" + "bgt 89f\n" + "beq 45f\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "2:" // Height 1: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x27, x28, x19, LSL #1\n" + "add x26, x27, x19, LSL #1\n" + "add x19, x26, x19, LSL #1\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "cmp x13, #0x14\n" + "bgt 3f\n" + "cmp x13, #0x10\n" + "mov x26, x11\n" + "bgt 3f\n" + "cmp x13, #0xc\n" + "mov x27, x11\n" + "bgt 3f\n" + "cmp x13, #0x8\n" + "mov x28, x11\n" + "bgt 3f\n" + "cmp x13, #0x4\n" + "mov x9, x11\n" + "bgt 3f\n" + "mov x10, x11\n" + "3:" // Height 1: B setup done + "cbz x14, 4f\n" + "ldr q8, [x14, #0x0]\n" + "ldr q9, [x14, #0x10]\n" + "zip2 v14.2d, v8.2d, v8.2d\n" + "zip1 v8.2d, v8.2d, v8.2d\n" + "ldr q10, [x14, #0x20]\n" + "ldr q11, [x14, #0x30]\n" + "zip2 v15.2d, v9.2d, v9.2d\n" + "zip1 v9.2d, v9.2d, v9.2d\n" + "ldr q12, [x14, #0x40]\n" + "ldr q13, [x14, #0x50]\n" + "zip2 v16.2d, v10.2d, v10.2d\n" + "zip1 v10.2d, v10.2d, v10.2d\n" + "zip2 v17.2d, v11.2d, v11.2d\n" + "zip1 v11.2d, v11.2d, v11.2d\n" + "add x14, x14, #0x60\n" + "zip2 v18.2d, v12.2d, v12.2d\n" + "zip1 v12.2d, v12.2d, v12.2d\n" + "zip2 v19.2d, v13.2d, v13.2d\n" + "zip1 v13.2d, v13.2d, v13.2d\n" + "b 20f\n" + "4:" // Height 1: no bias + "tbz %x[flags], #0, 19f\n" + "cmp x13, #0x18\n" + "bge 17f\n" + "tbz x13, #4, 8f\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "ld1 { v10.4s }, [x12], #0x10\n" + "ld1 { v11.4s }, [x12], #0x10\n" + "ld1 { v12.4s }, [x12], #0x10\n" + "tbz x13, #2, 6f\n" + "ld1 { v13.4s }, [x12], #0x10\n" + "tbz x13, #1, 5f\n" + "ldr d20, [x12], #0x8\n" + "mov x19, #0x58\n" + "tbz x13, #0, 16f\n" + "ld1 { v20.s }[2], [x12]\n" + "b 16f\n" + "5:" // Height 1: Partial accumulate: partial_1_20 + "mov x19, #0x50\n" + "tbz x13, #0, 16f\n" + "ldr s20, [x12, #0x0]\n" + "b 16f\n" + "6:" // Height 1: Partial accumulate: partial_2_16 + "tbz x13, #1, 7f\n" + "ldr d13, [x12], #0x8\n" + "mov x19, #0x48\n" + "tbz x13, #0, 16f\n" + "ld1 { v13.s }[2], [x12]\n" + "b 16f\n" + "7:" // Height 1: Partial accumulate: partial_1_16 + "mov x19, #0x40\n" + "tbz x13, #0, 16f\n" + "ldr s13, [x12, #0x0]\n" + "b 16f\n" + "8:" // Height 1: Partial accumulate: partial_8_0 + "tbz x13, #3, 12f\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "ld1 { v10.4s }, [x12], #0x10\n" + "tbz x13, #2, 10f\n" + "ld1 { v11.4s }, [x12], #0x10\n" + "tbz x13, #1, 9f\n" + "ldr d12, [x12], #0x8\n" + "mov x19, #0x38\n" + "tbz x13, #0, 16f\n" + "ld1 { v12.s }[2], [x12]\n" + "b 16f\n" + "9:" // Height 1: Partial accumulate: partial_1_12 + "mov x19, #0x30\n" + "tbz x13, #0, 16f\n" + "ldr s12, [x12, #0x0]\n" + "b 16f\n" + "10:" // Height 1: Partial accumulate: partial_2_8 + "tbz x13, #1, 11f\n" + "ldr d11, [x12], #0x8\n" + "mov x19, #0x28\n" + "tbz x13, #0, 16f\n" + "ld1 { v11.s }[2], [x12]\n" + "b 16f\n" + "11:" // Height 1: Partial accumulate: partial_1_8 + "mov x19, #0x20\n" + "tbz x13, #0, 16f\n" + "ldr s11, [x12, #0x0]\n" + "b 16f\n" + "12:" // Height 1: Partial accumulate: partial_4_0 + "tbz x13, #2, 14f\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "tbz x13, #1, 13f\n" + "ldr d10, [x12], #0x8\n" + "mov x19, #0x18\n" + "tbz x13, #0, 16f\n" + "ld1 { v10.s }[2], [x12]\n" + "b 16f\n" + "13:" // Height 1: Partial accumulate: partial_1_4 + "mov x19, #0x10\n" + "tbz x13, #0, 16f\n" + "ldr s10, [x12, #0x0]\n" + "b 16f\n" + "14:" // Height 1: Partial accumulate: partial_2_0 + "tbz x13, #1, 15f\n" + "ldr d9, [x12], #0x8\n" + "mov x19, #0x8\n" + "tbz x13, #0, 16f\n" + "ld1 { v9.s }[2], [x12]\n" + "b 16f\n" + "15:" // Height 1: Partial accumulate: partial_1_0 + "ldr s9, [x12, #0x0]\n" + "mov x19, #0x0\n" + "16:" // Height 1: Partial accumulate: Done + "sub x12, x12, x19\n" + "b 18f\n" + "17:" // Height 1: full accumulate + "ldr q9, [x12, #0x0]\n" + "ldr q10, [x12, #0x10]\n" + "ldr q11, [x12, #0x20]\n" + "ldr q12, [x12, #0x30]\n" + "ldr q13, [x12, #0x40]\n" + "ldr q20, [x12, #0x50]\n" + "18:" // Height 1: MMLA fixup + "zip1 v8.2d, v9.2d, v14.2d\n" + "zip2 v14.2d, v9.2d, v14.2d\n" + "zip1 v9.2d, v10.2d, v15.2d\n" + "zip2 v15.2d, v10.2d, v15.2d\n" + "zip1 v10.2d, v11.2d, v16.2d\n" + "zip2 v16.2d, v11.2d, v16.2d\n" + "zip1 v11.2d, v12.2d, v17.2d\n" + "zip2 v17.2d, v12.2d, v17.2d\n" + "zip1 v12.2d, v13.2d, v18.2d\n" + "zip2 v18.2d, v13.2d, v18.2d\n" + "zip1 v13.2d, v20.2d, v19.2d\n" + "zip2 v19.2d, v20.2d, v19.2d\n" + "b 20f\n" + "19:" // Height 1: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "20:" // Height 1: setup done + "mov x25, #0x0\n" + "21:" // Height 1: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w24, [x19, x25, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 22f\n" + "ldr x20, [%x[input_ptr], x25, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x23, [x20, #0x0]\n" + "cbnz x25, 23f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x23, x23, x19, LSL #2\n" + "b 23f\n" + "22:" // Height 1: setup direct input + "mov x23, %x[input_ptr]\n" + "23:" // Height 1: input setup done + "cmp x24, #0x4\n" + "blt 26f\n" + "ld1 { v0.4s }, [x23], #0x10\n" + "ldr q4, [x11, #0x0]\n" + "cmp x24, #0x8\n" + "ldr q5, [x11, #0x10]\n" + "ldr q6, [x10, #0x0]\n" + "ldr q7, [x10, #0x10]\n" + "blt 25f\n" + "24:" // Height 1: Multiply loop: Main loop head + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x6e44ec08 // bfmmla v8.4s, v0.8h, v4.8h\n" + "ldr q4, [x9, #0x0]\n" + ".inst 0x6e45ec0e // bfmmla v14.4s, v0.8h, v5.8h\n" + "ldr q5, [x9, #0x10]\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + "ldr q6, [x28, #0x0]\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + "ldr q7, [x28, #0x10]\n" + ".inst 0x6e44ec0a // bfmmla v10.4s, v0.8h, v4.8h\n" + "ldr q4, [x27, #0x0]\n" + ".inst 0x6e45ec10 // bfmmla v16.4s, v0.8h, v5.8h\n" + "ldr q5, [x27, #0x10]\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + "ldr q6, [x26, #0x0]\n" + ".inst 0x6e47ec11 // bfmmla v17.4s, v0.8h, v7.8h\n" + "ldr q7, [x26, #0x10]\n" + "sub x24, x24, #0x4\n" + "cmp x24, #0x8\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + "add x11, x11, #0x20\n" + "ldr q4, [x11, #0x0]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e45ec12 // bfmmla v18.4s, v0.8h, v5.8h\n" + "ldr q5, [x11, #0x10]\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + "ldr q6, [x10, #0x0]\n" + ".inst 0x6e47ec13 // bfmmla v19.4s, v0.8h, v7.8h\n" + "ld1 { v0.4s }, [x23], #0x10\n" + "ldr q7, [x10, #0x10]\n" + "add x9, x9, #0x20\n" + "add x28, x28, #0x20\n" + "add x27, x27, #0x20\n" + "add x26, x26, #0x20\n" + "bge 24b\n" + "25:" // Height 1: Multiply loop: Single iteration only + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x6e44ec08 // bfmmla v8.4s, v0.8h, v4.8h\n" + "ldr q4, [x9, #0x0]\n" + ".inst 0x6e45ec0e // bfmmla v14.4s, v0.8h, v5.8h\n" + "ldr q5, [x9, #0x10]\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + "ldr q6, [x28, #0x0]\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + "ldr q7, [x28, #0x10]\n" + ".inst 0x6e44ec0a // bfmmla v10.4s, v0.8h, v4.8h\n" + "ldr q4, [x27, #0x0]\n" + ".inst 0x6e45ec10 // bfmmla v16.4s, v0.8h, v5.8h\n" + "ldr q5, [x27, #0x10]\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + "ldr q6, [x26, #0x0]\n" + ".inst 0x6e47ec11 // bfmmla v17.4s, v0.8h, v7.8h\n" + "ldr q7, [x26, #0x10]\n" + "sub x24, x24, #0x4\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + ".inst 0x6e45ec12 // bfmmla v18.4s, v0.8h, v5.8h\n" + "add x11, x11, #0x20\n" + "add x10, x10, #0x20\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec13 // bfmmla v19.4s, v0.8h, v7.8h\n" + "add x9, x9, #0x20\n" + "add x28, x28, #0x20\n" + "add x27, x27, #0x20\n" + "add x26, x26, #0x20\n" + "26:" // Height 1: Multiply loop: Main loop skip + "cbz x24, 29f\n" + "cbz x24, 29f\n" + "tbz x24, #1, 27f\n" + "ldr d0, [x23], #0x8\n" + "tbz x24, #0, 28f\n" + "ld1 { v0.s }[2], [x23]\n" + "b 28f\n" + "27:" // Height 1: Multiply loop: Ragged operand read: partial_1_0 + "ldr s0, [x23, #0x0]\n" + "28:" // Height 1: Multiply loop: Ragged operand read: Done + "ldr q4, [x11, #0x0]\n" + "ldr q5, [x11, #0x10]\n" + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x6e44ec08 // bfmmla v8.4s, v0.8h, v4.8h\n" + "ldr q6, [x10, #0x0]\n" + "ldr q7, [x10, #0x10]\n" + ".inst 0x6e45ec0e // bfmmla v14.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + "ldr q4, [x9, #0x0]\n" + "ldr q5, [x9, #0x10]\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + ".inst 0x6e44ec0a // bfmmla v10.4s, v0.8h, v4.8h\n" + "ldr q6, [x28, #0x0]\n" + "ldr q7, [x28, #0x10]\n" + ".inst 0x6e45ec10 // bfmmla v16.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + "ldr q4, [x27, #0x0]\n" + "ldr q5, [x27, #0x10]\n" + ".inst 0x6e47ec11 // bfmmla v17.4s, v0.8h, v7.8h\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + "ldr q6, [x26, #0x0]\n" + "ldr q7, [x26, #0x10]\n" + ".inst 0x6e45ec12 // bfmmla v18.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec13 // bfmmla v19.4s, v0.8h, v7.8h\n" + "add x11, x11, #0x20\n" + "add x10, x10, #0x20\n" + "add x9, x9, #0x20\n" + "add x28, x28, #0x20\n" + "add x27, x27, #0x20\n" + "add x26, x26, #0x20\n" + "29:" // Height 1: Multiply loop: No odd multiplies + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x25, x25, #0x1\n" + "cmp x25, x19\n" + "bne 21b\n" + "uzp1 v8.2d, v8.2d, v14.2d\n" + "uzp1 v9.2d, v9.2d, v15.2d\n" + "uzp1 v10.2d, v10.2d, v16.2d\n" + "uzp1 v11.2d, v11.2d, v17.2d\n" + "uzp1 v12.2d, v12.2d, v18.2d\n" + "uzp1 v13.2d, v13.2d, v19.2d\n" + "tbz %x[flags], #1, 30f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1r { v1.4s }, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1r { v0.4s }, [x19]\n" + "fmin v8.4s, v8.4s, v1.4s\n" + "fmin v9.4s, v9.4s, v1.4s\n" + "fmin v10.4s, v10.4s, v1.4s\n" + "fmin v11.4s, v11.4s, v1.4s\n" + "fmin v12.4s, v12.4s, v1.4s\n" + "fmin v13.4s, v13.4s, v1.4s\n" + "fmax v8.4s, v8.4s, v0.4s\n" + "fmax v9.4s, v9.4s, v0.4s\n" + "fmax v10.4s, v10.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + "fmax v12.4s, v12.4s, v0.4s\n" + "fmax v13.4s, v13.4s, v0.4s\n" + "30:" // Height 1: No activation + "cmp x13, #0x18\n" + "bge 43f\n" + "tbz x13, #4, 34f\n" + "st1 { v8.4s }, [x12], #0x10\n" + "st1 { v9.4s }, [x12], #0x10\n" + "st1 { v10.4s }, [x12], #0x10\n" + "st1 { v11.4s }, [x12], #0x10\n" + "tbz x13, #2, 32f\n" + "st1 { v12.4s }, [x12], #0x10\n" + "tbz x13, #1, 31f\n" + "str d13, [x12], #0x8\n" + "tbz x13, #0, 42f\n" + "st1 { v13.s }[2], [x12]\n" + "b 42f\n" + "31:" // Height 1: Partial direct writeback: partial_1_20 + "tbz x13, #0, 42f\n" + "str s13, [x12, #0x0]\n" + "b 42f\n" + "32:" // Height 1: Partial direct writeback: partial_2_16 + "tbz x13, #1, 33f\n" + "str d12, [x12], #0x8\n" + "tbz x13, #0, 42f\n" + "st1 { v12.s }[2], [x12]\n" + "b 42f\n" + "33:" // Height 1: Partial direct writeback: partial_1_16 + "tbz x13, #0, 42f\n" + "str s12, [x12, #0x0]\n" + "b 42f\n" + "34:" // Height 1: Partial direct writeback: partial_8_0 + "tbz x13, #3, 38f\n" + "st1 { v8.4s }, [x12], #0x10\n" + "st1 { v9.4s }, [x12], #0x10\n" + "tbz x13, #2, 36f\n" + "st1 { v10.4s }, [x12], #0x10\n" + "tbz x13, #1, 35f\n" + "str d11, [x12], #0x8\n" + "tbz x13, #0, 42f\n" + "st1 { v11.s }[2], [x12]\n" + "b 42f\n" + "35:" // Height 1: Partial direct writeback: partial_1_12 + "tbz x13, #0, 42f\n" + "str s11, [x12, #0x0]\n" + "b 42f\n" + "36:" // Height 1: Partial direct writeback: partial_2_8 + "tbz x13, #1, 37f\n" + "str d10, [x12], #0x8\n" + "tbz x13, #0, 42f\n" + "st1 { v10.s }[2], [x12]\n" + "b 42f\n" + "37:" // Height 1: Partial direct writeback: partial_1_8 + "tbz x13, #0, 42f\n" + "str s10, [x12, #0x0]\n" + "b 42f\n" + "38:" // Height 1: Partial direct writeback: partial_4_0 + "tbz x13, #2, 40f\n" + "st1 { v8.4s }, [x12], #0x10\n" + "tbz x13, #1, 39f\n" + "str d9, [x12], #0x8\n" + "tbz x13, #0, 42f\n" + "st1 { v9.s }[2], [x12]\n" + "b 42f\n" + "39:" // Height 1: Partial direct writeback: partial_1_4 + "tbz x13, #0, 42f\n" + "str s9, [x12, #0x0]\n" + "b 42f\n" + "40:" // Height 1: Partial direct writeback: partial_2_0 + "tbz x13, #1, 41f\n" + "str d8, [x12], #0x8\n" + "tbz x13, #0, 42f\n" + "st1 { v8.s }[2], [x12]\n" + "b 42f\n" + "41:" // Height 1: Partial direct writeback: partial_1_0 + "str s8, [x12, #0x0]\n" + "42:" // Height 1: Partial direct writeback: Done + "b 44f\n" + "43:" // Height 1: Full writeback + "str q8, [x12, #0x0]\n" + "str q9, [x12, #0x10]\n" + "str q10, [x12, #0x20]\n" + "str q11, [x12, #0x30]\n" + "str q12, [x12, #0x40]\n" + "str q13, [x12, #0x50]\n" + "add x12, x12, #0x60\n" + "44:" // Height 1: Writeback done + "subs x13, x13, #0x18\n" + "bgt 2b\n" + "b 178f\n" + "45:" // Height 2 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "46:" // Height 2: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x27, x28, x19, LSL #1\n" + "add x26, x27, x19, LSL #1\n" + "add x19, x26, x19, LSL #1\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "cmp x13, #0x14\n" + "bgt 47f\n" + "cmp x13, #0x10\n" + "mov x26, x11\n" + "bgt 47f\n" + "cmp x13, #0xc\n" + "mov x27, x11\n" + "bgt 47f\n" + "cmp x13, #0x8\n" + "mov x28, x11\n" + "bgt 47f\n" + "cmp x13, #0x4\n" + "mov x9, x11\n" + "bgt 47f\n" + "mov x10, x11\n" + "47:" // Height 2: B setup done + "cbz x14, 48f\n" + "ldr q8, [x14, #0x0]\n" + "ldr q9, [x14, #0x10]\n" + "zip2 v14.2d, v8.2d, v8.2d\n" + "zip1 v8.2d, v8.2d, v8.2d\n" + "ldr q10, [x14, #0x20]\n" + "ldr q11, [x14, #0x30]\n" + "zip2 v15.2d, v9.2d, v9.2d\n" + "zip1 v9.2d, v9.2d, v9.2d\n" + "ldr q12, [x14, #0x40]\n" + "ldr q13, [x14, #0x50]\n" + "zip2 v16.2d, v10.2d, v10.2d\n" + "zip1 v10.2d, v10.2d, v10.2d\n" + "zip2 v17.2d, v11.2d, v11.2d\n" + "zip1 v11.2d, v11.2d, v11.2d\n" + "add x14, x14, #0x60\n" + "zip2 v18.2d, v12.2d, v12.2d\n" + "zip1 v12.2d, v12.2d, v12.2d\n" + "zip2 v19.2d, v13.2d, v13.2d\n" + "zip1 v13.2d, v13.2d, v13.2d\n" + "b 64f\n" + "48:" // Height 2: no bias + "tbz %x[flags], #0, 63f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "cmp x13, #0x18\n" + "add x22, x12, x19, LSL #2\n" + "bge 61f\n" + "tbz x13, #4, 52f\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "ld1 { v14.4s }, [x22], #0x10\n" + "ld1 { v10.4s }, [x12], #0x10\n" + "ld1 { v15.4s }, [x22], #0x10\n" + "ld1 { v11.4s }, [x12], #0x10\n" + "ld1 { v16.4s }, [x22], #0x10\n" + "ld1 { v12.4s }, [x12], #0x10\n" + "ld1 { v17.4s }, [x22], #0x10\n" + "tbz x13, #2, 50f\n" + "ld1 { v13.4s }, [x12], #0x10\n" + "ld1 { v18.4s }, [x22], #0x10\n" + "tbz x13, #1, 49f\n" + "ldr d20, [x12], #0x8\n" + "ldr d19, [x22], #0x8\n" + "mov x19, #0x58\n" + "tbz x13, #0, 60f\n" + "ld1 { v20.s }[2], [x12]\n" + "ld1 { v19.s }[2], [x22]\n" + "b 60f\n" + "49:" // Height 2: Partial accumulate: partial_1_20 + "mov x19, #0x50\n" + "tbz x13, #0, 60f\n" + "ldr s20, [x12, #0x0]\n" + "ldr s19, [x22, #0x0]\n" + "b 60f\n" + "50:" // Height 2: Partial accumulate: partial_2_16 + "tbz x13, #1, 51f\n" + "ldr d13, [x12], #0x8\n" + "ldr d18, [x22], #0x8\n" + "mov x19, #0x48\n" + "tbz x13, #0, 60f\n" + "ld1 { v13.s }[2], [x12]\n" + "ld1 { v18.s }[2], [x22]\n" + "b 60f\n" + "51:" // Height 2: Partial accumulate: partial_1_16 + "mov x19, #0x40\n" + "tbz x13, #0, 60f\n" + "ldr s13, [x12, #0x0]\n" + "ldr s18, [x22, #0x0]\n" + "b 60f\n" + "52:" // Height 2: Partial accumulate: partial_8_0 + "tbz x13, #3, 56f\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "ld1 { v14.4s }, [x22], #0x10\n" + "ld1 { v10.4s }, [x12], #0x10\n" + "ld1 { v15.4s }, [x22], #0x10\n" + "tbz x13, #2, 54f\n" + "ld1 { v11.4s }, [x12], #0x10\n" + "ld1 { v16.4s }, [x22], #0x10\n" + "tbz x13, #1, 53f\n" + "ldr d12, [x12], #0x8\n" + "ldr d17, [x22], #0x8\n" + "mov x19, #0x38\n" + "tbz x13, #0, 60f\n" + "ld1 { v12.s }[2], [x12]\n" + "ld1 { v17.s }[2], [x22]\n" + "b 60f\n" + "53:" // Height 2: Partial accumulate: partial_1_12 + "mov x19, #0x30\n" + "tbz x13, #0, 60f\n" + "ldr s12, [x12, #0x0]\n" + "ldr s17, [x22, #0x0]\n" + "b 60f\n" + "54:" // Height 2: Partial accumulate: partial_2_8 + "tbz x13, #1, 55f\n" + "ldr d11, [x12], #0x8\n" + "ldr d16, [x22], #0x8\n" + "mov x19, #0x28\n" + "tbz x13, #0, 60f\n" + "ld1 { v11.s }[2], [x12]\n" + "ld1 { v16.s }[2], [x22]\n" + "b 60f\n" + "55:" // Height 2: Partial accumulate: partial_1_8 + "mov x19, #0x20\n" + "tbz x13, #0, 60f\n" + "ldr s11, [x12, #0x0]\n" + "ldr s16, [x22, #0x0]\n" + "b 60f\n" + "56:" // Height 2: Partial accumulate: partial_4_0 + "tbz x13, #2, 58f\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "ld1 { v14.4s }, [x22], #0x10\n" + "tbz x13, #1, 57f\n" + "ldr d10, [x12], #0x8\n" + "ldr d15, [x22], #0x8\n" + "mov x19, #0x18\n" + "tbz x13, #0, 60f\n" + "ld1 { v10.s }[2], [x12]\n" + "ld1 { v15.s }[2], [x22]\n" + "b 60f\n" + "57:" // Height 2: Partial accumulate: partial_1_4 + "mov x19, #0x10\n" + "tbz x13, #0, 60f\n" + "ldr s10, [x12, #0x0]\n" + "ldr s15, [x22, #0x0]\n" + "b 60f\n" + "58:" // Height 2: Partial accumulate: partial_2_0 + "tbz x13, #1, 59f\n" + "ldr d9, [x12], #0x8\n" + "ldr d14, [x22], #0x8\n" + "mov x19, #0x8\n" + "tbz x13, #0, 60f\n" + "ld1 { v9.s }[2], [x12]\n" + "ld1 { v14.s }[2], [x22]\n" + "b 60f\n" + "59:" // Height 2: Partial accumulate: partial_1_0 + "ldr s9, [x12, #0x0]\n" + "ldr s14, [x22, #0x0]\n" + "mov x19, #0x0\n" + "60:" // Height 2: Partial accumulate: Done + "sub x12, x12, x19\n" + "b 62f\n" + "61:" // Height 2: full accumulate + "ldr q9, [x12, #0x0]\n" + "ldr q10, [x12, #0x10]\n" + "ldr q11, [x12, #0x20]\n" + "ldr q12, [x12, #0x30]\n" + "ldr q13, [x12, #0x40]\n" + "ldr q20, [x12, #0x50]\n" + "ldr q14, [x22, #0x0]\n" + "ldr q15, [x22, #0x10]\n" + "ldr q16, [x22, #0x20]\n" + "ldr q17, [x22, #0x30]\n" + "ldr q18, [x22, #0x40]\n" + "ldr q19, [x22, #0x50]\n" + "62:" // Height 2: MMLA fixup + "zip1 v8.2d, v9.2d, v14.2d\n" + "zip2 v14.2d, v9.2d, v14.2d\n" + "zip1 v9.2d, v10.2d, v15.2d\n" + "zip2 v15.2d, v10.2d, v15.2d\n" + "zip1 v10.2d, v11.2d, v16.2d\n" + "zip2 v16.2d, v11.2d, v16.2d\n" + "zip1 v11.2d, v12.2d, v17.2d\n" + "zip2 v17.2d, v12.2d, v17.2d\n" + "zip1 v12.2d, v13.2d, v18.2d\n" + "zip2 v18.2d, v13.2d, v18.2d\n" + "zip1 v13.2d, v20.2d, v19.2d\n" + "zip2 v19.2d, v20.2d, v19.2d\n" + "b 64f\n" + "63:" // Height 2: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "64:" // Height 2: setup done + "mov x25, #0x0\n" + "65:" // Height 2: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w24, [x19, x25, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 66f\n" + "ldr x20, [%x[input_ptr], x25, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x23, [x20, #0x0]\n" + "ldr x22, [x20, #0x8]\n" + "cbnz x25, 67f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x23, x23, x19, LSL #2\n" + "add x22, x22, x19, LSL #2\n" + "b 67f\n" + "66:" // Height 2: setup direct input + "mov x23, %x[input_ptr]\n" + "add x22, x23, x19, LSL #2\n" + "67:" // Height 2: input setup done + "cmp x24, #0x4\n" + "blt 70f\n" + "ld1 { v0.4s }, [x23], #0x10\n" + "ld1 { v1.4s }, [x22], #0x10\n" + "cmp x24, #0x8\n" + "ldr q4, [x11, #0x0]\n" + "ldr q5, [x11, #0x10]\n" + "ldr q6, [x10, #0x0]\n" + "ldr q7, [x10, #0x10]\n" + "blt 69f\n" + "68:" // Height 2: Multiply loop: Main loop head + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + "ld1 { v1.4s }, [x22], #0x10\n" + ".inst 0x6e44ec08 // bfmmla v8.4s, v0.8h, v4.8h\n" + "ldr q4, [x9, #0x0]\n" + ".inst 0x6e45ec0e // bfmmla v14.4s, v0.8h, v5.8h\n" + "ldr q5, [x9, #0x10]\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + "ldr q6, [x28, #0x0]\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + "ldr q7, [x28, #0x10]\n" + ".inst 0x6e44ec0a // bfmmla v10.4s, v0.8h, v4.8h\n" + "ldr q4, [x27, #0x0]\n" + ".inst 0x6e45ec10 // bfmmla v16.4s, v0.8h, v5.8h\n" + "ldr q5, [x27, #0x10]\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + "ldr q6, [x26, #0x0]\n" + ".inst 0x6e47ec11 // bfmmla v17.4s, v0.8h, v7.8h\n" + "ldr q7, [x26, #0x10]\n" + "sub x24, x24, #0x4\n" + "cmp x24, #0x8\n" + "add x11, x11, #0x20\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + "ldr q4, [x11, #0x0]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e45ec12 // bfmmla v18.4s, v0.8h, v5.8h\n" + "ldr q5, [x11, #0x10]\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + "ldr q6, [x10, #0x0]\n" + ".inst 0x6e47ec13 // bfmmla v19.4s, v0.8h, v7.8h\n" + "ld1 { v0.4s }, [x23], #0x10\n" + "add x9, x9, #0x20\n" + "ldr q7, [x10, #0x10]\n" + "add x28, x28, #0x20\n" + "add x27, x27, #0x20\n" + "add x26, x26, #0x20\n" + "bge 68b\n" + "69:" // Height 2: Multiply loop: Single iteration only + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x6e44ec08 // bfmmla v8.4s, v0.8h, v4.8h\n" + "ldr q4, [x9, #0x0]\n" + ".inst 0x6e45ec0e // bfmmla v14.4s, v0.8h, v5.8h\n" + "ldr q5, [x9, #0x10]\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + "ldr q6, [x28, #0x0]\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + "ldr q7, [x28, #0x10]\n" + ".inst 0x6e44ec0a // bfmmla v10.4s, v0.8h, v4.8h\n" + "ldr q4, [x27, #0x0]\n" + ".inst 0x6e45ec10 // bfmmla v16.4s, v0.8h, v5.8h\n" + "ldr q5, [x27, #0x10]\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + "ldr q6, [x26, #0x0]\n" + ".inst 0x6e47ec11 // bfmmla v17.4s, v0.8h, v7.8h\n" + "ldr q7, [x26, #0x10]\n" + "sub x24, x24, #0x4\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + ".inst 0x6e45ec12 // bfmmla v18.4s, v0.8h, v5.8h\n" + "add x11, x11, #0x20\n" + "add x10, x10, #0x20\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec13 // bfmmla v19.4s, v0.8h, v7.8h\n" + "add x9, x9, #0x20\n" + "add x28, x28, #0x20\n" + "add x27, x27, #0x20\n" + "add x26, x26, #0x20\n" + "70:" // Height 2: Multiply loop: Main loop skip + "cbz x24, 73f\n" + "cbz x24, 73f\n" + "tbz x24, #1, 71f\n" + "ldr d0, [x23], #0x8\n" + "ldr d1, [x22], #0x8\n" + "tbz x24, #0, 72f\n" + "ld1 { v0.s }[2], [x23]\n" + "ld1 { v1.s }[2], [x22]\n" + "b 72f\n" + "71:" // Height 2: Multiply loop: Ragged operand read: partial_1_0 + "ldr s0, [x23, #0x0]\n" + "ldr s1, [x22, #0x0]\n" + "72:" // Height 2: Multiply loop: Ragged operand read: Done + "ldr q4, [x11, #0x0]\n" + "ldr q5, [x11, #0x10]\n" + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + "ldr q6, [x10, #0x0]\n" + "ldr q7, [x10, #0x10]\n" + ".inst 0x6e44ec08 // bfmmla v8.4s, v0.8h, v4.8h\n" + ".inst 0x6e45ec0e // bfmmla v14.4s, v0.8h, v5.8h\n" + "ldr q4, [x9, #0x0]\n" + "ldr q5, [x9, #0x10]\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + "ldr q6, [x28, #0x0]\n" + "ldr q7, [x28, #0x10]\n" + ".inst 0x6e44ec0a // bfmmla v10.4s, v0.8h, v4.8h\n" + ".inst 0x6e45ec10 // bfmmla v16.4s, v0.8h, v5.8h\n" + "ldr q4, [x27, #0x0]\n" + "ldr q5, [x27, #0x10]\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec11 // bfmmla v17.4s, v0.8h, v7.8h\n" + "ldr q6, [x26, #0x0]\n" + "ldr q7, [x26, #0x10]\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + ".inst 0x6e45ec12 // bfmmla v18.4s, v0.8h, v5.8h\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec13 // bfmmla v19.4s, v0.8h, v7.8h\n" + "add x11, x11, #0x20\n" + "add x10, x10, #0x20\n" + "add x9, x9, #0x20\n" + "add x28, x28, #0x20\n" + "add x27, x27, #0x20\n" + "add x26, x26, #0x20\n" + "73:" // Height 2: Multiply loop: No odd multiplies + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x25, x25, #0x1\n" + "cmp x25, x19\n" + "bne 65b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "uzp1 v4.2d, v8.2d, v14.2d\n" + "uzp2 v8.2d, v8.2d, v14.2d\n" + "add x22, x12, x19, LSL #2\n" + "uzp1 v14.2d, v9.2d, v15.2d\n" + "uzp2 v9.2d, v9.2d, v15.2d\n" + "uzp1 v15.2d, v10.2d, v16.2d\n" + "uzp2 v10.2d, v10.2d, v16.2d\n" + "uzp1 v16.2d, v11.2d, v17.2d\n" + "uzp2 v11.2d, v11.2d, v17.2d\n" + "uzp1 v17.2d, v12.2d, v18.2d\n" + "uzp2 v12.2d, v12.2d, v18.2d\n" + "uzp1 v18.2d, v13.2d, v19.2d\n" + "uzp2 v13.2d, v13.2d, v19.2d\n" + "tbz %x[flags], #1, 74f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1r { v1.4s }, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1r { v0.4s }, [x19]\n" + "fmin v4.4s, v4.4s, v1.4s\n" + "fmin v14.4s, v14.4s, v1.4s\n" + "fmin v15.4s, v15.4s, v1.4s\n" + "fmin v16.4s, v16.4s, v1.4s\n" + "fmin v17.4s, v17.4s, v1.4s\n" + "fmin v18.4s, v18.4s, v1.4s\n" + "fmin v8.4s, v8.4s, v1.4s\n" + "fmin v9.4s, v9.4s, v1.4s\n" + "fmin v10.4s, v10.4s, v1.4s\n" + "fmin v11.4s, v11.4s, v1.4s\n" + "fmin v12.4s, v12.4s, v1.4s\n" + "fmin v13.4s, v13.4s, v1.4s\n" + "fmax v4.4s, v4.4s, v0.4s\n" + "fmax v14.4s, v14.4s, v0.4s\n" + "fmax v15.4s, v15.4s, v0.4s\n" + "fmax v16.4s, v16.4s, v0.4s\n" + "fmax v17.4s, v17.4s, v0.4s\n" + "fmax v18.4s, v18.4s, v0.4s\n" + "fmax v8.4s, v8.4s, v0.4s\n" + "fmax v9.4s, v9.4s, v0.4s\n" + "fmax v10.4s, v10.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + "fmax v12.4s, v12.4s, v0.4s\n" + "fmax v13.4s, v13.4s, v0.4s\n" + "74:" // Height 2: No activation + "cmp x13, #0x18\n" + "bge 87f\n" + "tbz x13, #4, 78f\n" + "st1 { v4.4s }, [x12], #0x10\n" + "st1 { v14.4s }, [x12], #0x10\n" + "st1 { v15.4s }, [x12], #0x10\n" + "st1 { v16.4s }, [x12], #0x10\n" + "st1 { v8.4s }, [x22], #0x10\n" + "st1 { v9.4s }, [x22], #0x10\n" + "st1 { v10.4s }, [x22], #0x10\n" + "st1 { v11.4s }, [x22], #0x10\n" + "tbz x13, #2, 76f\n" + "st1 { v17.4s }, [x12], #0x10\n" + "st1 { v12.4s }, [x22], #0x10\n" + "tbz x13, #1, 75f\n" + "str d18, [x12], #0x8\n" + "str d13, [x22], #0x8\n" + "tbz x13, #0, 86f\n" + "st1 { v18.s }[2], [x12]\n" + "st1 { v13.s }[2], [x22]\n" + "b 86f\n" + "75:" // Height 2: Partial direct writeback: partial_1_20 + "tbz x13, #0, 86f\n" + "str s18, [x12, #0x0]\n" + "str s13, [x22, #0x0]\n" + "b 86f\n" + "76:" // Height 2: Partial direct writeback: partial_2_16 + "tbz x13, #1, 77f\n" + "str d17, [x12], #0x8\n" + "str d12, [x22], #0x8\n" + "tbz x13, #0, 86f\n" + "st1 { v17.s }[2], [x12]\n" + "st1 { v12.s }[2], [x22]\n" + "b 86f\n" + "77:" // Height 2: Partial direct writeback: partial_1_16 + "tbz x13, #0, 86f\n" + "str s17, [x12, #0x0]\n" + "str s12, [x22, #0x0]\n" + "b 86f\n" + "78:" // Height 2: Partial direct writeback: partial_8_0 + "tbz x13, #3, 82f\n" + "st1 { v4.4s }, [x12], #0x10\n" + "st1 { v14.4s }, [x12], #0x10\n" + "st1 { v8.4s }, [x22], #0x10\n" + "st1 { v9.4s }, [x22], #0x10\n" + "tbz x13, #2, 80f\n" + "st1 { v15.4s }, [x12], #0x10\n" + "st1 { v10.4s }, [x22], #0x10\n" + "tbz x13, #1, 79f\n" + "str d16, [x12], #0x8\n" + "str d11, [x22], #0x8\n" + "tbz x13, #0, 86f\n" + "st1 { v16.s }[2], [x12]\n" + "st1 { v11.s }[2], [x22]\n" + "b 86f\n" + "79:" // Height 2: Partial direct writeback: partial_1_12 + "tbz x13, #0, 86f\n" + "str s16, [x12, #0x0]\n" + "str s11, [x22, #0x0]\n" + "b 86f\n" + "80:" // Height 2: Partial direct writeback: partial_2_8 + "tbz x13, #1, 81f\n" + "str d15, [x12], #0x8\n" + "str d10, [x22], #0x8\n" + "tbz x13, #0, 86f\n" + "st1 { v15.s }[2], [x12]\n" + "st1 { v10.s }[2], [x22]\n" + "b 86f\n" + "81:" // Height 2: Partial direct writeback: partial_1_8 + "tbz x13, #0, 86f\n" + "str s15, [x12, #0x0]\n" + "str s10, [x22, #0x0]\n" + "b 86f\n" + "82:" // Height 2: Partial direct writeback: partial_4_0 + "tbz x13, #2, 84f\n" + "st1 { v4.4s }, [x12], #0x10\n" + "st1 { v8.4s }, [x22], #0x10\n" + "tbz x13, #1, 83f\n" + "str d14, [x12], #0x8\n" + "str d9, [x22], #0x8\n" + "tbz x13, #0, 86f\n" + "st1 { v14.s }[2], [x12]\n" + "st1 { v9.s }[2], [x22]\n" + "b 86f\n" + "83:" // Height 2: Partial direct writeback: partial_1_4 + "tbz x13, #0, 86f\n" + "str s14, [x12, #0x0]\n" + "str s9, [x22, #0x0]\n" + "b 86f\n" + "84:" // Height 2: Partial direct writeback: partial_2_0 + "tbz x13, #1, 85f\n" + "str d4, [x12], #0x8\n" + "str d8, [x22], #0x8\n" + "tbz x13, #0, 86f\n" + "st1 { v4.s }[2], [x12]\n" + "st1 { v8.s }[2], [x22]\n" + "b 86f\n" + "85:" // Height 2: Partial direct writeback: partial_1_0 + "str s4, [x12, #0x0]\n" + "str s8, [x22, #0x0]\n" + "86:" // Height 2: Partial direct writeback: Done + "b 88f\n" + "87:" // Height 2: Full writeback + "str q4, [x12, #0x0]\n" + "str q14, [x12, #0x10]\n" + "str q15, [x12, #0x20]\n" + "str q16, [x12, #0x30]\n" + "str q17, [x12, #0x40]\n" + "str q18, [x12, #0x50]\n" + "add x12, x12, #0x60\n" + "str q8, [x22, #0x0]\n" + "str q9, [x22, #0x10]\n" + "str q10, [x22, #0x20]\n" + "str q11, [x22, #0x30]\n" + "str q12, [x22, #0x40]\n" + "str q13, [x22, #0x50]\n" + "88:" // Height 2: Writeback done + "subs x13, x13, #0x18\n" + "bgt 46b\n" + "b 178f\n" + "89:" // Height 3 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "90:" // Height 3: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x27, x28, x19, LSL #1\n" + "add x26, x27, x19, LSL #1\n" + "add x19, x26, x19, LSL #1\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "cmp x13, #0x14\n" + "bgt 91f\n" + "cmp x13, #0x10\n" + "mov x26, x11\n" + "bgt 91f\n" + "cmp x13, #0xc\n" + "mov x27, x11\n" + "bgt 91f\n" + "cmp x13, #0x8\n" + "mov x28, x11\n" + "bgt 91f\n" + "cmp x13, #0x4\n" + "mov x9, x11\n" + "bgt 91f\n" + "mov x10, x11\n" + "91:" // Height 3: B setup done + "cbz x14, 92f\n" + "ldr q8, [x14, #0x0]\n" + "ldr q9, [x14, #0x10]\n" + "zip2 v14.2d, v8.2d, v8.2d\n" + "zip1 v8.2d, v8.2d, v8.2d\n" + "ldr q10, [x14, #0x20]\n" + "ldr q11, [x14, #0x30]\n" + "zip2 v15.2d, v9.2d, v9.2d\n" + "zip1 v9.2d, v9.2d, v9.2d\n" + "ldr q12, [x14, #0x40]\n" + "ldr q13, [x14, #0x50]\n" + "zip2 v16.2d, v10.2d, v10.2d\n" + "zip1 v10.2d, v10.2d, v10.2d\n" + "zip2 v17.2d, v11.2d, v11.2d\n" + "zip1 v11.2d, v11.2d, v11.2d\n" + "add x14, x14, #0x60\n" + "zip2 v18.2d, v12.2d, v12.2d\n" + "zip1 v12.2d, v12.2d, v12.2d\n" + "zip2 v19.2d, v13.2d, v13.2d\n" + "zip1 v13.2d, v13.2d, v13.2d\n" + "mov v20.16b, v8.16b\n" + "mov v26.16b, v14.16b\n" + "mov v21.16b, v9.16b\n" + "mov v27.16b, v15.16b\n" + "mov v22.16b, v10.16b\n" + "mov v28.16b, v16.16b\n" + "mov v23.16b, v11.16b\n" + "mov v29.16b, v17.16b\n" + "mov v24.16b, v12.16b\n" + "mov v30.16b, v18.16b\n" + "mov v25.16b, v13.16b\n" + "mov v31.16b, v19.16b\n" + "b 108f\n" + "92:" // Height 3: no bias + "tbz %x[flags], #0, 107f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x22, x12, x19, LSL #2\n" + "cmp x13, #0x18\n" + "add x21, x22, x19, LSL #2\n" + "bge 105f\n" + "tbz x13, #4, 96f\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "ld1 { v14.4s }, [x22], #0x10\n" + "ld1 { v21.4s }, [x21], #0x10\n" + "ld1 { v10.4s }, [x12], #0x10\n" + "ld1 { v15.4s }, [x22], #0x10\n" + "ld1 { v22.4s }, [x21], #0x10\n" + "ld1 { v11.4s }, [x12], #0x10\n" + "ld1 { v16.4s }, [x22], #0x10\n" + "ld1 { v23.4s }, [x21], #0x10\n" + "ld1 { v12.4s }, [x12], #0x10\n" + "ld1 { v17.4s }, [x22], #0x10\n" + "ld1 { v24.4s }, [x21], #0x10\n" + "tbz x13, #2, 94f\n" + "ld1 { v13.4s }, [x12], #0x10\n" + "ld1 { v18.4s }, [x22], #0x10\n" + "ld1 { v25.4s }, [x21], #0x10\n" + "tbz x13, #1, 93f\n" + "ldr d20, [x12], #0x8\n" + "ldr d19, [x22], #0x8\n" + "mov x19, #0x58\n" + "ldr d4, [x21], #0x8\n" + "tbz x13, #0, 104f\n" + "ld1 { v20.s }[2], [x12]\n" + "ld1 { v19.s }[2], [x22]\n" + "ld1 { v4.s }[2], [x21]\n" + "b 104f\n" + "93:" // Height 3: Partial accumulate: partial_1_20 + "mov x19, #0x50\n" + "tbz x13, #0, 104f\n" + "ldr s20, [x12, #0x0]\n" + "ldr s19, [x22, #0x0]\n" + "ldr s4, [x21, #0x0]\n" + "b 104f\n" + "94:" // Height 3: Partial accumulate: partial_2_16 + "tbz x13, #1, 95f\n" + "ldr d13, [x12], #0x8\n" + "ldr d18, [x22], #0x8\n" + "mov x19, #0x48\n" + "ldr d25, [x21], #0x8\n" + "tbz x13, #0, 104f\n" + "ld1 { v13.s }[2], [x12]\n" + "ld1 { v18.s }[2], [x22]\n" + "ld1 { v25.s }[2], [x21]\n" + "b 104f\n" + "95:" // Height 3: Partial accumulate: partial_1_16 + "mov x19, #0x40\n" + "tbz x13, #0, 104f\n" + "ldr s13, [x12, #0x0]\n" + "ldr s18, [x22, #0x0]\n" + "ldr s25, [x21, #0x0]\n" + "b 104f\n" + "96:" // Height 3: Partial accumulate: partial_8_0 + "tbz x13, #3, 100f\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "ld1 { v14.4s }, [x22], #0x10\n" + "ld1 { v21.4s }, [x21], #0x10\n" + "ld1 { v10.4s }, [x12], #0x10\n" + "ld1 { v15.4s }, [x22], #0x10\n" + "ld1 { v22.4s }, [x21], #0x10\n" + "tbz x13, #2, 98f\n" + "ld1 { v11.4s }, [x12], #0x10\n" + "ld1 { v16.4s }, [x22], #0x10\n" + "ld1 { v23.4s }, [x21], #0x10\n" + "tbz x13, #1, 97f\n" + "ldr d12, [x12], #0x8\n" + "ldr d17, [x22], #0x8\n" + "mov x19, #0x38\n" + "ldr d24, [x21], #0x8\n" + "tbz x13, #0, 104f\n" + "ld1 { v12.s }[2], [x12]\n" + "ld1 { v17.s }[2], [x22]\n" + "ld1 { v24.s }[2], [x21]\n" + "b 104f\n" + "97:" // Height 3: Partial accumulate: partial_1_12 + "mov x19, #0x30\n" + "tbz x13, #0, 104f\n" + "ldr s12, [x12, #0x0]\n" + "ldr s17, [x22, #0x0]\n" + "ldr s24, [x21, #0x0]\n" + "b 104f\n" + "98:" // Height 3: Partial accumulate: partial_2_8 + "tbz x13, #1, 99f\n" + "ldr d11, [x12], #0x8\n" + "ldr d16, [x22], #0x8\n" + "mov x19, #0x28\n" + "ldr d23, [x21], #0x8\n" + "tbz x13, #0, 104f\n" + "ld1 { v11.s }[2], [x12]\n" + "ld1 { v16.s }[2], [x22]\n" + "ld1 { v23.s }[2], [x21]\n" + "b 104f\n" + "99:" // Height 3: Partial accumulate: partial_1_8 + "mov x19, #0x20\n" + "tbz x13, #0, 104f\n" + "ldr s11, [x12, #0x0]\n" + "ldr s16, [x22, #0x0]\n" + "ldr s23, [x21, #0x0]\n" + "b 104f\n" + "100:" // Height 3: Partial accumulate: partial_4_0 + "tbz x13, #2, 102f\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "ld1 { v14.4s }, [x22], #0x10\n" + "ld1 { v21.4s }, [x21], #0x10\n" + "tbz x13, #1, 101f\n" + "ldr d10, [x12], #0x8\n" + "ldr d15, [x22], #0x8\n" + "mov x19, #0x18\n" + "ldr d22, [x21], #0x8\n" + "tbz x13, #0, 104f\n" + "ld1 { v10.s }[2], [x12]\n" + "ld1 { v15.s }[2], [x22]\n" + "ld1 { v22.s }[2], [x21]\n" + "b 104f\n" + "101:" // Height 3: Partial accumulate: partial_1_4 + "mov x19, #0x10\n" + "tbz x13, #0, 104f\n" + "ldr s10, [x12, #0x0]\n" + "ldr s15, [x22, #0x0]\n" + "ldr s22, [x21, #0x0]\n" + "b 104f\n" + "102:" // Height 3: Partial accumulate: partial_2_0 + "tbz x13, #1, 103f\n" + "ldr d9, [x12], #0x8\n" + "ldr d14, [x22], #0x8\n" + "mov x19, #0x8\n" + "ldr d21, [x21], #0x8\n" + "tbz x13, #0, 104f\n" + "ld1 { v9.s }[2], [x12]\n" + "ld1 { v14.s }[2], [x22]\n" + "ld1 { v21.s }[2], [x21]\n" + "b 104f\n" + "103:" // Height 3: Partial accumulate: partial_1_0 + "ldr s9, [x12, #0x0]\n" + "ldr s14, [x22, #0x0]\n" + "mov x19, #0x0\n" + "ldr s21, [x21, #0x0]\n" + "104:" // Height 3: Partial accumulate: Done + "sub x12, x12, x19\n" + "b 106f\n" + "105:" // Height 3: full accumulate + "ldr q9, [x12, #0x0]\n" + "ldr q10, [x12, #0x10]\n" + "ldr q11, [x12, #0x20]\n" + "ldr q12, [x12, #0x30]\n" + "ldr q13, [x12, #0x40]\n" + "ldr q20, [x12, #0x50]\n" + "ldr q14, [x22, #0x0]\n" + "ldr q15, [x22, #0x10]\n" + "ldr q16, [x22, #0x20]\n" + "ldr q17, [x22, #0x30]\n" + "ldr q18, [x22, #0x40]\n" + "ldr q19, [x22, #0x50]\n" + "ldr q21, [x21, #0x0]\n" + "ldr q22, [x21, #0x10]\n" + "ldr q23, [x21, #0x20]\n" + "ldr q24, [x21, #0x30]\n" + "ldr q25, [x21, #0x40]\n" + "ldr q4, [x21, #0x50]\n" + "106:" // Height 3: MMLA fixup + "zip1 v8.2d, v9.2d, v14.2d\n" + "zip2 v14.2d, v9.2d, v14.2d\n" + "zip1 v9.2d, v10.2d, v15.2d\n" + "zip2 v15.2d, v10.2d, v15.2d\n" + "zip1 v10.2d, v11.2d, v16.2d\n" + "zip2 v16.2d, v11.2d, v16.2d\n" + "zip1 v11.2d, v12.2d, v17.2d\n" + "zip2 v17.2d, v12.2d, v17.2d\n" + "zip1 v12.2d, v13.2d, v18.2d\n" + "zip2 v18.2d, v13.2d, v18.2d\n" + "zip1 v13.2d, v20.2d, v19.2d\n" + "zip2 v19.2d, v20.2d, v19.2d\n" + "zip1 v20.2d, v21.2d, v26.2d\n" + "zip2 v26.2d, v21.2d, v26.2d\n" + "zip1 v21.2d, v22.2d, v27.2d\n" + "zip2 v27.2d, v22.2d, v27.2d\n" + "zip1 v22.2d, v23.2d, v28.2d\n" + "zip2 v28.2d, v23.2d, v28.2d\n" + "zip1 v23.2d, v24.2d, v29.2d\n" + "zip2 v29.2d, v24.2d, v29.2d\n" + "zip1 v24.2d, v25.2d, v30.2d\n" + "zip2 v30.2d, v25.2d, v30.2d\n" + "zip1 v25.2d, v4.2d, v31.2d\n" + "zip2 v31.2d, v4.2d, v31.2d\n" + "b 108f\n" + "107:" // Height 3: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v26.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "movi v29.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "108:" // Height 3: setup done + "mov x25, #0x0\n" + "109:" // Height 3: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w24, [x19, x25, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 110f\n" + "ldr x20, [%x[input_ptr], x25, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x23, [x20, #0x0]\n" + "ldr x22, [x20, #0x8]\n" + "ldr x21, [x20, #0x10]\n" + "cbnz x25, 111f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x23, x23, x19, LSL #2\n" + "add x22, x22, x19, LSL #2\n" + "add x21, x21, x19, LSL #2\n" + "b 111f\n" + "110:" // Height 3: setup direct input + "mov x23, %x[input_ptr]\n" + "add x22, x23, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "111:" // Height 3: input setup done + "cmp x24, #0x4\n" + "blt 114f\n" + "ld1 { v0.4s }, [x23], #0x10\n" + "ld1 { v1.4s }, [x22], #0x10\n" + "cmp x24, #0x8\n" + "ld1 { v2.4s }, [x21], #0x10\n" + "ldr q4, [x11, #0x0]\n" + "ldr q5, [x11, #0x10]\n" + "ldr q6, [x10, #0x0]\n" + "ldr q7, [x10, #0x10]\n" + "blt 113f\n" + "112:" // Height 3: Multiply loop: Main loop head + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + "ld1 { v1.4s }, [x22], #0x10\n" + ".inst 0x6e44ec08 // bfmmla v8.4s, v0.8h, v4.8h\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + ".inst 0x6e44ec54 // bfmmla v20.4s, v2.8h, v4.8h\n" + "ldr q4, [x9, #0x0]\n" + ".inst 0x6e45ec0e // bfmmla v14.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec5a // bfmmla v26.4s, v2.8h, v5.8h\n" + "ldr q5, [x9, #0x10]\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + "sub x24, x24, #0x4\n" + ".inst 0x6e46ec55 // bfmmla v21.4s, v2.8h, v6.8h\n" + "ldr q6, [x28, #0x0]\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + "cmp x24, #0x8\n" + ".inst 0x6e47ec5b // bfmmla v27.4s, v2.8h, v7.8h\n" + "ldr q7, [x28, #0x10]\n" + ".inst 0x6e44ec0a // bfmmla v10.4s, v0.8h, v4.8h\n" + "add x11, x11, #0x20\n" + ".inst 0x6e44ec56 // bfmmla v22.4s, v2.8h, v4.8h\n" + "ldr q4, [x27, #0x0]\n" + ".inst 0x6e45ec10 // bfmmla v16.4s, v0.8h, v5.8h\n" + "add x10, x10, #0x20\n" + ".inst 0x6e45ec5c // bfmmla v28.4s, v2.8h, v5.8h\n" + "ldr q5, [x27, #0x10]\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + "add x9, x9, #0x20\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + "ldr q6, [x26, #0x0]\n" + ".inst 0x6e47ec11 // bfmmla v17.4s, v0.8h, v7.8h\n" + "add x28, x28, #0x20\n" + ".inst 0x6e47ec5d // bfmmla v29.4s, v2.8h, v7.8h\n" + "ldr q7, [x26, #0x10]\n" + "add x27, x27, #0x20\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + "add x26, x26, #0x20\n" + ".inst 0x6e44ec58 // bfmmla v24.4s, v2.8h, v4.8h\n" + "ldr q4, [x11, #0x0]\n" + ".inst 0x6e45ec12 // bfmmla v18.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec5e // bfmmla v30.4s, v2.8h, v5.8h\n" + "ldr q5, [x11, #0x10]\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec59 // bfmmla v25.4s, v2.8h, v6.8h\n" + "ldr q6, [x10, #0x0]\n" + ".inst 0x6e47ec13 // bfmmla v19.4s, v0.8h, v7.8h\n" + "ld1 { v0.4s }, [x23], #0x10\n" + ".inst 0x6e47ec5f // bfmmla v31.4s, v2.8h, v7.8h\n" + "ld1 { v2.4s }, [x21], #0x10\n" + "ldr q7, [x10, #0x10]\n" + "bge 112b\n" + "113:" // Height 3: Multiply loop: Single iteration only + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x6e44ec08 // bfmmla v8.4s, v0.8h, v4.8h\n" + "sub x24, x24, #0x4\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + ".inst 0x6e44ec54 // bfmmla v20.4s, v2.8h, v4.8h\n" + "ldr q4, [x9, #0x0]\n" + ".inst 0x6e45ec0e // bfmmla v14.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec5a // bfmmla v26.4s, v2.8h, v5.8h\n" + "ldr q5, [x9, #0x10]\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + "add x11, x11, #0x20\n" + ".inst 0x6e46ec55 // bfmmla v21.4s, v2.8h, v6.8h\n" + "ldr q6, [x28, #0x0]\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + "add x10, x10, #0x20\n" + ".inst 0x6e47ec5b // bfmmla v27.4s, v2.8h, v7.8h\n" + "ldr q7, [x28, #0x10]\n" + ".inst 0x6e44ec0a // bfmmla v10.4s, v0.8h, v4.8h\n" + "add x9, x9, #0x20\n" + ".inst 0x6e44ec56 // bfmmla v22.4s, v2.8h, v4.8h\n" + "ldr q4, [x27, #0x0]\n" + ".inst 0x6e45ec10 // bfmmla v16.4s, v0.8h, v5.8h\n" + "add x28, x28, #0x20\n" + ".inst 0x6e45ec5c // bfmmla v28.4s, v2.8h, v5.8h\n" + "ldr q5, [x27, #0x10]\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + "add x27, x27, #0x20\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + "ldr q6, [x26, #0x0]\n" + ".inst 0x6e47ec11 // bfmmla v17.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec5d // bfmmla v29.4s, v2.8h, v7.8h\n" + "ldr q7, [x26, #0x10]\n" + "add x26, x26, #0x20\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + ".inst 0x6e44ec58 // bfmmla v24.4s, v2.8h, v4.8h\n" + ".inst 0x6e45ec12 // bfmmla v18.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec5e // bfmmla v30.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec59 // bfmmla v25.4s, v2.8h, v6.8h\n" + ".inst 0x6e47ec13 // bfmmla v19.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec5f // bfmmla v31.4s, v2.8h, v7.8h\n" + "114:" // Height 3: Multiply loop: Main loop skip + "cbz x24, 117f\n" + "cbz x24, 117f\n" + "tbz x24, #1, 115f\n" + "ldr d0, [x23], #0x8\n" + "ldr d1, [x22], #0x8\n" + "ldr d2, [x21], #0x8\n" + "tbz x24, #0, 116f\n" + "ld1 { v0.s }[2], [x23]\n" + "ld1 { v1.s }[2], [x22]\n" + "ld1 { v2.s }[2], [x21]\n" + "b 116f\n" + "115:" // Height 3: Multiply loop: Ragged operand read: partial_1_0 + "ldr s0, [x23, #0x0]\n" + "ldr s1, [x22, #0x0]\n" + "ldr s2, [x21, #0x0]\n" + "116:" // Height 3: Multiply loop: Ragged operand read: Done + "ldr q4, [x11, #0x0]\n" + "ldr q5, [x11, #0x10]\n" + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + "ldr q6, [x10, #0x0]\n" + "ldr q7, [x10, #0x10]\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + ".inst 0x6e44ec08 // bfmmla v8.4s, v0.8h, v4.8h\n" + ".inst 0x6e44ec54 // bfmmla v20.4s, v2.8h, v4.8h\n" + "ldr q4, [x9, #0x0]\n" + ".inst 0x6e45ec0e // bfmmla v14.4s, v0.8h, v5.8h\n" + "add x11, x11, #0x20\n" + ".inst 0x6e45ec5a // bfmmla v26.4s, v2.8h, v5.8h\n" + "ldr q5, [x9, #0x10]\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + "add x10, x10, #0x20\n" + ".inst 0x6e46ec55 // bfmmla v21.4s, v2.8h, v6.8h\n" + "ldr q6, [x28, #0x0]\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + "add x9, x9, #0x20\n" + ".inst 0x6e47ec5b // bfmmla v27.4s, v2.8h, v7.8h\n" + "ldr q7, [x28, #0x10]\n" + ".inst 0x6e44ec0a // bfmmla v10.4s, v0.8h, v4.8h\n" + "add x28, x28, #0x20\n" + ".inst 0x6e44ec56 // bfmmla v22.4s, v2.8h, v4.8h\n" + "ldr q4, [x27, #0x0]\n" + ".inst 0x6e45ec10 // bfmmla v16.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec5c // bfmmla v28.4s, v2.8h, v5.8h\n" + "ldr q5, [x27, #0x10]\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + "add x27, x27, #0x20\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + "ldr q6, [x26, #0x0]\n" + ".inst 0x6e47ec11 // bfmmla v17.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec5d // bfmmla v29.4s, v2.8h, v7.8h\n" + "ldr q7, [x26, #0x10]\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + "add x26, x26, #0x20\n" + ".inst 0x6e44ec58 // bfmmla v24.4s, v2.8h, v4.8h\n" + ".inst 0x6e45ec12 // bfmmla v18.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec5e // bfmmla v30.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec59 // bfmmla v25.4s, v2.8h, v6.8h\n" + ".inst 0x6e47ec13 // bfmmla v19.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec5f // bfmmla v31.4s, v2.8h, v7.8h\n" + "117:" // Height 3: Multiply loop: No odd multiplies + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x25, x25, #0x1\n" + "cmp x25, x19\n" + "bne 109b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x22, x12, x19, LSL #2\n" + "uzp1 v4.2d, v8.2d, v14.2d\n" + "uzp2 v8.2d, v8.2d, v14.2d\n" + "uzp1 v14.2d, v9.2d, v15.2d\n" + "uzp2 v9.2d, v9.2d, v15.2d\n" + "add x21, x22, x19, LSL #2\n" + "uzp1 v15.2d, v10.2d, v16.2d\n" + "uzp2 v10.2d, v10.2d, v16.2d\n" + "uzp1 v16.2d, v11.2d, v17.2d\n" + "uzp2 v11.2d, v11.2d, v17.2d\n" + "uzp1 v17.2d, v12.2d, v18.2d\n" + "uzp2 v12.2d, v12.2d, v18.2d\n" + "uzp1 v18.2d, v13.2d, v19.2d\n" + "uzp2 v13.2d, v13.2d, v19.2d\n" + "uzp1 v20.2d, v20.2d, v26.2d\n" + "uzp1 v21.2d, v21.2d, v27.2d\n" + "uzp1 v22.2d, v22.2d, v28.2d\n" + "uzp1 v23.2d, v23.2d, v29.2d\n" + "uzp1 v24.2d, v24.2d, v30.2d\n" + "uzp1 v25.2d, v25.2d, v31.2d\n" + "tbz %x[flags], #1, 118f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1r { v1.4s }, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1r { v0.4s }, [x19]\n" + "fmin v4.4s, v4.4s, v1.4s\n" + "fmin v14.4s, v14.4s, v1.4s\n" + "fmin v15.4s, v15.4s, v1.4s\n" + "fmin v16.4s, v16.4s, v1.4s\n" + "fmin v17.4s, v17.4s, v1.4s\n" + "fmin v18.4s, v18.4s, v1.4s\n" + "fmin v8.4s, v8.4s, v1.4s\n" + "fmin v9.4s, v9.4s, v1.4s\n" + "fmin v10.4s, v10.4s, v1.4s\n" + "fmin v11.4s, v11.4s, v1.4s\n" + "fmin v12.4s, v12.4s, v1.4s\n" + "fmin v13.4s, v13.4s, v1.4s\n" + "fmin v20.4s, v20.4s, v1.4s\n" + "fmin v21.4s, v21.4s, v1.4s\n" + "fmin v22.4s, v22.4s, v1.4s\n" + "fmin v23.4s, v23.4s, v1.4s\n" + "fmin v24.4s, v24.4s, v1.4s\n" + "fmin v25.4s, v25.4s, v1.4s\n" + "fmax v4.4s, v4.4s, v0.4s\n" + "fmax v14.4s, v14.4s, v0.4s\n" + "fmax v15.4s, v15.4s, v0.4s\n" + "fmax v16.4s, v16.4s, v0.4s\n" + "fmax v17.4s, v17.4s, v0.4s\n" + "fmax v18.4s, v18.4s, v0.4s\n" + "fmax v8.4s, v8.4s, v0.4s\n" + "fmax v9.4s, v9.4s, v0.4s\n" + "fmax v10.4s, v10.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + "fmax v12.4s, v12.4s, v0.4s\n" + "fmax v13.4s, v13.4s, v0.4s\n" + "fmax v20.4s, v20.4s, v0.4s\n" + "fmax v21.4s, v21.4s, v0.4s\n" + "fmax v22.4s, v22.4s, v0.4s\n" + "fmax v23.4s, v23.4s, v0.4s\n" + "fmax v24.4s, v24.4s, v0.4s\n" + "fmax v25.4s, v25.4s, v0.4s\n" + "118:" // Height 3: No activation + "cmp x13, #0x18\n" + "bge 131f\n" + "tbz x13, #4, 122f\n" + "st1 { v4.4s }, [x12], #0x10\n" + "st1 { v14.4s }, [x12], #0x10\n" + "st1 { v15.4s }, [x12], #0x10\n" + "st1 { v16.4s }, [x12], #0x10\n" + "st1 { v8.4s }, [x22], #0x10\n" + "st1 { v9.4s }, [x22], #0x10\n" + "st1 { v10.4s }, [x22], #0x10\n" + "st1 { v11.4s }, [x22], #0x10\n" + "st1 { v20.4s }, [x21], #0x10\n" + "st1 { v21.4s }, [x21], #0x10\n" + "st1 { v22.4s }, [x21], #0x10\n" + "st1 { v23.4s }, [x21], #0x10\n" + "tbz x13, #2, 120f\n" + "st1 { v17.4s }, [x12], #0x10\n" + "st1 { v12.4s }, [x22], #0x10\n" + "st1 { v24.4s }, [x21], #0x10\n" + "tbz x13, #1, 119f\n" + "str d18, [x12], #0x8\n" + "str d13, [x22], #0x8\n" + "str d25, [x21], #0x8\n" + "tbz x13, #0, 130f\n" + "st1 { v18.s }[2], [x12]\n" + "st1 { v13.s }[2], [x22]\n" + "st1 { v25.s }[2], [x21]\n" + "b 130f\n" + "119:" // Height 3: Partial direct writeback: partial_1_20 + "tbz x13, #0, 130f\n" + "str s18, [x12, #0x0]\n" + "str s13, [x22, #0x0]\n" + "str s25, [x21, #0x0]\n" + "b 130f\n" + "120:" // Height 3: Partial direct writeback: partial_2_16 + "tbz x13, #1, 121f\n" + "str d17, [x12], #0x8\n" + "str d12, [x22], #0x8\n" + "str d24, [x21], #0x8\n" + "tbz x13, #0, 130f\n" + "st1 { v17.s }[2], [x12]\n" + "st1 { v12.s }[2], [x22]\n" + "st1 { v24.s }[2], [x21]\n" + "b 130f\n" + "121:" // Height 3: Partial direct writeback: partial_1_16 + "tbz x13, #0, 130f\n" + "str s17, [x12, #0x0]\n" + "str s12, [x22, #0x0]\n" + "str s24, [x21, #0x0]\n" + "b 130f\n" + "122:" // Height 3: Partial direct writeback: partial_8_0 + "tbz x13, #3, 126f\n" + "st1 { v4.4s }, [x12], #0x10\n" + "st1 { v14.4s }, [x12], #0x10\n" + "st1 { v8.4s }, [x22], #0x10\n" + "st1 { v9.4s }, [x22], #0x10\n" + "st1 { v20.4s }, [x21], #0x10\n" + "st1 { v21.4s }, [x21], #0x10\n" + "tbz x13, #2, 124f\n" + "st1 { v15.4s }, [x12], #0x10\n" + "st1 { v10.4s }, [x22], #0x10\n" + "st1 { v22.4s }, [x21], #0x10\n" + "tbz x13, #1, 123f\n" + "str d16, [x12], #0x8\n" + "str d11, [x22], #0x8\n" + "str d23, [x21], #0x8\n" + "tbz x13, #0, 130f\n" + "st1 { v16.s }[2], [x12]\n" + "st1 { v11.s }[2], [x22]\n" + "st1 { v23.s }[2], [x21]\n" + "b 130f\n" + "123:" // Height 3: Partial direct writeback: partial_1_12 + "tbz x13, #0, 130f\n" + "str s16, [x12, #0x0]\n" + "str s11, [x22, #0x0]\n" + "str s23, [x21, #0x0]\n" + "b 130f\n" + "124:" // Height 3: Partial direct writeback: partial_2_8 + "tbz x13, #1, 125f\n" + "str d15, [x12], #0x8\n" + "str d10, [x22], #0x8\n" + "str d22, [x21], #0x8\n" + "tbz x13, #0, 130f\n" + "st1 { v15.s }[2], [x12]\n" + "st1 { v10.s }[2], [x22]\n" + "st1 { v22.s }[2], [x21]\n" + "b 130f\n" + "125:" // Height 3: Partial direct writeback: partial_1_8 + "tbz x13, #0, 130f\n" + "str s15, [x12, #0x0]\n" + "str s10, [x22, #0x0]\n" + "str s22, [x21, #0x0]\n" + "b 130f\n" + "126:" // Height 3: Partial direct writeback: partial_4_0 + "tbz x13, #2, 128f\n" + "st1 { v4.4s }, [x12], #0x10\n" + "st1 { v8.4s }, [x22], #0x10\n" + "st1 { v20.4s }, [x21], #0x10\n" + "tbz x13, #1, 127f\n" + "str d14, [x12], #0x8\n" + "str d9, [x22], #0x8\n" + "str d21, [x21], #0x8\n" + "tbz x13, #0, 130f\n" + "st1 { v14.s }[2], [x12]\n" + "st1 { v9.s }[2], [x22]\n" + "st1 { v21.s }[2], [x21]\n" + "b 130f\n" + "127:" // Height 3: Partial direct writeback: partial_1_4 + "tbz x13, #0, 130f\n" + "str s14, [x12, #0x0]\n" + "str s9, [x22, #0x0]\n" + "str s21, [x21, #0x0]\n" + "b 130f\n" + "128:" // Height 3: Partial direct writeback: partial_2_0 + "tbz x13, #1, 129f\n" + "str d4, [x12], #0x8\n" + "str d8, [x22], #0x8\n" + "str d20, [x21], #0x8\n" + "tbz x13, #0, 130f\n" + "st1 { v4.s }[2], [x12]\n" + "st1 { v8.s }[2], [x22]\n" + "st1 { v20.s }[2], [x21]\n" + "b 130f\n" + "129:" // Height 3: Partial direct writeback: partial_1_0 + "str s4, [x12, #0x0]\n" + "str s8, [x22, #0x0]\n" + "str s20, [x21, #0x0]\n" + "130:" // Height 3: Partial direct writeback: Done + "b 132f\n" + "131:" // Height 3: Full writeback + "str q4, [x12, #0x0]\n" + "str q14, [x12, #0x10]\n" + "str q15, [x12, #0x20]\n" + "str q16, [x12, #0x30]\n" + "str q17, [x12, #0x40]\n" + "str q18, [x12, #0x50]\n" + "add x12, x12, #0x60\n" + "str q8, [x22, #0x0]\n" + "str q9, [x22, #0x10]\n" + "str q10, [x22, #0x20]\n" + "str q11, [x22, #0x30]\n" + "str q12, [x22, #0x40]\n" + "str q13, [x22, #0x50]\n" + "str q20, [x21, #0x0]\n" + "str q21, [x21, #0x10]\n" + "str q22, [x21, #0x20]\n" + "str q23, [x21, #0x30]\n" + "str q24, [x21, #0x40]\n" + "str q25, [x21, #0x50]\n" + "132:" // Height 3: Writeback done + "subs x13, x13, #0x18\n" + "bgt 90b\n" + "b 178f\n" + "133:" // Height 4 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x20, #0x10\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "mov x14, %x[bias]\n" + "mov x12, %x[output_ptr]\n" + "madd %x[output_ptr], x19, x20, %x[output_ptr]\n" + "134:" // Height 4: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x27, x28, x19, LSL #1\n" + "add x26, x27, x19, LSL #1\n" + "add x19, x26, x19, LSL #1\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "cmp x13, #0x14\n" + "bgt 135f\n" + "cmp x13, #0x10\n" + "mov x26, x11\n" + "bgt 135f\n" + "cmp x13, #0xc\n" + "mov x27, x11\n" + "bgt 135f\n" + "cmp x13, #0x8\n" + "mov x28, x11\n" + "bgt 135f\n" + "cmp x13, #0x4\n" + "mov x9, x11\n" + "bgt 135f\n" + "mov x10, x11\n" + "135:" // Height 4: B setup done + "cbz x14, 136f\n" + "ldr q8, [x14, #0x0]\n" + "ldr q9, [x14, #0x10]\n" + "zip2 v14.2d, v8.2d, v8.2d\n" + "zip1 v8.2d, v8.2d, v8.2d\n" + "ldr q10, [x14, #0x20]\n" + "ldr q11, [x14, #0x30]\n" + "zip2 v15.2d, v9.2d, v9.2d\n" + "zip1 v9.2d, v9.2d, v9.2d\n" + "ldr q12, [x14, #0x40]\n" + "ldr q13, [x14, #0x50]\n" + "zip2 v16.2d, v10.2d, v10.2d\n" + "zip1 v10.2d, v10.2d, v10.2d\n" + "zip2 v17.2d, v11.2d, v11.2d\n" + "zip1 v11.2d, v11.2d, v11.2d\n" + "add x14, x14, #0x60\n" + "zip2 v18.2d, v12.2d, v12.2d\n" + "zip1 v12.2d, v12.2d, v12.2d\n" + "zip2 v19.2d, v13.2d, v13.2d\n" + "zip1 v13.2d, v13.2d, v13.2d\n" + "mov v20.16b, v8.16b\n" + "mov v26.16b, v14.16b\n" + "mov v21.16b, v9.16b\n" + "mov v27.16b, v15.16b\n" + "mov v22.16b, v10.16b\n" + "mov v28.16b, v16.16b\n" + "mov v23.16b, v11.16b\n" + "mov v29.16b, v17.16b\n" + "mov v24.16b, v12.16b\n" + "mov v30.16b, v18.16b\n" + "mov v25.16b, v13.16b\n" + "mov v31.16b, v19.16b\n" + "b 152f\n" + "136:" // Height 4: no bias + "tbz %x[flags], #0, 151f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x22, x12, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "cmp x13, #0x18\n" + "add x20, x21, x19, LSL #2\n" + "bge 149f\n" + "tbz x13, #4, 140f\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "ld1 { v14.4s }, [x22], #0x10\n" + "ld1 { v21.4s }, [x21], #0x10\n" + "ld1 { v26.4s }, [x20], #0x10\n" + "ld1 { v10.4s }, [x12], #0x10\n" + "ld1 { v15.4s }, [x22], #0x10\n" + "ld1 { v22.4s }, [x21], #0x10\n" + "ld1 { v27.4s }, [x20], #0x10\n" + "ld1 { v11.4s }, [x12], #0x10\n" + "ld1 { v16.4s }, [x22], #0x10\n" + "ld1 { v23.4s }, [x21], #0x10\n" + "ld1 { v28.4s }, [x20], #0x10\n" + "ld1 { v12.4s }, [x12], #0x10\n" + "ld1 { v17.4s }, [x22], #0x10\n" + "ld1 { v24.4s }, [x21], #0x10\n" + "ld1 { v29.4s }, [x20], #0x10\n" + "tbz x13, #2, 138f\n" + "ld1 { v13.4s }, [x12], #0x10\n" + "ld1 { v18.4s }, [x22], #0x10\n" + "ld1 { v25.4s }, [x21], #0x10\n" + "ld1 { v30.4s }, [x20], #0x10\n" + "tbz x13, #1, 137f\n" + "ldr d20, [x12], #0x8\n" + "ldr d19, [x22], #0x8\n" + "mov x19, #0x58\n" + "ldr d4, [x21], #0x8\n" + "ldr d31, [x20], #0x8\n" + "tbz x13, #0, 148f\n" + "ld1 { v20.s }[2], [x12]\n" + "ld1 { v19.s }[2], [x22]\n" + "ld1 { v4.s }[2], [x21]\n" + "ld1 { v31.s }[2], [x20]\n" + "b 148f\n" + "137:" // Height 4: Partial accumulate: partial_1_20 + "mov x19, #0x50\n" + "tbz x13, #0, 148f\n" + "ldr s20, [x12, #0x0]\n" + "ldr s19, [x22, #0x0]\n" + "ldr s4, [x21, #0x0]\n" + "ldr s31, [x20, #0x0]\n" + "b 148f\n" + "138:" // Height 4: Partial accumulate: partial_2_16 + "tbz x13, #1, 139f\n" + "ldr d13, [x12], #0x8\n" + "ldr d18, [x22], #0x8\n" + "mov x19, #0x48\n" + "ldr d25, [x21], #0x8\n" + "ldr d30, [x20], #0x8\n" + "tbz x13, #0, 148f\n" + "ld1 { v13.s }[2], [x12]\n" + "ld1 { v18.s }[2], [x22]\n" + "ld1 { v25.s }[2], [x21]\n" + "ld1 { v30.s }[2], [x20]\n" + "b 148f\n" + "139:" // Height 4: Partial accumulate: partial_1_16 + "mov x19, #0x40\n" + "tbz x13, #0, 148f\n" + "ldr s13, [x12, #0x0]\n" + "ldr s18, [x22, #0x0]\n" + "ldr s25, [x21, #0x0]\n" + "ldr s30, [x20, #0x0]\n" + "b 148f\n" + "140:" // Height 4: Partial accumulate: partial_8_0 + "tbz x13, #3, 144f\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "ld1 { v14.4s }, [x22], #0x10\n" + "ld1 { v21.4s }, [x21], #0x10\n" + "ld1 { v26.4s }, [x20], #0x10\n" + "ld1 { v10.4s }, [x12], #0x10\n" + "ld1 { v15.4s }, [x22], #0x10\n" + "ld1 { v22.4s }, [x21], #0x10\n" + "ld1 { v27.4s }, [x20], #0x10\n" + "tbz x13, #2, 142f\n" + "ld1 { v11.4s }, [x12], #0x10\n" + "ld1 { v16.4s }, [x22], #0x10\n" + "ld1 { v23.4s }, [x21], #0x10\n" + "ld1 { v28.4s }, [x20], #0x10\n" + "tbz x13, #1, 141f\n" + "ldr d12, [x12], #0x8\n" + "ldr d17, [x22], #0x8\n" + "mov x19, #0x38\n" + "ldr d24, [x21], #0x8\n" + "ldr d29, [x20], #0x8\n" + "tbz x13, #0, 148f\n" + "ld1 { v12.s }[2], [x12]\n" + "ld1 { v17.s }[2], [x22]\n" + "ld1 { v24.s }[2], [x21]\n" + "ld1 { v29.s }[2], [x20]\n" + "b 148f\n" + "141:" // Height 4: Partial accumulate: partial_1_12 + "mov x19, #0x30\n" + "tbz x13, #0, 148f\n" + "ldr s12, [x12, #0x0]\n" + "ldr s17, [x22, #0x0]\n" + "ldr s24, [x21, #0x0]\n" + "ldr s29, [x20, #0x0]\n" + "b 148f\n" + "142:" // Height 4: Partial accumulate: partial_2_8 + "tbz x13, #1, 143f\n" + "ldr d11, [x12], #0x8\n" + "ldr d16, [x22], #0x8\n" + "mov x19, #0x28\n" + "ldr d23, [x21], #0x8\n" + "ldr d28, [x20], #0x8\n" + "tbz x13, #0, 148f\n" + "ld1 { v11.s }[2], [x12]\n" + "ld1 { v16.s }[2], [x22]\n" + "ld1 { v23.s }[2], [x21]\n" + "ld1 { v28.s }[2], [x20]\n" + "b 148f\n" + "143:" // Height 4: Partial accumulate: partial_1_8 + "mov x19, #0x20\n" + "tbz x13, #0, 148f\n" + "ldr s11, [x12, #0x0]\n" + "ldr s16, [x22, #0x0]\n" + "ldr s23, [x21, #0x0]\n" + "ldr s28, [x20, #0x0]\n" + "b 148f\n" + "144:" // Height 4: Partial accumulate: partial_4_0 + "tbz x13, #2, 146f\n" + "ld1 { v9.4s }, [x12], #0x10\n" + "ld1 { v14.4s }, [x22], #0x10\n" + "ld1 { v21.4s }, [x21], #0x10\n" + "ld1 { v26.4s }, [x20], #0x10\n" + "tbz x13, #1, 145f\n" + "ldr d10, [x12], #0x8\n" + "ldr d15, [x22], #0x8\n" + "mov x19, #0x18\n" + "ldr d22, [x21], #0x8\n" + "ldr d27, [x20], #0x8\n" + "tbz x13, #0, 148f\n" + "ld1 { v10.s }[2], [x12]\n" + "ld1 { v15.s }[2], [x22]\n" + "ld1 { v22.s }[2], [x21]\n" + "ld1 { v27.s }[2], [x20]\n" + "b 148f\n" + "145:" // Height 4: Partial accumulate: partial_1_4 + "mov x19, #0x10\n" + "tbz x13, #0, 148f\n" + "ldr s10, [x12, #0x0]\n" + "ldr s15, [x22, #0x0]\n" + "ldr s22, [x21, #0x0]\n" + "ldr s27, [x20, #0x0]\n" + "b 148f\n" + "146:" // Height 4: Partial accumulate: partial_2_0 + "tbz x13, #1, 147f\n" + "ldr d9, [x12], #0x8\n" + "ldr d14, [x22], #0x8\n" + "mov x19, #0x8\n" + "ldr d21, [x21], #0x8\n" + "ldr d26, [x20], #0x8\n" + "tbz x13, #0, 148f\n" + "ld1 { v9.s }[2], [x12]\n" + "ld1 { v14.s }[2], [x22]\n" + "ld1 { v21.s }[2], [x21]\n" + "ld1 { v26.s }[2], [x20]\n" + "b 148f\n" + "147:" // Height 4: Partial accumulate: partial_1_0 + "ldr s9, [x12, #0x0]\n" + "ldr s14, [x22, #0x0]\n" + "mov x19, #0x0\n" + "ldr s21, [x21, #0x0]\n" + "ldr s26, [x20, #0x0]\n" + "148:" // Height 4: Partial accumulate: Done + "sub x12, x12, x19\n" + "b 150f\n" + "149:" // Height 4: full accumulate + "ldr q9, [x12, #0x0]\n" + "ldr q10, [x12, #0x10]\n" + "ldr q11, [x12, #0x20]\n" + "ldr q12, [x12, #0x30]\n" + "ldr q13, [x12, #0x40]\n" + "ldr q20, [x12, #0x50]\n" + "ldr q14, [x22, #0x0]\n" + "ldr q15, [x22, #0x10]\n" + "ldr q16, [x22, #0x20]\n" + "ldr q17, [x22, #0x30]\n" + "ldr q18, [x22, #0x40]\n" + "ldr q19, [x22, #0x50]\n" + "ldr q21, [x21, #0x0]\n" + "ldr q22, [x21, #0x10]\n" + "ldr q23, [x21, #0x20]\n" + "ldr q24, [x21, #0x30]\n" + "ldr q25, [x21, #0x40]\n" + "ldr q4, [x21, #0x50]\n" + "ldr q26, [x20, #0x0]\n" + "ldr q27, [x20, #0x10]\n" + "ldr q28, [x20, #0x20]\n" + "ldr q29, [x20, #0x30]\n" + "ldr q30, [x20, #0x40]\n" + "ldr q31, [x20, #0x50]\n" + "150:" // Height 4: MMLA fixup + "zip1 v8.2d, v9.2d, v14.2d\n" + "zip2 v14.2d, v9.2d, v14.2d\n" + "zip1 v9.2d, v10.2d, v15.2d\n" + "zip2 v15.2d, v10.2d, v15.2d\n" + "zip1 v10.2d, v11.2d, v16.2d\n" + "zip2 v16.2d, v11.2d, v16.2d\n" + "zip1 v11.2d, v12.2d, v17.2d\n" + "zip2 v17.2d, v12.2d, v17.2d\n" + "zip1 v12.2d, v13.2d, v18.2d\n" + "zip2 v18.2d, v13.2d, v18.2d\n" + "zip1 v13.2d, v20.2d, v19.2d\n" + "zip2 v19.2d, v20.2d, v19.2d\n" + "zip1 v20.2d, v21.2d, v26.2d\n" + "zip2 v26.2d, v21.2d, v26.2d\n" + "zip1 v21.2d, v22.2d, v27.2d\n" + "zip2 v27.2d, v22.2d, v27.2d\n" + "zip1 v22.2d, v23.2d, v28.2d\n" + "zip2 v28.2d, v23.2d, v28.2d\n" + "zip1 v23.2d, v24.2d, v29.2d\n" + "zip2 v29.2d, v24.2d, v29.2d\n" + "zip1 v24.2d, v25.2d, v30.2d\n" + "zip2 v30.2d, v25.2d, v30.2d\n" + "zip1 v25.2d, v4.2d, v31.2d\n" + "zip2 v31.2d, v4.2d, v31.2d\n" + "b 152f\n" + "151:" // Height 4: no accumulate + "movi v8.16b, #0x0\n" + "movi v9.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v26.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "movi v29.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "152:" // Height 4: setup done + "mov x25, #0x0\n" + "153:" // Height 4: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w24, [x19, x25, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 154f\n" + "ldr x20, [%x[input_ptr], x25, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x23, [x20, #0x0]\n" + "ldr x22, [x20, #0x8]\n" + "ldr x21, [x20, #0x10]\n" + "ldr x20, [x20, #0x18]\n" + "cbnz x25, 155f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x23, x23, x19, LSL #2\n" + "add x22, x22, x19, LSL #2\n" + "add x21, x21, x19, LSL #2\n" + "add x20, x20, x19, LSL #2\n" + "b 155f\n" + "154:" // Height 4: setup direct input + "mov x23, %x[input_ptr]\n" + "add x22, x23, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "add x20, x21, x19, LSL #2\n" + "155:" // Height 4: input setup done + "cmp x24, #0x4\n" + "blt 158f\n" + "ld1 { v0.4s }, [x23], #0x10\n" + "ld1 { v2.4s }, [x21], #0x10\n" + "cmp x24, #0x8\n" + "ld1 { v1.4s }, [x22], #0x10\n" + "ld1 { v3.4s }, [x20], #0x10\n" + "ldr q4, [x11, #0x0]\n" + "ldr q5, [x11, #0x10]\n" + "ldr q6, [x10, #0x0]\n" + "ldr q7, [x10, #0x10]\n" + "blt 157f\n" + "156:" // Height 4: Multiply loop: Main loop head + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "sub x24, x24, #0x4\n" + "cmp x24, #0x8\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + "ld1 { v1.4s }, [x22], #0x10\n" + ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n" + "ld1 { v3.4s }, [x20], #0x10\n" + ".inst 0x6e44ec08 // bfmmla v8.4s, v0.8h, v4.8h\n" + ".inst 0x6e44ec54 // bfmmla v20.4s, v2.8h, v4.8h\n" + "ldr q4, [x9, #0x0]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e45ec0e // bfmmla v14.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec5a // bfmmla v26.4s, v2.8h, v5.8h\n" + "ldr q5, [x9, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec55 // bfmmla v21.4s, v2.8h, v6.8h\n" + "ldr q6, [x28, #0x0]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec5b // bfmmla v27.4s, v2.8h, v7.8h\n" + "ldr q7, [x28, #0x10]\n" + "add x28, x28, #0x20\n" + ".inst 0x6e44ec0a // bfmmla v10.4s, v0.8h, v4.8h\n" + ".inst 0x6e44ec56 // bfmmla v22.4s, v2.8h, v4.8h\n" + "ldr q4, [x27, #0x0]\n" + ".inst 0x6e45ec10 // bfmmla v16.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec5c // bfmmla v28.4s, v2.8h, v5.8h\n" + "ldr q5, [x27, #0x10]\n" + "add x27, x27, #0x20\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + "ldr q6, [x26, #0x0]\n" + ".inst 0x6e47ec11 // bfmmla v17.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec5d // bfmmla v29.4s, v2.8h, v7.8h\n" + "ldr q7, [x26, #0x10]\n" + "add x26, x26, #0x20\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + ".inst 0x6e44ec58 // bfmmla v24.4s, v2.8h, v4.8h\n" + "ldr q4, [x11, #0x0]\n" + ".inst 0x6e45ec12 // bfmmla v18.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec5e // bfmmla v30.4s, v2.8h, v5.8h\n" + "ldr q5, [x11, #0x10]\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec59 // bfmmla v25.4s, v2.8h, v6.8h\n" + "ldr q6, [x10, #0x0]\n" + ".inst 0x6e47ec13 // bfmmla v19.4s, v0.8h, v7.8h\n" + "ld1 { v0.4s }, [x23], #0x10\n" + ".inst 0x6e47ec5f // bfmmla v31.4s, v2.8h, v7.8h\n" + "ld1 { v2.4s }, [x21], #0x10\n" + "ldr q7, [x10, #0x10]\n" + "bge 156b\n" + "157:" // Height 4: Multiply loop: Single iteration only + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "sub x24, x24, #0x4\n" + "add x11, x11, #0x20\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n" + ".inst 0x6e44ec08 // bfmmla v8.4s, v0.8h, v4.8h\n" + "add x10, x10, #0x20\n" + ".inst 0x6e44ec54 // bfmmla v20.4s, v2.8h, v4.8h\n" + "ldr q4, [x9, #0x0]\n" + ".inst 0x6e45ec0e // bfmmla v14.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec5a // bfmmla v26.4s, v2.8h, v5.8h\n" + "ldr q5, [x9, #0x10]\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + "add x9, x9, #0x20\n" + ".inst 0x6e46ec55 // bfmmla v21.4s, v2.8h, v6.8h\n" + "ldr q6, [x28, #0x0]\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec5b // bfmmla v27.4s, v2.8h, v7.8h\n" + "ldr q7, [x28, #0x10]\n" + ".inst 0x6e44ec0a // bfmmla v10.4s, v0.8h, v4.8h\n" + "add x28, x28, #0x20\n" + ".inst 0x6e44ec56 // bfmmla v22.4s, v2.8h, v4.8h\n" + "ldr q4, [x27, #0x0]\n" + ".inst 0x6e45ec10 // bfmmla v16.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec5c // bfmmla v28.4s, v2.8h, v5.8h\n" + "ldr q5, [x27, #0x10]\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + "add x27, x27, #0x20\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + "ldr q6, [x26, #0x0]\n" + ".inst 0x6e47ec11 // bfmmla v17.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec5d // bfmmla v29.4s, v2.8h, v7.8h\n" + "ldr q7, [x26, #0x10]\n" + "add x26, x26, #0x20\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + ".inst 0x6e44ec58 // bfmmla v24.4s, v2.8h, v4.8h\n" + ".inst 0x6e45ec12 // bfmmla v18.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec5e // bfmmla v30.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec59 // bfmmla v25.4s, v2.8h, v6.8h\n" + ".inst 0x6e47ec13 // bfmmla v19.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec5f // bfmmla v31.4s, v2.8h, v7.8h\n" + "158:" // Height 4: Multiply loop: Main loop skip + "cbz x24, 161f\n" + "cbz x24, 161f\n" + "tbz x24, #1, 159f\n" + "ldr d0, [x23], #0x8\n" + "ldr d1, [x22], #0x8\n" + "ldr d2, [x21], #0x8\n" + "ldr d3, [x20], #0x8\n" + "tbz x24, #0, 160f\n" + "ld1 { v0.s }[2], [x23]\n" + "ld1 { v1.s }[2], [x22]\n" + "ld1 { v2.s }[2], [x21]\n" + "ld1 { v3.s }[2], [x20]\n" + "b 160f\n" + "159:" // Height 4: Multiply loop: Ragged operand read: partial_1_0 + "ldr s0, [x23, #0x0]\n" + "ldr s1, [x22, #0x0]\n" + "ldr s2, [x21, #0x0]\n" + "ldr s3, [x20, #0x0]\n" + "160:" // Height 4: Multiply loop: Ragged operand read: Done + "ldr q4, [x11, #0x0]\n" + "ldr q5, [x11, #0x10]\n" + ".inst 0x0ea16800 // bfcvtn v0.4h, v0.4s\n" + ".inst 0x0ea16842 // bfcvtn v2.4h, v2.4s\n" + "ldr q6, [x10, #0x0]\n" + "ldr q7, [x10, #0x10]\n" + ".inst 0x4ea16820 // bfcvtn2 v0.8h, v1.4s\n" + ".inst 0x4ea16862 // bfcvtn2 v2.8h, v3.4s\n" + ".inst 0x6e44ec08 // bfmmla v8.4s, v0.8h, v4.8h\n" + ".inst 0x6e44ec54 // bfmmla v20.4s, v2.8h, v4.8h\n" + "ldr q4, [x9, #0x0]\n" + "add x11, x11, #0x20\n" + ".inst 0x6e45ec0e // bfmmla v14.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec5a // bfmmla v26.4s, v2.8h, v5.8h\n" + "ldr q5, [x9, #0x10]\n" + "add x10, x10, #0x20\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec55 // bfmmla v21.4s, v2.8h, v6.8h\n" + "ldr q6, [x28, #0x0]\n" + "add x9, x9, #0x20\n" + ".inst 0x6e47ec0f // bfmmla v15.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec5b // bfmmla v27.4s, v2.8h, v7.8h\n" + "ldr q7, [x28, #0x10]\n" + "add x28, x28, #0x20\n" + ".inst 0x6e44ec0a // bfmmla v10.4s, v0.8h, v4.8h\n" + ".inst 0x6e44ec56 // bfmmla v22.4s, v2.8h, v4.8h\n" + "ldr q4, [x27, #0x0]\n" + ".inst 0x6e45ec10 // bfmmla v16.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec5c // bfmmla v28.4s, v2.8h, v5.8h\n" + "ldr q5, [x27, #0x10]\n" + "add x27, x27, #0x20\n" + ".inst 0x6e46ec0b // bfmmla v11.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec57 // bfmmla v23.4s, v2.8h, v6.8h\n" + "ldr q6, [x26, #0x0]\n" + ".inst 0x6e47ec11 // bfmmla v17.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec5d // bfmmla v29.4s, v2.8h, v7.8h\n" + "ldr q7, [x26, #0x10]\n" + "add x26, x26, #0x20\n" + ".inst 0x6e44ec0c // bfmmla v12.4s, v0.8h, v4.8h\n" + ".inst 0x6e44ec58 // bfmmla v24.4s, v2.8h, v4.8h\n" + ".inst 0x6e45ec12 // bfmmla v18.4s, v0.8h, v5.8h\n" + ".inst 0x6e45ec5e // bfmmla v30.4s, v2.8h, v5.8h\n" + ".inst 0x6e46ec0d // bfmmla v13.4s, v0.8h, v6.8h\n" + ".inst 0x6e46ec59 // bfmmla v25.4s, v2.8h, v6.8h\n" + ".inst 0x6e47ec13 // bfmmla v19.4s, v0.8h, v7.8h\n" + ".inst 0x6e47ec5f // bfmmla v31.4s, v2.8h, v7.8h\n" + "161:" // Height 4: Multiply loop: No odd multiplies + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x25, x25, #0x1\n" + "cmp x25, x19\n" + "bne 153b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x22, x12, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "uzp1 v4.2d, v8.2d, v14.2d\n" + "uzp2 v8.2d, v8.2d, v14.2d\n" + "uzp1 v14.2d, v9.2d, v15.2d\n" + "add x20, x21, x19, LSL #2\n" + "uzp2 v9.2d, v9.2d, v15.2d\n" + "uzp1 v15.2d, v10.2d, v16.2d\n" + "uzp2 v10.2d, v10.2d, v16.2d\n" + "uzp1 v16.2d, v11.2d, v17.2d\n" + "uzp2 v11.2d, v11.2d, v17.2d\n" + "uzp1 v17.2d, v12.2d, v18.2d\n" + "uzp2 v12.2d, v12.2d, v18.2d\n" + "uzp1 v18.2d, v13.2d, v19.2d\n" + "uzp2 v13.2d, v13.2d, v19.2d\n" + "uzp1 v19.2d, v20.2d, v26.2d\n" + "uzp2 v20.2d, v20.2d, v26.2d\n" + "uzp1 v26.2d, v21.2d, v27.2d\n" + "uzp2 v21.2d, v21.2d, v27.2d\n" + "uzp1 v27.2d, v22.2d, v28.2d\n" + "uzp2 v22.2d, v22.2d, v28.2d\n" + "uzp1 v28.2d, v23.2d, v29.2d\n" + "uzp2 v23.2d, v23.2d, v29.2d\n" + "uzp1 v29.2d, v24.2d, v30.2d\n" + "uzp2 v24.2d, v24.2d, v30.2d\n" + "uzp1 v30.2d, v25.2d, v31.2d\n" + "uzp2 v25.2d, v25.2d, v31.2d\n" + "tbz %x[flags], #1, 162f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1r { v1.4s }, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1r { v0.4s }, [x19]\n" + "fmin v4.4s, v4.4s, v1.4s\n" + "fmin v14.4s, v14.4s, v1.4s\n" + "fmin v15.4s, v15.4s, v1.4s\n" + "fmin v16.4s, v16.4s, v1.4s\n" + "fmin v17.4s, v17.4s, v1.4s\n" + "fmin v18.4s, v18.4s, v1.4s\n" + "fmin v8.4s, v8.4s, v1.4s\n" + "fmin v9.4s, v9.4s, v1.4s\n" + "fmin v10.4s, v10.4s, v1.4s\n" + "fmin v11.4s, v11.4s, v1.4s\n" + "fmin v12.4s, v12.4s, v1.4s\n" + "fmin v13.4s, v13.4s, v1.4s\n" + "fmin v19.4s, v19.4s, v1.4s\n" + "fmin v26.4s, v26.4s, v1.4s\n" + "fmin v27.4s, v27.4s, v1.4s\n" + "fmin v28.4s, v28.4s, v1.4s\n" + "fmin v29.4s, v29.4s, v1.4s\n" + "fmin v30.4s, v30.4s, v1.4s\n" + "fmin v20.4s, v20.4s, v1.4s\n" + "fmin v21.4s, v21.4s, v1.4s\n" + "fmin v22.4s, v22.4s, v1.4s\n" + "fmin v23.4s, v23.4s, v1.4s\n" + "fmin v24.4s, v24.4s, v1.4s\n" + "fmin v25.4s, v25.4s, v1.4s\n" + "fmax v4.4s, v4.4s, v0.4s\n" + "fmax v14.4s, v14.4s, v0.4s\n" + "fmax v15.4s, v15.4s, v0.4s\n" + "fmax v16.4s, v16.4s, v0.4s\n" + "fmax v17.4s, v17.4s, v0.4s\n" + "fmax v18.4s, v18.4s, v0.4s\n" + "fmax v8.4s, v8.4s, v0.4s\n" + "fmax v9.4s, v9.4s, v0.4s\n" + "fmax v10.4s, v10.4s, v0.4s\n" + "fmax v11.4s, v11.4s, v0.4s\n" + "fmax v12.4s, v12.4s, v0.4s\n" + "fmax v13.4s, v13.4s, v0.4s\n" + "fmax v19.4s, v19.4s, v0.4s\n" + "fmax v26.4s, v26.4s, v0.4s\n" + "fmax v27.4s, v27.4s, v0.4s\n" + "fmax v28.4s, v28.4s, v0.4s\n" + "fmax v29.4s, v29.4s, v0.4s\n" + "fmax v30.4s, v30.4s, v0.4s\n" + "fmax v20.4s, v20.4s, v0.4s\n" + "fmax v21.4s, v21.4s, v0.4s\n" + "fmax v22.4s, v22.4s, v0.4s\n" + "fmax v23.4s, v23.4s, v0.4s\n" + "fmax v24.4s, v24.4s, v0.4s\n" + "fmax v25.4s, v25.4s, v0.4s\n" + "162:" // Height 4: No activation + "cmp x13, #0x18\n" + "bge 175f\n" + "tbz x13, #4, 166f\n" + "st1 { v4.4s }, [x12], #0x10\n" + "st1 { v14.4s }, [x12], #0x10\n" + "st1 { v15.4s }, [x12], #0x10\n" + "st1 { v16.4s }, [x12], #0x10\n" + "st1 { v8.4s }, [x22], #0x10\n" + "st1 { v9.4s }, [x22], #0x10\n" + "st1 { v10.4s }, [x22], #0x10\n" + "st1 { v11.4s }, [x22], #0x10\n" + "st1 { v19.4s }, [x21], #0x10\n" + "st1 { v26.4s }, [x21], #0x10\n" + "st1 { v27.4s }, [x21], #0x10\n" + "st1 { v28.4s }, [x21], #0x10\n" + "st1 { v20.4s }, [x20], #0x10\n" + "st1 { v21.4s }, [x20], #0x10\n" + "st1 { v22.4s }, [x20], #0x10\n" + "st1 { v23.4s }, [x20], #0x10\n" + "tbz x13, #2, 164f\n" + "st1 { v17.4s }, [x12], #0x10\n" + "st1 { v12.4s }, [x22], #0x10\n" + "st1 { v29.4s }, [x21], #0x10\n" + "st1 { v24.4s }, [x20], #0x10\n" + "tbz x13, #1, 163f\n" + "str d18, [x12], #0x8\n" + "str d13, [x22], #0x8\n" + "str d30, [x21], #0x8\n" + "str d25, [x20], #0x8\n" + "tbz x13, #0, 174f\n" + "st1 { v18.s }[2], [x12]\n" + "st1 { v13.s }[2], [x22]\n" + "st1 { v30.s }[2], [x21]\n" + "st1 { v25.s }[2], [x20]\n" + "b 174f\n" + "163:" // Height 4: Partial direct writeback: partial_1_20 + "tbz x13, #0, 174f\n" + "str s18, [x12, #0x0]\n" + "str s13, [x22, #0x0]\n" + "str s30, [x21, #0x0]\n" + "str s25, [x20, #0x0]\n" + "b 174f\n" + "164:" // Height 4: Partial direct writeback: partial_2_16 + "tbz x13, #1, 165f\n" + "str d17, [x12], #0x8\n" + "str d12, [x22], #0x8\n" + "str d29, [x21], #0x8\n" + "str d24, [x20], #0x8\n" + "tbz x13, #0, 174f\n" + "st1 { v17.s }[2], [x12]\n" + "st1 { v12.s }[2], [x22]\n" + "st1 { v29.s }[2], [x21]\n" + "st1 { v24.s }[2], [x20]\n" + "b 174f\n" + "165:" // Height 4: Partial direct writeback: partial_1_16 + "tbz x13, #0, 174f\n" + "str s17, [x12, #0x0]\n" + "str s12, [x22, #0x0]\n" + "str s29, [x21, #0x0]\n" + "str s24, [x20, #0x0]\n" + "b 174f\n" + "166:" // Height 4: Partial direct writeback: partial_8_0 + "tbz x13, #3, 170f\n" + "st1 { v4.4s }, [x12], #0x10\n" + "st1 { v14.4s }, [x12], #0x10\n" + "st1 { v8.4s }, [x22], #0x10\n" + "st1 { v9.4s }, [x22], #0x10\n" + "st1 { v19.4s }, [x21], #0x10\n" + "st1 { v26.4s }, [x21], #0x10\n" + "st1 { v20.4s }, [x20], #0x10\n" + "st1 { v21.4s }, [x20], #0x10\n" + "tbz x13, #2, 168f\n" + "st1 { v15.4s }, [x12], #0x10\n" + "st1 { v10.4s }, [x22], #0x10\n" + "st1 { v27.4s }, [x21], #0x10\n" + "st1 { v22.4s }, [x20], #0x10\n" + "tbz x13, #1, 167f\n" + "str d16, [x12], #0x8\n" + "str d11, [x22], #0x8\n" + "str d28, [x21], #0x8\n" + "str d23, [x20], #0x8\n" + "tbz x13, #0, 174f\n" + "st1 { v16.s }[2], [x12]\n" + "st1 { v11.s }[2], [x22]\n" + "st1 { v28.s }[2], [x21]\n" + "st1 { v23.s }[2], [x20]\n" + "b 174f\n" + "167:" // Height 4: Partial direct writeback: partial_1_12 + "tbz x13, #0, 174f\n" + "str s16, [x12, #0x0]\n" + "str s11, [x22, #0x0]\n" + "str s28, [x21, #0x0]\n" + "str s23, [x20, #0x0]\n" + "b 174f\n" + "168:" // Height 4: Partial direct writeback: partial_2_8 + "tbz x13, #1, 169f\n" + "str d15, [x12], #0x8\n" + "str d10, [x22], #0x8\n" + "str d27, [x21], #0x8\n" + "str d22, [x20], #0x8\n" + "tbz x13, #0, 174f\n" + "st1 { v15.s }[2], [x12]\n" + "st1 { v10.s }[2], [x22]\n" + "st1 { v27.s }[2], [x21]\n" + "st1 { v22.s }[2], [x20]\n" + "b 174f\n" + "169:" // Height 4: Partial direct writeback: partial_1_8 + "tbz x13, #0, 174f\n" + "str s15, [x12, #0x0]\n" + "str s10, [x22, #0x0]\n" + "str s27, [x21, #0x0]\n" + "str s22, [x20, #0x0]\n" + "b 174f\n" + "170:" // Height 4: Partial direct writeback: partial_4_0 + "tbz x13, #2, 172f\n" + "st1 { v4.4s }, [x12], #0x10\n" + "st1 { v8.4s }, [x22], #0x10\n" + "st1 { v19.4s }, [x21], #0x10\n" + "st1 { v20.4s }, [x20], #0x10\n" + "tbz x13, #1, 171f\n" + "str d14, [x12], #0x8\n" + "str d9, [x22], #0x8\n" + "str d26, [x21], #0x8\n" + "str d21, [x20], #0x8\n" + "tbz x13, #0, 174f\n" + "st1 { v14.s }[2], [x12]\n" + "st1 { v9.s }[2], [x22]\n" + "st1 { v26.s }[2], [x21]\n" + "st1 { v21.s }[2], [x20]\n" + "b 174f\n" + "171:" // Height 4: Partial direct writeback: partial_1_4 + "tbz x13, #0, 174f\n" + "str s14, [x12, #0x0]\n" + "str s9, [x22, #0x0]\n" + "str s26, [x21, #0x0]\n" + "str s21, [x20, #0x0]\n" + "b 174f\n" + "172:" // Height 4: Partial direct writeback: partial_2_0 + "tbz x13, #1, 173f\n" + "str d4, [x12], #0x8\n" + "str d8, [x22], #0x8\n" + "str d19, [x21], #0x8\n" + "str d20, [x20], #0x8\n" + "tbz x13, #0, 174f\n" + "st1 { v4.s }[2], [x12]\n" + "st1 { v8.s }[2], [x22]\n" + "st1 { v19.s }[2], [x21]\n" + "st1 { v20.s }[2], [x20]\n" + "b 174f\n" + "173:" // Height 4: Partial direct writeback: partial_1_0 + "str s4, [x12, #0x0]\n" + "str s8, [x22, #0x0]\n" + "str s19, [x21, #0x0]\n" + "str s20, [x20, #0x0]\n" + "174:" // Height 4: Partial direct writeback: Done + "b 176f\n" + "175:" // Height 4: Full writeback + "str q4, [x12, #0x0]\n" + "str q14, [x12, #0x10]\n" + "str q15, [x12, #0x20]\n" + "str q16, [x12, #0x30]\n" + "str q17, [x12, #0x40]\n" + "str q18, [x12, #0x50]\n" + "add x12, x12, #0x60\n" + "str q8, [x22, #0x0]\n" + "str q9, [x22, #0x10]\n" + "str q10, [x22, #0x20]\n" + "str q11, [x22, #0x30]\n" + "str q12, [x22, #0x40]\n" + "str q13, [x22, #0x50]\n" + "str q19, [x21, #0x0]\n" + "str q26, [x21, #0x10]\n" + "str q27, [x21, #0x20]\n" + "str q28, [x21, #0x30]\n" + "str q29, [x21, #0x40]\n" + "str q30, [x21, #0x50]\n" + "str q20, [x20, #0x0]\n" + "str q21, [x20, #0x10]\n" + "str q22, [x20, #0x20]\n" + "str q23, [x20, #0x30]\n" + "str q24, [x20, #0x40]\n" + "str q25, [x20, #0x50]\n" + "176:" // Height 4: Writeback done + "subs x13, x13, #0x18\n" + "bgt 134b\n" + "subs %x[M], %x[M], #0x4\n" + "beq 178f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 177f\n" + "add x20, x20, #0x4\n" + "str x20, [%x[args_ptr], %[offsetof_input_offset]]\n" + "b 1b\n" + "177:" // Update direct input + "mov x19, #0x10\n" + "madd %x[input_ptr], x19, x20, %x[input_ptr]\n" + "b 1b\n" + "178:" // Exit + : [M] "+&r" (M), [input_ptr] "+&r" (input_ptr), [output_ptr] "+&r" (output_ptr) + : [args_ptr] "r" (&ka), [bias] "r" (bias), [flags] "r" (flags), [offset_max] "I" (offsetof(KernelArgs, maxval)), [offset_min] "I" (offsetof(KernelArgs, minval)), [offsetof_B_ptr] "I" (offsetof(KernelArgs, B_ptr)), [offsetof_B_stride] "I" (offsetof(KernelArgs, B_stride)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_cur_B_ptr] "I" (offsetof(KernelArgs, cur_B_ptr)), [offsetof_input_initial_col] "I" (offsetof(KernelArgs, input_initial_col)), [offsetof_input_offset] "I" (offsetof(KernelArgs, input_offset)), [offsetof_num_strings] "I" (offsetof(KernelArgs, num_strings)), [offsetof_output_offset] "I" (offsetof(KernelArgs, output_offset)), [offsetof_string_lengths] "I" (offsetof(KernelArgs, string_lengths)) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", "x12", "x13", "x14", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +} + +} // namespace arm_gemm +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_dot_8x12.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_dot_8x12.hpp new file mode 100644 index 0000000000..e24dab68e8 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_dot_8x12.hpp @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#pragma once +#ifdef __aarch64__ + +#include "../std_transforms_fixed.hpp" +#include "../bfloat.hpp" +#include "../kernel_weight_format.hpp" +#include "../performance_parameters.hpp" + +#define ARGLIST \ + const bfloat16 *, const bfloat16 *, size_t, \ + float *, int, size_t, int + +namespace arm_gemm +{ +// Actual kernel implementations +void a64_ffinterleaved_bf16fp32_dot_8x12( ARGLIST ); + +class cls_a64_ffinterleaved_bf16fp32_dot_8x12 +{ +public: + typedef bfloat16 operand_type; + typedef float result_type; + + typedef void (*kern_type)( ARGLIST ); + + /* Kernel blocking parameters */ + static constexpr unsigned int out_height() + { + return 8; + } + + static unsigned int out_width() + { + return 12; + } + static unsigned int stripe_width() + { + return 4; + } + + static KernelWeightFormat kernel_weight_format() + { + return KernelWeightFormat::VL128_BL32; + } + + static constexpr unsigned int k_unroll() + { + return 2; + } + + + StdTransformsFixed transforms = {}; + StdTransformsFixed transforms_quantized = {}; + template + static inline PerformanceParameters get_performance_parameters(const CPUInfo *ci) + { + + if (std::is_same::value) { + switch (ci->get_cpu_model()) { + default: + return { 22.16, 8.25, 3.26 }; + } + } + + return { 1.0 }; + } + + // Default to the generic kernel + kern_type kernel=a64_ffinterleaved_bf16fp32_dot_8x12; + cls_a64_ffinterleaved_bf16fp32_dot_8x12(const CPUInfo *) + { + } +}; + +} // namespace arm_gemm + +#undef ARGLIST +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_dot_8x12/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_dot_8x12/generic.cpp new file mode 100644 index 0000000000..967396c377 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_dot_8x12/generic.cpp @@ -0,0 +1,269 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#ifdef __aarch64__ + +#include +#include "../../bfloat.hpp" + +namespace arm_gemm { + +void a64_ffinterleaved_bf16fp32_dot_8x12( + const bfloat16 *Apanel, + const bfloat16 *Bpanel, + size_t B_stride, + float *Cpanel, + int ablocks, + size_t N, + int K) { + + struct KernelArgs { + size_t K = {}; + const bfloat16 *Bpanel = {}; + size_t N = {}; + size_t B_stride = {}; + const bfloat16 *cur_B_ptr = {}; + } ka; + + ka.K = (K/2) - 1; + ka.Bpanel = Bpanel; + ka.N = N; + ka.B_stride = B_stride; + + __asm__ __volatile__( + "1:" // Height loop + "ldr x24, [%x[args_ptr], %[offsetof_Bpanel]]\n" + "ldr x23, [%x[args_ptr], %[offsetof_N]]\n" + "str x24, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x22, %x[Apanel]\n" + "2:" // Width loop + "ldr x24, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x21, x24, x19, LSL #1\n" + "add x20, x21, x19, LSL #1\n" + "add x19, x20, x19, LSL #1\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "cmp x23, #0x8\n" + "mov %x[Apanel], x22\n" + "bgt 3f\n" + "cmp x23, #0x4\n" + "mov x20, x24\n" + "bgt 3f\n" + "mov x21, x24\n" + "3:" // B setup done + "ldr q0, [%x[Apanel], #0x0]\n" + "ldr q1, [%x[Apanel], #0x10]\n" + "movi v8.16b, #0x0\n" + "ldr q4, [x24, #0x0]\n" + "ldr q5, [x21, #0x0]\n" + "movi v9.16b, #0x0\n" + "ldr q6, [x20, #0x0]\n" + "ldr x19, [%x[args_ptr], %[offsetof_K]]\n" + "cmp x19, #0x2\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v26.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "movi v29.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "blt 5f\n" + "4:" // main loop head + "ldr q2, [%x[Apanel], #0x20]\n" + "ldr q3, [%x[Apanel], #0x30]\n" + ".inst 0x4f40f088 // bfdot v8.4s, v4.8h, v0.h[0]\n" + ".inst 0x4f60f08b // bfdot v11.4s, v4.8h, v0.h[1]\n" + ".inst 0x4f40f88e // bfdot v14.4s, v4.8h, v0.h[2]\n" + "sub x19, x19, #0x2\n" + ".inst 0x4f60f891 // bfdot v17.4s, v4.8h, v0.h[3]\n" + ".inst 0x4f41f094 // bfdot v20.4s, v4.8h, v1.h[0]\n" + "cmp x19, #0x2\n" + ".inst 0x4f61f097 // bfdot v23.4s, v4.8h, v1.h[1]\n" + ".inst 0x4f41f89a // bfdot v26.4s, v4.8h, v1.h[2]\n" + "add %x[Apanel], %x[Apanel], #0x40\n" + ".inst 0x4f61f89d // bfdot v29.4s, v4.8h, v1.h[3]\n" + "ldr q4, [x24, #0x10]\n" + ".inst 0x4f40f0a9 // bfdot v9.4s, v5.8h, v0.h[0]\n" + ".inst 0x4f60f0ac // bfdot v12.4s, v5.8h, v0.h[1]\n" + ".inst 0x4f40f8af // bfdot v15.4s, v5.8h, v0.h[2]\n" + "add x24, x24, #0x20\n" + ".inst 0x4f60f8b2 // bfdot v18.4s, v5.8h, v0.h[3]\n" + ".inst 0x4f41f0b5 // bfdot v21.4s, v5.8h, v1.h[0]\n" + ".inst 0x4f61f0b8 // bfdot v24.4s, v5.8h, v1.h[1]\n" + ".inst 0x4f41f8bb // bfdot v27.4s, v5.8h, v1.h[2]\n" + ".inst 0x4f61f8be // bfdot v30.4s, v5.8h, v1.h[3]\n" + "ldr q5, [x21, #0x10]\n" + ".inst 0x4f40f0ca // bfdot v10.4s, v6.8h, v0.h[0]\n" + ".inst 0x4f60f0cd // bfdot v13.4s, v6.8h, v0.h[1]\n" + ".inst 0x4f40f8d0 // bfdot v16.4s, v6.8h, v0.h[2]\n" + "add x21, x21, #0x20\n" + ".inst 0x4f60f8d3 // bfdot v19.4s, v6.8h, v0.h[3]\n" + "ldr q0, [%x[Apanel], #0x0]\n" + ".inst 0x4f41f0d6 // bfdot v22.4s, v6.8h, v1.h[0]\n" + ".inst 0x4f61f0d9 // bfdot v25.4s, v6.8h, v1.h[1]\n" + ".inst 0x4f41f8dc // bfdot v28.4s, v6.8h, v1.h[2]\n" + ".inst 0x4f61f8df // bfdot v31.4s, v6.8h, v1.h[3]\n" + "ldr q6, [x20, #0x10]\n" + "ldr q1, [%x[Apanel], #0x10]\n" + "add x20, x20, #0x20\n" + ".inst 0x4f42f088 // bfdot v8.4s, v4.8h, v2.h[0]\n" + ".inst 0x4f62f08b // bfdot v11.4s, v4.8h, v2.h[1]\n" + ".inst 0x4f42f88e // bfdot v14.4s, v4.8h, v2.h[2]\n" + ".inst 0x4f62f891 // bfdot v17.4s, v4.8h, v2.h[3]\n" + ".inst 0x4f43f094 // bfdot v20.4s, v4.8h, v3.h[0]\n" + ".inst 0x4f63f097 // bfdot v23.4s, v4.8h, v3.h[1]\n" + ".inst 0x4f43f89a // bfdot v26.4s, v4.8h, v3.h[2]\n" + ".inst 0x4f63f89d // bfdot v29.4s, v4.8h, v3.h[3]\n" + "ldr q4, [x24, #0x0]\n" + ".inst 0x4f42f0a9 // bfdot v9.4s, v5.8h, v2.h[0]\n" + ".inst 0x4f62f0ac // bfdot v12.4s, v5.8h, v2.h[1]\n" + ".inst 0x4f42f8af // bfdot v15.4s, v5.8h, v2.h[2]\n" + ".inst 0x4f62f8b2 // bfdot v18.4s, v5.8h, v2.h[3]\n" + ".inst 0x4f43f0b5 // bfdot v21.4s, v5.8h, v3.h[0]\n" + ".inst 0x4f63f0b8 // bfdot v24.4s, v5.8h, v3.h[1]\n" + ".inst 0x4f43f8bb // bfdot v27.4s, v5.8h, v3.h[2]\n" + ".inst 0x4f63f8be // bfdot v30.4s, v5.8h, v3.h[3]\n" + "ldr q5, [x21, #0x0]\n" + ".inst 0x4f42f0ca // bfdot v10.4s, v6.8h, v2.h[0]\n" + ".inst 0x4f62f0cd // bfdot v13.4s, v6.8h, v2.h[1]\n" + ".inst 0x4f42f8d0 // bfdot v16.4s, v6.8h, v2.h[2]\n" + ".inst 0x4f62f8d3 // bfdot v19.4s, v6.8h, v2.h[3]\n" + ".inst 0x4f43f0d6 // bfdot v22.4s, v6.8h, v3.h[0]\n" + ".inst 0x4f63f0d9 // bfdot v25.4s, v6.8h, v3.h[1]\n" + ".inst 0x4f43f8dc // bfdot v28.4s, v6.8h, v3.h[2]\n" + ".inst 0x4f63f8df // bfdot v31.4s, v6.8h, v3.h[3]\n" + "ldr q6, [x20, #0x0]\n" + "bge 4b\n" + "5:" // main loop skip + ".inst 0x4f40f088 // bfdot v8.4s, v4.8h, v0.h[0]\n" + ".inst 0x4f60f08b // bfdot v11.4s, v4.8h, v0.h[1]\n" + "add %x[Apanel], %x[Apanel], #0x20\n" + ".inst 0x4f40f88e // bfdot v14.4s, v4.8h, v0.h[2]\n" + ".inst 0x4f60f891 // bfdot v17.4s, v4.8h, v0.h[3]\n" + "add x24, x24, #0x10\n" + ".inst 0x4f41f094 // bfdot v20.4s, v4.8h, v1.h[0]\n" + ".inst 0x4f61f097 // bfdot v23.4s, v4.8h, v1.h[1]\n" + "add x21, x21, #0x10\n" + ".inst 0x4f41f89a // bfdot v26.4s, v4.8h, v1.h[2]\n" + ".inst 0x4f61f89d // bfdot v29.4s, v4.8h, v1.h[3]\n" + "add x20, x20, #0x10\n" + ".inst 0x4f40f0a9 // bfdot v9.4s, v5.8h, v0.h[0]\n" + ".inst 0x4f60f0ac // bfdot v12.4s, v5.8h, v0.h[1]\n" + ".inst 0x4f40f8af // bfdot v15.4s, v5.8h, v0.h[2]\n" + ".inst 0x4f60f8b2 // bfdot v18.4s, v5.8h, v0.h[3]\n" + ".inst 0x4f41f0b5 // bfdot v21.4s, v5.8h, v1.h[0]\n" + ".inst 0x4f61f0b8 // bfdot v24.4s, v5.8h, v1.h[1]\n" + ".inst 0x4f41f8bb // bfdot v27.4s, v5.8h, v1.h[2]\n" + ".inst 0x4f61f8be // bfdot v30.4s, v5.8h, v1.h[3]\n" + ".inst 0x4f40f0ca // bfdot v10.4s, v6.8h, v0.h[0]\n" + ".inst 0x4f60f0cd // bfdot v13.4s, v6.8h, v0.h[1]\n" + ".inst 0x4f40f8d0 // bfdot v16.4s, v6.8h, v0.h[2]\n" + ".inst 0x4f60f8d3 // bfdot v19.4s, v6.8h, v0.h[3]\n" + ".inst 0x4f41f0d6 // bfdot v22.4s, v6.8h, v1.h[0]\n" + ".inst 0x4f61f0d9 // bfdot v25.4s, v6.8h, v1.h[1]\n" + ".inst 0x4f41f8dc // bfdot v28.4s, v6.8h, v1.h[2]\n" + ".inst 0x4f61f8df // bfdot v31.4s, v6.8h, v1.h[3]\n" + "cbz x19, 6f\n" + "ldr q0, [%x[Apanel], #0x0]\n" + "ldr q1, [%x[Apanel], #0x10]\n" + "add %x[Apanel], %x[Apanel], #0x20\n" + "ldr q7, [x24, #0x0]\n" + "ldr q4, [x21, #0x0]\n" + ".inst 0x4f40f0e8 // bfdot v8.4s, v7.8h, v0.h[0]\n" + "ldr q5, [x20, #0x0]\n" + ".inst 0x4f60f0eb // bfdot v11.4s, v7.8h, v0.h[1]\n" + ".inst 0x4f40f8ee // bfdot v14.4s, v7.8h, v0.h[2]\n" + ".inst 0x4f60f8f1 // bfdot v17.4s, v7.8h, v0.h[3]\n" + ".inst 0x4f41f0f4 // bfdot v20.4s, v7.8h, v1.h[0]\n" + ".inst 0x4f61f0f7 // bfdot v23.4s, v7.8h, v1.h[1]\n" + ".inst 0x4f41f8fa // bfdot v26.4s, v7.8h, v1.h[2]\n" + ".inst 0x4f61f8fd // bfdot v29.4s, v7.8h, v1.h[3]\n" + ".inst 0x4f40f089 // bfdot v9.4s, v4.8h, v0.h[0]\n" + ".inst 0x4f60f08c // bfdot v12.4s, v4.8h, v0.h[1]\n" + ".inst 0x4f40f88f // bfdot v15.4s, v4.8h, v0.h[2]\n" + ".inst 0x4f60f892 // bfdot v18.4s, v4.8h, v0.h[3]\n" + ".inst 0x4f41f095 // bfdot v21.4s, v4.8h, v1.h[0]\n" + ".inst 0x4f61f098 // bfdot v24.4s, v4.8h, v1.h[1]\n" + ".inst 0x4f41f89b // bfdot v27.4s, v4.8h, v1.h[2]\n" + ".inst 0x4f61f89e // bfdot v30.4s, v4.8h, v1.h[3]\n" + ".inst 0x4f40f0aa // bfdot v10.4s, v5.8h, v0.h[0]\n" + ".inst 0x4f60f0ad // bfdot v13.4s, v5.8h, v0.h[1]\n" + ".inst 0x4f40f8b0 // bfdot v16.4s, v5.8h, v0.h[2]\n" + ".inst 0x4f60f8b3 // bfdot v19.4s, v5.8h, v0.h[3]\n" + ".inst 0x4f41f0b6 // bfdot v22.4s, v5.8h, v1.h[0]\n" + ".inst 0x4f61f0b9 // bfdot v25.4s, v5.8h, v1.h[1]\n" + ".inst 0x4f41f8bc // bfdot v28.4s, v5.8h, v1.h[2]\n" + ".inst 0x4f61f8bf // bfdot v31.4s, v5.8h, v1.h[3]\n" + "6:" // multiply loop done + "subs x23, x23, #0xc\n" + "str q8, [%x[Cpanel], #0x0]\n" + "str q9, [%x[Cpanel], #0x10]\n" + "str q10, [%x[Cpanel], #0x20]\n" + "str q11, [%x[Cpanel], #0x30]\n" + "str q12, [%x[Cpanel], #0x40]\n" + "str q13, [%x[Cpanel], #0x50]\n" + "str q14, [%x[Cpanel], #0x60]\n" + "str q15, [%x[Cpanel], #0x70]\n" + "str q16, [%x[Cpanel], #0x80]\n" + "str q17, [%x[Cpanel], #0x90]\n" + "str q18, [%x[Cpanel], #0xa0]\n" + "str q19, [%x[Cpanel], #0xb0]\n" + "str q20, [%x[Cpanel], #0xc0]\n" + "str q21, [%x[Cpanel], #0xd0]\n" + "str q22, [%x[Cpanel], #0xe0]\n" + "str q23, [%x[Cpanel], #0xf0]\n" + "str q24, [%x[Cpanel], #0x100]\n" + "str q25, [%x[Cpanel], #0x110]\n" + "str q26, [%x[Cpanel], #0x120]\n" + "str q27, [%x[Cpanel], #0x130]\n" + "str q28, [%x[Cpanel], #0x140]\n" + "str q29, [%x[Cpanel], #0x150]\n" + "str q30, [%x[Cpanel], #0x160]\n" + "str q31, [%x[Cpanel], #0x170]\n" + "add %x[Cpanel], %x[Cpanel], #0x180\n" + "bgt 2b\n" + "subs %x[ablocks], %x[ablocks], #0x1\n" + "bne 1b\n" + : [Apanel] "+&r" (Apanel), [Cpanel] "+&r" (Cpanel), [ablocks] "+&r" (ablocks) + : [args_ptr] "r" (&ka), [offsetof_B_stride] "I" (offsetof(KernelArgs, B_stride)), [offsetof_Bpanel] "I" (offsetof(KernelArgs, Bpanel)), [offsetof_K] "I" (offsetof(KernelArgs, K)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_cur_B_ptr] "I" (offsetof(KernelArgs, cur_B_ptr)) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x19", "x20", "x21", "x22", "x23", "x24" + ); +} + +} // namespace arm_gemm +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp new file mode 100644 index 0000000000..c61315b80a --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#pragma once +#ifdef __aarch64__ + +#include "../std_transforms_fixed.hpp" +#include "../bfloat.hpp" +#include "../kernel_weight_format.hpp" +#include "../performance_parameters.hpp" + +#define ARGLIST \ + const bfloat16 *, const bfloat16 *, size_t, \ + float *, int, size_t, int + +namespace arm_gemm +{ +// Actual kernel implementations +void a64_ffinterleaved_bf16fp32_mmla_8x12( ARGLIST ); + +class cls_a64_ffinterleaved_bf16fp32_mmla_8x12 +{ +public: + typedef bfloat16 operand_type; + typedef float result_type; + + typedef void (*kern_type)( ARGLIST ); + + /* Kernel blocking parameters */ + static constexpr unsigned int out_height() + { + return 8; + } + + static unsigned int out_width() + { + return 12; + } + static unsigned int stripe_width() + { + return 4; + } + + static KernelWeightFormat kernel_weight_format() + { + return KernelWeightFormat::VL256_BL64; + } + + static constexpr unsigned int k_unroll() + { + return 4; + } + + + StdTransformsFixed transforms = {}; + StdTransformsFixed transforms_quantized = {}; + template + static inline PerformanceParameters get_performance_parameters(const CPUInfo *ci) + { + + if (std::is_same::value) { + switch (ci->get_cpu_model()) { + default: + return { 31.62, 9.07, 3.23 }; + } + } + + + if (std::is_same::value) { + switch (ci->get_cpu_model()) { + default: + return { 38.10, 5.23, 3.15 }; + } + } + + return { 1.0 }; + } + + // Default to the generic kernel + kern_type kernel=a64_ffinterleaved_bf16fp32_mmla_8x12; + cls_a64_ffinterleaved_bf16fp32_mmla_8x12(const CPUInfo *) + { + } +}; + +} // namespace arm_gemm + +#undef ARGLIST +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12/generic.cpp new file mode 100644 index 0000000000..509f2afa09 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_bf16fp32_mmla_8x12/generic.cpp @@ -0,0 +1,314 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#ifdef __aarch64__ + +#include +#include "../../bfloat.hpp" + +namespace arm_gemm { + +void a64_ffinterleaved_bf16fp32_mmla_8x12( + const bfloat16 *Apanel, + const bfloat16 *Bpanel, + size_t B_stride, + float *Cpanel, + int ablocks, + size_t N, + int K) { + + struct KernelArgs { + size_t K = {}; + const bfloat16 *Bpanel = {}; + size_t N = {}; + size_t B_stride = {}; + const bfloat16 *cur_B_ptr = {}; + } ka; + + ka.K = (K/4) - 1; + ka.Bpanel = Bpanel; + ka.N = N; + ka.B_stride = B_stride; + + __asm__ __volatile__( + "1:" // Height loop + "ldr x24, [%x[args_ptr], %[offsetof_Bpanel]]\n" + "ldr x23, [%x[args_ptr], %[offsetof_N]]\n" + "str x24, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x22, %x[Apanel]\n" + "2:" // Width loop + "ldr x24, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x21, x24, x19, LSL #1\n" + "add x20, x21, x19, LSL #1\n" + "add x19, x20, x19, LSL #1\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "cmp x23, #0x8\n" + "mov %x[Apanel], x22\n" + "bgt 3f\n" + "cmp x23, #0x4\n" + "mov x20, x24\n" + "bgt 3f\n" + "mov x21, x24\n" + "3:" // B setup done + "ldr q4, [x24, #0x0]\n" + "ldr q0, [%x[Apanel], #0x0]\n" + "movi v8.16b, #0x0\n" + "ldr q1, [%x[Apanel], #0x10]\n" + "ldr q5, [x24, #0x10]\n" + "movi v9.16b, #0x0\n" + "ldr q2, [%x[Apanel], #0x20]\n" + "ldr x19, [%x[args_ptr], %[offsetof_K]]\n" + "cmp x19, #0x2\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "add x24, x24, #0x20\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "add %x[Apanel], %x[Apanel], #0x30\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v26.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "movi v29.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "blt 5f\n" + "4:" // main loop head + "ldr q3, [%x[Apanel], #0x0]\n" + "ldr q6, [x21, #0x0]\n" + ".inst 0x6e44ec08 // bfmmla v8.4s, v0.8h, v4.8h\n" + "ldr q7, [x21, #0x10]\n" + ".inst 0x6e45ec0b // bfmmla v11.4s, v0.8h, v5.8h\n" + ".inst 0x6e44ec2e // bfmmla v14.4s, v1.8h, v4.8h\n" + ".inst 0x6e45ec31 // bfmmla v17.4s, v1.8h, v5.8h\n" + ".inst 0x6e44ec54 // bfmmla v20.4s, v2.8h, v4.8h\n" + "sub x19, x19, #0x2\n" + ".inst 0x6e45ec57 // bfmmla v23.4s, v2.8h, v5.8h\n" + ".inst 0x6e44ec7a // bfmmla v26.4s, v3.8h, v4.8h\n" + "ldr q4, [x20, #0x0]\n" + ".inst 0x6e45ec7d // bfmmla v29.4s, v3.8h, v5.8h\n" + "ldr q5, [x20, #0x10]\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + ".inst 0x6e46ec2f // bfmmla v15.4s, v1.8h, v6.8h\n" + "cmp x19, #0x2\n" + ".inst 0x6e47ec32 // bfmmla v18.4s, v1.8h, v7.8h\n" + ".inst 0x6e46ec55 // bfmmla v21.4s, v2.8h, v6.8h\n" + ".inst 0x6e47ec58 // bfmmla v24.4s, v2.8h, v7.8h\n" + ".inst 0x6e46ec7b // bfmmla v27.4s, v3.8h, v6.8h\n" + "ldr q6, [x24, #0x0]\n" + ".inst 0x6e47ec7e // bfmmla v30.4s, v3.8h, v7.8h\n" + "ldr q7, [x24, #0x10]\n" + ".inst 0x6e44ec0a // bfmmla v10.4s, v0.8h, v4.8h\n" + ".inst 0x6e45ec0d // bfmmla v13.4s, v0.8h, v5.8h\n" + "ldr q0, [%x[Apanel], #0x10]\n" + ".inst 0x6e44ec30 // bfmmla v16.4s, v1.8h, v4.8h\n" + ".inst 0x6e45ec33 // bfmmla v19.4s, v1.8h, v5.8h\n" + "ldr q1, [%x[Apanel], #0x20]\n" + ".inst 0x6e44ec56 // bfmmla v22.4s, v2.8h, v4.8h\n" + ".inst 0x6e45ec59 // bfmmla v25.4s, v2.8h, v5.8h\n" + "ldr q2, [%x[Apanel], #0x30]\n" + ".inst 0x6e44ec7c // bfmmla v28.4s, v3.8h, v4.8h\n" + "ldr q4, [x21, #0x20]\n" + ".inst 0x6e45ec7f // bfmmla v31.4s, v3.8h, v5.8h\n" + "ldr q3, [%x[Apanel], #0x40]\n" + "ldr q5, [x21, #0x30]\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec0b // bfmmla v11.4s, v0.8h, v7.8h\n" + ".inst 0x6e46ec2e // bfmmla v14.4s, v1.8h, v6.8h\n" + ".inst 0x6e47ec31 // bfmmla v17.4s, v1.8h, v7.8h\n" + "add x21, x21, #0x40\n" + ".inst 0x6e46ec54 // bfmmla v20.4s, v2.8h, v6.8h\n" + ".inst 0x6e47ec57 // bfmmla v23.4s, v2.8h, v7.8h\n" + ".inst 0x6e46ec7a // bfmmla v26.4s, v3.8h, v6.8h\n" + "ldr q6, [x20, #0x20]\n" + ".inst 0x6e47ec7d // bfmmla v29.4s, v3.8h, v7.8h\n" + "ldr q7, [x20, #0x30]\n" + ".inst 0x6e44ec09 // bfmmla v9.4s, v0.8h, v4.8h\n" + ".inst 0x6e45ec0c // bfmmla v12.4s, v0.8h, v5.8h\n" + ".inst 0x6e44ec2f // bfmmla v15.4s, v1.8h, v4.8h\n" + ".inst 0x6e45ec32 // bfmmla v18.4s, v1.8h, v5.8h\n" + "add x20, x20, #0x40\n" + ".inst 0x6e44ec55 // bfmmla v21.4s, v2.8h, v4.8h\n" + ".inst 0x6e45ec58 // bfmmla v24.4s, v2.8h, v5.8h\n" + ".inst 0x6e44ec7b // bfmmla v27.4s, v3.8h, v4.8h\n" + "ldr q4, [x24, #0x20]\n" + ".inst 0x6e45ec7e // bfmmla v30.4s, v3.8h, v5.8h\n" + "ldr q5, [x24, #0x30]\n" + ".inst 0x6e46ec0a // bfmmla v10.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec0d // bfmmla v13.4s, v0.8h, v7.8h\n" + "ldr q0, [%x[Apanel], #0x50]\n" + ".inst 0x6e46ec30 // bfmmla v16.4s, v1.8h, v6.8h\n" + ".inst 0x6e47ec33 // bfmmla v19.4s, v1.8h, v7.8h\n" + "ldr q1, [%x[Apanel], #0x60]\n" + ".inst 0x6e46ec56 // bfmmla v22.4s, v2.8h, v6.8h\n" + ".inst 0x6e47ec59 // bfmmla v25.4s, v2.8h, v7.8h\n" + "ldr q2, [%x[Apanel], #0x70]\n" + ".inst 0x6e46ec7c // bfmmla v28.4s, v3.8h, v6.8h\n" + ".inst 0x6e47ec7f // bfmmla v31.4s, v3.8h, v7.8h\n" + "add %x[Apanel], %x[Apanel], #0x80\n" + "add x24, x24, #0x40\n" + "bge 4b\n" + "5:" // main loop skip + "ldr q3, [%x[Apanel], #0x0]\n" + "ldr q6, [x21, #0x0]\n" + ".inst 0x6e44ec08 // bfmmla v8.4s, v0.8h, v4.8h\n" + "ldr q7, [x21, #0x10]\n" + ".inst 0x6e45ec0b // bfmmla v11.4s, v0.8h, v5.8h\n" + ".inst 0x6e44ec2e // bfmmla v14.4s, v1.8h, v4.8h\n" + ".inst 0x6e45ec31 // bfmmla v17.4s, v1.8h, v5.8h\n" + ".inst 0x6e44ec54 // bfmmla v20.4s, v2.8h, v4.8h\n" + "add %x[Apanel], %x[Apanel], #0x10\n" + ".inst 0x6e45ec57 // bfmmla v23.4s, v2.8h, v5.8h\n" + ".inst 0x6e44ec7a // bfmmla v26.4s, v3.8h, v4.8h\n" + "ldr q4, [x20, #0x0]\n" + ".inst 0x6e45ec7d // bfmmla v29.4s, v3.8h, v5.8h\n" + "ldr q5, [x20, #0x10]\n" + ".inst 0x6e46ec09 // bfmmla v9.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec0c // bfmmla v12.4s, v0.8h, v7.8h\n" + ".inst 0x6e46ec2f // bfmmla v15.4s, v1.8h, v6.8h\n" + "add x21, x21, #0x20\n" + ".inst 0x6e47ec32 // bfmmla v18.4s, v1.8h, v7.8h\n" + ".inst 0x6e46ec55 // bfmmla v21.4s, v2.8h, v6.8h\n" + "add x20, x20, #0x20\n" + ".inst 0x6e47ec58 // bfmmla v24.4s, v2.8h, v7.8h\n" + ".inst 0x6e46ec7b // bfmmla v27.4s, v3.8h, v6.8h\n" + ".inst 0x6e47ec7e // bfmmla v30.4s, v3.8h, v7.8h\n" + ".inst 0x6e44ec0a // bfmmla v10.4s, v0.8h, v4.8h\n" + ".inst 0x6e45ec0d // bfmmla v13.4s, v0.8h, v5.8h\n" + ".inst 0x6e44ec30 // bfmmla v16.4s, v1.8h, v4.8h\n" + ".inst 0x6e45ec33 // bfmmla v19.4s, v1.8h, v5.8h\n" + ".inst 0x6e44ec56 // bfmmla v22.4s, v2.8h, v4.8h\n" + ".inst 0x6e45ec59 // bfmmla v25.4s, v2.8h, v5.8h\n" + ".inst 0x6e44ec7c // bfmmla v28.4s, v3.8h, v4.8h\n" + ".inst 0x6e45ec7f // bfmmla v31.4s, v3.8h, v5.8h\n" + "cbz x19, 6f\n" + "ldr q6, [x24, #0x0]\n" + "ldr q0, [%x[Apanel], #0x0]\n" + ".inst 0x6e46ec08 // bfmmla v8.4s, v0.8h, v6.8h\n" + "ldr q1, [%x[Apanel], #0x10]\n" + "ldr q7, [x24, #0x10]\n" + ".inst 0x6e47ec0b // bfmmla v11.4s, v0.8h, v7.8h\n" + "ldr q2, [%x[Apanel], #0x20]\n" + "ldr q3, [%x[Apanel], #0x30]\n" + ".inst 0x6e46ec2e // bfmmla v14.4s, v1.8h, v6.8h\n" + "ldr q4, [x21, #0x0]\n" + "ldr q5, [x21, #0x10]\n" + ".inst 0x6e47ec31 // bfmmla v17.4s, v1.8h, v7.8h\n" + ".inst 0x6e46ec54 // bfmmla v20.4s, v2.8h, v6.8h\n" + ".inst 0x6e47ec57 // bfmmla v23.4s, v2.8h, v7.8h\n" + "add %x[Apanel], %x[Apanel], #0x40\n" + ".inst 0x6e46ec7a // bfmmla v26.4s, v3.8h, v6.8h\n" + "ldr q6, [x20, #0x0]\n" + ".inst 0x6e47ec7d // bfmmla v29.4s, v3.8h, v7.8h\n" + "ldr q7, [x20, #0x10]\n" + ".inst 0x6e44ec09 // bfmmla v9.4s, v0.8h, v4.8h\n" + ".inst 0x6e45ec0c // bfmmla v12.4s, v0.8h, v5.8h\n" + ".inst 0x6e44ec2f // bfmmla v15.4s, v1.8h, v4.8h\n" + ".inst 0x6e45ec32 // bfmmla v18.4s, v1.8h, v5.8h\n" + ".inst 0x6e44ec55 // bfmmla v21.4s, v2.8h, v4.8h\n" + ".inst 0x6e45ec58 // bfmmla v24.4s, v2.8h, v5.8h\n" + ".inst 0x6e44ec7b // bfmmla v27.4s, v3.8h, v4.8h\n" + ".inst 0x6e45ec7e // bfmmla v30.4s, v3.8h, v5.8h\n" + ".inst 0x6e46ec0a // bfmmla v10.4s, v0.8h, v6.8h\n" + ".inst 0x6e47ec0d // bfmmla v13.4s, v0.8h, v7.8h\n" + ".inst 0x6e46ec30 // bfmmla v16.4s, v1.8h, v6.8h\n" + ".inst 0x6e47ec33 // bfmmla v19.4s, v1.8h, v7.8h\n" + ".inst 0x6e46ec56 // bfmmla v22.4s, v2.8h, v6.8h\n" + ".inst 0x6e47ec59 // bfmmla v25.4s, v2.8h, v7.8h\n" + ".inst 0x6e46ec7c // bfmmla v28.4s, v3.8h, v6.8h\n" + ".inst 0x6e47ec7f // bfmmla v31.4s, v3.8h, v7.8h\n" + "6:" // multiply loop done + "subs x23, x23, #0xc\n" + "uzp1 v4.2d, v8.2d, v11.2d\n" + "uzp2 v8.2d, v8.2d, v11.2d\n" + "uzp1 v11.2d, v9.2d, v12.2d\n" + "uzp2 v9.2d, v9.2d, v12.2d\n" + "str q4, [%x[Cpanel], #0x0]\n" + "uzp1 v12.2d, v10.2d, v13.2d\n" + "uzp2 v10.2d, v10.2d, v13.2d\n" + "str q11, [%x[Cpanel], #0x10]\n" + "str q12, [%x[Cpanel], #0x20]\n" + "uzp1 v13.2d, v14.2d, v17.2d\n" + "uzp2 v14.2d, v14.2d, v17.2d\n" + "str q8, [%x[Cpanel], #0x30]\n" + "uzp1 v17.2d, v15.2d, v18.2d\n" + "uzp2 v15.2d, v15.2d, v18.2d\n" + "str q9, [%x[Cpanel], #0x40]\n" + "uzp1 v18.2d, v16.2d, v19.2d\n" + "uzp2 v16.2d, v16.2d, v19.2d\n" + "str q10, [%x[Cpanel], #0x50]\n" + "uzp1 v19.2d, v20.2d, v23.2d\n" + "uzp2 v20.2d, v20.2d, v23.2d\n" + "str q13, [%x[Cpanel], #0x60]\n" + "uzp1 v23.2d, v21.2d, v24.2d\n" + "uzp2 v21.2d, v21.2d, v24.2d\n" + "str q17, [%x[Cpanel], #0x70]\n" + "uzp1 v24.2d, v22.2d, v25.2d\n" + "uzp2 v22.2d, v22.2d, v25.2d\n" + "str q18, [%x[Cpanel], #0x80]\n" + "uzp1 v25.2d, v26.2d, v29.2d\n" + "uzp2 v26.2d, v26.2d, v29.2d\n" + "str q14, [%x[Cpanel], #0x90]\n" + "uzp1 v29.2d, v27.2d, v30.2d\n" + "uzp2 v27.2d, v27.2d, v30.2d\n" + "str q15, [%x[Cpanel], #0xa0]\n" + "uzp1 v30.2d, v28.2d, v31.2d\n" + "uzp2 v28.2d, v28.2d, v31.2d\n" + "str q16, [%x[Cpanel], #0xb0]\n" + "str q19, [%x[Cpanel], #0xc0]\n" + "str q23, [%x[Cpanel], #0xd0]\n" + "str q24, [%x[Cpanel], #0xe0]\n" + "str q20, [%x[Cpanel], #0xf0]\n" + "str q21, [%x[Cpanel], #0x100]\n" + "str q22, [%x[Cpanel], #0x110]\n" + "str q25, [%x[Cpanel], #0x120]\n" + "str q29, [%x[Cpanel], #0x130]\n" + "str q30, [%x[Cpanel], #0x140]\n" + "str q26, [%x[Cpanel], #0x150]\n" + "str q27, [%x[Cpanel], #0x160]\n" + "str q28, [%x[Cpanel], #0x170]\n" + "add %x[Cpanel], %x[Cpanel], #0x180\n" + "bgt 2b\n" + "subs %x[ablocks], %x[ablocks], #0x1\n" + "bne 1b\n" + : [Apanel] "+&r" (Apanel), [Cpanel] "+&r" (Cpanel), [ablocks] "+&r" (ablocks) + : [args_ptr] "r" (&ka), [offsetof_B_stride] "I" (offsetof(KernelArgs, B_stride)), [offsetof_Bpanel] "I" (offsetof(KernelArgs, Bpanel)), [offsetof_K] "I" (offsetof(KernelArgs, K)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_cur_B_ptr] "I" (offsetof(KernelArgs, cur_B_ptr)) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x19", "x20", "x21", "x22", "x23", "x24" + ); +} + +} // namespace arm_gemm +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp16_mla_8x24.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp16_mla_8x24.hpp new file mode 100644 index 0000000000..1495306879 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp16_mla_8x24.hpp @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#pragma once +#ifdef __aarch64__ + +#include "../std_transforms_fixed.hpp" +#include "../kernel_weight_format.hpp" +#include "../performance_parameters.hpp" + +#define ARGLIST \ + const __fp16 *, const __fp16 *, size_t, \ + __fp16 *, int, size_t, int + +namespace arm_gemm +{ +// Actual kernel implementations +void a64_ffinterleaved_fp16_mla_8x24( ARGLIST ); + +class cls_a64_ffinterleaved_fp16_mla_8x24 +{ +public: + typedef __fp16 operand_type; + typedef __fp16 result_type; + + typedef void (*kern_type)( ARGLIST ); + + /* Kernel blocking parameters */ + static constexpr unsigned int out_height() + { + return 8; + } + + static unsigned int out_width() + { + return 24; + } + static unsigned int stripe_width() + { + return 8; + } + + static KernelWeightFormat kernel_weight_format() + { + return KernelWeightFormat::VL128_BL16; + } + + static constexpr unsigned int k_unroll() + { + return 1; + } + + + StdTransformsFixed transforms = {}; + StdTransformsFixed transforms_quantized = {}; + template + static inline PerformanceParameters get_performance_parameters(const CPUInfo *ci) + { + + if (std::is_same::value) { + switch (ci->get_cpu_model()) { + default: + return { 22.87, 7.77, 2.03 }; + } + } + + return { 1.0 }; + } + + // Default to the generic kernel + kern_type kernel=a64_ffinterleaved_fp16_mla_8x24; + cls_a64_ffinterleaved_fp16_mla_8x24(const CPUInfo *) + { + } +}; + +} // namespace arm_gemm + +#undef ARGLIST +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp16_mla_8x24/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp16_mla_8x24/generic.cpp new file mode 100644 index 0000000000..19836f2e9d --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp16_mla_8x24/generic.cpp @@ -0,0 +1,264 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)) + +#include + +namespace arm_gemm { + +void a64_ffinterleaved_fp16_mla_8x24( + const __fp16 *Apanel, + const __fp16 *Bpanel, + size_t B_stride, + __fp16 *Cpanel, + int ablocks, + size_t N, + int K) { + + struct KernelArgs { + size_t K = {}; + const __fp16 *Bpanel = {}; + size_t N = {}; + size_t B_stride = {}; + const __fp16 *cur_B_ptr = {}; + } ka; + + ka.K = (K/1) - 1; + ka.Bpanel = Bpanel; + ka.N = N; + ka.B_stride = B_stride; + + __asm__ __volatile__( + "1:" // Height loop + "ldr x24, [%x[args_ptr], %[offsetof_Bpanel]]\n" + "ldr x23, [%x[args_ptr], %[offsetof_N]]\n" + "str x24, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x22, %x[Apanel]\n" + "2:" // Width loop + "ldr x24, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x21, x24, x19, LSL #1\n" + "add x20, x21, x19, LSL #1\n" + "add x19, x20, x19, LSL #1\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "cmp x23, #0x10\n" + "mov %x[Apanel], x22\n" + "bgt 3f\n" + "cmp x23, #0x8\n" + "mov x20, x24\n" + "bgt 3f\n" + "mov x21, x24\n" + "3:" // B setup done + "ldr q0, [%x[Apanel], #0x0]\n" + "ldr q2, [x24, #0x0]\n" + "movi v8.16b, #0x0\n" + "ldr q3, [x21, #0x0]\n" + "ldr q4, [x20, #0x0]\n" + "movi v9.16b, #0x0\n" + "ldr x19, [%x[args_ptr], %[offsetof_K]]\n" + "cmp x19, #0x2\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v26.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "movi v29.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "blt 5f\n" + "4:" // main loop head + "ldr q1, [%x[Apanel], #0x10]\n" + "ldr q5, [x24, #0x10]\n" + "fmla v8.8h, v2.8h, v0.h[0]\n" + "ldr q6, [x21, #0x10]\n" + "ldr q7, [x20, #0x10]\n" + "fmla v11.8h, v2.8h, v0.h[1]\n" + "fmla v14.8h, v2.8h, v0.h[2]\n" + "fmla v17.8h, v2.8h, v0.h[3]\n" + "sub x19, x19, #0x2\n" + "fmla v20.8h, v2.8h, v0.h[4]\n" + "fmla v23.8h, v2.8h, v0.h[5]\n" + "cmp x19, #0x2\n" + "fmla v26.8h, v2.8h, v0.h[6]\n" + "fmla v29.8h, v2.8h, v0.h[7]\n" + "add %x[Apanel], %x[Apanel], #0x20\n" + "fmla v9.8h, v3.8h, v0.h[0]\n" + "fmla v12.8h, v3.8h, v0.h[1]\n" + "add x24, x24, #0x20\n" + "ldr q2, [x24, #0x0]\n" + "fmla v15.8h, v3.8h, v0.h[2]\n" + "fmla v18.8h, v3.8h, v0.h[3]\n" + "fmla v21.8h, v3.8h, v0.h[4]\n" + "fmla v24.8h, v3.8h, v0.h[5]\n" + "add x21, x21, #0x20\n" + "fmla v27.8h, v3.8h, v0.h[6]\n" + "fmla v30.8h, v3.8h, v0.h[7]\n" + "ldr q3, [x21, #0x0]\n" + "fmla v10.8h, v4.8h, v0.h[0]\n" + "fmla v13.8h, v4.8h, v0.h[1]\n" + "add x20, x20, #0x20\n" + "fmla v16.8h, v4.8h, v0.h[2]\n" + "fmla v19.8h, v4.8h, v0.h[3]\n" + "fmla v22.8h, v4.8h, v0.h[4]\n" + "fmla v25.8h, v4.8h, v0.h[5]\n" + "fmla v28.8h, v4.8h, v0.h[6]\n" + "fmla v31.8h, v4.8h, v0.h[7]\n" + "ldr q0, [%x[Apanel], #0x0]\n" + "ldr q4, [x20, #0x0]\n" + "fmla v8.8h, v5.8h, v1.h[0]\n" + "fmla v11.8h, v5.8h, v1.h[1]\n" + "fmla v14.8h, v5.8h, v1.h[2]\n" + "fmla v17.8h, v5.8h, v1.h[3]\n" + "fmla v20.8h, v5.8h, v1.h[4]\n" + "fmla v23.8h, v5.8h, v1.h[5]\n" + "fmla v26.8h, v5.8h, v1.h[6]\n" + "fmla v29.8h, v5.8h, v1.h[7]\n" + "fmla v9.8h, v6.8h, v1.h[0]\n" + "fmla v12.8h, v6.8h, v1.h[1]\n" + "fmla v15.8h, v6.8h, v1.h[2]\n" + "fmla v18.8h, v6.8h, v1.h[3]\n" + "fmla v21.8h, v6.8h, v1.h[4]\n" + "fmla v24.8h, v6.8h, v1.h[5]\n" + "fmla v27.8h, v6.8h, v1.h[6]\n" + "fmla v30.8h, v6.8h, v1.h[7]\n" + "fmla v10.8h, v7.8h, v1.h[0]\n" + "fmla v13.8h, v7.8h, v1.h[1]\n" + "fmla v16.8h, v7.8h, v1.h[2]\n" + "fmla v19.8h, v7.8h, v1.h[3]\n" + "fmla v22.8h, v7.8h, v1.h[4]\n" + "fmla v25.8h, v7.8h, v1.h[5]\n" + "fmla v28.8h, v7.8h, v1.h[6]\n" + "fmla v31.8h, v7.8h, v1.h[7]\n" + "bge 4b\n" + "5:" // main loop skip + "fmla v8.8h, v2.8h, v0.h[0]\n" + "fmla v11.8h, v2.8h, v0.h[1]\n" + "add %x[Apanel], %x[Apanel], #0x10\n" + "fmla v14.8h, v2.8h, v0.h[2]\n" + "fmla v17.8h, v2.8h, v0.h[3]\n" + "add x24, x24, #0x10\n" + "fmla v20.8h, v2.8h, v0.h[4]\n" + "fmla v23.8h, v2.8h, v0.h[5]\n" + "add x21, x21, #0x10\n" + "fmla v26.8h, v2.8h, v0.h[6]\n" + "fmla v29.8h, v2.8h, v0.h[7]\n" + "add x20, x20, #0x10\n" + "fmla v9.8h, v3.8h, v0.h[0]\n" + "fmla v12.8h, v3.8h, v0.h[1]\n" + "fmla v15.8h, v3.8h, v0.h[2]\n" + "fmla v18.8h, v3.8h, v0.h[3]\n" + "fmla v21.8h, v3.8h, v0.h[4]\n" + "fmla v24.8h, v3.8h, v0.h[5]\n" + "fmla v27.8h, v3.8h, v0.h[6]\n" + "fmla v30.8h, v3.8h, v0.h[7]\n" + "fmla v10.8h, v4.8h, v0.h[0]\n" + "fmla v13.8h, v4.8h, v0.h[1]\n" + "fmla v16.8h, v4.8h, v0.h[2]\n" + "fmla v19.8h, v4.8h, v0.h[3]\n" + "fmla v22.8h, v4.8h, v0.h[4]\n" + "fmla v25.8h, v4.8h, v0.h[5]\n" + "fmla v28.8h, v4.8h, v0.h[6]\n" + "fmla v31.8h, v4.8h, v0.h[7]\n" + "cbz x19, 6f\n" + "ldr q0, [%x[Apanel], #0x0]\n" + "ldr q5, [x24, #0x0]\n" + "fmla v8.8h, v5.8h, v0.h[0]\n" + "ldr q6, [x21, #0x0]\n" + "ldr q7, [x20, #0x0]\n" + "fmla v11.8h, v5.8h, v0.h[1]\n" + "fmla v14.8h, v5.8h, v0.h[2]\n" + "fmla v17.8h, v5.8h, v0.h[3]\n" + "add %x[Apanel], %x[Apanel], #0x10\n" + "fmla v20.8h, v5.8h, v0.h[4]\n" + "fmla v23.8h, v5.8h, v0.h[5]\n" + "fmla v26.8h, v5.8h, v0.h[6]\n" + "fmla v29.8h, v5.8h, v0.h[7]\n" + "fmla v9.8h, v6.8h, v0.h[0]\n" + "fmla v12.8h, v6.8h, v0.h[1]\n" + "fmla v15.8h, v6.8h, v0.h[2]\n" + "fmla v18.8h, v6.8h, v0.h[3]\n" + "fmla v21.8h, v6.8h, v0.h[4]\n" + "fmla v24.8h, v6.8h, v0.h[5]\n" + "fmla v27.8h, v6.8h, v0.h[6]\n" + "fmla v30.8h, v6.8h, v0.h[7]\n" + "fmla v10.8h, v7.8h, v0.h[0]\n" + "fmla v13.8h, v7.8h, v0.h[1]\n" + "fmla v16.8h, v7.8h, v0.h[2]\n" + "fmla v19.8h, v7.8h, v0.h[3]\n" + "fmla v22.8h, v7.8h, v0.h[4]\n" + "fmla v25.8h, v7.8h, v0.h[5]\n" + "fmla v28.8h, v7.8h, v0.h[6]\n" + "fmla v31.8h, v7.8h, v0.h[7]\n" + "6:" // multiply loop done + "subs x23, x23, #0x18\n" + "str q8, [%x[Cpanel], #0x0]\n" + "str q9, [%x[Cpanel], #0x10]\n" + "str q10, [%x[Cpanel], #0x20]\n" + "str q11, [%x[Cpanel], #0x30]\n" + "str q12, [%x[Cpanel], #0x40]\n" + "str q13, [%x[Cpanel], #0x50]\n" + "str q14, [%x[Cpanel], #0x60]\n" + "str q15, [%x[Cpanel], #0x70]\n" + "str q16, [%x[Cpanel], #0x80]\n" + "str q17, [%x[Cpanel], #0x90]\n" + "str q18, [%x[Cpanel], #0xa0]\n" + "str q19, [%x[Cpanel], #0xb0]\n" + "str q20, [%x[Cpanel], #0xc0]\n" + "str q21, [%x[Cpanel], #0xd0]\n" + "str q22, [%x[Cpanel], #0xe0]\n" + "str q23, [%x[Cpanel], #0xf0]\n" + "str q24, [%x[Cpanel], #0x100]\n" + "str q25, [%x[Cpanel], #0x110]\n" + "str q26, [%x[Cpanel], #0x120]\n" + "str q27, [%x[Cpanel], #0x130]\n" + "str q28, [%x[Cpanel], #0x140]\n" + "str q29, [%x[Cpanel], #0x150]\n" + "str q30, [%x[Cpanel], #0x160]\n" + "str q31, [%x[Cpanel], #0x170]\n" + "add %x[Cpanel], %x[Cpanel], #0x180\n" + "bgt 2b\n" + "subs %x[ablocks], %x[ablocks], #0x1\n" + "bne 1b\n" + : [Apanel] "+&r" (Apanel), [Cpanel] "+&r" (Cpanel), [ablocks] "+&r" (ablocks) + : [args_ptr] "r" (&ka), [offsetof_B_stride] "I" (offsetof(KernelArgs, B_stride)), [offsetof_Bpanel] "I" (offsetof(KernelArgs, Bpanel)), [offsetof_K] "I" (offsetof(KernelArgs, K)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_cur_B_ptr] "I" (offsetof(KernelArgs, cur_B_ptr)) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x19", "x20", "x21", "x22", "x23", "x24" + ); +} + +} // namespace arm_gemm +#endif // defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)) diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp32_mla_8x12.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp32_mla_8x12.hpp new file mode 100644 index 0000000000..f2a836c9b4 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp32_mla_8x12.hpp @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#pragma once +#ifdef __aarch64__ + +#include "../std_transforms_fixed.hpp" +#include "../kernel_weight_format.hpp" +#include "../performance_parameters.hpp" + +#define ARGLIST \ + const float *, const float *, size_t, \ + float *, int, size_t, int + +namespace arm_gemm +{ +// Actual kernel implementations +void a64_ffinterleaved_fp32_mla_8x12( ARGLIST ); + +class cls_a64_ffinterleaved_fp32_mla_8x12 +{ +public: + typedef float operand_type; + typedef float result_type; + + typedef void (*kern_type)( ARGLIST ); + + /* Kernel blocking parameters */ + static constexpr unsigned int out_height() + { + return 8; + } + + static unsigned int out_width() + { + return 12; + } + static unsigned int stripe_width() + { + return 4; + } + + static KernelWeightFormat kernel_weight_format() + { + return KernelWeightFormat::VL128_BL32; + } + + static constexpr unsigned int k_unroll() + { + return 1; + } + + + StdTransformsFixed transforms = {}; + StdTransformsFixed transforms_quantized = {}; + template + static inline PerformanceParameters get_performance_parameters(const CPUInfo *ci) + { + + if (std::is_same::value) { + switch (ci->get_cpu_model()) { + default: + return { 12.56, 9.83, 3.02 }; + } + } + + return { 1.0 }; + } + + // Default to the generic kernel + kern_type kernel=a64_ffinterleaved_fp32_mla_8x12; + cls_a64_ffinterleaved_fp32_mla_8x12(const CPUInfo *) + { + } +}; + +} // namespace arm_gemm + +#undef ARGLIST +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp32_mla_8x12/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp32_mla_8x12/generic.cpp new file mode 100644 index 0000000000..bf804b5f43 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_ffinterleaved_fp32_mla_8x12/generic.cpp @@ -0,0 +1,332 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#ifdef __aarch64__ + +#include + +namespace arm_gemm { + +void a64_ffinterleaved_fp32_mla_8x12( + const float *Apanel, + const float *Bpanel, + size_t B_stride, + float *Cpanel, + int ablocks, + size_t N, + int K) { + + struct KernelArgs { + size_t K = {}; + const float *Bpanel = {}; + size_t N = {}; + size_t B_stride = {}; + const float *cur_B_ptr = {}; + } ka; + + ka.K = (K/1) - 1; + ka.Bpanel = Bpanel; + ka.N = N; + ka.B_stride = B_stride; + + __asm__ __volatile__( + "1:" // Height loop + "ldr x24, [%x[args_ptr], %[offsetof_Bpanel]]\n" + "ldr x23, [%x[args_ptr], %[offsetof_N]]\n" + "str x24, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x22, %x[Apanel]\n" + "2:" // Width loop + "ldr x24, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x21, x24, x19, LSL #2\n" + "add x20, x21, x19, LSL #2\n" + "add x19, x20, x19, LSL #2\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "cmp x23, #0x8\n" + "mov %x[Apanel], x22\n" + "bgt 3f\n" + "cmp x23, #0x4\n" + "mov x20, x24\n" + "bgt 3f\n" + "mov x21, x24\n" + "3:" // B setup done + "ldr q0, [%x[Apanel], #0x0]\n" + "ldr q1, [%x[Apanel], #0x10]\n" + "movi v8.16b, #0x0\n" + "ldr q4, [x24, #0x0]\n" + "ldr q5, [x21, #0x0]\n" + "movi v9.16b, #0x0\n" + "ldr q6, [x20, #0x0]\n" + "ldr x19, [%x[args_ptr], %[offsetof_K]]\n" + "cmp x19, #0x4\n" + "movi v10.16b, #0x0\n" + "movi v11.16b, #0x0\n" + "movi v12.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "movi v15.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v17.16b, #0x0\n" + "movi v18.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "movi v20.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v26.16b, #0x0\n" + "movi v27.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "movi v29.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v31.16b, #0x0\n" + "blt 5f\n" + "4:" // main loop head + "ldr q2, [%x[Apanel], #0x20]\n" + "ldr q3, [%x[Apanel], #0x30]\n" + "fmla v8.4s, v4.4s, v0.s[0]\n" + "ldr q7, [x24, #0x10]\n" + "fmla v11.4s, v4.4s, v0.s[1]\n" + "fmla v14.4s, v4.4s, v0.s[2]\n" + "fmla v17.4s, v4.4s, v0.s[3]\n" + "fmla v20.4s, v4.4s, v1.s[0]\n" + "sub x19, x19, #0x4\n" + "fmla v23.4s, v4.4s, v1.s[1]\n" + "fmla v26.4s, v4.4s, v1.s[2]\n" + "cmp x19, #0x4\n" + "fmla v29.4s, v4.4s, v1.s[3]\n" + "ldr q4, [x21, #0x10]\n" + "fmla v9.4s, v5.4s, v0.s[0]\n" + "fmla v12.4s, v5.4s, v0.s[1]\n" + "fmla v15.4s, v5.4s, v0.s[2]\n" + "fmla v18.4s, v5.4s, v0.s[3]\n" + "fmla v21.4s, v5.4s, v1.s[0]\n" + "fmla v24.4s, v5.4s, v1.s[1]\n" + "fmla v27.4s, v5.4s, v1.s[2]\n" + "fmla v30.4s, v5.4s, v1.s[3]\n" + "ldr q5, [x20, #0x10]\n" + "fmla v10.4s, v6.4s, v0.s[0]\n" + "fmla v13.4s, v6.4s, v0.s[1]\n" + "fmla v16.4s, v6.4s, v0.s[2]\n" + "fmla v19.4s, v6.4s, v0.s[3]\n" + "ldr q0, [%x[Apanel], #0x40]\n" + "fmla v22.4s, v6.4s, v1.s[0]\n" + "fmla v25.4s, v6.4s, v1.s[1]\n" + "fmla v28.4s, v6.4s, v1.s[2]\n" + "fmla v31.4s, v6.4s, v1.s[3]\n" + "ldr q1, [%x[Apanel], #0x50]\n" + "ldr q6, [x24, #0x20]\n" + "fmla v8.4s, v7.4s, v2.s[0]\n" + "fmla v11.4s, v7.4s, v2.s[1]\n" + "fmla v14.4s, v7.4s, v2.s[2]\n" + "fmla v17.4s, v7.4s, v2.s[3]\n" + "fmla v20.4s, v7.4s, v3.s[0]\n" + "fmla v23.4s, v7.4s, v3.s[1]\n" + "fmla v26.4s, v7.4s, v3.s[2]\n" + "fmla v29.4s, v7.4s, v3.s[3]\n" + "ldr q7, [x21, #0x20]\n" + "fmla v9.4s, v4.4s, v2.s[0]\n" + "fmla v12.4s, v4.4s, v2.s[1]\n" + "fmla v15.4s, v4.4s, v2.s[2]\n" + "fmla v18.4s, v4.4s, v2.s[3]\n" + "fmla v21.4s, v4.4s, v3.s[0]\n" + "fmla v24.4s, v4.4s, v3.s[1]\n" + "fmla v27.4s, v4.4s, v3.s[2]\n" + "fmla v30.4s, v4.4s, v3.s[3]\n" + "ldr q4, [x20, #0x20]\n" + "fmla v10.4s, v5.4s, v2.s[0]\n" + "fmla v13.4s, v5.4s, v2.s[1]\n" + "fmla v16.4s, v5.4s, v2.s[2]\n" + "fmla v19.4s, v5.4s, v2.s[3]\n" + "ldr q2, [%x[Apanel], #0x60]\n" + "fmla v22.4s, v5.4s, v3.s[0]\n" + "fmla v25.4s, v5.4s, v3.s[1]\n" + "fmla v28.4s, v5.4s, v3.s[2]\n" + "fmla v31.4s, v5.4s, v3.s[3]\n" + "ldr q3, [%x[Apanel], #0x70]\n" + "ldr q5, [x24, #0x30]\n" + "fmla v8.4s, v6.4s, v0.s[0]\n" + "fmla v11.4s, v6.4s, v0.s[1]\n" + "fmla v14.4s, v6.4s, v0.s[2]\n" + "fmla v17.4s, v6.4s, v0.s[3]\n" + "add %x[Apanel], %x[Apanel], #0x80\n" + "fmla v20.4s, v6.4s, v1.s[0]\n" + "fmla v23.4s, v6.4s, v1.s[1]\n" + "add x24, x24, #0x40\n" + "fmla v26.4s, v6.4s, v1.s[2]\n" + "fmla v29.4s, v6.4s, v1.s[3]\n" + "ldr q6, [x21, #0x30]\n" + "fmla v9.4s, v7.4s, v0.s[0]\n" + "fmla v12.4s, v7.4s, v0.s[1]\n" + "add x21, x21, #0x40\n" + "fmla v15.4s, v7.4s, v0.s[2]\n" + "fmla v18.4s, v7.4s, v0.s[3]\n" + "fmla v21.4s, v7.4s, v1.s[0]\n" + "fmla v24.4s, v7.4s, v1.s[1]\n" + "fmla v27.4s, v7.4s, v1.s[2]\n" + "fmla v30.4s, v7.4s, v1.s[3]\n" + "ldr q7, [x20, #0x30]\n" + "fmla v10.4s, v4.4s, v0.s[0]\n" + "fmla v13.4s, v4.4s, v0.s[1]\n" + "add x20, x20, #0x40\n" + "fmla v16.4s, v4.4s, v0.s[2]\n" + "fmla v19.4s, v4.4s, v0.s[3]\n" + "ldr q0, [%x[Apanel], #0x0]\n" + "fmla v22.4s, v4.4s, v1.s[0]\n" + "fmla v25.4s, v4.4s, v1.s[1]\n" + "fmla v28.4s, v4.4s, v1.s[2]\n" + "fmla v31.4s, v4.4s, v1.s[3]\n" + "ldr q1, [%x[Apanel], #0x10]\n" + "ldr q4, [x24, #0x0]\n" + "fmla v8.4s, v5.4s, v2.s[0]\n" + "fmla v11.4s, v5.4s, v2.s[1]\n" + "fmla v14.4s, v5.4s, v2.s[2]\n" + "fmla v17.4s, v5.4s, v2.s[3]\n" + "fmla v20.4s, v5.4s, v3.s[0]\n" + "fmla v23.4s, v5.4s, v3.s[1]\n" + "fmla v26.4s, v5.4s, v3.s[2]\n" + "fmla v29.4s, v5.4s, v3.s[3]\n" + "ldr q5, [x21, #0x0]\n" + "fmla v9.4s, v6.4s, v2.s[0]\n" + "fmla v12.4s, v6.4s, v2.s[1]\n" + "fmla v15.4s, v6.4s, v2.s[2]\n" + "fmla v18.4s, v6.4s, v2.s[3]\n" + "fmla v21.4s, v6.4s, v3.s[0]\n" + "fmla v24.4s, v6.4s, v3.s[1]\n" + "fmla v27.4s, v6.4s, v3.s[2]\n" + "fmla v30.4s, v6.4s, v3.s[3]\n" + "ldr q6, [x20, #0x0]\n" + "fmla v10.4s, v7.4s, v2.s[0]\n" + "fmla v13.4s, v7.4s, v2.s[1]\n" + "fmla v16.4s, v7.4s, v2.s[2]\n" + "fmla v19.4s, v7.4s, v2.s[3]\n" + "fmla v22.4s, v7.4s, v3.s[0]\n" + "fmla v25.4s, v7.4s, v3.s[1]\n" + "fmla v28.4s, v7.4s, v3.s[2]\n" + "fmla v31.4s, v7.4s, v3.s[3]\n" + "bge 4b\n" + "5:" // main loop skip + "fmla v8.4s, v4.4s, v0.s[0]\n" + "fmla v11.4s, v4.4s, v0.s[1]\n" + "add %x[Apanel], %x[Apanel], #0x20\n" + "fmla v14.4s, v4.4s, v0.s[2]\n" + "fmla v17.4s, v4.4s, v0.s[3]\n" + "add x24, x24, #0x10\n" + "fmla v20.4s, v4.4s, v1.s[0]\n" + "fmla v23.4s, v4.4s, v1.s[1]\n" + "add x21, x21, #0x10\n" + "fmla v26.4s, v4.4s, v1.s[2]\n" + "fmla v29.4s, v4.4s, v1.s[3]\n" + "add x20, x20, #0x10\n" + "fmla v9.4s, v5.4s, v0.s[0]\n" + "fmla v12.4s, v5.4s, v0.s[1]\n" + "fmla v15.4s, v5.4s, v0.s[2]\n" + "fmla v18.4s, v5.4s, v0.s[3]\n" + "fmla v21.4s, v5.4s, v1.s[0]\n" + "fmla v24.4s, v5.4s, v1.s[1]\n" + "fmla v27.4s, v5.4s, v1.s[2]\n" + "fmla v30.4s, v5.4s, v1.s[3]\n" + "fmla v10.4s, v6.4s, v0.s[0]\n" + "fmla v13.4s, v6.4s, v0.s[1]\n" + "fmla v16.4s, v6.4s, v0.s[2]\n" + "fmla v19.4s, v6.4s, v0.s[3]\n" + "fmla v22.4s, v6.4s, v1.s[0]\n" + "fmla v25.4s, v6.4s, v1.s[1]\n" + "fmla v28.4s, v6.4s, v1.s[2]\n" + "fmla v31.4s, v6.4s, v1.s[3]\n" + "cbz x19, 7f\n" + "6:" // odd loop + "ldr q0, [%x[Apanel], #0x0]\n" + "ldr q1, [%x[Apanel], #0x10]\n" + "subs x19, x19, #0x1\n" + "ldr q7, [x24, #0x0]\n" + "ldr q4, [x21, #0x0]\n" + "fmla v8.4s, v7.4s, v0.s[0]\n" + "ldr q5, [x20, #0x0]\n" + "fmla v11.4s, v7.4s, v0.s[1]\n" + "fmla v14.4s, v7.4s, v0.s[2]\n" + "fmla v17.4s, v7.4s, v0.s[3]\n" + "fmla v20.4s, v7.4s, v1.s[0]\n" + "add %x[Apanel], %x[Apanel], #0x20\n" + "fmla v23.4s, v7.4s, v1.s[1]\n" + "fmla v26.4s, v7.4s, v1.s[2]\n" + "add x24, x24, #0x10\n" + "fmla v29.4s, v7.4s, v1.s[3]\n" + "fmla v9.4s, v4.4s, v0.s[0]\n" + "add x21, x21, #0x10\n" + "fmla v12.4s, v4.4s, v0.s[1]\n" + "fmla v15.4s, v4.4s, v0.s[2]\n" + "add x20, x20, #0x10\n" + "fmla v18.4s, v4.4s, v0.s[3]\n" + "fmla v21.4s, v4.4s, v1.s[0]\n" + "fmla v24.4s, v4.4s, v1.s[1]\n" + "fmla v27.4s, v4.4s, v1.s[2]\n" + "fmla v30.4s, v4.4s, v1.s[3]\n" + "fmla v10.4s, v5.4s, v0.s[0]\n" + "fmla v13.4s, v5.4s, v0.s[1]\n" + "fmla v16.4s, v5.4s, v0.s[2]\n" + "fmla v19.4s, v5.4s, v0.s[3]\n" + "fmla v22.4s, v5.4s, v1.s[0]\n" + "fmla v25.4s, v5.4s, v1.s[1]\n" + "fmla v28.4s, v5.4s, v1.s[2]\n" + "fmla v31.4s, v5.4s, v1.s[3]\n" + "bne 6b\n" + "7:" // multiply loop done + "subs x23, x23, #0xc\n" + "str q8, [%x[Cpanel], #0x0]\n" + "str q9, [%x[Cpanel], #0x10]\n" + "str q10, [%x[Cpanel], #0x20]\n" + "str q11, [%x[Cpanel], #0x30]\n" + "str q12, [%x[Cpanel], #0x40]\n" + "str q13, [%x[Cpanel], #0x50]\n" + "str q14, [%x[Cpanel], #0x60]\n" + "str q15, [%x[Cpanel], #0x70]\n" + "str q16, [%x[Cpanel], #0x80]\n" + "str q17, [%x[Cpanel], #0x90]\n" + "str q18, [%x[Cpanel], #0xa0]\n" + "str q19, [%x[Cpanel], #0xb0]\n" + "str q20, [%x[Cpanel], #0xc0]\n" + "str q21, [%x[Cpanel], #0xd0]\n" + "str q22, [%x[Cpanel], #0xe0]\n" + "str q23, [%x[Cpanel], #0xf0]\n" + "str q24, [%x[Cpanel], #0x100]\n" + "str q25, [%x[Cpanel], #0x110]\n" + "str q26, [%x[Cpanel], #0x120]\n" + "str q27, [%x[Cpanel], #0x130]\n" + "str q28, [%x[Cpanel], #0x140]\n" + "str q29, [%x[Cpanel], #0x150]\n" + "str q30, [%x[Cpanel], #0x160]\n" + "str q31, [%x[Cpanel], #0x170]\n" + "add %x[Cpanel], %x[Cpanel], #0x180\n" + "bgt 2b\n" + "subs %x[ablocks], %x[ablocks], #0x1\n" + "bne 1b\n" + : [Apanel] "+&r" (Apanel), [Cpanel] "+&r" (Cpanel), [ablocks] "+&r" (ablocks) + : [args_ptr] "r" (&ka), [offsetof_B_stride] "I" (offsetof(KernelArgs, B_stride)), [offsetof_Bpanel] "I" (offsetof(KernelArgs, Bpanel)), [offsetof_K] "I" (offsetof(KernelArgs, K)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_cur_B_ptr] "I" (offsetof(KernelArgs, cur_B_ptr)) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x19", "x20", "x21", "x22", "x23", "x24" + ); +} + +} // namespace arm_gemm +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_interleaved_bf16fp32_mmla_8x12.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_interleaved_bf16fp32_mmla_8x12.hpp index 6ec6bd2ed8..17c93faca2 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_interleaved_bf16fp32_mmla_8x12.hpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_interleaved_bf16fp32_mmla_8x12.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021 Arm Limited. + * Copyright (c) 2019-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -79,7 +79,7 @@ class cls_a64_interleaved_bf16fp32_mmla_8x12 default: return { 31.54, 4.30, 7.33 }; case CPUModel::V1: - return { 41.44, 5.01, 5.64 }; + return { 59.94, 5.08, 9.83 }; case CPUModel::A510: return { 7.82, 4.05, 3.07 }; } @@ -91,7 +91,7 @@ class cls_a64_interleaved_bf16fp32_mmla_8x12 default: return { 31.15, 2.51, 5.25 }; case CPUModel::V1: - return { 59.44, 3.18, 7.26 }; + return { 41.44, 5.01, 5.64 }; case CPUModel::A510: return { 7.83, 2.53, 2.71 }; } diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_bf16fp32_mmla_6x4VL.hpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_bf16fp32_mmla_6x4VL.hpp new file mode 100644 index 0000000000..e07fa549f3 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_bf16fp32_mmla_6x4VL.hpp @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#pragma once +#ifdef ARM_COMPUTE_ENABLE_SVE + +#include "../std_transforms_sve.hpp" +#include "../bfloat.hpp" +#include "../kernel_weight_format.hpp" +#include "../performance_parameters.hpp" + +#define ARGLIST \ + unsigned int, const unsigned int *, \ + IndirectInputArg, \ + size_t, size_t, \ + const bfloat16 *, \ + size_t, \ + IndirectOutputArg, \ + const float *, Activation, bool + +namespace arm_gemm +{ +// Actual kernel implementations +void sve_ffhybrid_bf16fp32_mmla_6x4VL( ARGLIST ); + +class cls_sve_ffhybrid_bf16fp32_mmla_6x4VL +{ +public: + typedef bfloat16 lhs_operand_type; + typedef bfloat16 rhs_operand_type; + typedef float result_type; + + typedef void (*kern_type)( ARGLIST ); + + /* Kernel blocking parameters */ + static constexpr unsigned int out_height() + { + return 6; + } + static unsigned int stripe_width() + { + return get_vector_length() * 1; + } + + static KernelWeightFormat kernel_weight_format() + { + return KernelWeightFormat::VL2VL_BL64; + } + + static unsigned int out_width() + { + return get_vector_length() * 4; + } + + static constexpr unsigned int k_unroll() + { + return 4; + } + + static constexpr bool supports_accumulate() + { + return true; + } + + StdTransformsSVE transforms = {}; + template + static inline PerformanceParameters get_performance_parameters(const CPUInfo *ci) + { + if (std::is_same::value) { + switch (ci->get_cpu_model()) { + default: + return { 49.10 }; + } + } + + return { 1.0 }; + } + + // Default to the generic kernel + kern_type kernel=sve_ffhybrid_bf16fp32_mmla_6x4VL; + cls_sve_ffhybrid_bf16fp32_mmla_6x4VL(const CPUInfo *) + { + } +}; + +} // namespace arm_gemm + +#undef ARGLIST +#endif // ARM_COMPUTE_ENABLE_SVE diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_bf16fp32_mmla_6x4VL/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_bf16fp32_mmla_6x4VL/generic.cpp new file mode 100644 index 0000000000..c0b6b30762 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_bf16fp32_mmla_6x4VL/generic.cpp @@ -0,0 +1,2227 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#ifdef ARM_COMPUTE_ENABLE_SVE + +#include "arm_gemm.hpp" +#include "../../utils.hpp" +#include "../../bfloat.hpp" + +#include +#include + +namespace arm_gemm { + +void sve_ffhybrid_bf16fp32_mmla_6x4VL ( + unsigned int num_strings, const unsigned int *string_lengths, IndirectInputArg A_arg, + size_t M, size_t N, const bfloat16 *B_ptr, size_t B_stride, IndirectOutputArg output_arg, + const float *bias, Activation act, bool accumulate +) +{ + struct KernelArgs { + float maxval = static_cast(std::numeric_limits::infinity()); + float minval = - static_cast(std::numeric_limits::infinity()); + unsigned int num_strings = {}; + const unsigned int *string_lengths = {}; + size_t N = {}; + const bfloat16 *B_ptr = {}; + const bfloat16 *cur_B_ptr = {}; + size_t B_stride = {}; + size_t output_offset = {}; + size_t input_initial_col = {}; + size_t input_offset = {}; + } ka; + + unsigned long flags=0; + void *output_ptr; + void *input_ptr; + + if (output_arg.is_indirect) { + output_ptr=(void *)(output_arg.indirect.ptr); + ka.output_offset=output_arg.indirect.offset; + flags |= 0x4; + } else { + output_ptr=(void *)(output_arg.direct.base); + ka.output_offset=output_arg.direct.stride; + } + + if (A_arg.is_indirect) { + input_ptr=(void *)(A_arg.indirect.ptr); + ka.input_offset=A_arg.indirect.start_row; + ka.input_initial_col=A_arg.indirect.start_col; + flags |= 0x8; + } else { + assert(num_strings==1); + input_ptr=(void *)(A_arg.direct.base); + ka.input_offset=A_arg.direct.stride; + } + if (accumulate) { + flags |= 0x1; + } + ka.num_strings = num_strings; + ka.string_lengths = string_lengths; + ka.N = N; + ka.B_ptr = B_ptr; + ka.B_stride = B_stride; + switch(act.type) { + default: + case Activation::Type::None: + break; + case Activation::Type::BoundedReLU: + ka.maxval = static_cast(act.param1); + /* fall through */ + case Activation::Type::ReLU: + ka.minval = 0; + flags |= 0x2; + break; + } + __asm__ __volatile__( + "ptrue p5.b\n" + "1:" // Row loop + "cmp %x[M], #0x6\n" + "bge 71f\n" + "cmp %x[M], #0x4\n" + "bgt 57f\n" + "beq 43f\n" + "cmp %x[M], #0x2\n" + "bgt 29f\n" + "beq 15f\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "2:" // Height 1: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "cntw x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 3f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 3f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 3f\n" + "mov x10, x11\n" + "3:" // Height 1: B setup done + "mov x19, #0x0\n" + "whilelt p4.s, x19, x13\n" + "incw x19\n" + "whilelt p3.s, x19, x13\n" + "incw x19\n" + "whilelt p2.s, x19, x13\n" + "incw x19\n" + "whilelt p1.s, x19, x13\n" + "cbz x14, 4f\n" + "ld1w { z8.s }, p5/Z, [x14]\n" + "ld1w { z9.s }, p5/Z, [x14, #1, MUL VL]\n" + "zip2 z12.d, z8.d, z8.d\n" + "zip1 z8.d, z8.d, z8.d\n" + "ld1w { z10.s }, p5/Z, [x14, #2, MUL VL]\n" + "ld1w { z11.s }, p5/Z, [x14, #3, MUL VL]\n" + "zip2 z13.d, z9.d, z9.d\n" + "zip1 z9.d, z9.d, z9.d\n" + "zip2 z14.d, z10.d, z10.d\n" + "zip1 z10.d, z10.d, z10.d\n" + "addvl x14, x14, #4\n" + "zip2 z15.d, z11.d, z11.d\n" + "zip1 z11.d, z11.d, z11.d\n" + "b 6f\n" + "4:" // Height 1: no bias + "tbz %x[flags], #0, 5f\n" + "ld1w { z9.s }, p4/Z, [x12]\n" + "ld1w { z10.s }, p3/Z, [x12, #1, MUL VL]\n" + "zip1 z8.d, z9.d, z12.d\n" + "zip2 z12.d, z9.d, z12.d\n" + "ld1w { z11.s }, p2/Z, [x12, #2, MUL VL]\n" + "ld1w { z16.s }, p1/Z, [x12, #3, MUL VL]\n" + "zip1 z9.d, z10.d, z13.d\n" + "zip2 z13.d, z10.d, z13.d\n" + "zip1 z10.d, z11.d, z14.d\n" + "zip2 z14.d, z11.d, z14.d\n" + "zip1 z11.d, z16.d, z15.d\n" + "zip2 z15.d, z16.d, z15.d\n" + "b 6f\n" + "5:" // Height 1: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "6:" // Height 1: setup done + "mov x27, #0x0\n" + "7:" // Height 1: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 8f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "cbnz x27, 9f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "b 9f\n" + "8:" // Height 1: setup direct input + "mov x25, %x[input_ptr]\n" + "9:" // Height 1: input setup done + "cmp x26, #0x8\n" + "ble 11f\n" + "10:" // Height 1: Multiply loop: Main loop head + "whilelt p0.h, XZR, x26\n" + "ld1rqh { z1.h }, p0/Z, [x25]\n" + "trn1 z0.d, z1.d, z2.d\n" + "ld1h { z7.h }, p5/Z, [x11]\n" + "ld1h { z6.h }, p5/Z, [x11, #1, MUL VL]\n" + ".inst 0x6467e408 // bfmmla z8.s, z0.h, z7.h\n" + ".inst 0x6466e40c // bfmmla z12.s, z0.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "ld1h { z6.h }, p5/Z, [x10, #1, MUL VL]\n" + ".inst 0x6467e409 // bfmmla z9.s, z0.h, z7.h\n" + ".inst 0x6466e40d // bfmmla z13.s, z0.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x9]\n" + "ld1h { z6.h }, p5/Z, [x9, #1, MUL VL]\n" + ".inst 0x6467e40a // bfmmla z10.s, z0.h, z7.h\n" + ".inst 0x6466e40e // bfmmla z14.s, z0.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "ld1h { z6.h }, p5/Z, [x28, #1, MUL VL]\n" + "trn2 z1.d, z1.d, z2.d\n" + ".inst 0x6467e40b // bfmmla z11.s, z0.h, z7.h\n" + ".inst 0x6466e40f // bfmmla z15.s, z0.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x11, #2, MUL VL]\n" + "ld1h { z6.h }, p5/Z, [x11, #3, MUL VL]\n" + ".inst 0x6467e428 // bfmmla z8.s, z1.h, z7.h\n" + ".inst 0x6466e42c // bfmmla z12.s, z1.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x10, #2, MUL VL]\n" + "ld1h { z6.h }, p5/Z, [x10, #3, MUL VL]\n" + ".inst 0x6467e429 // bfmmla z9.s, z1.h, z7.h\n" + ".inst 0x6466e42d // bfmmla z13.s, z1.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x9, #2, MUL VL]\n" + "ld1h { z6.h }, p5/Z, [x9, #3, MUL VL]\n" + ".inst 0x6467e42a // bfmmla z10.s, z1.h, z7.h\n" + ".inst 0x6466e42e // bfmmla z14.s, z1.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x28, #2, MUL VL]\n" + "ld1h { z6.h }, p5/Z, [x28, #3, MUL VL]\n" + "sub x26, x26, #0x8\n" + "cmp x26, #0x8\n" + ".inst 0x6467e42b // bfmmla z11.s, z1.h, z7.h\n" + ".inst 0x6466e42f // bfmmla z15.s, z1.h, z6.h\n" + "add x25, x25, #0x10\n" + "addvl x11, x11, #4\n" + "addvl x10, x10, #4\n" + "addvl x9, x9, #4\n" + "addvl x28, x28, #4\n" + "bgt 10b\n" + "11:" // Height 1: Multiply loop: Single iteration only + "whilelt p0.h, XZR, x26\n" + "ld1rqh { z1.h }, p0/Z, [x25]\n" + "trn1 z0.d, z1.d, z2.d\n" + "ld1h { z7.h }, p5/Z, [x11]\n" + "ld1h { z6.h }, p5/Z, [x11, #1, MUL VL]\n" + ".inst 0x6467e408 // bfmmla z8.s, z0.h, z7.h\n" + ".inst 0x6466e40c // bfmmla z12.s, z0.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "ld1h { z6.h }, p5/Z, [x10, #1, MUL VL]\n" + ".inst 0x6467e409 // bfmmla z9.s, z0.h, z7.h\n" + ".inst 0x6466e40d // bfmmla z13.s, z0.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x9]\n" + "ld1h { z6.h }, p5/Z, [x9, #1, MUL VL]\n" + ".inst 0x6467e40a // bfmmla z10.s, z0.h, z7.h\n" + ".inst 0x6466e40e // bfmmla z14.s, z0.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "ld1h { z6.h }, p5/Z, [x28, #1, MUL VL]\n" + "subs x26, x26, #0x4\n" + "trn2 z1.d, z1.d, z2.d\n" + ".inst 0x6467e40b // bfmmla z11.s, z0.h, z7.h\n" + ".inst 0x6466e40f // bfmmla z15.s, z0.h, z6.h\n" + "addvl x11, x11, #2\n" + "addvl x10, x10, #2\n" + "addvl x9, x9, #2\n" + "addvl x28, x28, #2\n" + "ble 12f\n" + "ld1h { z7.h }, p5/Z, [x11]\n" + "ld1h { z6.h }, p5/Z, [x11, #1, MUL VL]\n" + ".inst 0x6467e428 // bfmmla z8.s, z1.h, z7.h\n" + ".inst 0x6466e42c // bfmmla z12.s, z1.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "ld1h { z6.h }, p5/Z, [x10, #1, MUL VL]\n" + ".inst 0x6467e429 // bfmmla z9.s, z1.h, z7.h\n" + ".inst 0x6466e42d // bfmmla z13.s, z1.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x9]\n" + "ld1h { z6.h }, p5/Z, [x9, #1, MUL VL]\n" + ".inst 0x6467e42a // bfmmla z10.s, z1.h, z7.h\n" + ".inst 0x6466e42e // bfmmla z14.s, z1.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "ld1h { z6.h }, p5/Z, [x28, #1, MUL VL]\n" + ".inst 0x6467e42b // bfmmla z11.s, z1.h, z7.h\n" + ".inst 0x6466e42f // bfmmla z15.s, z1.h, z6.h\n" + "addvl x11, x11, #2\n" + "addvl x10, x10, #2\n" + "addvl x9, x9, #2\n" + "addvl x28, x28, #2\n" + "12:" // Height 1: Multiply loop: multiply skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 7b\n" + "uzp1 z8.d, z8.d, z12.d\n" + "uzp1 z9.d, z9.d, z13.d\n" + "uzp1 z10.d, z10.d, z14.d\n" + "uzp1 z11.d, z11.d, z15.d\n" + "tbz %x[flags], #1, 13f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rw { z1.s }, p5/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rw { z0.s }, p5/Z, [x19]\n" + "fmin z8.s, p5/M, z8.s, z1.s\n" + "fmin z9.s, p5/M, z9.s, z1.s\n" + "fmin z10.s, p5/M, z10.s, z1.s\n" + "fmin z11.s, p5/M, z11.s, z1.s\n" + "fmax z8.s, p5/M, z8.s, z0.s\n" + "fmax z9.s, p5/M, z9.s, z0.s\n" + "fmax z10.s, p5/M, z10.s, z0.s\n" + "fmax z11.s, p5/M, z11.s, z0.s\n" + "13:" // Height 1: No activation + "st1w { z8.s }, p4, [x12]\n" + "st1w { z9.s }, p3, [x12, #1, MUL VL]\n" + "st1w { z10.s }, p2, [x12, #2, MUL VL]\n" + "st1w { z11.s }, p1, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "14:" // Height 1: Writeback done + "decw x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 2b\n" + "b 86f\n" + "15:" // Height 2 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "16:" // Height 2: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "cntw x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 17f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 17f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 17f\n" + "mov x10, x11\n" + "17:" // Height 2: B setup done + "mov x19, #0x0\n" + "whilelt p4.s, x19, x13\n" + "incw x19\n" + "whilelt p3.s, x19, x13\n" + "incw x19\n" + "whilelt p2.s, x19, x13\n" + "incw x19\n" + "whilelt p1.s, x19, x13\n" + "cbz x14, 18f\n" + "ld1w { z8.s }, p5/Z, [x14]\n" + "ld1w { z9.s }, p5/Z, [x14, #1, MUL VL]\n" + "zip2 z12.d, z8.d, z8.d\n" + "zip1 z8.d, z8.d, z8.d\n" + "ld1w { z10.s }, p5/Z, [x14, #2, MUL VL]\n" + "ld1w { z11.s }, p5/Z, [x14, #3, MUL VL]\n" + "zip2 z13.d, z9.d, z9.d\n" + "zip1 z9.d, z9.d, z9.d\n" + "zip2 z14.d, z10.d, z10.d\n" + "zip1 z10.d, z10.d, z10.d\n" + "addvl x14, x14, #4\n" + "zip2 z15.d, z11.d, z11.d\n" + "zip1 z11.d, z11.d, z11.d\n" + "b 20f\n" + "18:" // Height 2: no bias + "tbz %x[flags], #0, 19f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "ld1w { z9.s }, p4/Z, [x12]\n" + "ld1w { z10.s }, p3/Z, [x12, #1, MUL VL]\n" + "ld1w { z11.s }, p2/Z, [x12, #2, MUL VL]\n" + "ld1w { z16.s }, p1/Z, [x12, #3, MUL VL]\n" + "ld1w { z12.s }, p4/Z, [x24]\n" + "zip1 z8.d, z9.d, z12.d\n" + "zip2 z12.d, z9.d, z12.d\n" + "ld1w { z13.s }, p3/Z, [x24, #1, MUL VL]\n" + "ld1w { z14.s }, p2/Z, [x24, #2, MUL VL]\n" + "zip1 z9.d, z10.d, z13.d\n" + "zip2 z13.d, z10.d, z13.d\n" + "ld1w { z15.s }, p1/Z, [x24, #3, MUL VL]\n" + "zip1 z10.d, z11.d, z14.d\n" + "zip2 z14.d, z11.d, z14.d\n" + "zip1 z11.d, z16.d, z15.d\n" + "zip2 z15.d, z16.d, z15.d\n" + "b 20f\n" + "19:" // Height 2: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "20:" // Height 2: setup done + "mov x27, #0x0\n" + "21:" // Height 2: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 22f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "cbnz x27, 23f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "add x24, x24, x19, LSL #1\n" + "b 23f\n" + "22:" // Height 2: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #1\n" + "23:" // Height 2: input setup done + "cmp x26, #0x8\n" + "ble 25f\n" + "24:" // Height 2: Multiply loop: Main loop head + "whilelt p0.h, XZR, x26\n" + "ld1rqh { z1.h }, p0/Z, [x25]\n" + "ld1rqh { z2.h }, p0/Z, [x24]\n" + "trn1 z0.d, z1.d, z2.d\n" + "ld1h { z7.h }, p5/Z, [x11]\n" + "ld1h { z6.h }, p5/Z, [x11, #1, MUL VL]\n" + ".inst 0x6467e408 // bfmmla z8.s, z0.h, z7.h\n" + ".inst 0x6466e40c // bfmmla z12.s, z0.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "ld1h { z6.h }, p5/Z, [x10, #1, MUL VL]\n" + ".inst 0x6467e409 // bfmmla z9.s, z0.h, z7.h\n" + ".inst 0x6466e40d // bfmmla z13.s, z0.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x9]\n" + "ld1h { z6.h }, p5/Z, [x9, #1, MUL VL]\n" + ".inst 0x6467e40a // bfmmla z10.s, z0.h, z7.h\n" + ".inst 0x6466e40e // bfmmla z14.s, z0.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "ld1h { z6.h }, p5/Z, [x28, #1, MUL VL]\n" + "trn2 z1.d, z1.d, z2.d\n" + ".inst 0x6467e40b // bfmmla z11.s, z0.h, z7.h\n" + ".inst 0x6466e40f // bfmmla z15.s, z0.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x11, #2, MUL VL]\n" + "ld1h { z6.h }, p5/Z, [x11, #3, MUL VL]\n" + ".inst 0x6467e428 // bfmmla z8.s, z1.h, z7.h\n" + ".inst 0x6466e42c // bfmmla z12.s, z1.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x10, #2, MUL VL]\n" + "ld1h { z6.h }, p5/Z, [x10, #3, MUL VL]\n" + ".inst 0x6467e429 // bfmmla z9.s, z1.h, z7.h\n" + ".inst 0x6466e42d // bfmmla z13.s, z1.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x9, #2, MUL VL]\n" + "ld1h { z6.h }, p5/Z, [x9, #3, MUL VL]\n" + ".inst 0x6467e42a // bfmmla z10.s, z1.h, z7.h\n" + ".inst 0x6466e42e // bfmmla z14.s, z1.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x28, #2, MUL VL]\n" + "ld1h { z6.h }, p5/Z, [x28, #3, MUL VL]\n" + "sub x26, x26, #0x8\n" + "cmp x26, #0x8\n" + ".inst 0x6467e42b // bfmmla z11.s, z1.h, z7.h\n" + ".inst 0x6466e42f // bfmmla z15.s, z1.h, z6.h\n" + "add x25, x25, #0x10\n" + "add x24, x24, #0x10\n" + "addvl x11, x11, #4\n" + "addvl x10, x10, #4\n" + "addvl x9, x9, #4\n" + "addvl x28, x28, #4\n" + "bgt 24b\n" + "25:" // Height 2: Multiply loop: Single iteration only + "whilelt p0.h, XZR, x26\n" + "ld1rqh { z1.h }, p0/Z, [x25]\n" + "ld1rqh { z2.h }, p0/Z, [x24]\n" + "trn1 z0.d, z1.d, z2.d\n" + "ld1h { z7.h }, p5/Z, [x11]\n" + "ld1h { z6.h }, p5/Z, [x11, #1, MUL VL]\n" + ".inst 0x6467e408 // bfmmla z8.s, z0.h, z7.h\n" + ".inst 0x6466e40c // bfmmla z12.s, z0.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "ld1h { z6.h }, p5/Z, [x10, #1, MUL VL]\n" + ".inst 0x6467e409 // bfmmla z9.s, z0.h, z7.h\n" + ".inst 0x6466e40d // bfmmla z13.s, z0.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x9]\n" + "ld1h { z6.h }, p5/Z, [x9, #1, MUL VL]\n" + ".inst 0x6467e40a // bfmmla z10.s, z0.h, z7.h\n" + ".inst 0x6466e40e // bfmmla z14.s, z0.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "ld1h { z6.h }, p5/Z, [x28, #1, MUL VL]\n" + "subs x26, x26, #0x4\n" + "trn2 z1.d, z1.d, z2.d\n" + ".inst 0x6467e40b // bfmmla z11.s, z0.h, z7.h\n" + ".inst 0x6466e40f // bfmmla z15.s, z0.h, z6.h\n" + "addvl x11, x11, #2\n" + "addvl x10, x10, #2\n" + "addvl x9, x9, #2\n" + "addvl x28, x28, #2\n" + "ble 26f\n" + "ld1h { z7.h }, p5/Z, [x11]\n" + "ld1h { z6.h }, p5/Z, [x11, #1, MUL VL]\n" + ".inst 0x6467e428 // bfmmla z8.s, z1.h, z7.h\n" + ".inst 0x6466e42c // bfmmla z12.s, z1.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "ld1h { z6.h }, p5/Z, [x10, #1, MUL VL]\n" + ".inst 0x6467e429 // bfmmla z9.s, z1.h, z7.h\n" + ".inst 0x6466e42d // bfmmla z13.s, z1.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x9]\n" + "ld1h { z6.h }, p5/Z, [x9, #1, MUL VL]\n" + ".inst 0x6467e42a // bfmmla z10.s, z1.h, z7.h\n" + ".inst 0x6466e42e // bfmmla z14.s, z1.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "ld1h { z6.h }, p5/Z, [x28, #1, MUL VL]\n" + ".inst 0x6467e42b // bfmmla z11.s, z1.h, z7.h\n" + ".inst 0x6466e42f // bfmmla z15.s, z1.h, z6.h\n" + "addvl x11, x11, #2\n" + "addvl x10, x10, #2\n" + "addvl x9, x9, #2\n" + "addvl x28, x28, #2\n" + "26:" // Height 2: Multiply loop: multiply skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 21b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "uzp1 z7.d, z8.d, z12.d\n" + "uzp2 z8.d, z8.d, z12.d\n" + "add x24, x12, x19, LSL #2\n" + "uzp1 z12.d, z9.d, z13.d\n" + "uzp2 z9.d, z9.d, z13.d\n" + "uzp1 z13.d, z10.d, z14.d\n" + "uzp2 z10.d, z10.d, z14.d\n" + "uzp1 z14.d, z11.d, z15.d\n" + "uzp2 z11.d, z11.d, z15.d\n" + "tbz %x[flags], #1, 27f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rw { z1.s }, p5/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rw { z0.s }, p5/Z, [x19]\n" + "fmin z7.s, p5/M, z7.s, z1.s\n" + "fmin z12.s, p5/M, z12.s, z1.s\n" + "fmin z13.s, p5/M, z13.s, z1.s\n" + "fmin z14.s, p5/M, z14.s, z1.s\n" + "fmin z8.s, p5/M, z8.s, z1.s\n" + "fmin z9.s, p5/M, z9.s, z1.s\n" + "fmin z10.s, p5/M, z10.s, z1.s\n" + "fmin z11.s, p5/M, z11.s, z1.s\n" + "fmax z7.s, p5/M, z7.s, z0.s\n" + "fmax z12.s, p5/M, z12.s, z0.s\n" + "fmax z13.s, p5/M, z13.s, z0.s\n" + "fmax z14.s, p5/M, z14.s, z0.s\n" + "fmax z8.s, p5/M, z8.s, z0.s\n" + "fmax z9.s, p5/M, z9.s, z0.s\n" + "fmax z10.s, p5/M, z10.s, z0.s\n" + "fmax z11.s, p5/M, z11.s, z0.s\n" + "27:" // Height 2: No activation + "st1w { z7.s }, p4, [x12]\n" + "st1w { z12.s }, p3, [x12, #1, MUL VL]\n" + "st1w { z13.s }, p2, [x12, #2, MUL VL]\n" + "st1w { z14.s }, p1, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "st1w { z8.s }, p4, [x24]\n" + "st1w { z9.s }, p3, [x24, #1, MUL VL]\n" + "st1w { z10.s }, p2, [x24, #2, MUL VL]\n" + "st1w { z11.s }, p1, [x24, #3, MUL VL]\n" + "28:" // Height 2: Writeback done + "decw x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 16b\n" + "b 86f\n" + "29:" // Height 3 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "30:" // Height 3: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "cntw x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 31f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 31f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 31f\n" + "mov x10, x11\n" + "31:" // Height 3: B setup done + "mov x19, #0x0\n" + "whilelt p4.s, x19, x13\n" + "incw x19\n" + "whilelt p3.s, x19, x13\n" + "incw x19\n" + "whilelt p2.s, x19, x13\n" + "incw x19\n" + "whilelt p1.s, x19, x13\n" + "cbz x14, 32f\n" + "ld1w { z8.s }, p5/Z, [x14]\n" + "ld1w { z9.s }, p5/Z, [x14, #1, MUL VL]\n" + "zip2 z12.d, z8.d, z8.d\n" + "zip1 z8.d, z8.d, z8.d\n" + "ld1w { z10.s }, p5/Z, [x14, #2, MUL VL]\n" + "ld1w { z11.s }, p5/Z, [x14, #3, MUL VL]\n" + "zip2 z13.d, z9.d, z9.d\n" + "zip1 z9.d, z9.d, z9.d\n" + "zip2 z14.d, z10.d, z10.d\n" + "zip1 z10.d, z10.d, z10.d\n" + "addvl x14, x14, #4\n" + "zip2 z15.d, z11.d, z11.d\n" + "zip1 z11.d, z11.d, z11.d\n" + "mov z16.d, z8.d\n" + "mov z20.d, z12.d\n" + "mov z17.d, z9.d\n" + "mov z21.d, z13.d\n" + "mov z18.d, z10.d\n" + "mov z22.d, z14.d\n" + "mov z19.d, z11.d\n" + "mov z23.d, z15.d\n" + "b 34f\n" + "32:" // Height 3: no bias + "tbz %x[flags], #0, 33f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "ld1w { z9.s }, p4/Z, [x12]\n" + "ld1w { z10.s }, p3/Z, [x12, #1, MUL VL]\n" + "ld1w { z11.s }, p2/Z, [x12, #2, MUL VL]\n" + "ld1w { z16.s }, p1/Z, [x12, #3, MUL VL]\n" + "ld1w { z12.s }, p4/Z, [x24]\n" + "zip1 z8.d, z9.d, z12.d\n" + "zip2 z12.d, z9.d, z12.d\n" + "ld1w { z13.s }, p3/Z, [x24, #1, MUL VL]\n" + "ld1w { z14.s }, p2/Z, [x24, #2, MUL VL]\n" + "zip1 z9.d, z10.d, z13.d\n" + "zip2 z13.d, z10.d, z13.d\n" + "ld1w { z15.s }, p1/Z, [x24, #3, MUL VL]\n" + "ld1w { z17.s }, p4/Z, [x23]\n" + "zip1 z10.d, z11.d, z14.d\n" + "zip2 z14.d, z11.d, z14.d\n" + "ld1w { z18.s }, p3/Z, [x23, #1, MUL VL]\n" + "ld1w { z19.s }, p2/Z, [x23, #2, MUL VL]\n" + "zip1 z11.d, z16.d, z15.d\n" + "zip2 z15.d, z16.d, z15.d\n" + "ld1w { z24.s }, p1/Z, [x23, #3, MUL VL]\n" + "zip1 z16.d, z17.d, z20.d\n" + "zip2 z20.d, z17.d, z20.d\n" + "zip1 z17.d, z18.d, z21.d\n" + "zip2 z21.d, z18.d, z21.d\n" + "zip1 z18.d, z19.d, z22.d\n" + "zip2 z22.d, z19.d, z22.d\n" + "zip1 z19.d, z24.d, z23.d\n" + "zip2 z23.d, z24.d, z23.d\n" + "b 34f\n" + "33:" // Height 3: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "mov z19.b, #0x0\n" + "mov z20.b, #0x0\n" + "mov z21.b, #0x0\n" + "mov z22.b, #0x0\n" + "mov z23.b, #0x0\n" + "34:" // Height 3: setup done + "mov x27, #0x0\n" + "35:" // Height 3: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 36f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "cbnz x27, 37f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "add x24, x24, x19, LSL #1\n" + "add x23, x23, x19, LSL #1\n" + "b 37f\n" + "36:" // Height 3: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "37:" // Height 3: input setup done + "cmp x26, #0x8\n" + "ble 39f\n" + "38:" // Height 3: Multiply loop: Main loop head + "whilelt p0.h, XZR, x26\n" + "ld1rqh { z1.h }, p0/Z, [x25]\n" + "ld1rqh { z2.h }, p0/Z, [x24]\n" + "ld1rqh { z3.h }, p0/Z, [x23]\n" + "trn1 z0.d, z1.d, z2.d\n" + "trn2 z1.d, z1.d, z2.d\n" + "ld1h { z7.h }, p5/Z, [x11]\n" + "trn1 z2.d, z3.d, z4.d\n" + "ld1h { z6.h }, p5/Z, [x11, #1, MUL VL]\n" + ".inst 0x6467e408 // bfmmla z8.s, z0.h, z7.h\n" + ".inst 0x6467e450 // bfmmla z16.s, z2.h, z7.h\n" + ".inst 0x6466e40c // bfmmla z12.s, z0.h, z6.h\n" + ".inst 0x6466e454 // bfmmla z20.s, z2.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "ld1h { z6.h }, p5/Z, [x10, #1, MUL VL]\n" + ".inst 0x6467e409 // bfmmla z9.s, z0.h, z7.h\n" + ".inst 0x6467e451 // bfmmla z17.s, z2.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x9]\n" + "trn2 z3.d, z3.d, z4.d\n" + ".inst 0x6466e40d // bfmmla z13.s, z0.h, z6.h\n" + ".inst 0x6466e455 // bfmmla z21.s, z2.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x9, #1, MUL VL]\n" + "sub x26, x26, #0x8\n" + ".inst 0x6467e40a // bfmmla z10.s, z0.h, z7.h\n" + ".inst 0x6467e452 // bfmmla z18.s, z2.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "cmp x26, #0x8\n" + ".inst 0x6466e40e // bfmmla z14.s, z0.h, z6.h\n" + ".inst 0x6466e456 // bfmmla z22.s, z2.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x28, #1, MUL VL]\n" + "add x25, x25, #0x10\n" + ".inst 0x6467e40b // bfmmla z11.s, z0.h, z7.h\n" + ".inst 0x6467e453 // bfmmla z19.s, z2.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x11, #2, MUL VL]\n" + "add x24, x24, #0x10\n" + ".inst 0x6466e40f // bfmmla z15.s, z0.h, z6.h\n" + ".inst 0x6466e457 // bfmmla z23.s, z2.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x11, #3, MUL VL]\n" + "add x23, x23, #0x10\n" + ".inst 0x6467e428 // bfmmla z8.s, z1.h, z7.h\n" + ".inst 0x6467e470 // bfmmla z16.s, z3.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x10, #2, MUL VL]\n" + "addvl x11, x11, #4\n" + ".inst 0x6466e42c // bfmmla z12.s, z1.h, z6.h\n" + ".inst 0x6466e474 // bfmmla z20.s, z3.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x10, #3, MUL VL]\n" + "addvl x10, x10, #4\n" + ".inst 0x6467e429 // bfmmla z9.s, z1.h, z7.h\n" + ".inst 0x6467e471 // bfmmla z17.s, z3.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x9, #2, MUL VL]\n" + ".inst 0x6466e42d // bfmmla z13.s, z1.h, z6.h\n" + ".inst 0x6466e475 // bfmmla z21.s, z3.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x9, #3, MUL VL]\n" + "addvl x9, x9, #4\n" + ".inst 0x6467e42a // bfmmla z10.s, z1.h, z7.h\n" + ".inst 0x6467e472 // bfmmla z18.s, z3.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x28, #2, MUL VL]\n" + ".inst 0x6466e42e // bfmmla z14.s, z1.h, z6.h\n" + ".inst 0x6466e476 // bfmmla z22.s, z3.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x28, #3, MUL VL]\n" + "addvl x28, x28, #4\n" + ".inst 0x6467e42b // bfmmla z11.s, z1.h, z7.h\n" + ".inst 0x6467e473 // bfmmla z19.s, z3.h, z7.h\n" + ".inst 0x6466e42f // bfmmla z15.s, z1.h, z6.h\n" + ".inst 0x6466e477 // bfmmla z23.s, z3.h, z6.h\n" + "bgt 38b\n" + "39:" // Height 3: Multiply loop: Single iteration only + "whilelt p0.h, XZR, x26\n" + "ld1rqh { z1.h }, p0/Z, [x25]\n" + "ld1rqh { z2.h }, p0/Z, [x24]\n" + "ld1rqh { z3.h }, p0/Z, [x23]\n" + "trn1 z0.d, z1.d, z2.d\n" + "trn2 z1.d, z1.d, z2.d\n" + "ld1h { z7.h }, p5/Z, [x11]\n" + "trn1 z2.d, z3.d, z4.d\n" + "ld1h { z6.h }, p5/Z, [x11, #1, MUL VL]\n" + ".inst 0x6467e408 // bfmmla z8.s, z0.h, z7.h\n" + ".inst 0x6467e450 // bfmmla z16.s, z2.h, z7.h\n" + ".inst 0x6466e40c // bfmmla z12.s, z0.h, z6.h\n" + ".inst 0x6466e454 // bfmmla z20.s, z2.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "ld1h { z6.h }, p5/Z, [x10, #1, MUL VL]\n" + ".inst 0x6467e409 // bfmmla z9.s, z0.h, z7.h\n" + ".inst 0x6467e451 // bfmmla z17.s, z2.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x9]\n" + "subs x26, x26, #0x4\n" + ".inst 0x6466e40d // bfmmla z13.s, z0.h, z6.h\n" + ".inst 0x6466e455 // bfmmla z21.s, z2.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x9, #1, MUL VL]\n" + "trn2 z3.d, z3.d, z4.d\n" + ".inst 0x6467e40a // bfmmla z10.s, z0.h, z7.h\n" + ".inst 0x6467e452 // bfmmla z18.s, z2.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x11, x11, #2\n" + ".inst 0x6466e40e // bfmmla z14.s, z0.h, z6.h\n" + ".inst 0x6466e456 // bfmmla z22.s, z2.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x28, #1, MUL VL]\n" + "addvl x10, x10, #2\n" + ".inst 0x6467e40b // bfmmla z11.s, z0.h, z7.h\n" + ".inst 0x6467e453 // bfmmla z19.s, z2.h, z7.h\n" + "addvl x9, x9, #2\n" + "addvl x28, x28, #2\n" + ".inst 0x6466e40f // bfmmla z15.s, z0.h, z6.h\n" + ".inst 0x6466e457 // bfmmla z23.s, z2.h, z6.h\n" + "ble 40f\n" + "ld1h { z7.h }, p5/Z, [x11]\n" + "ld1h { z6.h }, p5/Z, [x11, #1, MUL VL]\n" + ".inst 0x6467e428 // bfmmla z8.s, z1.h, z7.h\n" + ".inst 0x6467e470 // bfmmla z16.s, z3.h, z7.h\n" + ".inst 0x6466e42c // bfmmla z12.s, z1.h, z6.h\n" + ".inst 0x6466e474 // bfmmla z20.s, z3.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "ld1h { z6.h }, p5/Z, [x10, #1, MUL VL]\n" + ".inst 0x6467e429 // bfmmla z9.s, z1.h, z7.h\n" + ".inst 0x6467e471 // bfmmla z17.s, z3.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x9]\n" + "addvl x11, x11, #2\n" + ".inst 0x6466e42d // bfmmla z13.s, z1.h, z6.h\n" + ".inst 0x6466e475 // bfmmla z21.s, z3.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x9, #1, MUL VL]\n" + "addvl x10, x10, #2\n" + ".inst 0x6467e42a // bfmmla z10.s, z1.h, z7.h\n" + ".inst 0x6467e472 // bfmmla z18.s, z3.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x9, x9, #2\n" + ".inst 0x6466e42e // bfmmla z14.s, z1.h, z6.h\n" + ".inst 0x6466e476 // bfmmla z22.s, z3.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x28, #1, MUL VL]\n" + "addvl x28, x28, #2\n" + ".inst 0x6467e42b // bfmmla z11.s, z1.h, z7.h\n" + ".inst 0x6467e473 // bfmmla z19.s, z3.h, z7.h\n" + ".inst 0x6466e42f // bfmmla z15.s, z1.h, z6.h\n" + ".inst 0x6466e477 // bfmmla z23.s, z3.h, z6.h\n" + "40:" // Height 3: Multiply loop: multiply skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 35b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "uzp1 z7.d, z8.d, z12.d\n" + "uzp2 z8.d, z8.d, z12.d\n" + "uzp1 z12.d, z9.d, z13.d\n" + "uzp2 z9.d, z9.d, z13.d\n" + "add x23, x24, x19, LSL #2\n" + "uzp1 z13.d, z10.d, z14.d\n" + "uzp2 z10.d, z10.d, z14.d\n" + "uzp1 z14.d, z11.d, z15.d\n" + "uzp2 z11.d, z11.d, z15.d\n" + "uzp1 z16.d, z16.d, z20.d\n" + "uzp1 z17.d, z17.d, z21.d\n" + "uzp1 z18.d, z18.d, z22.d\n" + "uzp1 z19.d, z19.d, z23.d\n" + "tbz %x[flags], #1, 41f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rw { z1.s }, p5/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rw { z0.s }, p5/Z, [x19]\n" + "fmin z7.s, p5/M, z7.s, z1.s\n" + "fmin z12.s, p5/M, z12.s, z1.s\n" + "fmin z13.s, p5/M, z13.s, z1.s\n" + "fmin z14.s, p5/M, z14.s, z1.s\n" + "fmin z8.s, p5/M, z8.s, z1.s\n" + "fmin z9.s, p5/M, z9.s, z1.s\n" + "fmin z10.s, p5/M, z10.s, z1.s\n" + "fmin z11.s, p5/M, z11.s, z1.s\n" + "fmin z16.s, p5/M, z16.s, z1.s\n" + "fmin z17.s, p5/M, z17.s, z1.s\n" + "fmin z18.s, p5/M, z18.s, z1.s\n" + "fmin z19.s, p5/M, z19.s, z1.s\n" + "fmax z7.s, p5/M, z7.s, z0.s\n" + "fmax z12.s, p5/M, z12.s, z0.s\n" + "fmax z13.s, p5/M, z13.s, z0.s\n" + "fmax z14.s, p5/M, z14.s, z0.s\n" + "fmax z8.s, p5/M, z8.s, z0.s\n" + "fmax z9.s, p5/M, z9.s, z0.s\n" + "fmax z10.s, p5/M, z10.s, z0.s\n" + "fmax z11.s, p5/M, z11.s, z0.s\n" + "fmax z16.s, p5/M, z16.s, z0.s\n" + "fmax z17.s, p5/M, z17.s, z0.s\n" + "fmax z18.s, p5/M, z18.s, z0.s\n" + "fmax z19.s, p5/M, z19.s, z0.s\n" + "41:" // Height 3: No activation + "st1w { z7.s }, p4, [x12]\n" + "st1w { z12.s }, p3, [x12, #1, MUL VL]\n" + "st1w { z13.s }, p2, [x12, #2, MUL VL]\n" + "st1w { z14.s }, p1, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "st1w { z8.s }, p4, [x24]\n" + "st1w { z9.s }, p3, [x24, #1, MUL VL]\n" + "st1w { z10.s }, p2, [x24, #2, MUL VL]\n" + "st1w { z11.s }, p1, [x24, #3, MUL VL]\n" + "st1w { z16.s }, p4, [x23]\n" + "st1w { z17.s }, p3, [x23, #1, MUL VL]\n" + "st1w { z18.s }, p2, [x23, #2, MUL VL]\n" + "st1w { z19.s }, p1, [x23, #3, MUL VL]\n" + "42:" // Height 3: Writeback done + "decw x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 30b\n" + "b 86f\n" + "43:" // Height 4 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "44:" // Height 4: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "cntw x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 45f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 45f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 45f\n" + "mov x10, x11\n" + "45:" // Height 4: B setup done + "mov x19, #0x0\n" + "whilelt p4.s, x19, x13\n" + "incw x19\n" + "whilelt p3.s, x19, x13\n" + "incw x19\n" + "whilelt p2.s, x19, x13\n" + "incw x19\n" + "whilelt p1.s, x19, x13\n" + "cbz x14, 46f\n" + "ld1w { z8.s }, p5/Z, [x14]\n" + "ld1w { z9.s }, p5/Z, [x14, #1, MUL VL]\n" + "zip2 z12.d, z8.d, z8.d\n" + "zip1 z8.d, z8.d, z8.d\n" + "ld1w { z10.s }, p5/Z, [x14, #2, MUL VL]\n" + "ld1w { z11.s }, p5/Z, [x14, #3, MUL VL]\n" + "zip2 z13.d, z9.d, z9.d\n" + "zip1 z9.d, z9.d, z9.d\n" + "zip2 z14.d, z10.d, z10.d\n" + "zip1 z10.d, z10.d, z10.d\n" + "addvl x14, x14, #4\n" + "zip2 z15.d, z11.d, z11.d\n" + "zip1 z11.d, z11.d, z11.d\n" + "mov z16.d, z8.d\n" + "mov z20.d, z12.d\n" + "mov z17.d, z9.d\n" + "mov z21.d, z13.d\n" + "mov z18.d, z10.d\n" + "mov z22.d, z14.d\n" + "mov z19.d, z11.d\n" + "mov z23.d, z15.d\n" + "b 48f\n" + "46:" // Height 4: no bias + "tbz %x[flags], #0, 47f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "ld1w { z9.s }, p4/Z, [x12]\n" + "add x22, x23, x19, LSL #2\n" + "ld1w { z10.s }, p3/Z, [x12, #1, MUL VL]\n" + "ld1w { z11.s }, p2/Z, [x12, #2, MUL VL]\n" + "ld1w { z16.s }, p1/Z, [x12, #3, MUL VL]\n" + "ld1w { z12.s }, p4/Z, [x24]\n" + "zip1 z8.d, z9.d, z12.d\n" + "zip2 z12.d, z9.d, z12.d\n" + "ld1w { z13.s }, p3/Z, [x24, #1, MUL VL]\n" + "ld1w { z14.s }, p2/Z, [x24, #2, MUL VL]\n" + "zip1 z9.d, z10.d, z13.d\n" + "zip2 z13.d, z10.d, z13.d\n" + "ld1w { z15.s }, p1/Z, [x24, #3, MUL VL]\n" + "ld1w { z17.s }, p4/Z, [x23]\n" + "zip1 z10.d, z11.d, z14.d\n" + "zip2 z14.d, z11.d, z14.d\n" + "ld1w { z18.s }, p3/Z, [x23, #1, MUL VL]\n" + "ld1w { z19.s }, p2/Z, [x23, #2, MUL VL]\n" + "zip1 z11.d, z16.d, z15.d\n" + "zip2 z15.d, z16.d, z15.d\n" + "ld1w { z24.s }, p1/Z, [x23, #3, MUL VL]\n" + "ld1w { z20.s }, p4/Z, [x22]\n" + "zip1 z16.d, z17.d, z20.d\n" + "zip2 z20.d, z17.d, z20.d\n" + "ld1w { z21.s }, p3/Z, [x22, #1, MUL VL]\n" + "ld1w { z22.s }, p2/Z, [x22, #2, MUL VL]\n" + "zip1 z17.d, z18.d, z21.d\n" + "zip2 z21.d, z18.d, z21.d\n" + "ld1w { z23.s }, p1/Z, [x22, #3, MUL VL]\n" + "zip1 z18.d, z19.d, z22.d\n" + "zip2 z22.d, z19.d, z22.d\n" + "zip1 z19.d, z24.d, z23.d\n" + "zip2 z23.d, z24.d, z23.d\n" + "b 48f\n" + "47:" // Height 4: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "mov z19.b, #0x0\n" + "mov z20.b, #0x0\n" + "mov z21.b, #0x0\n" + "mov z22.b, #0x0\n" + "mov z23.b, #0x0\n" + "48:" // Height 4: setup done + "mov x27, #0x0\n" + "49:" // Height 4: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 50f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "ldr x22, [x20, #0x18]\n" + "cbnz x27, 51f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "add x24, x24, x19, LSL #1\n" + "add x23, x23, x19, LSL #1\n" + "add x22, x22, x19, LSL #1\n" + "b 51f\n" + "50:" // Height 4: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "add x22, x23, x19, LSL #1\n" + "51:" // Height 4: input setup done + "cmp x26, #0x8\n" + "ble 53f\n" + "52:" // Height 4: Multiply loop: Main loop head + "whilelt p0.h, XZR, x26\n" + "ld1rqh { z1.h }, p0/Z, [x25]\n" + "ld1rqh { z2.h }, p0/Z, [x24]\n" + "trn1 z0.d, z1.d, z2.d\n" + "ld1rqh { z3.h }, p0/Z, [x23]\n" + "ld1rqh { z4.h }, p0/Z, [x22]\n" + "trn2 z1.d, z1.d, z2.d\n" + "trn1 z2.d, z3.d, z4.d\n" + "ld1h { z7.h }, p5/Z, [x11]\n" + "ld1h { z6.h }, p5/Z, [x11, #1, MUL VL]\n" + ".inst 0x6467e408 // bfmmla z8.s, z0.h, z7.h\n" + ".inst 0x6467e450 // bfmmla z16.s, z2.h, z7.h\n" + ".inst 0x6466e40c // bfmmla z12.s, z0.h, z6.h\n" + ".inst 0x6466e454 // bfmmla z20.s, z2.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "ld1h { z6.h }, p5/Z, [x10, #1, MUL VL]\n" + ".inst 0x6467e409 // bfmmla z9.s, z0.h, z7.h\n" + ".inst 0x6467e451 // bfmmla z17.s, z2.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x9]\n" + "trn2 z3.d, z3.d, z4.d\n" + ".inst 0x6466e40d // bfmmla z13.s, z0.h, z6.h\n" + ".inst 0x6466e455 // bfmmla z21.s, z2.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x9, #1, MUL VL]\n" + "sub x26, x26, #0x8\n" + ".inst 0x6467e40a // bfmmla z10.s, z0.h, z7.h\n" + ".inst 0x6467e452 // bfmmla z18.s, z2.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "cmp x26, #0x8\n" + ".inst 0x6466e40e // bfmmla z14.s, z0.h, z6.h\n" + ".inst 0x6466e456 // bfmmla z22.s, z2.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x28, #1, MUL VL]\n" + "add x25, x25, #0x10\n" + ".inst 0x6467e40b // bfmmla z11.s, z0.h, z7.h\n" + ".inst 0x6467e453 // bfmmla z19.s, z2.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x11, #2, MUL VL]\n" + "add x24, x24, #0x10\n" + ".inst 0x6466e40f // bfmmla z15.s, z0.h, z6.h\n" + ".inst 0x6466e457 // bfmmla z23.s, z2.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x11, #3, MUL VL]\n" + "add x23, x23, #0x10\n" + ".inst 0x6467e428 // bfmmla z8.s, z1.h, z7.h\n" + ".inst 0x6467e470 // bfmmla z16.s, z3.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x10, #2, MUL VL]\n" + "add x22, x22, #0x10\n" + ".inst 0x6466e42c // bfmmla z12.s, z1.h, z6.h\n" + ".inst 0x6466e474 // bfmmla z20.s, z3.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x10, #3, MUL VL]\n" + "addvl x11, x11, #4\n" + ".inst 0x6467e429 // bfmmla z9.s, z1.h, z7.h\n" + ".inst 0x6467e471 // bfmmla z17.s, z3.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x9, #2, MUL VL]\n" + "addvl x10, x10, #4\n" + ".inst 0x6466e42d // bfmmla z13.s, z1.h, z6.h\n" + ".inst 0x6466e475 // bfmmla z21.s, z3.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x9, #3, MUL VL]\n" + "addvl x9, x9, #4\n" + ".inst 0x6467e42a // bfmmla z10.s, z1.h, z7.h\n" + ".inst 0x6467e472 // bfmmla z18.s, z3.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x28, #2, MUL VL]\n" + ".inst 0x6466e42e // bfmmla z14.s, z1.h, z6.h\n" + ".inst 0x6466e476 // bfmmla z22.s, z3.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x28, #3, MUL VL]\n" + "addvl x28, x28, #4\n" + ".inst 0x6467e42b // bfmmla z11.s, z1.h, z7.h\n" + ".inst 0x6467e473 // bfmmla z19.s, z3.h, z7.h\n" + ".inst 0x6466e42f // bfmmla z15.s, z1.h, z6.h\n" + ".inst 0x6466e477 // bfmmla z23.s, z3.h, z6.h\n" + "bgt 52b\n" + "53:" // Height 4: Multiply loop: Single iteration only + "whilelt p0.h, XZR, x26\n" + "ld1rqh { z1.h }, p0/Z, [x25]\n" + "ld1rqh { z2.h }, p0/Z, [x24]\n" + "trn1 z0.d, z1.d, z2.d\n" + "ld1rqh { z3.h }, p0/Z, [x23]\n" + "ld1rqh { z4.h }, p0/Z, [x22]\n" + "trn2 z1.d, z1.d, z2.d\n" + "trn1 z2.d, z3.d, z4.d\n" + "ld1h { z7.h }, p5/Z, [x11]\n" + "ld1h { z6.h }, p5/Z, [x11, #1, MUL VL]\n" + ".inst 0x6467e408 // bfmmla z8.s, z0.h, z7.h\n" + ".inst 0x6467e450 // bfmmla z16.s, z2.h, z7.h\n" + ".inst 0x6466e40c // bfmmla z12.s, z0.h, z6.h\n" + ".inst 0x6466e454 // bfmmla z20.s, z2.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "ld1h { z6.h }, p5/Z, [x10, #1, MUL VL]\n" + ".inst 0x6467e409 // bfmmla z9.s, z0.h, z7.h\n" + ".inst 0x6467e451 // bfmmla z17.s, z2.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x9]\n" + "subs x26, x26, #0x4\n" + ".inst 0x6466e40d // bfmmla z13.s, z0.h, z6.h\n" + ".inst 0x6466e455 // bfmmla z21.s, z2.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x9, #1, MUL VL]\n" + "trn2 z3.d, z3.d, z4.d\n" + ".inst 0x6467e40a // bfmmla z10.s, z0.h, z7.h\n" + ".inst 0x6467e452 // bfmmla z18.s, z2.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x11, x11, #2\n" + ".inst 0x6466e40e // bfmmla z14.s, z0.h, z6.h\n" + ".inst 0x6466e456 // bfmmla z22.s, z2.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x28, #1, MUL VL]\n" + "addvl x10, x10, #2\n" + ".inst 0x6467e40b // bfmmla z11.s, z0.h, z7.h\n" + ".inst 0x6467e453 // bfmmla z19.s, z2.h, z7.h\n" + "addvl x9, x9, #2\n" + "addvl x28, x28, #2\n" + ".inst 0x6466e40f // bfmmla z15.s, z0.h, z6.h\n" + ".inst 0x6466e457 // bfmmla z23.s, z2.h, z6.h\n" + "ble 54f\n" + "ld1h { z7.h }, p5/Z, [x11]\n" + "ld1h { z6.h }, p5/Z, [x11, #1, MUL VL]\n" + ".inst 0x6467e428 // bfmmla z8.s, z1.h, z7.h\n" + ".inst 0x6467e470 // bfmmla z16.s, z3.h, z7.h\n" + ".inst 0x6466e42c // bfmmla z12.s, z1.h, z6.h\n" + ".inst 0x6466e474 // bfmmla z20.s, z3.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "ld1h { z6.h }, p5/Z, [x10, #1, MUL VL]\n" + ".inst 0x6467e429 // bfmmla z9.s, z1.h, z7.h\n" + ".inst 0x6467e471 // bfmmla z17.s, z3.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x9]\n" + "addvl x11, x11, #2\n" + ".inst 0x6466e42d // bfmmla z13.s, z1.h, z6.h\n" + ".inst 0x6466e475 // bfmmla z21.s, z3.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x9, #1, MUL VL]\n" + "addvl x10, x10, #2\n" + ".inst 0x6467e42a // bfmmla z10.s, z1.h, z7.h\n" + ".inst 0x6467e472 // bfmmla z18.s, z3.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x9, x9, #2\n" + ".inst 0x6466e42e // bfmmla z14.s, z1.h, z6.h\n" + ".inst 0x6466e476 // bfmmla z22.s, z3.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x28, #1, MUL VL]\n" + "addvl x28, x28, #2\n" + ".inst 0x6467e42b // bfmmla z11.s, z1.h, z7.h\n" + ".inst 0x6467e473 // bfmmla z19.s, z3.h, z7.h\n" + ".inst 0x6466e42f // bfmmla z15.s, z1.h, z6.h\n" + ".inst 0x6466e477 // bfmmla z23.s, z3.h, z6.h\n" + "54:" // Height 4: Multiply loop: multiply skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 49b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "uzp1 z7.d, z8.d, z12.d\n" + "uzp2 z8.d, z8.d, z12.d\n" + "uzp1 z12.d, z9.d, z13.d\n" + "add x22, x23, x19, LSL #2\n" + "uzp2 z9.d, z9.d, z13.d\n" + "uzp1 z13.d, z10.d, z14.d\n" + "uzp2 z10.d, z10.d, z14.d\n" + "uzp1 z14.d, z11.d, z15.d\n" + "uzp2 z11.d, z11.d, z15.d\n" + "uzp1 z15.d, z16.d, z20.d\n" + "uzp2 z16.d, z16.d, z20.d\n" + "uzp1 z20.d, z17.d, z21.d\n" + "uzp2 z17.d, z17.d, z21.d\n" + "uzp1 z21.d, z18.d, z22.d\n" + "uzp2 z18.d, z18.d, z22.d\n" + "uzp1 z22.d, z19.d, z23.d\n" + "uzp2 z19.d, z19.d, z23.d\n" + "tbz %x[flags], #1, 55f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rw { z1.s }, p5/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rw { z0.s }, p5/Z, [x19]\n" + "fmin z7.s, p5/M, z7.s, z1.s\n" + "fmin z12.s, p5/M, z12.s, z1.s\n" + "fmin z13.s, p5/M, z13.s, z1.s\n" + "fmin z14.s, p5/M, z14.s, z1.s\n" + "fmin z8.s, p5/M, z8.s, z1.s\n" + "fmin z9.s, p5/M, z9.s, z1.s\n" + "fmin z10.s, p5/M, z10.s, z1.s\n" + "fmin z11.s, p5/M, z11.s, z1.s\n" + "fmin z15.s, p5/M, z15.s, z1.s\n" + "fmin z20.s, p5/M, z20.s, z1.s\n" + "fmin z21.s, p5/M, z21.s, z1.s\n" + "fmin z22.s, p5/M, z22.s, z1.s\n" + "fmin z16.s, p5/M, z16.s, z1.s\n" + "fmin z17.s, p5/M, z17.s, z1.s\n" + "fmin z18.s, p5/M, z18.s, z1.s\n" + "fmin z19.s, p5/M, z19.s, z1.s\n" + "fmax z7.s, p5/M, z7.s, z0.s\n" + "fmax z12.s, p5/M, z12.s, z0.s\n" + "fmax z13.s, p5/M, z13.s, z0.s\n" + "fmax z14.s, p5/M, z14.s, z0.s\n" + "fmax z8.s, p5/M, z8.s, z0.s\n" + "fmax z9.s, p5/M, z9.s, z0.s\n" + "fmax z10.s, p5/M, z10.s, z0.s\n" + "fmax z11.s, p5/M, z11.s, z0.s\n" + "fmax z15.s, p5/M, z15.s, z0.s\n" + "fmax z20.s, p5/M, z20.s, z0.s\n" + "fmax z21.s, p5/M, z21.s, z0.s\n" + "fmax z22.s, p5/M, z22.s, z0.s\n" + "fmax z16.s, p5/M, z16.s, z0.s\n" + "fmax z17.s, p5/M, z17.s, z0.s\n" + "fmax z18.s, p5/M, z18.s, z0.s\n" + "fmax z19.s, p5/M, z19.s, z0.s\n" + "55:" // Height 4: No activation + "st1w { z7.s }, p4, [x12]\n" + "st1w { z12.s }, p3, [x12, #1, MUL VL]\n" + "st1w { z13.s }, p2, [x12, #2, MUL VL]\n" + "st1w { z14.s }, p1, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "st1w { z8.s }, p4, [x24]\n" + "st1w { z9.s }, p3, [x24, #1, MUL VL]\n" + "st1w { z10.s }, p2, [x24, #2, MUL VL]\n" + "st1w { z11.s }, p1, [x24, #3, MUL VL]\n" + "st1w { z15.s }, p4, [x23]\n" + "st1w { z20.s }, p3, [x23, #1, MUL VL]\n" + "st1w { z21.s }, p2, [x23, #2, MUL VL]\n" + "st1w { z22.s }, p1, [x23, #3, MUL VL]\n" + "st1w { z16.s }, p4, [x22]\n" + "st1w { z17.s }, p3, [x22, #1, MUL VL]\n" + "st1w { z18.s }, p2, [x22, #2, MUL VL]\n" + "st1w { z19.s }, p1, [x22, #3, MUL VL]\n" + "56:" // Height 4: Writeback done + "decw x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 44b\n" + "b 86f\n" + "57:" // Height 5 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "58:" // Height 5: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "cntw x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 59f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 59f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 59f\n" + "mov x10, x11\n" + "59:" // Height 5: B setup done + "mov x19, #0x0\n" + "whilelt p4.s, x19, x13\n" + "incw x19\n" + "whilelt p3.s, x19, x13\n" + "incw x19\n" + "whilelt p2.s, x19, x13\n" + "incw x19\n" + "whilelt p1.s, x19, x13\n" + "cbz x14, 60f\n" + "ld1w { z8.s }, p5/Z, [x14]\n" + "ld1w { z9.s }, p5/Z, [x14, #1, MUL VL]\n" + "zip2 z12.d, z8.d, z8.d\n" + "zip1 z8.d, z8.d, z8.d\n" + "ld1w { z10.s }, p5/Z, [x14, #2, MUL VL]\n" + "ld1w { z11.s }, p5/Z, [x14, #3, MUL VL]\n" + "zip2 z13.d, z9.d, z9.d\n" + "zip1 z9.d, z9.d, z9.d\n" + "zip2 z14.d, z10.d, z10.d\n" + "zip1 z10.d, z10.d, z10.d\n" + "addvl x14, x14, #4\n" + "zip2 z15.d, z11.d, z11.d\n" + "zip1 z11.d, z11.d, z11.d\n" + "mov z16.d, z8.d\n" + "mov z20.d, z12.d\n" + "mov z17.d, z9.d\n" + "mov z21.d, z13.d\n" + "mov z18.d, z10.d\n" + "mov z22.d, z14.d\n" + "mov z19.d, z11.d\n" + "mov z23.d, z15.d\n" + "mov z24.d, z8.d\n" + "mov z28.d, z12.d\n" + "mov z25.d, z9.d\n" + "mov z29.d, z13.d\n" + "mov z26.d, z10.d\n" + "mov z30.d, z14.d\n" + "mov z27.d, z11.d\n" + "mov z31.d, z15.d\n" + "b 62f\n" + "60:" // Height 5: no bias + "tbz %x[flags], #0, 61f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "ld1w { z9.s }, p4/Z, [x12]\n" + "add x22, x23, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "ld1w { z10.s }, p3/Z, [x12, #1, MUL VL]\n" + "ld1w { z11.s }, p2/Z, [x12, #2, MUL VL]\n" + "ld1w { z16.s }, p1/Z, [x12, #3, MUL VL]\n" + "ld1w { z12.s }, p4/Z, [x24]\n" + "zip1 z8.d, z9.d, z12.d\n" + "zip2 z12.d, z9.d, z12.d\n" + "ld1w { z13.s }, p3/Z, [x24, #1, MUL VL]\n" + "ld1w { z14.s }, p2/Z, [x24, #2, MUL VL]\n" + "zip1 z9.d, z10.d, z13.d\n" + "zip2 z13.d, z10.d, z13.d\n" + "ld1w { z15.s }, p1/Z, [x24, #3, MUL VL]\n" + "ld1w { z17.s }, p4/Z, [x23]\n" + "zip1 z10.d, z11.d, z14.d\n" + "zip2 z14.d, z11.d, z14.d\n" + "ld1w { z18.s }, p3/Z, [x23, #1, MUL VL]\n" + "ld1w { z19.s }, p2/Z, [x23, #2, MUL VL]\n" + "zip1 z11.d, z16.d, z15.d\n" + "zip2 z15.d, z16.d, z15.d\n" + "ld1w { z24.s }, p1/Z, [x23, #3, MUL VL]\n" + "ld1w { z20.s }, p4/Z, [x22]\n" + "zip1 z16.d, z17.d, z20.d\n" + "zip2 z20.d, z17.d, z20.d\n" + "ld1w { z21.s }, p3/Z, [x22, #1, MUL VL]\n" + "ld1w { z22.s }, p2/Z, [x22, #2, MUL VL]\n" + "zip1 z17.d, z18.d, z21.d\n" + "zip2 z21.d, z18.d, z21.d\n" + "ld1w { z23.s }, p1/Z, [x22, #3, MUL VL]\n" + "ld1w { z25.s }, p4/Z, [x21]\n" + "zip1 z18.d, z19.d, z22.d\n" + "zip2 z22.d, z19.d, z22.d\n" + "ld1w { z26.s }, p3/Z, [x21, #1, MUL VL]\n" + "ld1w { z27.s }, p2/Z, [x21, #2, MUL VL]\n" + "zip1 z19.d, z24.d, z23.d\n" + "zip2 z23.d, z24.d, z23.d\n" + "ld1w { z6.s }, p1/Z, [x21, #3, MUL VL]\n" + "zip1 z24.d, z25.d, z28.d\n" + "zip2 z28.d, z25.d, z28.d\n" + "zip1 z25.d, z26.d, z29.d\n" + "zip2 z29.d, z26.d, z29.d\n" + "zip1 z26.d, z27.d, z30.d\n" + "zip2 z30.d, z27.d, z30.d\n" + "zip1 z27.d, z6.d, z31.d\n" + "zip2 z31.d, z6.d, z31.d\n" + "b 62f\n" + "61:" // Height 5: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "mov z19.b, #0x0\n" + "mov z20.b, #0x0\n" + "mov z21.b, #0x0\n" + "mov z22.b, #0x0\n" + "mov z23.b, #0x0\n" + "mov z24.b, #0x0\n" + "mov z25.b, #0x0\n" + "mov z26.b, #0x0\n" + "mov z27.b, #0x0\n" + "mov z28.b, #0x0\n" + "mov z29.b, #0x0\n" + "mov z30.b, #0x0\n" + "mov z31.b, #0x0\n" + "62:" // Height 5: setup done + "mov x27, #0x0\n" + "63:" // Height 5: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 64f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "ldr x22, [x20, #0x18]\n" + "ldr x21, [x20, #0x20]\n" + "cbnz x27, 65f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "add x24, x24, x19, LSL #1\n" + "add x23, x23, x19, LSL #1\n" + "add x22, x22, x19, LSL #1\n" + "add x21, x21, x19, LSL #1\n" + "b 65f\n" + "64:" // Height 5: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "add x22, x23, x19, LSL #1\n" + "add x21, x22, x19, LSL #1\n" + "65:" // Height 5: input setup done + "cmp x26, #0x8\n" + "ble 67f\n" + "66:" // Height 5: Multiply loop: Main loop head + "whilelt p0.h, XZR, x26\n" + "ld1rqh { z1.h }, p0/Z, [x25]\n" + "ld1rqh { z2.h }, p0/Z, [x24]\n" + "ld1rqh { z3.h }, p0/Z, [x23]\n" + "ld1rqh { z4.h }, p0/Z, [x22]\n" + "trn1 z0.d, z1.d, z2.d\n" + "trn2 z1.d, z1.d, z2.d\n" + "ld1rqh { z5.h }, p0/Z, [x21]\n" + "trn1 z2.d, z3.d, z4.d\n" + "trn2 z3.d, z3.d, z4.d\n" + "ld1h { z7.h }, p5/Z, [x11]\n" + "trn1 z4.d, z5.d, z6.d\n" + "trn2 z5.d, z5.d, z6.d\n" + "ld1h { z6.h }, p5/Z, [x11, #1, MUL VL]\n" + ".inst 0x6467e408 // bfmmla z8.s, z0.h, z7.h\n" + ".inst 0x6467e450 // bfmmla z16.s, z2.h, z7.h\n" + ".inst 0x6467e498 // bfmmla z24.s, z4.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "sub x26, x26, #0x8\n" + ".inst 0x6466e40c // bfmmla z12.s, z0.h, z6.h\n" + ".inst 0x6466e454 // bfmmla z20.s, z2.h, z6.h\n" + "cmp x26, #0x8\n" + "add x25, x25, #0x10\n" + ".inst 0x6466e49c // bfmmla z28.s, z4.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x10, #1, MUL VL]\n" + ".inst 0x6467e409 // bfmmla z9.s, z0.h, z7.h\n" + "add x24, x24, #0x10\n" + ".inst 0x6467e451 // bfmmla z17.s, z2.h, z7.h\n" + ".inst 0x6467e499 // bfmmla z25.s, z4.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x9]\n" + "add x23, x23, #0x10\n" + ".inst 0x6466e40d // bfmmla z13.s, z0.h, z6.h\n" + ".inst 0x6466e455 // bfmmla z21.s, z2.h, z6.h\n" + "add x22, x22, #0x10\n" + "add x21, x21, #0x10\n" + ".inst 0x6466e49d // bfmmla z29.s, z4.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x9, #1, MUL VL]\n" + ".inst 0x6467e40a // bfmmla z10.s, z0.h, z7.h\n" + ".inst 0x6467e452 // bfmmla z18.s, z2.h, z7.h\n" + ".inst 0x6467e49a // bfmmla z26.s, z4.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + ".inst 0x6466e40e // bfmmla z14.s, z0.h, z6.h\n" + ".inst 0x6466e456 // bfmmla z22.s, z2.h, z6.h\n" + ".inst 0x6466e49e // bfmmla z30.s, z4.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x28, #1, MUL VL]\n" + ".inst 0x6467e40b // bfmmla z11.s, z0.h, z7.h\n" + ".inst 0x6467e453 // bfmmla z19.s, z2.h, z7.h\n" + ".inst 0x6467e49b // bfmmla z27.s, z4.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x11, #2, MUL VL]\n" + ".inst 0x6466e40f // bfmmla z15.s, z0.h, z6.h\n" + ".inst 0x6466e457 // bfmmla z23.s, z2.h, z6.h\n" + ".inst 0x6466e49f // bfmmla z31.s, z4.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x11, #3, MUL VL]\n" + ".inst 0x6467e428 // bfmmla z8.s, z1.h, z7.h\n" + "addvl x11, x11, #4\n" + ".inst 0x6467e470 // bfmmla z16.s, z3.h, z7.h\n" + ".inst 0x6467e4b8 // bfmmla z24.s, z5.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x10, #2, MUL VL]\n" + ".inst 0x6466e42c // bfmmla z12.s, z1.h, z6.h\n" + ".inst 0x6466e474 // bfmmla z20.s, z3.h, z6.h\n" + ".inst 0x6466e4bc // bfmmla z28.s, z5.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x10, #3, MUL VL]\n" + ".inst 0x6467e429 // bfmmla z9.s, z1.h, z7.h\n" + "addvl x10, x10, #4\n" + ".inst 0x6467e471 // bfmmla z17.s, z3.h, z7.h\n" + ".inst 0x6467e4b9 // bfmmla z25.s, z5.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x9, #2, MUL VL]\n" + ".inst 0x6466e42d // bfmmla z13.s, z1.h, z6.h\n" + ".inst 0x6466e475 // bfmmla z21.s, z3.h, z6.h\n" + ".inst 0x6466e4bd // bfmmla z29.s, z5.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x9, #3, MUL VL]\n" + ".inst 0x6467e42a // bfmmla z10.s, z1.h, z7.h\n" + "addvl x9, x9, #4\n" + ".inst 0x6467e472 // bfmmla z18.s, z3.h, z7.h\n" + ".inst 0x6467e4ba // bfmmla z26.s, z5.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x28, #2, MUL VL]\n" + ".inst 0x6466e42e // bfmmla z14.s, z1.h, z6.h\n" + ".inst 0x6466e476 // bfmmla z22.s, z3.h, z6.h\n" + ".inst 0x6466e4be // bfmmla z30.s, z5.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x28, #3, MUL VL]\n" + ".inst 0x6467e42b // bfmmla z11.s, z1.h, z7.h\n" + "addvl x28, x28, #4\n" + ".inst 0x6467e473 // bfmmla z19.s, z3.h, z7.h\n" + ".inst 0x6467e4bb // bfmmla z27.s, z5.h, z7.h\n" + ".inst 0x6466e42f // bfmmla z15.s, z1.h, z6.h\n" + ".inst 0x6466e477 // bfmmla z23.s, z3.h, z6.h\n" + ".inst 0x6466e4bf // bfmmla z31.s, z5.h, z6.h\n" + "bgt 66b\n" + "67:" // Height 5: Multiply loop: Single iteration only + "whilelt p0.h, XZR, x26\n" + "ld1rqh { z1.h }, p0/Z, [x25]\n" + "ld1rqh { z2.h }, p0/Z, [x24]\n" + "ld1rqh { z3.h }, p0/Z, [x23]\n" + "ld1rqh { z4.h }, p0/Z, [x22]\n" + "trn1 z0.d, z1.d, z2.d\n" + "trn2 z1.d, z1.d, z2.d\n" + "ld1rqh { z5.h }, p0/Z, [x21]\n" + "trn1 z2.d, z3.d, z4.d\n" + "trn2 z3.d, z3.d, z4.d\n" + "ld1h { z7.h }, p5/Z, [x11]\n" + "trn1 z4.d, z5.d, z6.d\n" + "trn2 z5.d, z5.d, z6.d\n" + "ld1h { z6.h }, p5/Z, [x11, #1, MUL VL]\n" + ".inst 0x6467e408 // bfmmla z8.s, z0.h, z7.h\n" + ".inst 0x6467e450 // bfmmla z16.s, z2.h, z7.h\n" + ".inst 0x6467e498 // bfmmla z24.s, z4.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "subs x26, x26, #0x4\n" + ".inst 0x6466e40c // bfmmla z12.s, z0.h, z6.h\n" + ".inst 0x6466e454 // bfmmla z20.s, z2.h, z6.h\n" + "addvl x11, x11, #2\n" + ".inst 0x6466e49c // bfmmla z28.s, z4.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x10, #1, MUL VL]\n" + ".inst 0x6467e409 // bfmmla z9.s, z0.h, z7.h\n" + "addvl x10, x10, #2\n" + ".inst 0x6467e451 // bfmmla z17.s, z2.h, z7.h\n" + ".inst 0x6467e499 // bfmmla z25.s, z4.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x9]\n" + ".inst 0x6466e40d // bfmmla z13.s, z0.h, z6.h\n" + ".inst 0x6466e455 // bfmmla z21.s, z2.h, z6.h\n" + ".inst 0x6466e49d // bfmmla z29.s, z4.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x9, #1, MUL VL]\n" + ".inst 0x6467e40a // bfmmla z10.s, z0.h, z7.h\n" + "addvl x9, x9, #2\n" + ".inst 0x6467e452 // bfmmla z18.s, z2.h, z7.h\n" + ".inst 0x6467e49a // bfmmla z26.s, z4.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + ".inst 0x6466e40e // bfmmla z14.s, z0.h, z6.h\n" + ".inst 0x6466e456 // bfmmla z22.s, z2.h, z6.h\n" + ".inst 0x6466e49e // bfmmla z30.s, z4.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x28, #1, MUL VL]\n" + ".inst 0x6467e40b // bfmmla z11.s, z0.h, z7.h\n" + "addvl x28, x28, #2\n" + ".inst 0x6467e453 // bfmmla z19.s, z2.h, z7.h\n" + ".inst 0x6467e49b // bfmmla z27.s, z4.h, z7.h\n" + ".inst 0x6466e40f // bfmmla z15.s, z0.h, z6.h\n" + ".inst 0x6466e457 // bfmmla z23.s, z2.h, z6.h\n" + ".inst 0x6466e49f // bfmmla z31.s, z4.h, z6.h\n" + "ble 68f\n" + "ld1h { z7.h }, p5/Z, [x11]\n" + "ld1h { z6.h }, p5/Z, [x11, #1, MUL VL]\n" + ".inst 0x6467e428 // bfmmla z8.s, z1.h, z7.h\n" + ".inst 0x6467e470 // bfmmla z16.s, z3.h, z7.h\n" + ".inst 0x6467e4b8 // bfmmla z24.s, z5.h, z7.h\n" + ".inst 0x6466e42c // bfmmla z12.s, z1.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "addvl x11, x11, #2\n" + ".inst 0x6466e474 // bfmmla z20.s, z3.h, z6.h\n" + ".inst 0x6466e4bc // bfmmla z28.s, z5.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x10, #1, MUL VL]\n" + "addvl x10, x10, #2\n" + ".inst 0x6467e429 // bfmmla z9.s, z1.h, z7.h\n" + ".inst 0x6467e471 // bfmmla z17.s, z3.h, z7.h\n" + ".inst 0x6467e4b9 // bfmmla z25.s, z5.h, z7.h\n" + ".inst 0x6466e42d // bfmmla z13.s, z1.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x9]\n" + ".inst 0x6466e475 // bfmmla z21.s, z3.h, z6.h\n" + ".inst 0x6466e4bd // bfmmla z29.s, z5.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x9, #1, MUL VL]\n" + "addvl x9, x9, #2\n" + ".inst 0x6467e42a // bfmmla z10.s, z1.h, z7.h\n" + ".inst 0x6467e472 // bfmmla z18.s, z3.h, z7.h\n" + ".inst 0x6467e4ba // bfmmla z26.s, z5.h, z7.h\n" + ".inst 0x6466e42e // bfmmla z14.s, z1.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + ".inst 0x6466e476 // bfmmla z22.s, z3.h, z6.h\n" + ".inst 0x6466e4be // bfmmla z30.s, z5.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x28, #1, MUL VL]\n" + "addvl x28, x28, #2\n" + ".inst 0x6467e42b // bfmmla z11.s, z1.h, z7.h\n" + ".inst 0x6467e473 // bfmmla z19.s, z3.h, z7.h\n" + ".inst 0x6467e4bb // bfmmla z27.s, z5.h, z7.h\n" + ".inst 0x6466e42f // bfmmla z15.s, z1.h, z6.h\n" + ".inst 0x6466e477 // bfmmla z23.s, z3.h, z6.h\n" + ".inst 0x6466e4bf // bfmmla z31.s, z5.h, z6.h\n" + "68:" // Height 5: Multiply loop: multiply skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 63b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "uzp1 z7.d, z8.d, z12.d\n" + "add x22, x23, x19, LSL #2\n" + "uzp2 z8.d, z8.d, z12.d\n" + "uzp1 z12.d, z9.d, z13.d\n" + "add x21, x22, x19, LSL #2\n" + "uzp2 z9.d, z9.d, z13.d\n" + "uzp1 z13.d, z10.d, z14.d\n" + "uzp2 z10.d, z10.d, z14.d\n" + "uzp1 z14.d, z11.d, z15.d\n" + "uzp2 z11.d, z11.d, z15.d\n" + "uzp1 z15.d, z16.d, z20.d\n" + "uzp2 z16.d, z16.d, z20.d\n" + "uzp1 z20.d, z17.d, z21.d\n" + "uzp2 z17.d, z17.d, z21.d\n" + "uzp1 z21.d, z18.d, z22.d\n" + "uzp2 z18.d, z18.d, z22.d\n" + "uzp1 z22.d, z19.d, z23.d\n" + "uzp2 z19.d, z19.d, z23.d\n" + "uzp1 z24.d, z24.d, z28.d\n" + "uzp1 z25.d, z25.d, z29.d\n" + "uzp1 z26.d, z26.d, z30.d\n" + "uzp1 z27.d, z27.d, z31.d\n" + "tbz %x[flags], #1, 69f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rw { z1.s }, p5/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rw { z0.s }, p5/Z, [x19]\n" + "fmin z7.s, p5/M, z7.s, z1.s\n" + "fmin z12.s, p5/M, z12.s, z1.s\n" + "fmin z13.s, p5/M, z13.s, z1.s\n" + "fmin z14.s, p5/M, z14.s, z1.s\n" + "fmin z8.s, p5/M, z8.s, z1.s\n" + "fmin z9.s, p5/M, z9.s, z1.s\n" + "fmin z10.s, p5/M, z10.s, z1.s\n" + "fmin z11.s, p5/M, z11.s, z1.s\n" + "fmin z15.s, p5/M, z15.s, z1.s\n" + "fmin z20.s, p5/M, z20.s, z1.s\n" + "fmin z21.s, p5/M, z21.s, z1.s\n" + "fmin z22.s, p5/M, z22.s, z1.s\n" + "fmin z16.s, p5/M, z16.s, z1.s\n" + "fmin z17.s, p5/M, z17.s, z1.s\n" + "fmin z18.s, p5/M, z18.s, z1.s\n" + "fmin z19.s, p5/M, z19.s, z1.s\n" + "fmin z24.s, p5/M, z24.s, z1.s\n" + "fmin z25.s, p5/M, z25.s, z1.s\n" + "fmin z26.s, p5/M, z26.s, z1.s\n" + "fmin z27.s, p5/M, z27.s, z1.s\n" + "fmax z7.s, p5/M, z7.s, z0.s\n" + "fmax z12.s, p5/M, z12.s, z0.s\n" + "fmax z13.s, p5/M, z13.s, z0.s\n" + "fmax z14.s, p5/M, z14.s, z0.s\n" + "fmax z8.s, p5/M, z8.s, z0.s\n" + "fmax z9.s, p5/M, z9.s, z0.s\n" + "fmax z10.s, p5/M, z10.s, z0.s\n" + "fmax z11.s, p5/M, z11.s, z0.s\n" + "fmax z15.s, p5/M, z15.s, z0.s\n" + "fmax z20.s, p5/M, z20.s, z0.s\n" + "fmax z21.s, p5/M, z21.s, z0.s\n" + "fmax z22.s, p5/M, z22.s, z0.s\n" + "fmax z16.s, p5/M, z16.s, z0.s\n" + "fmax z17.s, p5/M, z17.s, z0.s\n" + "fmax z18.s, p5/M, z18.s, z0.s\n" + "fmax z19.s, p5/M, z19.s, z0.s\n" + "fmax z24.s, p5/M, z24.s, z0.s\n" + "fmax z25.s, p5/M, z25.s, z0.s\n" + "fmax z26.s, p5/M, z26.s, z0.s\n" + "fmax z27.s, p5/M, z27.s, z0.s\n" + "69:" // Height 5: No activation + "st1w { z7.s }, p4, [x12]\n" + "st1w { z12.s }, p3, [x12, #1, MUL VL]\n" + "st1w { z13.s }, p2, [x12, #2, MUL VL]\n" + "st1w { z14.s }, p1, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "st1w { z8.s }, p4, [x24]\n" + "st1w { z9.s }, p3, [x24, #1, MUL VL]\n" + "st1w { z10.s }, p2, [x24, #2, MUL VL]\n" + "st1w { z11.s }, p1, [x24, #3, MUL VL]\n" + "st1w { z15.s }, p4, [x23]\n" + "st1w { z20.s }, p3, [x23, #1, MUL VL]\n" + "st1w { z21.s }, p2, [x23, #2, MUL VL]\n" + "st1w { z22.s }, p1, [x23, #3, MUL VL]\n" + "st1w { z16.s }, p4, [x22]\n" + "st1w { z17.s }, p3, [x22, #1, MUL VL]\n" + "st1w { z18.s }, p2, [x22, #2, MUL VL]\n" + "st1w { z19.s }, p1, [x22, #3, MUL VL]\n" + "st1w { z24.s }, p4, [x21]\n" + "st1w { z25.s }, p3, [x21, #1, MUL VL]\n" + "st1w { z26.s }, p2, [x21, #2, MUL VL]\n" + "st1w { z27.s }, p1, [x21, #3, MUL VL]\n" + "70:" // Height 5: Writeback done + "decw x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 58b\n" + "b 86f\n" + "71:" // Height 6 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x20, #0x18\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "mov x14, %x[bias]\n" + "mov x12, %x[output_ptr]\n" + "madd %x[output_ptr], x19, x20, %x[output_ptr]\n" + "72:" // Height 6: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "cntw x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 73f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 73f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 73f\n" + "mov x10, x11\n" + "73:" // Height 6: B setup done + "mov x19, #0x0\n" + "whilelt p4.s, x19, x13\n" + "incw x19\n" + "whilelt p3.s, x19, x13\n" + "incw x19\n" + "whilelt p2.s, x19, x13\n" + "incw x19\n" + "whilelt p1.s, x19, x13\n" + "cbz x14, 74f\n" + "ld1w { z8.s }, p5/Z, [x14]\n" + "ld1w { z9.s }, p5/Z, [x14, #1, MUL VL]\n" + "zip2 z12.d, z8.d, z8.d\n" + "zip1 z8.d, z8.d, z8.d\n" + "ld1w { z10.s }, p5/Z, [x14, #2, MUL VL]\n" + "ld1w { z11.s }, p5/Z, [x14, #3, MUL VL]\n" + "zip2 z13.d, z9.d, z9.d\n" + "zip1 z9.d, z9.d, z9.d\n" + "zip2 z14.d, z10.d, z10.d\n" + "zip1 z10.d, z10.d, z10.d\n" + "addvl x14, x14, #4\n" + "zip2 z15.d, z11.d, z11.d\n" + "zip1 z11.d, z11.d, z11.d\n" + "mov z16.d, z8.d\n" + "mov z20.d, z12.d\n" + "mov z17.d, z9.d\n" + "mov z21.d, z13.d\n" + "mov z18.d, z10.d\n" + "mov z22.d, z14.d\n" + "mov z19.d, z11.d\n" + "mov z23.d, z15.d\n" + "mov z24.d, z8.d\n" + "mov z28.d, z12.d\n" + "mov z25.d, z9.d\n" + "mov z29.d, z13.d\n" + "mov z26.d, z10.d\n" + "mov z30.d, z14.d\n" + "mov z27.d, z11.d\n" + "mov z31.d, z15.d\n" + "b 76f\n" + "74:" // Height 6: no bias + "tbz %x[flags], #0, 75f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "ld1w { z9.s }, p4/Z, [x12]\n" + "add x22, x23, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "ld1w { z10.s }, p3/Z, [x12, #1, MUL VL]\n" + "ld1w { z11.s }, p2/Z, [x12, #2, MUL VL]\n" + "add x20, x21, x19, LSL #2\n" + "ld1w { z16.s }, p1/Z, [x12, #3, MUL VL]\n" + "ld1w { z12.s }, p4/Z, [x24]\n" + "zip1 z8.d, z9.d, z12.d\n" + "ld1w { z13.s }, p3/Z, [x24, #1, MUL VL]\n" + "ld1w { z14.s }, p2/Z, [x24, #2, MUL VL]\n" + "zip2 z12.d, z9.d, z12.d\n" + "zip1 z9.d, z10.d, z13.d\n" + "ld1w { z15.s }, p1/Z, [x24, #3, MUL VL]\n" + "ld1w { z17.s }, p4/Z, [x23]\n" + "zip2 z13.d, z10.d, z13.d\n" + "zip1 z10.d, z11.d, z14.d\n" + "ld1w { z18.s }, p3/Z, [x23, #1, MUL VL]\n" + "ld1w { z19.s }, p2/Z, [x23, #2, MUL VL]\n" + "zip2 z14.d, z11.d, z14.d\n" + "zip1 z11.d, z16.d, z15.d\n" + "ld1w { z24.s }, p1/Z, [x23, #3, MUL VL]\n" + "ld1w { z20.s }, p4/Z, [x22]\n" + "zip2 z15.d, z16.d, z15.d\n" + "zip1 z16.d, z17.d, z20.d\n" + "ld1w { z21.s }, p3/Z, [x22, #1, MUL VL]\n" + "ld1w { z22.s }, p2/Z, [x22, #2, MUL VL]\n" + "zip2 z20.d, z17.d, z20.d\n" + "zip1 z17.d, z18.d, z21.d\n" + "ld1w { z23.s }, p1/Z, [x22, #3, MUL VL]\n" + "ld1w { z25.s }, p4/Z, [x21]\n" + "zip2 z21.d, z18.d, z21.d\n" + "zip1 z18.d, z19.d, z22.d\n" + "ld1w { z26.s }, p3/Z, [x21, #1, MUL VL]\n" + "ld1w { z27.s }, p2/Z, [x21, #2, MUL VL]\n" + "zip2 z22.d, z19.d, z22.d\n" + "zip1 z19.d, z24.d, z23.d\n" + "ld1w { z6.s }, p1/Z, [x21, #3, MUL VL]\n" + "ld1w { z28.s }, p4/Z, [x20]\n" + "zip2 z23.d, z24.d, z23.d\n" + "zip1 z24.d, z25.d, z28.d\n" + "ld1w { z29.s }, p3/Z, [x20, #1, MUL VL]\n" + "ld1w { z30.s }, p2/Z, [x20, #2, MUL VL]\n" + "zip2 z28.d, z25.d, z28.d\n" + "zip1 z25.d, z26.d, z29.d\n" + "ld1w { z31.s }, p1/Z, [x20, #3, MUL VL]\n" + "zip2 z29.d, z26.d, z29.d\n" + "zip1 z26.d, z27.d, z30.d\n" + "zip2 z30.d, z27.d, z30.d\n" + "zip1 z27.d, z6.d, z31.d\n" + "zip2 z31.d, z6.d, z31.d\n" + "b 76f\n" + "75:" // Height 6: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "mov z19.b, #0x0\n" + "mov z20.b, #0x0\n" + "mov z21.b, #0x0\n" + "mov z22.b, #0x0\n" + "mov z23.b, #0x0\n" + "mov z24.b, #0x0\n" + "mov z25.b, #0x0\n" + "mov z26.b, #0x0\n" + "mov z27.b, #0x0\n" + "mov z28.b, #0x0\n" + "mov z29.b, #0x0\n" + "mov z30.b, #0x0\n" + "mov z31.b, #0x0\n" + "76:" // Height 6: setup done + "mov x27, #0x0\n" + "77:" // Height 6: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 78f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "ldr x22, [x20, #0x18]\n" + "ldr x21, [x20, #0x20]\n" + "ldr x20, [x20, #0x28]\n" + "cbnz x27, 79f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "add x24, x24, x19, LSL #1\n" + "add x23, x23, x19, LSL #1\n" + "add x22, x22, x19, LSL #1\n" + "add x21, x21, x19, LSL #1\n" + "add x20, x20, x19, LSL #1\n" + "b 79f\n" + "78:" // Height 6: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "add x22, x23, x19, LSL #1\n" + "add x21, x22, x19, LSL #1\n" + "add x20, x21, x19, LSL #1\n" + "79:" // Height 6: input setup done + "cmp x26, #0x8\n" + "ble 81f\n" + "80:" // Height 6: Multiply loop: Main loop head + "whilelt p0.h, XZR, x26\n" + "ld1rqh { z1.h }, p0/Z, [x25]\n" + "ld1rqh { z2.h }, p0/Z, [x24]\n" + "trn1 z0.d, z1.d, z2.d\n" + "ld1rqh { z3.h }, p0/Z, [x23]\n" + "ld1rqh { z4.h }, p0/Z, [x22]\n" + "trn2 z1.d, z1.d, z2.d\n" + "trn1 z2.d, z3.d, z4.d\n" + "ld1rqh { z5.h }, p0/Z, [x21]\n" + "ld1rqh { z6.h }, p0/Z, [x20]\n" + "trn2 z3.d, z3.d, z4.d\n" + "trn1 z4.d, z5.d, z6.d\n" + "trn2 z5.d, z5.d, z6.d\n" + "ld1h { z7.h }, p5/Z, [x11]\n" + "ld1h { z6.h }, p5/Z, [x11, #1, MUL VL]\n" + ".inst 0x6467e408 // bfmmla z8.s, z0.h, z7.h\n" + ".inst 0x6467e450 // bfmmla z16.s, z2.h, z7.h\n" + ".inst 0x6467e498 // bfmmla z24.s, z4.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "sub x26, x26, #0x8\n" + ".inst 0x6466e40c // bfmmla z12.s, z0.h, z6.h\n" + ".inst 0x6466e454 // bfmmla z20.s, z2.h, z6.h\n" + "cmp x26, #0x8\n" + "add x25, x25, #0x10\n" + ".inst 0x6466e49c // bfmmla z28.s, z4.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x10, #1, MUL VL]\n" + ".inst 0x6467e409 // bfmmla z9.s, z0.h, z7.h\n" + "add x24, x24, #0x10\n" + ".inst 0x6467e451 // bfmmla z17.s, z2.h, z7.h\n" + ".inst 0x6467e499 // bfmmla z25.s, z4.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x9]\n" + "add x23, x23, #0x10\n" + ".inst 0x6466e40d // bfmmla z13.s, z0.h, z6.h\n" + ".inst 0x6466e455 // bfmmla z21.s, z2.h, z6.h\n" + "add x22, x22, #0x10\n" + "add x21, x21, #0x10\n" + ".inst 0x6466e49d // bfmmla z29.s, z4.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x9, #1, MUL VL]\n" + ".inst 0x6467e40a // bfmmla z10.s, z0.h, z7.h\n" + "add x20, x20, #0x10\n" + ".inst 0x6467e452 // bfmmla z18.s, z2.h, z7.h\n" + ".inst 0x6467e49a // bfmmla z26.s, z4.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + ".inst 0x6466e40e // bfmmla z14.s, z0.h, z6.h\n" + ".inst 0x6466e456 // bfmmla z22.s, z2.h, z6.h\n" + ".inst 0x6466e49e // bfmmla z30.s, z4.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x28, #1, MUL VL]\n" + ".inst 0x6467e40b // bfmmla z11.s, z0.h, z7.h\n" + ".inst 0x6467e453 // bfmmla z19.s, z2.h, z7.h\n" + ".inst 0x6467e49b // bfmmla z27.s, z4.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x11, #2, MUL VL]\n" + ".inst 0x6466e40f // bfmmla z15.s, z0.h, z6.h\n" + ".inst 0x6466e457 // bfmmla z23.s, z2.h, z6.h\n" + ".inst 0x6466e49f // bfmmla z31.s, z4.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x11, #3, MUL VL]\n" + ".inst 0x6467e428 // bfmmla z8.s, z1.h, z7.h\n" + "addvl x11, x11, #4\n" + ".inst 0x6467e470 // bfmmla z16.s, z3.h, z7.h\n" + ".inst 0x6467e4b8 // bfmmla z24.s, z5.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x10, #2, MUL VL]\n" + ".inst 0x6466e42c // bfmmla z12.s, z1.h, z6.h\n" + ".inst 0x6466e474 // bfmmla z20.s, z3.h, z6.h\n" + ".inst 0x6466e4bc // bfmmla z28.s, z5.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x10, #3, MUL VL]\n" + ".inst 0x6467e429 // bfmmla z9.s, z1.h, z7.h\n" + "addvl x10, x10, #4\n" + ".inst 0x6467e471 // bfmmla z17.s, z3.h, z7.h\n" + ".inst 0x6467e4b9 // bfmmla z25.s, z5.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x9, #2, MUL VL]\n" + ".inst 0x6466e42d // bfmmla z13.s, z1.h, z6.h\n" + ".inst 0x6466e475 // bfmmla z21.s, z3.h, z6.h\n" + ".inst 0x6466e4bd // bfmmla z29.s, z5.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x9, #3, MUL VL]\n" + ".inst 0x6467e42a // bfmmla z10.s, z1.h, z7.h\n" + "addvl x9, x9, #4\n" + ".inst 0x6467e472 // bfmmla z18.s, z3.h, z7.h\n" + ".inst 0x6467e4ba // bfmmla z26.s, z5.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x28, #2, MUL VL]\n" + ".inst 0x6466e42e // bfmmla z14.s, z1.h, z6.h\n" + ".inst 0x6466e476 // bfmmla z22.s, z3.h, z6.h\n" + ".inst 0x6466e4be // bfmmla z30.s, z5.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x28, #3, MUL VL]\n" + ".inst 0x6467e42b // bfmmla z11.s, z1.h, z7.h\n" + "addvl x28, x28, #4\n" + ".inst 0x6467e473 // bfmmla z19.s, z3.h, z7.h\n" + ".inst 0x6467e4bb // bfmmla z27.s, z5.h, z7.h\n" + ".inst 0x6466e42f // bfmmla z15.s, z1.h, z6.h\n" + ".inst 0x6466e477 // bfmmla z23.s, z3.h, z6.h\n" + ".inst 0x6466e4bf // bfmmla z31.s, z5.h, z6.h\n" + "bgt 80b\n" + "81:" // Height 6: Multiply loop: Single iteration only + "whilelt p0.h, XZR, x26\n" + "ld1rqh { z1.h }, p0/Z, [x25]\n" + "ld1rqh { z2.h }, p0/Z, [x24]\n" + "trn1 z0.d, z1.d, z2.d\n" + "ld1rqh { z3.h }, p0/Z, [x23]\n" + "ld1rqh { z4.h }, p0/Z, [x22]\n" + "trn2 z1.d, z1.d, z2.d\n" + "trn1 z2.d, z3.d, z4.d\n" + "ld1rqh { z5.h }, p0/Z, [x21]\n" + "ld1rqh { z6.h }, p0/Z, [x20]\n" + "trn2 z3.d, z3.d, z4.d\n" + "trn1 z4.d, z5.d, z6.d\n" + "trn2 z5.d, z5.d, z6.d\n" + "ld1h { z7.h }, p5/Z, [x11]\n" + "ld1h { z6.h }, p5/Z, [x11, #1, MUL VL]\n" + ".inst 0x6467e408 // bfmmla z8.s, z0.h, z7.h\n" + ".inst 0x6467e450 // bfmmla z16.s, z2.h, z7.h\n" + ".inst 0x6467e498 // bfmmla z24.s, z4.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "subs x26, x26, #0x4\n" + ".inst 0x6466e40c // bfmmla z12.s, z0.h, z6.h\n" + ".inst 0x6466e454 // bfmmla z20.s, z2.h, z6.h\n" + "addvl x11, x11, #2\n" + ".inst 0x6466e49c // bfmmla z28.s, z4.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x10, #1, MUL VL]\n" + ".inst 0x6467e409 // bfmmla z9.s, z0.h, z7.h\n" + "addvl x10, x10, #2\n" + ".inst 0x6467e451 // bfmmla z17.s, z2.h, z7.h\n" + ".inst 0x6467e499 // bfmmla z25.s, z4.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x9]\n" + ".inst 0x6466e40d // bfmmla z13.s, z0.h, z6.h\n" + ".inst 0x6466e455 // bfmmla z21.s, z2.h, z6.h\n" + ".inst 0x6466e49d // bfmmla z29.s, z4.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x9, #1, MUL VL]\n" + ".inst 0x6467e40a // bfmmla z10.s, z0.h, z7.h\n" + "addvl x9, x9, #2\n" + ".inst 0x6467e452 // bfmmla z18.s, z2.h, z7.h\n" + ".inst 0x6467e49a // bfmmla z26.s, z4.h, z7.h\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + ".inst 0x6466e40e // bfmmla z14.s, z0.h, z6.h\n" + ".inst 0x6466e456 // bfmmla z22.s, z2.h, z6.h\n" + ".inst 0x6466e49e // bfmmla z30.s, z4.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x28, #1, MUL VL]\n" + ".inst 0x6467e40b // bfmmla z11.s, z0.h, z7.h\n" + "addvl x28, x28, #2\n" + ".inst 0x6467e453 // bfmmla z19.s, z2.h, z7.h\n" + ".inst 0x6467e49b // bfmmla z27.s, z4.h, z7.h\n" + ".inst 0x6466e40f // bfmmla z15.s, z0.h, z6.h\n" + ".inst 0x6466e457 // bfmmla z23.s, z2.h, z6.h\n" + ".inst 0x6466e49f // bfmmla z31.s, z4.h, z6.h\n" + "ble 82f\n" + "ld1h { z7.h }, p5/Z, [x11]\n" + "ld1h { z6.h }, p5/Z, [x11, #1, MUL VL]\n" + ".inst 0x6467e428 // bfmmla z8.s, z1.h, z7.h\n" + ".inst 0x6467e470 // bfmmla z16.s, z3.h, z7.h\n" + ".inst 0x6467e4b8 // bfmmla z24.s, z5.h, z7.h\n" + ".inst 0x6466e42c // bfmmla z12.s, z1.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "addvl x11, x11, #2\n" + ".inst 0x6466e474 // bfmmla z20.s, z3.h, z6.h\n" + ".inst 0x6466e4bc // bfmmla z28.s, z5.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x10, #1, MUL VL]\n" + "addvl x10, x10, #2\n" + ".inst 0x6467e429 // bfmmla z9.s, z1.h, z7.h\n" + ".inst 0x6467e471 // bfmmla z17.s, z3.h, z7.h\n" + ".inst 0x6467e4b9 // bfmmla z25.s, z5.h, z7.h\n" + ".inst 0x6466e42d // bfmmla z13.s, z1.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x9]\n" + ".inst 0x6466e475 // bfmmla z21.s, z3.h, z6.h\n" + ".inst 0x6466e4bd // bfmmla z29.s, z5.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x9, #1, MUL VL]\n" + "addvl x9, x9, #2\n" + ".inst 0x6467e42a // bfmmla z10.s, z1.h, z7.h\n" + ".inst 0x6467e472 // bfmmla z18.s, z3.h, z7.h\n" + ".inst 0x6467e4ba // bfmmla z26.s, z5.h, z7.h\n" + ".inst 0x6466e42e // bfmmla z14.s, z1.h, z6.h\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + ".inst 0x6466e476 // bfmmla z22.s, z3.h, z6.h\n" + ".inst 0x6466e4be // bfmmla z30.s, z5.h, z6.h\n" + "ld1h { z6.h }, p5/Z, [x28, #1, MUL VL]\n" + "addvl x28, x28, #2\n" + ".inst 0x6467e42b // bfmmla z11.s, z1.h, z7.h\n" + ".inst 0x6467e473 // bfmmla z19.s, z3.h, z7.h\n" + ".inst 0x6467e4bb // bfmmla z27.s, z5.h, z7.h\n" + ".inst 0x6466e42f // bfmmla z15.s, z1.h, z6.h\n" + ".inst 0x6466e477 // bfmmla z23.s, z3.h, z6.h\n" + ".inst 0x6466e4bf // bfmmla z31.s, z5.h, z6.h\n" + "82:" // Height 6: Multiply loop: multiply skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 77b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "uzp1 z7.d, z8.d, z12.d\n" + "add x22, x23, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "uzp2 z8.d, z8.d, z12.d\n" + "uzp1 z12.d, z9.d, z13.d\n" + "uzp2 z9.d, z9.d, z13.d\n" + "uzp1 z13.d, z10.d, z14.d\n" + "add x20, x21, x19, LSL #2\n" + "uzp2 z10.d, z10.d, z14.d\n" + "uzp1 z14.d, z11.d, z15.d\n" + "uzp2 z11.d, z11.d, z15.d\n" + "uzp1 z15.d, z16.d, z20.d\n" + "uzp2 z16.d, z16.d, z20.d\n" + "uzp1 z20.d, z17.d, z21.d\n" + "uzp2 z17.d, z17.d, z21.d\n" + "uzp1 z21.d, z18.d, z22.d\n" + "uzp2 z18.d, z18.d, z22.d\n" + "uzp1 z22.d, z19.d, z23.d\n" + "uzp2 z19.d, z19.d, z23.d\n" + "uzp1 z23.d, z24.d, z28.d\n" + "uzp2 z24.d, z24.d, z28.d\n" + "uzp1 z28.d, z25.d, z29.d\n" + "uzp2 z25.d, z25.d, z29.d\n" + "uzp1 z29.d, z26.d, z30.d\n" + "uzp2 z26.d, z26.d, z30.d\n" + "uzp1 z30.d, z27.d, z31.d\n" + "uzp2 z27.d, z27.d, z31.d\n" + "tbz %x[flags], #1, 83f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rw { z1.s }, p5/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rw { z0.s }, p5/Z, [x19]\n" + "fmin z7.s, p5/M, z7.s, z1.s\n" + "fmin z12.s, p5/M, z12.s, z1.s\n" + "fmin z13.s, p5/M, z13.s, z1.s\n" + "fmin z14.s, p5/M, z14.s, z1.s\n" + "fmin z8.s, p5/M, z8.s, z1.s\n" + "fmin z9.s, p5/M, z9.s, z1.s\n" + "fmin z10.s, p5/M, z10.s, z1.s\n" + "fmin z11.s, p5/M, z11.s, z1.s\n" + "fmin z15.s, p5/M, z15.s, z1.s\n" + "fmin z20.s, p5/M, z20.s, z1.s\n" + "fmin z21.s, p5/M, z21.s, z1.s\n" + "fmin z22.s, p5/M, z22.s, z1.s\n" + "fmin z16.s, p5/M, z16.s, z1.s\n" + "fmin z17.s, p5/M, z17.s, z1.s\n" + "fmin z18.s, p5/M, z18.s, z1.s\n" + "fmin z19.s, p5/M, z19.s, z1.s\n" + "fmin z23.s, p5/M, z23.s, z1.s\n" + "fmin z28.s, p5/M, z28.s, z1.s\n" + "fmin z29.s, p5/M, z29.s, z1.s\n" + "fmin z30.s, p5/M, z30.s, z1.s\n" + "fmin z24.s, p5/M, z24.s, z1.s\n" + "fmin z25.s, p5/M, z25.s, z1.s\n" + "fmin z26.s, p5/M, z26.s, z1.s\n" + "fmin z27.s, p5/M, z27.s, z1.s\n" + "fmax z7.s, p5/M, z7.s, z0.s\n" + "fmax z12.s, p5/M, z12.s, z0.s\n" + "fmax z13.s, p5/M, z13.s, z0.s\n" + "fmax z14.s, p5/M, z14.s, z0.s\n" + "fmax z8.s, p5/M, z8.s, z0.s\n" + "fmax z9.s, p5/M, z9.s, z0.s\n" + "fmax z10.s, p5/M, z10.s, z0.s\n" + "fmax z11.s, p5/M, z11.s, z0.s\n" + "fmax z15.s, p5/M, z15.s, z0.s\n" + "fmax z20.s, p5/M, z20.s, z0.s\n" + "fmax z21.s, p5/M, z21.s, z0.s\n" + "fmax z22.s, p5/M, z22.s, z0.s\n" + "fmax z16.s, p5/M, z16.s, z0.s\n" + "fmax z17.s, p5/M, z17.s, z0.s\n" + "fmax z18.s, p5/M, z18.s, z0.s\n" + "fmax z19.s, p5/M, z19.s, z0.s\n" + "fmax z23.s, p5/M, z23.s, z0.s\n" + "fmax z28.s, p5/M, z28.s, z0.s\n" + "fmax z29.s, p5/M, z29.s, z0.s\n" + "fmax z30.s, p5/M, z30.s, z0.s\n" + "fmax z24.s, p5/M, z24.s, z0.s\n" + "fmax z25.s, p5/M, z25.s, z0.s\n" + "fmax z26.s, p5/M, z26.s, z0.s\n" + "fmax z27.s, p5/M, z27.s, z0.s\n" + "83:" // Height 6: No activation + "st1w { z7.s }, p4, [x12]\n" + "st1w { z12.s }, p3, [x12, #1, MUL VL]\n" + "st1w { z13.s }, p2, [x12, #2, MUL VL]\n" + "st1w { z14.s }, p1, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "st1w { z8.s }, p4, [x24]\n" + "st1w { z9.s }, p3, [x24, #1, MUL VL]\n" + "st1w { z10.s }, p2, [x24, #2, MUL VL]\n" + "st1w { z11.s }, p1, [x24, #3, MUL VL]\n" + "st1w { z15.s }, p4, [x23]\n" + "st1w { z20.s }, p3, [x23, #1, MUL VL]\n" + "st1w { z21.s }, p2, [x23, #2, MUL VL]\n" + "st1w { z22.s }, p1, [x23, #3, MUL VL]\n" + "st1w { z16.s }, p4, [x22]\n" + "st1w { z17.s }, p3, [x22, #1, MUL VL]\n" + "st1w { z18.s }, p2, [x22, #2, MUL VL]\n" + "st1w { z19.s }, p1, [x22, #3, MUL VL]\n" + "st1w { z23.s }, p4, [x21]\n" + "st1w { z28.s }, p3, [x21, #1, MUL VL]\n" + "st1w { z29.s }, p2, [x21, #2, MUL VL]\n" + "st1w { z30.s }, p1, [x21, #3, MUL VL]\n" + "st1w { z24.s }, p4, [x20]\n" + "st1w { z25.s }, p3, [x20, #1, MUL VL]\n" + "st1w { z26.s }, p2, [x20, #2, MUL VL]\n" + "st1w { z27.s }, p1, [x20, #3, MUL VL]\n" + "84:" // Height 6: Writeback done + "decw x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 72b\n" + "subs %x[M], %x[M], #0x6\n" + "beq 86f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 85f\n" + "add x20, x20, #0x6\n" + "str x20, [%x[args_ptr], %[offsetof_input_offset]]\n" + "b 1b\n" + "85:" // Update direct input + "mov x19, #0xc\n" + "madd %x[input_ptr], x19, x20, %x[input_ptr]\n" + "b 1b\n" + "86:" // Exit + : [M] "+&r" (M), [input_ptr] "+&r" (input_ptr), [output_ptr] "+&r" (output_ptr) + : [args_ptr] "r" (&ka), [bias] "r" (bias), [flags] "r" (flags), [offset_max] "I" (offsetof(KernelArgs, maxval)), [offset_min] "I" (offsetof(KernelArgs, minval)), [offsetof_B_ptr] "I" (offsetof(KernelArgs, B_ptr)), [offsetof_B_stride] "I" (offsetof(KernelArgs, B_stride)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_cur_B_ptr] "I" (offsetof(KernelArgs, cur_B_ptr)), [offsetof_input_initial_col] "I" (offsetof(KernelArgs, input_initial_col)), [offsetof_input_offset] "I" (offsetof(KernelArgs, input_offset)), [offsetof_num_strings] "I" (offsetof(KernelArgs, num_strings)), [offsetof_output_offset] "I" (offsetof(KernelArgs, output_offset)), [offsetof_string_lengths] "I" (offsetof(KernelArgs, string_lengths)) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "x9", "x10", "x11", "x12", "x13", "x14", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // namespace arm_gemm +#endif // ARM_COMPUTE_ENABLE_SVE diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp16_mla_6x4VL.hpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp16_mla_6x4VL.hpp new file mode 100644 index 0000000000..acbc619eed --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp16_mla_6x4VL.hpp @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#pragma once +#ifdef ARM_COMPUTE_ENABLE_SVE + +#include "../std_transforms_sve.hpp" +#include "../kernel_weight_format.hpp" +#include "../performance_parameters.hpp" + +#define ARGLIST \ + unsigned int, const unsigned int *, \ + IndirectInputArg<__fp16>, \ + size_t, size_t, \ + const __fp16 *, \ + size_t, \ + IndirectOutputArg<__fp16>, \ + const __fp16 *, Activation, bool + +namespace arm_gemm +{ +// Actual kernel implementations +void sve_ffhybrid_fp16_mla_6x4VL( ARGLIST ); +void sve_ffhybrid_fp16_mla_6x4VL_a64fx( ARGLIST ); + +class cls_sve_ffhybrid_fp16_mla_6x4VL +{ +public: + typedef __fp16 lhs_operand_type; + typedef __fp16 rhs_operand_type; + typedef __fp16 result_type; + + typedef void (*kern_type)( ARGLIST ); + + /* Kernel blocking parameters */ + static constexpr unsigned int out_height() + { + return 6; + } + static unsigned int stripe_width() + { + return get_vector_length<__fp16>() * 1; + } + + static KernelWeightFormat kernel_weight_format() + { + return KernelWeightFormat::VL1VL_BL16; + } + + static unsigned int out_width() + { + return get_vector_length<__fp16>() * 4; + } + + static constexpr unsigned int k_unroll() + { + return 1; + } + + static constexpr bool supports_accumulate() + { + return true; + } + + StdTransformsSVE transforms = {}; + template + static inline PerformanceParameters get_performance_parameters(const CPUInfo *ci) + { + if (std::is_same::value) { + switch (ci->get_cpu_model()) { + default: + return { 31.51 }; + } + } + + return { 1.0 }; + } + + // Default to the generic kernel + kern_type kernel=sve_ffhybrid_fp16_mla_6x4VL; + cls_sve_ffhybrid_fp16_mla_6x4VL(const CPUInfo *ci) + { + switch(ci->get_cpu_model()) { + default: + break; + case CPUModel::A64FX: + kernel=sve_ffhybrid_fp16_mla_6x4VL_a64fx; + break; + } + } +}; + +} // namespace arm_gemm + +#undef ARGLIST +#endif // ARM_COMPUTE_ENABLE_SVE diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp16_mla_6x4VL/a64fx.cpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp16_mla_6x4VL/a64fx.cpp new file mode 100644 index 0000000000..181022bf51 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp16_mla_6x4VL/a64fx.cpp @@ -0,0 +1,1530 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#ifdef ARM_COMPUTE_ENABLE_SVE + +#include "arm_gemm.hpp" +#include "../../utils.hpp" + +#include +#include + +namespace arm_gemm { + +void sve_ffhybrid_fp16_mla_6x4VL_a64fx ( + unsigned int num_strings, const unsigned int *string_lengths, IndirectInputArg<__fp16> A_arg, + size_t M, size_t N, const __fp16 *B_ptr, size_t B_stride, IndirectOutputArg<__fp16> output_arg, + const __fp16 *bias, Activation act, bool accumulate +) +{ + struct KernelArgs { + __fp16 maxval = static_cast<__fp16>(std::numeric_limits::infinity()); + __fp16 minval = - static_cast<__fp16>(std::numeric_limits::infinity()); + unsigned int num_strings = {}; + const unsigned int *string_lengths = {}; + size_t N = {}; + const __fp16 *B_ptr = {}; + const __fp16 *cur_B_ptr = {}; + size_t B_stride = {}; + size_t output_offset = {}; + size_t input_initial_col = {}; + size_t input_offset = {}; + } ka; + + unsigned long flags=0; + void *output_ptr; + void *input_ptr; + + if (output_arg.is_indirect) { + output_ptr=(void *)(output_arg.indirect.ptr); + ka.output_offset=output_arg.indirect.offset; + flags |= 0x4; + } else { + output_ptr=(void *)(output_arg.direct.base); + ka.output_offset=output_arg.direct.stride; + } + + if (A_arg.is_indirect) { + input_ptr=(void *)(A_arg.indirect.ptr); + ka.input_offset=A_arg.indirect.start_row; + ka.input_initial_col=A_arg.indirect.start_col; + flags |= 0x8; + } else { + assert(num_strings==1); + input_ptr=(void *)(A_arg.direct.base); + ka.input_offset=A_arg.direct.stride; + } + if (accumulate) { + flags |= 0x1; + } + ka.num_strings = num_strings; + ka.string_lengths = string_lengths; + ka.N = N; + ka.B_ptr = B_ptr; + ka.B_stride = B_stride; + switch(act.type) { + default: + case Activation::Type::None: + break; + case Activation::Type::BoundedReLU: + ka.maxval = static_cast<__fp16>(act.param1); + /* fall through */ + case Activation::Type::ReLU: + ka.minval = 0; + flags |= 0x2; + break; + } + __asm__ __volatile__( + "ptrue p4.b\n" + "1:" // Row loop + "cmp %x[M], #0x6\n" + "bge 66f\n" + "cmp %x[M], #0x4\n" + "bgt 53f\n" + "beq 40f\n" + "cmp %x[M], #0x2\n" + "bgt 27f\n" + "beq 14f\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "2:" // Height 1: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "cnth x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 3f\n" + "dech x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 3f\n" + "dech x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 3f\n" + "mov x10, x11\n" + "3:" // Height 1: B setup done + "mov x19, #0x0\n" + "whilelt p3.h, x19, x13\n" + "inch x19\n" + "whilelt p2.h, x19, x13\n" + "inch x19\n" + "whilelt p1.h, x19, x13\n" + "inch x19\n" + "whilelt p0.h, x19, x13\n" + "cbz x14, 4f\n" + "ld1h { z8.h }, p4/Z, [x14]\n" + "ld1h { z9.h }, p4/Z, [x14, #1, MUL VL]\n" + "ld1h { z10.h }, p4/Z, [x14, #2, MUL VL]\n" + "ld1h { z11.h }, p4/Z, [x14, #3, MUL VL]\n" + "addvl x14, x14, #4\n" + "b 6f\n" + "4:" // Height 1: no bias + "tbz %x[flags], #0, 5f\n" + "ld1h { z8.h }, p3/Z, [x12]\n" + "ld1h { z9.h }, p2/Z, [x12, #1, MUL VL]\n" + "ld1h { z10.h }, p1/Z, [x12, #2, MUL VL]\n" + "ld1h { z11.h }, p0/Z, [x12, #3, MUL VL]\n" + "b 6f\n" + "5:" // Height 1: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "6:" // Height 1: setup done + "mov x27, #0x0\n" + "7:" // Height 1: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 8f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "cbnz x27, 9f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "b 9f\n" + "8:" // Height 1: setup direct input + "mov x25, %x[input_ptr]\n" + "9:" // Height 1: input setup done + "subs x26, x26, #0x1\n" + "ld1rh { z0.h }, p4/Z, [x25]\n" + "ld1h { z6.h }, p4/Z, [x11]\n" + "ld1h { z7.h }, p4/Z, [x10]\n" + "ble 11f\n" + "10:" // Height 1: Multiply loop: Main loop + "fmla z8.h, p4/M, z6.h, z0.h\n" + "fmla z9.h, p4/M, z7.h, z0.h\n" + "ld1h { z6.h }, p4/Z, [x9]\n" + "ld1h { z7.h }, p4/Z, [x28]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z10.h, p4/M, z6.h, z0.h\n" + "fmla z11.h, p4/M, z7.h, z0.h\n" + "add x25, x25, #0x2\n" + "subs x26, x26, #0x1\n" + "ld1rh { z0.h }, p4/Z, [x25]\n" + "ld1h { z6.h }, p4/Z, [x11]\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "ld1h { z7.h }, p4/Z, [x10]\n" + "bgt 10b\n" + "11:" // Height 1: Multiply loop: Main loop skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "fmla z8.h, p4/M, z6.h, z0.h\n" + "fmla z9.h, p4/M, z7.h, z0.h\n" + "ld1h { z6.h }, p4/Z, [x9]\n" + "ld1h { z7.h }, p4/Z, [x28]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "fmla z10.h, p4/M, z6.h, z0.h\n" + "fmla z11.h, p4/M, z7.h, z0.h\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "bne 7b\n" + "tbz %x[flags], #1, 12f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rh { z1.h }, p4/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rh { z0.h }, p4/Z, [x19]\n" + "fmin z8.h, p4/M, z8.h, z1.h\n" + "fmin z9.h, p4/M, z9.h, z1.h\n" + "fmin z10.h, p4/M, z10.h, z1.h\n" + "fmin z11.h, p4/M, z11.h, z1.h\n" + "fmax z8.h, p4/M, z8.h, z0.h\n" + "fmax z9.h, p4/M, z9.h, z0.h\n" + "fmax z10.h, p4/M, z10.h, z0.h\n" + "fmax z11.h, p4/M, z11.h, z0.h\n" + "12:" // Height 1: No activation + "st1h { z8.h }, p3, [x12]\n" + "st1h { z9.h }, p2, [x12, #1, MUL VL]\n" + "st1h { z10.h }, p1, [x12, #2, MUL VL]\n" + "st1h { z11.h }, p0, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "13:" // Height 1: Writeback done + "dech x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 2b\n" + "b 80f\n" + "14:" // Height 2 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "15:" // Height 2: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "cnth x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 16f\n" + "dech x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 16f\n" + "dech x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 16f\n" + "mov x10, x11\n" + "16:" // Height 2: B setup done + "mov x19, #0x0\n" + "whilelt p3.h, x19, x13\n" + "inch x19\n" + "whilelt p2.h, x19, x13\n" + "inch x19\n" + "whilelt p1.h, x19, x13\n" + "inch x19\n" + "whilelt p0.h, x19, x13\n" + "cbz x14, 17f\n" + "ld1h { z8.h }, p4/Z, [x14]\n" + "ld1h { z9.h }, p4/Z, [x14, #1, MUL VL]\n" + "mov z12.d, z8.d\n" + "mov z13.d, z9.d\n" + "ld1h { z10.h }, p4/Z, [x14, #2, MUL VL]\n" + "ld1h { z11.h }, p4/Z, [x14, #3, MUL VL]\n" + "mov z14.d, z10.d\n" + "mov z15.d, z11.d\n" + "addvl x14, x14, #4\n" + "b 19f\n" + "17:" // Height 2: no bias + "tbz %x[flags], #0, 18f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "ld1h { z8.h }, p3/Z, [x12]\n" + "ld1h { z9.h }, p2/Z, [x12, #1, MUL VL]\n" + "ld1h { z10.h }, p1/Z, [x12, #2, MUL VL]\n" + "ld1h { z11.h }, p0/Z, [x12, #3, MUL VL]\n" + "ld1h { z12.h }, p3/Z, [x24]\n" + "ld1h { z13.h }, p2/Z, [x24, #1, MUL VL]\n" + "ld1h { z14.h }, p1/Z, [x24, #2, MUL VL]\n" + "ld1h { z15.h }, p0/Z, [x24, #3, MUL VL]\n" + "b 19f\n" + "18:" // Height 2: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "19:" // Height 2: setup done + "mov x27, #0x0\n" + "20:" // Height 2: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 21f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "cbnz x27, 22f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "add x24, x24, x19, LSL #1\n" + "b 22f\n" + "21:" // Height 2: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #1\n" + "22:" // Height 2: input setup done + "subs x26, x26, #0x1\n" + "ld1rh { z0.h }, p4/Z, [x25]\n" + "ld1rh { z1.h }, p4/Z, [x24]\n" + "ld1h { z6.h }, p4/Z, [x11]\n" + "ld1h { z7.h }, p4/Z, [x10]\n" + "ble 24f\n" + "23:" // Height 2: Multiply loop: Main loop + "fmla z8.h, p4/M, z6.h, z0.h\n" + "fmla z12.h, p4/M, z6.h, z1.h\n" + "ld1h { z6.h }, p4/Z, [x9]\n" + "addvl x11, x11, #1\n" + "fmla z9.h, p4/M, z7.h, z0.h\n" + "fmla z13.h, p4/M, z7.h, z1.h\n" + "ld1h { z7.h }, p4/Z, [x28]\n" + "addvl x10, x10, #1\n" + "add x25, x25, #0x2\n" + "subs x26, x26, #0x1\n" + "fmla z10.h, p4/M, z6.h, z0.h\n" + "fmla z14.h, p4/M, z6.h, z1.h\n" + "add x24, x24, #0x2\n" + "fmla z11.h, p4/M, z7.h, z0.h\n" + "fmla z15.h, p4/M, z7.h, z1.h\n" + "addvl x9, x9, #1\n" + "ld1rh { z0.h }, p4/Z, [x25]\n" + "ld1rh { z1.h }, p4/Z, [x24]\n" + "addvl x28, x28, #1\n" + "ld1h { z6.h }, p4/Z, [x11]\n" + "ld1h { z7.h }, p4/Z, [x10]\n" + "bgt 23b\n" + "24:" // Height 2: Multiply loop: Main loop skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "fmla z8.h, p4/M, z6.h, z0.h\n" + "fmla z12.h, p4/M, z6.h, z1.h\n" + "ld1h { z6.h }, p4/Z, [x9]\n" + "fmla z9.h, p4/M, z7.h, z0.h\n" + "fmla z13.h, p4/M, z7.h, z1.h\n" + "ld1h { z7.h }, p4/Z, [x28]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "fmla z10.h, p4/M, z6.h, z0.h\n" + "fmla z14.h, p4/M, z6.h, z1.h\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z11.h, p4/M, z7.h, z0.h\n" + "fmla z15.h, p4/M, z7.h, z1.h\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "bne 20b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "tbz %x[flags], #1, 25f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rh { z1.h }, p4/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rh { z0.h }, p4/Z, [x19]\n" + "fmin z8.h, p4/M, z8.h, z1.h\n" + "fmin z9.h, p4/M, z9.h, z1.h\n" + "fmin z10.h, p4/M, z10.h, z1.h\n" + "fmin z11.h, p4/M, z11.h, z1.h\n" + "fmin z12.h, p4/M, z12.h, z1.h\n" + "fmin z13.h, p4/M, z13.h, z1.h\n" + "fmin z14.h, p4/M, z14.h, z1.h\n" + "fmin z15.h, p4/M, z15.h, z1.h\n" + "fmax z8.h, p4/M, z8.h, z0.h\n" + "fmax z9.h, p4/M, z9.h, z0.h\n" + "fmax z10.h, p4/M, z10.h, z0.h\n" + "fmax z11.h, p4/M, z11.h, z0.h\n" + "fmax z12.h, p4/M, z12.h, z0.h\n" + "fmax z13.h, p4/M, z13.h, z0.h\n" + "fmax z14.h, p4/M, z14.h, z0.h\n" + "fmax z15.h, p4/M, z15.h, z0.h\n" + "25:" // Height 2: No activation + "st1h { z8.h }, p3, [x12]\n" + "st1h { z9.h }, p2, [x12, #1, MUL VL]\n" + "st1h { z10.h }, p1, [x12, #2, MUL VL]\n" + "st1h { z11.h }, p0, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "st1h { z12.h }, p3, [x24]\n" + "st1h { z13.h }, p2, [x24, #1, MUL VL]\n" + "st1h { z14.h }, p1, [x24, #2, MUL VL]\n" + "st1h { z15.h }, p0, [x24, #3, MUL VL]\n" + "26:" // Height 2: Writeback done + "dech x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 15b\n" + "b 80f\n" + "27:" // Height 3 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "28:" // Height 3: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "cnth x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 29f\n" + "dech x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 29f\n" + "dech x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 29f\n" + "mov x10, x11\n" + "29:" // Height 3: B setup done + "mov x19, #0x0\n" + "whilelt p3.h, x19, x13\n" + "inch x19\n" + "whilelt p2.h, x19, x13\n" + "inch x19\n" + "whilelt p1.h, x19, x13\n" + "inch x19\n" + "whilelt p0.h, x19, x13\n" + "cbz x14, 30f\n" + "ld1h { z8.h }, p4/Z, [x14]\n" + "ld1h { z9.h }, p4/Z, [x14, #1, MUL VL]\n" + "mov z12.d, z8.d\n" + "mov z13.d, z9.d\n" + "ld1h { z10.h }, p4/Z, [x14, #2, MUL VL]\n" + "ld1h { z11.h }, p4/Z, [x14, #3, MUL VL]\n" + "mov z14.d, z10.d\n" + "mov z15.d, z11.d\n" + "mov z16.d, z8.d\n" + "mov z17.d, z9.d\n" + "addvl x14, x14, #4\n" + "mov z18.d, z10.d\n" + "mov z19.d, z11.d\n" + "b 32f\n" + "30:" // Height 3: no bias + "tbz %x[flags], #0, 31f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "ld1h { z8.h }, p3/Z, [x12]\n" + "ld1h { z9.h }, p2/Z, [x12, #1, MUL VL]\n" + "ld1h { z10.h }, p1/Z, [x12, #2, MUL VL]\n" + "ld1h { z11.h }, p0/Z, [x12, #3, MUL VL]\n" + "ld1h { z12.h }, p3/Z, [x24]\n" + "ld1h { z13.h }, p2/Z, [x24, #1, MUL VL]\n" + "ld1h { z14.h }, p1/Z, [x24, #2, MUL VL]\n" + "ld1h { z15.h }, p0/Z, [x24, #3, MUL VL]\n" + "ld1h { z16.h }, p3/Z, [x23]\n" + "ld1h { z17.h }, p2/Z, [x23, #1, MUL VL]\n" + "ld1h { z18.h }, p1/Z, [x23, #2, MUL VL]\n" + "ld1h { z19.h }, p0/Z, [x23, #3, MUL VL]\n" + "b 32f\n" + "31:" // Height 3: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "mov z19.b, #0x0\n" + "32:" // Height 3: setup done + "mov x27, #0x0\n" + "33:" // Height 3: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 34f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "cbnz x27, 35f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "add x24, x24, x19, LSL #1\n" + "add x23, x23, x19, LSL #1\n" + "b 35f\n" + "34:" // Height 3: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "35:" // Height 3: input setup done + "subs x26, x26, #0x1\n" + "ld1rh { z0.h }, p4/Z, [x25]\n" + "ld1rh { z1.h }, p4/Z, [x24]\n" + "ld1rh { z2.h }, p4/Z, [x23]\n" + "ld1h { z6.h }, p4/Z, [x11]\n" + "ld1h { z7.h }, p4/Z, [x10]\n" + "ble 37f\n" + "36:" // Height 3: Multiply loop: Main loop + "fmla z8.h, p4/M, z6.h, z0.h\n" + "fmla z12.h, p4/M, z6.h, z1.h\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z16.h, p4/M, z6.h, z2.h\n" + "fmla z9.h, p4/M, z7.h, z0.h\n" + "ld1h { z6.h }, p4/Z, [x9]\n" + "add x25, x25, #0x2\n" + "fmla z13.h, p4/M, z7.h, z1.h\n" + "fmla z17.h, p4/M, z7.h, z2.h\n" + "ld1h { z7.h }, p4/Z, [x28]\n" + "subs x26, x26, #0x1\n" + "add x24, x24, #0x2\n" + "add x23, x23, #0x2\n" + "fmla z10.h, p4/M, z6.h, z0.h\n" + "fmla z14.h, p4/M, z6.h, z1.h\n" + "fmla z18.h, p4/M, z6.h, z2.h\n" + "fmla z11.h, p4/M, z7.h, z0.h\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "fmla z15.h, p4/M, z7.h, z1.h\n" + "fmla z19.h, p4/M, z7.h, z2.h\n" + "ld1rh { z0.h }, p4/Z, [x25]\n" + "ld1rh { z1.h }, p4/Z, [x24]\n" + "ld1rh { z2.h }, p4/Z, [x23]\n" + "ld1h { z6.h }, p4/Z, [x11]\n" + "ld1h { z7.h }, p4/Z, [x10]\n" + "bgt 36b\n" + "37:" // Height 3: Multiply loop: Main loop skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "fmla z8.h, p4/M, z6.h, z0.h\n" + "fmla z12.h, p4/M, z6.h, z1.h\n" + "add x27, x27, #0x1\n" + "fmla z16.h, p4/M, z6.h, z2.h\n" + "fmla z9.h, p4/M, z7.h, z0.h\n" + "ld1h { z6.h }, p4/Z, [x9]\n" + "cmp x27, x19\n" + "fmla z13.h, p4/M, z7.h, z1.h\n" + "fmla z17.h, p4/M, z7.h, z2.h\n" + "ld1h { z7.h }, p4/Z, [x28]\n" + "addvl x11, x11, #1\n" + "fmla z10.h, p4/M, z6.h, z0.h\n" + "fmla z14.h, p4/M, z6.h, z1.h\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "fmla z18.h, p4/M, z6.h, z2.h\n" + "fmla z11.h, p4/M, z7.h, z0.h\n" + "addvl x28, x28, #1\n" + "fmla z15.h, p4/M, z7.h, z1.h\n" + "fmla z19.h, p4/M, z7.h, z2.h\n" + "bne 33b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "tbz %x[flags], #1, 38f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rh { z1.h }, p4/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rh { z0.h }, p4/Z, [x19]\n" + "fmin z8.h, p4/M, z8.h, z1.h\n" + "fmin z9.h, p4/M, z9.h, z1.h\n" + "fmin z10.h, p4/M, z10.h, z1.h\n" + "fmin z11.h, p4/M, z11.h, z1.h\n" + "fmin z12.h, p4/M, z12.h, z1.h\n" + "fmin z13.h, p4/M, z13.h, z1.h\n" + "fmin z14.h, p4/M, z14.h, z1.h\n" + "fmin z15.h, p4/M, z15.h, z1.h\n" + "fmin z16.h, p4/M, z16.h, z1.h\n" + "fmin z17.h, p4/M, z17.h, z1.h\n" + "fmin z18.h, p4/M, z18.h, z1.h\n" + "fmin z19.h, p4/M, z19.h, z1.h\n" + "fmax z8.h, p4/M, z8.h, z0.h\n" + "fmax z9.h, p4/M, z9.h, z0.h\n" + "fmax z10.h, p4/M, z10.h, z0.h\n" + "fmax z11.h, p4/M, z11.h, z0.h\n" + "fmax z12.h, p4/M, z12.h, z0.h\n" + "fmax z13.h, p4/M, z13.h, z0.h\n" + "fmax z14.h, p4/M, z14.h, z0.h\n" + "fmax z15.h, p4/M, z15.h, z0.h\n" + "fmax z16.h, p4/M, z16.h, z0.h\n" + "fmax z17.h, p4/M, z17.h, z0.h\n" + "fmax z18.h, p4/M, z18.h, z0.h\n" + "fmax z19.h, p4/M, z19.h, z0.h\n" + "38:" // Height 3: No activation + "st1h { z8.h }, p3, [x12]\n" + "st1h { z9.h }, p2, [x12, #1, MUL VL]\n" + "st1h { z10.h }, p1, [x12, #2, MUL VL]\n" + "st1h { z11.h }, p0, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "st1h { z12.h }, p3, [x24]\n" + "st1h { z13.h }, p2, [x24, #1, MUL VL]\n" + "st1h { z14.h }, p1, [x24, #2, MUL VL]\n" + "st1h { z15.h }, p0, [x24, #3, MUL VL]\n" + "st1h { z16.h }, p3, [x23]\n" + "st1h { z17.h }, p2, [x23, #1, MUL VL]\n" + "st1h { z18.h }, p1, [x23, #2, MUL VL]\n" + "st1h { z19.h }, p0, [x23, #3, MUL VL]\n" + "39:" // Height 3: Writeback done + "dech x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 28b\n" + "b 80f\n" + "40:" // Height 4 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "41:" // Height 4: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "cnth x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 42f\n" + "dech x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 42f\n" + "dech x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 42f\n" + "mov x10, x11\n" + "42:" // Height 4: B setup done + "mov x19, #0x0\n" + "whilelt p3.h, x19, x13\n" + "inch x19\n" + "whilelt p2.h, x19, x13\n" + "inch x19\n" + "whilelt p1.h, x19, x13\n" + "inch x19\n" + "whilelt p0.h, x19, x13\n" + "cbz x14, 43f\n" + "ld1h { z8.h }, p4/Z, [x14]\n" + "ld1h { z9.h }, p4/Z, [x14, #1, MUL VL]\n" + "mov z12.d, z8.d\n" + "mov z13.d, z9.d\n" + "ld1h { z10.h }, p4/Z, [x14, #2, MUL VL]\n" + "ld1h { z11.h }, p4/Z, [x14, #3, MUL VL]\n" + "mov z14.d, z10.d\n" + "mov z15.d, z11.d\n" + "mov z16.d, z8.d\n" + "mov z17.d, z9.d\n" + "addvl x14, x14, #4\n" + "mov z18.d, z10.d\n" + "mov z19.d, z11.d\n" + "mov z20.d, z8.d\n" + "mov z21.d, z9.d\n" + "mov z22.d, z10.d\n" + "mov z23.d, z11.d\n" + "b 45f\n" + "43:" // Height 4: no bias + "tbz %x[flags], #0, 44f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "ld1h { z8.h }, p3/Z, [x12]\n" + "add x22, x23, x19, LSL #1\n" + "ld1h { z9.h }, p2/Z, [x12, #1, MUL VL]\n" + "ld1h { z10.h }, p1/Z, [x12, #2, MUL VL]\n" + "ld1h { z11.h }, p0/Z, [x12, #3, MUL VL]\n" + "ld1h { z12.h }, p3/Z, [x24]\n" + "ld1h { z13.h }, p2/Z, [x24, #1, MUL VL]\n" + "ld1h { z14.h }, p1/Z, [x24, #2, MUL VL]\n" + "ld1h { z15.h }, p0/Z, [x24, #3, MUL VL]\n" + "ld1h { z16.h }, p3/Z, [x23]\n" + "ld1h { z17.h }, p2/Z, [x23, #1, MUL VL]\n" + "ld1h { z18.h }, p1/Z, [x23, #2, MUL VL]\n" + "ld1h { z19.h }, p0/Z, [x23, #3, MUL VL]\n" + "ld1h { z20.h }, p3/Z, [x22]\n" + "ld1h { z21.h }, p2/Z, [x22, #1, MUL VL]\n" + "ld1h { z22.h }, p1/Z, [x22, #2, MUL VL]\n" + "ld1h { z23.h }, p0/Z, [x22, #3, MUL VL]\n" + "b 45f\n" + "44:" // Height 4: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "mov z19.b, #0x0\n" + "mov z20.b, #0x0\n" + "mov z21.b, #0x0\n" + "mov z22.b, #0x0\n" + "mov z23.b, #0x0\n" + "45:" // Height 4: setup done + "mov x27, #0x0\n" + "46:" // Height 4: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 47f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "ldr x22, [x20, #0x18]\n" + "cbnz x27, 48f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "add x24, x24, x19, LSL #1\n" + "add x23, x23, x19, LSL #1\n" + "add x22, x22, x19, LSL #1\n" + "b 48f\n" + "47:" // Height 4: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "add x22, x23, x19, LSL #1\n" + "48:" // Height 4: input setup done + "subs x26, x26, #0x1\n" + "ld1rh { z0.h }, p4/Z, [x25]\n" + "ld1rh { z1.h }, p4/Z, [x24]\n" + "ld1rh { z2.h }, p4/Z, [x23]\n" + "ld1rh { z3.h }, p4/Z, [x22]\n" + "ld1h { z6.h }, p4/Z, [x11]\n" + "ld1h { z7.h }, p4/Z, [x10]\n" + "ble 50f\n" + "49:" // Height 4: Multiply loop: Main loop + "fmla z8.h, p4/M, z6.h, z0.h\n" + "fmla z12.h, p4/M, z6.h, z1.h\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z16.h, p4/M, z6.h, z2.h\n" + "fmla z20.h, p4/M, z6.h, z3.h\n" + "ld1h { z6.h }, p4/Z, [x9]\n" + "add x25, x25, #0x2\n" + "fmla z9.h, p4/M, z7.h, z0.h\n" + "fmla z13.h, p4/M, z7.h, z1.h\n" + "subs x26, x26, #0x1\n" + "add x24, x24, #0x2\n" + "fmla z17.h, p4/M, z7.h, z2.h\n" + "fmla z21.h, p4/M, z7.h, z3.h\n" + "ld1h { z7.h }, p4/Z, [x28]\n" + "add x23, x23, #0x2\n" + "add x22, x22, #0x2\n" + "fmla z10.h, p4/M, z6.h, z0.h\n" + "fmla z14.h, p4/M, z6.h, z1.h\n" + "addvl x9, x9, #1\n" + "fmla z18.h, p4/M, z6.h, z2.h\n" + "fmla z22.h, p4/M, z6.h, z3.h\n" + "addvl x28, x28, #1\n" + "ld1h { z6.h }, p4/Z, [x11]\n" + "fmla z11.h, p4/M, z7.h, z0.h\n" + "fmla z15.h, p4/M, z7.h, z1.h\n" + "ld1rh { z0.h }, p4/Z, [x25]\n" + "ld1rh { z1.h }, p4/Z, [x24]\n" + "fmla z19.h, p4/M, z7.h, z2.h\n" + "fmla z23.h, p4/M, z7.h, z3.h\n" + "ld1rh { z2.h }, p4/Z, [x23]\n" + "ld1rh { z3.h }, p4/Z, [x22]\n" + "ld1h { z7.h }, p4/Z, [x10]\n" + "bgt 49b\n" + "50:" // Height 4: Multiply loop: Main loop skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "fmla z8.h, p4/M, z6.h, z0.h\n" + "fmla z12.h, p4/M, z6.h, z1.h\n" + "add x27, x27, #0x1\n" + "fmla z16.h, p4/M, z6.h, z2.h\n" + "fmla z20.h, p4/M, z6.h, z3.h\n" + "ld1h { z6.h }, p4/Z, [x9]\n" + "cmp x27, x19\n" + "fmla z9.h, p4/M, z7.h, z0.h\n" + "fmla z13.h, p4/M, z7.h, z1.h\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z17.h, p4/M, z7.h, z2.h\n" + "fmla z21.h, p4/M, z7.h, z3.h\n" + "ld1h { z7.h }, p4/Z, [x28]\n" + "addvl x9, x9, #1\n" + "fmla z10.h, p4/M, z6.h, z0.h\n" + "fmla z14.h, p4/M, z6.h, z1.h\n" + "addvl x28, x28, #1\n" + "fmla z18.h, p4/M, z6.h, z2.h\n" + "fmla z22.h, p4/M, z6.h, z3.h\n" + "fmla z11.h, p4/M, z7.h, z0.h\n" + "fmla z15.h, p4/M, z7.h, z1.h\n" + "fmla z19.h, p4/M, z7.h, z2.h\n" + "fmla z23.h, p4/M, z7.h, z3.h\n" + "bne 46b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "add x22, x23, x19, LSL #1\n" + "tbz %x[flags], #1, 51f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rh { z1.h }, p4/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rh { z0.h }, p4/Z, [x19]\n" + "fmin z8.h, p4/M, z8.h, z1.h\n" + "fmin z9.h, p4/M, z9.h, z1.h\n" + "fmin z10.h, p4/M, z10.h, z1.h\n" + "fmin z11.h, p4/M, z11.h, z1.h\n" + "fmin z12.h, p4/M, z12.h, z1.h\n" + "fmin z13.h, p4/M, z13.h, z1.h\n" + "fmin z14.h, p4/M, z14.h, z1.h\n" + "fmin z15.h, p4/M, z15.h, z1.h\n" + "fmin z16.h, p4/M, z16.h, z1.h\n" + "fmin z17.h, p4/M, z17.h, z1.h\n" + "fmin z18.h, p4/M, z18.h, z1.h\n" + "fmin z19.h, p4/M, z19.h, z1.h\n" + "fmin z20.h, p4/M, z20.h, z1.h\n" + "fmin z21.h, p4/M, z21.h, z1.h\n" + "fmin z22.h, p4/M, z22.h, z1.h\n" + "fmin z23.h, p4/M, z23.h, z1.h\n" + "fmax z8.h, p4/M, z8.h, z0.h\n" + "fmax z9.h, p4/M, z9.h, z0.h\n" + "fmax z10.h, p4/M, z10.h, z0.h\n" + "fmax z11.h, p4/M, z11.h, z0.h\n" + "fmax z12.h, p4/M, z12.h, z0.h\n" + "fmax z13.h, p4/M, z13.h, z0.h\n" + "fmax z14.h, p4/M, z14.h, z0.h\n" + "fmax z15.h, p4/M, z15.h, z0.h\n" + "fmax z16.h, p4/M, z16.h, z0.h\n" + "fmax z17.h, p4/M, z17.h, z0.h\n" + "fmax z18.h, p4/M, z18.h, z0.h\n" + "fmax z19.h, p4/M, z19.h, z0.h\n" + "fmax z20.h, p4/M, z20.h, z0.h\n" + "fmax z21.h, p4/M, z21.h, z0.h\n" + "fmax z22.h, p4/M, z22.h, z0.h\n" + "fmax z23.h, p4/M, z23.h, z0.h\n" + "51:" // Height 4: No activation + "st1h { z8.h }, p3, [x12]\n" + "st1h { z9.h }, p2, [x12, #1, MUL VL]\n" + "st1h { z10.h }, p1, [x12, #2, MUL VL]\n" + "st1h { z11.h }, p0, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "st1h { z12.h }, p3, [x24]\n" + "st1h { z13.h }, p2, [x24, #1, MUL VL]\n" + "st1h { z14.h }, p1, [x24, #2, MUL VL]\n" + "st1h { z15.h }, p0, [x24, #3, MUL VL]\n" + "st1h { z16.h }, p3, [x23]\n" + "st1h { z17.h }, p2, [x23, #1, MUL VL]\n" + "st1h { z18.h }, p1, [x23, #2, MUL VL]\n" + "st1h { z19.h }, p0, [x23, #3, MUL VL]\n" + "st1h { z20.h }, p3, [x22]\n" + "st1h { z21.h }, p2, [x22, #1, MUL VL]\n" + "st1h { z22.h }, p1, [x22, #2, MUL VL]\n" + "st1h { z23.h }, p0, [x22, #3, MUL VL]\n" + "52:" // Height 4: Writeback done + "dech x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 41b\n" + "b 80f\n" + "53:" // Height 5 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "54:" // Height 5: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "cnth x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 55f\n" + "dech x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 55f\n" + "dech x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 55f\n" + "mov x10, x11\n" + "55:" // Height 5: B setup done + "mov x19, #0x0\n" + "whilelt p3.h, x19, x13\n" + "inch x19\n" + "whilelt p2.h, x19, x13\n" + "inch x19\n" + "whilelt p1.h, x19, x13\n" + "inch x19\n" + "whilelt p0.h, x19, x13\n" + "cbz x14, 56f\n" + "ld1h { z8.h }, p4/Z, [x14]\n" + "ld1h { z9.h }, p4/Z, [x14, #1, MUL VL]\n" + "mov z12.d, z8.d\n" + "mov z13.d, z9.d\n" + "ld1h { z10.h }, p4/Z, [x14, #2, MUL VL]\n" + "ld1h { z11.h }, p4/Z, [x14, #3, MUL VL]\n" + "mov z14.d, z10.d\n" + "mov z15.d, z11.d\n" + "mov z16.d, z8.d\n" + "mov z17.d, z9.d\n" + "addvl x14, x14, #4\n" + "mov z18.d, z10.d\n" + "mov z19.d, z11.d\n" + "mov z20.d, z8.d\n" + "mov z21.d, z9.d\n" + "mov z22.d, z10.d\n" + "mov z23.d, z11.d\n" + "mov z24.d, z8.d\n" + "mov z25.d, z9.d\n" + "mov z26.d, z10.d\n" + "mov z27.d, z11.d\n" + "b 58f\n" + "56:" // Height 5: no bias + "tbz %x[flags], #0, 57f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "ld1h { z8.h }, p3/Z, [x12]\n" + "add x22, x23, x19, LSL #1\n" + "add x21, x22, x19, LSL #1\n" + "ld1h { z9.h }, p2/Z, [x12, #1, MUL VL]\n" + "ld1h { z10.h }, p1/Z, [x12, #2, MUL VL]\n" + "ld1h { z11.h }, p0/Z, [x12, #3, MUL VL]\n" + "ld1h { z12.h }, p3/Z, [x24]\n" + "ld1h { z13.h }, p2/Z, [x24, #1, MUL VL]\n" + "ld1h { z14.h }, p1/Z, [x24, #2, MUL VL]\n" + "ld1h { z15.h }, p0/Z, [x24, #3, MUL VL]\n" + "ld1h { z16.h }, p3/Z, [x23]\n" + "ld1h { z17.h }, p2/Z, [x23, #1, MUL VL]\n" + "ld1h { z18.h }, p1/Z, [x23, #2, MUL VL]\n" + "ld1h { z19.h }, p0/Z, [x23, #3, MUL VL]\n" + "ld1h { z20.h }, p3/Z, [x22]\n" + "ld1h { z21.h }, p2/Z, [x22, #1, MUL VL]\n" + "ld1h { z22.h }, p1/Z, [x22, #2, MUL VL]\n" + "ld1h { z23.h }, p0/Z, [x22, #3, MUL VL]\n" + "ld1h { z24.h }, p3/Z, [x21]\n" + "ld1h { z25.h }, p2/Z, [x21, #1, MUL VL]\n" + "ld1h { z26.h }, p1/Z, [x21, #2, MUL VL]\n" + "ld1h { z27.h }, p0/Z, [x21, #3, MUL VL]\n" + "b 58f\n" + "57:" // Height 5: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "mov z19.b, #0x0\n" + "mov z20.b, #0x0\n" + "mov z21.b, #0x0\n" + "mov z22.b, #0x0\n" + "mov z23.b, #0x0\n" + "mov z24.b, #0x0\n" + "mov z25.b, #0x0\n" + "mov z26.b, #0x0\n" + "mov z27.b, #0x0\n" + "58:" // Height 5: setup done + "mov x27, #0x0\n" + "59:" // Height 5: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 60f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "ldr x22, [x20, #0x18]\n" + "ldr x21, [x20, #0x20]\n" + "cbnz x27, 61f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "add x24, x24, x19, LSL #1\n" + "add x23, x23, x19, LSL #1\n" + "add x22, x22, x19, LSL #1\n" + "add x21, x21, x19, LSL #1\n" + "b 61f\n" + "60:" // Height 5: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "add x22, x23, x19, LSL #1\n" + "add x21, x22, x19, LSL #1\n" + "61:" // Height 5: input setup done + "subs x26, x26, #0x1\n" + "ld1rh { z0.h }, p4/Z, [x25]\n" + "ld1rh { z1.h }, p4/Z, [x24]\n" + "ld1rh { z2.h }, p4/Z, [x23]\n" + "ld1rh { z3.h }, p4/Z, [x22]\n" + "ld1rh { z4.h }, p4/Z, [x21]\n" + "ld1h { z6.h }, p4/Z, [x11]\n" + "ld1h { z7.h }, p4/Z, [x10]\n" + "ble 63f\n" + "62:" // Height 5: Multiply loop: Main loop + "fmla z8.h, p4/M, z6.h, z0.h\n" + "fmla z12.h, p4/M, z6.h, z1.h\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z16.h, p4/M, z6.h, z2.h\n" + "fmla z20.h, p4/M, z6.h, z3.h\n" + "add x25, x25, #0x2\n" + "subs x26, x26, #0x1\n" + "fmla z24.h, p4/M, z6.h, z4.h\n" + "fmla z9.h, p4/M, z7.h, z0.h\n" + "ld1h { z6.h }, p4/Z, [x9]\n" + "add x24, x24, #0x2\n" + "fmla z13.h, p4/M, z7.h, z1.h\n" + "fmla z17.h, p4/M, z7.h, z2.h\n" + "add x23, x23, #0x2\n" + "add x22, x22, #0x2\n" + "fmla z21.h, p4/M, z7.h, z3.h\n" + "fmla z25.h, p4/M, z7.h, z4.h\n" + "ld1h { z7.h }, p4/Z, [x28]\n" + "add x21, x21, #0x2\n" + "fmla z10.h, p4/M, z6.h, z0.h\n" + "fmla z14.h, p4/M, z6.h, z1.h\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "fmla z18.h, p4/M, z6.h, z2.h\n" + "fmla z22.h, p4/M, z6.h, z3.h\n" + "fmla z26.h, p4/M, z6.h, z4.h\n" + "fmla z11.h, p4/M, z7.h, z0.h\n" + "ld1rh { z0.h }, p4/Z, [x25]\n" + "ld1h { z6.h }, p4/Z, [x11]\n" + "fmla z15.h, p4/M, z7.h, z1.h\n" + "fmla z19.h, p4/M, z7.h, z2.h\n" + "ld1rh { z1.h }, p4/Z, [x24]\n" + "ld1rh { z2.h }, p4/Z, [x23]\n" + "fmla z23.h, p4/M, z7.h, z3.h\n" + "fmla z27.h, p4/M, z7.h, z4.h\n" + "ld1rh { z3.h }, p4/Z, [x22]\n" + "ld1rh { z4.h }, p4/Z, [x21]\n" + "ld1h { z7.h }, p4/Z, [x10]\n" + "bgt 62b\n" + "63:" // Height 5: Multiply loop: Main loop skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "fmla z8.h, p4/M, z6.h, z0.h\n" + "fmla z12.h, p4/M, z6.h, z1.h\n" + "add x27, x27, #0x1\n" + "fmla z16.h, p4/M, z6.h, z2.h\n" + "fmla z20.h, p4/M, z6.h, z3.h\n" + "cmp x27, x19\n" + "addvl x11, x11, #1\n" + "fmla z24.h, p4/M, z6.h, z4.h\n" + "fmla z9.h, p4/M, z7.h, z0.h\n" + "ld1h { z6.h }, p4/Z, [x9]\n" + "addvl x10, x10, #1\n" + "fmla z13.h, p4/M, z7.h, z1.h\n" + "fmla z17.h, p4/M, z7.h, z2.h\n" + "addvl x9, x9, #1\n" + "fmla z21.h, p4/M, z7.h, z3.h\n" + "fmla z25.h, p4/M, z7.h, z4.h\n" + "ld1h { z7.h }, p4/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.h, p4/M, z6.h, z0.h\n" + "fmla z14.h, p4/M, z6.h, z1.h\n" + "fmla z18.h, p4/M, z6.h, z2.h\n" + "fmla z22.h, p4/M, z6.h, z3.h\n" + "fmla z26.h, p4/M, z6.h, z4.h\n" + "fmla z11.h, p4/M, z7.h, z0.h\n" + "fmla z15.h, p4/M, z7.h, z1.h\n" + "fmla z19.h, p4/M, z7.h, z2.h\n" + "fmla z23.h, p4/M, z7.h, z3.h\n" + "fmla z27.h, p4/M, z7.h, z4.h\n" + "bne 59b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "add x22, x23, x19, LSL #1\n" + "add x21, x22, x19, LSL #1\n" + "tbz %x[flags], #1, 64f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rh { z1.h }, p4/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rh { z0.h }, p4/Z, [x19]\n" + "fmin z8.h, p4/M, z8.h, z1.h\n" + "fmin z9.h, p4/M, z9.h, z1.h\n" + "fmin z10.h, p4/M, z10.h, z1.h\n" + "fmin z11.h, p4/M, z11.h, z1.h\n" + "fmin z12.h, p4/M, z12.h, z1.h\n" + "fmin z13.h, p4/M, z13.h, z1.h\n" + "fmin z14.h, p4/M, z14.h, z1.h\n" + "fmin z15.h, p4/M, z15.h, z1.h\n" + "fmin z16.h, p4/M, z16.h, z1.h\n" + "fmin z17.h, p4/M, z17.h, z1.h\n" + "fmin z18.h, p4/M, z18.h, z1.h\n" + "fmin z19.h, p4/M, z19.h, z1.h\n" + "fmin z20.h, p4/M, z20.h, z1.h\n" + "fmin z21.h, p4/M, z21.h, z1.h\n" + "fmin z22.h, p4/M, z22.h, z1.h\n" + "fmin z23.h, p4/M, z23.h, z1.h\n" + "fmin z24.h, p4/M, z24.h, z1.h\n" + "fmin z25.h, p4/M, z25.h, z1.h\n" + "fmin z26.h, p4/M, z26.h, z1.h\n" + "fmin z27.h, p4/M, z27.h, z1.h\n" + "fmax z8.h, p4/M, z8.h, z0.h\n" + "fmax z9.h, p4/M, z9.h, z0.h\n" + "fmax z10.h, p4/M, z10.h, z0.h\n" + "fmax z11.h, p4/M, z11.h, z0.h\n" + "fmax z12.h, p4/M, z12.h, z0.h\n" + "fmax z13.h, p4/M, z13.h, z0.h\n" + "fmax z14.h, p4/M, z14.h, z0.h\n" + "fmax z15.h, p4/M, z15.h, z0.h\n" + "fmax z16.h, p4/M, z16.h, z0.h\n" + "fmax z17.h, p4/M, z17.h, z0.h\n" + "fmax z18.h, p4/M, z18.h, z0.h\n" + "fmax z19.h, p4/M, z19.h, z0.h\n" + "fmax z20.h, p4/M, z20.h, z0.h\n" + "fmax z21.h, p4/M, z21.h, z0.h\n" + "fmax z22.h, p4/M, z22.h, z0.h\n" + "fmax z23.h, p4/M, z23.h, z0.h\n" + "fmax z24.h, p4/M, z24.h, z0.h\n" + "fmax z25.h, p4/M, z25.h, z0.h\n" + "fmax z26.h, p4/M, z26.h, z0.h\n" + "fmax z27.h, p4/M, z27.h, z0.h\n" + "64:" // Height 5: No activation + "st1h { z8.h }, p3, [x12]\n" + "st1h { z9.h }, p2, [x12, #1, MUL VL]\n" + "st1h { z10.h }, p1, [x12, #2, MUL VL]\n" + "st1h { z11.h }, p0, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "st1h { z12.h }, p3, [x24]\n" + "st1h { z13.h }, p2, [x24, #1, MUL VL]\n" + "st1h { z14.h }, p1, [x24, #2, MUL VL]\n" + "st1h { z15.h }, p0, [x24, #3, MUL VL]\n" + "st1h { z16.h }, p3, [x23]\n" + "st1h { z17.h }, p2, [x23, #1, MUL VL]\n" + "st1h { z18.h }, p1, [x23, #2, MUL VL]\n" + "st1h { z19.h }, p0, [x23, #3, MUL VL]\n" + "st1h { z20.h }, p3, [x22]\n" + "st1h { z21.h }, p2, [x22, #1, MUL VL]\n" + "st1h { z22.h }, p1, [x22, #2, MUL VL]\n" + "st1h { z23.h }, p0, [x22, #3, MUL VL]\n" + "st1h { z24.h }, p3, [x21]\n" + "st1h { z25.h }, p2, [x21, #1, MUL VL]\n" + "st1h { z26.h }, p1, [x21, #2, MUL VL]\n" + "st1h { z27.h }, p0, [x21, #3, MUL VL]\n" + "65:" // Height 5: Writeback done + "dech x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 54b\n" + "b 80f\n" + "66:" // Height 6 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x20, #0xc\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "mov x14, %x[bias]\n" + "mov x12, %x[output_ptr]\n" + "madd %x[output_ptr], x19, x20, %x[output_ptr]\n" + "67:" // Height 6: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "cnth x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 68f\n" + "dech x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 68f\n" + "dech x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 68f\n" + "mov x10, x11\n" + "68:" // Height 6: B setup done + "mov x19, #0x0\n" + "whilelt p3.h, x19, x13\n" + "inch x19\n" + "whilelt p2.h, x19, x13\n" + "inch x19\n" + "whilelt p1.h, x19, x13\n" + "inch x19\n" + "whilelt p0.h, x19, x13\n" + "cbz x14, 69f\n" + "ld1h { z8.h }, p4/Z, [x14]\n" + "ld1h { z9.h }, p4/Z, [x14, #1, MUL VL]\n" + "mov z12.d, z8.d\n" + "mov z13.d, z9.d\n" + "ld1h { z10.h }, p4/Z, [x14, #2, MUL VL]\n" + "ld1h { z11.h }, p4/Z, [x14, #3, MUL VL]\n" + "mov z14.d, z10.d\n" + "mov z15.d, z11.d\n" + "mov z16.d, z8.d\n" + "mov z17.d, z9.d\n" + "addvl x14, x14, #4\n" + "mov z18.d, z10.d\n" + "mov z19.d, z11.d\n" + "mov z20.d, z8.d\n" + "mov z21.d, z9.d\n" + "mov z22.d, z10.d\n" + "mov z23.d, z11.d\n" + "mov z24.d, z8.d\n" + "mov z25.d, z9.d\n" + "mov z26.d, z10.d\n" + "mov z27.d, z11.d\n" + "mov z28.d, z8.d\n" + "mov z29.d, z9.d\n" + "mov z30.d, z10.d\n" + "mov z31.d, z11.d\n" + "b 71f\n" + "69:" // Height 6: no bias + "tbz %x[flags], #0, 70f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "ld1h { z8.h }, p3/Z, [x12]\n" + "add x22, x23, x19, LSL #1\n" + "add x21, x22, x19, LSL #1\n" + "ld1h { z9.h }, p2/Z, [x12, #1, MUL VL]\n" + "ld1h { z10.h }, p1/Z, [x12, #2, MUL VL]\n" + "add x20, x21, x19, LSL #1\n" + "ld1h { z11.h }, p0/Z, [x12, #3, MUL VL]\n" + "ld1h { z12.h }, p3/Z, [x24]\n" + "ld1h { z13.h }, p2/Z, [x24, #1, MUL VL]\n" + "ld1h { z14.h }, p1/Z, [x24, #2, MUL VL]\n" + "ld1h { z15.h }, p0/Z, [x24, #3, MUL VL]\n" + "ld1h { z16.h }, p3/Z, [x23]\n" + "ld1h { z17.h }, p2/Z, [x23, #1, MUL VL]\n" + "ld1h { z18.h }, p1/Z, [x23, #2, MUL VL]\n" + "ld1h { z19.h }, p0/Z, [x23, #3, MUL VL]\n" + "ld1h { z20.h }, p3/Z, [x22]\n" + "ld1h { z21.h }, p2/Z, [x22, #1, MUL VL]\n" + "ld1h { z22.h }, p1/Z, [x22, #2, MUL VL]\n" + "ld1h { z23.h }, p0/Z, [x22, #3, MUL VL]\n" + "ld1h { z24.h }, p3/Z, [x21]\n" + "ld1h { z25.h }, p2/Z, [x21, #1, MUL VL]\n" + "ld1h { z26.h }, p1/Z, [x21, #2, MUL VL]\n" + "ld1h { z27.h }, p0/Z, [x21, #3, MUL VL]\n" + "ld1h { z28.h }, p3/Z, [x20]\n" + "ld1h { z29.h }, p2/Z, [x20, #1, MUL VL]\n" + "ld1h { z30.h }, p1/Z, [x20, #2, MUL VL]\n" + "ld1h { z31.h }, p0/Z, [x20, #3, MUL VL]\n" + "b 71f\n" + "70:" // Height 6: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "mov z19.b, #0x0\n" + "mov z20.b, #0x0\n" + "mov z21.b, #0x0\n" + "mov z22.b, #0x0\n" + "mov z23.b, #0x0\n" + "mov z24.b, #0x0\n" + "mov z25.b, #0x0\n" + "mov z26.b, #0x0\n" + "mov z27.b, #0x0\n" + "mov z28.b, #0x0\n" + "mov z29.b, #0x0\n" + "mov z30.b, #0x0\n" + "mov z31.b, #0x0\n" + "71:" // Height 6: setup done + "mov x27, #0x0\n" + "72:" // Height 6: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 73f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "ldr x22, [x20, #0x18]\n" + "ldr x21, [x20, #0x20]\n" + "ldr x20, [x20, #0x28]\n" + "cbnz x27, 74f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "add x24, x24, x19, LSL #1\n" + "add x23, x23, x19, LSL #1\n" + "add x22, x22, x19, LSL #1\n" + "add x21, x21, x19, LSL #1\n" + "add x20, x20, x19, LSL #1\n" + "b 74f\n" + "73:" // Height 6: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "add x22, x23, x19, LSL #1\n" + "add x21, x22, x19, LSL #1\n" + "add x20, x21, x19, LSL #1\n" + "74:" // Height 6: input setup done + "subs x26, x26, #0x1\n" + "ld1rh { z0.h }, p4/Z, [x25]\n" + "ld1rh { z1.h }, p4/Z, [x24]\n" + "ld1rh { z2.h }, p4/Z, [x23]\n" + "ld1rh { z3.h }, p4/Z, [x22]\n" + "ld1rh { z4.h }, p4/Z, [x21]\n" + "ld1rh { z5.h }, p4/Z, [x20]\n" + "ld1h { z6.h }, p4/Z, [x11]\n" + "ld1h { z7.h }, p4/Z, [x10]\n" + "ble 76f\n" + "75:" // Height 6: Multiply loop: Main loop + "fmla z8.h, p4/M, z6.h, z0.h\n" + "fmla z12.h, p4/M, z6.h, z1.h\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z16.h, p4/M, z6.h, z2.h\n" + "fmla z20.h, p4/M, z6.h, z3.h\n" + "add x25, x25, #0x2\n" + "subs x26, x26, #0x1\n" + "fmla z24.h, p4/M, z6.h, z4.h\n" + "fmla z28.h, p4/M, z6.h, z5.h\n" + "ld1h { z6.h }, p4/Z, [x9]\n" + "add x24, x24, #0x2\n" + "fmla z9.h, p4/M, z7.h, z0.h\n" + "fmla z13.h, p4/M, z7.h, z1.h\n" + "add x23, x23, #0x2\n" + "add x22, x22, #0x2\n" + "fmla z17.h, p4/M, z7.h, z2.h\n" + "fmla z21.h, p4/M, z7.h, z3.h\n" + "add x21, x21, #0x2\n" + "add x20, x20, #0x2\n" + "fmla z25.h, p4/M, z7.h, z4.h\n" + "fmla z29.h, p4/M, z7.h, z5.h\n" + "ld1h { z7.h }, p4/Z, [x28]\n" + "addvl x9, x9, #1\n" + "fmla z10.h, p4/M, z6.h, z0.h\n" + "fmla z14.h, p4/M, z6.h, z1.h\n" + "addvl x28, x28, #1\n" + "fmla z18.h, p4/M, z6.h, z2.h\n" + "fmla z22.h, p4/M, z6.h, z3.h\n" + "fmla z26.h, p4/M, z6.h, z4.h\n" + "fmla z30.h, p4/M, z6.h, z5.h\n" + "ld1h { z6.h }, p4/Z, [x11]\n" + "fmla z11.h, p4/M, z7.h, z0.h\n" + "fmla z15.h, p4/M, z7.h, z1.h\n" + "ld1rh { z0.h }, p4/Z, [x25]\n" + "ld1rh { z1.h }, p4/Z, [x24]\n" + "fmla z19.h, p4/M, z7.h, z2.h\n" + "fmla z23.h, p4/M, z7.h, z3.h\n" + "ld1rh { z2.h }, p4/Z, [x23]\n" + "ld1rh { z3.h }, p4/Z, [x22]\n" + "fmla z27.h, p4/M, z7.h, z4.h\n" + "fmla z31.h, p4/M, z7.h, z5.h\n" + "ld1rh { z4.h }, p4/Z, [x21]\n" + "ld1rh { z5.h }, p4/Z, [x20]\n" + "ld1h { z7.h }, p4/Z, [x10]\n" + "bgt 75b\n" + "76:" // Height 6: Multiply loop: Main loop skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "fmla z8.h, p4/M, z6.h, z0.h\n" + "fmla z12.h, p4/M, z6.h, z1.h\n" + "add x27, x27, #0x1\n" + "fmla z16.h, p4/M, z6.h, z2.h\n" + "fmla z20.h, p4/M, z6.h, z3.h\n" + "cmp x27, x19\n" + "addvl x11, x11, #1\n" + "fmla z24.h, p4/M, z6.h, z4.h\n" + "fmla z28.h, p4/M, z6.h, z5.h\n" + "ld1h { z6.h }, p4/Z, [x9]\n" + "addvl x10, x10, #1\n" + "fmla z9.h, p4/M, z7.h, z0.h\n" + "fmla z13.h, p4/M, z7.h, z1.h\n" + "addvl x9, x9, #1\n" + "fmla z17.h, p4/M, z7.h, z2.h\n" + "fmla z21.h, p4/M, z7.h, z3.h\n" + "fmla z25.h, p4/M, z7.h, z4.h\n" + "fmla z29.h, p4/M, z7.h, z5.h\n" + "ld1h { z7.h }, p4/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.h, p4/M, z6.h, z0.h\n" + "fmla z14.h, p4/M, z6.h, z1.h\n" + "fmla z18.h, p4/M, z6.h, z2.h\n" + "fmla z22.h, p4/M, z6.h, z3.h\n" + "fmla z26.h, p4/M, z6.h, z4.h\n" + "fmla z30.h, p4/M, z6.h, z5.h\n" + "fmla z11.h, p4/M, z7.h, z0.h\n" + "fmla z15.h, p4/M, z7.h, z1.h\n" + "fmla z19.h, p4/M, z7.h, z2.h\n" + "fmla z23.h, p4/M, z7.h, z3.h\n" + "fmla z27.h, p4/M, z7.h, z4.h\n" + "fmla z31.h, p4/M, z7.h, z5.h\n" + "bne 72b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "add x22, x23, x19, LSL #1\n" + "add x21, x22, x19, LSL #1\n" + "add x20, x21, x19, LSL #1\n" + "tbz %x[flags], #1, 77f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rh { z1.h }, p4/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rh { z0.h }, p4/Z, [x19]\n" + "fmin z8.h, p4/M, z8.h, z1.h\n" + "fmin z9.h, p4/M, z9.h, z1.h\n" + "fmin z10.h, p4/M, z10.h, z1.h\n" + "fmin z11.h, p4/M, z11.h, z1.h\n" + "fmin z12.h, p4/M, z12.h, z1.h\n" + "fmin z13.h, p4/M, z13.h, z1.h\n" + "fmin z14.h, p4/M, z14.h, z1.h\n" + "fmin z15.h, p4/M, z15.h, z1.h\n" + "fmin z16.h, p4/M, z16.h, z1.h\n" + "fmin z17.h, p4/M, z17.h, z1.h\n" + "fmin z18.h, p4/M, z18.h, z1.h\n" + "fmin z19.h, p4/M, z19.h, z1.h\n" + "fmin z20.h, p4/M, z20.h, z1.h\n" + "fmin z21.h, p4/M, z21.h, z1.h\n" + "fmin z22.h, p4/M, z22.h, z1.h\n" + "fmin z23.h, p4/M, z23.h, z1.h\n" + "fmin z24.h, p4/M, z24.h, z1.h\n" + "fmin z25.h, p4/M, z25.h, z1.h\n" + "fmin z26.h, p4/M, z26.h, z1.h\n" + "fmin z27.h, p4/M, z27.h, z1.h\n" + "fmin z28.h, p4/M, z28.h, z1.h\n" + "fmin z29.h, p4/M, z29.h, z1.h\n" + "fmin z30.h, p4/M, z30.h, z1.h\n" + "fmin z31.h, p4/M, z31.h, z1.h\n" + "fmax z8.h, p4/M, z8.h, z0.h\n" + "fmax z9.h, p4/M, z9.h, z0.h\n" + "fmax z10.h, p4/M, z10.h, z0.h\n" + "fmax z11.h, p4/M, z11.h, z0.h\n" + "fmax z12.h, p4/M, z12.h, z0.h\n" + "fmax z13.h, p4/M, z13.h, z0.h\n" + "fmax z14.h, p4/M, z14.h, z0.h\n" + "fmax z15.h, p4/M, z15.h, z0.h\n" + "fmax z16.h, p4/M, z16.h, z0.h\n" + "fmax z17.h, p4/M, z17.h, z0.h\n" + "fmax z18.h, p4/M, z18.h, z0.h\n" + "fmax z19.h, p4/M, z19.h, z0.h\n" + "fmax z20.h, p4/M, z20.h, z0.h\n" + "fmax z21.h, p4/M, z21.h, z0.h\n" + "fmax z22.h, p4/M, z22.h, z0.h\n" + "fmax z23.h, p4/M, z23.h, z0.h\n" + "fmax z24.h, p4/M, z24.h, z0.h\n" + "fmax z25.h, p4/M, z25.h, z0.h\n" + "fmax z26.h, p4/M, z26.h, z0.h\n" + "fmax z27.h, p4/M, z27.h, z0.h\n" + "fmax z28.h, p4/M, z28.h, z0.h\n" + "fmax z29.h, p4/M, z29.h, z0.h\n" + "fmax z30.h, p4/M, z30.h, z0.h\n" + "fmax z31.h, p4/M, z31.h, z0.h\n" + "77:" // Height 6: No activation + "st1h { z8.h }, p3, [x12]\n" + "st1h { z9.h }, p2, [x12, #1, MUL VL]\n" + "st1h { z10.h }, p1, [x12, #2, MUL VL]\n" + "st1h { z11.h }, p0, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "st1h { z12.h }, p3, [x24]\n" + "st1h { z13.h }, p2, [x24, #1, MUL VL]\n" + "st1h { z14.h }, p1, [x24, #2, MUL VL]\n" + "st1h { z15.h }, p0, [x24, #3, MUL VL]\n" + "st1h { z16.h }, p3, [x23]\n" + "st1h { z17.h }, p2, [x23, #1, MUL VL]\n" + "st1h { z18.h }, p1, [x23, #2, MUL VL]\n" + "st1h { z19.h }, p0, [x23, #3, MUL VL]\n" + "st1h { z20.h }, p3, [x22]\n" + "st1h { z21.h }, p2, [x22, #1, MUL VL]\n" + "st1h { z22.h }, p1, [x22, #2, MUL VL]\n" + "st1h { z23.h }, p0, [x22, #3, MUL VL]\n" + "st1h { z24.h }, p3, [x21]\n" + "st1h { z25.h }, p2, [x21, #1, MUL VL]\n" + "st1h { z26.h }, p1, [x21, #2, MUL VL]\n" + "st1h { z27.h }, p0, [x21, #3, MUL VL]\n" + "st1h { z28.h }, p3, [x20]\n" + "st1h { z29.h }, p2, [x20, #1, MUL VL]\n" + "st1h { z30.h }, p1, [x20, #2, MUL VL]\n" + "st1h { z31.h }, p0, [x20, #3, MUL VL]\n" + "78:" // Height 6: Writeback done + "dech x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 67b\n" + "subs %x[M], %x[M], #0x6\n" + "beq 80f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 79f\n" + "add x20, x20, #0x6\n" + "str x20, [%x[args_ptr], %[offsetof_input_offset]]\n" + "b 1b\n" + "79:" // Update direct input + "mov x19, #0xc\n" + "madd %x[input_ptr], x19, x20, %x[input_ptr]\n" + "b 1b\n" + "80:" // Exit + : [M] "+&r" (M), [input_ptr] "+&r" (input_ptr), [output_ptr] "+&r" (output_ptr) + : [args_ptr] "r" (&ka), [bias] "r" (bias), [flags] "r" (flags), [offset_max] "I" (offsetof(KernelArgs, maxval)), [offset_min] "I" (offsetof(KernelArgs, minval)), [offsetof_B_ptr] "I" (offsetof(KernelArgs, B_ptr)), [offsetof_B_stride] "I" (offsetof(KernelArgs, B_stride)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_cur_B_ptr] "I" (offsetof(KernelArgs, cur_B_ptr)), [offsetof_input_initial_col] "I" (offsetof(KernelArgs, input_initial_col)), [offsetof_input_offset] "I" (offsetof(KernelArgs, input_offset)), [offsetof_num_strings] "I" (offsetof(KernelArgs, num_strings)), [offsetof_output_offset] "I" (offsetof(KernelArgs, output_offset)), [offsetof_string_lengths] "I" (offsetof(KernelArgs, string_lengths)) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "x9", "x10", "x11", "x12", "x13", "x14", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // namespace arm_gemm +#endif // ARM_COMPUTE_ENABLE_SVE diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp16_mla_6x4VL/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp16_mla_6x4VL/generic.cpp new file mode 100644 index 0000000000..0f995812d8 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp16_mla_6x4VL/generic.cpp @@ -0,0 +1,3318 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#ifdef ARM_COMPUTE_ENABLE_SVE + +#include "arm_gemm.hpp" +#include "../../utils.hpp" + +#include +#include + +namespace arm_gemm { + +void sve_ffhybrid_fp16_mla_6x4VL ( + unsigned int num_strings, const unsigned int *string_lengths, IndirectInputArg<__fp16> A_arg, + size_t M, size_t N, const __fp16 *B_ptr, size_t B_stride, IndirectOutputArg<__fp16> output_arg, + const __fp16 *bias, Activation act, bool accumulate +) +{ + struct KernelArgs { + __fp16 maxval = static_cast<__fp16>(std::numeric_limits::infinity()); + __fp16 minval = - static_cast<__fp16>(std::numeric_limits::infinity()); + unsigned int num_strings = {}; + const unsigned int *string_lengths = {}; + size_t N = {}; + const __fp16 *B_ptr = {}; + const __fp16 *cur_B_ptr = {}; + size_t B_stride = {}; + size_t output_offset = {}; + size_t input_initial_col = {}; + size_t input_offset = {}; + } ka; + + unsigned long flags=0; + void *output_ptr; + void *input_ptr; + + if (output_arg.is_indirect) { + output_ptr=(void *)(output_arg.indirect.ptr); + ka.output_offset=output_arg.indirect.offset; + flags |= 0x4; + } else { + output_ptr=(void *)(output_arg.direct.base); + ka.output_offset=output_arg.direct.stride; + } + + if (A_arg.is_indirect) { + input_ptr=(void *)(A_arg.indirect.ptr); + ka.input_offset=A_arg.indirect.start_row; + ka.input_initial_col=A_arg.indirect.start_col; + flags |= 0x8; + } else { + assert(num_strings==1); + input_ptr=(void *)(A_arg.direct.base); + ka.input_offset=A_arg.direct.stride; + } + if (accumulate) { + flags |= 0x1; + } + ka.num_strings = num_strings; + ka.string_lengths = string_lengths; + ka.N = N; + ka.B_ptr = B_ptr; + ka.B_stride = B_stride; + switch(act.type) { + default: + case Activation::Type::None: + break; + case Activation::Type::BoundedReLU: + ka.maxval = static_cast<__fp16>(act.param1); + /* fall through */ + case Activation::Type::ReLU: + ka.minval = 0; + flags |= 0x2; + break; + } + __asm__ __volatile__( + "ptrue p5.b\n" + "1:" // Row loop + "cmp %x[M], #0x6\n" + "bge 71f\n" + "cmp %x[M], #0x4\n" + "bgt 57f\n" + "beq 43f\n" + "cmp %x[M], #0x2\n" + "bgt 29f\n" + "beq 15f\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "2:" // Height 1: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "cnth x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 3f\n" + "dech x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 3f\n" + "dech x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 3f\n" + "mov x10, x11\n" + "3:" // Height 1: B setup done + "mov x19, #0x0\n" + "whilelt p4.h, x19, x13\n" + "inch x19\n" + "whilelt p3.h, x19, x13\n" + "inch x19\n" + "whilelt p2.h, x19, x13\n" + "inch x19\n" + "whilelt p1.h, x19, x13\n" + "cbz x14, 4f\n" + "ld1h { z8.h }, p5/Z, [x14]\n" + "ld1h { z9.h }, p5/Z, [x14, #1, MUL VL]\n" + "ld1h { z10.h }, p5/Z, [x14, #2, MUL VL]\n" + "ld1h { z11.h }, p5/Z, [x14, #3, MUL VL]\n" + "addvl x14, x14, #4\n" + "b 6f\n" + "4:" // Height 1: no bias + "tbz %x[flags], #0, 5f\n" + "ld1h { z8.h }, p4/Z, [x12]\n" + "ld1h { z9.h }, p3/Z, [x12, #1, MUL VL]\n" + "ld1h { z10.h }, p2/Z, [x12, #2, MUL VL]\n" + "ld1h { z11.h }, p1/Z, [x12, #3, MUL VL]\n" + "b 6f\n" + "5:" // Height 1: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "6:" // Height 1: setup done + "mov x27, #0x0\n" + "7:" // Height 1: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 8f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "cbnz x27, 9f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "b 9f\n" + "8:" // Height 1: setup direct input + "mov x25, %x[input_ptr]\n" + "9:" // Height 1: input setup done + "cmp x26, #0x8\n" + "ble 11f\n" + "10:" // Height 1: Multiply loop: Main loop head + "whilelt p0.h, XZR, x26\n" + "ld1rqh { z0.h }, p0/Z, [x25]\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "fmla z8.h, z6.h, z0.h[0]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z9.h, z7.h, z0.h[0]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "fmla z10.h, z6.h, z0.h[0]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "fmla z11.h, z7.h, z0.h[0]\n" + "ld1h { z6.h }, p5/Z, [x11, #1, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[1]\n" + "ld1h { z7.h }, p5/Z, [x10, #1, MUL VL]\n" + "fmla z9.h, z7.h, z0.h[1]\n" + "ld1h { z6.h }, p5/Z, [x9, #1, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[1]\n" + "ld1h { z7.h }, p5/Z, [x28, #1, MUL VL]\n" + "fmla z11.h, z7.h, z0.h[1]\n" + "ld1h { z6.h }, p5/Z, [x11, #2, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[2]\n" + "ld1h { z7.h }, p5/Z, [x10, #2, MUL VL]\n" + "fmla z9.h, z7.h, z0.h[2]\n" + "ld1h { z6.h }, p5/Z, [x9, #2, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[2]\n" + "ld1h { z7.h }, p5/Z, [x28, #2, MUL VL]\n" + "fmla z11.h, z7.h, z0.h[2]\n" + "ld1h { z6.h }, p5/Z, [x11, #3, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[3]\n" + "ld1h { z7.h }, p5/Z, [x10, #3, MUL VL]\n" + "fmla z9.h, z7.h, z0.h[3]\n" + "ld1h { z6.h }, p5/Z, [x9, #3, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[3]\n" + "ld1h { z7.h }, p5/Z, [x28, #3, MUL VL]\n" + "fmla z11.h, z7.h, z0.h[3]\n" + "ld1h { z6.h }, p5/Z, [x11, #4, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[4]\n" + "ld1h { z7.h }, p5/Z, [x10, #4, MUL VL]\n" + "fmla z9.h, z7.h, z0.h[4]\n" + "ld1h { z6.h }, p5/Z, [x9, #4, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[4]\n" + "ld1h { z7.h }, p5/Z, [x28, #4, MUL VL]\n" + "fmla z11.h, z7.h, z0.h[4]\n" + "ld1h { z6.h }, p5/Z, [x11, #5, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[5]\n" + "ld1h { z7.h }, p5/Z, [x10, #5, MUL VL]\n" + "fmla z9.h, z7.h, z0.h[5]\n" + "ld1h { z6.h }, p5/Z, [x9, #5, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[5]\n" + "ld1h { z7.h }, p5/Z, [x28, #5, MUL VL]\n" + "fmla z11.h, z7.h, z0.h[5]\n" + "ld1h { z6.h }, p5/Z, [x11, #6, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[6]\n" + "ld1h { z7.h }, p5/Z, [x10, #6, MUL VL]\n" + "fmla z9.h, z7.h, z0.h[6]\n" + "ld1h { z6.h }, p5/Z, [x9, #6, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[6]\n" + "ld1h { z7.h }, p5/Z, [x28, #6, MUL VL]\n" + "fmla z11.h, z7.h, z0.h[6]\n" + "ld1h { z6.h }, p5/Z, [x11, #7, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[7]\n" + "ld1h { z7.h }, p5/Z, [x10, #7, MUL VL]\n" + "fmla z9.h, z7.h, z0.h[7]\n" + "ld1h { z6.h }, p5/Z, [x9, #7, MUL VL]\n" + "sub x26, x26, #0x8\n" + "ld1h { z7.h }, p5/Z, [x28, #7, MUL VL]\n" + "cmp x26, #0x8\n" + "fmla z10.h, z6.h, z0.h[7]\n" + "fmla z11.h, z7.h, z0.h[7]\n" + "add x25, x25, #0x10\n" + "addvl x11, x11, #8\n" + "addvl x10, x10, #8\n" + "addvl x9, x9, #8\n" + "addvl x28, x28, #8\n" + "bgt 10b\n" + "11:" // Height 1: Multiply loop: Single iteration only + "whilelt p0.h, XZR, x26\n" + "ld1rqh { z0.h }, p0/Z, [x25]\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "fmla z8.h, z6.h, z0.h[0]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z9.h, z7.h, z0.h[0]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "subs x26, x26, #0x1\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "fmla z10.h, z6.h, z0.h[0]\n" + "fmla z11.h, z7.h, z0.h[0]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "ble 12f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[1]\n" + "fmla z9.h, z7.h, z0.h[1]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "subs x26, x26, #0x1\n" + "fmla z10.h, z6.h, z0.h[1]\n" + "fmla z11.h, z7.h, z0.h[1]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "ble 12f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[2]\n" + "fmla z9.h, z7.h, z0.h[2]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "subs x26, x26, #0x1\n" + "fmla z10.h, z6.h, z0.h[2]\n" + "fmla z11.h, z7.h, z0.h[2]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "ble 12f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[3]\n" + "fmla z9.h, z7.h, z0.h[3]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "subs x26, x26, #0x1\n" + "fmla z10.h, z6.h, z0.h[3]\n" + "fmla z11.h, z7.h, z0.h[3]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "ble 12f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[4]\n" + "fmla z9.h, z7.h, z0.h[4]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "subs x26, x26, #0x1\n" + "fmla z10.h, z6.h, z0.h[4]\n" + "fmla z11.h, z7.h, z0.h[4]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "ble 12f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[5]\n" + "fmla z9.h, z7.h, z0.h[5]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "subs x26, x26, #0x1\n" + "fmla z10.h, z6.h, z0.h[5]\n" + "fmla z11.h, z7.h, z0.h[5]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "ble 12f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[6]\n" + "fmla z9.h, z7.h, z0.h[6]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "subs x26, x26, #0x1\n" + "fmla z10.h, z6.h, z0.h[6]\n" + "fmla z11.h, z7.h, z0.h[6]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "ble 12f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[7]\n" + "fmla z9.h, z7.h, z0.h[7]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "fmla z10.h, z6.h, z0.h[7]\n" + "fmla z11.h, z7.h, z0.h[7]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "12:" // Height 1: Multiply loop: multiply skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 7b\n" + "tbz %x[flags], #1, 13f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rh { z1.h }, p5/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rh { z0.h }, p5/Z, [x19]\n" + "fmin z8.h, p5/M, z8.h, z1.h\n" + "fmin z9.h, p5/M, z9.h, z1.h\n" + "fmin z10.h, p5/M, z10.h, z1.h\n" + "fmin z11.h, p5/M, z11.h, z1.h\n" + "fmax z8.h, p5/M, z8.h, z0.h\n" + "fmax z9.h, p5/M, z9.h, z0.h\n" + "fmax z10.h, p5/M, z10.h, z0.h\n" + "fmax z11.h, p5/M, z11.h, z0.h\n" + "13:" // Height 1: No activation + "st1h { z8.h }, p4, [x12]\n" + "st1h { z9.h }, p3, [x12, #1, MUL VL]\n" + "st1h { z10.h }, p2, [x12, #2, MUL VL]\n" + "st1h { z11.h }, p1, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "14:" // Height 1: Writeback done + "dech x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 2b\n" + "b 86f\n" + "15:" // Height 2 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "16:" // Height 2: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "cnth x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 17f\n" + "dech x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 17f\n" + "dech x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 17f\n" + "mov x10, x11\n" + "17:" // Height 2: B setup done + "mov x19, #0x0\n" + "whilelt p4.h, x19, x13\n" + "inch x19\n" + "whilelt p3.h, x19, x13\n" + "inch x19\n" + "whilelt p2.h, x19, x13\n" + "inch x19\n" + "whilelt p1.h, x19, x13\n" + "cbz x14, 18f\n" + "ld1h { z8.h }, p5/Z, [x14]\n" + "ld1h { z9.h }, p5/Z, [x14, #1, MUL VL]\n" + "mov z12.d, z8.d\n" + "mov z13.d, z9.d\n" + "ld1h { z10.h }, p5/Z, [x14, #2, MUL VL]\n" + "ld1h { z11.h }, p5/Z, [x14, #3, MUL VL]\n" + "mov z14.d, z10.d\n" + "mov z15.d, z11.d\n" + "addvl x14, x14, #4\n" + "b 20f\n" + "18:" // Height 2: no bias + "tbz %x[flags], #0, 19f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "ld1h { z8.h }, p4/Z, [x12]\n" + "ld1h { z9.h }, p3/Z, [x12, #1, MUL VL]\n" + "ld1h { z10.h }, p2/Z, [x12, #2, MUL VL]\n" + "ld1h { z11.h }, p1/Z, [x12, #3, MUL VL]\n" + "ld1h { z12.h }, p4/Z, [x24]\n" + "ld1h { z13.h }, p3/Z, [x24, #1, MUL VL]\n" + "ld1h { z14.h }, p2/Z, [x24, #2, MUL VL]\n" + "ld1h { z15.h }, p1/Z, [x24, #3, MUL VL]\n" + "b 20f\n" + "19:" // Height 2: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "20:" // Height 2: setup done + "mov x27, #0x0\n" + "21:" // Height 2: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 22f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "cbnz x27, 23f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "add x24, x24, x19, LSL #1\n" + "b 23f\n" + "22:" // Height 2: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #1\n" + "23:" // Height 2: input setup done + "cmp x26, #0x8\n" + "ble 25f\n" + "24:" // Height 2: Multiply loop: Main loop head + "whilelt p0.h, XZR, x26\n" + "ld1rqh { z0.h }, p0/Z, [x25]\n" + "ld1rqh { z1.h }, p0/Z, [x24]\n" + "sub x26, x26, #0x8\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[0]\n" + "fmla z12.h, z6.h, z1.h[0]\n" + "fmla z9.h, z7.h, z0.h[0]\n" + "fmla z13.h, z7.h, z1.h[0]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "fmla z10.h, z6.h, z0.h[0]\n" + "fmla z14.h, z6.h, z1.h[0]\n" + "ld1h { z6.h }, p5/Z, [x11, #1, MUL VL]\n" + "cmp x26, #0x8\n" + "fmla z11.h, z7.h, z0.h[0]\n" + "fmla z15.h, z7.h, z1.h[0]\n" + "ld1h { z7.h }, p5/Z, [x10, #1, MUL VL]\n" + "add x25, x25, #0x10\n" + "fmla z8.h, z6.h, z0.h[1]\n" + "fmla z12.h, z6.h, z1.h[1]\n" + "ld1h { z6.h }, p5/Z, [x9, #1, MUL VL]\n" + "add x24, x24, #0x10\n" + "fmla z9.h, z7.h, z0.h[1]\n" + "fmla z13.h, z7.h, z1.h[1]\n" + "ld1h { z7.h }, p5/Z, [x28, #1, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[1]\n" + "fmla z14.h, z6.h, z1.h[1]\n" + "ld1h { z6.h }, p5/Z, [x11, #2, MUL VL]\n" + "fmla z11.h, z7.h, z0.h[1]\n" + "fmla z15.h, z7.h, z1.h[1]\n" + "ld1h { z7.h }, p5/Z, [x10, #2, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[2]\n" + "fmla z12.h, z6.h, z1.h[2]\n" + "ld1h { z6.h }, p5/Z, [x9, #2, MUL VL]\n" + "fmla z9.h, z7.h, z0.h[2]\n" + "fmla z13.h, z7.h, z1.h[2]\n" + "ld1h { z7.h }, p5/Z, [x28, #2, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[2]\n" + "fmla z14.h, z6.h, z1.h[2]\n" + "ld1h { z6.h }, p5/Z, [x11, #3, MUL VL]\n" + "fmla z11.h, z7.h, z0.h[2]\n" + "fmla z15.h, z7.h, z1.h[2]\n" + "ld1h { z7.h }, p5/Z, [x10, #3, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[3]\n" + "fmla z12.h, z6.h, z1.h[3]\n" + "ld1h { z6.h }, p5/Z, [x9, #3, MUL VL]\n" + "fmla z9.h, z7.h, z0.h[3]\n" + "fmla z13.h, z7.h, z1.h[3]\n" + "ld1h { z7.h }, p5/Z, [x28, #3, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[3]\n" + "fmla z14.h, z6.h, z1.h[3]\n" + "ld1h { z6.h }, p5/Z, [x11, #4, MUL VL]\n" + "fmla z11.h, z7.h, z0.h[3]\n" + "fmla z15.h, z7.h, z1.h[3]\n" + "ld1h { z7.h }, p5/Z, [x10, #4, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[4]\n" + "fmla z12.h, z6.h, z1.h[4]\n" + "ld1h { z6.h }, p5/Z, [x9, #4, MUL VL]\n" + "fmla z9.h, z7.h, z0.h[4]\n" + "fmla z13.h, z7.h, z1.h[4]\n" + "ld1h { z7.h }, p5/Z, [x28, #4, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[4]\n" + "fmla z14.h, z6.h, z1.h[4]\n" + "ld1h { z6.h }, p5/Z, [x11, #5, MUL VL]\n" + "fmla z11.h, z7.h, z0.h[4]\n" + "fmla z15.h, z7.h, z1.h[4]\n" + "ld1h { z7.h }, p5/Z, [x10, #5, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[5]\n" + "fmla z12.h, z6.h, z1.h[5]\n" + "ld1h { z6.h }, p5/Z, [x9, #5, MUL VL]\n" + "fmla z9.h, z7.h, z0.h[5]\n" + "fmla z13.h, z7.h, z1.h[5]\n" + "ld1h { z7.h }, p5/Z, [x28, #5, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[5]\n" + "fmla z14.h, z6.h, z1.h[5]\n" + "ld1h { z6.h }, p5/Z, [x11, #6, MUL VL]\n" + "fmla z11.h, z7.h, z0.h[5]\n" + "fmla z15.h, z7.h, z1.h[5]\n" + "ld1h { z7.h }, p5/Z, [x10, #6, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[6]\n" + "fmla z12.h, z6.h, z1.h[6]\n" + "ld1h { z6.h }, p5/Z, [x9, #6, MUL VL]\n" + "fmla z9.h, z7.h, z0.h[6]\n" + "fmla z13.h, z7.h, z1.h[6]\n" + "ld1h { z7.h }, p5/Z, [x28, #6, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[6]\n" + "fmla z14.h, z6.h, z1.h[6]\n" + "ld1h { z6.h }, p5/Z, [x11, #7, MUL VL]\n" + "addvl x11, x11, #8\n" + "fmla z11.h, z7.h, z0.h[6]\n" + "fmla z15.h, z7.h, z1.h[6]\n" + "ld1h { z7.h }, p5/Z, [x10, #7, MUL VL]\n" + "addvl x10, x10, #8\n" + "fmla z8.h, z6.h, z0.h[7]\n" + "fmla z12.h, z6.h, z1.h[7]\n" + "ld1h { z6.h }, p5/Z, [x9, #7, MUL VL]\n" + "addvl x9, x9, #8\n" + "fmla z9.h, z7.h, z0.h[7]\n" + "fmla z13.h, z7.h, z1.h[7]\n" + "ld1h { z7.h }, p5/Z, [x28, #7, MUL VL]\n" + "addvl x28, x28, #8\n" + "fmla z10.h, z6.h, z0.h[7]\n" + "fmla z14.h, z6.h, z1.h[7]\n" + "fmla z11.h, z7.h, z0.h[7]\n" + "fmla z15.h, z7.h, z1.h[7]\n" + "bgt 24b\n" + "25:" // Height 2: Multiply loop: Single iteration only + "whilelt p0.h, XZR, x26\n" + "ld1rqh { z0.h }, p0/Z, [x25]\n" + "ld1rqh { z1.h }, p0/Z, [x24]\n" + "subs x26, x26, #0x1\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[0]\n" + "fmla z12.h, z6.h, z1.h[0]\n" + "fmla z9.h, z7.h, z0.h[0]\n" + "fmla z13.h, z7.h, z1.h[0]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "fmla z10.h, z6.h, z0.h[0]\n" + "fmla z14.h, z6.h, z1.h[0]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z11.h, z7.h, z0.h[0]\n" + "fmla z15.h, z7.h, z1.h[0]\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "ble 26f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[1]\n" + "fmla z12.h, z6.h, z1.h[1]\n" + "fmla z9.h, z7.h, z0.h[1]\n" + "fmla z13.h, z7.h, z1.h[1]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "subs x26, x26, #0x1\n" + "fmla z10.h, z6.h, z0.h[1]\n" + "fmla z14.h, z6.h, z1.h[1]\n" + "addvl x11, x11, #1\n" + "fmla z11.h, z7.h, z0.h[1]\n" + "fmla z15.h, z7.h, z1.h[1]\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "ble 26f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[2]\n" + "fmla z12.h, z6.h, z1.h[2]\n" + "fmla z9.h, z7.h, z0.h[2]\n" + "fmla z13.h, z7.h, z1.h[2]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "subs x26, x26, #0x1\n" + "fmla z10.h, z6.h, z0.h[2]\n" + "fmla z14.h, z6.h, z1.h[2]\n" + "addvl x11, x11, #1\n" + "fmla z11.h, z7.h, z0.h[2]\n" + "fmla z15.h, z7.h, z1.h[2]\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "ble 26f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[3]\n" + "fmla z12.h, z6.h, z1.h[3]\n" + "fmla z9.h, z7.h, z0.h[3]\n" + "fmla z13.h, z7.h, z1.h[3]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "subs x26, x26, #0x1\n" + "fmla z10.h, z6.h, z0.h[3]\n" + "fmla z14.h, z6.h, z1.h[3]\n" + "addvl x11, x11, #1\n" + "fmla z11.h, z7.h, z0.h[3]\n" + "fmla z15.h, z7.h, z1.h[3]\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "ble 26f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[4]\n" + "fmla z12.h, z6.h, z1.h[4]\n" + "fmla z9.h, z7.h, z0.h[4]\n" + "fmla z13.h, z7.h, z1.h[4]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "subs x26, x26, #0x1\n" + "fmla z10.h, z6.h, z0.h[4]\n" + "fmla z14.h, z6.h, z1.h[4]\n" + "addvl x11, x11, #1\n" + "fmla z11.h, z7.h, z0.h[4]\n" + "fmla z15.h, z7.h, z1.h[4]\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "ble 26f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[5]\n" + "fmla z12.h, z6.h, z1.h[5]\n" + "fmla z9.h, z7.h, z0.h[5]\n" + "fmla z13.h, z7.h, z1.h[5]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "subs x26, x26, #0x1\n" + "fmla z10.h, z6.h, z0.h[5]\n" + "fmla z14.h, z6.h, z1.h[5]\n" + "addvl x11, x11, #1\n" + "fmla z11.h, z7.h, z0.h[5]\n" + "fmla z15.h, z7.h, z1.h[5]\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "ble 26f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[6]\n" + "fmla z12.h, z6.h, z1.h[6]\n" + "fmla z9.h, z7.h, z0.h[6]\n" + "fmla z13.h, z7.h, z1.h[6]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "subs x26, x26, #0x1\n" + "fmla z10.h, z6.h, z0.h[6]\n" + "fmla z14.h, z6.h, z1.h[6]\n" + "addvl x11, x11, #1\n" + "fmla z11.h, z7.h, z0.h[6]\n" + "fmla z15.h, z7.h, z1.h[6]\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "ble 26f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[7]\n" + "fmla z12.h, z6.h, z1.h[7]\n" + "fmla z9.h, z7.h, z0.h[7]\n" + "fmla z13.h, z7.h, z1.h[7]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "fmla z10.h, z6.h, z0.h[7]\n" + "fmla z14.h, z6.h, z1.h[7]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z11.h, z7.h, z0.h[7]\n" + "fmla z15.h, z7.h, z1.h[7]\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "26:" // Height 2: Multiply loop: multiply skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 21b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "tbz %x[flags], #1, 27f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rh { z1.h }, p5/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rh { z0.h }, p5/Z, [x19]\n" + "fmin z8.h, p5/M, z8.h, z1.h\n" + "fmin z9.h, p5/M, z9.h, z1.h\n" + "fmin z10.h, p5/M, z10.h, z1.h\n" + "fmin z11.h, p5/M, z11.h, z1.h\n" + "fmin z12.h, p5/M, z12.h, z1.h\n" + "fmin z13.h, p5/M, z13.h, z1.h\n" + "fmin z14.h, p5/M, z14.h, z1.h\n" + "fmin z15.h, p5/M, z15.h, z1.h\n" + "fmax z8.h, p5/M, z8.h, z0.h\n" + "fmax z9.h, p5/M, z9.h, z0.h\n" + "fmax z10.h, p5/M, z10.h, z0.h\n" + "fmax z11.h, p5/M, z11.h, z0.h\n" + "fmax z12.h, p5/M, z12.h, z0.h\n" + "fmax z13.h, p5/M, z13.h, z0.h\n" + "fmax z14.h, p5/M, z14.h, z0.h\n" + "fmax z15.h, p5/M, z15.h, z0.h\n" + "27:" // Height 2: No activation + "st1h { z8.h }, p4, [x12]\n" + "st1h { z9.h }, p3, [x12, #1, MUL VL]\n" + "st1h { z10.h }, p2, [x12, #2, MUL VL]\n" + "st1h { z11.h }, p1, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "st1h { z12.h }, p4, [x24]\n" + "st1h { z13.h }, p3, [x24, #1, MUL VL]\n" + "st1h { z14.h }, p2, [x24, #2, MUL VL]\n" + "st1h { z15.h }, p1, [x24, #3, MUL VL]\n" + "28:" // Height 2: Writeback done + "dech x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 16b\n" + "b 86f\n" + "29:" // Height 3 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "30:" // Height 3: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "cnth x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 31f\n" + "dech x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 31f\n" + "dech x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 31f\n" + "mov x10, x11\n" + "31:" // Height 3: B setup done + "mov x19, #0x0\n" + "whilelt p4.h, x19, x13\n" + "inch x19\n" + "whilelt p3.h, x19, x13\n" + "inch x19\n" + "whilelt p2.h, x19, x13\n" + "inch x19\n" + "whilelt p1.h, x19, x13\n" + "cbz x14, 32f\n" + "ld1h { z8.h }, p5/Z, [x14]\n" + "ld1h { z9.h }, p5/Z, [x14, #1, MUL VL]\n" + "mov z12.d, z8.d\n" + "mov z13.d, z9.d\n" + "ld1h { z10.h }, p5/Z, [x14, #2, MUL VL]\n" + "ld1h { z11.h }, p5/Z, [x14, #3, MUL VL]\n" + "mov z14.d, z10.d\n" + "mov z15.d, z11.d\n" + "mov z16.d, z8.d\n" + "mov z17.d, z9.d\n" + "addvl x14, x14, #4\n" + "mov z18.d, z10.d\n" + "mov z19.d, z11.d\n" + "b 34f\n" + "32:" // Height 3: no bias + "tbz %x[flags], #0, 33f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "ld1h { z8.h }, p4/Z, [x12]\n" + "ld1h { z9.h }, p3/Z, [x12, #1, MUL VL]\n" + "ld1h { z10.h }, p2/Z, [x12, #2, MUL VL]\n" + "ld1h { z11.h }, p1/Z, [x12, #3, MUL VL]\n" + "ld1h { z12.h }, p4/Z, [x24]\n" + "ld1h { z13.h }, p3/Z, [x24, #1, MUL VL]\n" + "ld1h { z14.h }, p2/Z, [x24, #2, MUL VL]\n" + "ld1h { z15.h }, p1/Z, [x24, #3, MUL VL]\n" + "ld1h { z16.h }, p4/Z, [x23]\n" + "ld1h { z17.h }, p3/Z, [x23, #1, MUL VL]\n" + "ld1h { z18.h }, p2/Z, [x23, #2, MUL VL]\n" + "ld1h { z19.h }, p1/Z, [x23, #3, MUL VL]\n" + "b 34f\n" + "33:" // Height 3: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "mov z19.b, #0x0\n" + "34:" // Height 3: setup done + "mov x27, #0x0\n" + "35:" // Height 3: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 36f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "cbnz x27, 37f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "add x24, x24, x19, LSL #1\n" + "add x23, x23, x19, LSL #1\n" + "b 37f\n" + "36:" // Height 3: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "37:" // Height 3: input setup done + "cmp x26, #0x8\n" + "ble 39f\n" + "38:" // Height 3: Multiply loop: Main loop head + "whilelt p0.h, XZR, x26\n" + "ld1rqh { z0.h }, p0/Z, [x25]\n" + "ld1rqh { z1.h }, p0/Z, [x24]\n" + "sub x26, x26, #0x8\n" + "ld1rqh { z2.h }, p0/Z, [x23]\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "fmla z8.h, z6.h, z0.h[0]\n" + "fmla z12.h, z6.h, z1.h[0]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z16.h, z6.h, z2.h[0]\n" + "fmla z9.h, z7.h, z0.h[0]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "fmla z13.h, z7.h, z1.h[0]\n" + "fmla z17.h, z7.h, z2.h[0]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "cmp x26, #0x8\n" + "fmla z10.h, z6.h, z0.h[0]\n" + "fmla z14.h, z6.h, z1.h[0]\n" + "add x25, x25, #0x10\n" + "add x24, x24, #0x10\n" + "fmla z18.h, z6.h, z2.h[0]\n" + "fmla z11.h, z7.h, z0.h[0]\n" + "ld1h { z6.h }, p5/Z, [x11, #1, MUL VL]\n" + "add x23, x23, #0x10\n" + "fmla z15.h, z7.h, z1.h[0]\n" + "fmla z19.h, z7.h, z2.h[0]\n" + "ld1h { z7.h }, p5/Z, [x10, #1, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[1]\n" + "fmla z12.h, z6.h, z1.h[1]\n" + "fmla z16.h, z6.h, z2.h[1]\n" + "fmla z9.h, z7.h, z0.h[1]\n" + "ld1h { z6.h }, p5/Z, [x9, #1, MUL VL]\n" + "fmla z13.h, z7.h, z1.h[1]\n" + "fmla z17.h, z7.h, z2.h[1]\n" + "ld1h { z7.h }, p5/Z, [x28, #1, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[1]\n" + "fmla z14.h, z6.h, z1.h[1]\n" + "fmla z18.h, z6.h, z2.h[1]\n" + "fmla z11.h, z7.h, z0.h[1]\n" + "ld1h { z6.h }, p5/Z, [x11, #2, MUL VL]\n" + "fmla z15.h, z7.h, z1.h[1]\n" + "fmla z19.h, z7.h, z2.h[1]\n" + "ld1h { z7.h }, p5/Z, [x10, #2, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[2]\n" + "fmla z12.h, z6.h, z1.h[2]\n" + "fmla z16.h, z6.h, z2.h[2]\n" + "fmla z9.h, z7.h, z0.h[2]\n" + "ld1h { z6.h }, p5/Z, [x9, #2, MUL VL]\n" + "fmla z13.h, z7.h, z1.h[2]\n" + "fmla z17.h, z7.h, z2.h[2]\n" + "ld1h { z7.h }, p5/Z, [x28, #2, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[2]\n" + "fmla z14.h, z6.h, z1.h[2]\n" + "fmla z18.h, z6.h, z2.h[2]\n" + "fmla z11.h, z7.h, z0.h[2]\n" + "ld1h { z6.h }, p5/Z, [x11, #3, MUL VL]\n" + "fmla z15.h, z7.h, z1.h[2]\n" + "fmla z19.h, z7.h, z2.h[2]\n" + "ld1h { z7.h }, p5/Z, [x10, #3, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[3]\n" + "fmla z12.h, z6.h, z1.h[3]\n" + "fmla z16.h, z6.h, z2.h[3]\n" + "fmla z9.h, z7.h, z0.h[3]\n" + "ld1h { z6.h }, p5/Z, [x9, #3, MUL VL]\n" + "fmla z13.h, z7.h, z1.h[3]\n" + "fmla z17.h, z7.h, z2.h[3]\n" + "ld1h { z7.h }, p5/Z, [x28, #3, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[3]\n" + "fmla z14.h, z6.h, z1.h[3]\n" + "fmla z18.h, z6.h, z2.h[3]\n" + "fmla z11.h, z7.h, z0.h[3]\n" + "ld1h { z6.h }, p5/Z, [x11, #4, MUL VL]\n" + "fmla z15.h, z7.h, z1.h[3]\n" + "fmla z19.h, z7.h, z2.h[3]\n" + "ld1h { z7.h }, p5/Z, [x10, #4, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[4]\n" + "fmla z12.h, z6.h, z1.h[4]\n" + "fmla z16.h, z6.h, z2.h[4]\n" + "fmla z9.h, z7.h, z0.h[4]\n" + "ld1h { z6.h }, p5/Z, [x9, #4, MUL VL]\n" + "fmla z13.h, z7.h, z1.h[4]\n" + "fmla z17.h, z7.h, z2.h[4]\n" + "ld1h { z7.h }, p5/Z, [x28, #4, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[4]\n" + "fmla z14.h, z6.h, z1.h[4]\n" + "fmla z18.h, z6.h, z2.h[4]\n" + "fmla z11.h, z7.h, z0.h[4]\n" + "ld1h { z6.h }, p5/Z, [x11, #5, MUL VL]\n" + "fmla z15.h, z7.h, z1.h[4]\n" + "fmla z19.h, z7.h, z2.h[4]\n" + "ld1h { z7.h }, p5/Z, [x10, #5, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[5]\n" + "fmla z12.h, z6.h, z1.h[5]\n" + "fmla z16.h, z6.h, z2.h[5]\n" + "fmla z9.h, z7.h, z0.h[5]\n" + "ld1h { z6.h }, p5/Z, [x9, #5, MUL VL]\n" + "fmla z13.h, z7.h, z1.h[5]\n" + "fmla z17.h, z7.h, z2.h[5]\n" + "ld1h { z7.h }, p5/Z, [x28, #5, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[5]\n" + "fmla z14.h, z6.h, z1.h[5]\n" + "fmla z18.h, z6.h, z2.h[5]\n" + "fmla z11.h, z7.h, z0.h[5]\n" + "ld1h { z6.h }, p5/Z, [x11, #6, MUL VL]\n" + "fmla z15.h, z7.h, z1.h[5]\n" + "fmla z19.h, z7.h, z2.h[5]\n" + "ld1h { z7.h }, p5/Z, [x10, #6, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[6]\n" + "fmla z12.h, z6.h, z1.h[6]\n" + "fmla z16.h, z6.h, z2.h[6]\n" + "fmla z9.h, z7.h, z0.h[6]\n" + "ld1h { z6.h }, p5/Z, [x9, #6, MUL VL]\n" + "fmla z13.h, z7.h, z1.h[6]\n" + "fmla z17.h, z7.h, z2.h[6]\n" + "ld1h { z7.h }, p5/Z, [x28, #6, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[6]\n" + "fmla z14.h, z6.h, z1.h[6]\n" + "fmla z18.h, z6.h, z2.h[6]\n" + "fmla z11.h, z7.h, z0.h[6]\n" + "ld1h { z6.h }, p5/Z, [x11, #7, MUL VL]\n" + "addvl x11, x11, #8\n" + "fmla z15.h, z7.h, z1.h[6]\n" + "fmla z19.h, z7.h, z2.h[6]\n" + "ld1h { z7.h }, p5/Z, [x10, #7, MUL VL]\n" + "addvl x10, x10, #8\n" + "fmla z8.h, z6.h, z0.h[7]\n" + "fmla z12.h, z6.h, z1.h[7]\n" + "fmla z16.h, z6.h, z2.h[7]\n" + "fmla z9.h, z7.h, z0.h[7]\n" + "ld1h { z6.h }, p5/Z, [x9, #7, MUL VL]\n" + "addvl x9, x9, #8\n" + "fmla z13.h, z7.h, z1.h[7]\n" + "fmla z17.h, z7.h, z2.h[7]\n" + "ld1h { z7.h }, p5/Z, [x28, #7, MUL VL]\n" + "addvl x28, x28, #8\n" + "fmla z10.h, z6.h, z0.h[7]\n" + "fmla z14.h, z6.h, z1.h[7]\n" + "fmla z18.h, z6.h, z2.h[7]\n" + "fmla z11.h, z7.h, z0.h[7]\n" + "fmla z15.h, z7.h, z1.h[7]\n" + "fmla z19.h, z7.h, z2.h[7]\n" + "bgt 38b\n" + "39:" // Height 3: Multiply loop: Single iteration only + "whilelt p0.h, XZR, x26\n" + "ld1rqh { z0.h }, p0/Z, [x25]\n" + "ld1rqh { z1.h }, p0/Z, [x24]\n" + "subs x26, x26, #0x1\n" + "ld1rqh { z2.h }, p0/Z, [x23]\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "fmla z8.h, z6.h, z0.h[0]\n" + "fmla z12.h, z6.h, z1.h[0]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z16.h, z6.h, z2.h[0]\n" + "fmla z9.h, z7.h, z0.h[0]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "fmla z13.h, z7.h, z1.h[0]\n" + "fmla z17.h, z7.h, z2.h[0]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x11, x11, #1\n" + "fmla z10.h, z6.h, z0.h[0]\n" + "fmla z14.h, z6.h, z1.h[0]\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "fmla z18.h, z6.h, z2.h[0]\n" + "fmla z11.h, z7.h, z0.h[0]\n" + "addvl x28, x28, #1\n" + "fmla z15.h, z7.h, z1.h[0]\n" + "fmla z19.h, z7.h, z2.h[0]\n" + "ble 40f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[1]\n" + "fmla z12.h, z6.h, z1.h[1]\n" + "fmla z16.h, z6.h, z2.h[1]\n" + "fmla z9.h, z7.h, z0.h[1]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "subs x26, x26, #0x1\n" + "fmla z13.h, z7.h, z1.h[1]\n" + "fmla z17.h, z7.h, z2.h[1]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x11, x11, #1\n" + "fmla z10.h, z6.h, z0.h[1]\n" + "fmla z14.h, z6.h, z1.h[1]\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "fmla z18.h, z6.h, z2.h[1]\n" + "fmla z11.h, z7.h, z0.h[1]\n" + "addvl x28, x28, #1\n" + "fmla z15.h, z7.h, z1.h[1]\n" + "fmla z19.h, z7.h, z2.h[1]\n" + "ble 40f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[2]\n" + "fmla z12.h, z6.h, z1.h[2]\n" + "fmla z16.h, z6.h, z2.h[2]\n" + "fmla z9.h, z7.h, z0.h[2]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "subs x26, x26, #0x1\n" + "fmla z13.h, z7.h, z1.h[2]\n" + "fmla z17.h, z7.h, z2.h[2]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x11, x11, #1\n" + "fmla z10.h, z6.h, z0.h[2]\n" + "fmla z14.h, z6.h, z1.h[2]\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "fmla z18.h, z6.h, z2.h[2]\n" + "fmla z11.h, z7.h, z0.h[2]\n" + "addvl x28, x28, #1\n" + "fmla z15.h, z7.h, z1.h[2]\n" + "fmla z19.h, z7.h, z2.h[2]\n" + "ble 40f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[3]\n" + "fmla z12.h, z6.h, z1.h[3]\n" + "fmla z16.h, z6.h, z2.h[3]\n" + "fmla z9.h, z7.h, z0.h[3]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "subs x26, x26, #0x1\n" + "fmla z13.h, z7.h, z1.h[3]\n" + "fmla z17.h, z7.h, z2.h[3]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x11, x11, #1\n" + "fmla z10.h, z6.h, z0.h[3]\n" + "fmla z14.h, z6.h, z1.h[3]\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "fmla z18.h, z6.h, z2.h[3]\n" + "fmla z11.h, z7.h, z0.h[3]\n" + "addvl x28, x28, #1\n" + "fmla z15.h, z7.h, z1.h[3]\n" + "fmla z19.h, z7.h, z2.h[3]\n" + "ble 40f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[4]\n" + "fmla z12.h, z6.h, z1.h[4]\n" + "fmla z16.h, z6.h, z2.h[4]\n" + "fmla z9.h, z7.h, z0.h[4]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "subs x26, x26, #0x1\n" + "fmla z13.h, z7.h, z1.h[4]\n" + "fmla z17.h, z7.h, z2.h[4]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x11, x11, #1\n" + "fmla z10.h, z6.h, z0.h[4]\n" + "fmla z14.h, z6.h, z1.h[4]\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "fmla z18.h, z6.h, z2.h[4]\n" + "fmla z11.h, z7.h, z0.h[4]\n" + "addvl x28, x28, #1\n" + "fmla z15.h, z7.h, z1.h[4]\n" + "fmla z19.h, z7.h, z2.h[4]\n" + "ble 40f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[5]\n" + "fmla z12.h, z6.h, z1.h[5]\n" + "fmla z16.h, z6.h, z2.h[5]\n" + "fmla z9.h, z7.h, z0.h[5]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "subs x26, x26, #0x1\n" + "fmla z13.h, z7.h, z1.h[5]\n" + "fmla z17.h, z7.h, z2.h[5]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x11, x11, #1\n" + "fmla z10.h, z6.h, z0.h[5]\n" + "fmla z14.h, z6.h, z1.h[5]\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "fmla z18.h, z6.h, z2.h[5]\n" + "fmla z11.h, z7.h, z0.h[5]\n" + "addvl x28, x28, #1\n" + "fmla z15.h, z7.h, z1.h[5]\n" + "fmla z19.h, z7.h, z2.h[5]\n" + "ble 40f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[6]\n" + "fmla z12.h, z6.h, z1.h[6]\n" + "fmla z16.h, z6.h, z2.h[6]\n" + "fmla z9.h, z7.h, z0.h[6]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "subs x26, x26, #0x1\n" + "fmla z13.h, z7.h, z1.h[6]\n" + "fmla z17.h, z7.h, z2.h[6]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x11, x11, #1\n" + "fmla z10.h, z6.h, z0.h[6]\n" + "fmla z14.h, z6.h, z1.h[6]\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "fmla z18.h, z6.h, z2.h[6]\n" + "fmla z11.h, z7.h, z0.h[6]\n" + "addvl x28, x28, #1\n" + "fmla z15.h, z7.h, z1.h[6]\n" + "fmla z19.h, z7.h, z2.h[6]\n" + "ble 40f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[7]\n" + "fmla z12.h, z6.h, z1.h[7]\n" + "fmla z16.h, z6.h, z2.h[7]\n" + "fmla z9.h, z7.h, z0.h[7]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "addvl x11, x11, #1\n" + "fmla z13.h, z7.h, z1.h[7]\n" + "fmla z17.h, z7.h, z2.h[7]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x10, x10, #1\n" + "fmla z10.h, z6.h, z0.h[7]\n" + "fmla z14.h, z6.h, z1.h[7]\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "fmla z18.h, z6.h, z2.h[7]\n" + "fmla z11.h, z7.h, z0.h[7]\n" + "fmla z15.h, z7.h, z1.h[7]\n" + "fmla z19.h, z7.h, z2.h[7]\n" + "40:" // Height 3: Multiply loop: multiply skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 35b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "tbz %x[flags], #1, 41f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rh { z1.h }, p5/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rh { z0.h }, p5/Z, [x19]\n" + "fmin z8.h, p5/M, z8.h, z1.h\n" + "fmin z9.h, p5/M, z9.h, z1.h\n" + "fmin z10.h, p5/M, z10.h, z1.h\n" + "fmin z11.h, p5/M, z11.h, z1.h\n" + "fmin z12.h, p5/M, z12.h, z1.h\n" + "fmin z13.h, p5/M, z13.h, z1.h\n" + "fmin z14.h, p5/M, z14.h, z1.h\n" + "fmin z15.h, p5/M, z15.h, z1.h\n" + "fmin z16.h, p5/M, z16.h, z1.h\n" + "fmin z17.h, p5/M, z17.h, z1.h\n" + "fmin z18.h, p5/M, z18.h, z1.h\n" + "fmin z19.h, p5/M, z19.h, z1.h\n" + "fmax z8.h, p5/M, z8.h, z0.h\n" + "fmax z9.h, p5/M, z9.h, z0.h\n" + "fmax z10.h, p5/M, z10.h, z0.h\n" + "fmax z11.h, p5/M, z11.h, z0.h\n" + "fmax z12.h, p5/M, z12.h, z0.h\n" + "fmax z13.h, p5/M, z13.h, z0.h\n" + "fmax z14.h, p5/M, z14.h, z0.h\n" + "fmax z15.h, p5/M, z15.h, z0.h\n" + "fmax z16.h, p5/M, z16.h, z0.h\n" + "fmax z17.h, p5/M, z17.h, z0.h\n" + "fmax z18.h, p5/M, z18.h, z0.h\n" + "fmax z19.h, p5/M, z19.h, z0.h\n" + "41:" // Height 3: No activation + "st1h { z8.h }, p4, [x12]\n" + "st1h { z9.h }, p3, [x12, #1, MUL VL]\n" + "st1h { z10.h }, p2, [x12, #2, MUL VL]\n" + "st1h { z11.h }, p1, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "st1h { z12.h }, p4, [x24]\n" + "st1h { z13.h }, p3, [x24, #1, MUL VL]\n" + "st1h { z14.h }, p2, [x24, #2, MUL VL]\n" + "st1h { z15.h }, p1, [x24, #3, MUL VL]\n" + "st1h { z16.h }, p4, [x23]\n" + "st1h { z17.h }, p3, [x23, #1, MUL VL]\n" + "st1h { z18.h }, p2, [x23, #2, MUL VL]\n" + "st1h { z19.h }, p1, [x23, #3, MUL VL]\n" + "42:" // Height 3: Writeback done + "dech x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 30b\n" + "b 86f\n" + "43:" // Height 4 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "44:" // Height 4: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "cnth x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 45f\n" + "dech x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 45f\n" + "dech x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 45f\n" + "mov x10, x11\n" + "45:" // Height 4: B setup done + "mov x19, #0x0\n" + "whilelt p4.h, x19, x13\n" + "inch x19\n" + "whilelt p3.h, x19, x13\n" + "inch x19\n" + "whilelt p2.h, x19, x13\n" + "inch x19\n" + "whilelt p1.h, x19, x13\n" + "cbz x14, 46f\n" + "ld1h { z8.h }, p5/Z, [x14]\n" + "ld1h { z9.h }, p5/Z, [x14, #1, MUL VL]\n" + "mov z12.d, z8.d\n" + "mov z13.d, z9.d\n" + "ld1h { z10.h }, p5/Z, [x14, #2, MUL VL]\n" + "ld1h { z11.h }, p5/Z, [x14, #3, MUL VL]\n" + "mov z14.d, z10.d\n" + "mov z15.d, z11.d\n" + "mov z16.d, z8.d\n" + "mov z17.d, z9.d\n" + "addvl x14, x14, #4\n" + "mov z18.d, z10.d\n" + "mov z19.d, z11.d\n" + "mov z20.d, z8.d\n" + "mov z21.d, z9.d\n" + "mov z22.d, z10.d\n" + "mov z23.d, z11.d\n" + "b 48f\n" + "46:" // Height 4: no bias + "tbz %x[flags], #0, 47f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "ld1h { z8.h }, p4/Z, [x12]\n" + "add x22, x23, x19, LSL #1\n" + "ld1h { z9.h }, p3/Z, [x12, #1, MUL VL]\n" + "ld1h { z10.h }, p2/Z, [x12, #2, MUL VL]\n" + "ld1h { z11.h }, p1/Z, [x12, #3, MUL VL]\n" + "ld1h { z12.h }, p4/Z, [x24]\n" + "ld1h { z13.h }, p3/Z, [x24, #1, MUL VL]\n" + "ld1h { z14.h }, p2/Z, [x24, #2, MUL VL]\n" + "ld1h { z15.h }, p1/Z, [x24, #3, MUL VL]\n" + "ld1h { z16.h }, p4/Z, [x23]\n" + "ld1h { z17.h }, p3/Z, [x23, #1, MUL VL]\n" + "ld1h { z18.h }, p2/Z, [x23, #2, MUL VL]\n" + "ld1h { z19.h }, p1/Z, [x23, #3, MUL VL]\n" + "ld1h { z20.h }, p4/Z, [x22]\n" + "ld1h { z21.h }, p3/Z, [x22, #1, MUL VL]\n" + "ld1h { z22.h }, p2/Z, [x22, #2, MUL VL]\n" + "ld1h { z23.h }, p1/Z, [x22, #3, MUL VL]\n" + "b 48f\n" + "47:" // Height 4: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "mov z19.b, #0x0\n" + "mov z20.b, #0x0\n" + "mov z21.b, #0x0\n" + "mov z22.b, #0x0\n" + "mov z23.b, #0x0\n" + "48:" // Height 4: setup done + "mov x27, #0x0\n" + "49:" // Height 4: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 50f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "ldr x22, [x20, #0x18]\n" + "cbnz x27, 51f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "add x24, x24, x19, LSL #1\n" + "add x23, x23, x19, LSL #1\n" + "add x22, x22, x19, LSL #1\n" + "b 51f\n" + "50:" // Height 4: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "add x22, x23, x19, LSL #1\n" + "51:" // Height 4: input setup done + "cmp x26, #0x8\n" + "ble 53f\n" + "52:" // Height 4: Multiply loop: Main loop head + "whilelt p0.h, XZR, x26\n" + "ld1rqh { z0.h }, p0/Z, [x25]\n" + "ld1rqh { z1.h }, p0/Z, [x24]\n" + "sub x26, x26, #0x8\n" + "ld1rqh { z2.h }, p0/Z, [x23]\n" + "ld1rqh { z3.h }, p0/Z, [x22]\n" + "cmp x26, #0x8\n" + "add x25, x25, #0x10\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[0]\n" + "fmla z12.h, z6.h, z1.h[0]\n" + "fmla z16.h, z6.h, z2.h[0]\n" + "fmla z20.h, z6.h, z3.h[0]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "add x24, x24, #0x10\n" + "fmla z9.h, z7.h, z0.h[0]\n" + "fmla z13.h, z7.h, z1.h[0]\n" + "add x23, x23, #0x10\n" + "add x22, x22, #0x10\n" + "fmla z17.h, z7.h, z2.h[0]\n" + "fmla z21.h, z7.h, z3.h[0]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "fmla z10.h, z6.h, z0.h[0]\n" + "fmla z14.h, z6.h, z1.h[0]\n" + "fmla z18.h, z6.h, z2.h[0]\n" + "fmla z22.h, z6.h, z3.h[0]\n" + "ld1h { z6.h }, p5/Z, [x11, #1, MUL VL]\n" + "fmla z11.h, z7.h, z0.h[0]\n" + "fmla z15.h, z7.h, z1.h[0]\n" + "fmla z19.h, z7.h, z2.h[0]\n" + "fmla z23.h, z7.h, z3.h[0]\n" + "ld1h { z7.h }, p5/Z, [x10, #1, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[1]\n" + "fmla z12.h, z6.h, z1.h[1]\n" + "fmla z16.h, z6.h, z2.h[1]\n" + "fmla z20.h, z6.h, z3.h[1]\n" + "ld1h { z6.h }, p5/Z, [x9, #1, MUL VL]\n" + "fmla z9.h, z7.h, z0.h[1]\n" + "fmla z13.h, z7.h, z1.h[1]\n" + "fmla z17.h, z7.h, z2.h[1]\n" + "fmla z21.h, z7.h, z3.h[1]\n" + "ld1h { z7.h }, p5/Z, [x28, #1, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[1]\n" + "fmla z14.h, z6.h, z1.h[1]\n" + "fmla z18.h, z6.h, z2.h[1]\n" + "fmla z22.h, z6.h, z3.h[1]\n" + "ld1h { z6.h }, p5/Z, [x11, #2, MUL VL]\n" + "fmla z11.h, z7.h, z0.h[1]\n" + "fmla z15.h, z7.h, z1.h[1]\n" + "fmla z19.h, z7.h, z2.h[1]\n" + "fmla z23.h, z7.h, z3.h[1]\n" + "ld1h { z7.h }, p5/Z, [x10, #2, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[2]\n" + "fmla z12.h, z6.h, z1.h[2]\n" + "fmla z16.h, z6.h, z2.h[2]\n" + "fmla z20.h, z6.h, z3.h[2]\n" + "ld1h { z6.h }, p5/Z, [x9, #2, MUL VL]\n" + "fmla z9.h, z7.h, z0.h[2]\n" + "fmla z13.h, z7.h, z1.h[2]\n" + "fmla z17.h, z7.h, z2.h[2]\n" + "fmla z21.h, z7.h, z3.h[2]\n" + "ld1h { z7.h }, p5/Z, [x28, #2, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[2]\n" + "fmla z14.h, z6.h, z1.h[2]\n" + "fmla z18.h, z6.h, z2.h[2]\n" + "fmla z22.h, z6.h, z3.h[2]\n" + "ld1h { z6.h }, p5/Z, [x11, #3, MUL VL]\n" + "fmla z11.h, z7.h, z0.h[2]\n" + "fmla z15.h, z7.h, z1.h[2]\n" + "fmla z19.h, z7.h, z2.h[2]\n" + "fmla z23.h, z7.h, z3.h[2]\n" + "ld1h { z7.h }, p5/Z, [x10, #3, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[3]\n" + "fmla z12.h, z6.h, z1.h[3]\n" + "fmla z16.h, z6.h, z2.h[3]\n" + "fmla z20.h, z6.h, z3.h[3]\n" + "ld1h { z6.h }, p5/Z, [x9, #3, MUL VL]\n" + "fmla z9.h, z7.h, z0.h[3]\n" + "fmla z13.h, z7.h, z1.h[3]\n" + "fmla z17.h, z7.h, z2.h[3]\n" + "fmla z21.h, z7.h, z3.h[3]\n" + "ld1h { z7.h }, p5/Z, [x28, #3, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[3]\n" + "fmla z14.h, z6.h, z1.h[3]\n" + "fmla z18.h, z6.h, z2.h[3]\n" + "fmla z22.h, z6.h, z3.h[3]\n" + "ld1h { z6.h }, p5/Z, [x11, #4, MUL VL]\n" + "fmla z11.h, z7.h, z0.h[3]\n" + "fmla z15.h, z7.h, z1.h[3]\n" + "fmla z19.h, z7.h, z2.h[3]\n" + "fmla z23.h, z7.h, z3.h[3]\n" + "ld1h { z7.h }, p5/Z, [x10, #4, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[4]\n" + "fmla z12.h, z6.h, z1.h[4]\n" + "fmla z16.h, z6.h, z2.h[4]\n" + "fmla z20.h, z6.h, z3.h[4]\n" + "ld1h { z6.h }, p5/Z, [x9, #4, MUL VL]\n" + "fmla z9.h, z7.h, z0.h[4]\n" + "fmla z13.h, z7.h, z1.h[4]\n" + "fmla z17.h, z7.h, z2.h[4]\n" + "fmla z21.h, z7.h, z3.h[4]\n" + "ld1h { z7.h }, p5/Z, [x28, #4, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[4]\n" + "fmla z14.h, z6.h, z1.h[4]\n" + "fmla z18.h, z6.h, z2.h[4]\n" + "fmla z22.h, z6.h, z3.h[4]\n" + "ld1h { z6.h }, p5/Z, [x11, #5, MUL VL]\n" + "fmla z11.h, z7.h, z0.h[4]\n" + "fmla z15.h, z7.h, z1.h[4]\n" + "fmla z19.h, z7.h, z2.h[4]\n" + "fmla z23.h, z7.h, z3.h[4]\n" + "ld1h { z7.h }, p5/Z, [x10, #5, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[5]\n" + "fmla z12.h, z6.h, z1.h[5]\n" + "fmla z16.h, z6.h, z2.h[5]\n" + "fmla z20.h, z6.h, z3.h[5]\n" + "ld1h { z6.h }, p5/Z, [x9, #5, MUL VL]\n" + "fmla z9.h, z7.h, z0.h[5]\n" + "fmla z13.h, z7.h, z1.h[5]\n" + "fmla z17.h, z7.h, z2.h[5]\n" + "fmla z21.h, z7.h, z3.h[5]\n" + "ld1h { z7.h }, p5/Z, [x28, #5, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[5]\n" + "fmla z14.h, z6.h, z1.h[5]\n" + "fmla z18.h, z6.h, z2.h[5]\n" + "fmla z22.h, z6.h, z3.h[5]\n" + "ld1h { z6.h }, p5/Z, [x11, #6, MUL VL]\n" + "fmla z11.h, z7.h, z0.h[5]\n" + "fmla z15.h, z7.h, z1.h[5]\n" + "fmla z19.h, z7.h, z2.h[5]\n" + "fmla z23.h, z7.h, z3.h[5]\n" + "ld1h { z7.h }, p5/Z, [x10, #6, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[6]\n" + "fmla z12.h, z6.h, z1.h[6]\n" + "fmla z16.h, z6.h, z2.h[6]\n" + "fmla z20.h, z6.h, z3.h[6]\n" + "ld1h { z6.h }, p5/Z, [x9, #6, MUL VL]\n" + "fmla z9.h, z7.h, z0.h[6]\n" + "fmla z13.h, z7.h, z1.h[6]\n" + "fmla z17.h, z7.h, z2.h[6]\n" + "fmla z21.h, z7.h, z3.h[6]\n" + "ld1h { z7.h }, p5/Z, [x28, #6, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[6]\n" + "fmla z14.h, z6.h, z1.h[6]\n" + "fmla z18.h, z6.h, z2.h[6]\n" + "fmla z22.h, z6.h, z3.h[6]\n" + "ld1h { z6.h }, p5/Z, [x11, #7, MUL VL]\n" + "addvl x11, x11, #8\n" + "fmla z11.h, z7.h, z0.h[6]\n" + "fmla z15.h, z7.h, z1.h[6]\n" + "fmla z19.h, z7.h, z2.h[6]\n" + "fmla z23.h, z7.h, z3.h[6]\n" + "ld1h { z7.h }, p5/Z, [x10, #7, MUL VL]\n" + "addvl x10, x10, #8\n" + "fmla z8.h, z6.h, z0.h[7]\n" + "fmla z12.h, z6.h, z1.h[7]\n" + "fmla z16.h, z6.h, z2.h[7]\n" + "fmla z20.h, z6.h, z3.h[7]\n" + "ld1h { z6.h }, p5/Z, [x9, #7, MUL VL]\n" + "addvl x9, x9, #8\n" + "fmla z9.h, z7.h, z0.h[7]\n" + "fmla z13.h, z7.h, z1.h[7]\n" + "fmla z17.h, z7.h, z2.h[7]\n" + "fmla z21.h, z7.h, z3.h[7]\n" + "ld1h { z7.h }, p5/Z, [x28, #7, MUL VL]\n" + "addvl x28, x28, #8\n" + "fmla z10.h, z6.h, z0.h[7]\n" + "fmla z14.h, z6.h, z1.h[7]\n" + "fmla z18.h, z6.h, z2.h[7]\n" + "fmla z22.h, z6.h, z3.h[7]\n" + "fmla z11.h, z7.h, z0.h[7]\n" + "fmla z15.h, z7.h, z1.h[7]\n" + "fmla z19.h, z7.h, z2.h[7]\n" + "fmla z23.h, z7.h, z3.h[7]\n" + "bgt 52b\n" + "53:" // Height 4: Multiply loop: Single iteration only + "whilelt p0.h, XZR, x26\n" + "ld1rqh { z0.h }, p0/Z, [x25]\n" + "ld1rqh { z1.h }, p0/Z, [x24]\n" + "subs x26, x26, #0x1\n" + "ld1rqh { z2.h }, p0/Z, [x23]\n" + "ld1rqh { z3.h }, p0/Z, [x22]\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[0]\n" + "fmla z12.h, z6.h, z1.h[0]\n" + "fmla z16.h, z6.h, z2.h[0]\n" + "fmla z20.h, z6.h, z3.h[0]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "addvl x11, x11, #1\n" + "fmla z9.h, z7.h, z0.h[0]\n" + "fmla z13.h, z7.h, z1.h[0]\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "fmla z17.h, z7.h, z2.h[0]\n" + "fmla z21.h, z7.h, z3.h[0]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.h, z6.h, z0.h[0]\n" + "fmla z14.h, z6.h, z1.h[0]\n" + "fmla z18.h, z6.h, z2.h[0]\n" + "fmla z22.h, z6.h, z3.h[0]\n" + "fmla z11.h, z7.h, z0.h[0]\n" + "fmla z15.h, z7.h, z1.h[0]\n" + "fmla z19.h, z7.h, z2.h[0]\n" + "fmla z23.h, z7.h, z3.h[0]\n" + "ble 54f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[1]\n" + "fmla z12.h, z6.h, z1.h[1]\n" + "fmla z16.h, z6.h, z2.h[1]\n" + "fmla z20.h, z6.h, z3.h[1]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "subs x26, x26, #0x1\n" + "fmla z9.h, z7.h, z0.h[1]\n" + "fmla z13.h, z7.h, z1.h[1]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z17.h, z7.h, z2.h[1]\n" + "fmla z21.h, z7.h, z3.h[1]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x9, x9, #1\n" + "fmla z10.h, z6.h, z0.h[1]\n" + "fmla z14.h, z6.h, z1.h[1]\n" + "addvl x28, x28, #1\n" + "fmla z18.h, z6.h, z2.h[1]\n" + "fmla z22.h, z6.h, z3.h[1]\n" + "fmla z11.h, z7.h, z0.h[1]\n" + "fmla z15.h, z7.h, z1.h[1]\n" + "fmla z19.h, z7.h, z2.h[1]\n" + "fmla z23.h, z7.h, z3.h[1]\n" + "ble 54f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[2]\n" + "fmla z12.h, z6.h, z1.h[2]\n" + "fmla z16.h, z6.h, z2.h[2]\n" + "fmla z20.h, z6.h, z3.h[2]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "subs x26, x26, #0x1\n" + "fmla z9.h, z7.h, z0.h[2]\n" + "fmla z13.h, z7.h, z1.h[2]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z17.h, z7.h, z2.h[2]\n" + "fmla z21.h, z7.h, z3.h[2]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x9, x9, #1\n" + "fmla z10.h, z6.h, z0.h[2]\n" + "fmla z14.h, z6.h, z1.h[2]\n" + "addvl x28, x28, #1\n" + "fmla z18.h, z6.h, z2.h[2]\n" + "fmla z22.h, z6.h, z3.h[2]\n" + "fmla z11.h, z7.h, z0.h[2]\n" + "fmla z15.h, z7.h, z1.h[2]\n" + "fmla z19.h, z7.h, z2.h[2]\n" + "fmla z23.h, z7.h, z3.h[2]\n" + "ble 54f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[3]\n" + "fmla z12.h, z6.h, z1.h[3]\n" + "fmla z16.h, z6.h, z2.h[3]\n" + "fmla z20.h, z6.h, z3.h[3]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "subs x26, x26, #0x1\n" + "fmla z9.h, z7.h, z0.h[3]\n" + "fmla z13.h, z7.h, z1.h[3]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z17.h, z7.h, z2.h[3]\n" + "fmla z21.h, z7.h, z3.h[3]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x9, x9, #1\n" + "fmla z10.h, z6.h, z0.h[3]\n" + "fmla z14.h, z6.h, z1.h[3]\n" + "addvl x28, x28, #1\n" + "fmla z18.h, z6.h, z2.h[3]\n" + "fmla z22.h, z6.h, z3.h[3]\n" + "fmla z11.h, z7.h, z0.h[3]\n" + "fmla z15.h, z7.h, z1.h[3]\n" + "fmla z19.h, z7.h, z2.h[3]\n" + "fmla z23.h, z7.h, z3.h[3]\n" + "ble 54f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[4]\n" + "fmla z12.h, z6.h, z1.h[4]\n" + "fmla z16.h, z6.h, z2.h[4]\n" + "fmla z20.h, z6.h, z3.h[4]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "subs x26, x26, #0x1\n" + "fmla z9.h, z7.h, z0.h[4]\n" + "fmla z13.h, z7.h, z1.h[4]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z17.h, z7.h, z2.h[4]\n" + "fmla z21.h, z7.h, z3.h[4]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x9, x9, #1\n" + "fmla z10.h, z6.h, z0.h[4]\n" + "fmla z14.h, z6.h, z1.h[4]\n" + "addvl x28, x28, #1\n" + "fmla z18.h, z6.h, z2.h[4]\n" + "fmla z22.h, z6.h, z3.h[4]\n" + "fmla z11.h, z7.h, z0.h[4]\n" + "fmla z15.h, z7.h, z1.h[4]\n" + "fmla z19.h, z7.h, z2.h[4]\n" + "fmla z23.h, z7.h, z3.h[4]\n" + "ble 54f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[5]\n" + "fmla z12.h, z6.h, z1.h[5]\n" + "fmla z16.h, z6.h, z2.h[5]\n" + "fmla z20.h, z6.h, z3.h[5]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "subs x26, x26, #0x1\n" + "fmla z9.h, z7.h, z0.h[5]\n" + "fmla z13.h, z7.h, z1.h[5]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z17.h, z7.h, z2.h[5]\n" + "fmla z21.h, z7.h, z3.h[5]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x9, x9, #1\n" + "fmla z10.h, z6.h, z0.h[5]\n" + "fmla z14.h, z6.h, z1.h[5]\n" + "addvl x28, x28, #1\n" + "fmla z18.h, z6.h, z2.h[5]\n" + "fmla z22.h, z6.h, z3.h[5]\n" + "fmla z11.h, z7.h, z0.h[5]\n" + "fmla z15.h, z7.h, z1.h[5]\n" + "fmla z19.h, z7.h, z2.h[5]\n" + "fmla z23.h, z7.h, z3.h[5]\n" + "ble 54f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[6]\n" + "fmla z12.h, z6.h, z1.h[6]\n" + "fmla z16.h, z6.h, z2.h[6]\n" + "fmla z20.h, z6.h, z3.h[6]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "subs x26, x26, #0x1\n" + "fmla z9.h, z7.h, z0.h[6]\n" + "fmla z13.h, z7.h, z1.h[6]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z17.h, z7.h, z2.h[6]\n" + "fmla z21.h, z7.h, z3.h[6]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x9, x9, #1\n" + "fmla z10.h, z6.h, z0.h[6]\n" + "fmla z14.h, z6.h, z1.h[6]\n" + "addvl x28, x28, #1\n" + "fmla z18.h, z6.h, z2.h[6]\n" + "fmla z22.h, z6.h, z3.h[6]\n" + "fmla z11.h, z7.h, z0.h[6]\n" + "fmla z15.h, z7.h, z1.h[6]\n" + "fmla z19.h, z7.h, z2.h[6]\n" + "fmla z23.h, z7.h, z3.h[6]\n" + "ble 54f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[7]\n" + "fmla z12.h, z6.h, z1.h[7]\n" + "fmla z16.h, z6.h, z2.h[7]\n" + "fmla z20.h, z6.h, z3.h[7]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "addvl x11, x11, #1\n" + "fmla z9.h, z7.h, z0.h[7]\n" + "fmla z13.h, z7.h, z1.h[7]\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "fmla z17.h, z7.h, z2.h[7]\n" + "fmla z21.h, z7.h, z3.h[7]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.h, z6.h, z0.h[7]\n" + "fmla z14.h, z6.h, z1.h[7]\n" + "fmla z18.h, z6.h, z2.h[7]\n" + "fmla z22.h, z6.h, z3.h[7]\n" + "fmla z11.h, z7.h, z0.h[7]\n" + "fmla z15.h, z7.h, z1.h[7]\n" + "fmla z19.h, z7.h, z2.h[7]\n" + "fmla z23.h, z7.h, z3.h[7]\n" + "54:" // Height 4: Multiply loop: multiply skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 49b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "add x22, x23, x19, LSL #1\n" + "tbz %x[flags], #1, 55f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rh { z1.h }, p5/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rh { z0.h }, p5/Z, [x19]\n" + "fmin z8.h, p5/M, z8.h, z1.h\n" + "fmin z9.h, p5/M, z9.h, z1.h\n" + "fmin z10.h, p5/M, z10.h, z1.h\n" + "fmin z11.h, p5/M, z11.h, z1.h\n" + "fmin z12.h, p5/M, z12.h, z1.h\n" + "fmin z13.h, p5/M, z13.h, z1.h\n" + "fmin z14.h, p5/M, z14.h, z1.h\n" + "fmin z15.h, p5/M, z15.h, z1.h\n" + "fmin z16.h, p5/M, z16.h, z1.h\n" + "fmin z17.h, p5/M, z17.h, z1.h\n" + "fmin z18.h, p5/M, z18.h, z1.h\n" + "fmin z19.h, p5/M, z19.h, z1.h\n" + "fmin z20.h, p5/M, z20.h, z1.h\n" + "fmin z21.h, p5/M, z21.h, z1.h\n" + "fmin z22.h, p5/M, z22.h, z1.h\n" + "fmin z23.h, p5/M, z23.h, z1.h\n" + "fmax z8.h, p5/M, z8.h, z0.h\n" + "fmax z9.h, p5/M, z9.h, z0.h\n" + "fmax z10.h, p5/M, z10.h, z0.h\n" + "fmax z11.h, p5/M, z11.h, z0.h\n" + "fmax z12.h, p5/M, z12.h, z0.h\n" + "fmax z13.h, p5/M, z13.h, z0.h\n" + "fmax z14.h, p5/M, z14.h, z0.h\n" + "fmax z15.h, p5/M, z15.h, z0.h\n" + "fmax z16.h, p5/M, z16.h, z0.h\n" + "fmax z17.h, p5/M, z17.h, z0.h\n" + "fmax z18.h, p5/M, z18.h, z0.h\n" + "fmax z19.h, p5/M, z19.h, z0.h\n" + "fmax z20.h, p5/M, z20.h, z0.h\n" + "fmax z21.h, p5/M, z21.h, z0.h\n" + "fmax z22.h, p5/M, z22.h, z0.h\n" + "fmax z23.h, p5/M, z23.h, z0.h\n" + "55:" // Height 4: No activation + "st1h { z8.h }, p4, [x12]\n" + "st1h { z9.h }, p3, [x12, #1, MUL VL]\n" + "st1h { z10.h }, p2, [x12, #2, MUL VL]\n" + "st1h { z11.h }, p1, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "st1h { z12.h }, p4, [x24]\n" + "st1h { z13.h }, p3, [x24, #1, MUL VL]\n" + "st1h { z14.h }, p2, [x24, #2, MUL VL]\n" + "st1h { z15.h }, p1, [x24, #3, MUL VL]\n" + "st1h { z16.h }, p4, [x23]\n" + "st1h { z17.h }, p3, [x23, #1, MUL VL]\n" + "st1h { z18.h }, p2, [x23, #2, MUL VL]\n" + "st1h { z19.h }, p1, [x23, #3, MUL VL]\n" + "st1h { z20.h }, p4, [x22]\n" + "st1h { z21.h }, p3, [x22, #1, MUL VL]\n" + "st1h { z22.h }, p2, [x22, #2, MUL VL]\n" + "st1h { z23.h }, p1, [x22, #3, MUL VL]\n" + "56:" // Height 4: Writeback done + "dech x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 44b\n" + "b 86f\n" + "57:" // Height 5 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "58:" // Height 5: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "cnth x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 59f\n" + "dech x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 59f\n" + "dech x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 59f\n" + "mov x10, x11\n" + "59:" // Height 5: B setup done + "mov x19, #0x0\n" + "whilelt p4.h, x19, x13\n" + "inch x19\n" + "whilelt p3.h, x19, x13\n" + "inch x19\n" + "whilelt p2.h, x19, x13\n" + "inch x19\n" + "whilelt p1.h, x19, x13\n" + "cbz x14, 60f\n" + "ld1h { z8.h }, p5/Z, [x14]\n" + "ld1h { z9.h }, p5/Z, [x14, #1, MUL VL]\n" + "mov z12.d, z8.d\n" + "mov z13.d, z9.d\n" + "ld1h { z10.h }, p5/Z, [x14, #2, MUL VL]\n" + "ld1h { z11.h }, p5/Z, [x14, #3, MUL VL]\n" + "mov z14.d, z10.d\n" + "mov z15.d, z11.d\n" + "mov z16.d, z8.d\n" + "mov z17.d, z9.d\n" + "addvl x14, x14, #4\n" + "mov z18.d, z10.d\n" + "mov z19.d, z11.d\n" + "mov z20.d, z8.d\n" + "mov z21.d, z9.d\n" + "mov z22.d, z10.d\n" + "mov z23.d, z11.d\n" + "mov z24.d, z8.d\n" + "mov z25.d, z9.d\n" + "mov z26.d, z10.d\n" + "mov z27.d, z11.d\n" + "b 62f\n" + "60:" // Height 5: no bias + "tbz %x[flags], #0, 61f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "ld1h { z8.h }, p4/Z, [x12]\n" + "add x22, x23, x19, LSL #1\n" + "add x21, x22, x19, LSL #1\n" + "ld1h { z9.h }, p3/Z, [x12, #1, MUL VL]\n" + "ld1h { z10.h }, p2/Z, [x12, #2, MUL VL]\n" + "ld1h { z11.h }, p1/Z, [x12, #3, MUL VL]\n" + "ld1h { z12.h }, p4/Z, [x24]\n" + "ld1h { z13.h }, p3/Z, [x24, #1, MUL VL]\n" + "ld1h { z14.h }, p2/Z, [x24, #2, MUL VL]\n" + "ld1h { z15.h }, p1/Z, [x24, #3, MUL VL]\n" + "ld1h { z16.h }, p4/Z, [x23]\n" + "ld1h { z17.h }, p3/Z, [x23, #1, MUL VL]\n" + "ld1h { z18.h }, p2/Z, [x23, #2, MUL VL]\n" + "ld1h { z19.h }, p1/Z, [x23, #3, MUL VL]\n" + "ld1h { z20.h }, p4/Z, [x22]\n" + "ld1h { z21.h }, p3/Z, [x22, #1, MUL VL]\n" + "ld1h { z22.h }, p2/Z, [x22, #2, MUL VL]\n" + "ld1h { z23.h }, p1/Z, [x22, #3, MUL VL]\n" + "ld1h { z24.h }, p4/Z, [x21]\n" + "ld1h { z25.h }, p3/Z, [x21, #1, MUL VL]\n" + "ld1h { z26.h }, p2/Z, [x21, #2, MUL VL]\n" + "ld1h { z27.h }, p1/Z, [x21, #3, MUL VL]\n" + "b 62f\n" + "61:" // Height 5: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "mov z19.b, #0x0\n" + "mov z20.b, #0x0\n" + "mov z21.b, #0x0\n" + "mov z22.b, #0x0\n" + "mov z23.b, #0x0\n" + "mov z24.b, #0x0\n" + "mov z25.b, #0x0\n" + "mov z26.b, #0x0\n" + "mov z27.b, #0x0\n" + "62:" // Height 5: setup done + "mov x27, #0x0\n" + "63:" // Height 5: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 64f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "ldr x22, [x20, #0x18]\n" + "ldr x21, [x20, #0x20]\n" + "cbnz x27, 65f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "add x24, x24, x19, LSL #1\n" + "add x23, x23, x19, LSL #1\n" + "add x22, x22, x19, LSL #1\n" + "add x21, x21, x19, LSL #1\n" + "b 65f\n" + "64:" // Height 5: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "add x22, x23, x19, LSL #1\n" + "add x21, x22, x19, LSL #1\n" + "65:" // Height 5: input setup done + "cmp x26, #0x8\n" + "ble 67f\n" + "66:" // Height 5: Multiply loop: Main loop head + "whilelt p0.h, XZR, x26\n" + "ld1rqh { z0.h }, p0/Z, [x25]\n" + "ld1rqh { z1.h }, p0/Z, [x24]\n" + "sub x26, x26, #0x8\n" + "ld1rqh { z2.h }, p0/Z, [x23]\n" + "ld1rqh { z3.h }, p0/Z, [x22]\n" + "cmp x26, #0x8\n" + "add x25, x25, #0x10\n" + "ld1rqh { z4.h }, p0/Z, [x21]\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "fmla z8.h, z6.h, z0.h[0]\n" + "fmla z12.h, z6.h, z1.h[0]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z16.h, z6.h, z2.h[0]\n" + "fmla z20.h, z6.h, z3.h[0]\n" + "add x24, x24, #0x10\n" + "fmla z24.h, z6.h, z4.h[0]\n" + "fmla z9.h, z7.h, z0.h[0]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "add x23, x23, #0x10\n" + "fmla z13.h, z7.h, z1.h[0]\n" + "fmla z17.h, z7.h, z2.h[0]\n" + "add x22, x22, #0x10\n" + "add x21, x21, #0x10\n" + "fmla z21.h, z7.h, z3.h[0]\n" + "fmla z25.h, z7.h, z4.h[0]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "fmla z10.h, z6.h, z0.h[0]\n" + "fmla z14.h, z6.h, z1.h[0]\n" + "fmla z18.h, z6.h, z2.h[0]\n" + "fmla z22.h, z6.h, z3.h[0]\n" + "fmla z26.h, z6.h, z4.h[0]\n" + "fmla z11.h, z7.h, z0.h[0]\n" + "ld1h { z6.h }, p5/Z, [x11, #1, MUL VL]\n" + "fmla z15.h, z7.h, z1.h[0]\n" + "fmla z19.h, z7.h, z2.h[0]\n" + "fmla z23.h, z7.h, z3.h[0]\n" + "fmla z27.h, z7.h, z4.h[0]\n" + "ld1h { z7.h }, p5/Z, [x10, #1, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[1]\n" + "fmla z12.h, z6.h, z1.h[1]\n" + "fmla z16.h, z6.h, z2.h[1]\n" + "fmla z20.h, z6.h, z3.h[1]\n" + "fmla z24.h, z6.h, z4.h[1]\n" + "fmla z9.h, z7.h, z0.h[1]\n" + "ld1h { z6.h }, p5/Z, [x9, #1, MUL VL]\n" + "fmla z13.h, z7.h, z1.h[1]\n" + "fmla z17.h, z7.h, z2.h[1]\n" + "fmla z21.h, z7.h, z3.h[1]\n" + "fmla z25.h, z7.h, z4.h[1]\n" + "ld1h { z7.h }, p5/Z, [x28, #1, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[1]\n" + "fmla z14.h, z6.h, z1.h[1]\n" + "fmla z18.h, z6.h, z2.h[1]\n" + "fmla z22.h, z6.h, z3.h[1]\n" + "fmla z26.h, z6.h, z4.h[1]\n" + "fmla z11.h, z7.h, z0.h[1]\n" + "ld1h { z6.h }, p5/Z, [x11, #2, MUL VL]\n" + "fmla z15.h, z7.h, z1.h[1]\n" + "fmla z19.h, z7.h, z2.h[1]\n" + "fmla z23.h, z7.h, z3.h[1]\n" + "fmla z27.h, z7.h, z4.h[1]\n" + "ld1h { z7.h }, p5/Z, [x10, #2, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[2]\n" + "fmla z12.h, z6.h, z1.h[2]\n" + "fmla z16.h, z6.h, z2.h[2]\n" + "fmla z20.h, z6.h, z3.h[2]\n" + "fmla z24.h, z6.h, z4.h[2]\n" + "fmla z9.h, z7.h, z0.h[2]\n" + "ld1h { z6.h }, p5/Z, [x9, #2, MUL VL]\n" + "fmla z13.h, z7.h, z1.h[2]\n" + "fmla z17.h, z7.h, z2.h[2]\n" + "fmla z21.h, z7.h, z3.h[2]\n" + "fmla z25.h, z7.h, z4.h[2]\n" + "ld1h { z7.h }, p5/Z, [x28, #2, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[2]\n" + "fmla z14.h, z6.h, z1.h[2]\n" + "fmla z18.h, z6.h, z2.h[2]\n" + "fmla z22.h, z6.h, z3.h[2]\n" + "fmla z26.h, z6.h, z4.h[2]\n" + "fmla z11.h, z7.h, z0.h[2]\n" + "ld1h { z6.h }, p5/Z, [x11, #3, MUL VL]\n" + "fmla z15.h, z7.h, z1.h[2]\n" + "fmla z19.h, z7.h, z2.h[2]\n" + "fmla z23.h, z7.h, z3.h[2]\n" + "fmla z27.h, z7.h, z4.h[2]\n" + "ld1h { z7.h }, p5/Z, [x10, #3, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[3]\n" + "fmla z12.h, z6.h, z1.h[3]\n" + "fmla z16.h, z6.h, z2.h[3]\n" + "fmla z20.h, z6.h, z3.h[3]\n" + "fmla z24.h, z6.h, z4.h[3]\n" + "fmla z9.h, z7.h, z0.h[3]\n" + "ld1h { z6.h }, p5/Z, [x9, #3, MUL VL]\n" + "fmla z13.h, z7.h, z1.h[3]\n" + "fmla z17.h, z7.h, z2.h[3]\n" + "fmla z21.h, z7.h, z3.h[3]\n" + "fmla z25.h, z7.h, z4.h[3]\n" + "ld1h { z7.h }, p5/Z, [x28, #3, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[3]\n" + "fmla z14.h, z6.h, z1.h[3]\n" + "fmla z18.h, z6.h, z2.h[3]\n" + "fmla z22.h, z6.h, z3.h[3]\n" + "fmla z26.h, z6.h, z4.h[3]\n" + "fmla z11.h, z7.h, z0.h[3]\n" + "ld1h { z6.h }, p5/Z, [x11, #4, MUL VL]\n" + "fmla z15.h, z7.h, z1.h[3]\n" + "fmla z19.h, z7.h, z2.h[3]\n" + "fmla z23.h, z7.h, z3.h[3]\n" + "fmla z27.h, z7.h, z4.h[3]\n" + "ld1h { z7.h }, p5/Z, [x10, #4, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[4]\n" + "fmla z12.h, z6.h, z1.h[4]\n" + "fmla z16.h, z6.h, z2.h[4]\n" + "fmla z20.h, z6.h, z3.h[4]\n" + "fmla z24.h, z6.h, z4.h[4]\n" + "fmla z9.h, z7.h, z0.h[4]\n" + "ld1h { z6.h }, p5/Z, [x9, #4, MUL VL]\n" + "fmla z13.h, z7.h, z1.h[4]\n" + "fmla z17.h, z7.h, z2.h[4]\n" + "fmla z21.h, z7.h, z3.h[4]\n" + "fmla z25.h, z7.h, z4.h[4]\n" + "ld1h { z7.h }, p5/Z, [x28, #4, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[4]\n" + "fmla z14.h, z6.h, z1.h[4]\n" + "fmla z18.h, z6.h, z2.h[4]\n" + "fmla z22.h, z6.h, z3.h[4]\n" + "fmla z26.h, z6.h, z4.h[4]\n" + "fmla z11.h, z7.h, z0.h[4]\n" + "ld1h { z6.h }, p5/Z, [x11, #5, MUL VL]\n" + "fmla z15.h, z7.h, z1.h[4]\n" + "fmla z19.h, z7.h, z2.h[4]\n" + "fmla z23.h, z7.h, z3.h[4]\n" + "fmla z27.h, z7.h, z4.h[4]\n" + "ld1h { z7.h }, p5/Z, [x10, #5, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[5]\n" + "fmla z12.h, z6.h, z1.h[5]\n" + "fmla z16.h, z6.h, z2.h[5]\n" + "fmla z20.h, z6.h, z3.h[5]\n" + "fmla z24.h, z6.h, z4.h[5]\n" + "fmla z9.h, z7.h, z0.h[5]\n" + "ld1h { z6.h }, p5/Z, [x9, #5, MUL VL]\n" + "fmla z13.h, z7.h, z1.h[5]\n" + "fmla z17.h, z7.h, z2.h[5]\n" + "fmla z21.h, z7.h, z3.h[5]\n" + "fmla z25.h, z7.h, z4.h[5]\n" + "ld1h { z7.h }, p5/Z, [x28, #5, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[5]\n" + "fmla z14.h, z6.h, z1.h[5]\n" + "fmla z18.h, z6.h, z2.h[5]\n" + "fmla z22.h, z6.h, z3.h[5]\n" + "fmla z26.h, z6.h, z4.h[5]\n" + "fmla z11.h, z7.h, z0.h[5]\n" + "ld1h { z6.h }, p5/Z, [x11, #6, MUL VL]\n" + "fmla z15.h, z7.h, z1.h[5]\n" + "fmla z19.h, z7.h, z2.h[5]\n" + "fmla z23.h, z7.h, z3.h[5]\n" + "fmla z27.h, z7.h, z4.h[5]\n" + "ld1h { z7.h }, p5/Z, [x10, #6, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[6]\n" + "fmla z12.h, z6.h, z1.h[6]\n" + "fmla z16.h, z6.h, z2.h[6]\n" + "fmla z20.h, z6.h, z3.h[6]\n" + "fmla z24.h, z6.h, z4.h[6]\n" + "fmla z9.h, z7.h, z0.h[6]\n" + "ld1h { z6.h }, p5/Z, [x9, #6, MUL VL]\n" + "fmla z13.h, z7.h, z1.h[6]\n" + "fmla z17.h, z7.h, z2.h[6]\n" + "fmla z21.h, z7.h, z3.h[6]\n" + "fmla z25.h, z7.h, z4.h[6]\n" + "ld1h { z7.h }, p5/Z, [x28, #6, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[6]\n" + "fmla z14.h, z6.h, z1.h[6]\n" + "fmla z18.h, z6.h, z2.h[6]\n" + "fmla z22.h, z6.h, z3.h[6]\n" + "fmla z26.h, z6.h, z4.h[6]\n" + "fmla z11.h, z7.h, z0.h[6]\n" + "ld1h { z6.h }, p5/Z, [x11, #7, MUL VL]\n" + "addvl x11, x11, #8\n" + "fmla z15.h, z7.h, z1.h[6]\n" + "fmla z19.h, z7.h, z2.h[6]\n" + "fmla z23.h, z7.h, z3.h[6]\n" + "fmla z27.h, z7.h, z4.h[6]\n" + "ld1h { z7.h }, p5/Z, [x10, #7, MUL VL]\n" + "addvl x10, x10, #8\n" + "fmla z8.h, z6.h, z0.h[7]\n" + "fmla z12.h, z6.h, z1.h[7]\n" + "fmla z16.h, z6.h, z2.h[7]\n" + "fmla z20.h, z6.h, z3.h[7]\n" + "fmla z24.h, z6.h, z4.h[7]\n" + "fmla z9.h, z7.h, z0.h[7]\n" + "ld1h { z6.h }, p5/Z, [x9, #7, MUL VL]\n" + "addvl x9, x9, #8\n" + "fmla z13.h, z7.h, z1.h[7]\n" + "fmla z17.h, z7.h, z2.h[7]\n" + "fmla z21.h, z7.h, z3.h[7]\n" + "fmla z25.h, z7.h, z4.h[7]\n" + "ld1h { z7.h }, p5/Z, [x28, #7, MUL VL]\n" + "addvl x28, x28, #8\n" + "fmla z10.h, z6.h, z0.h[7]\n" + "fmla z14.h, z6.h, z1.h[7]\n" + "fmla z18.h, z6.h, z2.h[7]\n" + "fmla z22.h, z6.h, z3.h[7]\n" + "fmla z26.h, z6.h, z4.h[7]\n" + "fmla z11.h, z7.h, z0.h[7]\n" + "fmla z15.h, z7.h, z1.h[7]\n" + "fmla z19.h, z7.h, z2.h[7]\n" + "fmla z23.h, z7.h, z3.h[7]\n" + "fmla z27.h, z7.h, z4.h[7]\n" + "bgt 66b\n" + "67:" // Height 5: Multiply loop: Single iteration only + "whilelt p0.h, XZR, x26\n" + "ld1rqh { z0.h }, p0/Z, [x25]\n" + "ld1rqh { z1.h }, p0/Z, [x24]\n" + "subs x26, x26, #0x1\n" + "ld1rqh { z2.h }, p0/Z, [x23]\n" + "ld1rqh { z3.h }, p0/Z, [x22]\n" + "ld1rqh { z4.h }, p0/Z, [x21]\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "fmla z8.h, z6.h, z0.h[0]\n" + "fmla z12.h, z6.h, z1.h[0]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z16.h, z6.h, z2.h[0]\n" + "fmla z20.h, z6.h, z3.h[0]\n" + "addvl x11, x11, #1\n" + "fmla z24.h, z6.h, z4.h[0]\n" + "fmla z9.h, z7.h, z0.h[0]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "addvl x10, x10, #1\n" + "fmla z13.h, z7.h, z1.h[0]\n" + "fmla z17.h, z7.h, z2.h[0]\n" + "addvl x9, x9, #1\n" + "fmla z21.h, z7.h, z3.h[0]\n" + "fmla z25.h, z7.h, z4.h[0]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.h, z6.h, z0.h[0]\n" + "fmla z14.h, z6.h, z1.h[0]\n" + "fmla z18.h, z6.h, z2.h[0]\n" + "fmla z22.h, z6.h, z3.h[0]\n" + "fmla z26.h, z6.h, z4.h[0]\n" + "fmla z11.h, z7.h, z0.h[0]\n" + "fmla z15.h, z7.h, z1.h[0]\n" + "fmla z19.h, z7.h, z2.h[0]\n" + "fmla z23.h, z7.h, z3.h[0]\n" + "fmla z27.h, z7.h, z4.h[0]\n" + "ble 68f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[1]\n" + "fmla z12.h, z6.h, z1.h[1]\n" + "fmla z16.h, z6.h, z2.h[1]\n" + "fmla z20.h, z6.h, z3.h[1]\n" + "subs x26, x26, #0x1\n" + "addvl x11, x11, #1\n" + "fmla z24.h, z6.h, z4.h[1]\n" + "fmla z9.h, z7.h, z0.h[1]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "addvl x10, x10, #1\n" + "fmla z13.h, z7.h, z1.h[1]\n" + "fmla z17.h, z7.h, z2.h[1]\n" + "addvl x9, x9, #1\n" + "fmla z21.h, z7.h, z3.h[1]\n" + "fmla z25.h, z7.h, z4.h[1]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.h, z6.h, z0.h[1]\n" + "fmla z14.h, z6.h, z1.h[1]\n" + "fmla z18.h, z6.h, z2.h[1]\n" + "fmla z22.h, z6.h, z3.h[1]\n" + "fmla z26.h, z6.h, z4.h[1]\n" + "fmla z11.h, z7.h, z0.h[1]\n" + "fmla z15.h, z7.h, z1.h[1]\n" + "fmla z19.h, z7.h, z2.h[1]\n" + "fmla z23.h, z7.h, z3.h[1]\n" + "fmla z27.h, z7.h, z4.h[1]\n" + "ble 68f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[2]\n" + "fmla z12.h, z6.h, z1.h[2]\n" + "fmla z16.h, z6.h, z2.h[2]\n" + "fmla z20.h, z6.h, z3.h[2]\n" + "subs x26, x26, #0x1\n" + "addvl x11, x11, #1\n" + "fmla z24.h, z6.h, z4.h[2]\n" + "fmla z9.h, z7.h, z0.h[2]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "addvl x10, x10, #1\n" + "fmla z13.h, z7.h, z1.h[2]\n" + "fmla z17.h, z7.h, z2.h[2]\n" + "addvl x9, x9, #1\n" + "fmla z21.h, z7.h, z3.h[2]\n" + "fmla z25.h, z7.h, z4.h[2]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.h, z6.h, z0.h[2]\n" + "fmla z14.h, z6.h, z1.h[2]\n" + "fmla z18.h, z6.h, z2.h[2]\n" + "fmla z22.h, z6.h, z3.h[2]\n" + "fmla z26.h, z6.h, z4.h[2]\n" + "fmla z11.h, z7.h, z0.h[2]\n" + "fmla z15.h, z7.h, z1.h[2]\n" + "fmla z19.h, z7.h, z2.h[2]\n" + "fmla z23.h, z7.h, z3.h[2]\n" + "fmla z27.h, z7.h, z4.h[2]\n" + "ble 68f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[3]\n" + "fmla z12.h, z6.h, z1.h[3]\n" + "fmla z16.h, z6.h, z2.h[3]\n" + "fmla z20.h, z6.h, z3.h[3]\n" + "subs x26, x26, #0x1\n" + "addvl x11, x11, #1\n" + "fmla z24.h, z6.h, z4.h[3]\n" + "fmla z9.h, z7.h, z0.h[3]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "addvl x10, x10, #1\n" + "fmla z13.h, z7.h, z1.h[3]\n" + "fmla z17.h, z7.h, z2.h[3]\n" + "addvl x9, x9, #1\n" + "fmla z21.h, z7.h, z3.h[3]\n" + "fmla z25.h, z7.h, z4.h[3]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.h, z6.h, z0.h[3]\n" + "fmla z14.h, z6.h, z1.h[3]\n" + "fmla z18.h, z6.h, z2.h[3]\n" + "fmla z22.h, z6.h, z3.h[3]\n" + "fmla z26.h, z6.h, z4.h[3]\n" + "fmla z11.h, z7.h, z0.h[3]\n" + "fmla z15.h, z7.h, z1.h[3]\n" + "fmla z19.h, z7.h, z2.h[3]\n" + "fmla z23.h, z7.h, z3.h[3]\n" + "fmla z27.h, z7.h, z4.h[3]\n" + "ble 68f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[4]\n" + "fmla z12.h, z6.h, z1.h[4]\n" + "fmla z16.h, z6.h, z2.h[4]\n" + "fmla z20.h, z6.h, z3.h[4]\n" + "subs x26, x26, #0x1\n" + "addvl x11, x11, #1\n" + "fmla z24.h, z6.h, z4.h[4]\n" + "fmla z9.h, z7.h, z0.h[4]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "addvl x10, x10, #1\n" + "fmla z13.h, z7.h, z1.h[4]\n" + "fmla z17.h, z7.h, z2.h[4]\n" + "addvl x9, x9, #1\n" + "fmla z21.h, z7.h, z3.h[4]\n" + "fmla z25.h, z7.h, z4.h[4]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.h, z6.h, z0.h[4]\n" + "fmla z14.h, z6.h, z1.h[4]\n" + "fmla z18.h, z6.h, z2.h[4]\n" + "fmla z22.h, z6.h, z3.h[4]\n" + "fmla z26.h, z6.h, z4.h[4]\n" + "fmla z11.h, z7.h, z0.h[4]\n" + "fmla z15.h, z7.h, z1.h[4]\n" + "fmla z19.h, z7.h, z2.h[4]\n" + "fmla z23.h, z7.h, z3.h[4]\n" + "fmla z27.h, z7.h, z4.h[4]\n" + "ble 68f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[5]\n" + "fmla z12.h, z6.h, z1.h[5]\n" + "fmla z16.h, z6.h, z2.h[5]\n" + "fmla z20.h, z6.h, z3.h[5]\n" + "subs x26, x26, #0x1\n" + "addvl x11, x11, #1\n" + "fmla z24.h, z6.h, z4.h[5]\n" + "fmla z9.h, z7.h, z0.h[5]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "addvl x10, x10, #1\n" + "fmla z13.h, z7.h, z1.h[5]\n" + "fmla z17.h, z7.h, z2.h[5]\n" + "addvl x9, x9, #1\n" + "fmla z21.h, z7.h, z3.h[5]\n" + "fmla z25.h, z7.h, z4.h[5]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.h, z6.h, z0.h[5]\n" + "fmla z14.h, z6.h, z1.h[5]\n" + "fmla z18.h, z6.h, z2.h[5]\n" + "fmla z22.h, z6.h, z3.h[5]\n" + "fmla z26.h, z6.h, z4.h[5]\n" + "fmla z11.h, z7.h, z0.h[5]\n" + "fmla z15.h, z7.h, z1.h[5]\n" + "fmla z19.h, z7.h, z2.h[5]\n" + "fmla z23.h, z7.h, z3.h[5]\n" + "fmla z27.h, z7.h, z4.h[5]\n" + "ble 68f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[6]\n" + "fmla z12.h, z6.h, z1.h[6]\n" + "fmla z16.h, z6.h, z2.h[6]\n" + "fmla z20.h, z6.h, z3.h[6]\n" + "subs x26, x26, #0x1\n" + "addvl x11, x11, #1\n" + "fmla z24.h, z6.h, z4.h[6]\n" + "fmla z9.h, z7.h, z0.h[6]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "addvl x10, x10, #1\n" + "fmla z13.h, z7.h, z1.h[6]\n" + "fmla z17.h, z7.h, z2.h[6]\n" + "addvl x9, x9, #1\n" + "fmla z21.h, z7.h, z3.h[6]\n" + "fmla z25.h, z7.h, z4.h[6]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.h, z6.h, z0.h[6]\n" + "fmla z14.h, z6.h, z1.h[6]\n" + "fmla z18.h, z6.h, z2.h[6]\n" + "fmla z22.h, z6.h, z3.h[6]\n" + "fmla z26.h, z6.h, z4.h[6]\n" + "fmla z11.h, z7.h, z0.h[6]\n" + "fmla z15.h, z7.h, z1.h[6]\n" + "fmla z19.h, z7.h, z2.h[6]\n" + "fmla z23.h, z7.h, z3.h[6]\n" + "fmla z27.h, z7.h, z4.h[6]\n" + "ble 68f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[7]\n" + "fmla z12.h, z6.h, z1.h[7]\n" + "fmla z16.h, z6.h, z2.h[7]\n" + "fmla z20.h, z6.h, z3.h[7]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z24.h, z6.h, z4.h[7]\n" + "fmla z9.h, z7.h, z0.h[7]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "addvl x9, x9, #1\n" + "fmla z13.h, z7.h, z1.h[7]\n" + "fmla z17.h, z7.h, z2.h[7]\n" + "fmla z21.h, z7.h, z3.h[7]\n" + "fmla z25.h, z7.h, z4.h[7]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.h, z6.h, z0.h[7]\n" + "fmla z14.h, z6.h, z1.h[7]\n" + "fmla z18.h, z6.h, z2.h[7]\n" + "fmla z22.h, z6.h, z3.h[7]\n" + "fmla z26.h, z6.h, z4.h[7]\n" + "fmla z11.h, z7.h, z0.h[7]\n" + "fmla z15.h, z7.h, z1.h[7]\n" + "fmla z19.h, z7.h, z2.h[7]\n" + "fmla z23.h, z7.h, z3.h[7]\n" + "fmla z27.h, z7.h, z4.h[7]\n" + "68:" // Height 5: Multiply loop: multiply skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 63b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "add x22, x23, x19, LSL #1\n" + "add x21, x22, x19, LSL #1\n" + "tbz %x[flags], #1, 69f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rh { z1.h }, p5/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rh { z0.h }, p5/Z, [x19]\n" + "fmin z8.h, p5/M, z8.h, z1.h\n" + "fmin z9.h, p5/M, z9.h, z1.h\n" + "fmin z10.h, p5/M, z10.h, z1.h\n" + "fmin z11.h, p5/M, z11.h, z1.h\n" + "fmin z12.h, p5/M, z12.h, z1.h\n" + "fmin z13.h, p5/M, z13.h, z1.h\n" + "fmin z14.h, p5/M, z14.h, z1.h\n" + "fmin z15.h, p5/M, z15.h, z1.h\n" + "fmin z16.h, p5/M, z16.h, z1.h\n" + "fmin z17.h, p5/M, z17.h, z1.h\n" + "fmin z18.h, p5/M, z18.h, z1.h\n" + "fmin z19.h, p5/M, z19.h, z1.h\n" + "fmin z20.h, p5/M, z20.h, z1.h\n" + "fmin z21.h, p5/M, z21.h, z1.h\n" + "fmin z22.h, p5/M, z22.h, z1.h\n" + "fmin z23.h, p5/M, z23.h, z1.h\n" + "fmin z24.h, p5/M, z24.h, z1.h\n" + "fmin z25.h, p5/M, z25.h, z1.h\n" + "fmin z26.h, p5/M, z26.h, z1.h\n" + "fmin z27.h, p5/M, z27.h, z1.h\n" + "fmax z8.h, p5/M, z8.h, z0.h\n" + "fmax z9.h, p5/M, z9.h, z0.h\n" + "fmax z10.h, p5/M, z10.h, z0.h\n" + "fmax z11.h, p5/M, z11.h, z0.h\n" + "fmax z12.h, p5/M, z12.h, z0.h\n" + "fmax z13.h, p5/M, z13.h, z0.h\n" + "fmax z14.h, p5/M, z14.h, z0.h\n" + "fmax z15.h, p5/M, z15.h, z0.h\n" + "fmax z16.h, p5/M, z16.h, z0.h\n" + "fmax z17.h, p5/M, z17.h, z0.h\n" + "fmax z18.h, p5/M, z18.h, z0.h\n" + "fmax z19.h, p5/M, z19.h, z0.h\n" + "fmax z20.h, p5/M, z20.h, z0.h\n" + "fmax z21.h, p5/M, z21.h, z0.h\n" + "fmax z22.h, p5/M, z22.h, z0.h\n" + "fmax z23.h, p5/M, z23.h, z0.h\n" + "fmax z24.h, p5/M, z24.h, z0.h\n" + "fmax z25.h, p5/M, z25.h, z0.h\n" + "fmax z26.h, p5/M, z26.h, z0.h\n" + "fmax z27.h, p5/M, z27.h, z0.h\n" + "69:" // Height 5: No activation + "st1h { z8.h }, p4, [x12]\n" + "st1h { z9.h }, p3, [x12, #1, MUL VL]\n" + "st1h { z10.h }, p2, [x12, #2, MUL VL]\n" + "st1h { z11.h }, p1, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "st1h { z12.h }, p4, [x24]\n" + "st1h { z13.h }, p3, [x24, #1, MUL VL]\n" + "st1h { z14.h }, p2, [x24, #2, MUL VL]\n" + "st1h { z15.h }, p1, [x24, #3, MUL VL]\n" + "st1h { z16.h }, p4, [x23]\n" + "st1h { z17.h }, p3, [x23, #1, MUL VL]\n" + "st1h { z18.h }, p2, [x23, #2, MUL VL]\n" + "st1h { z19.h }, p1, [x23, #3, MUL VL]\n" + "st1h { z20.h }, p4, [x22]\n" + "st1h { z21.h }, p3, [x22, #1, MUL VL]\n" + "st1h { z22.h }, p2, [x22, #2, MUL VL]\n" + "st1h { z23.h }, p1, [x22, #3, MUL VL]\n" + "st1h { z24.h }, p4, [x21]\n" + "st1h { z25.h }, p3, [x21, #1, MUL VL]\n" + "st1h { z26.h }, p2, [x21, #2, MUL VL]\n" + "st1h { z27.h }, p1, [x21, #3, MUL VL]\n" + "70:" // Height 5: Writeback done + "dech x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 58b\n" + "b 86f\n" + "71:" // Height 6 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x20, #0xc\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "mov x14, %x[bias]\n" + "mov x12, %x[output_ptr]\n" + "madd %x[output_ptr], x19, x20, %x[output_ptr]\n" + "72:" // Height 6: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "cnth x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "add x19, x28, x19, LSL #1\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 73f\n" + "dech x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 73f\n" + "dech x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 73f\n" + "mov x10, x11\n" + "73:" // Height 6: B setup done + "mov x19, #0x0\n" + "whilelt p4.h, x19, x13\n" + "inch x19\n" + "whilelt p3.h, x19, x13\n" + "inch x19\n" + "whilelt p2.h, x19, x13\n" + "inch x19\n" + "whilelt p1.h, x19, x13\n" + "cbz x14, 74f\n" + "ld1h { z8.h }, p5/Z, [x14]\n" + "ld1h { z9.h }, p5/Z, [x14, #1, MUL VL]\n" + "mov z12.d, z8.d\n" + "mov z13.d, z9.d\n" + "ld1h { z10.h }, p5/Z, [x14, #2, MUL VL]\n" + "ld1h { z11.h }, p5/Z, [x14, #3, MUL VL]\n" + "mov z14.d, z10.d\n" + "mov z15.d, z11.d\n" + "mov z16.d, z8.d\n" + "mov z17.d, z9.d\n" + "addvl x14, x14, #4\n" + "mov z18.d, z10.d\n" + "mov z19.d, z11.d\n" + "mov z20.d, z8.d\n" + "mov z21.d, z9.d\n" + "mov z22.d, z10.d\n" + "mov z23.d, z11.d\n" + "mov z24.d, z8.d\n" + "mov z25.d, z9.d\n" + "mov z26.d, z10.d\n" + "mov z27.d, z11.d\n" + "mov z28.d, z8.d\n" + "mov z29.d, z9.d\n" + "mov z30.d, z10.d\n" + "mov z31.d, z11.d\n" + "b 76f\n" + "74:" // Height 6: no bias + "tbz %x[flags], #0, 75f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "ld1h { z8.h }, p4/Z, [x12]\n" + "add x22, x23, x19, LSL #1\n" + "add x21, x22, x19, LSL #1\n" + "ld1h { z9.h }, p3/Z, [x12, #1, MUL VL]\n" + "ld1h { z10.h }, p2/Z, [x12, #2, MUL VL]\n" + "add x20, x21, x19, LSL #1\n" + "ld1h { z11.h }, p1/Z, [x12, #3, MUL VL]\n" + "ld1h { z12.h }, p4/Z, [x24]\n" + "ld1h { z13.h }, p3/Z, [x24, #1, MUL VL]\n" + "ld1h { z14.h }, p2/Z, [x24, #2, MUL VL]\n" + "ld1h { z15.h }, p1/Z, [x24, #3, MUL VL]\n" + "ld1h { z16.h }, p4/Z, [x23]\n" + "ld1h { z17.h }, p3/Z, [x23, #1, MUL VL]\n" + "ld1h { z18.h }, p2/Z, [x23, #2, MUL VL]\n" + "ld1h { z19.h }, p1/Z, [x23, #3, MUL VL]\n" + "ld1h { z20.h }, p4/Z, [x22]\n" + "ld1h { z21.h }, p3/Z, [x22, #1, MUL VL]\n" + "ld1h { z22.h }, p2/Z, [x22, #2, MUL VL]\n" + "ld1h { z23.h }, p1/Z, [x22, #3, MUL VL]\n" + "ld1h { z24.h }, p4/Z, [x21]\n" + "ld1h { z25.h }, p3/Z, [x21, #1, MUL VL]\n" + "ld1h { z26.h }, p2/Z, [x21, #2, MUL VL]\n" + "ld1h { z27.h }, p1/Z, [x21, #3, MUL VL]\n" + "ld1h { z28.h }, p4/Z, [x20]\n" + "ld1h { z29.h }, p3/Z, [x20, #1, MUL VL]\n" + "ld1h { z30.h }, p2/Z, [x20, #2, MUL VL]\n" + "ld1h { z31.h }, p1/Z, [x20, #3, MUL VL]\n" + "b 76f\n" + "75:" // Height 6: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "mov z19.b, #0x0\n" + "mov z20.b, #0x0\n" + "mov z21.b, #0x0\n" + "mov z22.b, #0x0\n" + "mov z23.b, #0x0\n" + "mov z24.b, #0x0\n" + "mov z25.b, #0x0\n" + "mov z26.b, #0x0\n" + "mov z27.b, #0x0\n" + "mov z28.b, #0x0\n" + "mov z29.b, #0x0\n" + "mov z30.b, #0x0\n" + "mov z31.b, #0x0\n" + "76:" // Height 6: setup done + "mov x27, #0x0\n" + "77:" // Height 6: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 78f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "ldr x22, [x20, #0x18]\n" + "ldr x21, [x20, #0x20]\n" + "ldr x20, [x20, #0x28]\n" + "cbnz x27, 79f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #1\n" + "add x24, x24, x19, LSL #1\n" + "add x23, x23, x19, LSL #1\n" + "add x22, x22, x19, LSL #1\n" + "add x21, x21, x19, LSL #1\n" + "add x20, x20, x19, LSL #1\n" + "b 79f\n" + "78:" // Height 6: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "add x22, x23, x19, LSL #1\n" + "add x21, x22, x19, LSL #1\n" + "add x20, x21, x19, LSL #1\n" + "79:" // Height 6: input setup done + "cmp x26, #0x8\n" + "ble 81f\n" + "80:" // Height 6: Multiply loop: Main loop head + "whilelt p0.h, XZR, x26\n" + "ld1rqh { z0.h }, p0/Z, [x25]\n" + "ld1rqh { z1.h }, p0/Z, [x24]\n" + "sub x26, x26, #0x8\n" + "ld1rqh { z2.h }, p0/Z, [x23]\n" + "ld1rqh { z3.h }, p0/Z, [x22]\n" + "cmp x26, #0x8\n" + "add x25, x25, #0x10\n" + "ld1rqh { z4.h }, p0/Z, [x21]\n" + "ld1rqh { z5.h }, p0/Z, [x20]\n" + "add x24, x24, #0x10\n" + "add x23, x23, #0x10\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[0]\n" + "fmla z12.h, z6.h, z1.h[0]\n" + "fmla z16.h, z6.h, z2.h[0]\n" + "fmla z20.h, z6.h, z3.h[0]\n" + "add x22, x22, #0x10\n" + "add x21, x21, #0x10\n" + "fmla z24.h, z6.h, z4.h[0]\n" + "fmla z28.h, z6.h, z5.h[0]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "add x20, x20, #0x10\n" + "fmla z9.h, z7.h, z0.h[0]\n" + "fmla z13.h, z7.h, z1.h[0]\n" + "fmla z17.h, z7.h, z2.h[0]\n" + "fmla z21.h, z7.h, z3.h[0]\n" + "fmla z25.h, z7.h, z4.h[0]\n" + "fmla z29.h, z7.h, z5.h[0]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "fmla z10.h, z6.h, z0.h[0]\n" + "fmla z14.h, z6.h, z1.h[0]\n" + "fmla z18.h, z6.h, z2.h[0]\n" + "fmla z22.h, z6.h, z3.h[0]\n" + "fmla z26.h, z6.h, z4.h[0]\n" + "fmla z30.h, z6.h, z5.h[0]\n" + "ld1h { z6.h }, p5/Z, [x11, #1, MUL VL]\n" + "fmla z11.h, z7.h, z0.h[0]\n" + "fmla z15.h, z7.h, z1.h[0]\n" + "fmla z19.h, z7.h, z2.h[0]\n" + "fmla z23.h, z7.h, z3.h[0]\n" + "fmla z27.h, z7.h, z4.h[0]\n" + "fmla z31.h, z7.h, z5.h[0]\n" + "ld1h { z7.h }, p5/Z, [x10, #1, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[1]\n" + "fmla z12.h, z6.h, z1.h[1]\n" + "fmla z16.h, z6.h, z2.h[1]\n" + "fmla z20.h, z6.h, z3.h[1]\n" + "fmla z24.h, z6.h, z4.h[1]\n" + "fmla z28.h, z6.h, z5.h[1]\n" + "ld1h { z6.h }, p5/Z, [x9, #1, MUL VL]\n" + "fmla z9.h, z7.h, z0.h[1]\n" + "fmla z13.h, z7.h, z1.h[1]\n" + "fmla z17.h, z7.h, z2.h[1]\n" + "fmla z21.h, z7.h, z3.h[1]\n" + "fmla z25.h, z7.h, z4.h[1]\n" + "fmla z29.h, z7.h, z5.h[1]\n" + "ld1h { z7.h }, p5/Z, [x28, #1, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[1]\n" + "fmla z14.h, z6.h, z1.h[1]\n" + "fmla z18.h, z6.h, z2.h[1]\n" + "fmla z22.h, z6.h, z3.h[1]\n" + "fmla z26.h, z6.h, z4.h[1]\n" + "fmla z30.h, z6.h, z5.h[1]\n" + "ld1h { z6.h }, p5/Z, [x11, #2, MUL VL]\n" + "fmla z11.h, z7.h, z0.h[1]\n" + "fmla z15.h, z7.h, z1.h[1]\n" + "fmla z19.h, z7.h, z2.h[1]\n" + "fmla z23.h, z7.h, z3.h[1]\n" + "fmla z27.h, z7.h, z4.h[1]\n" + "fmla z31.h, z7.h, z5.h[1]\n" + "ld1h { z7.h }, p5/Z, [x10, #2, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[2]\n" + "fmla z12.h, z6.h, z1.h[2]\n" + "fmla z16.h, z6.h, z2.h[2]\n" + "fmla z20.h, z6.h, z3.h[2]\n" + "fmla z24.h, z6.h, z4.h[2]\n" + "fmla z28.h, z6.h, z5.h[2]\n" + "ld1h { z6.h }, p5/Z, [x9, #2, MUL VL]\n" + "fmla z9.h, z7.h, z0.h[2]\n" + "fmla z13.h, z7.h, z1.h[2]\n" + "fmla z17.h, z7.h, z2.h[2]\n" + "fmla z21.h, z7.h, z3.h[2]\n" + "fmla z25.h, z7.h, z4.h[2]\n" + "fmla z29.h, z7.h, z5.h[2]\n" + "ld1h { z7.h }, p5/Z, [x28, #2, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[2]\n" + "fmla z14.h, z6.h, z1.h[2]\n" + "fmla z18.h, z6.h, z2.h[2]\n" + "fmla z22.h, z6.h, z3.h[2]\n" + "fmla z26.h, z6.h, z4.h[2]\n" + "fmla z30.h, z6.h, z5.h[2]\n" + "ld1h { z6.h }, p5/Z, [x11, #3, MUL VL]\n" + "fmla z11.h, z7.h, z0.h[2]\n" + "fmla z15.h, z7.h, z1.h[2]\n" + "fmla z19.h, z7.h, z2.h[2]\n" + "fmla z23.h, z7.h, z3.h[2]\n" + "fmla z27.h, z7.h, z4.h[2]\n" + "fmla z31.h, z7.h, z5.h[2]\n" + "ld1h { z7.h }, p5/Z, [x10, #3, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[3]\n" + "fmla z12.h, z6.h, z1.h[3]\n" + "fmla z16.h, z6.h, z2.h[3]\n" + "fmla z20.h, z6.h, z3.h[3]\n" + "fmla z24.h, z6.h, z4.h[3]\n" + "fmla z28.h, z6.h, z5.h[3]\n" + "ld1h { z6.h }, p5/Z, [x9, #3, MUL VL]\n" + "fmla z9.h, z7.h, z0.h[3]\n" + "fmla z13.h, z7.h, z1.h[3]\n" + "fmla z17.h, z7.h, z2.h[3]\n" + "fmla z21.h, z7.h, z3.h[3]\n" + "fmla z25.h, z7.h, z4.h[3]\n" + "fmla z29.h, z7.h, z5.h[3]\n" + "ld1h { z7.h }, p5/Z, [x28, #3, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[3]\n" + "fmla z14.h, z6.h, z1.h[3]\n" + "fmla z18.h, z6.h, z2.h[3]\n" + "fmla z22.h, z6.h, z3.h[3]\n" + "fmla z26.h, z6.h, z4.h[3]\n" + "fmla z30.h, z6.h, z5.h[3]\n" + "ld1h { z6.h }, p5/Z, [x11, #4, MUL VL]\n" + "fmla z11.h, z7.h, z0.h[3]\n" + "fmla z15.h, z7.h, z1.h[3]\n" + "fmla z19.h, z7.h, z2.h[3]\n" + "fmla z23.h, z7.h, z3.h[3]\n" + "fmla z27.h, z7.h, z4.h[3]\n" + "fmla z31.h, z7.h, z5.h[3]\n" + "ld1h { z7.h }, p5/Z, [x10, #4, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[4]\n" + "fmla z12.h, z6.h, z1.h[4]\n" + "fmla z16.h, z6.h, z2.h[4]\n" + "fmla z20.h, z6.h, z3.h[4]\n" + "fmla z24.h, z6.h, z4.h[4]\n" + "fmla z28.h, z6.h, z5.h[4]\n" + "ld1h { z6.h }, p5/Z, [x9, #4, MUL VL]\n" + "fmla z9.h, z7.h, z0.h[4]\n" + "fmla z13.h, z7.h, z1.h[4]\n" + "fmla z17.h, z7.h, z2.h[4]\n" + "fmla z21.h, z7.h, z3.h[4]\n" + "fmla z25.h, z7.h, z4.h[4]\n" + "fmla z29.h, z7.h, z5.h[4]\n" + "ld1h { z7.h }, p5/Z, [x28, #4, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[4]\n" + "fmla z14.h, z6.h, z1.h[4]\n" + "fmla z18.h, z6.h, z2.h[4]\n" + "fmla z22.h, z6.h, z3.h[4]\n" + "fmla z26.h, z6.h, z4.h[4]\n" + "fmla z30.h, z6.h, z5.h[4]\n" + "ld1h { z6.h }, p5/Z, [x11, #5, MUL VL]\n" + "fmla z11.h, z7.h, z0.h[4]\n" + "fmla z15.h, z7.h, z1.h[4]\n" + "fmla z19.h, z7.h, z2.h[4]\n" + "fmla z23.h, z7.h, z3.h[4]\n" + "fmla z27.h, z7.h, z4.h[4]\n" + "fmla z31.h, z7.h, z5.h[4]\n" + "ld1h { z7.h }, p5/Z, [x10, #5, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[5]\n" + "fmla z12.h, z6.h, z1.h[5]\n" + "fmla z16.h, z6.h, z2.h[5]\n" + "fmla z20.h, z6.h, z3.h[5]\n" + "fmla z24.h, z6.h, z4.h[5]\n" + "fmla z28.h, z6.h, z5.h[5]\n" + "ld1h { z6.h }, p5/Z, [x9, #5, MUL VL]\n" + "fmla z9.h, z7.h, z0.h[5]\n" + "fmla z13.h, z7.h, z1.h[5]\n" + "fmla z17.h, z7.h, z2.h[5]\n" + "fmla z21.h, z7.h, z3.h[5]\n" + "fmla z25.h, z7.h, z4.h[5]\n" + "fmla z29.h, z7.h, z5.h[5]\n" + "ld1h { z7.h }, p5/Z, [x28, #5, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[5]\n" + "fmla z14.h, z6.h, z1.h[5]\n" + "fmla z18.h, z6.h, z2.h[5]\n" + "fmla z22.h, z6.h, z3.h[5]\n" + "fmla z26.h, z6.h, z4.h[5]\n" + "fmla z30.h, z6.h, z5.h[5]\n" + "ld1h { z6.h }, p5/Z, [x11, #6, MUL VL]\n" + "fmla z11.h, z7.h, z0.h[5]\n" + "fmla z15.h, z7.h, z1.h[5]\n" + "fmla z19.h, z7.h, z2.h[5]\n" + "fmla z23.h, z7.h, z3.h[5]\n" + "fmla z27.h, z7.h, z4.h[5]\n" + "fmla z31.h, z7.h, z5.h[5]\n" + "ld1h { z7.h }, p5/Z, [x10, #6, MUL VL]\n" + "fmla z8.h, z6.h, z0.h[6]\n" + "fmla z12.h, z6.h, z1.h[6]\n" + "fmla z16.h, z6.h, z2.h[6]\n" + "fmla z20.h, z6.h, z3.h[6]\n" + "fmla z24.h, z6.h, z4.h[6]\n" + "fmla z28.h, z6.h, z5.h[6]\n" + "ld1h { z6.h }, p5/Z, [x9, #6, MUL VL]\n" + "fmla z9.h, z7.h, z0.h[6]\n" + "fmla z13.h, z7.h, z1.h[6]\n" + "fmla z17.h, z7.h, z2.h[6]\n" + "fmla z21.h, z7.h, z3.h[6]\n" + "fmla z25.h, z7.h, z4.h[6]\n" + "fmla z29.h, z7.h, z5.h[6]\n" + "ld1h { z7.h }, p5/Z, [x28, #6, MUL VL]\n" + "fmla z10.h, z6.h, z0.h[6]\n" + "fmla z14.h, z6.h, z1.h[6]\n" + "fmla z18.h, z6.h, z2.h[6]\n" + "fmla z22.h, z6.h, z3.h[6]\n" + "fmla z26.h, z6.h, z4.h[6]\n" + "fmla z30.h, z6.h, z5.h[6]\n" + "ld1h { z6.h }, p5/Z, [x11, #7, MUL VL]\n" + "addvl x11, x11, #8\n" + "fmla z11.h, z7.h, z0.h[6]\n" + "fmla z15.h, z7.h, z1.h[6]\n" + "fmla z19.h, z7.h, z2.h[6]\n" + "fmla z23.h, z7.h, z3.h[6]\n" + "fmla z27.h, z7.h, z4.h[6]\n" + "fmla z31.h, z7.h, z5.h[6]\n" + "ld1h { z7.h }, p5/Z, [x10, #7, MUL VL]\n" + "addvl x10, x10, #8\n" + "fmla z8.h, z6.h, z0.h[7]\n" + "fmla z12.h, z6.h, z1.h[7]\n" + "fmla z16.h, z6.h, z2.h[7]\n" + "fmla z20.h, z6.h, z3.h[7]\n" + "fmla z24.h, z6.h, z4.h[7]\n" + "fmla z28.h, z6.h, z5.h[7]\n" + "ld1h { z6.h }, p5/Z, [x9, #7, MUL VL]\n" + "addvl x9, x9, #8\n" + "fmla z9.h, z7.h, z0.h[7]\n" + "fmla z13.h, z7.h, z1.h[7]\n" + "fmla z17.h, z7.h, z2.h[7]\n" + "fmla z21.h, z7.h, z3.h[7]\n" + "fmla z25.h, z7.h, z4.h[7]\n" + "fmla z29.h, z7.h, z5.h[7]\n" + "ld1h { z7.h }, p5/Z, [x28, #7, MUL VL]\n" + "addvl x28, x28, #8\n" + "fmla z10.h, z6.h, z0.h[7]\n" + "fmla z14.h, z6.h, z1.h[7]\n" + "fmla z18.h, z6.h, z2.h[7]\n" + "fmla z22.h, z6.h, z3.h[7]\n" + "fmla z26.h, z6.h, z4.h[7]\n" + "fmla z30.h, z6.h, z5.h[7]\n" + "fmla z11.h, z7.h, z0.h[7]\n" + "fmla z15.h, z7.h, z1.h[7]\n" + "fmla z19.h, z7.h, z2.h[7]\n" + "fmla z23.h, z7.h, z3.h[7]\n" + "fmla z27.h, z7.h, z4.h[7]\n" + "fmla z31.h, z7.h, z5.h[7]\n" + "bgt 80b\n" + "81:" // Height 6: Multiply loop: Single iteration only + "whilelt p0.h, XZR, x26\n" + "ld1rqh { z0.h }, p0/Z, [x25]\n" + "ld1rqh { z1.h }, p0/Z, [x24]\n" + "subs x26, x26, #0x1\n" + "ld1rqh { z2.h }, p0/Z, [x23]\n" + "ld1rqh { z3.h }, p0/Z, [x22]\n" + "ld1rqh { z4.h }, p0/Z, [x21]\n" + "ld1rqh { z5.h }, p0/Z, [x20]\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[0]\n" + "fmla z12.h, z6.h, z1.h[0]\n" + "fmla z16.h, z6.h, z2.h[0]\n" + "fmla z20.h, z6.h, z3.h[0]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z24.h, z6.h, z4.h[0]\n" + "fmla z28.h, z6.h, z5.h[0]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "addvl x9, x9, #1\n" + "fmla z9.h, z7.h, z0.h[0]\n" + "fmla z13.h, z7.h, z1.h[0]\n" + "fmla z17.h, z7.h, z2.h[0]\n" + "fmla z21.h, z7.h, z3.h[0]\n" + "fmla z25.h, z7.h, z4.h[0]\n" + "fmla z29.h, z7.h, z5.h[0]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.h, z6.h, z0.h[0]\n" + "fmla z14.h, z6.h, z1.h[0]\n" + "fmla z18.h, z6.h, z2.h[0]\n" + "fmla z22.h, z6.h, z3.h[0]\n" + "fmla z26.h, z6.h, z4.h[0]\n" + "fmla z30.h, z6.h, z5.h[0]\n" + "fmla z11.h, z7.h, z0.h[0]\n" + "fmla z15.h, z7.h, z1.h[0]\n" + "fmla z19.h, z7.h, z2.h[0]\n" + "fmla z23.h, z7.h, z3.h[0]\n" + "fmla z27.h, z7.h, z4.h[0]\n" + "fmla z31.h, z7.h, z5.h[0]\n" + "ble 82f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[1]\n" + "fmla z12.h, z6.h, z1.h[1]\n" + "fmla z16.h, z6.h, z2.h[1]\n" + "fmla z20.h, z6.h, z3.h[1]\n" + "subs x26, x26, #0x1\n" + "addvl x11, x11, #1\n" + "fmla z24.h, z6.h, z4.h[1]\n" + "fmla z28.h, z6.h, z5.h[1]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "addvl x10, x10, #1\n" + "fmla z9.h, z7.h, z0.h[1]\n" + "fmla z13.h, z7.h, z1.h[1]\n" + "addvl x9, x9, #1\n" + "fmla z17.h, z7.h, z2.h[1]\n" + "fmla z21.h, z7.h, z3.h[1]\n" + "fmla z25.h, z7.h, z4.h[1]\n" + "fmla z29.h, z7.h, z5.h[1]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.h, z6.h, z0.h[1]\n" + "fmla z14.h, z6.h, z1.h[1]\n" + "fmla z18.h, z6.h, z2.h[1]\n" + "fmla z22.h, z6.h, z3.h[1]\n" + "fmla z26.h, z6.h, z4.h[1]\n" + "fmla z30.h, z6.h, z5.h[1]\n" + "fmla z11.h, z7.h, z0.h[1]\n" + "fmla z15.h, z7.h, z1.h[1]\n" + "fmla z19.h, z7.h, z2.h[1]\n" + "fmla z23.h, z7.h, z3.h[1]\n" + "fmla z27.h, z7.h, z4.h[1]\n" + "fmla z31.h, z7.h, z5.h[1]\n" + "ble 82f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[2]\n" + "fmla z12.h, z6.h, z1.h[2]\n" + "fmla z16.h, z6.h, z2.h[2]\n" + "fmla z20.h, z6.h, z3.h[2]\n" + "subs x26, x26, #0x1\n" + "addvl x11, x11, #1\n" + "fmla z24.h, z6.h, z4.h[2]\n" + "fmla z28.h, z6.h, z5.h[2]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "addvl x10, x10, #1\n" + "fmla z9.h, z7.h, z0.h[2]\n" + "fmla z13.h, z7.h, z1.h[2]\n" + "addvl x9, x9, #1\n" + "fmla z17.h, z7.h, z2.h[2]\n" + "fmla z21.h, z7.h, z3.h[2]\n" + "fmla z25.h, z7.h, z4.h[2]\n" + "fmla z29.h, z7.h, z5.h[2]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.h, z6.h, z0.h[2]\n" + "fmla z14.h, z6.h, z1.h[2]\n" + "fmla z18.h, z6.h, z2.h[2]\n" + "fmla z22.h, z6.h, z3.h[2]\n" + "fmla z26.h, z6.h, z4.h[2]\n" + "fmla z30.h, z6.h, z5.h[2]\n" + "fmla z11.h, z7.h, z0.h[2]\n" + "fmla z15.h, z7.h, z1.h[2]\n" + "fmla z19.h, z7.h, z2.h[2]\n" + "fmla z23.h, z7.h, z3.h[2]\n" + "fmla z27.h, z7.h, z4.h[2]\n" + "fmla z31.h, z7.h, z5.h[2]\n" + "ble 82f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[3]\n" + "fmla z12.h, z6.h, z1.h[3]\n" + "fmla z16.h, z6.h, z2.h[3]\n" + "fmla z20.h, z6.h, z3.h[3]\n" + "subs x26, x26, #0x1\n" + "addvl x11, x11, #1\n" + "fmla z24.h, z6.h, z4.h[3]\n" + "fmla z28.h, z6.h, z5.h[3]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "addvl x10, x10, #1\n" + "fmla z9.h, z7.h, z0.h[3]\n" + "fmla z13.h, z7.h, z1.h[3]\n" + "addvl x9, x9, #1\n" + "fmla z17.h, z7.h, z2.h[3]\n" + "fmla z21.h, z7.h, z3.h[3]\n" + "fmla z25.h, z7.h, z4.h[3]\n" + "fmla z29.h, z7.h, z5.h[3]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.h, z6.h, z0.h[3]\n" + "fmla z14.h, z6.h, z1.h[3]\n" + "fmla z18.h, z6.h, z2.h[3]\n" + "fmla z22.h, z6.h, z3.h[3]\n" + "fmla z26.h, z6.h, z4.h[3]\n" + "fmla z30.h, z6.h, z5.h[3]\n" + "fmla z11.h, z7.h, z0.h[3]\n" + "fmla z15.h, z7.h, z1.h[3]\n" + "fmla z19.h, z7.h, z2.h[3]\n" + "fmla z23.h, z7.h, z3.h[3]\n" + "fmla z27.h, z7.h, z4.h[3]\n" + "fmla z31.h, z7.h, z5.h[3]\n" + "ble 82f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[4]\n" + "fmla z12.h, z6.h, z1.h[4]\n" + "fmla z16.h, z6.h, z2.h[4]\n" + "fmla z20.h, z6.h, z3.h[4]\n" + "subs x26, x26, #0x1\n" + "addvl x11, x11, #1\n" + "fmla z24.h, z6.h, z4.h[4]\n" + "fmla z28.h, z6.h, z5.h[4]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "addvl x10, x10, #1\n" + "fmla z9.h, z7.h, z0.h[4]\n" + "fmla z13.h, z7.h, z1.h[4]\n" + "addvl x9, x9, #1\n" + "fmla z17.h, z7.h, z2.h[4]\n" + "fmla z21.h, z7.h, z3.h[4]\n" + "fmla z25.h, z7.h, z4.h[4]\n" + "fmla z29.h, z7.h, z5.h[4]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.h, z6.h, z0.h[4]\n" + "fmla z14.h, z6.h, z1.h[4]\n" + "fmla z18.h, z6.h, z2.h[4]\n" + "fmla z22.h, z6.h, z3.h[4]\n" + "fmla z26.h, z6.h, z4.h[4]\n" + "fmla z30.h, z6.h, z5.h[4]\n" + "fmla z11.h, z7.h, z0.h[4]\n" + "fmla z15.h, z7.h, z1.h[4]\n" + "fmla z19.h, z7.h, z2.h[4]\n" + "fmla z23.h, z7.h, z3.h[4]\n" + "fmla z27.h, z7.h, z4.h[4]\n" + "fmla z31.h, z7.h, z5.h[4]\n" + "ble 82f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[5]\n" + "fmla z12.h, z6.h, z1.h[5]\n" + "fmla z16.h, z6.h, z2.h[5]\n" + "fmla z20.h, z6.h, z3.h[5]\n" + "subs x26, x26, #0x1\n" + "addvl x11, x11, #1\n" + "fmla z24.h, z6.h, z4.h[5]\n" + "fmla z28.h, z6.h, z5.h[5]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "addvl x10, x10, #1\n" + "fmla z9.h, z7.h, z0.h[5]\n" + "fmla z13.h, z7.h, z1.h[5]\n" + "addvl x9, x9, #1\n" + "fmla z17.h, z7.h, z2.h[5]\n" + "fmla z21.h, z7.h, z3.h[5]\n" + "fmla z25.h, z7.h, z4.h[5]\n" + "fmla z29.h, z7.h, z5.h[5]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.h, z6.h, z0.h[5]\n" + "fmla z14.h, z6.h, z1.h[5]\n" + "fmla z18.h, z6.h, z2.h[5]\n" + "fmla z22.h, z6.h, z3.h[5]\n" + "fmla z26.h, z6.h, z4.h[5]\n" + "fmla z30.h, z6.h, z5.h[5]\n" + "fmla z11.h, z7.h, z0.h[5]\n" + "fmla z15.h, z7.h, z1.h[5]\n" + "fmla z19.h, z7.h, z2.h[5]\n" + "fmla z23.h, z7.h, z3.h[5]\n" + "fmla z27.h, z7.h, z4.h[5]\n" + "fmla z31.h, z7.h, z5.h[5]\n" + "ble 82f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[6]\n" + "fmla z12.h, z6.h, z1.h[6]\n" + "fmla z16.h, z6.h, z2.h[6]\n" + "fmla z20.h, z6.h, z3.h[6]\n" + "subs x26, x26, #0x1\n" + "addvl x11, x11, #1\n" + "fmla z24.h, z6.h, z4.h[6]\n" + "fmla z28.h, z6.h, z5.h[6]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "addvl x10, x10, #1\n" + "fmla z9.h, z7.h, z0.h[6]\n" + "fmla z13.h, z7.h, z1.h[6]\n" + "addvl x9, x9, #1\n" + "fmla z17.h, z7.h, z2.h[6]\n" + "fmla z21.h, z7.h, z3.h[6]\n" + "fmla z25.h, z7.h, z4.h[6]\n" + "fmla z29.h, z7.h, z5.h[6]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.h, z6.h, z0.h[6]\n" + "fmla z14.h, z6.h, z1.h[6]\n" + "fmla z18.h, z6.h, z2.h[6]\n" + "fmla z22.h, z6.h, z3.h[6]\n" + "fmla z26.h, z6.h, z4.h[6]\n" + "fmla z30.h, z6.h, z5.h[6]\n" + "fmla z11.h, z7.h, z0.h[6]\n" + "fmla z15.h, z7.h, z1.h[6]\n" + "fmla z19.h, z7.h, z2.h[6]\n" + "fmla z23.h, z7.h, z3.h[6]\n" + "fmla z27.h, z7.h, z4.h[6]\n" + "fmla z31.h, z7.h, z5.h[6]\n" + "ble 82f\n" + "ld1h { z6.h }, p5/Z, [x11]\n" + "ld1h { z7.h }, p5/Z, [x10]\n" + "fmla z8.h, z6.h, z0.h[7]\n" + "fmla z12.h, z6.h, z1.h[7]\n" + "fmla z16.h, z6.h, z2.h[7]\n" + "fmla z20.h, z6.h, z3.h[7]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z24.h, z6.h, z4.h[7]\n" + "fmla z28.h, z6.h, z5.h[7]\n" + "ld1h { z6.h }, p5/Z, [x9]\n" + "addvl x9, x9, #1\n" + "fmla z9.h, z7.h, z0.h[7]\n" + "fmla z13.h, z7.h, z1.h[7]\n" + "fmla z17.h, z7.h, z2.h[7]\n" + "fmla z21.h, z7.h, z3.h[7]\n" + "fmla z25.h, z7.h, z4.h[7]\n" + "fmla z29.h, z7.h, z5.h[7]\n" + "ld1h { z7.h }, p5/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.h, z6.h, z0.h[7]\n" + "fmla z14.h, z6.h, z1.h[7]\n" + "fmla z18.h, z6.h, z2.h[7]\n" + "fmla z22.h, z6.h, z3.h[7]\n" + "fmla z26.h, z6.h, z4.h[7]\n" + "fmla z30.h, z6.h, z5.h[7]\n" + "fmla z11.h, z7.h, z0.h[7]\n" + "fmla z15.h, z7.h, z1.h[7]\n" + "fmla z19.h, z7.h, z2.h[7]\n" + "fmla z23.h, z7.h, z3.h[7]\n" + "fmla z27.h, z7.h, z4.h[7]\n" + "fmla z31.h, z7.h, z5.h[7]\n" + "82:" // Height 6: Multiply loop: multiply skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 77b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #1\n" + "add x23, x24, x19, LSL #1\n" + "add x22, x23, x19, LSL #1\n" + "add x21, x22, x19, LSL #1\n" + "add x20, x21, x19, LSL #1\n" + "tbz %x[flags], #1, 83f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rh { z1.h }, p5/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rh { z0.h }, p5/Z, [x19]\n" + "fmin z8.h, p5/M, z8.h, z1.h\n" + "fmin z9.h, p5/M, z9.h, z1.h\n" + "fmin z10.h, p5/M, z10.h, z1.h\n" + "fmin z11.h, p5/M, z11.h, z1.h\n" + "fmin z12.h, p5/M, z12.h, z1.h\n" + "fmin z13.h, p5/M, z13.h, z1.h\n" + "fmin z14.h, p5/M, z14.h, z1.h\n" + "fmin z15.h, p5/M, z15.h, z1.h\n" + "fmin z16.h, p5/M, z16.h, z1.h\n" + "fmin z17.h, p5/M, z17.h, z1.h\n" + "fmin z18.h, p5/M, z18.h, z1.h\n" + "fmin z19.h, p5/M, z19.h, z1.h\n" + "fmin z20.h, p5/M, z20.h, z1.h\n" + "fmin z21.h, p5/M, z21.h, z1.h\n" + "fmin z22.h, p5/M, z22.h, z1.h\n" + "fmin z23.h, p5/M, z23.h, z1.h\n" + "fmin z24.h, p5/M, z24.h, z1.h\n" + "fmin z25.h, p5/M, z25.h, z1.h\n" + "fmin z26.h, p5/M, z26.h, z1.h\n" + "fmin z27.h, p5/M, z27.h, z1.h\n" + "fmin z28.h, p5/M, z28.h, z1.h\n" + "fmin z29.h, p5/M, z29.h, z1.h\n" + "fmin z30.h, p5/M, z30.h, z1.h\n" + "fmin z31.h, p5/M, z31.h, z1.h\n" + "fmax z8.h, p5/M, z8.h, z0.h\n" + "fmax z9.h, p5/M, z9.h, z0.h\n" + "fmax z10.h, p5/M, z10.h, z0.h\n" + "fmax z11.h, p5/M, z11.h, z0.h\n" + "fmax z12.h, p5/M, z12.h, z0.h\n" + "fmax z13.h, p5/M, z13.h, z0.h\n" + "fmax z14.h, p5/M, z14.h, z0.h\n" + "fmax z15.h, p5/M, z15.h, z0.h\n" + "fmax z16.h, p5/M, z16.h, z0.h\n" + "fmax z17.h, p5/M, z17.h, z0.h\n" + "fmax z18.h, p5/M, z18.h, z0.h\n" + "fmax z19.h, p5/M, z19.h, z0.h\n" + "fmax z20.h, p5/M, z20.h, z0.h\n" + "fmax z21.h, p5/M, z21.h, z0.h\n" + "fmax z22.h, p5/M, z22.h, z0.h\n" + "fmax z23.h, p5/M, z23.h, z0.h\n" + "fmax z24.h, p5/M, z24.h, z0.h\n" + "fmax z25.h, p5/M, z25.h, z0.h\n" + "fmax z26.h, p5/M, z26.h, z0.h\n" + "fmax z27.h, p5/M, z27.h, z0.h\n" + "fmax z28.h, p5/M, z28.h, z0.h\n" + "fmax z29.h, p5/M, z29.h, z0.h\n" + "fmax z30.h, p5/M, z30.h, z0.h\n" + "fmax z31.h, p5/M, z31.h, z0.h\n" + "83:" // Height 6: No activation + "st1h { z8.h }, p4, [x12]\n" + "st1h { z9.h }, p3, [x12, #1, MUL VL]\n" + "st1h { z10.h }, p2, [x12, #2, MUL VL]\n" + "st1h { z11.h }, p1, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "st1h { z12.h }, p4, [x24]\n" + "st1h { z13.h }, p3, [x24, #1, MUL VL]\n" + "st1h { z14.h }, p2, [x24, #2, MUL VL]\n" + "st1h { z15.h }, p1, [x24, #3, MUL VL]\n" + "st1h { z16.h }, p4, [x23]\n" + "st1h { z17.h }, p3, [x23, #1, MUL VL]\n" + "st1h { z18.h }, p2, [x23, #2, MUL VL]\n" + "st1h { z19.h }, p1, [x23, #3, MUL VL]\n" + "st1h { z20.h }, p4, [x22]\n" + "st1h { z21.h }, p3, [x22, #1, MUL VL]\n" + "st1h { z22.h }, p2, [x22, #2, MUL VL]\n" + "st1h { z23.h }, p1, [x22, #3, MUL VL]\n" + "st1h { z24.h }, p4, [x21]\n" + "st1h { z25.h }, p3, [x21, #1, MUL VL]\n" + "st1h { z26.h }, p2, [x21, #2, MUL VL]\n" + "st1h { z27.h }, p1, [x21, #3, MUL VL]\n" + "st1h { z28.h }, p4, [x20]\n" + "st1h { z29.h }, p3, [x20, #1, MUL VL]\n" + "st1h { z30.h }, p2, [x20, #2, MUL VL]\n" + "st1h { z31.h }, p1, [x20, #3, MUL VL]\n" + "84:" // Height 6: Writeback done + "dech x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 72b\n" + "subs %x[M], %x[M], #0x6\n" + "beq 86f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 85f\n" + "add x20, x20, #0x6\n" + "str x20, [%x[args_ptr], %[offsetof_input_offset]]\n" + "b 1b\n" + "85:" // Update direct input + "mov x19, #0xc\n" + "madd %x[input_ptr], x19, x20, %x[input_ptr]\n" + "b 1b\n" + "86:" // Exit + : [M] "+&r" (M), [input_ptr] "+&r" (input_ptr), [output_ptr] "+&r" (output_ptr) + : [args_ptr] "r" (&ka), [bias] "r" (bias), [flags] "r" (flags), [offset_max] "I" (offsetof(KernelArgs, maxval)), [offset_min] "I" (offsetof(KernelArgs, minval)), [offsetof_B_ptr] "I" (offsetof(KernelArgs, B_ptr)), [offsetof_B_stride] "I" (offsetof(KernelArgs, B_stride)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_cur_B_ptr] "I" (offsetof(KernelArgs, cur_B_ptr)), [offsetof_input_initial_col] "I" (offsetof(KernelArgs, input_initial_col)), [offsetof_input_offset] "I" (offsetof(KernelArgs, input_offset)), [offsetof_num_strings] "I" (offsetof(KernelArgs, num_strings)), [offsetof_output_offset] "I" (offsetof(KernelArgs, output_offset)), [offsetof_string_lengths] "I" (offsetof(KernelArgs, string_lengths)) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "x9", "x10", "x11", "x12", "x13", "x14", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // namespace arm_gemm +#endif // ARM_COMPUTE_ENABLE_SVE diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32_mla_6x4VL.hpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32_mla_6x4VL.hpp new file mode 100644 index 0000000000..b4c124c1e3 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32_mla_6x4VL.hpp @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#pragma once +#ifdef ARM_COMPUTE_ENABLE_SVE + +#include "../std_transforms_sve.hpp" +#include "../kernel_weight_format.hpp" +#include "../performance_parameters.hpp" + +#define ARGLIST \ + unsigned int, const unsigned int *, \ + IndirectInputArg, \ + size_t, size_t, \ + const float *, \ + size_t, \ + IndirectOutputArg, \ + const float *, Activation, bool + +namespace arm_gemm +{ +// Actual kernel implementations +void sve_ffhybrid_fp32_mla_6x4VL( ARGLIST ); +void sve_ffhybrid_fp32_mla_6x4VL_a64fx( ARGLIST ); + +class cls_sve_ffhybrid_fp32_mla_6x4VL +{ +public: + typedef float lhs_operand_type; + typedef float rhs_operand_type; + typedef float result_type; + + typedef void (*kern_type)( ARGLIST ); + + /* Kernel blocking parameters */ + static constexpr unsigned int out_height() + { + return 6; + } + static unsigned int stripe_width() + { + return get_vector_length() * 1; + } + + static KernelWeightFormat kernel_weight_format() + { + return KernelWeightFormat::VL1VL_BL32; + } + + static unsigned int out_width() + { + return get_vector_length() * 4; + } + + static constexpr unsigned int k_unroll() + { + return 1; + } + + static constexpr bool supports_accumulate() + { + return true; + } + + StdTransformsSVE transforms = {}; + template + static inline PerformanceParameters get_performance_parameters(const CPUInfo *ci) + { + if (std::is_same::value) { + switch (ci->get_cpu_model()) { + default: + return { 15.27 }; + } + } + + return { 1.0 }; + } + + // Default to the generic kernel + kern_type kernel=sve_ffhybrid_fp32_mla_6x4VL; + cls_sve_ffhybrid_fp32_mla_6x4VL(const CPUInfo *ci) + { + switch(ci->get_cpu_model()) { + default: + break; + case CPUModel::A64FX: + kernel=sve_ffhybrid_fp32_mla_6x4VL_a64fx; + break; + } + } +}; + +} // namespace arm_gemm + +#undef ARGLIST +#endif // ARM_COMPUTE_ENABLE_SVE diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32_mla_6x4VL/a64fx.cpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32_mla_6x4VL/a64fx.cpp new file mode 100644 index 0000000000..7dd4e234d5 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32_mla_6x4VL/a64fx.cpp @@ -0,0 +1,1530 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#ifdef ARM_COMPUTE_ENABLE_SVE + +#include "arm_gemm.hpp" +#include "../../utils.hpp" + +#include +#include + +namespace arm_gemm { + +void sve_ffhybrid_fp32_mla_6x4VL_a64fx ( + unsigned int num_strings, const unsigned int *string_lengths, IndirectInputArg A_arg, + size_t M, size_t N, const float *B_ptr, size_t B_stride, IndirectOutputArg output_arg, + const float *bias, Activation act, bool accumulate +) +{ + struct KernelArgs { + float maxval = static_cast(std::numeric_limits::infinity()); + float minval = - static_cast(std::numeric_limits::infinity()); + unsigned int num_strings = {}; + const unsigned int *string_lengths = {}; + size_t N = {}; + const float *B_ptr = {}; + const float *cur_B_ptr = {}; + size_t B_stride = {}; + size_t output_offset = {}; + size_t input_initial_col = {}; + size_t input_offset = {}; + } ka; + + unsigned long flags=0; + void *output_ptr; + void *input_ptr; + + if (output_arg.is_indirect) { + output_ptr=(void *)(output_arg.indirect.ptr); + ka.output_offset=output_arg.indirect.offset; + flags |= 0x4; + } else { + output_ptr=(void *)(output_arg.direct.base); + ka.output_offset=output_arg.direct.stride; + } + + if (A_arg.is_indirect) { + input_ptr=(void *)(A_arg.indirect.ptr); + ka.input_offset=A_arg.indirect.start_row; + ka.input_initial_col=A_arg.indirect.start_col; + flags |= 0x8; + } else { + assert(num_strings==1); + input_ptr=(void *)(A_arg.direct.base); + ka.input_offset=A_arg.direct.stride; + } + if (accumulate) { + flags |= 0x1; + } + ka.num_strings = num_strings; + ka.string_lengths = string_lengths; + ka.N = N; + ka.B_ptr = B_ptr; + ka.B_stride = B_stride; + switch(act.type) { + default: + case Activation::Type::None: + break; + case Activation::Type::BoundedReLU: + ka.maxval = static_cast(act.param1); + /* fall through */ + case Activation::Type::ReLU: + ka.minval = 0; + flags |= 0x2; + break; + } + __asm__ __volatile__( + "ptrue p4.b\n" + "1:" // Row loop + "cmp %x[M], #0x6\n" + "bge 66f\n" + "cmp %x[M], #0x4\n" + "bgt 53f\n" + "beq 40f\n" + "cmp %x[M], #0x2\n" + "bgt 27f\n" + "beq 14f\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "2:" // Height 1: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #2\n" + "cntw x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #2\n" + "add x28, x9, x19, LSL #2\n" + "add x19, x28, x19, LSL #2\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 3f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 3f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 3f\n" + "mov x10, x11\n" + "3:" // Height 1: B setup done + "mov x19, #0x0\n" + "whilelt p3.s, x19, x13\n" + "incw x19\n" + "whilelt p2.s, x19, x13\n" + "incw x19\n" + "whilelt p1.s, x19, x13\n" + "incw x19\n" + "whilelt p0.s, x19, x13\n" + "cbz x14, 4f\n" + "ld1w { z8.s }, p4/Z, [x14]\n" + "ld1w { z9.s }, p4/Z, [x14, #1, MUL VL]\n" + "ld1w { z10.s }, p4/Z, [x14, #2, MUL VL]\n" + "ld1w { z11.s }, p4/Z, [x14, #3, MUL VL]\n" + "addvl x14, x14, #4\n" + "b 6f\n" + "4:" // Height 1: no bias + "tbz %x[flags], #0, 5f\n" + "ld1w { z8.s }, p3/Z, [x12]\n" + "ld1w { z9.s }, p2/Z, [x12, #1, MUL VL]\n" + "ld1w { z10.s }, p1/Z, [x12, #2, MUL VL]\n" + "ld1w { z11.s }, p0/Z, [x12, #3, MUL VL]\n" + "b 6f\n" + "5:" // Height 1: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "6:" // Height 1: setup done + "mov x27, #0x0\n" + "7:" // Height 1: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 8f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "cbnz x27, 9f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #2\n" + "b 9f\n" + "8:" // Height 1: setup direct input + "mov x25, %x[input_ptr]\n" + "9:" // Height 1: input setup done + "subs x26, x26, #0x1\n" + "ld1rw { z0.s }, p4/Z, [x25]\n" + "ld1w { z6.s }, p4/Z, [x11]\n" + "ld1w { z7.s }, p4/Z, [x10]\n" + "ble 11f\n" + "10:" // Height 1: Multiply loop: Main loop + "fmla z8.s, p4/M, z6.s, z0.s\n" + "fmla z9.s, p4/M, z7.s, z0.s\n" + "ld1w { z6.s }, p4/Z, [x9]\n" + "ld1w { z7.s }, p4/Z, [x28]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z10.s, p4/M, z6.s, z0.s\n" + "fmla z11.s, p4/M, z7.s, z0.s\n" + "add x25, x25, #0x4\n" + "subs x26, x26, #0x1\n" + "ld1rw { z0.s }, p4/Z, [x25]\n" + "ld1w { z6.s }, p4/Z, [x11]\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "ld1w { z7.s }, p4/Z, [x10]\n" + "bgt 10b\n" + "11:" // Height 1: Multiply loop: Main loop skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "fmla z8.s, p4/M, z6.s, z0.s\n" + "fmla z9.s, p4/M, z7.s, z0.s\n" + "ld1w { z6.s }, p4/Z, [x9]\n" + "ld1w { z7.s }, p4/Z, [x28]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "fmla z10.s, p4/M, z6.s, z0.s\n" + "fmla z11.s, p4/M, z7.s, z0.s\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "bne 7b\n" + "tbz %x[flags], #1, 12f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rw { z1.s }, p4/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rw { z0.s }, p4/Z, [x19]\n" + "fmin z8.s, p4/M, z8.s, z1.s\n" + "fmin z9.s, p4/M, z9.s, z1.s\n" + "fmin z10.s, p4/M, z10.s, z1.s\n" + "fmin z11.s, p4/M, z11.s, z1.s\n" + "fmax z8.s, p4/M, z8.s, z0.s\n" + "fmax z9.s, p4/M, z9.s, z0.s\n" + "fmax z10.s, p4/M, z10.s, z0.s\n" + "fmax z11.s, p4/M, z11.s, z0.s\n" + "12:" // Height 1: No activation + "st1w { z8.s }, p3, [x12]\n" + "st1w { z9.s }, p2, [x12, #1, MUL VL]\n" + "st1w { z10.s }, p1, [x12, #2, MUL VL]\n" + "st1w { z11.s }, p0, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "13:" // Height 1: Writeback done + "decw x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 2b\n" + "b 80f\n" + "14:" // Height 2 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "15:" // Height 2: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #2\n" + "cntw x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #2\n" + "add x28, x9, x19, LSL #2\n" + "add x19, x28, x19, LSL #2\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 16f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 16f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 16f\n" + "mov x10, x11\n" + "16:" // Height 2: B setup done + "mov x19, #0x0\n" + "whilelt p3.s, x19, x13\n" + "incw x19\n" + "whilelt p2.s, x19, x13\n" + "incw x19\n" + "whilelt p1.s, x19, x13\n" + "incw x19\n" + "whilelt p0.s, x19, x13\n" + "cbz x14, 17f\n" + "ld1w { z8.s }, p4/Z, [x14]\n" + "ld1w { z9.s }, p4/Z, [x14, #1, MUL VL]\n" + "mov z12.d, z8.d\n" + "mov z13.d, z9.d\n" + "ld1w { z10.s }, p4/Z, [x14, #2, MUL VL]\n" + "ld1w { z11.s }, p4/Z, [x14, #3, MUL VL]\n" + "mov z14.d, z10.d\n" + "mov z15.d, z11.d\n" + "addvl x14, x14, #4\n" + "b 19f\n" + "17:" // Height 2: no bias + "tbz %x[flags], #0, 18f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "ld1w { z8.s }, p3/Z, [x12]\n" + "ld1w { z9.s }, p2/Z, [x12, #1, MUL VL]\n" + "ld1w { z10.s }, p1/Z, [x12, #2, MUL VL]\n" + "ld1w { z11.s }, p0/Z, [x12, #3, MUL VL]\n" + "ld1w { z12.s }, p3/Z, [x24]\n" + "ld1w { z13.s }, p2/Z, [x24, #1, MUL VL]\n" + "ld1w { z14.s }, p1/Z, [x24, #2, MUL VL]\n" + "ld1w { z15.s }, p0/Z, [x24, #3, MUL VL]\n" + "b 19f\n" + "18:" // Height 2: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "19:" // Height 2: setup done + "mov x27, #0x0\n" + "20:" // Height 2: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 21f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "cbnz x27, 22f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #2\n" + "add x24, x24, x19, LSL #2\n" + "b 22f\n" + "21:" // Height 2: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #2\n" + "22:" // Height 2: input setup done + "subs x26, x26, #0x1\n" + "ld1rw { z0.s }, p4/Z, [x25]\n" + "ld1rw { z1.s }, p4/Z, [x24]\n" + "ld1w { z6.s }, p4/Z, [x11]\n" + "ld1w { z7.s }, p4/Z, [x10]\n" + "ble 24f\n" + "23:" // Height 2: Multiply loop: Main loop + "fmla z8.s, p4/M, z6.s, z0.s\n" + "fmla z12.s, p4/M, z6.s, z1.s\n" + "ld1w { z6.s }, p4/Z, [x9]\n" + "addvl x11, x11, #1\n" + "fmla z9.s, p4/M, z7.s, z0.s\n" + "fmla z13.s, p4/M, z7.s, z1.s\n" + "ld1w { z7.s }, p4/Z, [x28]\n" + "addvl x10, x10, #1\n" + "add x25, x25, #0x4\n" + "subs x26, x26, #0x1\n" + "fmla z10.s, p4/M, z6.s, z0.s\n" + "fmla z14.s, p4/M, z6.s, z1.s\n" + "add x24, x24, #0x4\n" + "fmla z11.s, p4/M, z7.s, z0.s\n" + "fmla z15.s, p4/M, z7.s, z1.s\n" + "addvl x9, x9, #1\n" + "ld1rw { z0.s }, p4/Z, [x25]\n" + "ld1rw { z1.s }, p4/Z, [x24]\n" + "addvl x28, x28, #1\n" + "ld1w { z6.s }, p4/Z, [x11]\n" + "ld1w { z7.s }, p4/Z, [x10]\n" + "bgt 23b\n" + "24:" // Height 2: Multiply loop: Main loop skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "fmla z8.s, p4/M, z6.s, z0.s\n" + "fmla z12.s, p4/M, z6.s, z1.s\n" + "ld1w { z6.s }, p4/Z, [x9]\n" + "fmla z9.s, p4/M, z7.s, z0.s\n" + "fmla z13.s, p4/M, z7.s, z1.s\n" + "ld1w { z7.s }, p4/Z, [x28]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "fmla z10.s, p4/M, z6.s, z0.s\n" + "fmla z14.s, p4/M, z6.s, z1.s\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z11.s, p4/M, z7.s, z0.s\n" + "fmla z15.s, p4/M, z7.s, z1.s\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "bne 20b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "tbz %x[flags], #1, 25f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rw { z1.s }, p4/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rw { z0.s }, p4/Z, [x19]\n" + "fmin z8.s, p4/M, z8.s, z1.s\n" + "fmin z9.s, p4/M, z9.s, z1.s\n" + "fmin z10.s, p4/M, z10.s, z1.s\n" + "fmin z11.s, p4/M, z11.s, z1.s\n" + "fmin z12.s, p4/M, z12.s, z1.s\n" + "fmin z13.s, p4/M, z13.s, z1.s\n" + "fmin z14.s, p4/M, z14.s, z1.s\n" + "fmin z15.s, p4/M, z15.s, z1.s\n" + "fmax z8.s, p4/M, z8.s, z0.s\n" + "fmax z9.s, p4/M, z9.s, z0.s\n" + "fmax z10.s, p4/M, z10.s, z0.s\n" + "fmax z11.s, p4/M, z11.s, z0.s\n" + "fmax z12.s, p4/M, z12.s, z0.s\n" + "fmax z13.s, p4/M, z13.s, z0.s\n" + "fmax z14.s, p4/M, z14.s, z0.s\n" + "fmax z15.s, p4/M, z15.s, z0.s\n" + "25:" // Height 2: No activation + "st1w { z8.s }, p3, [x12]\n" + "st1w { z9.s }, p2, [x12, #1, MUL VL]\n" + "st1w { z10.s }, p1, [x12, #2, MUL VL]\n" + "st1w { z11.s }, p0, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "st1w { z12.s }, p3, [x24]\n" + "st1w { z13.s }, p2, [x24, #1, MUL VL]\n" + "st1w { z14.s }, p1, [x24, #2, MUL VL]\n" + "st1w { z15.s }, p0, [x24, #3, MUL VL]\n" + "26:" // Height 2: Writeback done + "decw x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 15b\n" + "b 80f\n" + "27:" // Height 3 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "28:" // Height 3: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #2\n" + "cntw x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #2\n" + "add x28, x9, x19, LSL #2\n" + "add x19, x28, x19, LSL #2\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 29f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 29f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 29f\n" + "mov x10, x11\n" + "29:" // Height 3: B setup done + "mov x19, #0x0\n" + "whilelt p3.s, x19, x13\n" + "incw x19\n" + "whilelt p2.s, x19, x13\n" + "incw x19\n" + "whilelt p1.s, x19, x13\n" + "incw x19\n" + "whilelt p0.s, x19, x13\n" + "cbz x14, 30f\n" + "ld1w { z8.s }, p4/Z, [x14]\n" + "ld1w { z9.s }, p4/Z, [x14, #1, MUL VL]\n" + "mov z12.d, z8.d\n" + "mov z13.d, z9.d\n" + "ld1w { z10.s }, p4/Z, [x14, #2, MUL VL]\n" + "ld1w { z11.s }, p4/Z, [x14, #3, MUL VL]\n" + "mov z14.d, z10.d\n" + "mov z15.d, z11.d\n" + "mov z16.d, z8.d\n" + "mov z17.d, z9.d\n" + "addvl x14, x14, #4\n" + "mov z18.d, z10.d\n" + "mov z19.d, z11.d\n" + "b 32f\n" + "30:" // Height 3: no bias + "tbz %x[flags], #0, 31f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "ld1w { z8.s }, p3/Z, [x12]\n" + "ld1w { z9.s }, p2/Z, [x12, #1, MUL VL]\n" + "ld1w { z10.s }, p1/Z, [x12, #2, MUL VL]\n" + "ld1w { z11.s }, p0/Z, [x12, #3, MUL VL]\n" + "ld1w { z12.s }, p3/Z, [x24]\n" + "ld1w { z13.s }, p2/Z, [x24, #1, MUL VL]\n" + "ld1w { z14.s }, p1/Z, [x24, #2, MUL VL]\n" + "ld1w { z15.s }, p0/Z, [x24, #3, MUL VL]\n" + "ld1w { z16.s }, p3/Z, [x23]\n" + "ld1w { z17.s }, p2/Z, [x23, #1, MUL VL]\n" + "ld1w { z18.s }, p1/Z, [x23, #2, MUL VL]\n" + "ld1w { z19.s }, p0/Z, [x23, #3, MUL VL]\n" + "b 32f\n" + "31:" // Height 3: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "mov z19.b, #0x0\n" + "32:" // Height 3: setup done + "mov x27, #0x0\n" + "33:" // Height 3: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 34f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "cbnz x27, 35f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #2\n" + "add x24, x24, x19, LSL #2\n" + "add x23, x23, x19, LSL #2\n" + "b 35f\n" + "34:" // Height 3: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "35:" // Height 3: input setup done + "subs x26, x26, #0x1\n" + "ld1rw { z0.s }, p4/Z, [x25]\n" + "ld1rw { z1.s }, p4/Z, [x24]\n" + "ld1rw { z2.s }, p4/Z, [x23]\n" + "ld1w { z6.s }, p4/Z, [x11]\n" + "ld1w { z7.s }, p4/Z, [x10]\n" + "ble 37f\n" + "36:" // Height 3: Multiply loop: Main loop + "fmla z8.s, p4/M, z6.s, z0.s\n" + "fmla z12.s, p4/M, z6.s, z1.s\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z16.s, p4/M, z6.s, z2.s\n" + "fmla z9.s, p4/M, z7.s, z0.s\n" + "ld1w { z6.s }, p4/Z, [x9]\n" + "add x25, x25, #0x4\n" + "fmla z13.s, p4/M, z7.s, z1.s\n" + "fmla z17.s, p4/M, z7.s, z2.s\n" + "ld1w { z7.s }, p4/Z, [x28]\n" + "subs x26, x26, #0x1\n" + "add x24, x24, #0x4\n" + "add x23, x23, #0x4\n" + "fmla z10.s, p4/M, z6.s, z0.s\n" + "fmla z14.s, p4/M, z6.s, z1.s\n" + "fmla z18.s, p4/M, z6.s, z2.s\n" + "fmla z11.s, p4/M, z7.s, z0.s\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "fmla z15.s, p4/M, z7.s, z1.s\n" + "fmla z19.s, p4/M, z7.s, z2.s\n" + "ld1rw { z0.s }, p4/Z, [x25]\n" + "ld1rw { z1.s }, p4/Z, [x24]\n" + "ld1rw { z2.s }, p4/Z, [x23]\n" + "ld1w { z6.s }, p4/Z, [x11]\n" + "ld1w { z7.s }, p4/Z, [x10]\n" + "bgt 36b\n" + "37:" // Height 3: Multiply loop: Main loop skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "fmla z8.s, p4/M, z6.s, z0.s\n" + "fmla z12.s, p4/M, z6.s, z1.s\n" + "add x27, x27, #0x1\n" + "fmla z16.s, p4/M, z6.s, z2.s\n" + "fmla z9.s, p4/M, z7.s, z0.s\n" + "ld1w { z6.s }, p4/Z, [x9]\n" + "cmp x27, x19\n" + "fmla z13.s, p4/M, z7.s, z1.s\n" + "fmla z17.s, p4/M, z7.s, z2.s\n" + "ld1w { z7.s }, p4/Z, [x28]\n" + "addvl x11, x11, #1\n" + "fmla z10.s, p4/M, z6.s, z0.s\n" + "fmla z14.s, p4/M, z6.s, z1.s\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "fmla z18.s, p4/M, z6.s, z2.s\n" + "fmla z11.s, p4/M, z7.s, z0.s\n" + "addvl x28, x28, #1\n" + "fmla z15.s, p4/M, z7.s, z1.s\n" + "fmla z19.s, p4/M, z7.s, z2.s\n" + "bne 33b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "tbz %x[flags], #1, 38f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rw { z1.s }, p4/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rw { z0.s }, p4/Z, [x19]\n" + "fmin z8.s, p4/M, z8.s, z1.s\n" + "fmin z9.s, p4/M, z9.s, z1.s\n" + "fmin z10.s, p4/M, z10.s, z1.s\n" + "fmin z11.s, p4/M, z11.s, z1.s\n" + "fmin z12.s, p4/M, z12.s, z1.s\n" + "fmin z13.s, p4/M, z13.s, z1.s\n" + "fmin z14.s, p4/M, z14.s, z1.s\n" + "fmin z15.s, p4/M, z15.s, z1.s\n" + "fmin z16.s, p4/M, z16.s, z1.s\n" + "fmin z17.s, p4/M, z17.s, z1.s\n" + "fmin z18.s, p4/M, z18.s, z1.s\n" + "fmin z19.s, p4/M, z19.s, z1.s\n" + "fmax z8.s, p4/M, z8.s, z0.s\n" + "fmax z9.s, p4/M, z9.s, z0.s\n" + "fmax z10.s, p4/M, z10.s, z0.s\n" + "fmax z11.s, p4/M, z11.s, z0.s\n" + "fmax z12.s, p4/M, z12.s, z0.s\n" + "fmax z13.s, p4/M, z13.s, z0.s\n" + "fmax z14.s, p4/M, z14.s, z0.s\n" + "fmax z15.s, p4/M, z15.s, z0.s\n" + "fmax z16.s, p4/M, z16.s, z0.s\n" + "fmax z17.s, p4/M, z17.s, z0.s\n" + "fmax z18.s, p4/M, z18.s, z0.s\n" + "fmax z19.s, p4/M, z19.s, z0.s\n" + "38:" // Height 3: No activation + "st1w { z8.s }, p3, [x12]\n" + "st1w { z9.s }, p2, [x12, #1, MUL VL]\n" + "st1w { z10.s }, p1, [x12, #2, MUL VL]\n" + "st1w { z11.s }, p0, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "st1w { z12.s }, p3, [x24]\n" + "st1w { z13.s }, p2, [x24, #1, MUL VL]\n" + "st1w { z14.s }, p1, [x24, #2, MUL VL]\n" + "st1w { z15.s }, p0, [x24, #3, MUL VL]\n" + "st1w { z16.s }, p3, [x23]\n" + "st1w { z17.s }, p2, [x23, #1, MUL VL]\n" + "st1w { z18.s }, p1, [x23, #2, MUL VL]\n" + "st1w { z19.s }, p0, [x23, #3, MUL VL]\n" + "39:" // Height 3: Writeback done + "decw x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 28b\n" + "b 80f\n" + "40:" // Height 4 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "41:" // Height 4: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #2\n" + "cntw x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #2\n" + "add x28, x9, x19, LSL #2\n" + "add x19, x28, x19, LSL #2\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 42f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 42f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 42f\n" + "mov x10, x11\n" + "42:" // Height 4: B setup done + "mov x19, #0x0\n" + "whilelt p3.s, x19, x13\n" + "incw x19\n" + "whilelt p2.s, x19, x13\n" + "incw x19\n" + "whilelt p1.s, x19, x13\n" + "incw x19\n" + "whilelt p0.s, x19, x13\n" + "cbz x14, 43f\n" + "ld1w { z8.s }, p4/Z, [x14]\n" + "ld1w { z9.s }, p4/Z, [x14, #1, MUL VL]\n" + "mov z12.d, z8.d\n" + "mov z13.d, z9.d\n" + "ld1w { z10.s }, p4/Z, [x14, #2, MUL VL]\n" + "ld1w { z11.s }, p4/Z, [x14, #3, MUL VL]\n" + "mov z14.d, z10.d\n" + "mov z15.d, z11.d\n" + "mov z16.d, z8.d\n" + "mov z17.d, z9.d\n" + "addvl x14, x14, #4\n" + "mov z18.d, z10.d\n" + "mov z19.d, z11.d\n" + "mov z20.d, z8.d\n" + "mov z21.d, z9.d\n" + "mov z22.d, z10.d\n" + "mov z23.d, z11.d\n" + "b 45f\n" + "43:" // Height 4: no bias + "tbz %x[flags], #0, 44f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "ld1w { z8.s }, p3/Z, [x12]\n" + "add x22, x23, x19, LSL #2\n" + "ld1w { z9.s }, p2/Z, [x12, #1, MUL VL]\n" + "ld1w { z10.s }, p1/Z, [x12, #2, MUL VL]\n" + "ld1w { z11.s }, p0/Z, [x12, #3, MUL VL]\n" + "ld1w { z12.s }, p3/Z, [x24]\n" + "ld1w { z13.s }, p2/Z, [x24, #1, MUL VL]\n" + "ld1w { z14.s }, p1/Z, [x24, #2, MUL VL]\n" + "ld1w { z15.s }, p0/Z, [x24, #3, MUL VL]\n" + "ld1w { z16.s }, p3/Z, [x23]\n" + "ld1w { z17.s }, p2/Z, [x23, #1, MUL VL]\n" + "ld1w { z18.s }, p1/Z, [x23, #2, MUL VL]\n" + "ld1w { z19.s }, p0/Z, [x23, #3, MUL VL]\n" + "ld1w { z20.s }, p3/Z, [x22]\n" + "ld1w { z21.s }, p2/Z, [x22, #1, MUL VL]\n" + "ld1w { z22.s }, p1/Z, [x22, #2, MUL VL]\n" + "ld1w { z23.s }, p0/Z, [x22, #3, MUL VL]\n" + "b 45f\n" + "44:" // Height 4: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "mov z19.b, #0x0\n" + "mov z20.b, #0x0\n" + "mov z21.b, #0x0\n" + "mov z22.b, #0x0\n" + "mov z23.b, #0x0\n" + "45:" // Height 4: setup done + "mov x27, #0x0\n" + "46:" // Height 4: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 47f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "ldr x22, [x20, #0x18]\n" + "cbnz x27, 48f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #2\n" + "add x24, x24, x19, LSL #2\n" + "add x23, x23, x19, LSL #2\n" + "add x22, x22, x19, LSL #2\n" + "b 48f\n" + "47:" // Height 4: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "add x22, x23, x19, LSL #2\n" + "48:" // Height 4: input setup done + "subs x26, x26, #0x1\n" + "ld1rw { z0.s }, p4/Z, [x25]\n" + "ld1rw { z1.s }, p4/Z, [x24]\n" + "ld1rw { z2.s }, p4/Z, [x23]\n" + "ld1rw { z3.s }, p4/Z, [x22]\n" + "ld1w { z6.s }, p4/Z, [x11]\n" + "ld1w { z7.s }, p4/Z, [x10]\n" + "ble 50f\n" + "49:" // Height 4: Multiply loop: Main loop + "fmla z8.s, p4/M, z6.s, z0.s\n" + "fmla z12.s, p4/M, z6.s, z1.s\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z16.s, p4/M, z6.s, z2.s\n" + "fmla z20.s, p4/M, z6.s, z3.s\n" + "ld1w { z6.s }, p4/Z, [x9]\n" + "add x25, x25, #0x4\n" + "fmla z9.s, p4/M, z7.s, z0.s\n" + "fmla z13.s, p4/M, z7.s, z1.s\n" + "subs x26, x26, #0x1\n" + "add x24, x24, #0x4\n" + "fmla z17.s, p4/M, z7.s, z2.s\n" + "fmla z21.s, p4/M, z7.s, z3.s\n" + "ld1w { z7.s }, p4/Z, [x28]\n" + "add x23, x23, #0x4\n" + "add x22, x22, #0x4\n" + "fmla z10.s, p4/M, z6.s, z0.s\n" + "fmla z14.s, p4/M, z6.s, z1.s\n" + "addvl x9, x9, #1\n" + "fmla z18.s, p4/M, z6.s, z2.s\n" + "fmla z22.s, p4/M, z6.s, z3.s\n" + "addvl x28, x28, #1\n" + "ld1w { z6.s }, p4/Z, [x11]\n" + "fmla z11.s, p4/M, z7.s, z0.s\n" + "fmla z15.s, p4/M, z7.s, z1.s\n" + "ld1rw { z0.s }, p4/Z, [x25]\n" + "ld1rw { z1.s }, p4/Z, [x24]\n" + "fmla z19.s, p4/M, z7.s, z2.s\n" + "fmla z23.s, p4/M, z7.s, z3.s\n" + "ld1rw { z2.s }, p4/Z, [x23]\n" + "ld1rw { z3.s }, p4/Z, [x22]\n" + "ld1w { z7.s }, p4/Z, [x10]\n" + "bgt 49b\n" + "50:" // Height 4: Multiply loop: Main loop skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "fmla z8.s, p4/M, z6.s, z0.s\n" + "fmla z12.s, p4/M, z6.s, z1.s\n" + "add x27, x27, #0x1\n" + "fmla z16.s, p4/M, z6.s, z2.s\n" + "fmla z20.s, p4/M, z6.s, z3.s\n" + "ld1w { z6.s }, p4/Z, [x9]\n" + "cmp x27, x19\n" + "fmla z9.s, p4/M, z7.s, z0.s\n" + "fmla z13.s, p4/M, z7.s, z1.s\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z17.s, p4/M, z7.s, z2.s\n" + "fmla z21.s, p4/M, z7.s, z3.s\n" + "ld1w { z7.s }, p4/Z, [x28]\n" + "addvl x9, x9, #1\n" + "fmla z10.s, p4/M, z6.s, z0.s\n" + "fmla z14.s, p4/M, z6.s, z1.s\n" + "addvl x28, x28, #1\n" + "fmla z18.s, p4/M, z6.s, z2.s\n" + "fmla z22.s, p4/M, z6.s, z3.s\n" + "fmla z11.s, p4/M, z7.s, z0.s\n" + "fmla z15.s, p4/M, z7.s, z1.s\n" + "fmla z19.s, p4/M, z7.s, z2.s\n" + "fmla z23.s, p4/M, z7.s, z3.s\n" + "bne 46b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "add x22, x23, x19, LSL #2\n" + "tbz %x[flags], #1, 51f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rw { z1.s }, p4/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rw { z0.s }, p4/Z, [x19]\n" + "fmin z8.s, p4/M, z8.s, z1.s\n" + "fmin z9.s, p4/M, z9.s, z1.s\n" + "fmin z10.s, p4/M, z10.s, z1.s\n" + "fmin z11.s, p4/M, z11.s, z1.s\n" + "fmin z12.s, p4/M, z12.s, z1.s\n" + "fmin z13.s, p4/M, z13.s, z1.s\n" + "fmin z14.s, p4/M, z14.s, z1.s\n" + "fmin z15.s, p4/M, z15.s, z1.s\n" + "fmin z16.s, p4/M, z16.s, z1.s\n" + "fmin z17.s, p4/M, z17.s, z1.s\n" + "fmin z18.s, p4/M, z18.s, z1.s\n" + "fmin z19.s, p4/M, z19.s, z1.s\n" + "fmin z20.s, p4/M, z20.s, z1.s\n" + "fmin z21.s, p4/M, z21.s, z1.s\n" + "fmin z22.s, p4/M, z22.s, z1.s\n" + "fmin z23.s, p4/M, z23.s, z1.s\n" + "fmax z8.s, p4/M, z8.s, z0.s\n" + "fmax z9.s, p4/M, z9.s, z0.s\n" + "fmax z10.s, p4/M, z10.s, z0.s\n" + "fmax z11.s, p4/M, z11.s, z0.s\n" + "fmax z12.s, p4/M, z12.s, z0.s\n" + "fmax z13.s, p4/M, z13.s, z0.s\n" + "fmax z14.s, p4/M, z14.s, z0.s\n" + "fmax z15.s, p4/M, z15.s, z0.s\n" + "fmax z16.s, p4/M, z16.s, z0.s\n" + "fmax z17.s, p4/M, z17.s, z0.s\n" + "fmax z18.s, p4/M, z18.s, z0.s\n" + "fmax z19.s, p4/M, z19.s, z0.s\n" + "fmax z20.s, p4/M, z20.s, z0.s\n" + "fmax z21.s, p4/M, z21.s, z0.s\n" + "fmax z22.s, p4/M, z22.s, z0.s\n" + "fmax z23.s, p4/M, z23.s, z0.s\n" + "51:" // Height 4: No activation + "st1w { z8.s }, p3, [x12]\n" + "st1w { z9.s }, p2, [x12, #1, MUL VL]\n" + "st1w { z10.s }, p1, [x12, #2, MUL VL]\n" + "st1w { z11.s }, p0, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "st1w { z12.s }, p3, [x24]\n" + "st1w { z13.s }, p2, [x24, #1, MUL VL]\n" + "st1w { z14.s }, p1, [x24, #2, MUL VL]\n" + "st1w { z15.s }, p0, [x24, #3, MUL VL]\n" + "st1w { z16.s }, p3, [x23]\n" + "st1w { z17.s }, p2, [x23, #1, MUL VL]\n" + "st1w { z18.s }, p1, [x23, #2, MUL VL]\n" + "st1w { z19.s }, p0, [x23, #3, MUL VL]\n" + "st1w { z20.s }, p3, [x22]\n" + "st1w { z21.s }, p2, [x22, #1, MUL VL]\n" + "st1w { z22.s }, p1, [x22, #2, MUL VL]\n" + "st1w { z23.s }, p0, [x22, #3, MUL VL]\n" + "52:" // Height 4: Writeback done + "decw x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 41b\n" + "b 80f\n" + "53:" // Height 5 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "54:" // Height 5: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #2\n" + "cntw x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #2\n" + "add x28, x9, x19, LSL #2\n" + "add x19, x28, x19, LSL #2\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 55f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 55f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 55f\n" + "mov x10, x11\n" + "55:" // Height 5: B setup done + "mov x19, #0x0\n" + "whilelt p3.s, x19, x13\n" + "incw x19\n" + "whilelt p2.s, x19, x13\n" + "incw x19\n" + "whilelt p1.s, x19, x13\n" + "incw x19\n" + "whilelt p0.s, x19, x13\n" + "cbz x14, 56f\n" + "ld1w { z8.s }, p4/Z, [x14]\n" + "ld1w { z9.s }, p4/Z, [x14, #1, MUL VL]\n" + "mov z12.d, z8.d\n" + "mov z13.d, z9.d\n" + "ld1w { z10.s }, p4/Z, [x14, #2, MUL VL]\n" + "ld1w { z11.s }, p4/Z, [x14, #3, MUL VL]\n" + "mov z14.d, z10.d\n" + "mov z15.d, z11.d\n" + "mov z16.d, z8.d\n" + "mov z17.d, z9.d\n" + "addvl x14, x14, #4\n" + "mov z18.d, z10.d\n" + "mov z19.d, z11.d\n" + "mov z20.d, z8.d\n" + "mov z21.d, z9.d\n" + "mov z22.d, z10.d\n" + "mov z23.d, z11.d\n" + "mov z24.d, z8.d\n" + "mov z25.d, z9.d\n" + "mov z26.d, z10.d\n" + "mov z27.d, z11.d\n" + "b 58f\n" + "56:" // Height 5: no bias + "tbz %x[flags], #0, 57f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "ld1w { z8.s }, p3/Z, [x12]\n" + "add x22, x23, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "ld1w { z9.s }, p2/Z, [x12, #1, MUL VL]\n" + "ld1w { z10.s }, p1/Z, [x12, #2, MUL VL]\n" + "ld1w { z11.s }, p0/Z, [x12, #3, MUL VL]\n" + "ld1w { z12.s }, p3/Z, [x24]\n" + "ld1w { z13.s }, p2/Z, [x24, #1, MUL VL]\n" + "ld1w { z14.s }, p1/Z, [x24, #2, MUL VL]\n" + "ld1w { z15.s }, p0/Z, [x24, #3, MUL VL]\n" + "ld1w { z16.s }, p3/Z, [x23]\n" + "ld1w { z17.s }, p2/Z, [x23, #1, MUL VL]\n" + "ld1w { z18.s }, p1/Z, [x23, #2, MUL VL]\n" + "ld1w { z19.s }, p0/Z, [x23, #3, MUL VL]\n" + "ld1w { z20.s }, p3/Z, [x22]\n" + "ld1w { z21.s }, p2/Z, [x22, #1, MUL VL]\n" + "ld1w { z22.s }, p1/Z, [x22, #2, MUL VL]\n" + "ld1w { z23.s }, p0/Z, [x22, #3, MUL VL]\n" + "ld1w { z24.s }, p3/Z, [x21]\n" + "ld1w { z25.s }, p2/Z, [x21, #1, MUL VL]\n" + "ld1w { z26.s }, p1/Z, [x21, #2, MUL VL]\n" + "ld1w { z27.s }, p0/Z, [x21, #3, MUL VL]\n" + "b 58f\n" + "57:" // Height 5: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "mov z19.b, #0x0\n" + "mov z20.b, #0x0\n" + "mov z21.b, #0x0\n" + "mov z22.b, #0x0\n" + "mov z23.b, #0x0\n" + "mov z24.b, #0x0\n" + "mov z25.b, #0x0\n" + "mov z26.b, #0x0\n" + "mov z27.b, #0x0\n" + "58:" // Height 5: setup done + "mov x27, #0x0\n" + "59:" // Height 5: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 60f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "ldr x22, [x20, #0x18]\n" + "ldr x21, [x20, #0x20]\n" + "cbnz x27, 61f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #2\n" + "add x24, x24, x19, LSL #2\n" + "add x23, x23, x19, LSL #2\n" + "add x22, x22, x19, LSL #2\n" + "add x21, x21, x19, LSL #2\n" + "b 61f\n" + "60:" // Height 5: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "add x22, x23, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "61:" // Height 5: input setup done + "subs x26, x26, #0x1\n" + "ld1rw { z0.s }, p4/Z, [x25]\n" + "ld1rw { z1.s }, p4/Z, [x24]\n" + "ld1rw { z2.s }, p4/Z, [x23]\n" + "ld1rw { z3.s }, p4/Z, [x22]\n" + "ld1rw { z4.s }, p4/Z, [x21]\n" + "ld1w { z6.s }, p4/Z, [x11]\n" + "ld1w { z7.s }, p4/Z, [x10]\n" + "ble 63f\n" + "62:" // Height 5: Multiply loop: Main loop + "fmla z8.s, p4/M, z6.s, z0.s\n" + "fmla z12.s, p4/M, z6.s, z1.s\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z16.s, p4/M, z6.s, z2.s\n" + "fmla z20.s, p4/M, z6.s, z3.s\n" + "add x25, x25, #0x4\n" + "subs x26, x26, #0x1\n" + "fmla z24.s, p4/M, z6.s, z4.s\n" + "fmla z9.s, p4/M, z7.s, z0.s\n" + "ld1w { z6.s }, p4/Z, [x9]\n" + "add x24, x24, #0x4\n" + "fmla z13.s, p4/M, z7.s, z1.s\n" + "fmla z17.s, p4/M, z7.s, z2.s\n" + "add x23, x23, #0x4\n" + "add x22, x22, #0x4\n" + "fmla z21.s, p4/M, z7.s, z3.s\n" + "fmla z25.s, p4/M, z7.s, z4.s\n" + "ld1w { z7.s }, p4/Z, [x28]\n" + "add x21, x21, #0x4\n" + "fmla z10.s, p4/M, z6.s, z0.s\n" + "fmla z14.s, p4/M, z6.s, z1.s\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "fmla z18.s, p4/M, z6.s, z2.s\n" + "fmla z22.s, p4/M, z6.s, z3.s\n" + "fmla z26.s, p4/M, z6.s, z4.s\n" + "fmla z11.s, p4/M, z7.s, z0.s\n" + "ld1rw { z0.s }, p4/Z, [x25]\n" + "ld1w { z6.s }, p4/Z, [x11]\n" + "fmla z15.s, p4/M, z7.s, z1.s\n" + "fmla z19.s, p4/M, z7.s, z2.s\n" + "ld1rw { z1.s }, p4/Z, [x24]\n" + "ld1rw { z2.s }, p4/Z, [x23]\n" + "fmla z23.s, p4/M, z7.s, z3.s\n" + "fmla z27.s, p4/M, z7.s, z4.s\n" + "ld1rw { z3.s }, p4/Z, [x22]\n" + "ld1rw { z4.s }, p4/Z, [x21]\n" + "ld1w { z7.s }, p4/Z, [x10]\n" + "bgt 62b\n" + "63:" // Height 5: Multiply loop: Main loop skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "fmla z8.s, p4/M, z6.s, z0.s\n" + "fmla z12.s, p4/M, z6.s, z1.s\n" + "add x27, x27, #0x1\n" + "fmla z16.s, p4/M, z6.s, z2.s\n" + "fmla z20.s, p4/M, z6.s, z3.s\n" + "cmp x27, x19\n" + "addvl x11, x11, #1\n" + "fmla z24.s, p4/M, z6.s, z4.s\n" + "fmla z9.s, p4/M, z7.s, z0.s\n" + "ld1w { z6.s }, p4/Z, [x9]\n" + "addvl x10, x10, #1\n" + "fmla z13.s, p4/M, z7.s, z1.s\n" + "fmla z17.s, p4/M, z7.s, z2.s\n" + "addvl x9, x9, #1\n" + "fmla z21.s, p4/M, z7.s, z3.s\n" + "fmla z25.s, p4/M, z7.s, z4.s\n" + "ld1w { z7.s }, p4/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.s, p4/M, z6.s, z0.s\n" + "fmla z14.s, p4/M, z6.s, z1.s\n" + "fmla z18.s, p4/M, z6.s, z2.s\n" + "fmla z22.s, p4/M, z6.s, z3.s\n" + "fmla z26.s, p4/M, z6.s, z4.s\n" + "fmla z11.s, p4/M, z7.s, z0.s\n" + "fmla z15.s, p4/M, z7.s, z1.s\n" + "fmla z19.s, p4/M, z7.s, z2.s\n" + "fmla z23.s, p4/M, z7.s, z3.s\n" + "fmla z27.s, p4/M, z7.s, z4.s\n" + "bne 59b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "add x22, x23, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "tbz %x[flags], #1, 64f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rw { z1.s }, p4/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rw { z0.s }, p4/Z, [x19]\n" + "fmin z8.s, p4/M, z8.s, z1.s\n" + "fmin z9.s, p4/M, z9.s, z1.s\n" + "fmin z10.s, p4/M, z10.s, z1.s\n" + "fmin z11.s, p4/M, z11.s, z1.s\n" + "fmin z12.s, p4/M, z12.s, z1.s\n" + "fmin z13.s, p4/M, z13.s, z1.s\n" + "fmin z14.s, p4/M, z14.s, z1.s\n" + "fmin z15.s, p4/M, z15.s, z1.s\n" + "fmin z16.s, p4/M, z16.s, z1.s\n" + "fmin z17.s, p4/M, z17.s, z1.s\n" + "fmin z18.s, p4/M, z18.s, z1.s\n" + "fmin z19.s, p4/M, z19.s, z1.s\n" + "fmin z20.s, p4/M, z20.s, z1.s\n" + "fmin z21.s, p4/M, z21.s, z1.s\n" + "fmin z22.s, p4/M, z22.s, z1.s\n" + "fmin z23.s, p4/M, z23.s, z1.s\n" + "fmin z24.s, p4/M, z24.s, z1.s\n" + "fmin z25.s, p4/M, z25.s, z1.s\n" + "fmin z26.s, p4/M, z26.s, z1.s\n" + "fmin z27.s, p4/M, z27.s, z1.s\n" + "fmax z8.s, p4/M, z8.s, z0.s\n" + "fmax z9.s, p4/M, z9.s, z0.s\n" + "fmax z10.s, p4/M, z10.s, z0.s\n" + "fmax z11.s, p4/M, z11.s, z0.s\n" + "fmax z12.s, p4/M, z12.s, z0.s\n" + "fmax z13.s, p4/M, z13.s, z0.s\n" + "fmax z14.s, p4/M, z14.s, z0.s\n" + "fmax z15.s, p4/M, z15.s, z0.s\n" + "fmax z16.s, p4/M, z16.s, z0.s\n" + "fmax z17.s, p4/M, z17.s, z0.s\n" + "fmax z18.s, p4/M, z18.s, z0.s\n" + "fmax z19.s, p4/M, z19.s, z0.s\n" + "fmax z20.s, p4/M, z20.s, z0.s\n" + "fmax z21.s, p4/M, z21.s, z0.s\n" + "fmax z22.s, p4/M, z22.s, z0.s\n" + "fmax z23.s, p4/M, z23.s, z0.s\n" + "fmax z24.s, p4/M, z24.s, z0.s\n" + "fmax z25.s, p4/M, z25.s, z0.s\n" + "fmax z26.s, p4/M, z26.s, z0.s\n" + "fmax z27.s, p4/M, z27.s, z0.s\n" + "64:" // Height 5: No activation + "st1w { z8.s }, p3, [x12]\n" + "st1w { z9.s }, p2, [x12, #1, MUL VL]\n" + "st1w { z10.s }, p1, [x12, #2, MUL VL]\n" + "st1w { z11.s }, p0, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "st1w { z12.s }, p3, [x24]\n" + "st1w { z13.s }, p2, [x24, #1, MUL VL]\n" + "st1w { z14.s }, p1, [x24, #2, MUL VL]\n" + "st1w { z15.s }, p0, [x24, #3, MUL VL]\n" + "st1w { z16.s }, p3, [x23]\n" + "st1w { z17.s }, p2, [x23, #1, MUL VL]\n" + "st1w { z18.s }, p1, [x23, #2, MUL VL]\n" + "st1w { z19.s }, p0, [x23, #3, MUL VL]\n" + "st1w { z20.s }, p3, [x22]\n" + "st1w { z21.s }, p2, [x22, #1, MUL VL]\n" + "st1w { z22.s }, p1, [x22, #2, MUL VL]\n" + "st1w { z23.s }, p0, [x22, #3, MUL VL]\n" + "st1w { z24.s }, p3, [x21]\n" + "st1w { z25.s }, p2, [x21, #1, MUL VL]\n" + "st1w { z26.s }, p1, [x21, #2, MUL VL]\n" + "st1w { z27.s }, p0, [x21, #3, MUL VL]\n" + "65:" // Height 5: Writeback done + "decw x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 54b\n" + "b 80f\n" + "66:" // Height 6 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x20, #0x18\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "mov x14, %x[bias]\n" + "mov x12, %x[output_ptr]\n" + "madd %x[output_ptr], x19, x20, %x[output_ptr]\n" + "67:" // Height 6: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #2\n" + "cntw x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #2\n" + "add x28, x9, x19, LSL #2\n" + "add x19, x28, x19, LSL #2\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 68f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 68f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 68f\n" + "mov x10, x11\n" + "68:" // Height 6: B setup done + "mov x19, #0x0\n" + "whilelt p3.s, x19, x13\n" + "incw x19\n" + "whilelt p2.s, x19, x13\n" + "incw x19\n" + "whilelt p1.s, x19, x13\n" + "incw x19\n" + "whilelt p0.s, x19, x13\n" + "cbz x14, 69f\n" + "ld1w { z8.s }, p4/Z, [x14]\n" + "ld1w { z9.s }, p4/Z, [x14, #1, MUL VL]\n" + "mov z12.d, z8.d\n" + "mov z13.d, z9.d\n" + "ld1w { z10.s }, p4/Z, [x14, #2, MUL VL]\n" + "ld1w { z11.s }, p4/Z, [x14, #3, MUL VL]\n" + "mov z14.d, z10.d\n" + "mov z15.d, z11.d\n" + "mov z16.d, z8.d\n" + "mov z17.d, z9.d\n" + "addvl x14, x14, #4\n" + "mov z18.d, z10.d\n" + "mov z19.d, z11.d\n" + "mov z20.d, z8.d\n" + "mov z21.d, z9.d\n" + "mov z22.d, z10.d\n" + "mov z23.d, z11.d\n" + "mov z24.d, z8.d\n" + "mov z25.d, z9.d\n" + "mov z26.d, z10.d\n" + "mov z27.d, z11.d\n" + "mov z28.d, z8.d\n" + "mov z29.d, z9.d\n" + "mov z30.d, z10.d\n" + "mov z31.d, z11.d\n" + "b 71f\n" + "69:" // Height 6: no bias + "tbz %x[flags], #0, 70f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "ld1w { z8.s }, p3/Z, [x12]\n" + "add x22, x23, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "ld1w { z9.s }, p2/Z, [x12, #1, MUL VL]\n" + "ld1w { z10.s }, p1/Z, [x12, #2, MUL VL]\n" + "add x20, x21, x19, LSL #2\n" + "ld1w { z11.s }, p0/Z, [x12, #3, MUL VL]\n" + "ld1w { z12.s }, p3/Z, [x24]\n" + "ld1w { z13.s }, p2/Z, [x24, #1, MUL VL]\n" + "ld1w { z14.s }, p1/Z, [x24, #2, MUL VL]\n" + "ld1w { z15.s }, p0/Z, [x24, #3, MUL VL]\n" + "ld1w { z16.s }, p3/Z, [x23]\n" + "ld1w { z17.s }, p2/Z, [x23, #1, MUL VL]\n" + "ld1w { z18.s }, p1/Z, [x23, #2, MUL VL]\n" + "ld1w { z19.s }, p0/Z, [x23, #3, MUL VL]\n" + "ld1w { z20.s }, p3/Z, [x22]\n" + "ld1w { z21.s }, p2/Z, [x22, #1, MUL VL]\n" + "ld1w { z22.s }, p1/Z, [x22, #2, MUL VL]\n" + "ld1w { z23.s }, p0/Z, [x22, #3, MUL VL]\n" + "ld1w { z24.s }, p3/Z, [x21]\n" + "ld1w { z25.s }, p2/Z, [x21, #1, MUL VL]\n" + "ld1w { z26.s }, p1/Z, [x21, #2, MUL VL]\n" + "ld1w { z27.s }, p0/Z, [x21, #3, MUL VL]\n" + "ld1w { z28.s }, p3/Z, [x20]\n" + "ld1w { z29.s }, p2/Z, [x20, #1, MUL VL]\n" + "ld1w { z30.s }, p1/Z, [x20, #2, MUL VL]\n" + "ld1w { z31.s }, p0/Z, [x20, #3, MUL VL]\n" + "b 71f\n" + "70:" // Height 6: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "mov z19.b, #0x0\n" + "mov z20.b, #0x0\n" + "mov z21.b, #0x0\n" + "mov z22.b, #0x0\n" + "mov z23.b, #0x0\n" + "mov z24.b, #0x0\n" + "mov z25.b, #0x0\n" + "mov z26.b, #0x0\n" + "mov z27.b, #0x0\n" + "mov z28.b, #0x0\n" + "mov z29.b, #0x0\n" + "mov z30.b, #0x0\n" + "mov z31.b, #0x0\n" + "71:" // Height 6: setup done + "mov x27, #0x0\n" + "72:" // Height 6: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 73f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "ldr x22, [x20, #0x18]\n" + "ldr x21, [x20, #0x20]\n" + "ldr x20, [x20, #0x28]\n" + "cbnz x27, 74f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #2\n" + "add x24, x24, x19, LSL #2\n" + "add x23, x23, x19, LSL #2\n" + "add x22, x22, x19, LSL #2\n" + "add x21, x21, x19, LSL #2\n" + "add x20, x20, x19, LSL #2\n" + "b 74f\n" + "73:" // Height 6: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "add x22, x23, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "add x20, x21, x19, LSL #2\n" + "74:" // Height 6: input setup done + "subs x26, x26, #0x1\n" + "ld1rw { z0.s }, p4/Z, [x25]\n" + "ld1rw { z1.s }, p4/Z, [x24]\n" + "ld1rw { z2.s }, p4/Z, [x23]\n" + "ld1rw { z3.s }, p4/Z, [x22]\n" + "ld1rw { z4.s }, p4/Z, [x21]\n" + "ld1rw { z5.s }, p4/Z, [x20]\n" + "ld1w { z6.s }, p4/Z, [x11]\n" + "ld1w { z7.s }, p4/Z, [x10]\n" + "ble 76f\n" + "75:" // Height 6: Multiply loop: Main loop + "fmla z8.s, p4/M, z6.s, z0.s\n" + "fmla z12.s, p4/M, z6.s, z1.s\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z16.s, p4/M, z6.s, z2.s\n" + "fmla z20.s, p4/M, z6.s, z3.s\n" + "add x25, x25, #0x4\n" + "subs x26, x26, #0x1\n" + "fmla z24.s, p4/M, z6.s, z4.s\n" + "fmla z28.s, p4/M, z6.s, z5.s\n" + "ld1w { z6.s }, p4/Z, [x9]\n" + "add x24, x24, #0x4\n" + "fmla z9.s, p4/M, z7.s, z0.s\n" + "fmla z13.s, p4/M, z7.s, z1.s\n" + "add x23, x23, #0x4\n" + "add x22, x22, #0x4\n" + "fmla z17.s, p4/M, z7.s, z2.s\n" + "fmla z21.s, p4/M, z7.s, z3.s\n" + "add x21, x21, #0x4\n" + "add x20, x20, #0x4\n" + "fmla z25.s, p4/M, z7.s, z4.s\n" + "fmla z29.s, p4/M, z7.s, z5.s\n" + "ld1w { z7.s }, p4/Z, [x28]\n" + "addvl x9, x9, #1\n" + "fmla z10.s, p4/M, z6.s, z0.s\n" + "fmla z14.s, p4/M, z6.s, z1.s\n" + "addvl x28, x28, #1\n" + "fmla z18.s, p4/M, z6.s, z2.s\n" + "fmla z22.s, p4/M, z6.s, z3.s\n" + "fmla z26.s, p4/M, z6.s, z4.s\n" + "fmla z30.s, p4/M, z6.s, z5.s\n" + "ld1w { z6.s }, p4/Z, [x11]\n" + "fmla z11.s, p4/M, z7.s, z0.s\n" + "fmla z15.s, p4/M, z7.s, z1.s\n" + "ld1rw { z0.s }, p4/Z, [x25]\n" + "ld1rw { z1.s }, p4/Z, [x24]\n" + "fmla z19.s, p4/M, z7.s, z2.s\n" + "fmla z23.s, p4/M, z7.s, z3.s\n" + "ld1rw { z2.s }, p4/Z, [x23]\n" + "ld1rw { z3.s }, p4/Z, [x22]\n" + "fmla z27.s, p4/M, z7.s, z4.s\n" + "fmla z31.s, p4/M, z7.s, z5.s\n" + "ld1rw { z4.s }, p4/Z, [x21]\n" + "ld1rw { z5.s }, p4/Z, [x20]\n" + "ld1w { z7.s }, p4/Z, [x10]\n" + "bgt 75b\n" + "76:" // Height 6: Multiply loop: Main loop skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "fmla z8.s, p4/M, z6.s, z0.s\n" + "fmla z12.s, p4/M, z6.s, z1.s\n" + "add x27, x27, #0x1\n" + "fmla z16.s, p4/M, z6.s, z2.s\n" + "fmla z20.s, p4/M, z6.s, z3.s\n" + "cmp x27, x19\n" + "addvl x11, x11, #1\n" + "fmla z24.s, p4/M, z6.s, z4.s\n" + "fmla z28.s, p4/M, z6.s, z5.s\n" + "ld1w { z6.s }, p4/Z, [x9]\n" + "addvl x10, x10, #1\n" + "fmla z9.s, p4/M, z7.s, z0.s\n" + "fmla z13.s, p4/M, z7.s, z1.s\n" + "addvl x9, x9, #1\n" + "fmla z17.s, p4/M, z7.s, z2.s\n" + "fmla z21.s, p4/M, z7.s, z3.s\n" + "fmla z25.s, p4/M, z7.s, z4.s\n" + "fmla z29.s, p4/M, z7.s, z5.s\n" + "ld1w { z7.s }, p4/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.s, p4/M, z6.s, z0.s\n" + "fmla z14.s, p4/M, z6.s, z1.s\n" + "fmla z18.s, p4/M, z6.s, z2.s\n" + "fmla z22.s, p4/M, z6.s, z3.s\n" + "fmla z26.s, p4/M, z6.s, z4.s\n" + "fmla z30.s, p4/M, z6.s, z5.s\n" + "fmla z11.s, p4/M, z7.s, z0.s\n" + "fmla z15.s, p4/M, z7.s, z1.s\n" + "fmla z19.s, p4/M, z7.s, z2.s\n" + "fmla z23.s, p4/M, z7.s, z3.s\n" + "fmla z27.s, p4/M, z7.s, z4.s\n" + "fmla z31.s, p4/M, z7.s, z5.s\n" + "bne 72b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "add x22, x23, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "add x20, x21, x19, LSL #2\n" + "tbz %x[flags], #1, 77f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rw { z1.s }, p4/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rw { z0.s }, p4/Z, [x19]\n" + "fmin z8.s, p4/M, z8.s, z1.s\n" + "fmin z9.s, p4/M, z9.s, z1.s\n" + "fmin z10.s, p4/M, z10.s, z1.s\n" + "fmin z11.s, p4/M, z11.s, z1.s\n" + "fmin z12.s, p4/M, z12.s, z1.s\n" + "fmin z13.s, p4/M, z13.s, z1.s\n" + "fmin z14.s, p4/M, z14.s, z1.s\n" + "fmin z15.s, p4/M, z15.s, z1.s\n" + "fmin z16.s, p4/M, z16.s, z1.s\n" + "fmin z17.s, p4/M, z17.s, z1.s\n" + "fmin z18.s, p4/M, z18.s, z1.s\n" + "fmin z19.s, p4/M, z19.s, z1.s\n" + "fmin z20.s, p4/M, z20.s, z1.s\n" + "fmin z21.s, p4/M, z21.s, z1.s\n" + "fmin z22.s, p4/M, z22.s, z1.s\n" + "fmin z23.s, p4/M, z23.s, z1.s\n" + "fmin z24.s, p4/M, z24.s, z1.s\n" + "fmin z25.s, p4/M, z25.s, z1.s\n" + "fmin z26.s, p4/M, z26.s, z1.s\n" + "fmin z27.s, p4/M, z27.s, z1.s\n" + "fmin z28.s, p4/M, z28.s, z1.s\n" + "fmin z29.s, p4/M, z29.s, z1.s\n" + "fmin z30.s, p4/M, z30.s, z1.s\n" + "fmin z31.s, p4/M, z31.s, z1.s\n" + "fmax z8.s, p4/M, z8.s, z0.s\n" + "fmax z9.s, p4/M, z9.s, z0.s\n" + "fmax z10.s, p4/M, z10.s, z0.s\n" + "fmax z11.s, p4/M, z11.s, z0.s\n" + "fmax z12.s, p4/M, z12.s, z0.s\n" + "fmax z13.s, p4/M, z13.s, z0.s\n" + "fmax z14.s, p4/M, z14.s, z0.s\n" + "fmax z15.s, p4/M, z15.s, z0.s\n" + "fmax z16.s, p4/M, z16.s, z0.s\n" + "fmax z17.s, p4/M, z17.s, z0.s\n" + "fmax z18.s, p4/M, z18.s, z0.s\n" + "fmax z19.s, p4/M, z19.s, z0.s\n" + "fmax z20.s, p4/M, z20.s, z0.s\n" + "fmax z21.s, p4/M, z21.s, z0.s\n" + "fmax z22.s, p4/M, z22.s, z0.s\n" + "fmax z23.s, p4/M, z23.s, z0.s\n" + "fmax z24.s, p4/M, z24.s, z0.s\n" + "fmax z25.s, p4/M, z25.s, z0.s\n" + "fmax z26.s, p4/M, z26.s, z0.s\n" + "fmax z27.s, p4/M, z27.s, z0.s\n" + "fmax z28.s, p4/M, z28.s, z0.s\n" + "fmax z29.s, p4/M, z29.s, z0.s\n" + "fmax z30.s, p4/M, z30.s, z0.s\n" + "fmax z31.s, p4/M, z31.s, z0.s\n" + "77:" // Height 6: No activation + "st1w { z8.s }, p3, [x12]\n" + "st1w { z9.s }, p2, [x12, #1, MUL VL]\n" + "st1w { z10.s }, p1, [x12, #2, MUL VL]\n" + "st1w { z11.s }, p0, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "st1w { z12.s }, p3, [x24]\n" + "st1w { z13.s }, p2, [x24, #1, MUL VL]\n" + "st1w { z14.s }, p1, [x24, #2, MUL VL]\n" + "st1w { z15.s }, p0, [x24, #3, MUL VL]\n" + "st1w { z16.s }, p3, [x23]\n" + "st1w { z17.s }, p2, [x23, #1, MUL VL]\n" + "st1w { z18.s }, p1, [x23, #2, MUL VL]\n" + "st1w { z19.s }, p0, [x23, #3, MUL VL]\n" + "st1w { z20.s }, p3, [x22]\n" + "st1w { z21.s }, p2, [x22, #1, MUL VL]\n" + "st1w { z22.s }, p1, [x22, #2, MUL VL]\n" + "st1w { z23.s }, p0, [x22, #3, MUL VL]\n" + "st1w { z24.s }, p3, [x21]\n" + "st1w { z25.s }, p2, [x21, #1, MUL VL]\n" + "st1w { z26.s }, p1, [x21, #2, MUL VL]\n" + "st1w { z27.s }, p0, [x21, #3, MUL VL]\n" + "st1w { z28.s }, p3, [x20]\n" + "st1w { z29.s }, p2, [x20, #1, MUL VL]\n" + "st1w { z30.s }, p1, [x20, #2, MUL VL]\n" + "st1w { z31.s }, p0, [x20, #3, MUL VL]\n" + "78:" // Height 6: Writeback done + "decw x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 67b\n" + "subs %x[M], %x[M], #0x6\n" + "beq 80f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 79f\n" + "add x20, x20, #0x6\n" + "str x20, [%x[args_ptr], %[offsetof_input_offset]]\n" + "b 1b\n" + "79:" // Update direct input + "mov x19, #0x18\n" + "madd %x[input_ptr], x19, x20, %x[input_ptr]\n" + "b 1b\n" + "80:" // Exit + : [M] "+&r" (M), [input_ptr] "+&r" (input_ptr), [output_ptr] "+&r" (output_ptr) + : [args_ptr] "r" (&ka), [bias] "r" (bias), [flags] "r" (flags), [offset_max] "I" (offsetof(KernelArgs, maxval)), [offset_min] "I" (offsetof(KernelArgs, minval)), [offsetof_B_ptr] "I" (offsetof(KernelArgs, B_ptr)), [offsetof_B_stride] "I" (offsetof(KernelArgs, B_stride)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_cur_B_ptr] "I" (offsetof(KernelArgs, cur_B_ptr)), [offsetof_input_initial_col] "I" (offsetof(KernelArgs, input_initial_col)), [offsetof_input_offset] "I" (offsetof(KernelArgs, input_offset)), [offsetof_num_strings] "I" (offsetof(KernelArgs, num_strings)), [offsetof_output_offset] "I" (offsetof(KernelArgs, output_offset)), [offsetof_string_lengths] "I" (offsetof(KernelArgs, string_lengths)) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "x9", "x10", "x11", "x12", "x13", "x14", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // namespace arm_gemm +#endif // ARM_COMPUTE_ENABLE_SVE diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32_mla_6x4VL/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32_mla_6x4VL/generic.cpp new file mode 100644 index 0000000000..3c7e562c89 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32_mla_6x4VL/generic.cpp @@ -0,0 +1,2310 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#ifdef ARM_COMPUTE_ENABLE_SVE + +#include "arm_gemm.hpp" +#include "../../utils.hpp" + +#include +#include + +namespace arm_gemm { + +void sve_ffhybrid_fp32_mla_6x4VL ( + unsigned int num_strings, const unsigned int *string_lengths, IndirectInputArg A_arg, + size_t M, size_t N, const float *B_ptr, size_t B_stride, IndirectOutputArg output_arg, + const float *bias, Activation act, bool accumulate +) +{ + struct KernelArgs { + float maxval = static_cast(std::numeric_limits::infinity()); + float minval = - static_cast(std::numeric_limits::infinity()); + unsigned int num_strings = {}; + const unsigned int *string_lengths = {}; + size_t N = {}; + const float *B_ptr = {}; + const float *cur_B_ptr = {}; + size_t B_stride = {}; + size_t output_offset = {}; + size_t input_initial_col = {}; + size_t input_offset = {}; + } ka; + + unsigned long flags=0; + void *output_ptr; + void *input_ptr; + + if (output_arg.is_indirect) { + output_ptr=(void *)(output_arg.indirect.ptr); + ka.output_offset=output_arg.indirect.offset; + flags |= 0x4; + } else { + output_ptr=(void *)(output_arg.direct.base); + ka.output_offset=output_arg.direct.stride; + } + + if (A_arg.is_indirect) { + input_ptr=(void *)(A_arg.indirect.ptr); + ka.input_offset=A_arg.indirect.start_row; + ka.input_initial_col=A_arg.indirect.start_col; + flags |= 0x8; + } else { + assert(num_strings==1); + input_ptr=(void *)(A_arg.direct.base); + ka.input_offset=A_arg.direct.stride; + } + if (accumulate) { + flags |= 0x1; + } + ka.num_strings = num_strings; + ka.string_lengths = string_lengths; + ka.N = N; + ka.B_ptr = B_ptr; + ka.B_stride = B_stride; + switch(act.type) { + default: + case Activation::Type::None: + break; + case Activation::Type::BoundedReLU: + ka.maxval = static_cast(act.param1); + /* fall through */ + case Activation::Type::ReLU: + ka.minval = 0; + flags |= 0x2; + break; + } + __asm__ __volatile__( + "ptrue p5.b\n" + "1:" // Row loop + "cmp %x[M], #0x6\n" + "bge 71f\n" + "cmp %x[M], #0x4\n" + "bgt 57f\n" + "beq 43f\n" + "cmp %x[M], #0x2\n" + "bgt 29f\n" + "beq 15f\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "2:" // Height 1: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #2\n" + "cntw x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #2\n" + "add x28, x9, x19, LSL #2\n" + "add x19, x28, x19, LSL #2\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 3f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 3f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 3f\n" + "mov x10, x11\n" + "3:" // Height 1: B setup done + "mov x19, #0x0\n" + "whilelt p4.s, x19, x13\n" + "incw x19\n" + "whilelt p3.s, x19, x13\n" + "incw x19\n" + "whilelt p2.s, x19, x13\n" + "incw x19\n" + "whilelt p1.s, x19, x13\n" + "cbz x14, 4f\n" + "ld1w { z8.s }, p5/Z, [x14]\n" + "ld1w { z9.s }, p5/Z, [x14, #1, MUL VL]\n" + "ld1w { z10.s }, p5/Z, [x14, #2, MUL VL]\n" + "ld1w { z11.s }, p5/Z, [x14, #3, MUL VL]\n" + "addvl x14, x14, #4\n" + "b 6f\n" + "4:" // Height 1: no bias + "tbz %x[flags], #0, 5f\n" + "ld1w { z8.s }, p4/Z, [x12]\n" + "ld1w { z9.s }, p3/Z, [x12, #1, MUL VL]\n" + "ld1w { z10.s }, p2/Z, [x12, #2, MUL VL]\n" + "ld1w { z11.s }, p1/Z, [x12, #3, MUL VL]\n" + "b 6f\n" + "5:" // Height 1: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "6:" // Height 1: setup done + "mov x27, #0x0\n" + "7:" // Height 1: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 8f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "cbnz x27, 9f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #2\n" + "b 9f\n" + "8:" // Height 1: setup direct input + "mov x25, %x[input_ptr]\n" + "9:" // Height 1: input setup done + "cmp x26, #0x4\n" + "ble 11f\n" + "10:" // Height 1: Multiply loop: Main loop head + "whilelt p0.s, XZR, x26\n" + "ld1rqw { z0.s }, p0/Z, [x25]\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "fmla z8.s, z6.s, z0.s[0]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z9.s, z7.s, z0.s[0]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "fmla z10.s, z6.s, z0.s[0]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "fmla z11.s, z7.s, z0.s[0]\n" + "ld1w { z6.s }, p5/Z, [x11, #1, MUL VL]\n" + "fmla z8.s, z6.s, z0.s[1]\n" + "ld1w { z7.s }, p5/Z, [x10, #1, MUL VL]\n" + "fmla z9.s, z7.s, z0.s[1]\n" + "ld1w { z6.s }, p5/Z, [x9, #1, MUL VL]\n" + "fmla z10.s, z6.s, z0.s[1]\n" + "ld1w { z7.s }, p5/Z, [x28, #1, MUL VL]\n" + "fmla z11.s, z7.s, z0.s[1]\n" + "ld1w { z6.s }, p5/Z, [x11, #2, MUL VL]\n" + "fmla z8.s, z6.s, z0.s[2]\n" + "ld1w { z7.s }, p5/Z, [x10, #2, MUL VL]\n" + "fmla z9.s, z7.s, z0.s[2]\n" + "ld1w { z6.s }, p5/Z, [x9, #2, MUL VL]\n" + "fmla z10.s, z6.s, z0.s[2]\n" + "ld1w { z7.s }, p5/Z, [x28, #2, MUL VL]\n" + "fmla z11.s, z7.s, z0.s[2]\n" + "ld1w { z6.s }, p5/Z, [x11, #3, MUL VL]\n" + "fmla z8.s, z6.s, z0.s[3]\n" + "ld1w { z7.s }, p5/Z, [x10, #3, MUL VL]\n" + "fmla z9.s, z7.s, z0.s[3]\n" + "ld1w { z6.s }, p5/Z, [x9, #3, MUL VL]\n" + "sub x26, x26, #0x4\n" + "ld1w { z7.s }, p5/Z, [x28, #3, MUL VL]\n" + "cmp x26, #0x4\n" + "fmla z10.s, z6.s, z0.s[3]\n" + "fmla z11.s, z7.s, z0.s[3]\n" + "add x25, x25, #0x10\n" + "addvl x11, x11, #4\n" + "addvl x10, x10, #4\n" + "addvl x9, x9, #4\n" + "addvl x28, x28, #4\n" + "bgt 10b\n" + "11:" // Height 1: Multiply loop: Single iteration only + "whilelt p0.s, XZR, x26\n" + "ld1rqw { z0.s }, p0/Z, [x25]\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "fmla z8.s, z6.s, z0.s[0]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z9.s, z7.s, z0.s[0]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "subs x26, x26, #0x1\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "fmla z10.s, z6.s, z0.s[0]\n" + "fmla z11.s, z7.s, z0.s[0]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "ble 12f\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z8.s, z6.s, z0.s[1]\n" + "fmla z9.s, z7.s, z0.s[1]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "subs x26, x26, #0x1\n" + "fmla z10.s, z6.s, z0.s[1]\n" + "fmla z11.s, z7.s, z0.s[1]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "ble 12f\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z8.s, z6.s, z0.s[2]\n" + "fmla z9.s, z7.s, z0.s[2]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "subs x26, x26, #0x1\n" + "fmla z10.s, z6.s, z0.s[2]\n" + "fmla z11.s, z7.s, z0.s[2]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "ble 12f\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z8.s, z6.s, z0.s[3]\n" + "fmla z9.s, z7.s, z0.s[3]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "fmla z10.s, z6.s, z0.s[3]\n" + "fmla z11.s, z7.s, z0.s[3]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "12:" // Height 1: Multiply loop: multiply skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 7b\n" + "tbz %x[flags], #1, 13f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rw { z1.s }, p5/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rw { z0.s }, p5/Z, [x19]\n" + "fmin z8.s, p5/M, z8.s, z1.s\n" + "fmin z9.s, p5/M, z9.s, z1.s\n" + "fmin z10.s, p5/M, z10.s, z1.s\n" + "fmin z11.s, p5/M, z11.s, z1.s\n" + "fmax z8.s, p5/M, z8.s, z0.s\n" + "fmax z9.s, p5/M, z9.s, z0.s\n" + "fmax z10.s, p5/M, z10.s, z0.s\n" + "fmax z11.s, p5/M, z11.s, z0.s\n" + "13:" // Height 1: No activation + "st1w { z8.s }, p4, [x12]\n" + "st1w { z9.s }, p3, [x12, #1, MUL VL]\n" + "st1w { z10.s }, p2, [x12, #2, MUL VL]\n" + "st1w { z11.s }, p1, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "14:" // Height 1: Writeback done + "decw x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 2b\n" + "b 86f\n" + "15:" // Height 2 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "16:" // Height 2: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #2\n" + "cntw x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #2\n" + "add x28, x9, x19, LSL #2\n" + "add x19, x28, x19, LSL #2\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 17f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 17f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 17f\n" + "mov x10, x11\n" + "17:" // Height 2: B setup done + "mov x19, #0x0\n" + "whilelt p4.s, x19, x13\n" + "incw x19\n" + "whilelt p3.s, x19, x13\n" + "incw x19\n" + "whilelt p2.s, x19, x13\n" + "incw x19\n" + "whilelt p1.s, x19, x13\n" + "cbz x14, 18f\n" + "ld1w { z8.s }, p5/Z, [x14]\n" + "ld1w { z9.s }, p5/Z, [x14, #1, MUL VL]\n" + "mov z12.d, z8.d\n" + "mov z13.d, z9.d\n" + "ld1w { z10.s }, p5/Z, [x14, #2, MUL VL]\n" + "ld1w { z11.s }, p5/Z, [x14, #3, MUL VL]\n" + "mov z14.d, z10.d\n" + "mov z15.d, z11.d\n" + "addvl x14, x14, #4\n" + "b 20f\n" + "18:" // Height 2: no bias + "tbz %x[flags], #0, 19f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "ld1w { z8.s }, p4/Z, [x12]\n" + "ld1w { z9.s }, p3/Z, [x12, #1, MUL VL]\n" + "ld1w { z10.s }, p2/Z, [x12, #2, MUL VL]\n" + "ld1w { z11.s }, p1/Z, [x12, #3, MUL VL]\n" + "ld1w { z12.s }, p4/Z, [x24]\n" + "ld1w { z13.s }, p3/Z, [x24, #1, MUL VL]\n" + "ld1w { z14.s }, p2/Z, [x24, #2, MUL VL]\n" + "ld1w { z15.s }, p1/Z, [x24, #3, MUL VL]\n" + "b 20f\n" + "19:" // Height 2: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "20:" // Height 2: setup done + "mov x27, #0x0\n" + "21:" // Height 2: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 22f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "cbnz x27, 23f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #2\n" + "add x24, x24, x19, LSL #2\n" + "b 23f\n" + "22:" // Height 2: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #2\n" + "23:" // Height 2: input setup done + "cmp x26, #0x4\n" + "ble 25f\n" + "24:" // Height 2: Multiply loop: Main loop head + "whilelt p0.s, XZR, x26\n" + "ld1rqw { z0.s }, p0/Z, [x25]\n" + "ld1rqw { z1.s }, p0/Z, [x24]\n" + "sub x26, x26, #0x4\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z8.s, z6.s, z0.s[0]\n" + "fmla z12.s, z6.s, z1.s[0]\n" + "fmla z9.s, z7.s, z0.s[0]\n" + "fmla z13.s, z7.s, z1.s[0]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "fmla z10.s, z6.s, z0.s[0]\n" + "fmla z14.s, z6.s, z1.s[0]\n" + "ld1w { z6.s }, p5/Z, [x11, #1, MUL VL]\n" + "cmp x26, #0x4\n" + "fmla z11.s, z7.s, z0.s[0]\n" + "fmla z15.s, z7.s, z1.s[0]\n" + "ld1w { z7.s }, p5/Z, [x10, #1, MUL VL]\n" + "add x25, x25, #0x10\n" + "fmla z8.s, z6.s, z0.s[1]\n" + "fmla z12.s, z6.s, z1.s[1]\n" + "ld1w { z6.s }, p5/Z, [x9, #1, MUL VL]\n" + "add x24, x24, #0x10\n" + "fmla z9.s, z7.s, z0.s[1]\n" + "fmla z13.s, z7.s, z1.s[1]\n" + "ld1w { z7.s }, p5/Z, [x28, #1, MUL VL]\n" + "fmla z10.s, z6.s, z0.s[1]\n" + "fmla z14.s, z6.s, z1.s[1]\n" + "ld1w { z6.s }, p5/Z, [x11, #2, MUL VL]\n" + "fmla z11.s, z7.s, z0.s[1]\n" + "fmla z15.s, z7.s, z1.s[1]\n" + "ld1w { z7.s }, p5/Z, [x10, #2, MUL VL]\n" + "fmla z8.s, z6.s, z0.s[2]\n" + "fmla z12.s, z6.s, z1.s[2]\n" + "ld1w { z6.s }, p5/Z, [x9, #2, MUL VL]\n" + "fmla z9.s, z7.s, z0.s[2]\n" + "fmla z13.s, z7.s, z1.s[2]\n" + "ld1w { z7.s }, p5/Z, [x28, #2, MUL VL]\n" + "fmla z10.s, z6.s, z0.s[2]\n" + "fmla z14.s, z6.s, z1.s[2]\n" + "ld1w { z6.s }, p5/Z, [x11, #3, MUL VL]\n" + "addvl x11, x11, #4\n" + "fmla z11.s, z7.s, z0.s[2]\n" + "fmla z15.s, z7.s, z1.s[2]\n" + "ld1w { z7.s }, p5/Z, [x10, #3, MUL VL]\n" + "addvl x10, x10, #4\n" + "fmla z8.s, z6.s, z0.s[3]\n" + "fmla z12.s, z6.s, z1.s[3]\n" + "ld1w { z6.s }, p5/Z, [x9, #3, MUL VL]\n" + "addvl x9, x9, #4\n" + "fmla z9.s, z7.s, z0.s[3]\n" + "fmla z13.s, z7.s, z1.s[3]\n" + "ld1w { z7.s }, p5/Z, [x28, #3, MUL VL]\n" + "addvl x28, x28, #4\n" + "fmla z10.s, z6.s, z0.s[3]\n" + "fmla z14.s, z6.s, z1.s[3]\n" + "fmla z11.s, z7.s, z0.s[3]\n" + "fmla z15.s, z7.s, z1.s[3]\n" + "bgt 24b\n" + "25:" // Height 2: Multiply loop: Single iteration only + "whilelt p0.s, XZR, x26\n" + "ld1rqw { z0.s }, p0/Z, [x25]\n" + "ld1rqw { z1.s }, p0/Z, [x24]\n" + "subs x26, x26, #0x1\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z8.s, z6.s, z0.s[0]\n" + "fmla z12.s, z6.s, z1.s[0]\n" + "fmla z9.s, z7.s, z0.s[0]\n" + "fmla z13.s, z7.s, z1.s[0]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "fmla z10.s, z6.s, z0.s[0]\n" + "fmla z14.s, z6.s, z1.s[0]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z11.s, z7.s, z0.s[0]\n" + "fmla z15.s, z7.s, z1.s[0]\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "ble 26f\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z8.s, z6.s, z0.s[1]\n" + "fmla z12.s, z6.s, z1.s[1]\n" + "fmla z9.s, z7.s, z0.s[1]\n" + "fmla z13.s, z7.s, z1.s[1]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "subs x26, x26, #0x1\n" + "fmla z10.s, z6.s, z0.s[1]\n" + "fmla z14.s, z6.s, z1.s[1]\n" + "addvl x11, x11, #1\n" + "fmla z11.s, z7.s, z0.s[1]\n" + "fmla z15.s, z7.s, z1.s[1]\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "ble 26f\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z8.s, z6.s, z0.s[2]\n" + "fmla z12.s, z6.s, z1.s[2]\n" + "fmla z9.s, z7.s, z0.s[2]\n" + "fmla z13.s, z7.s, z1.s[2]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "subs x26, x26, #0x1\n" + "fmla z10.s, z6.s, z0.s[2]\n" + "fmla z14.s, z6.s, z1.s[2]\n" + "addvl x11, x11, #1\n" + "fmla z11.s, z7.s, z0.s[2]\n" + "fmla z15.s, z7.s, z1.s[2]\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "ble 26f\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z8.s, z6.s, z0.s[3]\n" + "fmla z12.s, z6.s, z1.s[3]\n" + "fmla z9.s, z7.s, z0.s[3]\n" + "fmla z13.s, z7.s, z1.s[3]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "fmla z10.s, z6.s, z0.s[3]\n" + "fmla z14.s, z6.s, z1.s[3]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z11.s, z7.s, z0.s[3]\n" + "fmla z15.s, z7.s, z1.s[3]\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "26:" // Height 2: Multiply loop: multiply skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 21b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "tbz %x[flags], #1, 27f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rw { z1.s }, p5/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rw { z0.s }, p5/Z, [x19]\n" + "fmin z8.s, p5/M, z8.s, z1.s\n" + "fmin z9.s, p5/M, z9.s, z1.s\n" + "fmin z10.s, p5/M, z10.s, z1.s\n" + "fmin z11.s, p5/M, z11.s, z1.s\n" + "fmin z12.s, p5/M, z12.s, z1.s\n" + "fmin z13.s, p5/M, z13.s, z1.s\n" + "fmin z14.s, p5/M, z14.s, z1.s\n" + "fmin z15.s, p5/M, z15.s, z1.s\n" + "fmax z8.s, p5/M, z8.s, z0.s\n" + "fmax z9.s, p5/M, z9.s, z0.s\n" + "fmax z10.s, p5/M, z10.s, z0.s\n" + "fmax z11.s, p5/M, z11.s, z0.s\n" + "fmax z12.s, p5/M, z12.s, z0.s\n" + "fmax z13.s, p5/M, z13.s, z0.s\n" + "fmax z14.s, p5/M, z14.s, z0.s\n" + "fmax z15.s, p5/M, z15.s, z0.s\n" + "27:" // Height 2: No activation + "st1w { z8.s }, p4, [x12]\n" + "st1w { z9.s }, p3, [x12, #1, MUL VL]\n" + "st1w { z10.s }, p2, [x12, #2, MUL VL]\n" + "st1w { z11.s }, p1, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "st1w { z12.s }, p4, [x24]\n" + "st1w { z13.s }, p3, [x24, #1, MUL VL]\n" + "st1w { z14.s }, p2, [x24, #2, MUL VL]\n" + "st1w { z15.s }, p1, [x24, #3, MUL VL]\n" + "28:" // Height 2: Writeback done + "decw x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 16b\n" + "b 86f\n" + "29:" // Height 3 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "30:" // Height 3: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #2\n" + "cntw x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #2\n" + "add x28, x9, x19, LSL #2\n" + "add x19, x28, x19, LSL #2\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 31f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 31f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 31f\n" + "mov x10, x11\n" + "31:" // Height 3: B setup done + "mov x19, #0x0\n" + "whilelt p4.s, x19, x13\n" + "incw x19\n" + "whilelt p3.s, x19, x13\n" + "incw x19\n" + "whilelt p2.s, x19, x13\n" + "incw x19\n" + "whilelt p1.s, x19, x13\n" + "cbz x14, 32f\n" + "ld1w { z8.s }, p5/Z, [x14]\n" + "ld1w { z9.s }, p5/Z, [x14, #1, MUL VL]\n" + "mov z12.d, z8.d\n" + "mov z13.d, z9.d\n" + "ld1w { z10.s }, p5/Z, [x14, #2, MUL VL]\n" + "ld1w { z11.s }, p5/Z, [x14, #3, MUL VL]\n" + "mov z14.d, z10.d\n" + "mov z15.d, z11.d\n" + "mov z16.d, z8.d\n" + "mov z17.d, z9.d\n" + "addvl x14, x14, #4\n" + "mov z18.d, z10.d\n" + "mov z19.d, z11.d\n" + "b 34f\n" + "32:" // Height 3: no bias + "tbz %x[flags], #0, 33f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "ld1w { z8.s }, p4/Z, [x12]\n" + "ld1w { z9.s }, p3/Z, [x12, #1, MUL VL]\n" + "ld1w { z10.s }, p2/Z, [x12, #2, MUL VL]\n" + "ld1w { z11.s }, p1/Z, [x12, #3, MUL VL]\n" + "ld1w { z12.s }, p4/Z, [x24]\n" + "ld1w { z13.s }, p3/Z, [x24, #1, MUL VL]\n" + "ld1w { z14.s }, p2/Z, [x24, #2, MUL VL]\n" + "ld1w { z15.s }, p1/Z, [x24, #3, MUL VL]\n" + "ld1w { z16.s }, p4/Z, [x23]\n" + "ld1w { z17.s }, p3/Z, [x23, #1, MUL VL]\n" + "ld1w { z18.s }, p2/Z, [x23, #2, MUL VL]\n" + "ld1w { z19.s }, p1/Z, [x23, #3, MUL VL]\n" + "b 34f\n" + "33:" // Height 3: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "mov z19.b, #0x0\n" + "34:" // Height 3: setup done + "mov x27, #0x0\n" + "35:" // Height 3: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 36f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "cbnz x27, 37f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #2\n" + "add x24, x24, x19, LSL #2\n" + "add x23, x23, x19, LSL #2\n" + "b 37f\n" + "36:" // Height 3: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "37:" // Height 3: input setup done + "cmp x26, #0x4\n" + "ble 39f\n" + "38:" // Height 3: Multiply loop: Main loop head + "whilelt p0.s, XZR, x26\n" + "ld1rqw { z0.s }, p0/Z, [x25]\n" + "ld1rqw { z1.s }, p0/Z, [x24]\n" + "sub x26, x26, #0x4\n" + "ld1rqw { z2.s }, p0/Z, [x23]\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "fmla z8.s, z6.s, z0.s[0]\n" + "fmla z12.s, z6.s, z1.s[0]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z16.s, z6.s, z2.s[0]\n" + "fmla z9.s, z7.s, z0.s[0]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "fmla z13.s, z7.s, z1.s[0]\n" + "fmla z17.s, z7.s, z2.s[0]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "cmp x26, #0x4\n" + "fmla z10.s, z6.s, z0.s[0]\n" + "fmla z14.s, z6.s, z1.s[0]\n" + "add x25, x25, #0x10\n" + "add x24, x24, #0x10\n" + "fmla z18.s, z6.s, z2.s[0]\n" + "fmla z11.s, z7.s, z0.s[0]\n" + "ld1w { z6.s }, p5/Z, [x11, #1, MUL VL]\n" + "add x23, x23, #0x10\n" + "fmla z15.s, z7.s, z1.s[0]\n" + "fmla z19.s, z7.s, z2.s[0]\n" + "ld1w { z7.s }, p5/Z, [x10, #1, MUL VL]\n" + "fmla z8.s, z6.s, z0.s[1]\n" + "fmla z12.s, z6.s, z1.s[1]\n" + "fmla z16.s, z6.s, z2.s[1]\n" + "fmla z9.s, z7.s, z0.s[1]\n" + "ld1w { z6.s }, p5/Z, [x9, #1, MUL VL]\n" + "fmla z13.s, z7.s, z1.s[1]\n" + "fmla z17.s, z7.s, z2.s[1]\n" + "ld1w { z7.s }, p5/Z, [x28, #1, MUL VL]\n" + "fmla z10.s, z6.s, z0.s[1]\n" + "fmla z14.s, z6.s, z1.s[1]\n" + "fmla z18.s, z6.s, z2.s[1]\n" + "fmla z11.s, z7.s, z0.s[1]\n" + "ld1w { z6.s }, p5/Z, [x11, #2, MUL VL]\n" + "fmla z15.s, z7.s, z1.s[1]\n" + "fmla z19.s, z7.s, z2.s[1]\n" + "ld1w { z7.s }, p5/Z, [x10, #2, MUL VL]\n" + "fmla z8.s, z6.s, z0.s[2]\n" + "fmla z12.s, z6.s, z1.s[2]\n" + "fmla z16.s, z6.s, z2.s[2]\n" + "fmla z9.s, z7.s, z0.s[2]\n" + "ld1w { z6.s }, p5/Z, [x9, #2, MUL VL]\n" + "fmla z13.s, z7.s, z1.s[2]\n" + "fmla z17.s, z7.s, z2.s[2]\n" + "ld1w { z7.s }, p5/Z, [x28, #2, MUL VL]\n" + "fmla z10.s, z6.s, z0.s[2]\n" + "fmla z14.s, z6.s, z1.s[2]\n" + "fmla z18.s, z6.s, z2.s[2]\n" + "fmla z11.s, z7.s, z0.s[2]\n" + "ld1w { z6.s }, p5/Z, [x11, #3, MUL VL]\n" + "addvl x11, x11, #4\n" + "fmla z15.s, z7.s, z1.s[2]\n" + "fmla z19.s, z7.s, z2.s[2]\n" + "ld1w { z7.s }, p5/Z, [x10, #3, MUL VL]\n" + "addvl x10, x10, #4\n" + "fmla z8.s, z6.s, z0.s[3]\n" + "fmla z12.s, z6.s, z1.s[3]\n" + "fmla z16.s, z6.s, z2.s[3]\n" + "fmla z9.s, z7.s, z0.s[3]\n" + "ld1w { z6.s }, p5/Z, [x9, #3, MUL VL]\n" + "addvl x9, x9, #4\n" + "fmla z13.s, z7.s, z1.s[3]\n" + "fmla z17.s, z7.s, z2.s[3]\n" + "ld1w { z7.s }, p5/Z, [x28, #3, MUL VL]\n" + "addvl x28, x28, #4\n" + "fmla z10.s, z6.s, z0.s[3]\n" + "fmla z14.s, z6.s, z1.s[3]\n" + "fmla z18.s, z6.s, z2.s[3]\n" + "fmla z11.s, z7.s, z0.s[3]\n" + "fmla z15.s, z7.s, z1.s[3]\n" + "fmla z19.s, z7.s, z2.s[3]\n" + "bgt 38b\n" + "39:" // Height 3: Multiply loop: Single iteration only + "whilelt p0.s, XZR, x26\n" + "ld1rqw { z0.s }, p0/Z, [x25]\n" + "ld1rqw { z1.s }, p0/Z, [x24]\n" + "subs x26, x26, #0x1\n" + "ld1rqw { z2.s }, p0/Z, [x23]\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "fmla z8.s, z6.s, z0.s[0]\n" + "fmla z12.s, z6.s, z1.s[0]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z16.s, z6.s, z2.s[0]\n" + "fmla z9.s, z7.s, z0.s[0]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "fmla z13.s, z7.s, z1.s[0]\n" + "fmla z17.s, z7.s, z2.s[0]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "addvl x11, x11, #1\n" + "fmla z10.s, z6.s, z0.s[0]\n" + "fmla z14.s, z6.s, z1.s[0]\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "fmla z18.s, z6.s, z2.s[0]\n" + "fmla z11.s, z7.s, z0.s[0]\n" + "addvl x28, x28, #1\n" + "fmla z15.s, z7.s, z1.s[0]\n" + "fmla z19.s, z7.s, z2.s[0]\n" + "ble 40f\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z8.s, z6.s, z0.s[1]\n" + "fmla z12.s, z6.s, z1.s[1]\n" + "fmla z16.s, z6.s, z2.s[1]\n" + "fmla z9.s, z7.s, z0.s[1]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "subs x26, x26, #0x1\n" + "fmla z13.s, z7.s, z1.s[1]\n" + "fmla z17.s, z7.s, z2.s[1]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "addvl x11, x11, #1\n" + "fmla z10.s, z6.s, z0.s[1]\n" + "fmla z14.s, z6.s, z1.s[1]\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "fmla z18.s, z6.s, z2.s[1]\n" + "fmla z11.s, z7.s, z0.s[1]\n" + "addvl x28, x28, #1\n" + "fmla z15.s, z7.s, z1.s[1]\n" + "fmla z19.s, z7.s, z2.s[1]\n" + "ble 40f\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z8.s, z6.s, z0.s[2]\n" + "fmla z12.s, z6.s, z1.s[2]\n" + "fmla z16.s, z6.s, z2.s[2]\n" + "fmla z9.s, z7.s, z0.s[2]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "subs x26, x26, #0x1\n" + "fmla z13.s, z7.s, z1.s[2]\n" + "fmla z17.s, z7.s, z2.s[2]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "addvl x11, x11, #1\n" + "fmla z10.s, z6.s, z0.s[2]\n" + "fmla z14.s, z6.s, z1.s[2]\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "fmla z18.s, z6.s, z2.s[2]\n" + "fmla z11.s, z7.s, z0.s[2]\n" + "addvl x28, x28, #1\n" + "fmla z15.s, z7.s, z1.s[2]\n" + "fmla z19.s, z7.s, z2.s[2]\n" + "ble 40f\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z8.s, z6.s, z0.s[3]\n" + "fmla z12.s, z6.s, z1.s[3]\n" + "fmla z16.s, z6.s, z2.s[3]\n" + "fmla z9.s, z7.s, z0.s[3]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "addvl x11, x11, #1\n" + "fmla z13.s, z7.s, z1.s[3]\n" + "fmla z17.s, z7.s, z2.s[3]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "addvl x10, x10, #1\n" + "fmla z10.s, z6.s, z0.s[3]\n" + "fmla z14.s, z6.s, z1.s[3]\n" + "addvl x9, x9, #1\n" + "addvl x28, x28, #1\n" + "fmla z18.s, z6.s, z2.s[3]\n" + "fmla z11.s, z7.s, z0.s[3]\n" + "fmla z15.s, z7.s, z1.s[3]\n" + "fmla z19.s, z7.s, z2.s[3]\n" + "40:" // Height 3: Multiply loop: multiply skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 35b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "tbz %x[flags], #1, 41f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rw { z1.s }, p5/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rw { z0.s }, p5/Z, [x19]\n" + "fmin z8.s, p5/M, z8.s, z1.s\n" + "fmin z9.s, p5/M, z9.s, z1.s\n" + "fmin z10.s, p5/M, z10.s, z1.s\n" + "fmin z11.s, p5/M, z11.s, z1.s\n" + "fmin z12.s, p5/M, z12.s, z1.s\n" + "fmin z13.s, p5/M, z13.s, z1.s\n" + "fmin z14.s, p5/M, z14.s, z1.s\n" + "fmin z15.s, p5/M, z15.s, z1.s\n" + "fmin z16.s, p5/M, z16.s, z1.s\n" + "fmin z17.s, p5/M, z17.s, z1.s\n" + "fmin z18.s, p5/M, z18.s, z1.s\n" + "fmin z19.s, p5/M, z19.s, z1.s\n" + "fmax z8.s, p5/M, z8.s, z0.s\n" + "fmax z9.s, p5/M, z9.s, z0.s\n" + "fmax z10.s, p5/M, z10.s, z0.s\n" + "fmax z11.s, p5/M, z11.s, z0.s\n" + "fmax z12.s, p5/M, z12.s, z0.s\n" + "fmax z13.s, p5/M, z13.s, z0.s\n" + "fmax z14.s, p5/M, z14.s, z0.s\n" + "fmax z15.s, p5/M, z15.s, z0.s\n" + "fmax z16.s, p5/M, z16.s, z0.s\n" + "fmax z17.s, p5/M, z17.s, z0.s\n" + "fmax z18.s, p5/M, z18.s, z0.s\n" + "fmax z19.s, p5/M, z19.s, z0.s\n" + "41:" // Height 3: No activation + "st1w { z8.s }, p4, [x12]\n" + "st1w { z9.s }, p3, [x12, #1, MUL VL]\n" + "st1w { z10.s }, p2, [x12, #2, MUL VL]\n" + "st1w { z11.s }, p1, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "st1w { z12.s }, p4, [x24]\n" + "st1w { z13.s }, p3, [x24, #1, MUL VL]\n" + "st1w { z14.s }, p2, [x24, #2, MUL VL]\n" + "st1w { z15.s }, p1, [x24, #3, MUL VL]\n" + "st1w { z16.s }, p4, [x23]\n" + "st1w { z17.s }, p3, [x23, #1, MUL VL]\n" + "st1w { z18.s }, p2, [x23, #2, MUL VL]\n" + "st1w { z19.s }, p1, [x23, #3, MUL VL]\n" + "42:" // Height 3: Writeback done + "decw x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 30b\n" + "b 86f\n" + "43:" // Height 4 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "44:" // Height 4: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #2\n" + "cntw x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #2\n" + "add x28, x9, x19, LSL #2\n" + "add x19, x28, x19, LSL #2\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 45f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 45f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 45f\n" + "mov x10, x11\n" + "45:" // Height 4: B setup done + "mov x19, #0x0\n" + "whilelt p4.s, x19, x13\n" + "incw x19\n" + "whilelt p3.s, x19, x13\n" + "incw x19\n" + "whilelt p2.s, x19, x13\n" + "incw x19\n" + "whilelt p1.s, x19, x13\n" + "cbz x14, 46f\n" + "ld1w { z8.s }, p5/Z, [x14]\n" + "ld1w { z9.s }, p5/Z, [x14, #1, MUL VL]\n" + "mov z12.d, z8.d\n" + "mov z13.d, z9.d\n" + "ld1w { z10.s }, p5/Z, [x14, #2, MUL VL]\n" + "ld1w { z11.s }, p5/Z, [x14, #3, MUL VL]\n" + "mov z14.d, z10.d\n" + "mov z15.d, z11.d\n" + "mov z16.d, z8.d\n" + "mov z17.d, z9.d\n" + "addvl x14, x14, #4\n" + "mov z18.d, z10.d\n" + "mov z19.d, z11.d\n" + "mov z20.d, z8.d\n" + "mov z21.d, z9.d\n" + "mov z22.d, z10.d\n" + "mov z23.d, z11.d\n" + "b 48f\n" + "46:" // Height 4: no bias + "tbz %x[flags], #0, 47f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "ld1w { z8.s }, p4/Z, [x12]\n" + "add x22, x23, x19, LSL #2\n" + "ld1w { z9.s }, p3/Z, [x12, #1, MUL VL]\n" + "ld1w { z10.s }, p2/Z, [x12, #2, MUL VL]\n" + "ld1w { z11.s }, p1/Z, [x12, #3, MUL VL]\n" + "ld1w { z12.s }, p4/Z, [x24]\n" + "ld1w { z13.s }, p3/Z, [x24, #1, MUL VL]\n" + "ld1w { z14.s }, p2/Z, [x24, #2, MUL VL]\n" + "ld1w { z15.s }, p1/Z, [x24, #3, MUL VL]\n" + "ld1w { z16.s }, p4/Z, [x23]\n" + "ld1w { z17.s }, p3/Z, [x23, #1, MUL VL]\n" + "ld1w { z18.s }, p2/Z, [x23, #2, MUL VL]\n" + "ld1w { z19.s }, p1/Z, [x23, #3, MUL VL]\n" + "ld1w { z20.s }, p4/Z, [x22]\n" + "ld1w { z21.s }, p3/Z, [x22, #1, MUL VL]\n" + "ld1w { z22.s }, p2/Z, [x22, #2, MUL VL]\n" + "ld1w { z23.s }, p1/Z, [x22, #3, MUL VL]\n" + "b 48f\n" + "47:" // Height 4: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "mov z19.b, #0x0\n" + "mov z20.b, #0x0\n" + "mov z21.b, #0x0\n" + "mov z22.b, #0x0\n" + "mov z23.b, #0x0\n" + "48:" // Height 4: setup done + "mov x27, #0x0\n" + "49:" // Height 4: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 50f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "ldr x22, [x20, #0x18]\n" + "cbnz x27, 51f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #2\n" + "add x24, x24, x19, LSL #2\n" + "add x23, x23, x19, LSL #2\n" + "add x22, x22, x19, LSL #2\n" + "b 51f\n" + "50:" // Height 4: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "add x22, x23, x19, LSL #2\n" + "51:" // Height 4: input setup done + "cmp x26, #0x4\n" + "ble 53f\n" + "52:" // Height 4: Multiply loop: Main loop head + "whilelt p0.s, XZR, x26\n" + "ld1rqw { z0.s }, p0/Z, [x25]\n" + "ld1rqw { z1.s }, p0/Z, [x24]\n" + "sub x26, x26, #0x4\n" + "ld1rqw { z2.s }, p0/Z, [x23]\n" + "ld1rqw { z3.s }, p0/Z, [x22]\n" + "cmp x26, #0x4\n" + "add x25, x25, #0x10\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z8.s, z6.s, z0.s[0]\n" + "fmla z12.s, z6.s, z1.s[0]\n" + "fmla z16.s, z6.s, z2.s[0]\n" + "fmla z20.s, z6.s, z3.s[0]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "add x24, x24, #0x10\n" + "fmla z9.s, z7.s, z0.s[0]\n" + "fmla z13.s, z7.s, z1.s[0]\n" + "add x23, x23, #0x10\n" + "add x22, x22, #0x10\n" + "fmla z17.s, z7.s, z2.s[0]\n" + "fmla z21.s, z7.s, z3.s[0]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "fmla z10.s, z6.s, z0.s[0]\n" + "fmla z14.s, z6.s, z1.s[0]\n" + "fmla z18.s, z6.s, z2.s[0]\n" + "fmla z22.s, z6.s, z3.s[0]\n" + "ld1w { z6.s }, p5/Z, [x11, #1, MUL VL]\n" + "fmla z11.s, z7.s, z0.s[0]\n" + "fmla z15.s, z7.s, z1.s[0]\n" + "fmla z19.s, z7.s, z2.s[0]\n" + "fmla z23.s, z7.s, z3.s[0]\n" + "ld1w { z7.s }, p5/Z, [x10, #1, MUL VL]\n" + "fmla z8.s, z6.s, z0.s[1]\n" + "fmla z12.s, z6.s, z1.s[1]\n" + "fmla z16.s, z6.s, z2.s[1]\n" + "fmla z20.s, z6.s, z3.s[1]\n" + "ld1w { z6.s }, p5/Z, [x9, #1, MUL VL]\n" + "fmla z9.s, z7.s, z0.s[1]\n" + "fmla z13.s, z7.s, z1.s[1]\n" + "fmla z17.s, z7.s, z2.s[1]\n" + "fmla z21.s, z7.s, z3.s[1]\n" + "ld1w { z7.s }, p5/Z, [x28, #1, MUL VL]\n" + "fmla z10.s, z6.s, z0.s[1]\n" + "fmla z14.s, z6.s, z1.s[1]\n" + "fmla z18.s, z6.s, z2.s[1]\n" + "fmla z22.s, z6.s, z3.s[1]\n" + "ld1w { z6.s }, p5/Z, [x11, #2, MUL VL]\n" + "fmla z11.s, z7.s, z0.s[1]\n" + "fmla z15.s, z7.s, z1.s[1]\n" + "fmla z19.s, z7.s, z2.s[1]\n" + "fmla z23.s, z7.s, z3.s[1]\n" + "ld1w { z7.s }, p5/Z, [x10, #2, MUL VL]\n" + "fmla z8.s, z6.s, z0.s[2]\n" + "fmla z12.s, z6.s, z1.s[2]\n" + "fmla z16.s, z6.s, z2.s[2]\n" + "fmla z20.s, z6.s, z3.s[2]\n" + "ld1w { z6.s }, p5/Z, [x9, #2, MUL VL]\n" + "fmla z9.s, z7.s, z0.s[2]\n" + "fmla z13.s, z7.s, z1.s[2]\n" + "fmla z17.s, z7.s, z2.s[2]\n" + "fmla z21.s, z7.s, z3.s[2]\n" + "ld1w { z7.s }, p5/Z, [x28, #2, MUL VL]\n" + "fmla z10.s, z6.s, z0.s[2]\n" + "fmla z14.s, z6.s, z1.s[2]\n" + "fmla z18.s, z6.s, z2.s[2]\n" + "fmla z22.s, z6.s, z3.s[2]\n" + "ld1w { z6.s }, p5/Z, [x11, #3, MUL VL]\n" + "addvl x11, x11, #4\n" + "fmla z11.s, z7.s, z0.s[2]\n" + "fmla z15.s, z7.s, z1.s[2]\n" + "fmla z19.s, z7.s, z2.s[2]\n" + "fmla z23.s, z7.s, z3.s[2]\n" + "ld1w { z7.s }, p5/Z, [x10, #3, MUL VL]\n" + "addvl x10, x10, #4\n" + "fmla z8.s, z6.s, z0.s[3]\n" + "fmla z12.s, z6.s, z1.s[3]\n" + "fmla z16.s, z6.s, z2.s[3]\n" + "fmla z20.s, z6.s, z3.s[3]\n" + "ld1w { z6.s }, p5/Z, [x9, #3, MUL VL]\n" + "addvl x9, x9, #4\n" + "fmla z9.s, z7.s, z0.s[3]\n" + "fmla z13.s, z7.s, z1.s[3]\n" + "fmla z17.s, z7.s, z2.s[3]\n" + "fmla z21.s, z7.s, z3.s[3]\n" + "ld1w { z7.s }, p5/Z, [x28, #3, MUL VL]\n" + "addvl x28, x28, #4\n" + "fmla z10.s, z6.s, z0.s[3]\n" + "fmla z14.s, z6.s, z1.s[3]\n" + "fmla z18.s, z6.s, z2.s[3]\n" + "fmla z22.s, z6.s, z3.s[3]\n" + "fmla z11.s, z7.s, z0.s[3]\n" + "fmla z15.s, z7.s, z1.s[3]\n" + "fmla z19.s, z7.s, z2.s[3]\n" + "fmla z23.s, z7.s, z3.s[3]\n" + "bgt 52b\n" + "53:" // Height 4: Multiply loop: Single iteration only + "whilelt p0.s, XZR, x26\n" + "ld1rqw { z0.s }, p0/Z, [x25]\n" + "ld1rqw { z1.s }, p0/Z, [x24]\n" + "subs x26, x26, #0x1\n" + "ld1rqw { z2.s }, p0/Z, [x23]\n" + "ld1rqw { z3.s }, p0/Z, [x22]\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z8.s, z6.s, z0.s[0]\n" + "fmla z12.s, z6.s, z1.s[0]\n" + "fmla z16.s, z6.s, z2.s[0]\n" + "fmla z20.s, z6.s, z3.s[0]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "addvl x11, x11, #1\n" + "fmla z9.s, z7.s, z0.s[0]\n" + "fmla z13.s, z7.s, z1.s[0]\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "fmla z17.s, z7.s, z2.s[0]\n" + "fmla z21.s, z7.s, z3.s[0]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.s, z6.s, z0.s[0]\n" + "fmla z14.s, z6.s, z1.s[0]\n" + "fmla z18.s, z6.s, z2.s[0]\n" + "fmla z22.s, z6.s, z3.s[0]\n" + "fmla z11.s, z7.s, z0.s[0]\n" + "fmla z15.s, z7.s, z1.s[0]\n" + "fmla z19.s, z7.s, z2.s[0]\n" + "fmla z23.s, z7.s, z3.s[0]\n" + "ble 54f\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z8.s, z6.s, z0.s[1]\n" + "fmla z12.s, z6.s, z1.s[1]\n" + "fmla z16.s, z6.s, z2.s[1]\n" + "fmla z20.s, z6.s, z3.s[1]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "subs x26, x26, #0x1\n" + "fmla z9.s, z7.s, z0.s[1]\n" + "fmla z13.s, z7.s, z1.s[1]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z17.s, z7.s, z2.s[1]\n" + "fmla z21.s, z7.s, z3.s[1]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "addvl x9, x9, #1\n" + "fmla z10.s, z6.s, z0.s[1]\n" + "fmla z14.s, z6.s, z1.s[1]\n" + "addvl x28, x28, #1\n" + "fmla z18.s, z6.s, z2.s[1]\n" + "fmla z22.s, z6.s, z3.s[1]\n" + "fmla z11.s, z7.s, z0.s[1]\n" + "fmla z15.s, z7.s, z1.s[1]\n" + "fmla z19.s, z7.s, z2.s[1]\n" + "fmla z23.s, z7.s, z3.s[1]\n" + "ble 54f\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z8.s, z6.s, z0.s[2]\n" + "fmla z12.s, z6.s, z1.s[2]\n" + "fmla z16.s, z6.s, z2.s[2]\n" + "fmla z20.s, z6.s, z3.s[2]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "subs x26, x26, #0x1\n" + "fmla z9.s, z7.s, z0.s[2]\n" + "fmla z13.s, z7.s, z1.s[2]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z17.s, z7.s, z2.s[2]\n" + "fmla z21.s, z7.s, z3.s[2]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "addvl x9, x9, #1\n" + "fmla z10.s, z6.s, z0.s[2]\n" + "fmla z14.s, z6.s, z1.s[2]\n" + "addvl x28, x28, #1\n" + "fmla z18.s, z6.s, z2.s[2]\n" + "fmla z22.s, z6.s, z3.s[2]\n" + "fmla z11.s, z7.s, z0.s[2]\n" + "fmla z15.s, z7.s, z1.s[2]\n" + "fmla z19.s, z7.s, z2.s[2]\n" + "fmla z23.s, z7.s, z3.s[2]\n" + "ble 54f\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z8.s, z6.s, z0.s[3]\n" + "fmla z12.s, z6.s, z1.s[3]\n" + "fmla z16.s, z6.s, z2.s[3]\n" + "fmla z20.s, z6.s, z3.s[3]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "addvl x11, x11, #1\n" + "fmla z9.s, z7.s, z0.s[3]\n" + "fmla z13.s, z7.s, z1.s[3]\n" + "addvl x10, x10, #1\n" + "addvl x9, x9, #1\n" + "fmla z17.s, z7.s, z2.s[3]\n" + "fmla z21.s, z7.s, z3.s[3]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.s, z6.s, z0.s[3]\n" + "fmla z14.s, z6.s, z1.s[3]\n" + "fmla z18.s, z6.s, z2.s[3]\n" + "fmla z22.s, z6.s, z3.s[3]\n" + "fmla z11.s, z7.s, z0.s[3]\n" + "fmla z15.s, z7.s, z1.s[3]\n" + "fmla z19.s, z7.s, z2.s[3]\n" + "fmla z23.s, z7.s, z3.s[3]\n" + "54:" // Height 4: Multiply loop: multiply skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 49b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "add x22, x23, x19, LSL #2\n" + "tbz %x[flags], #1, 55f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rw { z1.s }, p5/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rw { z0.s }, p5/Z, [x19]\n" + "fmin z8.s, p5/M, z8.s, z1.s\n" + "fmin z9.s, p5/M, z9.s, z1.s\n" + "fmin z10.s, p5/M, z10.s, z1.s\n" + "fmin z11.s, p5/M, z11.s, z1.s\n" + "fmin z12.s, p5/M, z12.s, z1.s\n" + "fmin z13.s, p5/M, z13.s, z1.s\n" + "fmin z14.s, p5/M, z14.s, z1.s\n" + "fmin z15.s, p5/M, z15.s, z1.s\n" + "fmin z16.s, p5/M, z16.s, z1.s\n" + "fmin z17.s, p5/M, z17.s, z1.s\n" + "fmin z18.s, p5/M, z18.s, z1.s\n" + "fmin z19.s, p5/M, z19.s, z1.s\n" + "fmin z20.s, p5/M, z20.s, z1.s\n" + "fmin z21.s, p5/M, z21.s, z1.s\n" + "fmin z22.s, p5/M, z22.s, z1.s\n" + "fmin z23.s, p5/M, z23.s, z1.s\n" + "fmax z8.s, p5/M, z8.s, z0.s\n" + "fmax z9.s, p5/M, z9.s, z0.s\n" + "fmax z10.s, p5/M, z10.s, z0.s\n" + "fmax z11.s, p5/M, z11.s, z0.s\n" + "fmax z12.s, p5/M, z12.s, z0.s\n" + "fmax z13.s, p5/M, z13.s, z0.s\n" + "fmax z14.s, p5/M, z14.s, z0.s\n" + "fmax z15.s, p5/M, z15.s, z0.s\n" + "fmax z16.s, p5/M, z16.s, z0.s\n" + "fmax z17.s, p5/M, z17.s, z0.s\n" + "fmax z18.s, p5/M, z18.s, z0.s\n" + "fmax z19.s, p5/M, z19.s, z0.s\n" + "fmax z20.s, p5/M, z20.s, z0.s\n" + "fmax z21.s, p5/M, z21.s, z0.s\n" + "fmax z22.s, p5/M, z22.s, z0.s\n" + "fmax z23.s, p5/M, z23.s, z0.s\n" + "55:" // Height 4: No activation + "st1w { z8.s }, p4, [x12]\n" + "st1w { z9.s }, p3, [x12, #1, MUL VL]\n" + "st1w { z10.s }, p2, [x12, #2, MUL VL]\n" + "st1w { z11.s }, p1, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "st1w { z12.s }, p4, [x24]\n" + "st1w { z13.s }, p3, [x24, #1, MUL VL]\n" + "st1w { z14.s }, p2, [x24, #2, MUL VL]\n" + "st1w { z15.s }, p1, [x24, #3, MUL VL]\n" + "st1w { z16.s }, p4, [x23]\n" + "st1w { z17.s }, p3, [x23, #1, MUL VL]\n" + "st1w { z18.s }, p2, [x23, #2, MUL VL]\n" + "st1w { z19.s }, p1, [x23, #3, MUL VL]\n" + "st1w { z20.s }, p4, [x22]\n" + "st1w { z21.s }, p3, [x22, #1, MUL VL]\n" + "st1w { z22.s }, p2, [x22, #2, MUL VL]\n" + "st1w { z23.s }, p1, [x22, #3, MUL VL]\n" + "56:" // Height 4: Writeback done + "decw x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 44b\n" + "b 86f\n" + "57:" // Height 5 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "58:" // Height 5: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #2\n" + "cntw x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #2\n" + "add x28, x9, x19, LSL #2\n" + "add x19, x28, x19, LSL #2\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 59f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 59f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 59f\n" + "mov x10, x11\n" + "59:" // Height 5: B setup done + "mov x19, #0x0\n" + "whilelt p4.s, x19, x13\n" + "incw x19\n" + "whilelt p3.s, x19, x13\n" + "incw x19\n" + "whilelt p2.s, x19, x13\n" + "incw x19\n" + "whilelt p1.s, x19, x13\n" + "cbz x14, 60f\n" + "ld1w { z8.s }, p5/Z, [x14]\n" + "ld1w { z9.s }, p5/Z, [x14, #1, MUL VL]\n" + "mov z12.d, z8.d\n" + "mov z13.d, z9.d\n" + "ld1w { z10.s }, p5/Z, [x14, #2, MUL VL]\n" + "ld1w { z11.s }, p5/Z, [x14, #3, MUL VL]\n" + "mov z14.d, z10.d\n" + "mov z15.d, z11.d\n" + "mov z16.d, z8.d\n" + "mov z17.d, z9.d\n" + "addvl x14, x14, #4\n" + "mov z18.d, z10.d\n" + "mov z19.d, z11.d\n" + "mov z20.d, z8.d\n" + "mov z21.d, z9.d\n" + "mov z22.d, z10.d\n" + "mov z23.d, z11.d\n" + "mov z24.d, z8.d\n" + "mov z25.d, z9.d\n" + "mov z26.d, z10.d\n" + "mov z27.d, z11.d\n" + "b 62f\n" + "60:" // Height 5: no bias + "tbz %x[flags], #0, 61f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "ld1w { z8.s }, p4/Z, [x12]\n" + "add x22, x23, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "ld1w { z9.s }, p3/Z, [x12, #1, MUL VL]\n" + "ld1w { z10.s }, p2/Z, [x12, #2, MUL VL]\n" + "ld1w { z11.s }, p1/Z, [x12, #3, MUL VL]\n" + "ld1w { z12.s }, p4/Z, [x24]\n" + "ld1w { z13.s }, p3/Z, [x24, #1, MUL VL]\n" + "ld1w { z14.s }, p2/Z, [x24, #2, MUL VL]\n" + "ld1w { z15.s }, p1/Z, [x24, #3, MUL VL]\n" + "ld1w { z16.s }, p4/Z, [x23]\n" + "ld1w { z17.s }, p3/Z, [x23, #1, MUL VL]\n" + "ld1w { z18.s }, p2/Z, [x23, #2, MUL VL]\n" + "ld1w { z19.s }, p1/Z, [x23, #3, MUL VL]\n" + "ld1w { z20.s }, p4/Z, [x22]\n" + "ld1w { z21.s }, p3/Z, [x22, #1, MUL VL]\n" + "ld1w { z22.s }, p2/Z, [x22, #2, MUL VL]\n" + "ld1w { z23.s }, p1/Z, [x22, #3, MUL VL]\n" + "ld1w { z24.s }, p4/Z, [x21]\n" + "ld1w { z25.s }, p3/Z, [x21, #1, MUL VL]\n" + "ld1w { z26.s }, p2/Z, [x21, #2, MUL VL]\n" + "ld1w { z27.s }, p1/Z, [x21, #3, MUL VL]\n" + "b 62f\n" + "61:" // Height 5: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "mov z19.b, #0x0\n" + "mov z20.b, #0x0\n" + "mov z21.b, #0x0\n" + "mov z22.b, #0x0\n" + "mov z23.b, #0x0\n" + "mov z24.b, #0x0\n" + "mov z25.b, #0x0\n" + "mov z26.b, #0x0\n" + "mov z27.b, #0x0\n" + "62:" // Height 5: setup done + "mov x27, #0x0\n" + "63:" // Height 5: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 64f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "ldr x22, [x20, #0x18]\n" + "ldr x21, [x20, #0x20]\n" + "cbnz x27, 65f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #2\n" + "add x24, x24, x19, LSL #2\n" + "add x23, x23, x19, LSL #2\n" + "add x22, x22, x19, LSL #2\n" + "add x21, x21, x19, LSL #2\n" + "b 65f\n" + "64:" // Height 5: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "add x22, x23, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "65:" // Height 5: input setup done + "cmp x26, #0x4\n" + "ble 67f\n" + "66:" // Height 5: Multiply loop: Main loop head + "whilelt p0.s, XZR, x26\n" + "ld1rqw { z0.s }, p0/Z, [x25]\n" + "ld1rqw { z1.s }, p0/Z, [x24]\n" + "sub x26, x26, #0x4\n" + "ld1rqw { z2.s }, p0/Z, [x23]\n" + "ld1rqw { z3.s }, p0/Z, [x22]\n" + "cmp x26, #0x4\n" + "add x25, x25, #0x10\n" + "ld1rqw { z4.s }, p0/Z, [x21]\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "fmla z8.s, z6.s, z0.s[0]\n" + "fmla z12.s, z6.s, z1.s[0]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z16.s, z6.s, z2.s[0]\n" + "fmla z20.s, z6.s, z3.s[0]\n" + "add x24, x24, #0x10\n" + "fmla z24.s, z6.s, z4.s[0]\n" + "fmla z9.s, z7.s, z0.s[0]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "add x23, x23, #0x10\n" + "fmla z13.s, z7.s, z1.s[0]\n" + "fmla z17.s, z7.s, z2.s[0]\n" + "add x22, x22, #0x10\n" + "add x21, x21, #0x10\n" + "fmla z21.s, z7.s, z3.s[0]\n" + "fmla z25.s, z7.s, z4.s[0]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "fmla z10.s, z6.s, z0.s[0]\n" + "fmla z14.s, z6.s, z1.s[0]\n" + "fmla z18.s, z6.s, z2.s[0]\n" + "fmla z22.s, z6.s, z3.s[0]\n" + "fmla z26.s, z6.s, z4.s[0]\n" + "fmla z11.s, z7.s, z0.s[0]\n" + "ld1w { z6.s }, p5/Z, [x11, #1, MUL VL]\n" + "fmla z15.s, z7.s, z1.s[0]\n" + "fmla z19.s, z7.s, z2.s[0]\n" + "fmla z23.s, z7.s, z3.s[0]\n" + "fmla z27.s, z7.s, z4.s[0]\n" + "ld1w { z7.s }, p5/Z, [x10, #1, MUL VL]\n" + "fmla z8.s, z6.s, z0.s[1]\n" + "fmla z12.s, z6.s, z1.s[1]\n" + "fmla z16.s, z6.s, z2.s[1]\n" + "fmla z20.s, z6.s, z3.s[1]\n" + "fmla z24.s, z6.s, z4.s[1]\n" + "fmla z9.s, z7.s, z0.s[1]\n" + "ld1w { z6.s }, p5/Z, [x9, #1, MUL VL]\n" + "fmla z13.s, z7.s, z1.s[1]\n" + "fmla z17.s, z7.s, z2.s[1]\n" + "fmla z21.s, z7.s, z3.s[1]\n" + "fmla z25.s, z7.s, z4.s[1]\n" + "ld1w { z7.s }, p5/Z, [x28, #1, MUL VL]\n" + "fmla z10.s, z6.s, z0.s[1]\n" + "fmla z14.s, z6.s, z1.s[1]\n" + "fmla z18.s, z6.s, z2.s[1]\n" + "fmla z22.s, z6.s, z3.s[1]\n" + "fmla z26.s, z6.s, z4.s[1]\n" + "fmla z11.s, z7.s, z0.s[1]\n" + "ld1w { z6.s }, p5/Z, [x11, #2, MUL VL]\n" + "fmla z15.s, z7.s, z1.s[1]\n" + "fmla z19.s, z7.s, z2.s[1]\n" + "fmla z23.s, z7.s, z3.s[1]\n" + "fmla z27.s, z7.s, z4.s[1]\n" + "ld1w { z7.s }, p5/Z, [x10, #2, MUL VL]\n" + "fmla z8.s, z6.s, z0.s[2]\n" + "fmla z12.s, z6.s, z1.s[2]\n" + "fmla z16.s, z6.s, z2.s[2]\n" + "fmla z20.s, z6.s, z3.s[2]\n" + "fmla z24.s, z6.s, z4.s[2]\n" + "fmla z9.s, z7.s, z0.s[2]\n" + "ld1w { z6.s }, p5/Z, [x9, #2, MUL VL]\n" + "fmla z13.s, z7.s, z1.s[2]\n" + "fmla z17.s, z7.s, z2.s[2]\n" + "fmla z21.s, z7.s, z3.s[2]\n" + "fmla z25.s, z7.s, z4.s[2]\n" + "ld1w { z7.s }, p5/Z, [x28, #2, MUL VL]\n" + "fmla z10.s, z6.s, z0.s[2]\n" + "fmla z14.s, z6.s, z1.s[2]\n" + "fmla z18.s, z6.s, z2.s[2]\n" + "fmla z22.s, z6.s, z3.s[2]\n" + "fmla z26.s, z6.s, z4.s[2]\n" + "fmla z11.s, z7.s, z0.s[2]\n" + "ld1w { z6.s }, p5/Z, [x11, #3, MUL VL]\n" + "addvl x11, x11, #4\n" + "fmla z15.s, z7.s, z1.s[2]\n" + "fmla z19.s, z7.s, z2.s[2]\n" + "fmla z23.s, z7.s, z3.s[2]\n" + "fmla z27.s, z7.s, z4.s[2]\n" + "ld1w { z7.s }, p5/Z, [x10, #3, MUL VL]\n" + "addvl x10, x10, #4\n" + "fmla z8.s, z6.s, z0.s[3]\n" + "fmla z12.s, z6.s, z1.s[3]\n" + "fmla z16.s, z6.s, z2.s[3]\n" + "fmla z20.s, z6.s, z3.s[3]\n" + "fmla z24.s, z6.s, z4.s[3]\n" + "fmla z9.s, z7.s, z0.s[3]\n" + "ld1w { z6.s }, p5/Z, [x9, #3, MUL VL]\n" + "addvl x9, x9, #4\n" + "fmla z13.s, z7.s, z1.s[3]\n" + "fmla z17.s, z7.s, z2.s[3]\n" + "fmla z21.s, z7.s, z3.s[3]\n" + "fmla z25.s, z7.s, z4.s[3]\n" + "ld1w { z7.s }, p5/Z, [x28, #3, MUL VL]\n" + "addvl x28, x28, #4\n" + "fmla z10.s, z6.s, z0.s[3]\n" + "fmla z14.s, z6.s, z1.s[3]\n" + "fmla z18.s, z6.s, z2.s[3]\n" + "fmla z22.s, z6.s, z3.s[3]\n" + "fmla z26.s, z6.s, z4.s[3]\n" + "fmla z11.s, z7.s, z0.s[3]\n" + "fmla z15.s, z7.s, z1.s[3]\n" + "fmla z19.s, z7.s, z2.s[3]\n" + "fmla z23.s, z7.s, z3.s[3]\n" + "fmla z27.s, z7.s, z4.s[3]\n" + "bgt 66b\n" + "67:" // Height 5: Multiply loop: Single iteration only + "whilelt p0.s, XZR, x26\n" + "ld1rqw { z0.s }, p0/Z, [x25]\n" + "ld1rqw { z1.s }, p0/Z, [x24]\n" + "subs x26, x26, #0x1\n" + "ld1rqw { z2.s }, p0/Z, [x23]\n" + "ld1rqw { z3.s }, p0/Z, [x22]\n" + "ld1rqw { z4.s }, p0/Z, [x21]\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "fmla z8.s, z6.s, z0.s[0]\n" + "fmla z12.s, z6.s, z1.s[0]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z16.s, z6.s, z2.s[0]\n" + "fmla z20.s, z6.s, z3.s[0]\n" + "addvl x11, x11, #1\n" + "fmla z24.s, z6.s, z4.s[0]\n" + "fmla z9.s, z7.s, z0.s[0]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "addvl x10, x10, #1\n" + "fmla z13.s, z7.s, z1.s[0]\n" + "fmla z17.s, z7.s, z2.s[0]\n" + "addvl x9, x9, #1\n" + "fmla z21.s, z7.s, z3.s[0]\n" + "fmla z25.s, z7.s, z4.s[0]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.s, z6.s, z0.s[0]\n" + "fmla z14.s, z6.s, z1.s[0]\n" + "fmla z18.s, z6.s, z2.s[0]\n" + "fmla z22.s, z6.s, z3.s[0]\n" + "fmla z26.s, z6.s, z4.s[0]\n" + "fmla z11.s, z7.s, z0.s[0]\n" + "fmla z15.s, z7.s, z1.s[0]\n" + "fmla z19.s, z7.s, z2.s[0]\n" + "fmla z23.s, z7.s, z3.s[0]\n" + "fmla z27.s, z7.s, z4.s[0]\n" + "ble 68f\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z8.s, z6.s, z0.s[1]\n" + "fmla z12.s, z6.s, z1.s[1]\n" + "fmla z16.s, z6.s, z2.s[1]\n" + "fmla z20.s, z6.s, z3.s[1]\n" + "subs x26, x26, #0x1\n" + "addvl x11, x11, #1\n" + "fmla z24.s, z6.s, z4.s[1]\n" + "fmla z9.s, z7.s, z0.s[1]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "addvl x10, x10, #1\n" + "fmla z13.s, z7.s, z1.s[1]\n" + "fmla z17.s, z7.s, z2.s[1]\n" + "addvl x9, x9, #1\n" + "fmla z21.s, z7.s, z3.s[1]\n" + "fmla z25.s, z7.s, z4.s[1]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.s, z6.s, z0.s[1]\n" + "fmla z14.s, z6.s, z1.s[1]\n" + "fmla z18.s, z6.s, z2.s[1]\n" + "fmla z22.s, z6.s, z3.s[1]\n" + "fmla z26.s, z6.s, z4.s[1]\n" + "fmla z11.s, z7.s, z0.s[1]\n" + "fmla z15.s, z7.s, z1.s[1]\n" + "fmla z19.s, z7.s, z2.s[1]\n" + "fmla z23.s, z7.s, z3.s[1]\n" + "fmla z27.s, z7.s, z4.s[1]\n" + "ble 68f\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z8.s, z6.s, z0.s[2]\n" + "fmla z12.s, z6.s, z1.s[2]\n" + "fmla z16.s, z6.s, z2.s[2]\n" + "fmla z20.s, z6.s, z3.s[2]\n" + "subs x26, x26, #0x1\n" + "addvl x11, x11, #1\n" + "fmla z24.s, z6.s, z4.s[2]\n" + "fmla z9.s, z7.s, z0.s[2]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "addvl x10, x10, #1\n" + "fmla z13.s, z7.s, z1.s[2]\n" + "fmla z17.s, z7.s, z2.s[2]\n" + "addvl x9, x9, #1\n" + "fmla z21.s, z7.s, z3.s[2]\n" + "fmla z25.s, z7.s, z4.s[2]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.s, z6.s, z0.s[2]\n" + "fmla z14.s, z6.s, z1.s[2]\n" + "fmla z18.s, z6.s, z2.s[2]\n" + "fmla z22.s, z6.s, z3.s[2]\n" + "fmla z26.s, z6.s, z4.s[2]\n" + "fmla z11.s, z7.s, z0.s[2]\n" + "fmla z15.s, z7.s, z1.s[2]\n" + "fmla z19.s, z7.s, z2.s[2]\n" + "fmla z23.s, z7.s, z3.s[2]\n" + "fmla z27.s, z7.s, z4.s[2]\n" + "ble 68f\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z8.s, z6.s, z0.s[3]\n" + "fmla z12.s, z6.s, z1.s[3]\n" + "fmla z16.s, z6.s, z2.s[3]\n" + "fmla z20.s, z6.s, z3.s[3]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z24.s, z6.s, z4.s[3]\n" + "fmla z9.s, z7.s, z0.s[3]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "addvl x9, x9, #1\n" + "fmla z13.s, z7.s, z1.s[3]\n" + "fmla z17.s, z7.s, z2.s[3]\n" + "fmla z21.s, z7.s, z3.s[3]\n" + "fmla z25.s, z7.s, z4.s[3]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.s, z6.s, z0.s[3]\n" + "fmla z14.s, z6.s, z1.s[3]\n" + "fmla z18.s, z6.s, z2.s[3]\n" + "fmla z22.s, z6.s, z3.s[3]\n" + "fmla z26.s, z6.s, z4.s[3]\n" + "fmla z11.s, z7.s, z0.s[3]\n" + "fmla z15.s, z7.s, z1.s[3]\n" + "fmla z19.s, z7.s, z2.s[3]\n" + "fmla z23.s, z7.s, z3.s[3]\n" + "fmla z27.s, z7.s, z4.s[3]\n" + "68:" // Height 5: Multiply loop: multiply skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 63b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "add x22, x23, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "tbz %x[flags], #1, 69f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rw { z1.s }, p5/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rw { z0.s }, p5/Z, [x19]\n" + "fmin z8.s, p5/M, z8.s, z1.s\n" + "fmin z9.s, p5/M, z9.s, z1.s\n" + "fmin z10.s, p5/M, z10.s, z1.s\n" + "fmin z11.s, p5/M, z11.s, z1.s\n" + "fmin z12.s, p5/M, z12.s, z1.s\n" + "fmin z13.s, p5/M, z13.s, z1.s\n" + "fmin z14.s, p5/M, z14.s, z1.s\n" + "fmin z15.s, p5/M, z15.s, z1.s\n" + "fmin z16.s, p5/M, z16.s, z1.s\n" + "fmin z17.s, p5/M, z17.s, z1.s\n" + "fmin z18.s, p5/M, z18.s, z1.s\n" + "fmin z19.s, p5/M, z19.s, z1.s\n" + "fmin z20.s, p5/M, z20.s, z1.s\n" + "fmin z21.s, p5/M, z21.s, z1.s\n" + "fmin z22.s, p5/M, z22.s, z1.s\n" + "fmin z23.s, p5/M, z23.s, z1.s\n" + "fmin z24.s, p5/M, z24.s, z1.s\n" + "fmin z25.s, p5/M, z25.s, z1.s\n" + "fmin z26.s, p5/M, z26.s, z1.s\n" + "fmin z27.s, p5/M, z27.s, z1.s\n" + "fmax z8.s, p5/M, z8.s, z0.s\n" + "fmax z9.s, p5/M, z9.s, z0.s\n" + "fmax z10.s, p5/M, z10.s, z0.s\n" + "fmax z11.s, p5/M, z11.s, z0.s\n" + "fmax z12.s, p5/M, z12.s, z0.s\n" + "fmax z13.s, p5/M, z13.s, z0.s\n" + "fmax z14.s, p5/M, z14.s, z0.s\n" + "fmax z15.s, p5/M, z15.s, z0.s\n" + "fmax z16.s, p5/M, z16.s, z0.s\n" + "fmax z17.s, p5/M, z17.s, z0.s\n" + "fmax z18.s, p5/M, z18.s, z0.s\n" + "fmax z19.s, p5/M, z19.s, z0.s\n" + "fmax z20.s, p5/M, z20.s, z0.s\n" + "fmax z21.s, p5/M, z21.s, z0.s\n" + "fmax z22.s, p5/M, z22.s, z0.s\n" + "fmax z23.s, p5/M, z23.s, z0.s\n" + "fmax z24.s, p5/M, z24.s, z0.s\n" + "fmax z25.s, p5/M, z25.s, z0.s\n" + "fmax z26.s, p5/M, z26.s, z0.s\n" + "fmax z27.s, p5/M, z27.s, z0.s\n" + "69:" // Height 5: No activation + "st1w { z8.s }, p4, [x12]\n" + "st1w { z9.s }, p3, [x12, #1, MUL VL]\n" + "st1w { z10.s }, p2, [x12, #2, MUL VL]\n" + "st1w { z11.s }, p1, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "st1w { z12.s }, p4, [x24]\n" + "st1w { z13.s }, p3, [x24, #1, MUL VL]\n" + "st1w { z14.s }, p2, [x24, #2, MUL VL]\n" + "st1w { z15.s }, p1, [x24, #3, MUL VL]\n" + "st1w { z16.s }, p4, [x23]\n" + "st1w { z17.s }, p3, [x23, #1, MUL VL]\n" + "st1w { z18.s }, p2, [x23, #2, MUL VL]\n" + "st1w { z19.s }, p1, [x23, #3, MUL VL]\n" + "st1w { z20.s }, p4, [x22]\n" + "st1w { z21.s }, p3, [x22, #1, MUL VL]\n" + "st1w { z22.s }, p2, [x22, #2, MUL VL]\n" + "st1w { z23.s }, p1, [x22, #3, MUL VL]\n" + "st1w { z24.s }, p4, [x21]\n" + "st1w { z25.s }, p3, [x21, #1, MUL VL]\n" + "st1w { z26.s }, p2, [x21, #2, MUL VL]\n" + "st1w { z27.s }, p1, [x21, #3, MUL VL]\n" + "70:" // Height 5: Writeback done + "decw x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 58b\n" + "b 86f\n" + "71:" // Height 6 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x20, #0x18\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "mov x14, %x[bias]\n" + "mov x12, %x[output_ptr]\n" + "madd %x[output_ptr], x19, x20, %x[output_ptr]\n" + "72:" // Height 6: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #2\n" + "cntw x20, ALL, MUL #3\n" + "add x9, x10, x19, LSL #2\n" + "add x28, x9, x19, LSL #2\n" + "add x19, x28, x19, LSL #2\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 73f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 73f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 73f\n" + "mov x10, x11\n" + "73:" // Height 6: B setup done + "mov x19, #0x0\n" + "whilelt p4.s, x19, x13\n" + "incw x19\n" + "whilelt p3.s, x19, x13\n" + "incw x19\n" + "whilelt p2.s, x19, x13\n" + "incw x19\n" + "whilelt p1.s, x19, x13\n" + "cbz x14, 74f\n" + "ld1w { z8.s }, p5/Z, [x14]\n" + "ld1w { z9.s }, p5/Z, [x14, #1, MUL VL]\n" + "mov z12.d, z8.d\n" + "mov z13.d, z9.d\n" + "ld1w { z10.s }, p5/Z, [x14, #2, MUL VL]\n" + "ld1w { z11.s }, p5/Z, [x14, #3, MUL VL]\n" + "mov z14.d, z10.d\n" + "mov z15.d, z11.d\n" + "mov z16.d, z8.d\n" + "mov z17.d, z9.d\n" + "addvl x14, x14, #4\n" + "mov z18.d, z10.d\n" + "mov z19.d, z11.d\n" + "mov z20.d, z8.d\n" + "mov z21.d, z9.d\n" + "mov z22.d, z10.d\n" + "mov z23.d, z11.d\n" + "mov z24.d, z8.d\n" + "mov z25.d, z9.d\n" + "mov z26.d, z10.d\n" + "mov z27.d, z11.d\n" + "mov z28.d, z8.d\n" + "mov z29.d, z9.d\n" + "mov z30.d, z10.d\n" + "mov z31.d, z11.d\n" + "b 76f\n" + "74:" // Height 6: no bias + "tbz %x[flags], #0, 75f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "ld1w { z8.s }, p4/Z, [x12]\n" + "add x22, x23, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "ld1w { z9.s }, p3/Z, [x12, #1, MUL VL]\n" + "ld1w { z10.s }, p2/Z, [x12, #2, MUL VL]\n" + "add x20, x21, x19, LSL #2\n" + "ld1w { z11.s }, p1/Z, [x12, #3, MUL VL]\n" + "ld1w { z12.s }, p4/Z, [x24]\n" + "ld1w { z13.s }, p3/Z, [x24, #1, MUL VL]\n" + "ld1w { z14.s }, p2/Z, [x24, #2, MUL VL]\n" + "ld1w { z15.s }, p1/Z, [x24, #3, MUL VL]\n" + "ld1w { z16.s }, p4/Z, [x23]\n" + "ld1w { z17.s }, p3/Z, [x23, #1, MUL VL]\n" + "ld1w { z18.s }, p2/Z, [x23, #2, MUL VL]\n" + "ld1w { z19.s }, p1/Z, [x23, #3, MUL VL]\n" + "ld1w { z20.s }, p4/Z, [x22]\n" + "ld1w { z21.s }, p3/Z, [x22, #1, MUL VL]\n" + "ld1w { z22.s }, p2/Z, [x22, #2, MUL VL]\n" + "ld1w { z23.s }, p1/Z, [x22, #3, MUL VL]\n" + "ld1w { z24.s }, p4/Z, [x21]\n" + "ld1w { z25.s }, p3/Z, [x21, #1, MUL VL]\n" + "ld1w { z26.s }, p2/Z, [x21, #2, MUL VL]\n" + "ld1w { z27.s }, p1/Z, [x21, #3, MUL VL]\n" + "ld1w { z28.s }, p4/Z, [x20]\n" + "ld1w { z29.s }, p3/Z, [x20, #1, MUL VL]\n" + "ld1w { z30.s }, p2/Z, [x20, #2, MUL VL]\n" + "ld1w { z31.s }, p1/Z, [x20, #3, MUL VL]\n" + "b 76f\n" + "75:" // Height 6: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "mov z19.b, #0x0\n" + "mov z20.b, #0x0\n" + "mov z21.b, #0x0\n" + "mov z22.b, #0x0\n" + "mov z23.b, #0x0\n" + "mov z24.b, #0x0\n" + "mov z25.b, #0x0\n" + "mov z26.b, #0x0\n" + "mov z27.b, #0x0\n" + "mov z28.b, #0x0\n" + "mov z29.b, #0x0\n" + "mov z30.b, #0x0\n" + "mov z31.b, #0x0\n" + "76:" // Height 6: setup done + "mov x27, #0x0\n" + "77:" // Height 6: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w26, [x19, x27, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 78f\n" + "ldr x20, [%x[input_ptr], x27, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x25, [x20, #0x0]\n" + "ldr x24, [x20, #0x8]\n" + "ldr x23, [x20, #0x10]\n" + "ldr x22, [x20, #0x18]\n" + "ldr x21, [x20, #0x20]\n" + "ldr x20, [x20, #0x28]\n" + "cbnz x27, 79f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x25, x25, x19, LSL #2\n" + "add x24, x24, x19, LSL #2\n" + "add x23, x23, x19, LSL #2\n" + "add x22, x22, x19, LSL #2\n" + "add x21, x21, x19, LSL #2\n" + "add x20, x20, x19, LSL #2\n" + "b 79f\n" + "78:" // Height 6: setup direct input + "mov x25, %x[input_ptr]\n" + "add x24, x25, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "add x22, x23, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "add x20, x21, x19, LSL #2\n" + "79:" // Height 6: input setup done + "cmp x26, #0x4\n" + "ble 81f\n" + "80:" // Height 6: Multiply loop: Main loop head + "whilelt p0.s, XZR, x26\n" + "ld1rqw { z0.s }, p0/Z, [x25]\n" + "ld1rqw { z1.s }, p0/Z, [x24]\n" + "sub x26, x26, #0x4\n" + "ld1rqw { z2.s }, p0/Z, [x23]\n" + "ld1rqw { z3.s }, p0/Z, [x22]\n" + "cmp x26, #0x4\n" + "add x25, x25, #0x10\n" + "ld1rqw { z4.s }, p0/Z, [x21]\n" + "ld1rqw { z5.s }, p0/Z, [x20]\n" + "add x24, x24, #0x10\n" + "add x23, x23, #0x10\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z8.s, z6.s, z0.s[0]\n" + "fmla z12.s, z6.s, z1.s[0]\n" + "fmla z16.s, z6.s, z2.s[0]\n" + "fmla z20.s, z6.s, z3.s[0]\n" + "add x22, x22, #0x10\n" + "add x21, x21, #0x10\n" + "fmla z24.s, z6.s, z4.s[0]\n" + "fmla z28.s, z6.s, z5.s[0]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "add x20, x20, #0x10\n" + "fmla z9.s, z7.s, z0.s[0]\n" + "fmla z13.s, z7.s, z1.s[0]\n" + "fmla z17.s, z7.s, z2.s[0]\n" + "fmla z21.s, z7.s, z3.s[0]\n" + "fmla z25.s, z7.s, z4.s[0]\n" + "fmla z29.s, z7.s, z5.s[0]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "fmla z10.s, z6.s, z0.s[0]\n" + "fmla z14.s, z6.s, z1.s[0]\n" + "fmla z18.s, z6.s, z2.s[0]\n" + "fmla z22.s, z6.s, z3.s[0]\n" + "fmla z26.s, z6.s, z4.s[0]\n" + "fmla z30.s, z6.s, z5.s[0]\n" + "ld1w { z6.s }, p5/Z, [x11, #1, MUL VL]\n" + "fmla z11.s, z7.s, z0.s[0]\n" + "fmla z15.s, z7.s, z1.s[0]\n" + "fmla z19.s, z7.s, z2.s[0]\n" + "fmla z23.s, z7.s, z3.s[0]\n" + "fmla z27.s, z7.s, z4.s[0]\n" + "fmla z31.s, z7.s, z5.s[0]\n" + "ld1w { z7.s }, p5/Z, [x10, #1, MUL VL]\n" + "fmla z8.s, z6.s, z0.s[1]\n" + "fmla z12.s, z6.s, z1.s[1]\n" + "fmla z16.s, z6.s, z2.s[1]\n" + "fmla z20.s, z6.s, z3.s[1]\n" + "fmla z24.s, z6.s, z4.s[1]\n" + "fmla z28.s, z6.s, z5.s[1]\n" + "ld1w { z6.s }, p5/Z, [x9, #1, MUL VL]\n" + "fmla z9.s, z7.s, z0.s[1]\n" + "fmla z13.s, z7.s, z1.s[1]\n" + "fmla z17.s, z7.s, z2.s[1]\n" + "fmla z21.s, z7.s, z3.s[1]\n" + "fmla z25.s, z7.s, z4.s[1]\n" + "fmla z29.s, z7.s, z5.s[1]\n" + "ld1w { z7.s }, p5/Z, [x28, #1, MUL VL]\n" + "fmla z10.s, z6.s, z0.s[1]\n" + "fmla z14.s, z6.s, z1.s[1]\n" + "fmla z18.s, z6.s, z2.s[1]\n" + "fmla z22.s, z6.s, z3.s[1]\n" + "fmla z26.s, z6.s, z4.s[1]\n" + "fmla z30.s, z6.s, z5.s[1]\n" + "ld1w { z6.s }, p5/Z, [x11, #2, MUL VL]\n" + "fmla z11.s, z7.s, z0.s[1]\n" + "fmla z15.s, z7.s, z1.s[1]\n" + "fmla z19.s, z7.s, z2.s[1]\n" + "fmla z23.s, z7.s, z3.s[1]\n" + "fmla z27.s, z7.s, z4.s[1]\n" + "fmla z31.s, z7.s, z5.s[1]\n" + "ld1w { z7.s }, p5/Z, [x10, #2, MUL VL]\n" + "fmla z8.s, z6.s, z0.s[2]\n" + "fmla z12.s, z6.s, z1.s[2]\n" + "fmla z16.s, z6.s, z2.s[2]\n" + "fmla z20.s, z6.s, z3.s[2]\n" + "fmla z24.s, z6.s, z4.s[2]\n" + "fmla z28.s, z6.s, z5.s[2]\n" + "ld1w { z6.s }, p5/Z, [x9, #2, MUL VL]\n" + "fmla z9.s, z7.s, z0.s[2]\n" + "fmla z13.s, z7.s, z1.s[2]\n" + "fmla z17.s, z7.s, z2.s[2]\n" + "fmla z21.s, z7.s, z3.s[2]\n" + "fmla z25.s, z7.s, z4.s[2]\n" + "fmla z29.s, z7.s, z5.s[2]\n" + "ld1w { z7.s }, p5/Z, [x28, #2, MUL VL]\n" + "fmla z10.s, z6.s, z0.s[2]\n" + "fmla z14.s, z6.s, z1.s[2]\n" + "fmla z18.s, z6.s, z2.s[2]\n" + "fmla z22.s, z6.s, z3.s[2]\n" + "fmla z26.s, z6.s, z4.s[2]\n" + "fmla z30.s, z6.s, z5.s[2]\n" + "ld1w { z6.s }, p5/Z, [x11, #3, MUL VL]\n" + "addvl x11, x11, #4\n" + "fmla z11.s, z7.s, z0.s[2]\n" + "fmla z15.s, z7.s, z1.s[2]\n" + "fmla z19.s, z7.s, z2.s[2]\n" + "fmla z23.s, z7.s, z3.s[2]\n" + "fmla z27.s, z7.s, z4.s[2]\n" + "fmla z31.s, z7.s, z5.s[2]\n" + "ld1w { z7.s }, p5/Z, [x10, #3, MUL VL]\n" + "addvl x10, x10, #4\n" + "fmla z8.s, z6.s, z0.s[3]\n" + "fmla z12.s, z6.s, z1.s[3]\n" + "fmla z16.s, z6.s, z2.s[3]\n" + "fmla z20.s, z6.s, z3.s[3]\n" + "fmla z24.s, z6.s, z4.s[3]\n" + "fmla z28.s, z6.s, z5.s[3]\n" + "ld1w { z6.s }, p5/Z, [x9, #3, MUL VL]\n" + "addvl x9, x9, #4\n" + "fmla z9.s, z7.s, z0.s[3]\n" + "fmla z13.s, z7.s, z1.s[3]\n" + "fmla z17.s, z7.s, z2.s[3]\n" + "fmla z21.s, z7.s, z3.s[3]\n" + "fmla z25.s, z7.s, z4.s[3]\n" + "fmla z29.s, z7.s, z5.s[3]\n" + "ld1w { z7.s }, p5/Z, [x28, #3, MUL VL]\n" + "addvl x28, x28, #4\n" + "fmla z10.s, z6.s, z0.s[3]\n" + "fmla z14.s, z6.s, z1.s[3]\n" + "fmla z18.s, z6.s, z2.s[3]\n" + "fmla z22.s, z6.s, z3.s[3]\n" + "fmla z26.s, z6.s, z4.s[3]\n" + "fmla z30.s, z6.s, z5.s[3]\n" + "fmla z11.s, z7.s, z0.s[3]\n" + "fmla z15.s, z7.s, z1.s[3]\n" + "fmla z19.s, z7.s, z2.s[3]\n" + "fmla z23.s, z7.s, z3.s[3]\n" + "fmla z27.s, z7.s, z4.s[3]\n" + "fmla z31.s, z7.s, z5.s[3]\n" + "bgt 80b\n" + "81:" // Height 6: Multiply loop: Single iteration only + "whilelt p0.s, XZR, x26\n" + "ld1rqw { z0.s }, p0/Z, [x25]\n" + "ld1rqw { z1.s }, p0/Z, [x24]\n" + "subs x26, x26, #0x1\n" + "ld1rqw { z2.s }, p0/Z, [x23]\n" + "ld1rqw { z3.s }, p0/Z, [x22]\n" + "ld1rqw { z4.s }, p0/Z, [x21]\n" + "ld1rqw { z5.s }, p0/Z, [x20]\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z8.s, z6.s, z0.s[0]\n" + "fmla z12.s, z6.s, z1.s[0]\n" + "fmla z16.s, z6.s, z2.s[0]\n" + "fmla z20.s, z6.s, z3.s[0]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z24.s, z6.s, z4.s[0]\n" + "fmla z28.s, z6.s, z5.s[0]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "addvl x9, x9, #1\n" + "fmla z9.s, z7.s, z0.s[0]\n" + "fmla z13.s, z7.s, z1.s[0]\n" + "fmla z17.s, z7.s, z2.s[0]\n" + "fmla z21.s, z7.s, z3.s[0]\n" + "fmla z25.s, z7.s, z4.s[0]\n" + "fmla z29.s, z7.s, z5.s[0]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.s, z6.s, z0.s[0]\n" + "fmla z14.s, z6.s, z1.s[0]\n" + "fmla z18.s, z6.s, z2.s[0]\n" + "fmla z22.s, z6.s, z3.s[0]\n" + "fmla z26.s, z6.s, z4.s[0]\n" + "fmla z30.s, z6.s, z5.s[0]\n" + "fmla z11.s, z7.s, z0.s[0]\n" + "fmla z15.s, z7.s, z1.s[0]\n" + "fmla z19.s, z7.s, z2.s[0]\n" + "fmla z23.s, z7.s, z3.s[0]\n" + "fmla z27.s, z7.s, z4.s[0]\n" + "fmla z31.s, z7.s, z5.s[0]\n" + "ble 82f\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z8.s, z6.s, z0.s[1]\n" + "fmla z12.s, z6.s, z1.s[1]\n" + "fmla z16.s, z6.s, z2.s[1]\n" + "fmla z20.s, z6.s, z3.s[1]\n" + "subs x26, x26, #0x1\n" + "addvl x11, x11, #1\n" + "fmla z24.s, z6.s, z4.s[1]\n" + "fmla z28.s, z6.s, z5.s[1]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "addvl x10, x10, #1\n" + "fmla z9.s, z7.s, z0.s[1]\n" + "fmla z13.s, z7.s, z1.s[1]\n" + "addvl x9, x9, #1\n" + "fmla z17.s, z7.s, z2.s[1]\n" + "fmla z21.s, z7.s, z3.s[1]\n" + "fmla z25.s, z7.s, z4.s[1]\n" + "fmla z29.s, z7.s, z5.s[1]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.s, z6.s, z0.s[1]\n" + "fmla z14.s, z6.s, z1.s[1]\n" + "fmla z18.s, z6.s, z2.s[1]\n" + "fmla z22.s, z6.s, z3.s[1]\n" + "fmla z26.s, z6.s, z4.s[1]\n" + "fmla z30.s, z6.s, z5.s[1]\n" + "fmla z11.s, z7.s, z0.s[1]\n" + "fmla z15.s, z7.s, z1.s[1]\n" + "fmla z19.s, z7.s, z2.s[1]\n" + "fmla z23.s, z7.s, z3.s[1]\n" + "fmla z27.s, z7.s, z4.s[1]\n" + "fmla z31.s, z7.s, z5.s[1]\n" + "ble 82f\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z8.s, z6.s, z0.s[2]\n" + "fmla z12.s, z6.s, z1.s[2]\n" + "fmla z16.s, z6.s, z2.s[2]\n" + "fmla z20.s, z6.s, z3.s[2]\n" + "subs x26, x26, #0x1\n" + "addvl x11, x11, #1\n" + "fmla z24.s, z6.s, z4.s[2]\n" + "fmla z28.s, z6.s, z5.s[2]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "addvl x10, x10, #1\n" + "fmla z9.s, z7.s, z0.s[2]\n" + "fmla z13.s, z7.s, z1.s[2]\n" + "addvl x9, x9, #1\n" + "fmla z17.s, z7.s, z2.s[2]\n" + "fmla z21.s, z7.s, z3.s[2]\n" + "fmla z25.s, z7.s, z4.s[2]\n" + "fmla z29.s, z7.s, z5.s[2]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.s, z6.s, z0.s[2]\n" + "fmla z14.s, z6.s, z1.s[2]\n" + "fmla z18.s, z6.s, z2.s[2]\n" + "fmla z22.s, z6.s, z3.s[2]\n" + "fmla z26.s, z6.s, z4.s[2]\n" + "fmla z30.s, z6.s, z5.s[2]\n" + "fmla z11.s, z7.s, z0.s[2]\n" + "fmla z15.s, z7.s, z1.s[2]\n" + "fmla z19.s, z7.s, z2.s[2]\n" + "fmla z23.s, z7.s, z3.s[2]\n" + "fmla z27.s, z7.s, z4.s[2]\n" + "fmla z31.s, z7.s, z5.s[2]\n" + "ble 82f\n" + "ld1w { z6.s }, p5/Z, [x11]\n" + "ld1w { z7.s }, p5/Z, [x10]\n" + "fmla z8.s, z6.s, z0.s[3]\n" + "fmla z12.s, z6.s, z1.s[3]\n" + "fmla z16.s, z6.s, z2.s[3]\n" + "fmla z20.s, z6.s, z3.s[3]\n" + "addvl x11, x11, #1\n" + "addvl x10, x10, #1\n" + "fmla z24.s, z6.s, z4.s[3]\n" + "fmla z28.s, z6.s, z5.s[3]\n" + "ld1w { z6.s }, p5/Z, [x9]\n" + "addvl x9, x9, #1\n" + "fmla z9.s, z7.s, z0.s[3]\n" + "fmla z13.s, z7.s, z1.s[3]\n" + "fmla z17.s, z7.s, z2.s[3]\n" + "fmla z21.s, z7.s, z3.s[3]\n" + "fmla z25.s, z7.s, z4.s[3]\n" + "fmla z29.s, z7.s, z5.s[3]\n" + "ld1w { z7.s }, p5/Z, [x28]\n" + "addvl x28, x28, #1\n" + "fmla z10.s, z6.s, z0.s[3]\n" + "fmla z14.s, z6.s, z1.s[3]\n" + "fmla z18.s, z6.s, z2.s[3]\n" + "fmla z22.s, z6.s, z3.s[3]\n" + "fmla z26.s, z6.s, z4.s[3]\n" + "fmla z30.s, z6.s, z5.s[3]\n" + "fmla z11.s, z7.s, z0.s[3]\n" + "fmla z15.s, z7.s, z1.s[3]\n" + "fmla z19.s, z7.s, z2.s[3]\n" + "fmla z23.s, z7.s, z3.s[3]\n" + "fmla z27.s, z7.s, z4.s[3]\n" + "fmla z31.s, z7.s, z5.s[3]\n" + "82:" // Height 6: Multiply loop: multiply skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x27, x27, #0x1\n" + "cmp x27, x19\n" + "bne 77b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x24, x12, x19, LSL #2\n" + "add x23, x24, x19, LSL #2\n" + "add x22, x23, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "add x20, x21, x19, LSL #2\n" + "tbz %x[flags], #1, 83f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rw { z1.s }, p5/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rw { z0.s }, p5/Z, [x19]\n" + "fmin z8.s, p5/M, z8.s, z1.s\n" + "fmin z9.s, p5/M, z9.s, z1.s\n" + "fmin z10.s, p5/M, z10.s, z1.s\n" + "fmin z11.s, p5/M, z11.s, z1.s\n" + "fmin z12.s, p5/M, z12.s, z1.s\n" + "fmin z13.s, p5/M, z13.s, z1.s\n" + "fmin z14.s, p5/M, z14.s, z1.s\n" + "fmin z15.s, p5/M, z15.s, z1.s\n" + "fmin z16.s, p5/M, z16.s, z1.s\n" + "fmin z17.s, p5/M, z17.s, z1.s\n" + "fmin z18.s, p5/M, z18.s, z1.s\n" + "fmin z19.s, p5/M, z19.s, z1.s\n" + "fmin z20.s, p5/M, z20.s, z1.s\n" + "fmin z21.s, p5/M, z21.s, z1.s\n" + "fmin z22.s, p5/M, z22.s, z1.s\n" + "fmin z23.s, p5/M, z23.s, z1.s\n" + "fmin z24.s, p5/M, z24.s, z1.s\n" + "fmin z25.s, p5/M, z25.s, z1.s\n" + "fmin z26.s, p5/M, z26.s, z1.s\n" + "fmin z27.s, p5/M, z27.s, z1.s\n" + "fmin z28.s, p5/M, z28.s, z1.s\n" + "fmin z29.s, p5/M, z29.s, z1.s\n" + "fmin z30.s, p5/M, z30.s, z1.s\n" + "fmin z31.s, p5/M, z31.s, z1.s\n" + "fmax z8.s, p5/M, z8.s, z0.s\n" + "fmax z9.s, p5/M, z9.s, z0.s\n" + "fmax z10.s, p5/M, z10.s, z0.s\n" + "fmax z11.s, p5/M, z11.s, z0.s\n" + "fmax z12.s, p5/M, z12.s, z0.s\n" + "fmax z13.s, p5/M, z13.s, z0.s\n" + "fmax z14.s, p5/M, z14.s, z0.s\n" + "fmax z15.s, p5/M, z15.s, z0.s\n" + "fmax z16.s, p5/M, z16.s, z0.s\n" + "fmax z17.s, p5/M, z17.s, z0.s\n" + "fmax z18.s, p5/M, z18.s, z0.s\n" + "fmax z19.s, p5/M, z19.s, z0.s\n" + "fmax z20.s, p5/M, z20.s, z0.s\n" + "fmax z21.s, p5/M, z21.s, z0.s\n" + "fmax z22.s, p5/M, z22.s, z0.s\n" + "fmax z23.s, p5/M, z23.s, z0.s\n" + "fmax z24.s, p5/M, z24.s, z0.s\n" + "fmax z25.s, p5/M, z25.s, z0.s\n" + "fmax z26.s, p5/M, z26.s, z0.s\n" + "fmax z27.s, p5/M, z27.s, z0.s\n" + "fmax z28.s, p5/M, z28.s, z0.s\n" + "fmax z29.s, p5/M, z29.s, z0.s\n" + "fmax z30.s, p5/M, z30.s, z0.s\n" + "fmax z31.s, p5/M, z31.s, z0.s\n" + "83:" // Height 6: No activation + "st1w { z8.s }, p4, [x12]\n" + "st1w { z9.s }, p3, [x12, #1, MUL VL]\n" + "st1w { z10.s }, p2, [x12, #2, MUL VL]\n" + "st1w { z11.s }, p1, [x12, #3, MUL VL]\n" + "addvl x12, x12, #4\n" + "st1w { z12.s }, p4, [x24]\n" + "st1w { z13.s }, p3, [x24, #1, MUL VL]\n" + "st1w { z14.s }, p2, [x24, #2, MUL VL]\n" + "st1w { z15.s }, p1, [x24, #3, MUL VL]\n" + "st1w { z16.s }, p4, [x23]\n" + "st1w { z17.s }, p3, [x23, #1, MUL VL]\n" + "st1w { z18.s }, p2, [x23, #2, MUL VL]\n" + "st1w { z19.s }, p1, [x23, #3, MUL VL]\n" + "st1w { z20.s }, p4, [x22]\n" + "st1w { z21.s }, p3, [x22, #1, MUL VL]\n" + "st1w { z22.s }, p2, [x22, #2, MUL VL]\n" + "st1w { z23.s }, p1, [x22, #3, MUL VL]\n" + "st1w { z24.s }, p4, [x21]\n" + "st1w { z25.s }, p3, [x21, #1, MUL VL]\n" + "st1w { z26.s }, p2, [x21, #2, MUL VL]\n" + "st1w { z27.s }, p1, [x21, #3, MUL VL]\n" + "st1w { z28.s }, p4, [x20]\n" + "st1w { z29.s }, p3, [x20, #1, MUL VL]\n" + "st1w { z30.s }, p2, [x20, #2, MUL VL]\n" + "st1w { z31.s }, p1, [x20, #3, MUL VL]\n" + "84:" // Height 6: Writeback done + "decw x13, ALL, MUL #4\n" + "cmp x13, XZR\n" + "bgt 72b\n" + "subs %x[M], %x[M], #0x6\n" + "beq 86f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 85f\n" + "add x20, x20, #0x6\n" + "str x20, [%x[args_ptr], %[offsetof_input_offset]]\n" + "b 1b\n" + "85:" // Update direct input + "mov x19, #0x18\n" + "madd %x[input_ptr], x19, x20, %x[input_ptr]\n" + "b 1b\n" + "86:" // Exit + : [M] "+&r" (M), [input_ptr] "+&r" (input_ptr), [output_ptr] "+&r" (output_ptr) + : [args_ptr] "r" (&ka), [bias] "r" (bias), [flags] "r" (flags), [offset_max] "I" (offsetof(KernelArgs, maxval)), [offset_min] "I" (offsetof(KernelArgs, minval)), [offsetof_B_ptr] "I" (offsetof(KernelArgs, B_ptr)), [offsetof_B_stride] "I" (offsetof(KernelArgs, B_stride)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_cur_B_ptr] "I" (offsetof(KernelArgs, cur_B_ptr)), [offsetof_input_initial_col] "I" (offsetof(KernelArgs, input_initial_col)), [offsetof_input_offset] "I" (offsetof(KernelArgs, input_offset)), [offsetof_num_strings] "I" (offsetof(KernelArgs, num_strings)), [offsetof_output_offset] "I" (offsetof(KernelArgs, output_offset)), [offsetof_string_lengths] "I" (offsetof(KernelArgs, string_lengths)) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "x9", "x10", "x11", "x12", "x13", "x14", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // namespace arm_gemm +#endif // ARM_COMPUTE_ENABLE_SVE diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp new file mode 100644 index 0000000000..3ee3e31206 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#pragma once +#ifdef ARM_COMPUTE_ENABLE_SVE + +#include "../std_transforms_sve.hpp" +#include "../bfloat.hpp" +#include "../kernel_weight_format.hpp" +#include "../performance_parameters.hpp" + +#define ARGLIST \ + unsigned int, const unsigned int *, \ + IndirectInputArg, \ + size_t, size_t, \ + const bfloat16 *, \ + size_t, \ + IndirectOutputArg, \ + const float *, Activation, bool + +namespace arm_gemm +{ +// Actual kernel implementations +void sve_ffhybrid_fp32bf16fp32_mmla_4x6VL( ARGLIST ); + +class cls_sve_ffhybrid_fp32bf16fp32_mmla_4x6VL +{ +public: + typedef float lhs_operand_type; + typedef bfloat16 rhs_operand_type; + typedef float result_type; + + typedef void (*kern_type)( ARGLIST ); + + /* Kernel blocking parameters */ + static constexpr unsigned int out_height() + { + return 4; + } + static unsigned int stripe_width() + { + return get_vector_length() * 1; + } + + static KernelWeightFormat kernel_weight_format() + { + return KernelWeightFormat::VL2VL_BL64_BF16; + } + + static unsigned int out_width() + { + return get_vector_length() * 6; + } + + static constexpr unsigned int k_unroll() + { + return 4; + } + + static constexpr bool supports_accumulate() + { + return true; + } + + StdTransformsSVE transforms = {}; + template + static inline PerformanceParameters get_performance_parameters(const CPUInfo *ci) + { + if (std::is_same::value) { + switch (ci->get_cpu_model()) { + default: + return { 32.35 }; + } + } + + return { 1.0 }; + } + + // Default to the generic kernel + kern_type kernel=sve_ffhybrid_fp32bf16fp32_mmla_4x6VL; + cls_sve_ffhybrid_fp32bf16fp32_mmla_4x6VL(const CPUInfo *) + { + } +}; + +} // namespace arm_gemm + +#undef ARGLIST +#endif // ARM_COMPUTE_ENABLE_SVE diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL/generic.cpp new file mode 100644 index 0000000000..8e3676a007 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL/generic.cpp @@ -0,0 +1,1464 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#ifdef ARM_COMPUTE_ENABLE_SVE + +#include "arm_gemm.hpp" +#include "../../utils.hpp" +#include "../../bfloat.hpp" + +#include +#include + +namespace arm_gemm { + +void sve_ffhybrid_fp32bf16fp32_mmla_4x6VL ( + unsigned int num_strings, const unsigned int *string_lengths, IndirectInputArg A_arg, + size_t M, size_t N, const bfloat16 *B_ptr, size_t B_stride, IndirectOutputArg output_arg, + const float *bias, Activation act, bool accumulate +) +{ + struct KernelArgs { + float maxval = static_cast(std::numeric_limits::infinity()); + float minval = - static_cast(std::numeric_limits::infinity()); + unsigned int num_strings = {}; + const unsigned int *string_lengths = {}; + size_t N = {}; + const bfloat16 *B_ptr = {}; + const bfloat16 *cur_B_ptr = {}; + size_t B_stride = {}; + size_t output_offset = {}; + size_t input_initial_col = {}; + size_t input_offset = {}; + } ka; + + unsigned long flags=0; + void *output_ptr; + void *input_ptr; + + if (output_arg.is_indirect) { + output_ptr=(void *)(output_arg.indirect.ptr); + ka.output_offset=output_arg.indirect.offset; + flags |= 0x4; + } else { + output_ptr=(void *)(output_arg.direct.base); + ka.output_offset=output_arg.direct.stride; + } + + if (A_arg.is_indirect) { + input_ptr=(void *)(A_arg.indirect.ptr); + ka.input_offset=A_arg.indirect.start_row; + ka.input_initial_col=A_arg.indirect.start_col; + flags |= 0x8; + } else { + assert(num_strings==1); + input_ptr=(void *)(A_arg.direct.base); + ka.input_offset=A_arg.direct.stride; + } + if (accumulate) { + flags |= 0x1; + } + ka.num_strings = num_strings; + ka.string_lengths = string_lengths; + ka.N = N; + ka.B_ptr = B_ptr; + ka.B_stride = B_stride; + switch(act.type) { + default: + case Activation::Type::None: + break; + case Activation::Type::BoundedReLU: + ka.maxval = static_cast(act.param1); + /* fall through */ + case Activation::Type::ReLU: + ka.minval = 0; + flags |= 0x2; + break; + } + __asm__ __volatile__( + "ptrue p7.b\n" + "1:" // Row loop + "cmp %x[M], #0x4\n" + "bge 43f\n" + "cmp %x[M], #0x2\n" + "bgt 29f\n" + "beq 15f\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "2:" // Height 1: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "cntw x20, ALL, MUL #5\n" + "add x27, x28, x19, LSL #1\n" + "add x26, x27, x19, LSL #1\n" + "add x19, x26, x19, LSL #1\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 3f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x26, x11\n" + "bgt 3f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x27, x11\n" + "bgt 3f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 3f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 3f\n" + "mov x10, x11\n" + "3:" // Height 1: B setup done + "mov x19, #0x0\n" + "whilelt p6.s, x19, x13\n" + "incw x19\n" + "whilelt p5.s, x19, x13\n" + "incw x19\n" + "whilelt p4.s, x19, x13\n" + "incw x19\n" + "whilelt p3.s, x19, x13\n" + "incw x19\n" + "whilelt p2.s, x19, x13\n" + "incw x19\n" + "whilelt p1.s, x19, x13\n" + "cbz x14, 4f\n" + "ld1w { z8.s }, p7/Z, [x14]\n" + "ld1w { z9.s }, p7/Z, [x14, #1, MUL VL]\n" + "zip2 z14.d, z8.d, z8.d\n" + "zip1 z8.d, z8.d, z8.d\n" + "ld1w { z10.s }, p7/Z, [x14, #2, MUL VL]\n" + "ld1w { z11.s }, p7/Z, [x14, #3, MUL VL]\n" + "zip2 z15.d, z9.d, z9.d\n" + "zip1 z9.d, z9.d, z9.d\n" + "ld1w { z12.s }, p7/Z, [x14, #4, MUL VL]\n" + "ld1w { z13.s }, p7/Z, [x14, #5, MUL VL]\n" + "zip2 z16.d, z10.d, z10.d\n" + "zip1 z10.d, z10.d, z10.d\n" + "zip2 z17.d, z11.d, z11.d\n" + "zip1 z11.d, z11.d, z11.d\n" + "addvl x14, x14, #6\n" + "zip2 z18.d, z12.d, z12.d\n" + "zip1 z12.d, z12.d, z12.d\n" + "zip2 z19.d, z13.d, z13.d\n" + "zip1 z13.d, z13.d, z13.d\n" + "b 6f\n" + "4:" // Height 1: no bias + "tbz %x[flags], #0, 5f\n" + "ld1w { z9.s }, p6/Z, [x12]\n" + "ld1w { z10.s }, p5/Z, [x12, #1, MUL VL]\n" + "zip1 z8.d, z9.d, z14.d\n" + "zip2 z14.d, z9.d, z14.d\n" + "ld1w { z11.s }, p4/Z, [x12, #2, MUL VL]\n" + "ld1w { z12.s }, p3/Z, [x12, #3, MUL VL]\n" + "zip1 z9.d, z10.d, z15.d\n" + "zip2 z15.d, z10.d, z15.d\n" + "ld1w { z13.s }, p2/Z, [x12, #4, MUL VL]\n" + "ld1w { z20.s }, p1/Z, [x12, #5, MUL VL]\n" + "zip1 z10.d, z11.d, z16.d\n" + "zip2 z16.d, z11.d, z16.d\n" + "zip1 z11.d, z12.d, z17.d\n" + "zip2 z17.d, z12.d, z17.d\n" + "zip1 z12.d, z13.d, z18.d\n" + "zip2 z18.d, z13.d, z18.d\n" + "zip1 z13.d, z20.d, z19.d\n" + "zip2 z19.d, z20.d, z19.d\n" + "b 6f\n" + "5:" // Height 1: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "mov z19.b, #0x0\n" + "6:" // Height 1: setup done + "mov x25, #0x0\n" + "7:" // Height 1: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w24, [x19, x25, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 8f\n" + "ldr x20, [%x[input_ptr], x25, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x23, [x20, #0x0]\n" + "cbnz x25, 9f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x23, x23, x19, LSL #2\n" + "b 9f\n" + "8:" // Height 1: setup direct input + "mov x23, %x[input_ptr]\n" + "9:" // Height 1: input setup done + "cmp x24, #0x4\n" + "ble 11f\n" + "10:" // Height 1: Multiply loop: Main loop head + "whilelt p0.s, XZR, x24\n" + "ld1rqw { z0.s }, p0/Z, [x23]\n" + ".inst 0x658abc00 // bfcvt z0.h, p7/M, z0.s\n" + "uzp1 z0.h, z0.h, z0.h\n" + "ld1h { z4.h }, p7/Z, [x11]\n" + "ld1h { z5.h }, p7/Z, [x11, #1, MUL VL]\n" + ".inst 0x6464e408 // bfmmla z8.s, z0.h, z4.h\n" + ".inst 0x6465e40e // bfmmla z14.s, z0.h, z5.h\n" + "ld1h { z6.h }, p7/Z, [x10]\n" + "ld1h { z7.h }, p7/Z, [x10, #1, MUL VL]\n" + ".inst 0x6466e409 // bfmmla z9.s, z0.h, z6.h\n" + ".inst 0x6467e40f // bfmmla z15.s, z0.h, z7.h\n" + "ld1h { z4.h }, p7/Z, [x9]\n" + "ld1h { z5.h }, p7/Z, [x9, #1, MUL VL]\n" + ".inst 0x6464e40a // bfmmla z10.s, z0.h, z4.h\n" + ".inst 0x6465e410 // bfmmla z16.s, z0.h, z5.h\n" + "ld1h { z6.h }, p7/Z, [x28]\n" + "ld1h { z7.h }, p7/Z, [x28, #1, MUL VL]\n" + ".inst 0x6466e40b // bfmmla z11.s, z0.h, z6.h\n" + ".inst 0x6467e411 // bfmmla z17.s, z0.h, z7.h\n" + "ld1h { z4.h }, p7/Z, [x27]\n" + "ld1h { z5.h }, p7/Z, [x27, #1, MUL VL]\n" + "sub x24, x24, #0x4\n" + "cmp x24, #0x4\n" + "ld1h { z6.h }, p7/Z, [x26]\n" + "ld1h { z7.h }, p7/Z, [x26, #1, MUL VL]\n" + ".inst 0x6464e40c // bfmmla z12.s, z0.h, z4.h\n" + ".inst 0x6465e412 // bfmmla z18.s, z0.h, z5.h\n" + ".inst 0x6466e40d // bfmmla z13.s, z0.h, z6.h\n" + ".inst 0x6467e413 // bfmmla z19.s, z0.h, z7.h\n" + "add x23, x23, #0x10\n" + "addvl x11, x11, #2\n" + "addvl x10, x10, #2\n" + "addvl x9, x9, #2\n" + "addvl x28, x28, #2\n" + "addvl x27, x27, #2\n" + "addvl x26, x26, #2\n" + "bgt 10b\n" + "11:" // Height 1: Multiply loop: Single iteration only + "whilelt p0.s, XZR, x24\n" + "ld1rqw { z0.s }, p0/Z, [x23]\n" + ".inst 0x658abc00 // bfcvt z0.h, p7/M, z0.s\n" + "uzp1 z0.h, z0.h, z0.h\n" + "ld1h { z4.h }, p7/Z, [x11]\n" + "ld1h { z5.h }, p7/Z, [x11, #1, MUL VL]\n" + ".inst 0x6464e408 // bfmmla z8.s, z0.h, z4.h\n" + ".inst 0x6465e40e // bfmmla z14.s, z0.h, z5.h\n" + "ld1h { z6.h }, p7/Z, [x10]\n" + "ld1h { z7.h }, p7/Z, [x10, #1, MUL VL]\n" + ".inst 0x6466e409 // bfmmla z9.s, z0.h, z6.h\n" + ".inst 0x6467e40f // bfmmla z15.s, z0.h, z7.h\n" + "ld1h { z4.h }, p7/Z, [x9]\n" + "ld1h { z5.h }, p7/Z, [x9, #1, MUL VL]\n" + ".inst 0x6464e40a // bfmmla z10.s, z0.h, z4.h\n" + ".inst 0x6465e410 // bfmmla z16.s, z0.h, z5.h\n" + "ld1h { z6.h }, p7/Z, [x28]\n" + "ld1h { z7.h }, p7/Z, [x28, #1, MUL VL]\n" + ".inst 0x6466e40b // bfmmla z11.s, z0.h, z6.h\n" + ".inst 0x6467e411 // bfmmla z17.s, z0.h, z7.h\n" + "ld1h { z4.h }, p7/Z, [x27]\n" + "ld1h { z5.h }, p7/Z, [x27, #1, MUL VL]\n" + ".inst 0x6464e40c // bfmmla z12.s, z0.h, z4.h\n" + ".inst 0x6465e412 // bfmmla z18.s, z0.h, z5.h\n" + "ld1h { z6.h }, p7/Z, [x26]\n" + "ld1h { z7.h }, p7/Z, [x26, #1, MUL VL]\n" + ".inst 0x6466e40d // bfmmla z13.s, z0.h, z6.h\n" + ".inst 0x6467e413 // bfmmla z19.s, z0.h, z7.h\n" + "addvl x11, x11, #2\n" + "addvl x10, x10, #2\n" + "addvl x9, x9, #2\n" + "addvl x28, x28, #2\n" + "addvl x27, x27, #2\n" + "addvl x26, x26, #2\n" + "12:" // Height 1: Multiply loop: multiply skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x25, x25, #0x1\n" + "cmp x25, x19\n" + "bne 7b\n" + "uzp1 z8.d, z8.d, z14.d\n" + "uzp1 z9.d, z9.d, z15.d\n" + "uzp1 z10.d, z10.d, z16.d\n" + "uzp1 z11.d, z11.d, z17.d\n" + "uzp1 z12.d, z12.d, z18.d\n" + "uzp1 z13.d, z13.d, z19.d\n" + "tbz %x[flags], #1, 13f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rw { z1.s }, p7/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rw { z0.s }, p7/Z, [x19]\n" + "fmin z8.s, p7/M, z8.s, z1.s\n" + "fmin z9.s, p7/M, z9.s, z1.s\n" + "fmin z10.s, p7/M, z10.s, z1.s\n" + "fmin z11.s, p7/M, z11.s, z1.s\n" + "fmin z12.s, p7/M, z12.s, z1.s\n" + "fmin z13.s, p7/M, z13.s, z1.s\n" + "fmax z8.s, p7/M, z8.s, z0.s\n" + "fmax z9.s, p7/M, z9.s, z0.s\n" + "fmax z10.s, p7/M, z10.s, z0.s\n" + "fmax z11.s, p7/M, z11.s, z0.s\n" + "fmax z12.s, p7/M, z12.s, z0.s\n" + "fmax z13.s, p7/M, z13.s, z0.s\n" + "13:" // Height 1: No activation + "st1w { z8.s }, p6, [x12]\n" + "st1w { z9.s }, p5, [x12, #1, MUL VL]\n" + "st1w { z10.s }, p4, [x12, #2, MUL VL]\n" + "st1w { z11.s }, p3, [x12, #3, MUL VL]\n" + "st1w { z12.s }, p2, [x12, #4, MUL VL]\n" + "st1w { z13.s }, p1, [x12, #5, MUL VL]\n" + "addvl x12, x12, #6\n" + "14:" // Height 1: Writeback done + "decw x13, ALL, MUL #6\n" + "cmp x13, XZR\n" + "bgt 2b\n" + "b 58f\n" + "15:" // Height 2 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "16:" // Height 2: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "cntw x20, ALL, MUL #5\n" + "add x27, x28, x19, LSL #1\n" + "add x26, x27, x19, LSL #1\n" + "add x19, x26, x19, LSL #1\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 17f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x26, x11\n" + "bgt 17f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x27, x11\n" + "bgt 17f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 17f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 17f\n" + "mov x10, x11\n" + "17:" // Height 2: B setup done + "mov x19, #0x0\n" + "whilelt p6.s, x19, x13\n" + "incw x19\n" + "whilelt p5.s, x19, x13\n" + "incw x19\n" + "whilelt p4.s, x19, x13\n" + "incw x19\n" + "whilelt p3.s, x19, x13\n" + "incw x19\n" + "whilelt p2.s, x19, x13\n" + "incw x19\n" + "whilelt p1.s, x19, x13\n" + "cbz x14, 18f\n" + "ld1w { z8.s }, p7/Z, [x14]\n" + "ld1w { z9.s }, p7/Z, [x14, #1, MUL VL]\n" + "zip2 z14.d, z8.d, z8.d\n" + "zip1 z8.d, z8.d, z8.d\n" + "ld1w { z10.s }, p7/Z, [x14, #2, MUL VL]\n" + "ld1w { z11.s }, p7/Z, [x14, #3, MUL VL]\n" + "zip2 z15.d, z9.d, z9.d\n" + "zip1 z9.d, z9.d, z9.d\n" + "ld1w { z12.s }, p7/Z, [x14, #4, MUL VL]\n" + "ld1w { z13.s }, p7/Z, [x14, #5, MUL VL]\n" + "zip2 z16.d, z10.d, z10.d\n" + "zip1 z10.d, z10.d, z10.d\n" + "zip2 z17.d, z11.d, z11.d\n" + "zip1 z11.d, z11.d, z11.d\n" + "addvl x14, x14, #6\n" + "zip2 z18.d, z12.d, z12.d\n" + "zip1 z12.d, z12.d, z12.d\n" + "zip2 z19.d, z13.d, z13.d\n" + "zip1 z13.d, z13.d, z13.d\n" + "b 20f\n" + "18:" // Height 2: no bias + "tbz %x[flags], #0, 19f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x22, x12, x19, LSL #2\n" + "ld1w { z9.s }, p6/Z, [x12]\n" + "ld1w { z10.s }, p5/Z, [x12, #1, MUL VL]\n" + "ld1w { z11.s }, p4/Z, [x12, #2, MUL VL]\n" + "ld1w { z12.s }, p3/Z, [x12, #3, MUL VL]\n" + "ld1w { z13.s }, p2/Z, [x12, #4, MUL VL]\n" + "ld1w { z20.s }, p1/Z, [x12, #5, MUL VL]\n" + "ld1w { z14.s }, p6/Z, [x22]\n" + "zip1 z8.d, z9.d, z14.d\n" + "zip2 z14.d, z9.d, z14.d\n" + "ld1w { z15.s }, p5/Z, [x22, #1, MUL VL]\n" + "ld1w { z16.s }, p4/Z, [x22, #2, MUL VL]\n" + "zip1 z9.d, z10.d, z15.d\n" + "zip2 z15.d, z10.d, z15.d\n" + "ld1w { z17.s }, p3/Z, [x22, #3, MUL VL]\n" + "ld1w { z18.s }, p2/Z, [x22, #4, MUL VL]\n" + "zip1 z10.d, z11.d, z16.d\n" + "zip2 z16.d, z11.d, z16.d\n" + "ld1w { z19.s }, p1/Z, [x22, #5, MUL VL]\n" + "zip1 z11.d, z12.d, z17.d\n" + "zip2 z17.d, z12.d, z17.d\n" + "zip1 z12.d, z13.d, z18.d\n" + "zip2 z18.d, z13.d, z18.d\n" + "zip1 z13.d, z20.d, z19.d\n" + "zip2 z19.d, z20.d, z19.d\n" + "b 20f\n" + "19:" // Height 2: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "mov z19.b, #0x0\n" + "20:" // Height 2: setup done + "mov x25, #0x0\n" + "21:" // Height 2: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w24, [x19, x25, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 22f\n" + "ldr x20, [%x[input_ptr], x25, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x23, [x20, #0x0]\n" + "ldr x22, [x20, #0x8]\n" + "cbnz x25, 23f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x23, x23, x19, LSL #2\n" + "add x22, x22, x19, LSL #2\n" + "b 23f\n" + "22:" // Height 2: setup direct input + "mov x23, %x[input_ptr]\n" + "add x22, x23, x19, LSL #2\n" + "23:" // Height 2: input setup done + "cmp x24, #0x4\n" + "ble 25f\n" + "24:" // Height 2: Multiply loop: Main loop head + "whilelt p0.s, XZR, x24\n" + "ld1rqw { z0.s }, p0/Z, [x23]\n" + "ld1rqw { z1.s }, p0/Z, [x22]\n" + ".inst 0x658abc00 // bfcvt z0.h, p7/M, z0.s\n" + ".inst 0x658abc21 // bfcvt z1.h, p7/M, z1.s\n" + "uzp1 z0.h, z0.h, z0.h\n" + "ld1h { z4.h }, p7/Z, [x11]\n" + "ld1h { z5.h }, p7/Z, [x11, #1, MUL VL]\n" + "uzp1 z1.h, z1.h, z1.h\n" + "trn1 z0.d, z0.d, z1.d\n" + "ld1h { z6.h }, p7/Z, [x10]\n" + "ld1h { z7.h }, p7/Z, [x10, #1, MUL VL]\n" + ".inst 0x6464e408 // bfmmla z8.s, z0.h, z4.h\n" + ".inst 0x6465e40e // bfmmla z14.s, z0.h, z5.h\n" + "ld1h { z4.h }, p7/Z, [x9]\n" + "ld1h { z5.h }, p7/Z, [x9, #1, MUL VL]\n" + ".inst 0x6466e409 // bfmmla z9.s, z0.h, z6.h\n" + ".inst 0x6467e40f // bfmmla z15.s, z0.h, z7.h\n" + "ld1h { z6.h }, p7/Z, [x28]\n" + "ld1h { z7.h }, p7/Z, [x28, #1, MUL VL]\n" + ".inst 0x6464e40a // bfmmla z10.s, z0.h, z4.h\n" + ".inst 0x6465e410 // bfmmla z16.s, z0.h, z5.h\n" + "ld1h { z4.h }, p7/Z, [x27]\n" + "ld1h { z5.h }, p7/Z, [x27, #1, MUL VL]\n" + ".inst 0x6466e40b // bfmmla z11.s, z0.h, z6.h\n" + ".inst 0x6467e411 // bfmmla z17.s, z0.h, z7.h\n" + "ld1h { z6.h }, p7/Z, [x26]\n" + "ld1h { z7.h }, p7/Z, [x26, #1, MUL VL]\n" + "sub x24, x24, #0x4\n" + "cmp x24, #0x4\n" + ".inst 0x6464e40c // bfmmla z12.s, z0.h, z4.h\n" + ".inst 0x6465e412 // bfmmla z18.s, z0.h, z5.h\n" + ".inst 0x6466e40d // bfmmla z13.s, z0.h, z6.h\n" + ".inst 0x6467e413 // bfmmla z19.s, z0.h, z7.h\n" + "add x23, x23, #0x10\n" + "add x22, x22, #0x10\n" + "addvl x11, x11, #2\n" + "addvl x10, x10, #2\n" + "addvl x9, x9, #2\n" + "addvl x28, x28, #2\n" + "addvl x27, x27, #2\n" + "addvl x26, x26, #2\n" + "bgt 24b\n" + "25:" // Height 2: Multiply loop: Single iteration only + "whilelt p0.s, XZR, x24\n" + "ld1rqw { z0.s }, p0/Z, [x23]\n" + "ld1rqw { z1.s }, p0/Z, [x22]\n" + ".inst 0x658abc00 // bfcvt z0.h, p7/M, z0.s\n" + ".inst 0x658abc21 // bfcvt z1.h, p7/M, z1.s\n" + "uzp1 z0.h, z0.h, z0.h\n" + "ld1h { z4.h }, p7/Z, [x11]\n" + "ld1h { z5.h }, p7/Z, [x11, #1, MUL VL]\n" + "uzp1 z1.h, z1.h, z1.h\n" + "trn1 z0.d, z0.d, z1.d\n" + "ld1h { z6.h }, p7/Z, [x10]\n" + "ld1h { z7.h }, p7/Z, [x10, #1, MUL VL]\n" + ".inst 0x6464e408 // bfmmla z8.s, z0.h, z4.h\n" + ".inst 0x6465e40e // bfmmla z14.s, z0.h, z5.h\n" + "ld1h { z4.h }, p7/Z, [x9]\n" + "ld1h { z5.h }, p7/Z, [x9, #1, MUL VL]\n" + ".inst 0x6466e409 // bfmmla z9.s, z0.h, z6.h\n" + ".inst 0x6467e40f // bfmmla z15.s, z0.h, z7.h\n" + "ld1h { z6.h }, p7/Z, [x28]\n" + "ld1h { z7.h }, p7/Z, [x28, #1, MUL VL]\n" + ".inst 0x6464e40a // bfmmla z10.s, z0.h, z4.h\n" + ".inst 0x6465e410 // bfmmla z16.s, z0.h, z5.h\n" + "ld1h { z4.h }, p7/Z, [x27]\n" + "ld1h { z5.h }, p7/Z, [x27, #1, MUL VL]\n" + ".inst 0x6466e40b // bfmmla z11.s, z0.h, z6.h\n" + ".inst 0x6467e411 // bfmmla z17.s, z0.h, z7.h\n" + "ld1h { z6.h }, p7/Z, [x26]\n" + "ld1h { z7.h }, p7/Z, [x26, #1, MUL VL]\n" + ".inst 0x6464e40c // bfmmla z12.s, z0.h, z4.h\n" + ".inst 0x6465e412 // bfmmla z18.s, z0.h, z5.h\n" + "addvl x11, x11, #2\n" + "addvl x10, x10, #2\n" + ".inst 0x6466e40d // bfmmla z13.s, z0.h, z6.h\n" + ".inst 0x6467e413 // bfmmla z19.s, z0.h, z7.h\n" + "addvl x9, x9, #2\n" + "addvl x28, x28, #2\n" + "addvl x27, x27, #2\n" + "addvl x26, x26, #2\n" + "26:" // Height 2: Multiply loop: multiply skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x25, x25, #0x1\n" + "cmp x25, x19\n" + "bne 21b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "uzp1 z4.d, z8.d, z14.d\n" + "uzp2 z8.d, z8.d, z14.d\n" + "add x22, x12, x19, LSL #2\n" + "uzp1 z14.d, z9.d, z15.d\n" + "uzp2 z9.d, z9.d, z15.d\n" + "uzp1 z15.d, z10.d, z16.d\n" + "uzp2 z10.d, z10.d, z16.d\n" + "uzp1 z16.d, z11.d, z17.d\n" + "uzp2 z11.d, z11.d, z17.d\n" + "uzp1 z17.d, z12.d, z18.d\n" + "uzp2 z12.d, z12.d, z18.d\n" + "uzp1 z18.d, z13.d, z19.d\n" + "uzp2 z13.d, z13.d, z19.d\n" + "tbz %x[flags], #1, 27f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rw { z1.s }, p7/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rw { z0.s }, p7/Z, [x19]\n" + "fmin z4.s, p7/M, z4.s, z1.s\n" + "fmin z14.s, p7/M, z14.s, z1.s\n" + "fmin z15.s, p7/M, z15.s, z1.s\n" + "fmin z16.s, p7/M, z16.s, z1.s\n" + "fmin z17.s, p7/M, z17.s, z1.s\n" + "fmin z18.s, p7/M, z18.s, z1.s\n" + "fmin z8.s, p7/M, z8.s, z1.s\n" + "fmin z9.s, p7/M, z9.s, z1.s\n" + "fmin z10.s, p7/M, z10.s, z1.s\n" + "fmin z11.s, p7/M, z11.s, z1.s\n" + "fmin z12.s, p7/M, z12.s, z1.s\n" + "fmin z13.s, p7/M, z13.s, z1.s\n" + "fmax z4.s, p7/M, z4.s, z0.s\n" + "fmax z14.s, p7/M, z14.s, z0.s\n" + "fmax z15.s, p7/M, z15.s, z0.s\n" + "fmax z16.s, p7/M, z16.s, z0.s\n" + "fmax z17.s, p7/M, z17.s, z0.s\n" + "fmax z18.s, p7/M, z18.s, z0.s\n" + "fmax z8.s, p7/M, z8.s, z0.s\n" + "fmax z9.s, p7/M, z9.s, z0.s\n" + "fmax z10.s, p7/M, z10.s, z0.s\n" + "fmax z11.s, p7/M, z11.s, z0.s\n" + "fmax z12.s, p7/M, z12.s, z0.s\n" + "fmax z13.s, p7/M, z13.s, z0.s\n" + "27:" // Height 2: No activation + "st1w { z4.s }, p6, [x12]\n" + "st1w { z14.s }, p5, [x12, #1, MUL VL]\n" + "st1w { z15.s }, p4, [x12, #2, MUL VL]\n" + "st1w { z16.s }, p3, [x12, #3, MUL VL]\n" + "st1w { z17.s }, p2, [x12, #4, MUL VL]\n" + "st1w { z18.s }, p1, [x12, #5, MUL VL]\n" + "addvl x12, x12, #6\n" + "st1w { z8.s }, p6, [x22]\n" + "st1w { z9.s }, p5, [x22, #1, MUL VL]\n" + "st1w { z10.s }, p4, [x22, #2, MUL VL]\n" + "st1w { z11.s }, p3, [x22, #3, MUL VL]\n" + "st1w { z12.s }, p2, [x22, #4, MUL VL]\n" + "st1w { z13.s }, p1, [x22, #5, MUL VL]\n" + "28:" // Height 2: Writeback done + "decw x13, ALL, MUL #6\n" + "cmp x13, XZR\n" + "bgt 16b\n" + "b 58f\n" + "29:" // Height 3 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "mov x14, %x[bias]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x12, %x[output_ptr]\n" + "30:" // Height 3: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "cntw x20, ALL, MUL #5\n" + "add x27, x28, x19, LSL #1\n" + "add x26, x27, x19, LSL #1\n" + "add x19, x26, x19, LSL #1\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 31f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x26, x11\n" + "bgt 31f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x27, x11\n" + "bgt 31f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 31f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 31f\n" + "mov x10, x11\n" + "31:" // Height 3: B setup done + "mov x19, #0x0\n" + "whilelt p6.s, x19, x13\n" + "incw x19\n" + "whilelt p5.s, x19, x13\n" + "incw x19\n" + "whilelt p4.s, x19, x13\n" + "incw x19\n" + "whilelt p3.s, x19, x13\n" + "incw x19\n" + "whilelt p2.s, x19, x13\n" + "incw x19\n" + "whilelt p1.s, x19, x13\n" + "cbz x14, 32f\n" + "ld1w { z8.s }, p7/Z, [x14]\n" + "ld1w { z9.s }, p7/Z, [x14, #1, MUL VL]\n" + "zip2 z14.d, z8.d, z8.d\n" + "zip1 z8.d, z8.d, z8.d\n" + "ld1w { z10.s }, p7/Z, [x14, #2, MUL VL]\n" + "ld1w { z11.s }, p7/Z, [x14, #3, MUL VL]\n" + "zip2 z15.d, z9.d, z9.d\n" + "zip1 z9.d, z9.d, z9.d\n" + "ld1w { z12.s }, p7/Z, [x14, #4, MUL VL]\n" + "ld1w { z13.s }, p7/Z, [x14, #5, MUL VL]\n" + "zip2 z16.d, z10.d, z10.d\n" + "zip1 z10.d, z10.d, z10.d\n" + "zip2 z17.d, z11.d, z11.d\n" + "zip1 z11.d, z11.d, z11.d\n" + "addvl x14, x14, #6\n" + "zip2 z18.d, z12.d, z12.d\n" + "zip1 z12.d, z12.d, z12.d\n" + "zip2 z19.d, z13.d, z13.d\n" + "zip1 z13.d, z13.d, z13.d\n" + "mov z20.d, z8.d\n" + "mov z26.d, z14.d\n" + "mov z21.d, z9.d\n" + "mov z27.d, z15.d\n" + "mov z22.d, z10.d\n" + "mov z28.d, z16.d\n" + "mov z23.d, z11.d\n" + "mov z29.d, z17.d\n" + "mov z24.d, z12.d\n" + "mov z30.d, z18.d\n" + "mov z25.d, z13.d\n" + "mov z31.d, z19.d\n" + "b 34f\n" + "32:" // Height 3: no bias + "tbz %x[flags], #0, 33f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x22, x12, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "ld1w { z9.s }, p6/Z, [x12]\n" + "ld1w { z10.s }, p5/Z, [x12, #1, MUL VL]\n" + "ld1w { z11.s }, p4/Z, [x12, #2, MUL VL]\n" + "ld1w { z12.s }, p3/Z, [x12, #3, MUL VL]\n" + "ld1w { z13.s }, p2/Z, [x12, #4, MUL VL]\n" + "ld1w { z20.s }, p1/Z, [x12, #5, MUL VL]\n" + "ld1w { z14.s }, p6/Z, [x22]\n" + "zip1 z8.d, z9.d, z14.d\n" + "zip2 z14.d, z9.d, z14.d\n" + "ld1w { z15.s }, p5/Z, [x22, #1, MUL VL]\n" + "ld1w { z16.s }, p4/Z, [x22, #2, MUL VL]\n" + "zip1 z9.d, z10.d, z15.d\n" + "zip2 z15.d, z10.d, z15.d\n" + "ld1w { z17.s }, p3/Z, [x22, #3, MUL VL]\n" + "ld1w { z18.s }, p2/Z, [x22, #4, MUL VL]\n" + "zip1 z10.d, z11.d, z16.d\n" + "zip2 z16.d, z11.d, z16.d\n" + "ld1w { z19.s }, p1/Z, [x22, #5, MUL VL]\n" + "ld1w { z21.s }, p6/Z, [x21]\n" + "zip1 z11.d, z12.d, z17.d\n" + "zip2 z17.d, z12.d, z17.d\n" + "ld1w { z22.s }, p5/Z, [x21, #1, MUL VL]\n" + "ld1w { z23.s }, p4/Z, [x21, #2, MUL VL]\n" + "zip1 z12.d, z13.d, z18.d\n" + "zip2 z18.d, z13.d, z18.d\n" + "ld1w { z24.s }, p3/Z, [x21, #3, MUL VL]\n" + "ld1w { z25.s }, p2/Z, [x21, #4, MUL VL]\n" + "zip1 z13.d, z20.d, z19.d\n" + "zip2 z19.d, z20.d, z19.d\n" + "ld1w { z4.s }, p1/Z, [x21, #5, MUL VL]\n" + "zip1 z20.d, z21.d, z26.d\n" + "zip2 z26.d, z21.d, z26.d\n" + "zip1 z21.d, z22.d, z27.d\n" + "zip2 z27.d, z22.d, z27.d\n" + "zip1 z22.d, z23.d, z28.d\n" + "zip2 z28.d, z23.d, z28.d\n" + "zip1 z23.d, z24.d, z29.d\n" + "zip2 z29.d, z24.d, z29.d\n" + "zip1 z24.d, z25.d, z30.d\n" + "zip2 z30.d, z25.d, z30.d\n" + "zip1 z25.d, z4.d, z31.d\n" + "zip2 z31.d, z4.d, z31.d\n" + "b 34f\n" + "33:" // Height 3: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "mov z19.b, #0x0\n" + "mov z20.b, #0x0\n" + "mov z21.b, #0x0\n" + "mov z22.b, #0x0\n" + "mov z23.b, #0x0\n" + "mov z24.b, #0x0\n" + "mov z25.b, #0x0\n" + "mov z26.b, #0x0\n" + "mov z27.b, #0x0\n" + "mov z28.b, #0x0\n" + "mov z29.b, #0x0\n" + "mov z30.b, #0x0\n" + "mov z31.b, #0x0\n" + "34:" // Height 3: setup done + "mov x25, #0x0\n" + "35:" // Height 3: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w24, [x19, x25, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 36f\n" + "ldr x20, [%x[input_ptr], x25, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x23, [x20, #0x0]\n" + "ldr x22, [x20, #0x8]\n" + "ldr x21, [x20, #0x10]\n" + "cbnz x25, 37f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x23, x23, x19, LSL #2\n" + "add x22, x22, x19, LSL #2\n" + "add x21, x21, x19, LSL #2\n" + "b 37f\n" + "36:" // Height 3: setup direct input + "mov x23, %x[input_ptr]\n" + "add x22, x23, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "37:" // Height 3: input setup done + "cmp x24, #0x4\n" + "ble 39f\n" + "38:" // Height 3: Multiply loop: Main loop head + "whilelt p0.s, XZR, x24\n" + "ld1rqw { z0.s }, p0/Z, [x23]\n" + "ld1rqw { z1.s }, p0/Z, [x22]\n" + ".inst 0x658abc00 // bfcvt z0.h, p7/M, z0.s\n" + "ld1rqw { z2.s }, p0/Z, [x21]\n" + ".inst 0x658abc21 // bfcvt z1.h, p7/M, z1.s\n" + "uzp1 z0.h, z0.h, z0.h\n" + "ld1h { z4.h }, p7/Z, [x11]\n" + "uzp1 z1.h, z1.h, z1.h\n" + ".inst 0x658abc42 // bfcvt z2.h, p7/M, z2.s\n" + "ld1h { z5.h }, p7/Z, [x11, #1, MUL VL]\n" + "ld1h { z6.h }, p7/Z, [x10]\n" + "trn1 z0.d, z0.d, z1.d\n" + "uzp1 z2.h, z2.h, z2.h\n" + "ld1h { z7.h }, p7/Z, [x10, #1, MUL VL]\n" + ".inst 0x6464e408 // bfmmla z8.s, z0.h, z4.h\n" + ".inst 0x6464e454 // bfmmla z20.s, z2.h, z4.h\n" + ".inst 0x6465e40e // bfmmla z14.s, z0.h, z5.h\n" + "ld1h { z4.h }, p7/Z, [x9]\n" + "sub x24, x24, #0x4\n" + ".inst 0x6465e45a // bfmmla z26.s, z2.h, z5.h\n" + ".inst 0x6466e409 // bfmmla z9.s, z0.h, z6.h\n" + "ld1h { z5.h }, p7/Z, [x9, #1, MUL VL]\n" + "cmp x24, #0x4\n" + ".inst 0x6466e455 // bfmmla z21.s, z2.h, z6.h\n" + ".inst 0x6467e40f // bfmmla z15.s, z0.h, z7.h\n" + "ld1h { z6.h }, p7/Z, [x28]\n" + "add x23, x23, #0x10\n" + ".inst 0x6467e45b // bfmmla z27.s, z2.h, z7.h\n" + "ld1h { z7.h }, p7/Z, [x28, #1, MUL VL]\n" + ".inst 0x6464e40a // bfmmla z10.s, z0.h, z4.h\n" + "add x22, x22, #0x10\n" + ".inst 0x6464e456 // bfmmla z22.s, z2.h, z4.h\n" + ".inst 0x6465e410 // bfmmla z16.s, z0.h, z5.h\n" + "ld1h { z4.h }, p7/Z, [x27]\n" + "add x21, x21, #0x10\n" + ".inst 0x6465e45c // bfmmla z28.s, z2.h, z5.h\n" + ".inst 0x6466e40b // bfmmla z11.s, z0.h, z6.h\n" + "ld1h { z5.h }, p7/Z, [x27, #1, MUL VL]\n" + "addvl x11, x11, #2\n" + ".inst 0x6466e457 // bfmmla z23.s, z2.h, z6.h\n" + ".inst 0x6467e411 // bfmmla z17.s, z0.h, z7.h\n" + "ld1h { z6.h }, p7/Z, [x26]\n" + "addvl x10, x10, #2\n" + ".inst 0x6467e45d // bfmmla z29.s, z2.h, z7.h\n" + "ld1h { z7.h }, p7/Z, [x26, #1, MUL VL]\n" + ".inst 0x6464e40c // bfmmla z12.s, z0.h, z4.h\n" + "addvl x9, x9, #2\n" + ".inst 0x6464e458 // bfmmla z24.s, z2.h, z4.h\n" + ".inst 0x6465e412 // bfmmla z18.s, z0.h, z5.h\n" + "addvl x28, x28, #2\n" + "addvl x27, x27, #2\n" + ".inst 0x6465e45e // bfmmla z30.s, z2.h, z5.h\n" + ".inst 0x6466e40d // bfmmla z13.s, z0.h, z6.h\n" + "addvl x26, x26, #2\n" + ".inst 0x6466e459 // bfmmla z25.s, z2.h, z6.h\n" + ".inst 0x6467e413 // bfmmla z19.s, z0.h, z7.h\n" + ".inst 0x6467e45f // bfmmla z31.s, z2.h, z7.h\n" + "bgt 38b\n" + "39:" // Height 3: Multiply loop: Single iteration only + "whilelt p0.s, XZR, x24\n" + "ld1rqw { z0.s }, p0/Z, [x23]\n" + "ld1rqw { z1.s }, p0/Z, [x22]\n" + ".inst 0x658abc00 // bfcvt z0.h, p7/M, z0.s\n" + "ld1rqw { z2.s }, p0/Z, [x21]\n" + ".inst 0x658abc21 // bfcvt z1.h, p7/M, z1.s\n" + "uzp1 z0.h, z0.h, z0.h\n" + "ld1h { z4.h }, p7/Z, [x11]\n" + "uzp1 z1.h, z1.h, z1.h\n" + ".inst 0x658abc42 // bfcvt z2.h, p7/M, z2.s\n" + "ld1h { z5.h }, p7/Z, [x11, #1, MUL VL]\n" + "ld1h { z6.h }, p7/Z, [x10]\n" + "trn1 z0.d, z0.d, z1.d\n" + "uzp1 z2.h, z2.h, z2.h\n" + "ld1h { z7.h }, p7/Z, [x10, #1, MUL VL]\n" + ".inst 0x6464e408 // bfmmla z8.s, z0.h, z4.h\n" + ".inst 0x6464e454 // bfmmla z20.s, z2.h, z4.h\n" + ".inst 0x6465e40e // bfmmla z14.s, z0.h, z5.h\n" + "ld1h { z4.h }, p7/Z, [x9]\n" + "addvl x11, x11, #2\n" + ".inst 0x6465e45a // bfmmla z26.s, z2.h, z5.h\n" + ".inst 0x6466e409 // bfmmla z9.s, z0.h, z6.h\n" + "ld1h { z5.h }, p7/Z, [x9, #1, MUL VL]\n" + "addvl x10, x10, #2\n" + ".inst 0x6466e455 // bfmmla z21.s, z2.h, z6.h\n" + ".inst 0x6467e40f // bfmmla z15.s, z0.h, z7.h\n" + "ld1h { z6.h }, p7/Z, [x28]\n" + "addvl x9, x9, #2\n" + ".inst 0x6467e45b // bfmmla z27.s, z2.h, z7.h\n" + "ld1h { z7.h }, p7/Z, [x28, #1, MUL VL]\n" + ".inst 0x6464e40a // bfmmla z10.s, z0.h, z4.h\n" + "addvl x28, x28, #2\n" + ".inst 0x6464e456 // bfmmla z22.s, z2.h, z4.h\n" + ".inst 0x6465e410 // bfmmla z16.s, z0.h, z5.h\n" + "ld1h { z4.h }, p7/Z, [x27]\n" + ".inst 0x6465e45c // bfmmla z28.s, z2.h, z5.h\n" + ".inst 0x6466e40b // bfmmla z11.s, z0.h, z6.h\n" + "ld1h { z5.h }, p7/Z, [x27, #1, MUL VL]\n" + "addvl x27, x27, #2\n" + ".inst 0x6466e457 // bfmmla z23.s, z2.h, z6.h\n" + ".inst 0x6467e411 // bfmmla z17.s, z0.h, z7.h\n" + "ld1h { z6.h }, p7/Z, [x26]\n" + ".inst 0x6467e45d // bfmmla z29.s, z2.h, z7.h\n" + "ld1h { z7.h }, p7/Z, [x26, #1, MUL VL]\n" + ".inst 0x6464e40c // bfmmla z12.s, z0.h, z4.h\n" + "addvl x26, x26, #2\n" + ".inst 0x6464e458 // bfmmla z24.s, z2.h, z4.h\n" + ".inst 0x6465e412 // bfmmla z18.s, z0.h, z5.h\n" + ".inst 0x6465e45e // bfmmla z30.s, z2.h, z5.h\n" + ".inst 0x6466e40d // bfmmla z13.s, z0.h, z6.h\n" + ".inst 0x6466e459 // bfmmla z25.s, z2.h, z6.h\n" + ".inst 0x6467e413 // bfmmla z19.s, z0.h, z7.h\n" + ".inst 0x6467e45f // bfmmla z31.s, z2.h, z7.h\n" + "40:" // Height 3: Multiply loop: multiply skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x25, x25, #0x1\n" + "cmp x25, x19\n" + "bne 35b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x22, x12, x19, LSL #2\n" + "uzp1 z4.d, z8.d, z14.d\n" + "uzp2 z8.d, z8.d, z14.d\n" + "uzp1 z14.d, z9.d, z15.d\n" + "uzp2 z9.d, z9.d, z15.d\n" + "add x21, x22, x19, LSL #2\n" + "uzp1 z15.d, z10.d, z16.d\n" + "uzp2 z10.d, z10.d, z16.d\n" + "uzp1 z16.d, z11.d, z17.d\n" + "uzp2 z11.d, z11.d, z17.d\n" + "uzp1 z17.d, z12.d, z18.d\n" + "uzp2 z12.d, z12.d, z18.d\n" + "uzp1 z18.d, z13.d, z19.d\n" + "uzp2 z13.d, z13.d, z19.d\n" + "uzp1 z20.d, z20.d, z26.d\n" + "uzp1 z21.d, z21.d, z27.d\n" + "uzp1 z22.d, z22.d, z28.d\n" + "uzp1 z23.d, z23.d, z29.d\n" + "uzp1 z24.d, z24.d, z30.d\n" + "uzp1 z25.d, z25.d, z31.d\n" + "tbz %x[flags], #1, 41f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rw { z1.s }, p7/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rw { z0.s }, p7/Z, [x19]\n" + "fmin z4.s, p7/M, z4.s, z1.s\n" + "fmin z14.s, p7/M, z14.s, z1.s\n" + "fmin z15.s, p7/M, z15.s, z1.s\n" + "fmin z16.s, p7/M, z16.s, z1.s\n" + "fmin z17.s, p7/M, z17.s, z1.s\n" + "fmin z18.s, p7/M, z18.s, z1.s\n" + "fmin z8.s, p7/M, z8.s, z1.s\n" + "fmin z9.s, p7/M, z9.s, z1.s\n" + "fmin z10.s, p7/M, z10.s, z1.s\n" + "fmin z11.s, p7/M, z11.s, z1.s\n" + "fmin z12.s, p7/M, z12.s, z1.s\n" + "fmin z13.s, p7/M, z13.s, z1.s\n" + "fmin z20.s, p7/M, z20.s, z1.s\n" + "fmin z21.s, p7/M, z21.s, z1.s\n" + "fmin z22.s, p7/M, z22.s, z1.s\n" + "fmin z23.s, p7/M, z23.s, z1.s\n" + "fmin z24.s, p7/M, z24.s, z1.s\n" + "fmin z25.s, p7/M, z25.s, z1.s\n" + "fmax z4.s, p7/M, z4.s, z0.s\n" + "fmax z14.s, p7/M, z14.s, z0.s\n" + "fmax z15.s, p7/M, z15.s, z0.s\n" + "fmax z16.s, p7/M, z16.s, z0.s\n" + "fmax z17.s, p7/M, z17.s, z0.s\n" + "fmax z18.s, p7/M, z18.s, z0.s\n" + "fmax z8.s, p7/M, z8.s, z0.s\n" + "fmax z9.s, p7/M, z9.s, z0.s\n" + "fmax z10.s, p7/M, z10.s, z0.s\n" + "fmax z11.s, p7/M, z11.s, z0.s\n" + "fmax z12.s, p7/M, z12.s, z0.s\n" + "fmax z13.s, p7/M, z13.s, z0.s\n" + "fmax z20.s, p7/M, z20.s, z0.s\n" + "fmax z21.s, p7/M, z21.s, z0.s\n" + "fmax z22.s, p7/M, z22.s, z0.s\n" + "fmax z23.s, p7/M, z23.s, z0.s\n" + "fmax z24.s, p7/M, z24.s, z0.s\n" + "fmax z25.s, p7/M, z25.s, z0.s\n" + "41:" // Height 3: No activation + "st1w { z4.s }, p6, [x12]\n" + "st1w { z14.s }, p5, [x12, #1, MUL VL]\n" + "st1w { z15.s }, p4, [x12, #2, MUL VL]\n" + "st1w { z16.s }, p3, [x12, #3, MUL VL]\n" + "st1w { z17.s }, p2, [x12, #4, MUL VL]\n" + "st1w { z18.s }, p1, [x12, #5, MUL VL]\n" + "addvl x12, x12, #6\n" + "st1w { z8.s }, p6, [x22]\n" + "st1w { z9.s }, p5, [x22, #1, MUL VL]\n" + "st1w { z10.s }, p4, [x22, #2, MUL VL]\n" + "st1w { z11.s }, p3, [x22, #3, MUL VL]\n" + "st1w { z12.s }, p2, [x22, #4, MUL VL]\n" + "st1w { z13.s }, p1, [x22, #5, MUL VL]\n" + "st1w { z20.s }, p6, [x21]\n" + "st1w { z21.s }, p5, [x21, #1, MUL VL]\n" + "st1w { z22.s }, p4, [x21, #2, MUL VL]\n" + "st1w { z23.s }, p3, [x21, #3, MUL VL]\n" + "st1w { z24.s }, p2, [x21, #4, MUL VL]\n" + "st1w { z25.s }, p1, [x21, #5, MUL VL]\n" + "42:" // Height 3: Writeback done + "decw x13, ALL, MUL #6\n" + "cmp x13, XZR\n" + "bgt 30b\n" + "b 58f\n" + "43:" // Height 4 + "ldr x19, [%x[args_ptr], %[offsetof_B_ptr]]\n" + "ldr x13, [%x[args_ptr], %[offsetof_N]]\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x20, #0x10\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "mov x14, %x[bias]\n" + "mov x12, %x[output_ptr]\n" + "madd %x[output_ptr], x19, x20, %x[output_ptr]\n" + "44:" // Height 4: Column loop + "ldr x11, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "add x10, x11, x19, LSL #1\n" + "add x9, x10, x19, LSL #1\n" + "add x28, x9, x19, LSL #1\n" + "cntw x20, ALL, MUL #5\n" + "add x27, x28, x19, LSL #1\n" + "add x26, x27, x19, LSL #1\n" + "add x19, x26, x19, LSL #1\n" + "cmp x13, x20\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "bgt 45f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x26, x11\n" + "bgt 45f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x27, x11\n" + "bgt 45f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x28, x11\n" + "bgt 45f\n" + "decw x20\n" + "cmp x13, x20\n" + "mov x9, x11\n" + "bgt 45f\n" + "mov x10, x11\n" + "45:" // Height 4: B setup done + "mov x19, #0x0\n" + "whilelt p6.s, x19, x13\n" + "incw x19\n" + "whilelt p5.s, x19, x13\n" + "incw x19\n" + "whilelt p4.s, x19, x13\n" + "incw x19\n" + "whilelt p3.s, x19, x13\n" + "incw x19\n" + "whilelt p2.s, x19, x13\n" + "incw x19\n" + "whilelt p1.s, x19, x13\n" + "cbz x14, 46f\n" + "ld1w { z8.s }, p7/Z, [x14]\n" + "ld1w { z9.s }, p7/Z, [x14, #1, MUL VL]\n" + "zip2 z14.d, z8.d, z8.d\n" + "zip1 z8.d, z8.d, z8.d\n" + "ld1w { z10.s }, p7/Z, [x14, #2, MUL VL]\n" + "ld1w { z11.s }, p7/Z, [x14, #3, MUL VL]\n" + "zip2 z15.d, z9.d, z9.d\n" + "zip1 z9.d, z9.d, z9.d\n" + "ld1w { z12.s }, p7/Z, [x14, #4, MUL VL]\n" + "ld1w { z13.s }, p7/Z, [x14, #5, MUL VL]\n" + "zip2 z16.d, z10.d, z10.d\n" + "zip1 z10.d, z10.d, z10.d\n" + "zip2 z17.d, z11.d, z11.d\n" + "zip1 z11.d, z11.d, z11.d\n" + "addvl x14, x14, #6\n" + "zip2 z18.d, z12.d, z12.d\n" + "zip1 z12.d, z12.d, z12.d\n" + "zip2 z19.d, z13.d, z13.d\n" + "zip1 z13.d, z13.d, z13.d\n" + "mov z20.d, z8.d\n" + "mov z26.d, z14.d\n" + "mov z21.d, z9.d\n" + "mov z27.d, z15.d\n" + "mov z22.d, z10.d\n" + "mov z28.d, z16.d\n" + "mov z23.d, z11.d\n" + "mov z29.d, z17.d\n" + "mov z24.d, z12.d\n" + "mov z30.d, z18.d\n" + "mov z25.d, z13.d\n" + "mov z31.d, z19.d\n" + "b 48f\n" + "46:" // Height 4: no bias + "tbz %x[flags], #0, 47f\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x22, x12, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "ld1w { z9.s }, p6/Z, [x12]\n" + "add x20, x21, x19, LSL #2\n" + "ld1w { z10.s }, p5/Z, [x12, #1, MUL VL]\n" + "ld1w { z11.s }, p4/Z, [x12, #2, MUL VL]\n" + "ld1w { z12.s }, p3/Z, [x12, #3, MUL VL]\n" + "ld1w { z13.s }, p2/Z, [x12, #4, MUL VL]\n" + "ld1w { z20.s }, p1/Z, [x12, #5, MUL VL]\n" + "ld1w { z14.s }, p6/Z, [x22]\n" + "zip1 z8.d, z9.d, z14.d\n" + "zip2 z14.d, z9.d, z14.d\n" + "ld1w { z15.s }, p5/Z, [x22, #1, MUL VL]\n" + "ld1w { z16.s }, p4/Z, [x22, #2, MUL VL]\n" + "zip1 z9.d, z10.d, z15.d\n" + "zip2 z15.d, z10.d, z15.d\n" + "ld1w { z17.s }, p3/Z, [x22, #3, MUL VL]\n" + "ld1w { z18.s }, p2/Z, [x22, #4, MUL VL]\n" + "zip1 z10.d, z11.d, z16.d\n" + "zip2 z16.d, z11.d, z16.d\n" + "ld1w { z19.s }, p1/Z, [x22, #5, MUL VL]\n" + "ld1w { z21.s }, p6/Z, [x21]\n" + "zip1 z11.d, z12.d, z17.d\n" + "zip2 z17.d, z12.d, z17.d\n" + "ld1w { z22.s }, p5/Z, [x21, #1, MUL VL]\n" + "ld1w { z23.s }, p4/Z, [x21, #2, MUL VL]\n" + "zip1 z12.d, z13.d, z18.d\n" + "zip2 z18.d, z13.d, z18.d\n" + "ld1w { z24.s }, p3/Z, [x21, #3, MUL VL]\n" + "ld1w { z25.s }, p2/Z, [x21, #4, MUL VL]\n" + "zip1 z13.d, z20.d, z19.d\n" + "zip2 z19.d, z20.d, z19.d\n" + "ld1w { z4.s }, p1/Z, [x21, #5, MUL VL]\n" + "ld1w { z26.s }, p6/Z, [x20]\n" + "zip1 z20.d, z21.d, z26.d\n" + "zip2 z26.d, z21.d, z26.d\n" + "ld1w { z27.s }, p5/Z, [x20, #1, MUL VL]\n" + "ld1w { z28.s }, p4/Z, [x20, #2, MUL VL]\n" + "zip1 z21.d, z22.d, z27.d\n" + "zip2 z27.d, z22.d, z27.d\n" + "ld1w { z29.s }, p3/Z, [x20, #3, MUL VL]\n" + "ld1w { z30.s }, p2/Z, [x20, #4, MUL VL]\n" + "zip1 z22.d, z23.d, z28.d\n" + "zip2 z28.d, z23.d, z28.d\n" + "ld1w { z31.s }, p1/Z, [x20, #5, MUL VL]\n" + "zip1 z23.d, z24.d, z29.d\n" + "zip2 z29.d, z24.d, z29.d\n" + "zip1 z24.d, z25.d, z30.d\n" + "zip2 z30.d, z25.d, z30.d\n" + "zip1 z25.d, z4.d, z31.d\n" + "zip2 z31.d, z4.d, z31.d\n" + "b 48f\n" + "47:" // Height 4: no accumulate + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "mov z19.b, #0x0\n" + "mov z20.b, #0x0\n" + "mov z21.b, #0x0\n" + "mov z22.b, #0x0\n" + "mov z23.b, #0x0\n" + "mov z24.b, #0x0\n" + "mov z25.b, #0x0\n" + "mov z26.b, #0x0\n" + "mov z27.b, #0x0\n" + "mov z28.b, #0x0\n" + "mov z29.b, #0x0\n" + "mov z30.b, #0x0\n" + "mov z31.b, #0x0\n" + "48:" // Height 4: setup done + "mov x25, #0x0\n" + "49:" // Height 4: String loop + "ldr x19, [%x[args_ptr], %[offsetof_string_lengths]]\n" + "ldr w24, [x19, x25, LSL #0x2]\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 50f\n" + "ldr x20, [%x[input_ptr], x25, LSL #0x3]\n" + "add x20, x20, x19, LSL #3\n" + "ldr x23, [x20, #0x0]\n" + "ldr x22, [x20, #0x8]\n" + "ldr x21, [x20, #0x10]\n" + "ldr x20, [x20, #0x18]\n" + "cbnz x25, 51f\n" + "ldr x19, [%x[args_ptr], %[offsetof_input_initial_col]]\n" + "add x23, x23, x19, LSL #2\n" + "add x22, x22, x19, LSL #2\n" + "add x21, x21, x19, LSL #2\n" + "add x20, x20, x19, LSL #2\n" + "b 51f\n" + "50:" // Height 4: setup direct input + "mov x23, %x[input_ptr]\n" + "add x22, x23, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "add x20, x21, x19, LSL #2\n" + "51:" // Height 4: input setup done + "cmp x24, #0x4\n" + "ble 53f\n" + "52:" // Height 4: Multiply loop: Main loop head + "whilelt p0.s, XZR, x24\n" + "ld1rqw { z0.s }, p0/Z, [x23]\n" + "ld1rqw { z1.s }, p0/Z, [x22]\n" + ".inst 0x658abc00 // bfcvt z0.h, p7/M, z0.s\n" + "ld1rqw { z2.s }, p0/Z, [x21]\n" + "ld1rqw { z3.s }, p0/Z, [x20]\n" + ".inst 0x658abc21 // bfcvt z1.h, p7/M, z1.s\n" + ".inst 0x658abc42 // bfcvt z2.h, p7/M, z2.s\n" + ".inst 0x658abc63 // bfcvt z3.h, p7/M, z3.s\n" + "uzp1 z0.h, z0.h, z0.h\n" + "ld1h { z4.h }, p7/Z, [x11]\n" + "ld1h { z5.h }, p7/Z, [x11, #1, MUL VL]\n" + "uzp1 z1.h, z1.h, z1.h\n" + "uzp1 z2.h, z2.h, z2.h\n" + "ld1h { z6.h }, p7/Z, [x10]\n" + "ld1h { z7.h }, p7/Z, [x10, #1, MUL VL]\n" + "uzp1 z3.h, z3.h, z3.h\n" + "trn1 z0.d, z0.d, z1.d\n" + ".inst 0x6464e408 // bfmmla z8.s, z0.h, z4.h\n" + "sub x24, x24, #0x4\n" + "trn1 z2.d, z2.d, z3.d\n" + ".inst 0x6464e454 // bfmmla z20.s, z2.h, z4.h\n" + ".inst 0x6465e40e // bfmmla z14.s, z0.h, z5.h\n" + "ld1h { z4.h }, p7/Z, [x9]\n" + ".inst 0x6465e45a // bfmmla z26.s, z2.h, z5.h\n" + ".inst 0x6466e409 // bfmmla z9.s, z0.h, z6.h\n" + "ld1h { z5.h }, p7/Z, [x9, #1, MUL VL]\n" + "cmp x24, #0x4\n" + ".inst 0x6466e455 // bfmmla z21.s, z2.h, z6.h\n" + ".inst 0x6467e40f // bfmmla z15.s, z0.h, z7.h\n" + "ld1h { z6.h }, p7/Z, [x28]\n" + "add x23, x23, #0x10\n" + ".inst 0x6467e45b // bfmmla z27.s, z2.h, z7.h\n" + "ld1h { z7.h }, p7/Z, [x28, #1, MUL VL]\n" + ".inst 0x6464e40a // bfmmla z10.s, z0.h, z4.h\n" + "add x22, x22, #0x10\n" + ".inst 0x6464e456 // bfmmla z22.s, z2.h, z4.h\n" + ".inst 0x6465e410 // bfmmla z16.s, z0.h, z5.h\n" + "ld1h { z4.h }, p7/Z, [x27]\n" + "add x21, x21, #0x10\n" + ".inst 0x6465e45c // bfmmla z28.s, z2.h, z5.h\n" + ".inst 0x6466e40b // bfmmla z11.s, z0.h, z6.h\n" + "ld1h { z5.h }, p7/Z, [x27, #1, MUL VL]\n" + "add x20, x20, #0x10\n" + ".inst 0x6466e457 // bfmmla z23.s, z2.h, z6.h\n" + ".inst 0x6467e411 // bfmmla z17.s, z0.h, z7.h\n" + "ld1h { z6.h }, p7/Z, [x26]\n" + "addvl x11, x11, #2\n" + ".inst 0x6467e45d // bfmmla z29.s, z2.h, z7.h\n" + "ld1h { z7.h }, p7/Z, [x26, #1, MUL VL]\n" + ".inst 0x6464e40c // bfmmla z12.s, z0.h, z4.h\n" + "addvl x10, x10, #2\n" + ".inst 0x6464e458 // bfmmla z24.s, z2.h, z4.h\n" + ".inst 0x6465e412 // bfmmla z18.s, z0.h, z5.h\n" + "addvl x9, x9, #2\n" + "addvl x28, x28, #2\n" + ".inst 0x6465e45e // bfmmla z30.s, z2.h, z5.h\n" + ".inst 0x6466e40d // bfmmla z13.s, z0.h, z6.h\n" + "addvl x27, x27, #2\n" + "addvl x26, x26, #2\n" + ".inst 0x6466e459 // bfmmla z25.s, z2.h, z6.h\n" + ".inst 0x6467e413 // bfmmla z19.s, z0.h, z7.h\n" + ".inst 0x6467e45f // bfmmla z31.s, z2.h, z7.h\n" + "bgt 52b\n" + "53:" // Height 4: Multiply loop: Single iteration only + "whilelt p0.s, XZR, x24\n" + "ld1rqw { z0.s }, p0/Z, [x23]\n" + "ld1rqw { z1.s }, p0/Z, [x22]\n" + ".inst 0x658abc00 // bfcvt z0.h, p7/M, z0.s\n" + "ld1rqw { z2.s }, p0/Z, [x21]\n" + "ld1rqw { z3.s }, p0/Z, [x20]\n" + ".inst 0x658abc21 // bfcvt z1.h, p7/M, z1.s\n" + ".inst 0x658abc42 // bfcvt z2.h, p7/M, z2.s\n" + ".inst 0x658abc63 // bfcvt z3.h, p7/M, z3.s\n" + "uzp1 z0.h, z0.h, z0.h\n" + "ld1h { z4.h }, p7/Z, [x11]\n" + "ld1h { z5.h }, p7/Z, [x11, #1, MUL VL]\n" + "uzp1 z1.h, z1.h, z1.h\n" + "uzp1 z2.h, z2.h, z2.h\n" + "ld1h { z6.h }, p7/Z, [x10]\n" + "ld1h { z7.h }, p7/Z, [x10, #1, MUL VL]\n" + "uzp1 z3.h, z3.h, z3.h\n" + "trn1 z0.d, z0.d, z1.d\n" + ".inst 0x6464e408 // bfmmla z8.s, z0.h, z4.h\n" + "addvl x11, x11, #2\n" + "trn1 z2.d, z2.d, z3.d\n" + ".inst 0x6464e454 // bfmmla z20.s, z2.h, z4.h\n" + ".inst 0x6465e40e // bfmmla z14.s, z0.h, z5.h\n" + "ld1h { z4.h }, p7/Z, [x9]\n" + ".inst 0x6465e45a // bfmmla z26.s, z2.h, z5.h\n" + ".inst 0x6466e409 // bfmmla z9.s, z0.h, z6.h\n" + "ld1h { z5.h }, p7/Z, [x9, #1, MUL VL]\n" + "addvl x10, x10, #2\n" + ".inst 0x6466e455 // bfmmla z21.s, z2.h, z6.h\n" + ".inst 0x6467e40f // bfmmla z15.s, z0.h, z7.h\n" + "ld1h { z6.h }, p7/Z, [x28]\n" + "addvl x9, x9, #2\n" + ".inst 0x6467e45b // bfmmla z27.s, z2.h, z7.h\n" + "ld1h { z7.h }, p7/Z, [x28, #1, MUL VL]\n" + ".inst 0x6464e40a // bfmmla z10.s, z0.h, z4.h\n" + "addvl x28, x28, #2\n" + ".inst 0x6464e456 // bfmmla z22.s, z2.h, z4.h\n" + ".inst 0x6465e410 // bfmmla z16.s, z0.h, z5.h\n" + "ld1h { z4.h }, p7/Z, [x27]\n" + ".inst 0x6465e45c // bfmmla z28.s, z2.h, z5.h\n" + ".inst 0x6466e40b // bfmmla z11.s, z0.h, z6.h\n" + "ld1h { z5.h }, p7/Z, [x27, #1, MUL VL]\n" + "addvl x27, x27, #2\n" + ".inst 0x6466e457 // bfmmla z23.s, z2.h, z6.h\n" + ".inst 0x6467e411 // bfmmla z17.s, z0.h, z7.h\n" + "ld1h { z6.h }, p7/Z, [x26]\n" + ".inst 0x6467e45d // bfmmla z29.s, z2.h, z7.h\n" + "ld1h { z7.h }, p7/Z, [x26, #1, MUL VL]\n" + ".inst 0x6464e40c // bfmmla z12.s, z0.h, z4.h\n" + "addvl x26, x26, #2\n" + ".inst 0x6464e458 // bfmmla z24.s, z2.h, z4.h\n" + ".inst 0x6465e412 // bfmmla z18.s, z0.h, z5.h\n" + ".inst 0x6465e45e // bfmmla z30.s, z2.h, z5.h\n" + ".inst 0x6466e40d // bfmmla z13.s, z0.h, z6.h\n" + ".inst 0x6466e459 // bfmmla z25.s, z2.h, z6.h\n" + ".inst 0x6467e413 // bfmmla z19.s, z0.h, z7.h\n" + ".inst 0x6467e45f // bfmmla z31.s, z2.h, z7.h\n" + "54:" // Height 4: Multiply loop: multiply skip + "ldr w19, [%x[args_ptr], %[offsetof_num_strings]]\n" + "add x25, x25, #0x1\n" + "cmp x25, x19\n" + "bne 49b\n" + "ldr x19, [%x[args_ptr], %[offsetof_output_offset]]\n" + "add x22, x12, x19, LSL #2\n" + "add x21, x22, x19, LSL #2\n" + "uzp1 z4.d, z8.d, z14.d\n" + "uzp2 z8.d, z8.d, z14.d\n" + "uzp1 z14.d, z9.d, z15.d\n" + "add x20, x21, x19, LSL #2\n" + "uzp2 z9.d, z9.d, z15.d\n" + "uzp1 z15.d, z10.d, z16.d\n" + "uzp2 z10.d, z10.d, z16.d\n" + "uzp1 z16.d, z11.d, z17.d\n" + "uzp2 z11.d, z11.d, z17.d\n" + "uzp1 z17.d, z12.d, z18.d\n" + "uzp2 z12.d, z12.d, z18.d\n" + "uzp1 z18.d, z13.d, z19.d\n" + "uzp2 z13.d, z13.d, z19.d\n" + "uzp1 z19.d, z20.d, z26.d\n" + "uzp2 z20.d, z20.d, z26.d\n" + "uzp1 z26.d, z21.d, z27.d\n" + "uzp2 z21.d, z21.d, z27.d\n" + "uzp1 z27.d, z22.d, z28.d\n" + "uzp2 z22.d, z22.d, z28.d\n" + "uzp1 z28.d, z23.d, z29.d\n" + "uzp2 z23.d, z23.d, z29.d\n" + "uzp1 z29.d, z24.d, z30.d\n" + "uzp2 z24.d, z24.d, z30.d\n" + "uzp1 z30.d, z25.d, z31.d\n" + "uzp2 z25.d, z25.d, z31.d\n" + "tbz %x[flags], #1, 55f\n" + "add x19, %x[args_ptr], %[offset_max]\n" + "ld1rw { z1.s }, p7/Z, [x19]\n" + "add x19, %x[args_ptr], %[offset_min]\n" + "ld1rw { z0.s }, p7/Z, [x19]\n" + "fmin z4.s, p7/M, z4.s, z1.s\n" + "fmin z14.s, p7/M, z14.s, z1.s\n" + "fmin z15.s, p7/M, z15.s, z1.s\n" + "fmin z16.s, p7/M, z16.s, z1.s\n" + "fmin z17.s, p7/M, z17.s, z1.s\n" + "fmin z18.s, p7/M, z18.s, z1.s\n" + "fmin z8.s, p7/M, z8.s, z1.s\n" + "fmin z9.s, p7/M, z9.s, z1.s\n" + "fmin z10.s, p7/M, z10.s, z1.s\n" + "fmin z11.s, p7/M, z11.s, z1.s\n" + "fmin z12.s, p7/M, z12.s, z1.s\n" + "fmin z13.s, p7/M, z13.s, z1.s\n" + "fmin z19.s, p7/M, z19.s, z1.s\n" + "fmin z26.s, p7/M, z26.s, z1.s\n" + "fmin z27.s, p7/M, z27.s, z1.s\n" + "fmin z28.s, p7/M, z28.s, z1.s\n" + "fmin z29.s, p7/M, z29.s, z1.s\n" + "fmin z30.s, p7/M, z30.s, z1.s\n" + "fmin z20.s, p7/M, z20.s, z1.s\n" + "fmin z21.s, p7/M, z21.s, z1.s\n" + "fmin z22.s, p7/M, z22.s, z1.s\n" + "fmin z23.s, p7/M, z23.s, z1.s\n" + "fmin z24.s, p7/M, z24.s, z1.s\n" + "fmin z25.s, p7/M, z25.s, z1.s\n" + "fmax z4.s, p7/M, z4.s, z0.s\n" + "fmax z14.s, p7/M, z14.s, z0.s\n" + "fmax z15.s, p7/M, z15.s, z0.s\n" + "fmax z16.s, p7/M, z16.s, z0.s\n" + "fmax z17.s, p7/M, z17.s, z0.s\n" + "fmax z18.s, p7/M, z18.s, z0.s\n" + "fmax z8.s, p7/M, z8.s, z0.s\n" + "fmax z9.s, p7/M, z9.s, z0.s\n" + "fmax z10.s, p7/M, z10.s, z0.s\n" + "fmax z11.s, p7/M, z11.s, z0.s\n" + "fmax z12.s, p7/M, z12.s, z0.s\n" + "fmax z13.s, p7/M, z13.s, z0.s\n" + "fmax z19.s, p7/M, z19.s, z0.s\n" + "fmax z26.s, p7/M, z26.s, z0.s\n" + "fmax z27.s, p7/M, z27.s, z0.s\n" + "fmax z28.s, p7/M, z28.s, z0.s\n" + "fmax z29.s, p7/M, z29.s, z0.s\n" + "fmax z30.s, p7/M, z30.s, z0.s\n" + "fmax z20.s, p7/M, z20.s, z0.s\n" + "fmax z21.s, p7/M, z21.s, z0.s\n" + "fmax z22.s, p7/M, z22.s, z0.s\n" + "fmax z23.s, p7/M, z23.s, z0.s\n" + "fmax z24.s, p7/M, z24.s, z0.s\n" + "fmax z25.s, p7/M, z25.s, z0.s\n" + "55:" // Height 4: No activation + "st1w { z4.s }, p6, [x12]\n" + "st1w { z14.s }, p5, [x12, #1, MUL VL]\n" + "st1w { z15.s }, p4, [x12, #2, MUL VL]\n" + "st1w { z16.s }, p3, [x12, #3, MUL VL]\n" + "st1w { z17.s }, p2, [x12, #4, MUL VL]\n" + "st1w { z18.s }, p1, [x12, #5, MUL VL]\n" + "addvl x12, x12, #6\n" + "st1w { z8.s }, p6, [x22]\n" + "st1w { z9.s }, p5, [x22, #1, MUL VL]\n" + "st1w { z10.s }, p4, [x22, #2, MUL VL]\n" + "st1w { z11.s }, p3, [x22, #3, MUL VL]\n" + "st1w { z12.s }, p2, [x22, #4, MUL VL]\n" + "st1w { z13.s }, p1, [x22, #5, MUL VL]\n" + "st1w { z19.s }, p6, [x21]\n" + "st1w { z26.s }, p5, [x21, #1, MUL VL]\n" + "st1w { z27.s }, p4, [x21, #2, MUL VL]\n" + "st1w { z28.s }, p3, [x21, #3, MUL VL]\n" + "st1w { z29.s }, p2, [x21, #4, MUL VL]\n" + "st1w { z30.s }, p1, [x21, #5, MUL VL]\n" + "st1w { z20.s }, p6, [x20]\n" + "st1w { z21.s }, p5, [x20, #1, MUL VL]\n" + "st1w { z22.s }, p4, [x20, #2, MUL VL]\n" + "st1w { z23.s }, p3, [x20, #3, MUL VL]\n" + "st1w { z24.s }, p2, [x20, #4, MUL VL]\n" + "st1w { z25.s }, p1, [x20, #5, MUL VL]\n" + "56:" // Height 4: Writeback done + "decw x13, ALL, MUL #6\n" + "cmp x13, XZR\n" + "bgt 44b\n" + "subs %x[M], %x[M], #0x4\n" + "beq 58f\n" + "ldr x20, [%x[args_ptr], %[offsetof_input_offset]]\n" + "tbz %x[flags], #3, 57f\n" + "add x20, x20, #0x4\n" + "str x20, [%x[args_ptr], %[offsetof_input_offset]]\n" + "b 1b\n" + "57:" // Update direct input + "mov x19, #0x10\n" + "madd %x[input_ptr], x19, x20, %x[input_ptr]\n" + "b 1b\n" + "58:" // Exit + : [M] "+&r" (M), [input_ptr] "+&r" (input_ptr), [output_ptr] "+&r" (output_ptr) + : [args_ptr] "r" (&ka), [bias] "r" (bias), [flags] "r" (flags), [offset_max] "I" (offsetof(KernelArgs, maxval)), [offset_min] "I" (offsetof(KernelArgs, minval)), [offsetof_B_ptr] "I" (offsetof(KernelArgs, B_ptr)), [offsetof_B_stride] "I" (offsetof(KernelArgs, B_stride)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_cur_B_ptr] "I" (offsetof(KernelArgs, cur_B_ptr)), [offsetof_input_initial_col] "I" (offsetof(KernelArgs, input_initial_col)), [offsetof_input_offset] "I" (offsetof(KernelArgs, input_offset)), [offsetof_num_strings] "I" (offsetof(KernelArgs, num_strings)), [offsetof_output_offset] "I" (offsetof(KernelArgs, output_offset)), [offsetof_string_lengths] "I" (offsetof(KernelArgs, string_lengths)) + : "cc", "memory", "p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", "x9", "x10", "x11", "x12", "x13", "x14", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // namespace arm_gemm +#endif // ARM_COMPUTE_ENABLE_SVE diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp new file mode 100644 index 0000000000..5792a7152d --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#pragma once +#ifdef ARM_COMPUTE_ENABLE_SVE + +#include "../std_transforms_sve.hpp" +#include "../bfloat.hpp" +#include "../kernel_weight_format.hpp" +#include "../performance_parameters.hpp" + +#define ARGLIST \ + const bfloat16 *, const bfloat16 *, size_t, \ + float *, int, size_t, int + +namespace arm_gemm +{ +// Actual kernel implementations +void sve_ffinterleaved_bf16fp32_mmla_8x3VL( ARGLIST ); + +class cls_sve_ffinterleaved_bf16fp32_mmla_8x3VL +{ +public: + typedef bfloat16 operand_type; + typedef float result_type; + + typedef void (*kern_type)( ARGLIST ); + + /* Kernel blocking parameters */ + static constexpr unsigned int out_height() + { + return 8; + } + + static unsigned int out_width() + { + return get_vector_length() * 3; + } + static unsigned int stripe_width() + { + return get_vector_length(); + } + + static KernelWeightFormat kernel_weight_format() + { + return KernelWeightFormat::VL2VL_BL64; + } + + static constexpr unsigned int k_unroll() + { + return 4; + } + + + StdTransformsSVE transforms = {}; + StdTransformsSVE transforms_quantized = {}; + template + static inline PerformanceParameters get_performance_parameters(const CPUInfo *ci) + { + + if (std::is_same::value) { + switch (ci->get_cpu_model()) { + default: + return { 39.90, 8.55, 4.42 }; + } + } + + + if (std::is_same::value) { + switch (ci->get_cpu_model()) { + default: + return { 39.66, 5.18, 4.37 }; + } + } + + return { 1.0 }; + } + + // Default to the generic kernel + kern_type kernel=sve_ffinterleaved_bf16fp32_mmla_8x3VL; + cls_sve_ffinterleaved_bf16fp32_mmla_8x3VL(const CPUInfo *) + { + } +}; + +} // namespace arm_gemm + +#undef ARGLIST +#endif // ARM_COMPUTE_ENABLE_SVE diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL/generic.cpp new file mode 100644 index 0000000000..1f1e08d3dd --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL/generic.cpp @@ -0,0 +1,319 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#ifdef ARM_COMPUTE_ENABLE_SVE + +#include +#include "../../bfloat.hpp" + +namespace arm_gemm { + +void sve_ffinterleaved_bf16fp32_mmla_8x3VL( + const bfloat16 *Apanel, + const bfloat16 *Bpanel, + size_t B_stride, + float *Cpanel, + int ablocks, + size_t N, + int K) { + + struct KernelArgs { + size_t K = {}; + const bfloat16 *Bpanel = {}; + size_t N = {}; + size_t B_stride = {}; + const bfloat16 *cur_B_ptr = {}; + } ka; + + ka.K = (K/4) - 1; + ka.Bpanel = Bpanel; + ka.N = N; + ka.B_stride = B_stride; + + __asm__ __volatile__( + "ptrue p0.b\n" + "1:" // Height loop + "ldr x25, [%x[args_ptr], %[offsetof_Bpanel]]\n" + "ldr x24, [%x[args_ptr], %[offsetof_N]]\n" + "str x25, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x23, %x[Apanel]\n" + "2:" // Width loop + "ldr x25, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "cntw x22, ALL, MUL #2\n" + "add x21, x25, x19, LSL #1\n" + "add x20, x21, x19, LSL #1\n" + "add x19, x20, x19, LSL #1\n" + "cmp x24, x22\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov %x[Apanel], x23\n" + "bgt 3f\n" + "decw x22\n" + "cmp x24, x22\n" + "mov x20, x25\n" + "bgt 3f\n" + "mov x21, x25\n" + "3:" // B setup done + "ldr x19, [%x[args_ptr], %[offsetof_K]]\n" + "cmp x19, #0x2\n" + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "ld1h { z4.h }, p0/Z, [x25]\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "ld1rqh { z0.h }, p0/Z, [%x[Apanel]]\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "ld1rqh { z1.h }, p0/Z, [%x[Apanel], #16]\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "ld1h { z5.h }, p0/Z, [x25, #1, MUL VL]\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "ld1rqh { z2.h }, p0/Z, [%x[Apanel], #32]\n" + "mov z19.b, #0x0\n" + "mov z20.b, #0x0\n" + "addvl x25, x25, #2\n" + "mov z21.b, #0x0\n" + "mov z22.b, #0x0\n" + "add %x[Apanel], %x[Apanel], #0x30\n" + "mov z23.b, #0x0\n" + "mov z24.b, #0x0\n" + "mov z25.b, #0x0\n" + "mov z26.b, #0x0\n" + "mov z27.b, #0x0\n" + "mov z28.b, #0x0\n" + "mov z29.b, #0x0\n" + "mov z30.b, #0x0\n" + "mov z31.b, #0x0\n" + "blt 5f\n" + "4:" // main loop head + "ld1rqh { z3.h }, p0/Z, [%x[Apanel]]\n" + ".inst 0x6464e408 // bfmmla z8.s, z0.h, z4.h\n" + ".inst 0x6465e40b // bfmmla z11.s, z0.h, z5.h\n" + ".inst 0x6464e42e // bfmmla z14.s, z1.h, z4.h\n" + ".inst 0x6465e431 // bfmmla z17.s, z1.h, z5.h\n" + "ld1h { z6.h }, p0/Z, [x21]\n" + ".inst 0x6464e454 // bfmmla z20.s, z2.h, z4.h\n" + ".inst 0x6465e457 // bfmmla z23.s, z2.h, z5.h\n" + "ld1h { z7.h }, p0/Z, [x21, #1, MUL VL]\n" + ".inst 0x6464e47a // bfmmla z26.s, z3.h, z4.h\n" + ".inst 0x6465e47d // bfmmla z29.s, z3.h, z5.h\n" + "ld1h { z4.h }, p0/Z, [x20]\n" + "ld1h { z5.h }, p0/Z, [x20, #1, MUL VL]\n" + ".inst 0x6466e409 // bfmmla z9.s, z0.h, z6.h\n" + ".inst 0x6467e40c // bfmmla z12.s, z0.h, z7.h\n" + ".inst 0x6466e42f // bfmmla z15.s, z1.h, z6.h\n" + ".inst 0x6467e432 // bfmmla z18.s, z1.h, z7.h\n" + "sub x19, x19, #0x2\n" + ".inst 0x6466e455 // bfmmla z21.s, z2.h, z6.h\n" + ".inst 0x6467e458 // bfmmla z24.s, z2.h, z7.h\n" + "cmp x19, #0x2\n" + ".inst 0x6466e47b // bfmmla z27.s, z3.h, z6.h\n" + ".inst 0x6467e47e // bfmmla z30.s, z3.h, z7.h\n" + "ld1h { z6.h }, p0/Z, [x25]\n" + ".inst 0x6464e40a // bfmmla z10.s, z0.h, z4.h\n" + ".inst 0x6465e40d // bfmmla z13.s, z0.h, z5.h\n" + "ld1rqh { z0.h }, p0/Z, [%x[Apanel], #16]\n" + ".inst 0x6464e430 // bfmmla z16.s, z1.h, z4.h\n" + ".inst 0x6465e433 // bfmmla z19.s, z1.h, z5.h\n" + "ld1rqh { z1.h }, p0/Z, [%x[Apanel], #32]\n" + ".inst 0x6464e456 // bfmmla z22.s, z2.h, z4.h\n" + ".inst 0x6465e459 // bfmmla z25.s, z2.h, z5.h\n" + "ld1h { z7.h }, p0/Z, [x25, #1, MUL VL]\n" + ".inst 0x6464e47c // bfmmla z28.s, z3.h, z4.h\n" + ".inst 0x6465e47f // bfmmla z31.s, z3.h, z5.h\n" + "ld1rqh { z2.h }, p0/Z, [%x[Apanel], #48]\n" + "ld1rqh { z3.h }, p0/Z, [%x[Apanel], #64]\n" + ".inst 0x6466e408 // bfmmla z8.s, z0.h, z6.h\n" + ".inst 0x6467e40b // bfmmla z11.s, z0.h, z7.h\n" + ".inst 0x6466e42e // bfmmla z14.s, z1.h, z6.h\n" + ".inst 0x6467e431 // bfmmla z17.s, z1.h, z7.h\n" + "ld1h { z4.h }, p0/Z, [x21, #2, MUL VL]\n" + ".inst 0x6466e454 // bfmmla z20.s, z2.h, z6.h\n" + ".inst 0x6467e457 // bfmmla z23.s, z2.h, z7.h\n" + "ld1h { z5.h }, p0/Z, [x21, #3, MUL VL]\n" + ".inst 0x6466e47a // bfmmla z26.s, z3.h, z6.h\n" + ".inst 0x6467e47d // bfmmla z29.s, z3.h, z7.h\n" + "ld1h { z6.h }, p0/Z, [x20, #2, MUL VL]\n" + "ld1h { z7.h }, p0/Z, [x20, #3, MUL VL]\n" + ".inst 0x6464e409 // bfmmla z9.s, z0.h, z4.h\n" + ".inst 0x6465e40c // bfmmla z12.s, z0.h, z5.h\n" + ".inst 0x6464e42f // bfmmla z15.s, z1.h, z4.h\n" + ".inst 0x6465e432 // bfmmla z18.s, z1.h, z5.h\n" + "addvl x21, x21, #4\n" + ".inst 0x6464e455 // bfmmla z21.s, z2.h, z4.h\n" + ".inst 0x6465e458 // bfmmla z24.s, z2.h, z5.h\n" + "addvl x20, x20, #4\n" + ".inst 0x6464e47b // bfmmla z27.s, z3.h, z4.h\n" + ".inst 0x6465e47e // bfmmla z30.s, z3.h, z5.h\n" + "ld1h { z4.h }, p0/Z, [x25, #2, MUL VL]\n" + ".inst 0x6466e40a // bfmmla z10.s, z0.h, z6.h\n" + ".inst 0x6467e40d // bfmmla z13.s, z0.h, z7.h\n" + "ld1rqh { z0.h }, p0/Z, [%x[Apanel], #80]\n" + ".inst 0x6466e430 // bfmmla z16.s, z1.h, z6.h\n" + ".inst 0x6467e433 // bfmmla z19.s, z1.h, z7.h\n" + "ld1rqh { z1.h }, p0/Z, [%x[Apanel], #96]\n" + ".inst 0x6466e456 // bfmmla z22.s, z2.h, z6.h\n" + ".inst 0x6467e459 // bfmmla z25.s, z2.h, z7.h\n" + "ld1h { z5.h }, p0/Z, [x25, #3, MUL VL]\n" + ".inst 0x6466e47c // bfmmla z28.s, z3.h, z6.h\n" + ".inst 0x6467e47f // bfmmla z31.s, z3.h, z7.h\n" + "ld1rqh { z2.h }, p0/Z, [%x[Apanel], #112]\n" + "add %x[Apanel], %x[Apanel], #0x80\n" + "addvl x25, x25, #4\n" + "bge 4b\n" + "5:" // main loop skip + "ld1rqh { z3.h }, p0/Z, [%x[Apanel]]\n" + ".inst 0x6464e408 // bfmmla z8.s, z0.h, z4.h\n" + ".inst 0x6465e40b // bfmmla z11.s, z0.h, z5.h\n" + ".inst 0x6464e42e // bfmmla z14.s, z1.h, z4.h\n" + ".inst 0x6465e431 // bfmmla z17.s, z1.h, z5.h\n" + "ld1h { z6.h }, p0/Z, [x21]\n" + ".inst 0x6464e454 // bfmmla z20.s, z2.h, z4.h\n" + ".inst 0x6465e457 // bfmmla z23.s, z2.h, z5.h\n" + "ld1h { z7.h }, p0/Z, [x21, #1, MUL VL]\n" + ".inst 0x6464e47a // bfmmla z26.s, z3.h, z4.h\n" + ".inst 0x6465e47d // bfmmla z29.s, z3.h, z5.h\n" + "ld1h { z4.h }, p0/Z, [x20]\n" + "ld1h { z5.h }, p0/Z, [x20, #1, MUL VL]\n" + ".inst 0x6466e409 // bfmmla z9.s, z0.h, z6.h\n" + ".inst 0x6467e40c // bfmmla z12.s, z0.h, z7.h\n" + ".inst 0x6466e42f // bfmmla z15.s, z1.h, z6.h\n" + ".inst 0x6467e432 // bfmmla z18.s, z1.h, z7.h\n" + "add %x[Apanel], %x[Apanel], #0x10\n" + ".inst 0x6466e455 // bfmmla z21.s, z2.h, z6.h\n" + ".inst 0x6467e458 // bfmmla z24.s, z2.h, z7.h\n" + "addvl x21, x21, #2\n" + ".inst 0x6466e47b // bfmmla z27.s, z3.h, z6.h\n" + ".inst 0x6467e47e // bfmmla z30.s, z3.h, z7.h\n" + "addvl x20, x20, #2\n" + ".inst 0x6464e40a // bfmmla z10.s, z0.h, z4.h\n" + ".inst 0x6465e40d // bfmmla z13.s, z0.h, z5.h\n" + ".inst 0x6464e430 // bfmmla z16.s, z1.h, z4.h\n" + ".inst 0x6465e433 // bfmmla z19.s, z1.h, z5.h\n" + ".inst 0x6464e456 // bfmmla z22.s, z2.h, z4.h\n" + ".inst 0x6465e459 // bfmmla z25.s, z2.h, z5.h\n" + ".inst 0x6464e47c // bfmmla z28.s, z3.h, z4.h\n" + ".inst 0x6465e47f // bfmmla z31.s, z3.h, z5.h\n" + "cbz x19, 6f\n" + "ld1h { z6.h }, p0/Z, [x25]\n" + "ld1rqh { z0.h }, p0/Z, [%x[Apanel]]\n" + ".inst 0x6466e408 // bfmmla z8.s, z0.h, z6.h\n" + "ld1rqh { z1.h }, p0/Z, [%x[Apanel], #16]\n" + "ld1h { z7.h }, p0/Z, [x25, #1, MUL VL]\n" + ".inst 0x6467e40b // bfmmla z11.s, z0.h, z7.h\n" + "ld1rqh { z2.h }, p0/Z, [%x[Apanel], #32]\n" + "ld1rqh { z3.h }, p0/Z, [%x[Apanel], #48]\n" + ".inst 0x6466e42e // bfmmla z14.s, z1.h, z6.h\n" + ".inst 0x6467e431 // bfmmla z17.s, z1.h, z7.h\n" + ".inst 0x6466e454 // bfmmla z20.s, z2.h, z6.h\n" + "ld1h { z4.h }, p0/Z, [x21]\n" + ".inst 0x6467e457 // bfmmla z23.s, z2.h, z7.h\n" + ".inst 0x6466e47a // bfmmla z26.s, z3.h, z6.h\n" + "ld1h { z5.h }, p0/Z, [x21, #1, MUL VL]\n" + ".inst 0x6467e47d // bfmmla z29.s, z3.h, z7.h\n" + "ld1h { z6.h }, p0/Z, [x20]\n" + "ld1h { z7.h }, p0/Z, [x20, #1, MUL VL]\n" + ".inst 0x6464e409 // bfmmla z9.s, z0.h, z4.h\n" + ".inst 0x6465e40c // bfmmla z12.s, z0.h, z5.h\n" + "add %x[Apanel], %x[Apanel], #0x40\n" + ".inst 0x6464e42f // bfmmla z15.s, z1.h, z4.h\n" + ".inst 0x6465e432 // bfmmla z18.s, z1.h, z5.h\n" + ".inst 0x6464e455 // bfmmla z21.s, z2.h, z4.h\n" + ".inst 0x6465e458 // bfmmla z24.s, z2.h, z5.h\n" + ".inst 0x6464e47b // bfmmla z27.s, z3.h, z4.h\n" + ".inst 0x6465e47e // bfmmla z30.s, z3.h, z5.h\n" + ".inst 0x6466e40a // bfmmla z10.s, z0.h, z6.h\n" + ".inst 0x6467e40d // bfmmla z13.s, z0.h, z7.h\n" + ".inst 0x6466e430 // bfmmla z16.s, z1.h, z6.h\n" + ".inst 0x6467e433 // bfmmla z19.s, z1.h, z7.h\n" + ".inst 0x6466e456 // bfmmla z22.s, z2.h, z6.h\n" + ".inst 0x6467e459 // bfmmla z25.s, z2.h, z7.h\n" + ".inst 0x6466e47c // bfmmla z28.s, z3.h, z6.h\n" + ".inst 0x6467e47f // bfmmla z31.s, z3.h, z7.h\n" + "6:" // multiply loop done + "decw x24, ALL, MUL #3\n" + "uzp1 z4.d, z8.d, z11.d\n" + "uzp2 z8.d, z8.d, z11.d\n" + "uzp1 z11.d, z9.d, z12.d\n" + "uzp2 z9.d, z9.d, z12.d\n" + "st1w { z4.s }, p0, [%x[Cpanel]]\n" + "uzp1 z12.d, z10.d, z13.d\n" + "uzp2 z10.d, z10.d, z13.d\n" + "st1w { z11.s }, p0, [%x[Cpanel], #1, MUL VL]\n" + "st1w { z12.s }, p0, [%x[Cpanel], #2, MUL VL]\n" + "uzp1 z13.d, z14.d, z17.d\n" + "uzp2 z14.d, z14.d, z17.d\n" + "st1w { z8.s }, p0, [%x[Cpanel], #3, MUL VL]\n" + "uzp1 z17.d, z15.d, z18.d\n" + "cmp x24, XZR\n" + "st1w { z9.s }, p0, [%x[Cpanel], #4, MUL VL]\n" + "uzp2 z15.d, z15.d, z18.d\n" + "uzp1 z18.d, z16.d, z19.d\n" + "st1w { z10.s }, p0, [%x[Cpanel], #5, MUL VL]\n" + "uzp2 z16.d, z16.d, z19.d\n" + "uzp1 z19.d, z20.d, z23.d\n" + "st1w { z13.s }, p0, [%x[Cpanel], #6, MUL VL]\n" + "uzp2 z20.d, z20.d, z23.d\n" + "uzp1 z23.d, z21.d, z24.d\n" + "st1w { z17.s }, p0, [%x[Cpanel], #7, MUL VL]\n" + "addvl %x[Cpanel], %x[Cpanel], #16\n" + "uzp2 z21.d, z21.d, z24.d\n" + "st1w { z18.s }, p0, [%x[Cpanel], #-8, MUL VL]\n" + "uzp1 z24.d, z22.d, z25.d\n" + "uzp2 z22.d, z22.d, z25.d\n" + "st1w { z14.s }, p0, [%x[Cpanel], #-7, MUL VL]\n" + "uzp1 z25.d, z26.d, z29.d\n" + "uzp2 z26.d, z26.d, z29.d\n" + "st1w { z15.s }, p0, [%x[Cpanel], #-6, MUL VL]\n" + "uzp1 z29.d, z27.d, z30.d\n" + "uzp2 z27.d, z27.d, z30.d\n" + "st1w { z16.s }, p0, [%x[Cpanel], #-5, MUL VL]\n" + "uzp1 z30.d, z28.d, z31.d\n" + "uzp2 z28.d, z28.d, z31.d\n" + "st1w { z19.s }, p0, [%x[Cpanel], #-4, MUL VL]\n" + "st1w { z23.s }, p0, [%x[Cpanel], #-3, MUL VL]\n" + "st1w { z24.s }, p0, [%x[Cpanel], #-2, MUL VL]\n" + "st1w { z20.s }, p0, [%x[Cpanel], #-1, MUL VL]\n" + "st1w { z21.s }, p0, [%x[Cpanel]]\n" + "st1w { z22.s }, p0, [%x[Cpanel], #1, MUL VL]\n" + "st1w { z25.s }, p0, [%x[Cpanel], #2, MUL VL]\n" + "st1w { z29.s }, p0, [%x[Cpanel], #3, MUL VL]\n" + "st1w { z30.s }, p0, [%x[Cpanel], #4, MUL VL]\n" + "st1w { z26.s }, p0, [%x[Cpanel], #5, MUL VL]\n" + "st1w { z27.s }, p0, [%x[Cpanel], #6, MUL VL]\n" + "st1w { z28.s }, p0, [%x[Cpanel], #7, MUL VL]\n" + "addvl %x[Cpanel], %x[Cpanel], #8\n" + "bgt 2b\n" + "subs %x[ablocks], %x[ablocks], #0x1\n" + "bne 1b\n" + : [Apanel] "+&r" (Apanel), [Cpanel] "+&r" (Cpanel), [ablocks] "+&r" (ablocks) + : [args_ptr] "r" (&ka), [offsetof_B_stride] "I" (offsetof(KernelArgs, B_stride)), [offsetof_Bpanel] "I" (offsetof(KernelArgs, Bpanel)), [offsetof_K] "I" (offsetof(KernelArgs, K)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_cur_B_ptr] "I" (offsetof(KernelArgs, cur_B_ptr)) + : "cc", "memory", "p0", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // namespace arm_gemm +#endif // ARM_COMPUTE_ENABLE_SVE diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp16_mla_8x3VL.hpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp16_mla_8x3VL.hpp new file mode 100644 index 0000000000..6d36bf8bbf --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp16_mla_8x3VL.hpp @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#pragma once +#ifdef ARM_COMPUTE_ENABLE_SVE + +#include "../std_transforms_sve.hpp" +#include "../kernel_weight_format.hpp" +#include "../performance_parameters.hpp" + +#define ARGLIST \ + const __fp16 *, const __fp16 *, size_t, \ + __fp16 *, int, size_t, int + +namespace arm_gemm +{ +// Actual kernel implementations +void sve_ffinterleaved_fp16_mla_8x3VL( ARGLIST ); +void sve_ffinterleaved_fp16_mla_8x3VL_a64fx( ARGLIST ); + +class cls_sve_ffinterleaved_fp16_mla_8x3VL +{ +public: + typedef __fp16 operand_type; + typedef __fp16 result_type; + + typedef void (*kern_type)( ARGLIST ); + + /* Kernel blocking parameters */ + static constexpr unsigned int out_height() + { + return 8; + } + + static unsigned int out_width() + { + return get_vector_length<__fp16>() * 3; + } + static unsigned int stripe_width() + { + return get_vector_length<__fp16>(); + } + + static KernelWeightFormat kernel_weight_format() + { + return KernelWeightFormat::VL1VL_BL16; + } + + static constexpr unsigned int k_unroll() + { + return 1; + } + + + StdTransformsSVE transforms = {}; + StdTransformsSVE transforms_quantized = {}; + template + static inline PerformanceParameters get_performance_parameters(const CPUInfo *ci) + { + + if (std::is_same::value) { + switch (ci->get_cpu_model()) { + default: + return { 25.53, 7.89, 3.82 }; + } + } + + return { 1.0 }; + } + + // Default to the generic kernel + kern_type kernel=sve_ffinterleaved_fp16_mla_8x3VL; + cls_sve_ffinterleaved_fp16_mla_8x3VL(const CPUInfo *ci) + { + switch(ci->get_cpu_model()) { + default: + break; + case CPUModel::A64FX: + kernel=sve_ffinterleaved_fp16_mla_8x3VL_a64fx; + break; + } + } +}; + +} // namespace arm_gemm + +#undef ARGLIST +#endif // ARM_COMPUTE_ENABLE_SVE diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp16_mla_8x3VL/a64fx.cpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp16_mla_8x3VL/a64fx.cpp new file mode 100644 index 0000000000..cd4da2c124 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp16_mla_8x3VL/a64fx.cpp @@ -0,0 +1,297 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#ifdef ARM_COMPUTE_ENABLE_SVE + +#include + +namespace arm_gemm { + +void sve_ffinterleaved_fp16_mla_8x3VL_a64fx( + const __fp16 *Apanel, + const __fp16 *Bpanel, + size_t B_stride, + __fp16 *Cpanel, + int ablocks, + size_t N, + int K) { + + struct KernelArgs { + size_t K = {}; + const __fp16 *Bpanel = {}; + size_t N = {}; + size_t B_stride = {}; + const __fp16 *cur_B_ptr = {}; + } ka; + + ka.K = (K/1) - 1; + ka.Bpanel = Bpanel; + ka.N = N; + ka.B_stride = B_stride; + + __asm__ __volatile__( + "ptrue p0.b\n" + "1:" // Height loop + "ldr x25, [%x[args_ptr], %[offsetof_Bpanel]]\n" + "ldr x24, [%x[args_ptr], %[offsetof_N]]\n" + "str x25, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x23, %x[Apanel]\n" + "2:" // Width loop + "ldr x25, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "cnth x22, ALL, MUL #2\n" + "add x21, x25, x19, LSL #1\n" + "add x20, x21, x19, LSL #1\n" + "add x19, x20, x19, LSL #1\n" + "cmp x24, x22\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov %x[Apanel], x23\n" + "bgt 3f\n" + "dech x22\n" + "cmp x24, x22\n" + "mov x20, x25\n" + "bgt 3f\n" + "mov x21, x25\n" + "3:" // B setup done + "ldr x19, [%x[args_ptr], %[offsetof_K]]\n" + "cmp x19, #0x2\n" + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "ld1h { z0.h }, p0/Z, [x25]\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "ld1h { z1.h }, p0/Z, [x21]\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "ld1h { z2.h }, p0/Z, [x20]\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "ld1rh { z3.h }, p0/Z, [%x[Apanel]]\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "ld1rh { z4.h }, p0/Z, [%x[Apanel], #2]\n" + "mov z19.b, #0x0\n" + "mov z20.b, #0x0\n" + "ld1rh { z5.h }, p0/Z, [%x[Apanel], #4]\n" + "mov z21.b, #0x0\n" + "mov z22.b, #0x0\n" + "ld1rh { z6.h }, p0/Z, [%x[Apanel], #6]\n" + "mov z23.b, #0x0\n" + "mov z24.b, #0x0\n" + "mov z25.b, #0x0\n" + "mov z26.b, #0x0\n" + "mov z27.b, #0x0\n" + "mov z28.b, #0x0\n" + "mov z29.b, #0x0\n" + "mov z30.b, #0x0\n" + "mov z31.b, #0x0\n" + "blt 5f\n" + "4:" // main loop head + "fmla z8.h, p0/M, z0.h, z3.h\n" + "fmla z9.h, p0/M, z1.h, z3.h\n" + "sub x19, x19, #0x2\n" + "fmla z10.h, p0/M, z2.h, z3.h\n" + "ld1rh { z3.h }, p0/Z, [%x[Apanel], #8]\n" + "fmla z11.h, p0/M, z0.h, z4.h\n" + "fmla z12.h, p0/M, z1.h, z4.h\n" + "fmla z13.h, p0/M, z2.h, z4.h\n" + "ld1rh { z4.h }, p0/Z, [%x[Apanel], #10]\n" + "fmla z14.h, p0/M, z0.h, z5.h\n" + "fmla z15.h, p0/M, z1.h, z5.h\n" + "cmp x19, #0x2\n" + "fmla z16.h, p0/M, z2.h, z5.h\n" + "ld1rh { z5.h }, p0/Z, [%x[Apanel], #12]\n" + "fmla z17.h, p0/M, z0.h, z6.h\n" + "fmla z18.h, p0/M, z1.h, z6.h\n" + "fmla z19.h, p0/M, z2.h, z6.h\n" + "ld1rh { z6.h }, p0/Z, [%x[Apanel], #14]\n" + "fmla z20.h, p0/M, z0.h, z3.h\n" + "fmla z21.h, p0/M, z1.h, z3.h\n" + "fmla z22.h, p0/M, z2.h, z3.h\n" + "ld1rh { z3.h }, p0/Z, [%x[Apanel], #16]\n" + "fmla z23.h, p0/M, z0.h, z4.h\n" + "fmla z24.h, p0/M, z1.h, z4.h\n" + "fmla z25.h, p0/M, z2.h, z4.h\n" + "ld1rh { z4.h }, p0/Z, [%x[Apanel], #18]\n" + "fmla z26.h, p0/M, z0.h, z5.h\n" + "fmla z27.h, p0/M, z1.h, z5.h\n" + "fmla z28.h, p0/M, z2.h, z5.h\n" + "ld1rh { z5.h }, p0/Z, [%x[Apanel], #20]\n" + "fmla z29.h, p0/M, z0.h, z6.h\n" + "ld1h { z0.h }, p0/Z, [x25, #1, MUL VL]\n" + "fmla z30.h, p0/M, z1.h, z6.h\n" + "fmla z31.h, p0/M, z2.h, z6.h\n" + "ld1h { z1.h }, p0/Z, [x21, #1, MUL VL]\n" + "ld1h { z2.h }, p0/Z, [x20, #1, MUL VL]\n" + "fmla z8.h, p0/M, z0.h, z3.h\n" + "ld1rh { z6.h }, p0/Z, [%x[Apanel], #22]\n" + "fmla z9.h, p0/M, z1.h, z3.h\n" + "fmla z10.h, p0/M, z2.h, z3.h\n" + "fmla z11.h, p0/M, z0.h, z4.h\n" + "ld1rh { z3.h }, p0/Z, [%x[Apanel], #24]\n" + "fmla z12.h, p0/M, z1.h, z4.h\n" + "fmla z13.h, p0/M, z2.h, z4.h\n" + "ld1rh { z4.h }, p0/Z, [%x[Apanel], #26]\n" + "fmla z14.h, p0/M, z0.h, z5.h\n" + "fmla z15.h, p0/M, z1.h, z5.h\n" + "addvl x25, x25, #2\n" + "fmla z16.h, p0/M, z2.h, z5.h\n" + "ld1rh { z5.h }, p0/Z, [%x[Apanel], #28]\n" + "fmla z17.h, p0/M, z0.h, z6.h\n" + "fmla z18.h, p0/M, z1.h, z6.h\n" + "fmla z19.h, p0/M, z2.h, z6.h\n" + "ld1rh { z6.h }, p0/Z, [%x[Apanel], #30]\n" + "addvl x21, x21, #2\n" + "addvl x20, x20, #2\n" + "add %x[Apanel], %x[Apanel], #0x20\n" + "fmla z20.h, p0/M, z0.h, z3.h\n" + "fmla z21.h, p0/M, z1.h, z3.h\n" + "fmla z22.h, p0/M, z2.h, z3.h\n" + "ld1rh { z3.h }, p0/Z, [%x[Apanel]]\n" + "fmla z23.h, p0/M, z0.h, z4.h\n" + "fmla z24.h, p0/M, z1.h, z4.h\n" + "fmla z25.h, p0/M, z2.h, z4.h\n" + "fmla z26.h, p0/M, z0.h, z5.h\n" + "ld1rh { z4.h }, p0/Z, [%x[Apanel], #2]\n" + "fmla z27.h, p0/M, z1.h, z5.h\n" + "fmla z28.h, p0/M, z2.h, z5.h\n" + "ld1rh { z5.h }, p0/Z, [%x[Apanel], #4]\n" + "fmla z29.h, p0/M, z0.h, z6.h\n" + "ld1h { z0.h }, p0/Z, [x25]\n" + "fmla z30.h, p0/M, z1.h, z6.h\n" + "fmla z31.h, p0/M, z2.h, z6.h\n" + "ld1h { z1.h }, p0/Z, [x21]\n" + "ld1h { z2.h }, p0/Z, [x20]\n" + "ld1rh { z6.h }, p0/Z, [%x[Apanel], #6]\n" + "bge 4b\n" + "5:" // main loop skip + "fmla z8.h, p0/M, z0.h, z3.h\n" + "fmla z9.h, p0/M, z1.h, z3.h\n" + "addvl x25, x25, #1\n" + "fmla z10.h, p0/M, z2.h, z3.h\n" + "ld1rh { z3.h }, p0/Z, [%x[Apanel], #8]\n" + "fmla z11.h, p0/M, z0.h, z4.h\n" + "fmla z12.h, p0/M, z1.h, z4.h\n" + "fmla z13.h, p0/M, z2.h, z4.h\n" + "ld1rh { z4.h }, p0/Z, [%x[Apanel], #10]\n" + "fmla z14.h, p0/M, z0.h, z5.h\n" + "fmla z15.h, p0/M, z1.h, z5.h\n" + "addvl x21, x21, #1\n" + "fmla z16.h, p0/M, z2.h, z5.h\n" + "ld1rh { z5.h }, p0/Z, [%x[Apanel], #12]\n" + "fmla z17.h, p0/M, z0.h, z6.h\n" + "fmla z18.h, p0/M, z1.h, z6.h\n" + "fmla z19.h, p0/M, z2.h, z6.h\n" + "ld1rh { z6.h }, p0/Z, [%x[Apanel], #14]\n" + "fmla z20.h, p0/M, z0.h, z3.h\n" + "fmla z21.h, p0/M, z1.h, z3.h\n" + "addvl x20, x20, #1\n" + "fmla z22.h, p0/M, z2.h, z3.h\n" + "fmla z23.h, p0/M, z0.h, z4.h\n" + "add %x[Apanel], %x[Apanel], #0x10\n" + "fmla z24.h, p0/M, z1.h, z4.h\n" + "fmla z25.h, p0/M, z2.h, z4.h\n" + "fmla z26.h, p0/M, z0.h, z5.h\n" + "fmla z27.h, p0/M, z1.h, z5.h\n" + "fmla z28.h, p0/M, z2.h, z5.h\n" + "fmla z29.h, p0/M, z0.h, z6.h\n" + "fmla z30.h, p0/M, z1.h, z6.h\n" + "fmla z31.h, p0/M, z2.h, z6.h\n" + "cbz x19, 6f\n" + "ld1h { z0.h }, p0/Z, [x25]\n" + "ld1h { z1.h }, p0/Z, [x21]\n" + "ld1h { z2.h }, p0/Z, [x20]\n" + "ld1rh { z3.h }, p0/Z, [%x[Apanel]]\n" + "fmla z8.h, p0/M, z0.h, z3.h\n" + "ld1rh { z4.h }, p0/Z, [%x[Apanel], #2]\n" + "ld1rh { z5.h }, p0/Z, [%x[Apanel], #4]\n" + "fmla z9.h, p0/M, z1.h, z3.h\n" + "ld1rh { z6.h }, p0/Z, [%x[Apanel], #6]\n" + "fmla z10.h, p0/M, z2.h, z3.h\n" + "fmla z11.h, p0/M, z0.h, z4.h\n" + "fmla z12.h, p0/M, z1.h, z4.h\n" + "fmla z13.h, p0/M, z2.h, z4.h\n" + "ld1rh { z3.h }, p0/Z, [%x[Apanel], #8]\n" + "fmla z14.h, p0/M, z0.h, z5.h\n" + "fmla z15.h, p0/M, z1.h, z5.h\n" + "ld1rh { z4.h }, p0/Z, [%x[Apanel], #10]\n" + "fmla z16.h, p0/M, z2.h, z5.h\n" + "fmla z17.h, p0/M, z0.h, z6.h\n" + "ld1rh { z5.h }, p0/Z, [%x[Apanel], #12]\n" + "fmla z18.h, p0/M, z1.h, z6.h\n" + "fmla z19.h, p0/M, z2.h, z6.h\n" + "ld1rh { z6.h }, p0/Z, [%x[Apanel], #14]\n" + "fmla z20.h, p0/M, z0.h, z3.h\n" + "fmla z21.h, p0/M, z1.h, z3.h\n" + "add %x[Apanel], %x[Apanel], #0x10\n" + "fmla z22.h, p0/M, z2.h, z3.h\n" + "fmla z23.h, p0/M, z0.h, z4.h\n" + "fmla z24.h, p0/M, z1.h, z4.h\n" + "fmla z25.h, p0/M, z2.h, z4.h\n" + "fmla z26.h, p0/M, z0.h, z5.h\n" + "fmla z27.h, p0/M, z1.h, z5.h\n" + "fmla z28.h, p0/M, z2.h, z5.h\n" + "fmla z29.h, p0/M, z0.h, z6.h\n" + "fmla z30.h, p0/M, z1.h, z6.h\n" + "fmla z31.h, p0/M, z2.h, z6.h\n" + "6:" // multiply loop done + "dech x24, ALL, MUL #3\n" + "st1h { z8.h }, p0, [%x[Cpanel]]\n" + "cmp x24, XZR\n" + "st1h { z9.h }, p0, [%x[Cpanel], #1, MUL VL]\n" + "st1h { z10.h }, p0, [%x[Cpanel], #2, MUL VL]\n" + "st1h { z11.h }, p0, [%x[Cpanel], #3, MUL VL]\n" + "st1h { z12.h }, p0, [%x[Cpanel], #4, MUL VL]\n" + "st1h { z13.h }, p0, [%x[Cpanel], #5, MUL VL]\n" + "st1h { z14.h }, p0, [%x[Cpanel], #6, MUL VL]\n" + "st1h { z15.h }, p0, [%x[Cpanel], #7, MUL VL]\n" + "addvl %x[Cpanel], %x[Cpanel], #16\n" + "st1h { z16.h }, p0, [%x[Cpanel], #-8, MUL VL]\n" + "st1h { z17.h }, p0, [%x[Cpanel], #-7, MUL VL]\n" + "st1h { z18.h }, p0, [%x[Cpanel], #-6, MUL VL]\n" + "st1h { z19.h }, p0, [%x[Cpanel], #-5, MUL VL]\n" + "st1h { z20.h }, p0, [%x[Cpanel], #-4, MUL VL]\n" + "st1h { z21.h }, p0, [%x[Cpanel], #-3, MUL VL]\n" + "st1h { z22.h }, p0, [%x[Cpanel], #-2, MUL VL]\n" + "st1h { z23.h }, p0, [%x[Cpanel], #-1, MUL VL]\n" + "st1h { z24.h }, p0, [%x[Cpanel]]\n" + "st1h { z25.h }, p0, [%x[Cpanel], #1, MUL VL]\n" + "st1h { z26.h }, p0, [%x[Cpanel], #2, MUL VL]\n" + "st1h { z27.h }, p0, [%x[Cpanel], #3, MUL VL]\n" + "st1h { z28.h }, p0, [%x[Cpanel], #4, MUL VL]\n" + "st1h { z29.h }, p0, [%x[Cpanel], #5, MUL VL]\n" + "st1h { z30.h }, p0, [%x[Cpanel], #6, MUL VL]\n" + "st1h { z31.h }, p0, [%x[Cpanel], #7, MUL VL]\n" + "addvl %x[Cpanel], %x[Cpanel], #8\n" + "bgt 2b\n" + "subs %x[ablocks], %x[ablocks], #0x1\n" + "bne 1b\n" + : [Apanel] "+&r" (Apanel), [Cpanel] "+&r" (Cpanel), [ablocks] "+&r" (ablocks) + : [args_ptr] "r" (&ka), [offsetof_B_stride] "I" (offsetof(KernelArgs, B_stride)), [offsetof_Bpanel] "I" (offsetof(KernelArgs, Bpanel)), [offsetof_K] "I" (offsetof(KernelArgs, K)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_cur_B_ptr] "I" (offsetof(KernelArgs, cur_B_ptr)) + : "cc", "memory", "p0", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // namespace arm_gemm +#endif // ARM_COMPUTE_ENABLE_SVE diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp16_mla_8x3VL/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp16_mla_8x3VL/generic.cpp new file mode 100644 index 0000000000..431c2a88f5 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp16_mla_8x3VL/generic.cpp @@ -0,0 +1,269 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#ifdef ARM_COMPUTE_ENABLE_SVE + +#include + +namespace arm_gemm { + +void sve_ffinterleaved_fp16_mla_8x3VL( + const __fp16 *Apanel, + const __fp16 *Bpanel, + size_t B_stride, + __fp16 *Cpanel, + int ablocks, + size_t N, + int K) { + + struct KernelArgs { + size_t K = {}; + const __fp16 *Bpanel = {}; + size_t N = {}; + size_t B_stride = {}; + const __fp16 *cur_B_ptr = {}; + } ka; + + ka.K = (K/1) - 1; + ka.Bpanel = Bpanel; + ka.N = N; + ka.B_stride = B_stride; + + __asm__ __volatile__( + "ptrue p0.b\n" + "1:" // Height loop + "ldr x25, [%x[args_ptr], %[offsetof_Bpanel]]\n" + "ldr x24, [%x[args_ptr], %[offsetof_N]]\n" + "str x25, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x23, %x[Apanel]\n" + "2:" // Width loop + "ldr x25, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "cnth x22, ALL, MUL #2\n" + "add x21, x25, x19, LSL #1\n" + "add x20, x21, x19, LSL #1\n" + "add x19, x20, x19, LSL #1\n" + "cmp x24, x22\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov %x[Apanel], x23\n" + "bgt 3f\n" + "dech x22\n" + "cmp x24, x22\n" + "mov x20, x25\n" + "bgt 3f\n" + "mov x21, x25\n" + "3:" // B setup done + "ldr x19, [%x[args_ptr], %[offsetof_K]]\n" + "cmp x19, #0x2\n" + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "ld1rqh { z0.h }, p0/Z, [%x[Apanel]]\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "ld1h { z2.h }, p0/Z, [x25]\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "ld1h { z3.h }, p0/Z, [x21]\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "ld1h { z4.h }, p0/Z, [x20]\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "mov z19.b, #0x0\n" + "mov z20.b, #0x0\n" + "mov z21.b, #0x0\n" + "mov z22.b, #0x0\n" + "mov z23.b, #0x0\n" + "mov z24.b, #0x0\n" + "mov z25.b, #0x0\n" + "mov z26.b, #0x0\n" + "mov z27.b, #0x0\n" + "mov z28.b, #0x0\n" + "mov z29.b, #0x0\n" + "mov z30.b, #0x0\n" + "mov z31.b, #0x0\n" + "blt 5f\n" + "4:" // main loop head + "fmla z8.h, z2.h, z0.h[0]\n" + "fmla z11.h, z2.h, z0.h[1]\n" + "ld1rqh { z1.h }, p0/Z, [%x[Apanel], #16]\n" + "fmla z14.h, z2.h, z0.h[2]\n" + "fmla z17.h, z2.h, z0.h[3]\n" + "ld1h { z5.h }, p0/Z, [x25, #1, MUL VL]\n" + "fmla z20.h, z2.h, z0.h[4]\n" + "fmla z23.h, z2.h, z0.h[5]\n" + "ld1h { z6.h }, p0/Z, [x21, #1, MUL VL]\n" + "fmla z26.h, z2.h, z0.h[6]\n" + "fmla z29.h, z2.h, z0.h[7]\n" + "ld1h { z7.h }, p0/Z, [x20, #1, MUL VL]\n" + "fmla z9.h, z3.h, z0.h[0]\n" + "fmla z12.h, z3.h, z0.h[1]\n" + "addvl x25, x25, #2\n" + "fmla z15.h, z3.h, z0.h[2]\n" + "fmla z18.h, z3.h, z0.h[3]\n" + "addvl x21, x21, #2\n" + "fmla z21.h, z3.h, z0.h[4]\n" + "fmla z24.h, z3.h, z0.h[5]\n" + "addvl x20, x20, #2\n" + "fmla z27.h, z3.h, z0.h[6]\n" + "fmla z30.h, z3.h, z0.h[7]\n" + "sub x19, x19, #0x2\n" + "fmla z10.h, z4.h, z0.h[0]\n" + "fmla z13.h, z4.h, z0.h[1]\n" + "cmp x19, #0x2\n" + "fmla z16.h, z4.h, z0.h[2]\n" + "fmla z19.h, z4.h, z0.h[3]\n" + "add %x[Apanel], %x[Apanel], #0x20\n" + "fmla z22.h, z4.h, z0.h[4]\n" + "fmla z25.h, z4.h, z0.h[5]\n" + "ld1h { z2.h }, p0/Z, [x25]\n" + "fmla z28.h, z4.h, z0.h[6]\n" + "fmla z31.h, z4.h, z0.h[7]\n" + "ld1rqh { z0.h }, p0/Z, [%x[Apanel]]\n" + "fmla z8.h, z5.h, z1.h[0]\n" + "fmla z11.h, z5.h, z1.h[1]\n" + "ld1h { z3.h }, p0/Z, [x21]\n" + "fmla z14.h, z5.h, z1.h[2]\n" + "fmla z17.h, z5.h, z1.h[3]\n" + "ld1h { z4.h }, p0/Z, [x20]\n" + "fmla z20.h, z5.h, z1.h[4]\n" + "fmla z23.h, z5.h, z1.h[5]\n" + "fmla z26.h, z5.h, z1.h[6]\n" + "fmla z29.h, z5.h, z1.h[7]\n" + "fmla z9.h, z6.h, z1.h[0]\n" + "fmla z12.h, z6.h, z1.h[1]\n" + "fmla z15.h, z6.h, z1.h[2]\n" + "fmla z18.h, z6.h, z1.h[3]\n" + "fmla z21.h, z6.h, z1.h[4]\n" + "fmla z24.h, z6.h, z1.h[5]\n" + "fmla z27.h, z6.h, z1.h[6]\n" + "fmla z30.h, z6.h, z1.h[7]\n" + "fmla z10.h, z7.h, z1.h[0]\n" + "fmla z13.h, z7.h, z1.h[1]\n" + "fmla z16.h, z7.h, z1.h[2]\n" + "fmla z19.h, z7.h, z1.h[3]\n" + "fmla z22.h, z7.h, z1.h[4]\n" + "fmla z25.h, z7.h, z1.h[5]\n" + "fmla z28.h, z7.h, z1.h[6]\n" + "fmla z31.h, z7.h, z1.h[7]\n" + "bge 4b\n" + "5:" // main loop skip + "fmla z8.h, z2.h, z0.h[0]\n" + "fmla z11.h, z2.h, z0.h[1]\n" + "add %x[Apanel], %x[Apanel], #0x10\n" + "fmla z14.h, z2.h, z0.h[2]\n" + "fmla z17.h, z2.h, z0.h[3]\n" + "addvl x25, x25, #1\n" + "fmla z20.h, z2.h, z0.h[4]\n" + "fmla z23.h, z2.h, z0.h[5]\n" + "addvl x21, x21, #1\n" + "fmla z26.h, z2.h, z0.h[6]\n" + "fmla z29.h, z2.h, z0.h[7]\n" + "addvl x20, x20, #1\n" + "fmla z9.h, z3.h, z0.h[0]\n" + "fmla z12.h, z3.h, z0.h[1]\n" + "fmla z15.h, z3.h, z0.h[2]\n" + "fmla z18.h, z3.h, z0.h[3]\n" + "fmla z21.h, z3.h, z0.h[4]\n" + "fmla z24.h, z3.h, z0.h[5]\n" + "fmla z27.h, z3.h, z0.h[6]\n" + "fmla z30.h, z3.h, z0.h[7]\n" + "fmla z10.h, z4.h, z0.h[0]\n" + "fmla z13.h, z4.h, z0.h[1]\n" + "fmla z16.h, z4.h, z0.h[2]\n" + "fmla z19.h, z4.h, z0.h[3]\n" + "fmla z22.h, z4.h, z0.h[4]\n" + "fmla z25.h, z4.h, z0.h[5]\n" + "fmla z28.h, z4.h, z0.h[6]\n" + "fmla z31.h, z4.h, z0.h[7]\n" + "cbz x19, 6f\n" + "ld1rqh { z0.h }, p0/Z, [%x[Apanel]]\n" + "ld1h { z5.h }, p0/Z, [x25]\n" + "fmla z8.h, z5.h, z0.h[0]\n" + "ld1h { z6.h }, p0/Z, [x21]\n" + "ld1h { z7.h }, p0/Z, [x20]\n" + "fmla z11.h, z5.h, z0.h[1]\n" + "fmla z14.h, z5.h, z0.h[2]\n" + "fmla z17.h, z5.h, z0.h[3]\n" + "add %x[Apanel], %x[Apanel], #0x10\n" + "fmla z20.h, z5.h, z0.h[4]\n" + "fmla z23.h, z5.h, z0.h[5]\n" + "fmla z26.h, z5.h, z0.h[6]\n" + "fmla z29.h, z5.h, z0.h[7]\n" + "fmla z9.h, z6.h, z0.h[0]\n" + "fmla z12.h, z6.h, z0.h[1]\n" + "fmla z15.h, z6.h, z0.h[2]\n" + "fmla z18.h, z6.h, z0.h[3]\n" + "fmla z21.h, z6.h, z0.h[4]\n" + "fmla z24.h, z6.h, z0.h[5]\n" + "fmla z27.h, z6.h, z0.h[6]\n" + "fmla z30.h, z6.h, z0.h[7]\n" + "fmla z10.h, z7.h, z0.h[0]\n" + "fmla z13.h, z7.h, z0.h[1]\n" + "fmla z16.h, z7.h, z0.h[2]\n" + "fmla z19.h, z7.h, z0.h[3]\n" + "fmla z22.h, z7.h, z0.h[4]\n" + "fmla z25.h, z7.h, z0.h[5]\n" + "fmla z28.h, z7.h, z0.h[6]\n" + "fmla z31.h, z7.h, z0.h[7]\n" + "6:" // multiply loop done + "dech x24, ALL, MUL #3\n" + "st1h { z8.h }, p0, [%x[Cpanel]]\n" + "cmp x24, XZR\n" + "st1h { z9.h }, p0, [%x[Cpanel], #1, MUL VL]\n" + "st1h { z10.h }, p0, [%x[Cpanel], #2, MUL VL]\n" + "st1h { z11.h }, p0, [%x[Cpanel], #3, MUL VL]\n" + "st1h { z12.h }, p0, [%x[Cpanel], #4, MUL VL]\n" + "st1h { z13.h }, p0, [%x[Cpanel], #5, MUL VL]\n" + "st1h { z14.h }, p0, [%x[Cpanel], #6, MUL VL]\n" + "st1h { z15.h }, p0, [%x[Cpanel], #7, MUL VL]\n" + "addvl %x[Cpanel], %x[Cpanel], #16\n" + "st1h { z16.h }, p0, [%x[Cpanel], #-8, MUL VL]\n" + "st1h { z17.h }, p0, [%x[Cpanel], #-7, MUL VL]\n" + "st1h { z18.h }, p0, [%x[Cpanel], #-6, MUL VL]\n" + "st1h { z19.h }, p0, [%x[Cpanel], #-5, MUL VL]\n" + "st1h { z20.h }, p0, [%x[Cpanel], #-4, MUL VL]\n" + "st1h { z21.h }, p0, [%x[Cpanel], #-3, MUL VL]\n" + "st1h { z22.h }, p0, [%x[Cpanel], #-2, MUL VL]\n" + "st1h { z23.h }, p0, [%x[Cpanel], #-1, MUL VL]\n" + "st1h { z24.h }, p0, [%x[Cpanel]]\n" + "st1h { z25.h }, p0, [%x[Cpanel], #1, MUL VL]\n" + "st1h { z26.h }, p0, [%x[Cpanel], #2, MUL VL]\n" + "st1h { z27.h }, p0, [%x[Cpanel], #3, MUL VL]\n" + "st1h { z28.h }, p0, [%x[Cpanel], #4, MUL VL]\n" + "st1h { z29.h }, p0, [%x[Cpanel], #5, MUL VL]\n" + "st1h { z30.h }, p0, [%x[Cpanel], #6, MUL VL]\n" + "st1h { z31.h }, p0, [%x[Cpanel], #7, MUL VL]\n" + "addvl %x[Cpanel], %x[Cpanel], #8\n" + "bgt 2b\n" + "subs %x[ablocks], %x[ablocks], #0x1\n" + "bne 1b\n" + : [Apanel] "+&r" (Apanel), [Cpanel] "+&r" (Cpanel), [ablocks] "+&r" (ablocks) + : [args_ptr] "r" (&ka), [offsetof_B_stride] "I" (offsetof(KernelArgs, B_stride)), [offsetof_Bpanel] "I" (offsetof(KernelArgs, Bpanel)), [offsetof_K] "I" (offsetof(KernelArgs, K)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_cur_B_ptr] "I" (offsetof(KernelArgs, cur_B_ptr)) + : "cc", "memory", "p0", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // namespace arm_gemm +#endif // ARM_COMPUTE_ENABLE_SVE diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp32_mla_8x3VL.hpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp32_mla_8x3VL.hpp new file mode 100644 index 0000000000..aa3507ee73 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp32_mla_8x3VL.hpp @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#pragma once +#ifdef ARM_COMPUTE_ENABLE_SVE + +#include "../std_transforms_sve.hpp" +#include "../kernel_weight_format.hpp" +#include "../performance_parameters.hpp" + +#define ARGLIST \ + const float *, const float *, size_t, \ + float *, int, size_t, int + +namespace arm_gemm +{ +// Actual kernel implementations +void sve_ffinterleaved_fp32_mla_8x3VL( ARGLIST ); +void sve_ffinterleaved_fp32_mla_8x3VL_a64fx( ARGLIST ); + +class cls_sve_ffinterleaved_fp32_mla_8x3VL +{ +public: + typedef float operand_type; + typedef float result_type; + + typedef void (*kern_type)( ARGLIST ); + + /* Kernel blocking parameters */ + static constexpr unsigned int out_height() + { + return 8; + } + + static unsigned int out_width() + { + return get_vector_length() * 3; + } + static unsigned int stripe_width() + { + return get_vector_length(); + } + + static KernelWeightFormat kernel_weight_format() + { + return KernelWeightFormat::VL1VL_BL32; + } + + static constexpr unsigned int k_unroll() + { + return 1; + } + + + StdTransformsSVE transforms = {}; + StdTransformsSVE transforms_quantized = {}; + template + static inline PerformanceParameters get_performance_parameters(const CPUInfo *ci) + { + + if (std::is_same::value) { + switch (ci->get_cpu_model()) { + default: + return { 13.51, 9.27, 3.98 }; + } + } + + return { 1.0 }; + } + + // Default to the generic kernel + kern_type kernel=sve_ffinterleaved_fp32_mla_8x3VL; + cls_sve_ffinterleaved_fp32_mla_8x3VL(const CPUInfo *ci) + { + switch(ci->get_cpu_model()) { + default: + break; + case CPUModel::A64FX: + kernel=sve_ffinterleaved_fp32_mla_8x3VL_a64fx; + break; + } + } +}; + +} // namespace arm_gemm + +#undef ARGLIST +#endif // ARM_COMPUTE_ENABLE_SVE diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp32_mla_8x3VL/a64fx.cpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp32_mla_8x3VL/a64fx.cpp new file mode 100644 index 0000000000..aecf7f94c9 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp32_mla_8x3VL/a64fx.cpp @@ -0,0 +1,297 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#ifdef ARM_COMPUTE_ENABLE_SVE + +#include + +namespace arm_gemm { + +void sve_ffinterleaved_fp32_mla_8x3VL_a64fx( + const float *Apanel, + const float *Bpanel, + size_t B_stride, + float *Cpanel, + int ablocks, + size_t N, + int K) { + + struct KernelArgs { + size_t K = {}; + const float *Bpanel = {}; + size_t N = {}; + size_t B_stride = {}; + const float *cur_B_ptr = {}; + } ka; + + ka.K = (K/1) - 1; + ka.Bpanel = Bpanel; + ka.N = N; + ka.B_stride = B_stride; + + __asm__ __volatile__( + "ptrue p0.b\n" + "1:" // Height loop + "ldr x25, [%x[args_ptr], %[offsetof_Bpanel]]\n" + "ldr x24, [%x[args_ptr], %[offsetof_N]]\n" + "str x25, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x23, %x[Apanel]\n" + "2:" // Width loop + "ldr x25, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "cntw x22, ALL, MUL #2\n" + "add x21, x25, x19, LSL #2\n" + "add x20, x21, x19, LSL #2\n" + "add x19, x20, x19, LSL #2\n" + "cmp x24, x22\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov %x[Apanel], x23\n" + "bgt 3f\n" + "decw x22\n" + "cmp x24, x22\n" + "mov x20, x25\n" + "bgt 3f\n" + "mov x21, x25\n" + "3:" // B setup done + "ldr x19, [%x[args_ptr], %[offsetof_K]]\n" + "cmp x19, #0x2\n" + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "ld1w { z0.s }, p0/Z, [x25]\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "ld1w { z1.s }, p0/Z, [x21]\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "ld1w { z2.s }, p0/Z, [x20]\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "ld1rw { z3.s }, p0/Z, [%x[Apanel]]\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "ld1rw { z4.s }, p0/Z, [%x[Apanel], #4]\n" + "mov z19.b, #0x0\n" + "mov z20.b, #0x0\n" + "ld1rw { z5.s }, p0/Z, [%x[Apanel], #8]\n" + "mov z21.b, #0x0\n" + "mov z22.b, #0x0\n" + "ld1rw { z6.s }, p0/Z, [%x[Apanel], #12]\n" + "mov z23.b, #0x0\n" + "mov z24.b, #0x0\n" + "mov z25.b, #0x0\n" + "mov z26.b, #0x0\n" + "mov z27.b, #0x0\n" + "mov z28.b, #0x0\n" + "mov z29.b, #0x0\n" + "mov z30.b, #0x0\n" + "mov z31.b, #0x0\n" + "blt 5f\n" + "4:" // main loop head + "fmla z8.s, p0/M, z0.s, z3.s\n" + "fmla z9.s, p0/M, z1.s, z3.s\n" + "sub x19, x19, #0x2\n" + "fmla z10.s, p0/M, z2.s, z3.s\n" + "ld1rw { z3.s }, p0/Z, [%x[Apanel], #16]\n" + "fmla z11.s, p0/M, z0.s, z4.s\n" + "fmla z12.s, p0/M, z1.s, z4.s\n" + "fmla z13.s, p0/M, z2.s, z4.s\n" + "ld1rw { z4.s }, p0/Z, [%x[Apanel], #20]\n" + "fmla z14.s, p0/M, z0.s, z5.s\n" + "fmla z15.s, p0/M, z1.s, z5.s\n" + "cmp x19, #0x2\n" + "fmla z16.s, p0/M, z2.s, z5.s\n" + "ld1rw { z5.s }, p0/Z, [%x[Apanel], #24]\n" + "fmla z17.s, p0/M, z0.s, z6.s\n" + "fmla z18.s, p0/M, z1.s, z6.s\n" + "fmla z19.s, p0/M, z2.s, z6.s\n" + "ld1rw { z6.s }, p0/Z, [%x[Apanel], #28]\n" + "fmla z20.s, p0/M, z0.s, z3.s\n" + "fmla z21.s, p0/M, z1.s, z3.s\n" + "fmla z22.s, p0/M, z2.s, z3.s\n" + "ld1rw { z3.s }, p0/Z, [%x[Apanel], #32]\n" + "fmla z23.s, p0/M, z0.s, z4.s\n" + "fmla z24.s, p0/M, z1.s, z4.s\n" + "fmla z25.s, p0/M, z2.s, z4.s\n" + "ld1rw { z4.s }, p0/Z, [%x[Apanel], #36]\n" + "fmla z26.s, p0/M, z0.s, z5.s\n" + "fmla z27.s, p0/M, z1.s, z5.s\n" + "fmla z28.s, p0/M, z2.s, z5.s\n" + "ld1rw { z5.s }, p0/Z, [%x[Apanel], #40]\n" + "fmla z29.s, p0/M, z0.s, z6.s\n" + "ld1w { z0.s }, p0/Z, [x25, #1, MUL VL]\n" + "fmla z30.s, p0/M, z1.s, z6.s\n" + "fmla z31.s, p0/M, z2.s, z6.s\n" + "ld1w { z1.s }, p0/Z, [x21, #1, MUL VL]\n" + "ld1w { z2.s }, p0/Z, [x20, #1, MUL VL]\n" + "fmla z8.s, p0/M, z0.s, z3.s\n" + "ld1rw { z6.s }, p0/Z, [%x[Apanel], #44]\n" + "fmla z9.s, p0/M, z1.s, z3.s\n" + "fmla z10.s, p0/M, z2.s, z3.s\n" + "fmla z11.s, p0/M, z0.s, z4.s\n" + "ld1rw { z3.s }, p0/Z, [%x[Apanel], #48]\n" + "fmla z12.s, p0/M, z1.s, z4.s\n" + "fmla z13.s, p0/M, z2.s, z4.s\n" + "ld1rw { z4.s }, p0/Z, [%x[Apanel], #52]\n" + "fmla z14.s, p0/M, z0.s, z5.s\n" + "fmla z15.s, p0/M, z1.s, z5.s\n" + "addvl x25, x25, #2\n" + "fmla z16.s, p0/M, z2.s, z5.s\n" + "ld1rw { z5.s }, p0/Z, [%x[Apanel], #56]\n" + "fmla z17.s, p0/M, z0.s, z6.s\n" + "fmla z18.s, p0/M, z1.s, z6.s\n" + "fmla z19.s, p0/M, z2.s, z6.s\n" + "ld1rw { z6.s }, p0/Z, [%x[Apanel], #60]\n" + "addvl x21, x21, #2\n" + "addvl x20, x20, #2\n" + "add %x[Apanel], %x[Apanel], #0x40\n" + "fmla z20.s, p0/M, z0.s, z3.s\n" + "fmla z21.s, p0/M, z1.s, z3.s\n" + "fmla z22.s, p0/M, z2.s, z3.s\n" + "ld1rw { z3.s }, p0/Z, [%x[Apanel]]\n" + "fmla z23.s, p0/M, z0.s, z4.s\n" + "fmla z24.s, p0/M, z1.s, z4.s\n" + "fmla z25.s, p0/M, z2.s, z4.s\n" + "fmla z26.s, p0/M, z0.s, z5.s\n" + "ld1rw { z4.s }, p0/Z, [%x[Apanel], #4]\n" + "fmla z27.s, p0/M, z1.s, z5.s\n" + "fmla z28.s, p0/M, z2.s, z5.s\n" + "ld1rw { z5.s }, p0/Z, [%x[Apanel], #8]\n" + "fmla z29.s, p0/M, z0.s, z6.s\n" + "ld1w { z0.s }, p0/Z, [x25]\n" + "fmla z30.s, p0/M, z1.s, z6.s\n" + "fmla z31.s, p0/M, z2.s, z6.s\n" + "ld1w { z1.s }, p0/Z, [x21]\n" + "ld1w { z2.s }, p0/Z, [x20]\n" + "ld1rw { z6.s }, p0/Z, [%x[Apanel], #12]\n" + "bge 4b\n" + "5:" // main loop skip + "fmla z8.s, p0/M, z0.s, z3.s\n" + "fmla z9.s, p0/M, z1.s, z3.s\n" + "addvl x25, x25, #1\n" + "fmla z10.s, p0/M, z2.s, z3.s\n" + "ld1rw { z3.s }, p0/Z, [%x[Apanel], #16]\n" + "fmla z11.s, p0/M, z0.s, z4.s\n" + "fmla z12.s, p0/M, z1.s, z4.s\n" + "fmla z13.s, p0/M, z2.s, z4.s\n" + "ld1rw { z4.s }, p0/Z, [%x[Apanel], #20]\n" + "fmla z14.s, p0/M, z0.s, z5.s\n" + "fmla z15.s, p0/M, z1.s, z5.s\n" + "addvl x21, x21, #1\n" + "fmla z16.s, p0/M, z2.s, z5.s\n" + "ld1rw { z5.s }, p0/Z, [%x[Apanel], #24]\n" + "fmla z17.s, p0/M, z0.s, z6.s\n" + "fmla z18.s, p0/M, z1.s, z6.s\n" + "fmla z19.s, p0/M, z2.s, z6.s\n" + "ld1rw { z6.s }, p0/Z, [%x[Apanel], #28]\n" + "fmla z20.s, p0/M, z0.s, z3.s\n" + "fmla z21.s, p0/M, z1.s, z3.s\n" + "addvl x20, x20, #1\n" + "fmla z22.s, p0/M, z2.s, z3.s\n" + "fmla z23.s, p0/M, z0.s, z4.s\n" + "add %x[Apanel], %x[Apanel], #0x20\n" + "fmla z24.s, p0/M, z1.s, z4.s\n" + "fmla z25.s, p0/M, z2.s, z4.s\n" + "fmla z26.s, p0/M, z0.s, z5.s\n" + "fmla z27.s, p0/M, z1.s, z5.s\n" + "fmla z28.s, p0/M, z2.s, z5.s\n" + "fmla z29.s, p0/M, z0.s, z6.s\n" + "fmla z30.s, p0/M, z1.s, z6.s\n" + "fmla z31.s, p0/M, z2.s, z6.s\n" + "cbz x19, 6f\n" + "ld1w { z0.s }, p0/Z, [x25]\n" + "ld1w { z1.s }, p0/Z, [x21]\n" + "ld1w { z2.s }, p0/Z, [x20]\n" + "ld1rw { z3.s }, p0/Z, [%x[Apanel]]\n" + "fmla z8.s, p0/M, z0.s, z3.s\n" + "ld1rw { z4.s }, p0/Z, [%x[Apanel], #4]\n" + "ld1rw { z5.s }, p0/Z, [%x[Apanel], #8]\n" + "fmla z9.s, p0/M, z1.s, z3.s\n" + "ld1rw { z6.s }, p0/Z, [%x[Apanel], #12]\n" + "fmla z10.s, p0/M, z2.s, z3.s\n" + "fmla z11.s, p0/M, z0.s, z4.s\n" + "fmla z12.s, p0/M, z1.s, z4.s\n" + "fmla z13.s, p0/M, z2.s, z4.s\n" + "ld1rw { z3.s }, p0/Z, [%x[Apanel], #16]\n" + "fmla z14.s, p0/M, z0.s, z5.s\n" + "fmla z15.s, p0/M, z1.s, z5.s\n" + "ld1rw { z4.s }, p0/Z, [%x[Apanel], #20]\n" + "fmla z16.s, p0/M, z2.s, z5.s\n" + "fmla z17.s, p0/M, z0.s, z6.s\n" + "ld1rw { z5.s }, p0/Z, [%x[Apanel], #24]\n" + "fmla z18.s, p0/M, z1.s, z6.s\n" + "fmla z19.s, p0/M, z2.s, z6.s\n" + "ld1rw { z6.s }, p0/Z, [%x[Apanel], #28]\n" + "fmla z20.s, p0/M, z0.s, z3.s\n" + "fmla z21.s, p0/M, z1.s, z3.s\n" + "add %x[Apanel], %x[Apanel], #0x20\n" + "fmla z22.s, p0/M, z2.s, z3.s\n" + "fmla z23.s, p0/M, z0.s, z4.s\n" + "fmla z24.s, p0/M, z1.s, z4.s\n" + "fmla z25.s, p0/M, z2.s, z4.s\n" + "fmla z26.s, p0/M, z0.s, z5.s\n" + "fmla z27.s, p0/M, z1.s, z5.s\n" + "fmla z28.s, p0/M, z2.s, z5.s\n" + "fmla z29.s, p0/M, z0.s, z6.s\n" + "fmla z30.s, p0/M, z1.s, z6.s\n" + "fmla z31.s, p0/M, z2.s, z6.s\n" + "6:" // multiply loop done + "decw x24, ALL, MUL #3\n" + "st1w { z8.s }, p0, [%x[Cpanel]]\n" + "cmp x24, XZR\n" + "st1w { z9.s }, p0, [%x[Cpanel], #1, MUL VL]\n" + "st1w { z10.s }, p0, [%x[Cpanel], #2, MUL VL]\n" + "st1w { z11.s }, p0, [%x[Cpanel], #3, MUL VL]\n" + "st1w { z12.s }, p0, [%x[Cpanel], #4, MUL VL]\n" + "st1w { z13.s }, p0, [%x[Cpanel], #5, MUL VL]\n" + "st1w { z14.s }, p0, [%x[Cpanel], #6, MUL VL]\n" + "st1w { z15.s }, p0, [%x[Cpanel], #7, MUL VL]\n" + "addvl %x[Cpanel], %x[Cpanel], #16\n" + "st1w { z16.s }, p0, [%x[Cpanel], #-8, MUL VL]\n" + "st1w { z17.s }, p0, [%x[Cpanel], #-7, MUL VL]\n" + "st1w { z18.s }, p0, [%x[Cpanel], #-6, MUL VL]\n" + "st1w { z19.s }, p0, [%x[Cpanel], #-5, MUL VL]\n" + "st1w { z20.s }, p0, [%x[Cpanel], #-4, MUL VL]\n" + "st1w { z21.s }, p0, [%x[Cpanel], #-3, MUL VL]\n" + "st1w { z22.s }, p0, [%x[Cpanel], #-2, MUL VL]\n" + "st1w { z23.s }, p0, [%x[Cpanel], #-1, MUL VL]\n" + "st1w { z24.s }, p0, [%x[Cpanel]]\n" + "st1w { z25.s }, p0, [%x[Cpanel], #1, MUL VL]\n" + "st1w { z26.s }, p0, [%x[Cpanel], #2, MUL VL]\n" + "st1w { z27.s }, p0, [%x[Cpanel], #3, MUL VL]\n" + "st1w { z28.s }, p0, [%x[Cpanel], #4, MUL VL]\n" + "st1w { z29.s }, p0, [%x[Cpanel], #5, MUL VL]\n" + "st1w { z30.s }, p0, [%x[Cpanel], #6, MUL VL]\n" + "st1w { z31.s }, p0, [%x[Cpanel], #7, MUL VL]\n" + "addvl %x[Cpanel], %x[Cpanel], #8\n" + "bgt 2b\n" + "subs %x[ablocks], %x[ablocks], #0x1\n" + "bne 1b\n" + : [Apanel] "+&r" (Apanel), [Cpanel] "+&r" (Cpanel), [ablocks] "+&r" (ablocks) + : [args_ptr] "r" (&ka), [offsetof_B_stride] "I" (offsetof(KernelArgs, B_stride)), [offsetof_Bpanel] "I" (offsetof(KernelArgs, Bpanel)), [offsetof_K] "I" (offsetof(KernelArgs, K)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_cur_B_ptr] "I" (offsetof(KernelArgs, cur_B_ptr)) + : "cc", "memory", "p0", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // namespace arm_gemm +#endif // ARM_COMPUTE_ENABLE_SVE diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp32_mla_8x3VL/generic.cpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp32_mla_8x3VL/generic.cpp new file mode 100644 index 0000000000..1e9a3f119e --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_ffinterleaved_fp32_mla_8x3VL/generic.cpp @@ -0,0 +1,273 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ +#ifdef ARM_COMPUTE_ENABLE_SVE + +#include + +namespace arm_gemm { + +void sve_ffinterleaved_fp32_mla_8x3VL( + const float *Apanel, + const float *Bpanel, + size_t B_stride, + float *Cpanel, + int ablocks, + size_t N, + int K) { + + struct KernelArgs { + size_t K = {}; + const float *Bpanel = {}; + size_t N = {}; + size_t B_stride = {}; + const float *cur_B_ptr = {}; + } ka; + + ka.K = (K/1) - 1; + ka.Bpanel = Bpanel; + ka.N = N; + ka.B_stride = B_stride; + + __asm__ __volatile__( + "ptrue p0.b\n" + "1:" // Height loop + "ldr x25, [%x[args_ptr], %[offsetof_Bpanel]]\n" + "ldr x24, [%x[args_ptr], %[offsetof_N]]\n" + "str x25, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov x23, %x[Apanel]\n" + "2:" // Width loop + "ldr x25, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "ldr x19, [%x[args_ptr], %[offsetof_B_stride]]\n" + "cntw x22, ALL, MUL #2\n" + "add x21, x25, x19, LSL #2\n" + "add x20, x21, x19, LSL #2\n" + "add x19, x20, x19, LSL #2\n" + "cmp x24, x22\n" + "str x19, [%x[args_ptr], %[offsetof_cur_B_ptr]]\n" + "mov %x[Apanel], x23\n" + "bgt 3f\n" + "decw x22\n" + "cmp x24, x22\n" + "mov x20, x25\n" + "bgt 3f\n" + "mov x21, x25\n" + "3:" // B setup done + "ldr x19, [%x[args_ptr], %[offsetof_K]]\n" + "cmp x19, #0x2\n" + "mov z8.b, #0x0\n" + "mov z9.b, #0x0\n" + "mov z10.b, #0x0\n" + "ld1rqw { z0.s }, p0/Z, [%x[Apanel]]\n" + "mov z11.b, #0x0\n" + "mov z12.b, #0x0\n" + "ld1rqw { z1.s }, p0/Z, [%x[Apanel], #16]\n" + "mov z13.b, #0x0\n" + "mov z14.b, #0x0\n" + "ld1w { z4.s }, p0/Z, [x25]\n" + "mov z15.b, #0x0\n" + "mov z16.b, #0x0\n" + "ld1w { z5.s }, p0/Z, [x21]\n" + "mov z17.b, #0x0\n" + "mov z18.b, #0x0\n" + "ld1w { z6.s }, p0/Z, [x20]\n" + "mov z19.b, #0x0\n" + "mov z20.b, #0x0\n" + "mov z21.b, #0x0\n" + "mov z22.b, #0x0\n" + "mov z23.b, #0x0\n" + "mov z24.b, #0x0\n" + "mov z25.b, #0x0\n" + "mov z26.b, #0x0\n" + "mov z27.b, #0x0\n" + "mov z28.b, #0x0\n" + "mov z29.b, #0x0\n" + "mov z30.b, #0x0\n" + "mov z31.b, #0x0\n" + "blt 5f\n" + "4:" // main loop head + "fmla z8.s, z4.s, z0.s[0]\n" + "fmla z11.s, z4.s, z0.s[1]\n" + "ld1rqw { z2.s }, p0/Z, [%x[Apanel], #32]\n" + "fmla z14.s, z4.s, z0.s[2]\n" + "fmla z17.s, z4.s, z0.s[3]\n" + "ld1rqw { z3.s }, p0/Z, [%x[Apanel], #48]\n" + "fmla z20.s, z4.s, z1.s[0]\n" + "fmla z23.s, z4.s, z1.s[1]\n" + "sub x19, x19, #0x2\n" + "fmla z26.s, z4.s, z1.s[2]\n" + "fmla z29.s, z4.s, z1.s[3]\n" + "ld1w { z4.s }, p0/Z, [x25, #1, MUL VL]\n" + "fmla z9.s, z5.s, z0.s[0]\n" + "fmla z12.s, z5.s, z0.s[1]\n" + "addvl x25, x25, #2\n" + "fmla z15.s, z5.s, z0.s[2]\n" + "fmla z18.s, z5.s, z0.s[3]\n" + "cmp x19, #0x2\n" + "fmla z21.s, z5.s, z1.s[0]\n" + "fmla z24.s, z5.s, z1.s[1]\n" + "add %x[Apanel], %x[Apanel], #0x40\n" + "fmla z27.s, z5.s, z1.s[2]\n" + "fmla z30.s, z5.s, z1.s[3]\n" + "ld1w { z5.s }, p0/Z, [x21, #1, MUL VL]\n" + "fmla z10.s, z6.s, z0.s[0]\n" + "fmla z13.s, z6.s, z0.s[1]\n" + "addvl x21, x21, #2\n" + "fmla z16.s, z6.s, z0.s[2]\n" + "fmla z19.s, z6.s, z0.s[3]\n" + "ld1rqw { z0.s }, p0/Z, [%x[Apanel]]\n" + "fmla z22.s, z6.s, z1.s[0]\n" + "fmla z25.s, z6.s, z1.s[1]\n" + "fmla z28.s, z6.s, z1.s[2]\n" + "fmla z31.s, z6.s, z1.s[3]\n" + "ld1w { z6.s }, p0/Z, [x20, #1, MUL VL]\n" + "addvl x20, x20, #2\n" + "fmla z8.s, z4.s, z2.s[0]\n" + "fmla z11.s, z4.s, z2.s[1]\n" + "fmla z14.s, z4.s, z2.s[2]\n" + "fmla z17.s, z4.s, z2.s[3]\n" + "ld1rqw { z1.s }, p0/Z, [%x[Apanel], #16]\n" + "fmla z20.s, z4.s, z3.s[0]\n" + "fmla z23.s, z4.s, z3.s[1]\n" + "fmla z26.s, z4.s, z3.s[2]\n" + "fmla z29.s, z4.s, z3.s[3]\n" + "ld1w { z4.s }, p0/Z, [x25]\n" + "fmla z9.s, z5.s, z2.s[0]\n" + "fmla z12.s, z5.s, z2.s[1]\n" + "fmla z15.s, z5.s, z2.s[2]\n" + "fmla z18.s, z5.s, z2.s[3]\n" + "fmla z21.s, z5.s, z3.s[0]\n" + "fmla z24.s, z5.s, z3.s[1]\n" + "fmla z27.s, z5.s, z3.s[2]\n" + "fmla z30.s, z5.s, z3.s[3]\n" + "ld1w { z5.s }, p0/Z, [x21]\n" + "fmla z10.s, z6.s, z2.s[0]\n" + "fmla z13.s, z6.s, z2.s[1]\n" + "fmla z16.s, z6.s, z2.s[2]\n" + "fmla z19.s, z6.s, z2.s[3]\n" + "fmla z22.s, z6.s, z3.s[0]\n" + "fmla z25.s, z6.s, z3.s[1]\n" + "fmla z28.s, z6.s, z3.s[2]\n" + "fmla z31.s, z6.s, z3.s[3]\n" + "ld1w { z6.s }, p0/Z, [x20]\n" + "bge 4b\n" + "5:" // main loop skip + "fmla z8.s, z4.s, z0.s[0]\n" + "fmla z11.s, z4.s, z0.s[1]\n" + "add %x[Apanel], %x[Apanel], #0x20\n" + "fmla z14.s, z4.s, z0.s[2]\n" + "fmla z17.s, z4.s, z0.s[3]\n" + "addvl x25, x25, #1\n" + "fmla z20.s, z4.s, z1.s[0]\n" + "fmla z23.s, z4.s, z1.s[1]\n" + "addvl x21, x21, #1\n" + "fmla z26.s, z4.s, z1.s[2]\n" + "fmla z29.s, z4.s, z1.s[3]\n" + "addvl x20, x20, #1\n" + "fmla z9.s, z5.s, z0.s[0]\n" + "fmla z12.s, z5.s, z0.s[1]\n" + "fmla z15.s, z5.s, z0.s[2]\n" + "fmla z18.s, z5.s, z0.s[3]\n" + "fmla z21.s, z5.s, z1.s[0]\n" + "fmla z24.s, z5.s, z1.s[1]\n" + "fmla z27.s, z5.s, z1.s[2]\n" + "fmla z30.s, z5.s, z1.s[3]\n" + "fmla z10.s, z6.s, z0.s[0]\n" + "fmla z13.s, z6.s, z0.s[1]\n" + "fmla z16.s, z6.s, z0.s[2]\n" + "fmla z19.s, z6.s, z0.s[3]\n" + "fmla z22.s, z6.s, z1.s[0]\n" + "fmla z25.s, z6.s, z1.s[1]\n" + "fmla z28.s, z6.s, z1.s[2]\n" + "fmla z31.s, z6.s, z1.s[3]\n" + "cbz x19, 6f\n" + "ld1rqw { z0.s }, p0/Z, [%x[Apanel]]\n" + "ld1rqw { z1.s }, p0/Z, [%x[Apanel], #16]\n" + "add %x[Apanel], %x[Apanel], #0x20\n" + "ld1w { z7.s }, p0/Z, [x25]\n" + "ld1w { z4.s }, p0/Z, [x21]\n" + "fmla z8.s, z7.s, z0.s[0]\n" + "ld1w { z5.s }, p0/Z, [x20]\n" + "fmla z11.s, z7.s, z0.s[1]\n" + "fmla z14.s, z7.s, z0.s[2]\n" + "fmla z17.s, z7.s, z0.s[3]\n" + "fmla z20.s, z7.s, z1.s[0]\n" + "fmla z23.s, z7.s, z1.s[1]\n" + "fmla z26.s, z7.s, z1.s[2]\n" + "fmla z29.s, z7.s, z1.s[3]\n" + "fmla z9.s, z4.s, z0.s[0]\n" + "fmla z12.s, z4.s, z0.s[1]\n" + "fmla z15.s, z4.s, z0.s[2]\n" + "fmla z18.s, z4.s, z0.s[3]\n" + "fmla z21.s, z4.s, z1.s[0]\n" + "fmla z24.s, z4.s, z1.s[1]\n" + "fmla z27.s, z4.s, z1.s[2]\n" + "fmla z30.s, z4.s, z1.s[3]\n" + "fmla z10.s, z5.s, z0.s[0]\n" + "fmla z13.s, z5.s, z0.s[1]\n" + "fmla z16.s, z5.s, z0.s[2]\n" + "fmla z19.s, z5.s, z0.s[3]\n" + "fmla z22.s, z5.s, z1.s[0]\n" + "fmla z25.s, z5.s, z1.s[1]\n" + "fmla z28.s, z5.s, z1.s[2]\n" + "fmla z31.s, z5.s, z1.s[3]\n" + "6:" // multiply loop done + "decw x24, ALL, MUL #3\n" + "st1w { z8.s }, p0, [%x[Cpanel]]\n" + "cmp x24, XZR\n" + "st1w { z9.s }, p0, [%x[Cpanel], #1, MUL VL]\n" + "st1w { z10.s }, p0, [%x[Cpanel], #2, MUL VL]\n" + "st1w { z11.s }, p0, [%x[Cpanel], #3, MUL VL]\n" + "st1w { z12.s }, p0, [%x[Cpanel], #4, MUL VL]\n" + "st1w { z13.s }, p0, [%x[Cpanel], #5, MUL VL]\n" + "st1w { z14.s }, p0, [%x[Cpanel], #6, MUL VL]\n" + "st1w { z15.s }, p0, [%x[Cpanel], #7, MUL VL]\n" + "addvl %x[Cpanel], %x[Cpanel], #16\n" + "st1w { z16.s }, p0, [%x[Cpanel], #-8, MUL VL]\n" + "st1w { z17.s }, p0, [%x[Cpanel], #-7, MUL VL]\n" + "st1w { z18.s }, p0, [%x[Cpanel], #-6, MUL VL]\n" + "st1w { z19.s }, p0, [%x[Cpanel], #-5, MUL VL]\n" + "st1w { z20.s }, p0, [%x[Cpanel], #-4, MUL VL]\n" + "st1w { z21.s }, p0, [%x[Cpanel], #-3, MUL VL]\n" + "st1w { z22.s }, p0, [%x[Cpanel], #-2, MUL VL]\n" + "st1w { z23.s }, p0, [%x[Cpanel], #-1, MUL VL]\n" + "st1w { z24.s }, p0, [%x[Cpanel]]\n" + "st1w { z25.s }, p0, [%x[Cpanel], #1, MUL VL]\n" + "st1w { z26.s }, p0, [%x[Cpanel], #2, MUL VL]\n" + "st1w { z27.s }, p0, [%x[Cpanel], #3, MUL VL]\n" + "st1w { z28.s }, p0, [%x[Cpanel], #4, MUL VL]\n" + "st1w { z29.s }, p0, [%x[Cpanel], #5, MUL VL]\n" + "st1w { z30.s }, p0, [%x[Cpanel], #6, MUL VL]\n" + "st1w { z31.s }, p0, [%x[Cpanel], #7, MUL VL]\n" + "addvl %x[Cpanel], %x[Cpanel], #8\n" + "bgt 2b\n" + "subs %x[ablocks], %x[ablocks], #0x1\n" + "bne 1b\n" + : [Apanel] "+&r" (Apanel), [Cpanel] "+&r" (Cpanel), [ablocks] "+&r" (ablocks) + : [args_ptr] "r" (&ka), [offsetof_B_stride] "I" (offsetof(KernelArgs, B_stride)), [offsetof_Bpanel] "I" (offsetof(KernelArgs, Bpanel)), [offsetof_K] "I" (offsetof(KernelArgs, K)), [offsetof_N] "I" (offsetof(KernelArgs, N)), [offsetof_cur_B_ptr] "I" (offsetof(KernelArgs, cur_B_ptr)) + : "cc", "memory", "p0", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // namespace arm_gemm +#endif // ARM_COMPUTE_ENABLE_SVE diff --git a/src/core/NEON/kernels/arm_gemm/kernels/sve_interleaved_bf16fp32_mmla_8x3VL.hpp b/src/core/NEON/kernels/arm_gemm/kernels/sve_interleaved_bf16fp32_mmla_8x3VL.hpp index fc91dd71ad..1de8c68494 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/sve_interleaved_bf16fp32_mmla_8x3VL.hpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/sve_interleaved_bf16fp32_mmla_8x3VL.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021 Arm Limited. + * Copyright (c) 2019-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -80,7 +80,7 @@ class cls_sve_interleaved_bf16fp32_mmla_8x3VL case CPUModel::A510: return { 7.78, 4.01, 2.43 }; case CPUModel::V1: - return { 47.63, 5.11, 6.80 }; + return { 62.50, 5.09, 11.32 }; } } @@ -92,7 +92,7 @@ class cls_sve_interleaved_bf16fp32_mmla_8x3VL case CPUModel::A510: return { 7.75, 2.47, 2.39 }; case CPUModel::V1: - return { 60.83, 2.69, 8.66 }; + return { 47.63, 5.11, 6.80 }; } } diff --git a/src/core/NEON/kernels/arm_gemm/misc.cpp b/src/core/NEON/kernels/arm_gemm/misc.cpp index 229e6b56f9..cf99bbdb46 100644 --- a/src/core/NEON/kernels/arm_gemm/misc.cpp +++ b/src/core/NEON/kernels/arm_gemm/misc.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 Arm Limited. + * Copyright (c) 2017-2018, 2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -25,6 +25,11 @@ #ifndef NO_MULTI_THREADING #include #endif +#include + +#include "arm_gemm.hpp" +#include "kernel_weight_format.hpp" +#include "utils.hpp" namespace arm_gemm { @@ -32,4 +37,39 @@ namespace arm_gemm { std::mutex report_mutex; #endif -} // namespace arm_gemm \ No newline at end of file +WeightFormat get_weight_format(const KernelWeightFormat kwf, size_t element_size) { + if (kwf==KernelWeightFormat::NON_FIXED) { + return WeightFormat::UNSPECIFIED; + } + + uint32_t kwf_i = static_cast(kwf); + uint32_t wf_i = 0; + + const auto block_bytes = (kwf_i >> 8) & 0xf; + const auto vector_count = (kwf_i >> 12) & 0xf; + + uint32_t vector_bytes; + + // For fast mode BF16 kernels set the appropriate bit and override element size to 2. + if (kwf_i & 0x10) { + element_size = 2; + wf_i |= 0x10; + } + + // Get total bytes in vector output + if (kwf_i & 0x1) { + vector_bytes = vector_count * get_vector_length(); + } else { + vector_bytes = vector_count * 16; + } + + auto input_blocking = block_bytes / element_size; + auto output_blocking = vector_bytes / block_bytes; + + wf_i |= (input_blocking << 20); + wf_i |= (output_blocking << 8); + + return static_cast(wf_i); +} + +} // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/transform-sve.cpp b/src/core/NEON/kernels/arm_gemm/transform-sve.cpp index 3f6963d32b..d01a9b0fd0 100644 --- a/src/core/NEON/kernels/arm_gemm/transform-sve.cpp +++ b/src/core/NEON/kernels/arm_gemm/transform-sve.cpp @@ -26,7 +26,10 @@ #include "bfloat.hpp" #include "transform.hpp" +#if !defined(_WIN64) && !defined(__OpenBSD__) #include +#endif /* !defined(_WIN64) && !defined(__OpenBSD__) */ + namespace arm_gemm { diff --git a/src/core/NEON/kernels/arm_gemm/utils.hpp b/src/core/NEON/kernels/arm_gemm/utils.hpp index 18e124b83e..d7b5398488 100644 --- a/src/core/NEON/kernels/arm_gemm/utils.hpp +++ b/src/core/NEON/kernels/arm_gemm/utils.hpp @@ -24,7 +24,7 @@ #pragma once -#include "arm_gemm.hpp" +#include "src/cpu/kernels/assembly/arm_gemm.hpp" #include #include diff --git a/src/core/NEON/kernels/assembly/winograd.hpp b/src/core/NEON/kernels/assembly/winograd.hpp new file mode 100644 index 0000000000..836402e83d --- /dev/null +++ b/src/core/NEON/kernels/assembly/winograd.hpp @@ -0,0 +1,234 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#include "src/cpu/kernels/assembly/arm_gemm.hpp" +#include + +namespace arm_conv +{ +struct Shape2D +{ + unsigned int rows, cols; +}; + +struct ConvolutionArgs +{ + unsigned int n_batches; + Shape2D input_shape; + unsigned int n_input_channels; + unsigned int pad_top, pad_left; + Shape2D output_shape; + unsigned int n_output_channels; + Shape2D kernel_shape; + arm_gemm::Activation activation; + + ConvolutionArgs( + unsigned int n_batches, + const Shape2D &input_shape, + unsigned int n_input_channels, + unsigned int pad_top, unsigned int pad_left, + const Shape2D &output_shape, + unsigned int n_output_channels, + const Shape2D kernel_shape, + const arm_gemm::Activation &activation = {}) + : n_batches(n_batches), input_shape(input_shape), n_input_channels(n_input_channels), pad_top(pad_top), pad_left(pad_left), output_shape(output_shape), n_output_channels(n_output_channels), + kernel_shape(kernel_shape), activation(activation) + { + } +}; + +namespace winograd +{ +/* Constrain the selected Winograd implementation. + */ +struct WinogradConfig +{ + unsigned int output_rows = 0, output_cols = 0; + std::string input_transform_filter = ""; + std::string output_transform_filter = ""; + std::string weight_transform_filter = ""; +}; + +/* Struct describing (suggested) memory layout within the Winograd domain. + */ +struct WinogradDomainSpec +{ + size_t weight_matrix_size_bytes, input_matrix_size_bytes, output_matrix_size_bytes; + + size_t weight_ld_matrix, weight_ld_row; + size_t input_ld_batch, input_ld_matrix, input_ld_row; + size_t output_ld_batch, output_ld_matrix, output_ld_row; +}; + +class ITransformCommon +{ +public: + virtual ~ITransformCommon() = default; + + // Get the name of the transform + virtual const std::string &get_name(void) const = 0; +}; + +namespace weight_transform +{ +class ITransform : public ITransformCommon +{ +public: + ~ITransform() = default; + + virtual unsigned int get_kernel_rows(void) const = 0; + virtual unsigned int get_kernel_cols(void) const = 0; + + virtual unsigned int get_transformed_tile_rows(void) const = 0; + virtual unsigned int get_transformed_tile_cols(void) const = 0; + + void execute( + const ConvolutionArgs &args, + const void *inptr, size_t ld_in_row, size_t ld_in_col, size_t ld_input_channel, + void *outptr, const WinogradDomainSpec &wds, + unsigned int thread_id, unsigned int n_threads) const + { + this->execute( + args, inptr, ld_in_row, ld_in_col, ld_input_channel, + outptr, wds.weight_ld_matrix, wds.weight_ld_row, + thread_id, n_threads); + } + + virtual void execute( + const ConvolutionArgs &args, + const void *inptr, size_t ld_in_row, size_t ld_in_col, size_t ld_input_channel, + void *outptr, size_t ld_out_matrix, size_t ld_out_row, + unsigned int thread_id, unsigned int n_threads) const = 0; +}; + +} // namespace weight_transform + +namespace input_transform +{ +class ITransform : public ITransformCommon +{ +public: + ~ITransform() = default; + + virtual unsigned int get_input_rows(void) const = 0; + virtual unsigned int get_input_cols(void) const = 0; + + virtual size_t get_working_space_size( + const ConvolutionArgs &args, + unsigned int n_threads) const = 0; + + void execute( + const ConvolutionArgs &args, + const void *inptr, size_t ld_in_batch, size_t ld_in_row, size_t ld_in_col, + void *outptr, const WinogradDomainSpec &wds, + void *working_space, unsigned int thread_id, unsigned int n_threads) const + { + this->execute( + args, inptr, ld_in_batch, ld_in_row, ld_in_col, + outptr, wds.input_ld_batch, wds.input_ld_matrix, wds.input_ld_row, + working_space, thread_id, n_threads); + } + + virtual void execute( + const ConvolutionArgs &args, + const void *inptr, size_t ld_in_batch, size_t ld_in_row, size_t ld_in_col, + void *outptr, size_t ld_out_batch, size_t ld_out_matrix, size_t ld_out_row, + void *working_space, unsigned int thread_id, unsigned int n_threads) const = 0; +}; + +} // namespace input_transform + +namespace output_transform +{ +class ITransform : public ITransformCommon +{ +public: + ~ITransform() = default; + + virtual unsigned int get_input_rows(void) const = 0; + virtual unsigned int get_input_cols(void) const = 0; + + virtual unsigned int get_output_rows(void) const = 0; + virtual unsigned int get_output_cols(void) const = 0; + + virtual unsigned int get_kernel_rows(void) const = 0; + virtual unsigned int get_kernel_cols(void) const = 0; + + virtual size_t get_working_space_size( + const ConvolutionArgs &args, + unsigned int n_threads) const = 0; + + void execute( + const ConvolutionArgs &args, + const void *inptr, const WinogradDomainSpec &wds, + const void *bias, + void *outptr, size_t ld_out_batch, size_t ld_out_row, size_t ld_out_col, + void *working_space, unsigned int thread_id, unsigned int n_threads) const + { + this->execute( + args, + inptr, wds.output_ld_batch, wds.output_ld_matrix, wds.output_ld_row, + bias, + outptr, ld_out_batch, ld_out_row, ld_out_col, + working_space, thread_id, n_threads); + } + + virtual void execute( + const ConvolutionArgs &args, + const void *inptr, size_t ld_in_batch, size_t ld_in_matrix, size_t ld_in_row, + const void *bias, + void *outptr, size_t ld_out_batch, size_t ld_out_row, size_t ld_out_col, + void *working_space, unsigned int thread_id, unsigned int n_threads) const = 0; +}; + +} // namespace output_transform + +struct WinogradImpl +{ + const output_transform::ITransform *output_transform = nullptr; + const weight_transform::ITransform *weight_transform = nullptr; + const input_transform::ITransform *input_transform = nullptr; + std::unique_ptr gemm_args; + WinogradDomainSpec winograd_spec; +}; + +/* Get pointers to Winograd transforms for the given convolution problem. + * + * Assigns to the pointers in the `dest` struct and returns true or false to + * indicate whether the given problem can be executed or not. + */ +template +bool get_implementation( + WinogradImpl &dest, // Destination for the selected implementation + const CPUInfo *, + const ConvolutionArgs &, + int max_threads, + bool fast_mode, + const WinogradConfig *, + const arm_gemm::GemmConfig *); + +} // namespace winograd +} // namespace arm_conv diff --git a/src/core/NEON/kernels/convolution/common/padding.cpp b/src/core/NEON/kernels/convolution/common/padding.cpp index f57706fef6..5960e66968 100644 --- a/src/core/NEON/kernels/convolution/common/padding.cpp +++ b/src/core/NEON/kernels/convolution/common/padding.cpp @@ -81,7 +81,7 @@ template void copy_and_pad_tile( template void copy_and_pad_tile( unsigned int, unsigned int, unsigned int, - const float *, unsigned int, unsigned int, + float const *, unsigned int, unsigned int, float *, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, float ); diff --git a/src/core/NEON/kernels/convolution/common/padding.hpp b/src/core/NEON/kernels/convolution/common/padding.hpp index b6f95872c0..397d902e29 100644 --- a/src/core/NEON/kernels/convolution/common/padding.hpp +++ b/src/core/NEON/kernels/convolution/common/padding.hpp @@ -34,20 +34,20 @@ namespace padding */ template void copy_and_pad_tile( - unsigned int tile_rows, - unsigned int tile_cols, - unsigned int n_channels, - const T *inptr, - unsigned int in_row_stride, - unsigned int in_col_stride, - T* outptr, - unsigned int out_row_stride, - unsigned int out_col_stride, - unsigned int pad_top, - unsigned int pad_left, - unsigned int pad_bottom, - unsigned int pad_right, - T pad_value=static_cast(0) + const unsigned int tile_rows, + const unsigned int tile_cols, + const unsigned int n_channels, + const T * const inptr, + const unsigned int in_row_stride, + const unsigned int in_col_stride, + T* const outptr, + const unsigned int out_row_stride, + const unsigned int out_col_stride, + const unsigned int pad_top, + const unsigned int pad_left, + const unsigned int pad_bottom, + const unsigned int pad_right, + const T pad_value=static_cast(0) ); /** Copy a tile and remove padding elements in the output. diff --git a/src/core/NEON/kernels/convolution/winograd/input_transform.hpp b/src/core/NEON/kernels/convolution/winograd/input_transform.hpp new file mode 100644 index 0000000000..113b7ea928 --- /dev/null +++ b/src/core/NEON/kernels/convolution/winograd/input_transform.hpp @@ -0,0 +1,384 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#include "arm_compute/core/Error.h" + +#include "src/core/NEON/kernels/assembly/winograd.hpp" + +#include "src/core/NEON/kernels/arm_conv/addressing.hpp" +#include +#include +#include + +namespace arm_conv { +namespace winograd { +namespace input_transform { + +namespace { + +template +constexpr T iceildiv(const T a, const T b) +{ + return (a + b - 1) / b; +} + +} + +/* Driver class for the Winograd input transforms. + * + * This provides a base implementation which handles iteration over the input + * tensor; subclasses are responsible for managing working space and executing + * the transform on individual tiles. + */ +template +class TransformBase : public ITransform +{ + const std::string m_name; + const unsigned int m_input_rows, m_input_cols; + + protected: + virtual size_t get_working_space_per_thread(const ConvolutionArgs &) const + { + return 0; + } + + virtual void initialise_thread_working_space(const ConvolutionArgs &, void *) const + { + // Nothing to do + } + + virtual void execute_tile( + unsigned int n_channels, + const TIn *inptr, size_t ld_in_row, size_t ld_in_col, + TOut *outptr, size_t ld_out_matrix, + unsigned int pad_top, unsigned int valid_rows, + unsigned int pad_left, unsigned int valid_cols, + void *working_space + ) const = 0; + + void execute_internal( + const ConvolutionArgs &args, + const TIn *inptr, size_t ld_in_batch, size_t ld_in_row, size_t ld_in_col, + TOut *outptr, size_t ld_out_batch, size_t ld_out_matrix, size_t ld_out_row, + void *working_space, unsigned int thread_id, unsigned int n_threads + ) const + { + // Get the working space for this thread, and initialise it. + working_space = reinterpret_cast(working_space) + + this->get_working_space_per_thread(args) * thread_id; + this->initialise_thread_working_space(args, working_space); + + // Get tile traversal parameters + const auto tile_stride_rows = std::max(1u, m_input_rows - args.kernel_shape.rows + 1); + const auto tile_stride_cols = std::max(1u, m_input_cols - args.kernel_shape.cols + 1); + const auto n_tile_rows = iceildiv( + args.output_shape.rows, m_input_rows - args.kernel_shape.rows + 1); + const auto n_tile_cols = iceildiv( + args.output_shape.cols, m_input_cols - args.kernel_shape.cols + 1); + + // Execute over all batches + for (unsigned int batch = 0; batch < args.n_batches; batch++) + { + auto outptr_tile = outptr + thread_id * n_tile_cols * ld_out_row; + + // For a single batch, stripe the rows over the threads. + for (auto tile_i = thread_id; tile_i < n_tile_rows; tile_i += n_threads) + { + // Compute pointers and padding for this row of tiles + const auto start_i = tile_i * tile_stride_rows; + const auto pad_top = start_i < args.pad_top ? args.pad_top - start_i : 0; + const auto inptr_row = inptr + (pad_top ? 0 : start_i - args.pad_top) * ld_in_row; + const auto valid_rows = args.input_shape.rows - (pad_top ? 0 : start_i - args.pad_top); + + // Iterate over columns + for (auto tile_j = 0u; tile_j < n_tile_cols; tile_j++) + { + // Compute pointers and padding for this tile, then delegate to + // execute the kernel. + const auto start_j = tile_j * tile_stride_cols; + const auto pad_left = start_j < args.pad_left ? args.pad_left - start_j : 0; + const auto inptr_tile = inptr_row + (pad_left ? 0 : start_j - args.pad_left) * ld_in_col; + const auto valid_cols = args.input_shape.cols - (pad_left ? 0 : start_j - args.pad_left); + + this->execute_tile( + args.n_input_channels, + inptr_tile, ld_in_row, ld_in_col, + outptr_tile, ld_out_matrix, + pad_top, valid_rows, pad_left, valid_cols, + working_space + ); + outptr_tile += ld_out_row; + } + + outptr_tile += (n_threads - 1) * n_tile_cols * ld_out_row; + } + + inptr += ld_in_batch; + outptr += ld_out_batch; + } + } + + public: + TransformBase(const std::string &name, unsigned int input_rows, unsigned int input_cols) + : m_name(name), m_input_rows(input_rows), m_input_cols(input_cols) + { + } + + const std::string &get_name(void) const override { return m_name; } + + unsigned int get_input_rows(void) const override final { return m_input_rows; } + unsigned int get_input_cols(void) const override final { return m_input_cols; } + + size_t get_working_space_size(const ConvolutionArgs &args, unsigned int n_threads) const override + { + return n_threads * this->get_working_space_per_thread(args); + } + + void execute( + const ConvolutionArgs &args, + const void *inptr, size_t ld_in_batch, size_t ld_in_row, size_t ld_in_col, + void *outptr, size_t ld_out_batch, size_t ld_out_matrix, size_t ld_out_row, + void *working_space, unsigned int thread_id, unsigned int n_threads + ) const override + { + execute_internal( + args, + reinterpret_cast(inptr), ld_in_batch, ld_in_row, ld_in_col, + reinterpret_cast(outptr), ld_out_batch, ld_out_matrix, ld_out_row, + working_space, thread_id, n_threads + ); + } +}; + +template +class TransformDirect : public TransformBase +{ + using Kernel = std::function; + const Kernel m_kernel; + + protected: + void execute_tile( + unsigned int n_channels, + const TIn *inptr, size_t ld_in_row, size_t ld_in_col, + TOut *outptr, size_t ld_out_matrix, + unsigned int pad_top, unsigned int valid_rows, + unsigned int pad_left, unsigned int valid_cols, + void *working_space + ) const override + { + ARM_COMPUTE_UNUSED(working_space); + const auto end_i = this->get_input_rows() - pad_top; + const auto pad_bottom = end_i < valid_rows ? 0 : end_i - valid_rows; + const auto end_j = this->get_input_cols() - pad_left; + const auto pad_right = end_j < valid_cols ? 0 : end_j - valid_cols; + + // Execute the kernel + m_kernel( + n_channels, inptr, ld_in_row, ld_in_col, + pad_top, pad_left, pad_bottom, pad_right, + outptr, ld_out_matrix + ); + } + + public: + TransformDirect(const std::string &name, unsigned int input_rows, unsigned int input_cols, Kernel kernel) + : TransformBase(name, input_rows, input_cols), m_kernel(kernel) + { + } +}; + +template +class TransformIndirect : public TransformBase +{ + using Kernel = std::function; + const Kernel m_kernel; + + struct Workspace + { + const TIn **inptrs; + const TIn *input_buffer; + }; + + size_t sizeof_inptr_array(void) const + { + return sizeof(const TIn **) * this->get_input_rows() * this->get_input_cols(); + } + + protected: + size_t get_working_space_per_thread(const ConvolutionArgs &args) const override + { + return sizeof(Workspace) + sizeof_inptr_array() + sizeof(TIn) * args.n_input_channels; + } + + void initialise_thread_working_space(const ConvolutionArgs &args, void *buffer) const override + { + Workspace *ws = reinterpret_cast(buffer); + buffer = ws + 1; + + ws->inptrs = reinterpret_cast(buffer); + buffer = reinterpret_cast(buffer) + sizeof_inptr_array(); + + ws->input_buffer = reinterpret_cast(buffer); + memset(buffer, 0, sizeof(TIn) * args.n_input_channels); + } + + void execute_tile( + unsigned int n_channels, + const TIn *inptr, size_t ld_in_row, size_t ld_in_col, + TOut *outptr, size_t ld_out_matrix, + unsigned int pad_top, unsigned int valid_rows, + unsigned int pad_left, unsigned int valid_cols, + void *working_space + ) const override + { + // Get the working space + auto ws = reinterpret_cast(working_space); + + // Construct the input pointer array based on the given arguments + fill_pointer_array( + ws->inptrs, this->get_input_rows(), this->get_input_cols(), + inptr, ld_in_row, ld_in_col, + ws->input_buffer, + pad_top, valid_rows, + pad_left, valid_cols + ); + + // Execute the kernel + m_kernel(n_channels, ws->inptrs, outptr, ld_out_matrix); + } + + public: + TransformIndirect(const std::string &name, unsigned int input_rows, unsigned int input_cols, Kernel kernel) + : TransformBase(name, input_rows, input_cols), m_kernel(kernel) + { + } +}; + +template +class TransformUnpadded : public TransformBase +{ + using Kernel = std::function; + const Kernel m_kernel; + + protected: + size_t get_working_space_per_thread(const ConvolutionArgs &args) const override + { + const auto input_points = this->get_input_rows() * this->get_input_cols(); + return sizeof(TIn) * input_points * args.n_input_channels; + } + + void execute_tile( + unsigned int n_channels, + const TIn *inptr, size_t ld_in_row, size_t ld_in_col, + TOut *const outptr, const size_t ld_out_matrix, + const unsigned int pad_top, const unsigned int valid_rows, + const unsigned int pad_left, const unsigned int valid_cols, + void *const working_space + ) const override + { + // If there's any padding, then copy the valid portion of the tensor into + // the working space and reset the pointer, row and column strides to point + // at this copy of the data. + if (pad_top || valid_rows < this->get_input_rows() || + pad_left || valid_cols < this->get_input_cols()) + { + const auto patch_ld_col = n_channels; + const auto patch_ld_row = patch_ld_col * this->get_input_cols(); + auto patch = reinterpret_cast(working_space) + + pad_top*patch_ld_row + pad_left*patch_ld_col; + + // Fill the input patch with padding + memset(working_space, 0, sizeof(TIn) * this->get_input_rows() * patch_ld_row); + + // Determine the bounds for which to copy + const auto last_i = std::min(valid_rows + pad_top, this->get_input_rows()); + const auto last_j = std::min(valid_cols + pad_left, this->get_input_cols()); + + // Copy across the valid portion of the patch + for (auto i = pad_top; i < last_i; i++) + { + auto inptr_col = inptr; + inptr += ld_in_row; + + auto patch_col = patch; + patch += patch_ld_row; + + for (auto j = pad_left; j < last_j; j++) + { + // Perform the copy and progress both input and patch pointers + memcpy(patch_col, inptr_col, n_channels * sizeof(TIn)); + inptr_col += ld_in_col; + patch_col += patch_ld_col; + } + } + + // Override the input pointer and strides + inptr = reinterpret_cast(working_space); + ld_in_col = patch_ld_col; + ld_in_row = patch_ld_row; + } + + // Call the kernel + m_kernel(n_channels, inptr, ld_in_row, ld_in_col, outptr, ld_out_matrix); + } + + public: + TransformUnpadded(const std::string &name, unsigned int input_rows, unsigned int input_cols, Kernel kernel) + : TransformBase(name, input_rows, input_cols), m_kernel(kernel) + { + } + + /* Utility method which can be used to get a transposed version of a kernel, + * this just calls the kernel with the input row and column strides reversed. + */ + static constexpr Kernel get_transposed_kernel(const Kernel &kernel) + { + return [kernel] ( + const unsigned int n_channels, + const TIn *const inptr, const size_t ld_in_row, const size_t ld_in_col, + TOut *const outptr, const size_t ld_out_matrix + ) { + kernel(n_channels, inptr, ld_in_col, ld_in_row, outptr, ld_out_matrix); + }; + } +}; + +} // namespace input_transform +} // namespace winograd +} // namespace arm_conv diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_6x6_fp16_fp16_integers.cpp b/src/core/NEON/kernels/convolution/winograd/input_transforms/a64_fp16_6x6.cpp similarity index 95% rename from src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_6x6_fp16_fp16_integers.cpp rename to src/core/NEON/kernels/convolution/winograd/input_transforms/a64_fp16_6x6.cpp index d0ce307988..ad759b225e 100644 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_6x6_fp16_fp16_integers.cpp +++ b/src/core/NEON/kernels/convolution/winograd/input_transforms/a64_fp16_6x6.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020 Arm Limited. + * Copyright (c) 2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,20 +21,22 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -#include "arm.hpp" -#include "input.hpp" +#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) -namespace winograd -{ -template <> -void InputTransform<6, 6, __fp16, __fp16, WinogradRoots::Integers>::transform_tile( - const int n_channels, +#include +#include + +namespace arm_conv { +namespace winograd { +namespace input_transform { + +void a64_fp16_6x6( + const unsigned int n_channels, const __fp16* const input_base, - const int input_row_stride, - const int input_col_stride, + const size_t input_row_stride, + const size_t input_col_stride, __fp16* outptr, - const int matrix_stride + const size_t matrix_stride ) { constexpr int inner_tile_rows = 6; @@ -271,7 +273,8 @@ void InputTransform<6, 6, __fp16, __fp16, WinogradRoots::Integers>::transform_ti } } -template class InputTransform<6, 6, __fp16, __fp16, WinogradRoots::Integers>; - +} // namespace input_transform } // namespace winograd -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC \ No newline at end of file +} // namespace arm_conv + +#endif // defined(__aarch64__) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_6x6_fp32_fp32_integers.cpp b/src/core/NEON/kernels/convolution/winograd/input_transforms/a64_fp32_6x6.cpp similarity index 84% rename from src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_6x6_fp32_fp32_integers.cpp rename to src/core/NEON/kernels/convolution/winograd/input_transforms/a64_fp32_6x6.cpp index 0095e6c96b..6f818c69ff 100644 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_6x6_fp32_fp32_integers.cpp +++ b/src/core/NEON/kernels/convolution/winograd/input_transforms/a64_fp32_6x6.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 Arm Limited. + * Copyright (c) 2022 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -22,31 +22,30 @@ * SOFTWARE. */ -#include "arm.hpp" -#include "input.hpp" +#ifdef __aarch64__ -namespace winograd -{ +#include -#ifdef __aarch64__ +namespace arm_conv { +namespace winograd { +namespace input_transform { -template <> -void InputTransform<6, 6, float, float, WinogradRoots::Integers>::transform_tile( - int n_channels, - const float* input_base, - const int input_row_stride, - const int input_col_stride, - float* matrix_base, - const int matrix_stride +void a64_fp32_6x6( + unsigned int n_channels, + const float *input_base, + const size_t input_row_stride, + const size_t input_col_stride, + float *matrix_base, + const size_t matrix_stride ) { const float pcoeffs[4] = {1.0f, 2.0f, 4.0f, 5.0f}; __asm__ __volatile__( "ldr q0, [%[pcoeffs]]\n" "add x25, %[inptr0], %[input_row_stride]\n" - "add x9, %[input_col_stride1], %[input_col_stride1]\n" + "add x10, %[input_col_stride1], %[input_col_stride1]\n" "add x16, x25, %[input_row_stride]\n" - "add x19, x9, %[input_col_stride1]\n" + "add x19, x10, %[input_col_stride1]\n" "add x26, x16, %[input_row_stride]\n" "add x20, x19, %[input_col_stride1]\n" "add x17, x26, %[input_row_stride]\n" @@ -65,7 +64,7 @@ void InputTransform<6, 6, float, float, WinogradRoots::Integers>::transform_tile "blt 2f\n" "1:\n" "ldr q8, [%[inptr0], x20]\n" - "ldr q2, [%[inptr0], x9]\n" + "ldr q2, [%[inptr0], x10]\n" "mov v14.16b, v8.16b\n" "ldr q9, [%[inptr0]]\n" "mov v10.16b, v8.16b\n" @@ -77,7 +76,7 @@ void InputTransform<6, 6, float, float, WinogradRoots::Integers>::transform_tile "fmls v10.4s, v12.4s, v0.s[2]\n" "ldr q5, [x16, x20]\n" "fmls v14.4s, v2.4s, v0.s[3]\n" - "ldr q20, [x16, x9]\n" + "ldr q20, [x16, x10]\n" "fmla v9.4s, v12.4s, v0.s[2]\n" "ldr q3, [x16]\n" "fmls v10.4s, v2.4s, v0.s[2]\n" @@ -89,7 +88,7 @@ void InputTransform<6, 6, float, float, WinogradRoots::Integers>::transform_tile "fadd v10.4s, v10.4s, v4.4s\n" "ldr q17, [x17, x20]\n" "fmls v7.4s, v12.4s, v0.s[1]\n" - "ldr q15, [x17, x9]\n" + "ldr q15, [x17, x10]\n" "fsub v9.4s, v9.4s, v4.4s\n" "ldr q19, [x17]\n" "mov v8.16b, v8.16b\n" @@ -180,7 +179,7 @@ void InputTransform<6, 6, float, float, WinogradRoots::Integers>::transform_tile "mov v25.16b, v19.16b\n" "ldr q11, [x25, x20]\n" "mov v10.16b, v11.16b\n" - "ldr q23, [x25, x9]\n" + "ldr q23, [x25, x10]\n" "mov v9.16b, v11.16b\n" "ldr q7, [x25]\n" "fmla v10.4s, v7.4s, v0.s[2]\n" @@ -192,7 +191,7 @@ void InputTransform<6, 6, float, float, WinogradRoots::Integers>::transform_tile "fmls v10.4s, v23.4s, v0.s[3]\n" "ldr q30, [x26, x20]\n" "fmls v9.4s, v21.4s, v0.s[2]\n" - "ldr q29, [x26, x9]\n" + "ldr q29, [x26, x10]\n" "fmla v7.4s, v21.4s, v0.s[2]\n" "ldr q22, [x26]\n" "fmls v8.4s, v21.4s, v0.s[1]\n" @@ -360,7 +359,7 @@ void InputTransform<6, 6, float, float, WinogradRoots::Integers>::transform_tile "add x14, x14, #16\n" "ldr q2, [x27, x20]\n" "mov v4.16b, v2.16b\n" - "ldr q17, [x27, x9]\n" + "ldr q17, [x27, x10]\n" "mov v12.16b, v2.16b\n" "ldr q18, [x27]\n" "fmla v4.4s, v18.4s, v0.s[2]\n" @@ -420,7 +419,7 @@ void InputTransform<6, 6, float, float, WinogradRoots::Integers>::transform_tile "blt 3f\n" "ldr d8, [%[inptr0], x20]\n" "mov v14.16b, v8.16b\n" - "ldr d2, [%[inptr0], x9]\n" + "ldr d2, [%[inptr0], x10]\n" "mov v10.16b, v8.16b\n" "ldr d9, [%[inptr0]]\n" "fmla v14.4s, v9.4s, v0.s[2]\n" @@ -432,7 +431,7 @@ void InputTransform<6, 6, float, float, WinogradRoots::Integers>::transform_tile "fmls v14.4s, v2.4s, v0.s[3]\n" "ldr d5, [x16, x20]\n" "fmls v10.4s, v12.4s, v0.s[2]\n" - "ldr d20, [x16, x9]\n" + "ldr d20, [x16, x10]\n" "fmla v9.4s, v12.4s, v0.s[2]\n" "ldr d3, [x16]\n" "fmls v7.4s, v12.4s, v0.s[1]\n" @@ -444,7 +443,7 @@ void InputTransform<6, 6, float, float, WinogradRoots::Integers>::transform_tile "fsub v7.4s, v7.4s, v2.4s\n" "ldr d17, [x17, x20]\n" "fadd v10.4s, v10.4s, v4.4s\n" - "ldr d15, [x17, x9]\n" + "ldr d15, [x17, x10]\n" "fsub v9.4s, v9.4s, v4.4s\n" "ldr d19, [x17]\n" "fmla v7.4s, v4.4s, v0.s[1]\n" @@ -534,7 +533,7 @@ void InputTransform<6, 6, float, float, WinogradRoots::Integers>::transform_tile "mov v25.16b, v19.16b\n" "ldr d11, [x25, x20]\n" "mov v10.16b, v11.16b\n" - "ldr d23, [x25, x9]\n" + "ldr d23, [x25, x10]\n" "mov v9.16b, v11.16b\n" "ldr d7, [x25]\n" "fmla v10.4s, v7.4s, v0.s[2]\n" @@ -546,7 +545,7 @@ void InputTransform<6, 6, float, float, WinogradRoots::Integers>::transform_tile "fmls v10.4s, v23.4s, v0.s[3]\n" "ldr d30, [x26, x20]\n" "fmls v9.4s, v21.4s, v0.s[2]\n" - "ldr d29, [x26, x9]\n" + "ldr d29, [x26, x10]\n" "fmla v7.4s, v21.4s, v0.s[2]\n" "ldr d22, [x26]\n" "fmls v8.4s, v21.4s, v0.s[1]\n" @@ -714,7 +713,7 @@ void InputTransform<6, 6, float, float, WinogradRoots::Integers>::transform_tile "add x14, x14, #8\n" "ldr d2, [x27, x20]\n" "mov v4.16b, v2.16b\n" - "ldr d17, [x27, x9]\n" + "ldr d17, [x27, x10]\n" "mov v12.16b, v2.16b\n" "ldr d18, [x27]\n" "fmla v4.4s, v18.4s, v0.s[2]\n" @@ -771,7 +770,7 @@ void InputTransform<6, 6, float, float, WinogradRoots::Integers>::transform_tile "cbz %w[n_channels], 4f\n" "ldr s8, [%[inptr0], x20]\n" "mov v14.16b, v8.16b\n" - "ldr s2, [%[inptr0], x9]\n" + "ldr s2, [%[inptr0], x10]\n" "mov v10.16b, v8.16b\n" "ldr s9, [%[inptr0]]\n" "fmla v14.4s, v9.4s, v0.s[2]\n" @@ -783,7 +782,7 @@ void InputTransform<6, 6, float, float, WinogradRoots::Integers>::transform_tile "fmls v14.4s, v2.4s, v0.s[3]\n" "ldr s5, [x16, x20]\n" "fmls v10.4s, v12.4s, v0.s[2]\n" - "ldr s20, [x16, x9]\n" + "ldr s20, [x16, x10]\n" "fmla v9.4s, v12.4s, v0.s[2]\n" "ldr s3, [x16]\n" "fmls v7.4s, v12.4s, v0.s[1]\n" @@ -795,7 +794,7 @@ void InputTransform<6, 6, float, float, WinogradRoots::Integers>::transform_tile "fsub v7.4s, v7.4s, v2.4s\n" "ldr s17, [x17, x20]\n" "fadd v10.4s, v10.4s, v4.4s\n" - "ldr s15, [x17, x9]\n" + "ldr s15, [x17, x10]\n" "fsub v9.4s, v9.4s, v4.4s\n" "ldr s19, [x17]\n" "fmla v7.4s, v4.4s, v0.s[1]\n" @@ -885,7 +884,7 @@ void InputTransform<6, 6, float, float, WinogradRoots::Integers>::transform_tile "mov v25.16b, v19.16b\n" "ldr s11, [x25, x20]\n" "mov v10.16b, v11.16b\n" - "ldr s23, [x25, x9]\n" + "ldr s23, [x25, x10]\n" "mov v9.16b, v11.16b\n" "ldr s7, [x25]\n" "fmla v10.4s, v7.4s, v0.s[2]\n" @@ -897,7 +896,7 @@ void InputTransform<6, 6, float, float, WinogradRoots::Integers>::transform_tile "fmls v10.4s, v23.4s, v0.s[3]\n" "ldr s30, [x26, x20]\n" "fmls v9.4s, v21.4s, v0.s[2]\n" - "ldr s29, [x26, x9]\n" + "ldr s29, [x26, x10]\n" "fmla v7.4s, v21.4s, v0.s[2]\n" "ldr s22, [x26]\n" "fmls v8.4s, v21.4s, v0.s[1]\n" @@ -1065,7 +1064,7 @@ void InputTransform<6, 6, float, float, WinogradRoots::Integers>::transform_tile "add x14, x14, #4\n" "ldr s2, [x27, x20]\n" "mov v4.16b, v2.16b\n" - "ldr s17, [x27, x9]\n" + "ldr s17, [x27, x10]\n" "mov v12.16b, v2.16b\n" "ldr s18, [x27]\n" "fmla v4.4s, v18.4s, v0.s[2]\n" @@ -1129,180 +1128,13 @@ void InputTransform<6, 6, float, float, WinogradRoots::Integers>::transform_tile : "cc", "v0", "v1", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v2", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v3", "v30", "v31", "v4", "v5", "v6", "v7", "v8", - "v9", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x9", "x19", + "v9", "x11", "x12", "x13", "x14", "x15", "x16", "x17", "x10", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "memory" ); } -#else // __arm__ not __aarch64__ - -template <> -void InputTransform<6, 6, float, float, WinogradRoots::Integers>::transform_tile( - const int n_channels, - const float* const input_base, - const int input_row_stride, - const int input_col_stride, - float* outptr, - const int matrix_stride -) -{ - constexpr int inner_tile_rows = 6; - constexpr int inner_tile_cols = 6; - - // Get pointers into the input tile - const float *x_ptrs[inner_tile_rows][inner_tile_cols]; - for (int i = 0, xi = 0; i < inner_tile_rows; i++, xi++) - { - // Get a pointer into the row - const float* const row_ptr = input_base + xi*input_row_stride; - - for (int j = 0, xj = 0; j < inner_tile_cols; j++, xj++) - { - x_ptrs[i][j] = row_ptr + xj*input_col_stride; - } - } - - // Matrices used/computed in this kernel. - float x[inner_tile_rows][inner_tile_cols]; - float XTx[inner_tile_rows][inner_tile_cols]; - float U[inner_tile_rows][inner_tile_cols]; - for (int i = 0; i < inner_tile_rows; i++) - { - for (int j = 0; j < inner_tile_cols; j++) - { - x[i][j] = XTx[i][j] = 0.0f; - } - } - - // Perform the Winograd input transformation for each channel in the input - // tensor. - int channels_remaining = n_channels; - for (; channels_remaining >= 2; channels_remaining -= 2) - { - // Matrices used/computed in this kernel - float32x2_t x[inner_tile_rows][inner_tile_cols]; - float32x2_t XTx[inner_tile_rows][inner_tile_cols]; - float32x2_t U[inner_tile_rows][inner_tile_cols]; - for (int i = 0; i < inner_tile_rows; i++) - { - for (int j = 0; j < inner_tile_cols; j++) - { - x[i][j] = vdup_n_f32(0.0f); - XTx[i][j] = vdup_n_f32(0.0f); - } - } - - // Read a 6x6 tile in the Winograd domain - for (int i = 0; i < inner_tile_rows; i++) - { - for (int j = 0; j < inner_tile_cols; j++) - { - x[i][j] = vld1_f32(x_ptrs[i][j]); - x_ptrs[i][j] += 2; - } - } - - // Compute XT . x - for (int j = 0; j < inner_tile_cols; j++) - { - // XTx[0][j] = 4*x[0][j] + -5*x[2][j] + 1*x[4][j]; - XTx[0][j] = vmls_n_f32(vmla_n_f32(x[4][j], x[0][j], 4.0f), x[2][j], 5.0f); - - // XTx[1][j] = -4*x[1][j] + -4*x[2][j] + 1*x[3][j] + 1*x[4][j]; - XTx[1][j] = vmls_n_f32(vadd_f32(x[3][j], x[4][j]), vadd_f32(x[1][j], x[2][j]), 4.0f); - - // XTx[2][j] = 4*x[1][j] + -4*x[2][j] + -1*x[3][j] + 1*x[4][j]; - XTx[2][j] = vmla_n_f32(vsub_f32(x[4][j], x[3][j]), vsub_f32(x[1][j], x[2][j]), 4.0f); - - // XTx[3][j] = -2*x[1][j] + -1*x[2][j] + 2*x[3][j] + 1*x[4][j]; - XTx[3][j] = vmla_n_f32(vsub_f32(x[4][j], x[2][j]), vsub_f32(x[3][j], x[1][j]), 2.0f); - - // XTx[4][j] = 2*x[1][j] + -1*x[2][j] + -2*x[3][j] + 1*x[4][j]; - XTx[4][j] = vmla_n_f32(vsub_f32(x[4][j], x[2][j]), vsub_f32(x[1][j], x[3][j]), 2.0f); - - // XTx[5][j] = 4*x[1][j] + -5*x[3][j] + 1*x[5][j]; - XTx[5][j] = vmls_n_f32(vmla_n_f32(x[5][j], x[1][j], 4.0f), x[3][j], 5.0f); - } - - // Compute U = XT . x . X - for (int i = 0; i < inner_tile_rows; i++) - { - // U[i][0] = 4*XTx[i][0] + -5*XTx[i][2] + 1*XTx[i][4]; - U[i][0] = vmls_n_f32(vmla_n_f32(XTx[i][4], XTx[i][0], 4.0f), XTx[i][2], 5.0f); - - // U[i][1] = -4*XTx[i][1] + -4*XTx[i][2] + 1*XTx[i][3] + 1*XTx[i][4]; - U[i][1] = vmls_n_f32(vadd_f32(XTx[i][3], XTx[i][4]), vadd_f32(XTx[i][1], XTx[i][2]), 4.0f); - - // U[i][2] = 4*XTx[i][1] + -4*XTx[i][2] + -1*XTx[i][3] + 1*XTx[i][4]; - U[i][2] = vmla_n_f32(vsub_f32(XTx[i][4], XTx[i][3]), vsub_f32(XTx[i][1], XTx[i][2]), 4.0f); - - // U[i][3] = -2*XTx[i][1] + -1*XTx[i][2] + 2*XTx[i][3] + 1*XTx[i][4]; - U[i][3] = vmla_n_f32(vsub_f32(XTx[i][4], XTx[i][2]), vsub_f32(XTx[i][3], XTx[i][1]), 2.0f); - - // U[i][4] = 2*XTx[i][1] + -1*XTx[i][2] + -2*XTx[i][3] + 1*XTx[i][4]; - U[i][4] = vmla_n_f32(vsub_f32(XTx[i][4], XTx[i][2]), vsub_f32(XTx[i][1], XTx[i][3]), 2.0f); - - // U[i][5] = 4*XTx[i][1] + -5*XTx[i][3] + 1*XTx[i][5]; - U[i][5] = vmls_n_f32(vmla_n_f32(XTx[i][5], XTx[i][1], 4.0f), XTx[i][3], 5.0f); - } - - // Store the transformed matrix - for (int i = 0, m = 0; i < inner_tile_rows; i++) - { - for (int j = 0; j < inner_tile_cols; j++, m++) - { - vst1_f32(outptr + m*matrix_stride, U[i][j]); - } - } - outptr += 2; - } - for (; channels_remaining; channels_remaining--) - { - // Load x - for (int i = 0; i < inner_tile_rows; i++) - { - for (int j = 0; j < inner_tile_cols; j++) - { - x[i][j] = *(x_ptrs[i][j]++); - } - } - - // Compute XT . x - for (int j = 0; j < inner_tile_cols; j++) - { - XTx[0][j] = 4*x[0][j] + -5*x[2][j] + 1*x[4][j]; - XTx[1][j] = -4*x[1][j] + -4*x[2][j] + 1*x[3][j] + 1*x[4][j]; - XTx[2][j] = 4*x[1][j] + -4*x[2][j] + -1*x[3][j] + 1*x[4][j]; - XTx[3][j] = -2*x[1][j] + -1*x[2][j] + 2*x[3][j] + 1*x[4][j]; - XTx[4][j] = 2*x[1][j] + -1*x[2][j] + -2*x[3][j] + 1*x[4][j]; - XTx[5][j] = 4*x[1][j] + -5*x[3][j] + 1*x[5][j]; - } - - // Compute U = XT . x . X - for (int i = 0; i < inner_tile_rows; i++) - { - U[i][0] = 4*XTx[i][0] + -5*XTx[i][2] + 1*XTx[i][4]; - U[i][1] = -4*XTx[i][1] + -4*XTx[i][2] + 1*XTx[i][3] + 1*XTx[i][4]; - U[i][2] = 4*XTx[i][1] + -4*XTx[i][2] + -1*XTx[i][3] + 1*XTx[i][4]; - U[i][3] = -2*XTx[i][1] + -1*XTx[i][2] + 2*XTx[i][3] + 1*XTx[i][4]; - U[i][4] = 2*XTx[i][1] + -1*XTx[i][2] + -2*XTx[i][3] + 1*XTx[i][4]; - U[i][5] = 4*XTx[i][1] + -5*XTx[i][3] + 1*XTx[i][5]; - } - - // Store the transformed matrix - for (int i = 0, m = 0; i < inner_tile_rows; i++) - { - for (int j = 0; j < inner_tile_cols; j++, m++) - { - *(outptr + m*matrix_stride) = U[i][j]; - } - } - outptr++; - } -} - -#endif - -template class InputTransform<6, 6, float, float, WinogradRoots::Integers>; - +} // namespace input_transform } // namespace winograd +} // namespace arm_conv + +#endif // __aarch64__ diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_1x8_fp32_fp32_integers.cpp b/src/core/NEON/kernels/convolution/winograd/input_transforms/arm_fp32_1x8.cpp similarity index 91% rename from src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_1x8_fp32_fp32_integers.cpp rename to src/core/NEON/kernels/convolution/winograd/input_transforms/arm_fp32_1x8.cpp index 8f6e9e8b40..2d6b333a59 100644 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_1x8_fp32_fp32_integers.cpp +++ b/src/core/NEON/kernels/convolution/winograd/input_transforms/arm_fp32_1x8.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 Arm Limited. + * Copyright (c) 2022 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -22,20 +22,20 @@ * SOFTWARE. */ -#include "arm.hpp" -#include "input.hpp" +#include +#include -namespace winograd -{ +namespace arm_conv { +namespace winograd { +namespace input_transform { -template <> -void InputTransform<1, 8, float, float, WinogradRoots::Integers>::transform_tile( - const int n_channels, - const float* const input_base, - const int, // We don't need to stride over rows - const int input_col_stride, - float* outptr, - const int matrix_stride +void arm_fp32_1x8( + const unsigned int n_channels, + const float *const input_base, + size_t, // We don't need to stride over rows + const size_t input_col_stride, + float *outptr, + const size_t matrix_stride ) { constexpr int inner_tile_cols = 8; @@ -59,7 +59,6 @@ void InputTransform<1, 8, float, float, WinogradRoots::Integers>::transform_tile // Perform the Winograd input transformation for each channel in the input // tensor. int channels_remaining = n_channels; -#ifdef _arm_any_ for (; channels_remaining >= 4; channels_remaining -= 4) { float32x4_t x[inner_tile_cols], U[inner_tile_cols]; @@ -124,7 +123,6 @@ void InputTransform<1, 8, float, float, WinogradRoots::Integers>::transform_tile } outptr += 2; } -#endif // _arm_any_ for (; channels_remaining; channels_remaining--) { // Load x @@ -152,7 +150,6 @@ void InputTransform<1, 8, float, float, WinogradRoots::Integers>::transform_tile } } -template class InputTransform<1, 8, float, float, WinogradRoots::Integers>; -template class InputTransform<8, 1, float, float, WinogradRoots::Integers>; - +} // namespace input_transform } // namespace winograd +} // namespace arm_conv diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_4x4_fp32_fp32_integers.cpp b/src/core/NEON/kernels/convolution/winograd/input_transforms/arm_fp32_4x4.cpp similarity index 92% rename from src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_4x4_fp32_fp32_integers.cpp rename to src/core/NEON/kernels/convolution/winograd/input_transforms/arm_fp32_4x4.cpp index 69d3e8feb5..fae0173374 100644 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_4x4_fp32_fp32_integers.cpp +++ b/src/core/NEON/kernels/convolution/winograd/input_transforms/arm_fp32_4x4.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 Arm Limited. + * Copyright (c) 2022 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -22,20 +22,20 @@ * SOFTWARE. */ -#include "input.hpp" -#include "arm.hpp" +#include +#include -namespace winograd -{ +namespace arm_conv { +namespace winograd { +namespace input_transform { -template <> -void InputTransform<4, 4, float, float, WinogradRoots::Integers>::transform_tile( - const int n_channels, - const float* const input_base, - const int input_row_stride, - const int input_col_stride, - float* outptr, - const int matrix_stride +void arm_fp32_4x4( + const unsigned int n_channels, + const float *input_base, + const size_t input_row_stride, + const size_t input_col_stride, + float *outptr, + const size_t matrix_stride ) { constexpr int inner_tile_rows = 4, inner_tile_cols = 4; @@ -69,7 +69,6 @@ void InputTransform<4, 4, float, float, WinogradRoots::Integers>::transform_tile // Perform the Winograd input transformation for each channel in the input // tensor. int channels_remaining = n_channels; -#ifdef __aarch64__ for (; channels_remaining >= 4; channels_remaining -= 4) { // Matrices used/computed in this kernel. @@ -138,8 +137,6 @@ void InputTransform<4, 4, float, float, WinogradRoots::Integers>::transform_tile } outptr += 4; } -#endif // __aarch64__ -#ifdef __arm_any__ for (; channels_remaining >= 2; channels_remaining -= 2) { // Matrices used/computed in this kernel. @@ -208,7 +205,6 @@ void InputTransform<4, 4, float, float, WinogradRoots::Integers>::transform_tile } outptr += 2; } -#endif // __arm_any__ for (; channels_remaining; channels_remaining--) { // Load x @@ -250,6 +246,6 @@ void InputTransform<4, 4, float, float, WinogradRoots::Integers>::transform_tile } } -template class InputTransform<4, 4, float, float, WinogradRoots::Integers>; - -} // namespace +} // namespace input_transform +} // namespace winograd +} // namespace arm_conv diff --git a/src/core/NEON/kernels/convolution/winograd/input_transforms/arm_fp32_6x6.cpp b/src/core/NEON/kernels/convolution/winograd/input_transforms/arm_fp32_6x6.cpp new file mode 100644 index 0000000000..4adc45768e --- /dev/null +++ b/src/core/NEON/kernels/convolution/winograd/input_transforms/arm_fp32_6x6.cpp @@ -0,0 +1,202 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#ifndef __aarch64__ + +#include +#include + +namespace arm_conv { +namespace winograd { +namespace input_transform { + +void arm_fp32_6x6( + unsigned int n_channels, + const float* const input_base, + const size_t input_row_stride, + const size_t input_col_stride, + float* outptr, + const size_t matrix_stride +) +{ + constexpr int inner_tile_rows = 6; + constexpr int inner_tile_cols = 6; + + // Get pointers into the input tile + const float *x_ptrs[inner_tile_rows][inner_tile_cols]; + for (int i = 0, xi = 0; i < inner_tile_rows; i++, xi++) + { + // Get a pointer into the row + const float* const row_ptr = input_base + xi*input_row_stride; + + for (int j = 0, xj = 0; j < inner_tile_cols; j++, xj++) + { + x_ptrs[i][j] = row_ptr + xj*input_col_stride; + } + } + + // Matrices used/computed in this kernel. + float x[inner_tile_rows][inner_tile_cols]; + float XTx[inner_tile_rows][inner_tile_cols]; + float U[inner_tile_rows][inner_tile_cols]; + for (int i = 0; i < inner_tile_rows; i++) + { + for (int j = 0; j < inner_tile_cols; j++) + { + x[i][j] = XTx[i][j] = 0.0f; + } + } + + // Perform the Winograd input transformation for each channel in the input + // tensor. + int channels_remaining = n_channels; + for (; channels_remaining >= 2; channels_remaining -= 2) + { + // Matrices used/computed in this kernel + float32x2_t x[inner_tile_rows][inner_tile_cols]; + float32x2_t XTx[inner_tile_rows][inner_tile_cols]; + float32x2_t U[inner_tile_rows][inner_tile_cols]; + for (int i = 0; i < inner_tile_rows; i++) + { + for (int j = 0; j < inner_tile_cols; j++) + { + x[i][j] = vdup_n_f32(0.0f); + XTx[i][j] = vdup_n_f32(0.0f); + } + } + + // Read a 6x6 tile in the Winograd domain + for (int i = 0; i < inner_tile_rows; i++) + { + for (int j = 0; j < inner_tile_cols; j++) + { + x[i][j] = vld1_f32(x_ptrs[i][j]); + x_ptrs[i][j] += 2; + } + } + + // Compute XT . x + for (int j = 0; j < inner_tile_cols; j++) + { + // XTx[0][j] = 4*x[0][j] + -5*x[2][j] + 1*x[4][j]; + XTx[0][j] = vmls_n_f32(vmla_n_f32(x[4][j], x[0][j], 4.0f), x[2][j], 5.0f); + + // XTx[1][j] = -4*x[1][j] + -4*x[2][j] + 1*x[3][j] + 1*x[4][j]; + XTx[1][j] = vmls_n_f32(vadd_f32(x[3][j], x[4][j]), vadd_f32(x[1][j], x[2][j]), 4.0f); + + // XTx[2][j] = 4*x[1][j] + -4*x[2][j] + -1*x[3][j] + 1*x[4][j]; + XTx[2][j] = vmla_n_f32(vsub_f32(x[4][j], x[3][j]), vsub_f32(x[1][j], x[2][j]), 4.0f); + + // XTx[3][j] = -2*x[1][j] + -1*x[2][j] + 2*x[3][j] + 1*x[4][j]; + XTx[3][j] = vmla_n_f32(vsub_f32(x[4][j], x[2][j]), vsub_f32(x[3][j], x[1][j]), 2.0f); + + // XTx[4][j] = 2*x[1][j] + -1*x[2][j] + -2*x[3][j] + 1*x[4][j]; + XTx[4][j] = vmla_n_f32(vsub_f32(x[4][j], x[2][j]), vsub_f32(x[1][j], x[3][j]), 2.0f); + + // XTx[5][j] = 4*x[1][j] + -5*x[3][j] + 1*x[5][j]; + XTx[5][j] = vmls_n_f32(vmla_n_f32(x[5][j], x[1][j], 4.0f), x[3][j], 5.0f); + } + + // Compute U = XT . x . X + for (int i = 0; i < inner_tile_rows; i++) + { + // U[i][0] = 4*XTx[i][0] + -5*XTx[i][2] + 1*XTx[i][4]; + U[i][0] = vmls_n_f32(vmla_n_f32(XTx[i][4], XTx[i][0], 4.0f), XTx[i][2], 5.0f); + + // U[i][1] = -4*XTx[i][1] + -4*XTx[i][2] + 1*XTx[i][3] + 1*XTx[i][4]; + U[i][1] = vmls_n_f32(vadd_f32(XTx[i][3], XTx[i][4]), vadd_f32(XTx[i][1], XTx[i][2]), 4.0f); + + // U[i][2] = 4*XTx[i][1] + -4*XTx[i][2] + -1*XTx[i][3] + 1*XTx[i][4]; + U[i][2] = vmla_n_f32(vsub_f32(XTx[i][4], XTx[i][3]), vsub_f32(XTx[i][1], XTx[i][2]), 4.0f); + + // U[i][3] = -2*XTx[i][1] + -1*XTx[i][2] + 2*XTx[i][3] + 1*XTx[i][4]; + U[i][3] = vmla_n_f32(vsub_f32(XTx[i][4], XTx[i][2]), vsub_f32(XTx[i][3], XTx[i][1]), 2.0f); + + // U[i][4] = 2*XTx[i][1] + -1*XTx[i][2] + -2*XTx[i][3] + 1*XTx[i][4]; + U[i][4] = vmla_n_f32(vsub_f32(XTx[i][4], XTx[i][2]), vsub_f32(XTx[i][1], XTx[i][3]), 2.0f); + + // U[i][5] = 4*XTx[i][1] + -5*XTx[i][3] + 1*XTx[i][5]; + U[i][5] = vmls_n_f32(vmla_n_f32(XTx[i][5], XTx[i][1], 4.0f), XTx[i][3], 5.0f); + } + + // Store the transformed matrix + for (int i = 0, m = 0; i < inner_tile_rows; i++) + { + for (int j = 0; j < inner_tile_cols; j++, m++) + { + vst1_f32(outptr + m*matrix_stride, U[i][j]); + } + } + outptr += 2; + } + for (; channels_remaining; channels_remaining--) + { + // Load x + for (int i = 0; i < inner_tile_rows; i++) + { + for (int j = 0; j < inner_tile_cols; j++) + { + x[i][j] = *(x_ptrs[i][j]++); + } + } + + // Compute XT . x + for (int j = 0; j < inner_tile_cols; j++) + { + XTx[0][j] = 4*x[0][j] + -5*x[2][j] + 1*x[4][j]; + XTx[1][j] = -4*x[1][j] + -4*x[2][j] + 1*x[3][j] + 1*x[4][j]; + XTx[2][j] = 4*x[1][j] + -4*x[2][j] + -1*x[3][j] + 1*x[4][j]; + XTx[3][j] = -2*x[1][j] + -1*x[2][j] + 2*x[3][j] + 1*x[4][j]; + XTx[4][j] = 2*x[1][j] + -1*x[2][j] + -2*x[3][j] + 1*x[4][j]; + XTx[5][j] = 4*x[1][j] + -5*x[3][j] + 1*x[5][j]; + } + + // Compute U = XT . x . X + for (int i = 0; i < inner_tile_rows; i++) + { + U[i][0] = 4*XTx[i][0] + -5*XTx[i][2] + 1*XTx[i][4]; + U[i][1] = -4*XTx[i][1] + -4*XTx[i][2] + 1*XTx[i][3] + 1*XTx[i][4]; + U[i][2] = 4*XTx[i][1] + -4*XTx[i][2] + -1*XTx[i][3] + 1*XTx[i][4]; + U[i][3] = -2*XTx[i][1] + -1*XTx[i][2] + 2*XTx[i][3] + 1*XTx[i][4]; + U[i][4] = 2*XTx[i][1] + -1*XTx[i][2] + -2*XTx[i][3] + 1*XTx[i][4]; + U[i][5] = 4*XTx[i][1] + -5*XTx[i][3] + 1*XTx[i][5]; + } + + // Store the transformed matrix + for (int i = 0, m = 0; i < inner_tile_rows; i++) + { + for (int j = 0; j < inner_tile_cols; j++, m++) + { + *(outptr + m*matrix_stride) = U[i][j]; + } + } + outptr++; + } +} + +} // namespace input_transform +} // namespace winograd +} // namespace arm_conv + +#endif // ! __aarch64__ diff --git a/src/core/NEON/kernels/convolution/winograd/input_transforms/sve_fp32_6x6.cpp b/src/core/NEON/kernels/convolution/winograd/input_transforms/sve_fp32_6x6.cpp new file mode 100644 index 0000000000..a2f096f489 --- /dev/null +++ b/src/core/NEON/kernels/convolution/winograd/input_transforms/sve_fp32_6x6.cpp @@ -0,0 +1,361 @@ +/* + * Copyright (c) 2022 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#if __aarch64__ && defined(ARM_COMPUTE_ENABLE_SVE) +#include + +namespace arm_conv { +namespace winograd { +namespace input_transform { + +void sve_fp32_6x6( + const unsigned int num_channels, + const float *input, + const size_t input_row_stride, + const size_t input_col_stride, + float *output, + const size_t output_col_stride +) +{ + const float B_values[4] = { 1.0f, 2.0f, 4.0f, 5.0f }; + long long_channels = num_channels; + + // Generated by armasmgen (February 04th, 2021) + __asm__ __volatile__( + "fmov z16.s, #4.0\n" + "ptrue p1.b\n" + "ld1rqw { z2.s }, p1/Z, [%x[B_values]]\n" + "add x16, %x[input_row_0], %x[input_row_stride], LSL #2\n" + "add x15, %x[output_row_0], %x[output_row_stride], LSL #2\n" + "add x14, %x[input_row_0], %x[input_row_stride], LSL #3\n" + "add x13, %x[output_row_0], %x[output_row_stride], LSL #3\n" + "add x12, x14, %x[input_row_stride], LSL #2\n" + "add x11, x13, %x[output_row_stride], LSL #2\n" + "add x10, %x[input_row_0], %x[input_row_stride], LSL #4\n" + "add x9, %x[output_row_0], %x[output_row_stride], LSL #4\n" + "add x28, x10, %x[input_row_stride], LSL #2\n" + "add x27, x9, %x[output_row_stride], LSL #2\n" + "lsl x26, %x[input_col_1_stride], #0x1\n" + "lsl x25, %x[output_col_1_stride], #0x1\n" + "add x24, x26, %x[input_col_1_stride]\n" + "add x23, x25, %x[output_col_1_stride]\n" + "lsl x22, %x[input_col_1_stride], #0x2\n" + "lsl x21, %x[output_col_1_stride], #0x2\n" + "add x20, x22, %x[input_col_1_stride]\n" + "add x19, x21, %x[output_col_1_stride]\n" + "whilelt p0.s, XZR, %x[num_channels]\n" + "beq 2f\n" + "1:" // channel_loop + "ld1w { z31.s }, p0/Z, [%x[input_row_0]]\n" + "decw %x[num_channels]\n" + "ld1w { z28.s }, p0/Z, [%x[input_row_0], %x[input_col_1_stride], LSL #2]\n" + "fmul z13.s, z28.s, z2.s[1]\n" + "ld1w { z27.s }, p0/Z, [%x[input_row_0], x26, LSL #2]\n" + "ld1w { z11.s }, p0/Z, [%x[input_row_0], x24, LSL #2]\n" + "fneg z13.s, p1/M, z13.s\n" + "ld1w { z7.s }, p0/Z, [%x[input_row_0], x22, LSL #2]\n" + "fsub z15.s, z7.s, z27.s\n" + "fmad z31.s, p1/M, z16.s, z7.s\n" + "ld1w { z3.s }, p0/Z, [%x[input_row_0], x20, LSL #2]\n" + "fmla z13.s, z11.s, z2.s[1]\n" + "ld1w { z12.s }, p0/Z, [x14]\n" + "incb %x[input_row_0]\n" + "fmls z31.s, z27.s, z2.s[3]\n" + "ld1w { z14.s }, p0/Z, [x14, %x[input_col_1_stride], LSL #2]\n" + "fsub z25.s, z15.s, z13.s\n" + "fadd z8.s, z13.s, z15.s\n" + "ld1w { z24.s }, p0/Z, [x14, x26, LSL #2]\n" + "fmsb z27.s, p1/M, z16.s, z7.s\n" + "ld1w { z22.s }, p0/Z, [x14, x24, LSL #2]\n" + "fmul z7.s, z28.s, z2.s[2]\n" + "ld1w { z1.s }, p0/Z, [x14, x22, LSL #2]\n" + "fsub z15.s, z1.s, z24.s\n" + "fneg z7.s, p1/M, z7.s\n" + "ld1w { z20.s }, p0/Z, [x14, x20, LSL #2]\n" + "fadd z7.s, z7.s, z11.s\n" + "ld1w { z29.s }, p0/Z, [x10]\n" + "incb x14\n" + "fmad z28.s, p1/M, z16.s, z3.s\n" + "ld1w { z10.s }, p0/Z, [x10, %x[input_col_1_stride], LSL #2]\n" + "fmad z12.s, p1/M, z16.s, z1.s\n" + "ld1w { z18.s }, p0/Z, [x10, x26, LSL #2]\n" + "fmul z13.s, z14.s, z2.s[1]\n" + "ld1w { z19.s }, p0/Z, [x10, x24, LSL #2]\n" + "fadd z17.s, z7.s, z27.s\n" + "ld1w { z9.s }, p0/Z, [x10, x22, LSL #2]\n" + "fsub z27.s, z27.s, z7.s\n" + "fmls z28.s, z11.s, z2.s[3]\n" + "ld1w { z21.s }, p0/Z, [x10, x20, LSL #2]\n" + "incb x10\n" + "fmls z12.s, z24.s, z2.s[3]\n" + "fneg z13.s, p1/M, z13.s\n" + "fmla z13.s, z22.s, z2.s[1]\n" + "fsub z30.s, z15.s, z13.s\n" + "fadd z4.s, z13.s, z15.s\n" + "fmsb z24.s, p1/M, z16.s, z1.s\n" + "fsub z15.s, z9.s, z18.s\n" + "fmul z1.s, z14.s, z2.s[2]\n" + "fmad z14.s, p1/M, z16.s, z20.s\n" + "fmad z29.s, p1/M, z16.s, z9.s\n" + "fmul z13.s, z10.s, z2.s[1]\n" + "fneg z1.s, p1/M, z1.s\n" + "fadd z1.s, z1.s, z22.s\n" + "fmls z14.s, z22.s, z2.s[3]\n" + "fmls z29.s, z18.s, z2.s[3]\n" + "fadd z5.s, z1.s, z24.s\n" + "fsub z24.s, z24.s, z1.s\n" + "fneg z13.s, p1/M, z13.s\n" + "fmla z13.s, z19.s, z2.s[1]\n" + "fsub z23.s, z15.s, z13.s\n" + "fadd z11.s, z13.s, z15.s\n" + "fmsb z18.s, p1/M, z16.s, z9.s\n" + "fmul z9.s, z10.s, z2.s[2]\n" + "fmad z10.s, p1/M, z16.s, z21.s\n" + "fmad z31.s, p1/M, z16.s, z29.s\n" + "fmad z8.s, p1/M, z16.s, z11.s\n" + "fneg z9.s, p1/M, z9.s\n" + "fadd z9.s, z9.s, z19.s\n" + "fmls z10.s, z19.s, z2.s[3]\n" + "fmls z31.s, z12.s, z2.s[3]\n" + "st1w { z31.s }, p0, [%x[output_row_0]]\n" + "fadd z26.s, z9.s, z18.s\n" + "fsub z18.s, z18.s, z9.s\n" + "fmls z8.s, z4.s, z2.s[3]\n" + "fmad z25.s, p1/M, z16.s, z23.s\n" + "fmad z28.s, p1/M, z16.s, z10.s\n" + "fmad z17.s, p1/M, z16.s, z26.s\n" + "fmad z27.s, p1/M, z16.s, z18.s\n" + "fmls z25.s, z30.s, z2.s[3]\n" + "fmls z28.s, z14.s, z2.s[3]\n" + "fmls z17.s, z5.s, z2.s[3]\n" + "st1w { z17.s }, p0, [%x[output_row_0], %x[output_col_1_stride], LSL #2]\n" + "fmls z27.s, z24.s, z2.s[3]\n" + "st1w { z27.s }, p0, [%x[output_row_0], x25, LSL #2]\n" + "st1w { z8.s }, p0, [%x[output_row_0], x23, LSL #2]\n" + "st1w { z25.s }, p0, [%x[output_row_0], x21, LSL #2]\n" + "st1w { z28.s }, p0, [%x[output_row_0], x19, LSL #2]\n" + "incb %x[output_row_0]\n" + "ld1w { z19.s }, p0/Z, [x16]\n" + "ld1w { z7.s }, p0/Z, [x16, %x[input_col_1_stride], LSL #2]\n" + "fmul z13.s, z7.s, z2.s[1]\n" + "ld1w { z6.s }, p0/Z, [x16, x26, LSL #2]\n" + "ld1w { z27.s }, p0/Z, [x16, x24, LSL #2]\n" + "fneg z13.s, p1/M, z13.s\n" + "ld1w { z25.s }, p0/Z, [x16, x22, LSL #2]\n" + "fsub z15.s, z25.s, z6.s\n" + "fmad z19.s, p1/M, z16.s, z25.s\n" + "ld1w { z20.s }, p0/Z, [x16, x20, LSL #2]\n" + "fmla z13.s, z27.s, z2.s[1]\n" + "ld1w { z0.s }, p0/Z, [x12]\n" + "incb x16\n" + "fmls z19.s, z6.s, z2.s[3]\n" + "ld1w { z31.s }, p0/Z, [x12, %x[input_col_1_stride], LSL #2]\n" + "fsub z8.s, z15.s, z13.s\n" + "fadd z28.s, z13.s, z15.s\n" + "ld1w { z1.s }, p0/Z, [x12, x26, LSL #2]\n" + "fmsb z6.s, p1/M, z16.s, z25.s\n" + "ld1w { z21.s }, p0/Z, [x12, x24, LSL #2]\n" + "fmul z25.s, z7.s, z2.s[2]\n" + "ld1w { z22.s }, p0/Z, [x12, x22, LSL #2]\n" + "fsub z15.s, z22.s, z1.s\n" + "fneg z25.s, p1/M, z25.s\n" + "ld1w { z17.s }, p0/Z, [x12, x20, LSL #2]\n" + "fadd z25.s, z25.s, z27.s\n" + "incb x12\n" + "fmad z7.s, p1/M, z16.s, z20.s\n" + "fmad z0.s, p1/M, z16.s, z22.s\n" + "fmul z13.s, z31.s, z2.s[1]\n" + "fadd z3.s, z25.s, z6.s\n" + "fsub z6.s, z6.s, z25.s\n" + "fmls z7.s, z27.s, z2.s[3]\n" + "fmls z0.s, z1.s, z2.s[3]\n" + "fneg z13.s, p1/M, z13.s\n" + "fmla z13.s, z21.s, z2.s[1]\n" + "fsub z9.s, z15.s, z13.s\n" + "fadd z27.s, z13.s, z15.s\n" + "fmsb z1.s, p1/M, z16.s, z22.s\n" + "fsub z15.s, z29.s, z12.s\n" + "fmul z22.s, z31.s, z2.s[2]\n" + "fmad z31.s, p1/M, z16.s, z17.s\n" + "fmul z13.s, z19.s, z2.s[1]\n" + "fmsb z12.s, p1/M, z16.s, z29.s\n" + "fneg z22.s, p1/M, z22.s\n" + "fadd z22.s, z22.s, z21.s\n" + "fmls z31.s, z21.s, z2.s[3]\n" + "fneg z13.s, p1/M, z13.s\n" + "fadd z25.s, z22.s, z1.s\n" + "fsub z1.s, z1.s, z22.s\n" + "fmla z13.s, z0.s, z2.s[1]\n" + "fmul z29.s, z19.s, z2.s[2]\n" + "fadd z22.s, z13.s, z15.s\n" + "st1w { z22.s }, p0, [x11]\n" + "fneg z29.s, p1/M, z29.s\n" + "fsub z22.s, z15.s, z13.s\n" + "fadd z29.s, z29.s, z0.s\n" + "st1w { z22.s }, p0, [x9]\n" + "fadd z22.s, z29.s, z12.s\n" + "fsub z15.s, z26.s, z5.s\n" + "fmul z13.s, z3.s, z2.s[1]\n" + "fsub z12.s, z12.s, z29.s\n" + "fmsb z5.s, p1/M, z16.s, z26.s\n" + "fmul z26.s, z3.s, z2.s[2]\n" + "fneg z13.s, p1/M, z13.s\n" + "fmla z13.s, z25.s, z2.s[1]\n" + "fneg z26.s, p1/M, z26.s\n" + "fadd z26.s, z26.s, z25.s\n" + "fadd z21.s, z13.s, z15.s\n" + "st1w { z21.s }, p0, [x11, %x[output_col_1_stride], LSL #2]\n" + "fsub z21.s, z15.s, z13.s\n" + "fmul z13.s, z6.s, z2.s[1]\n" + "fneg z13.s, p1/M, z13.s\n" + "st1w { z21.s }, p0, [x9, %x[output_col_1_stride], LSL #2]\n" + "fadd z21.s, z26.s, z5.s\n" + "fsub z15.s, z18.s, z24.s\n" + "fmla z13.s, z1.s, z2.s[1]\n" + "fsub z5.s, z5.s, z26.s\n" + "fmsb z24.s, p1/M, z16.s, z18.s\n" + "fmul z18.s, z6.s, z2.s[2]\n" + "fadd z20.s, z13.s, z15.s\n" + "st1w { z20.s }, p0, [x11, x25, LSL #2]\n" + "fneg z18.s, p1/M, z18.s\n" + "fsub z20.s, z15.s, z13.s\n" + "fadd z18.s, z18.s, z1.s\n" + "st1w { z20.s }, p0, [x9, x25, LSL #2]\n" + "fadd z20.s, z18.s, z24.s\n" + "fsub z15.s, z11.s, z4.s\n" + "fmul z13.s, z28.s, z2.s[1]\n" + "fsub z24.s, z24.s, z18.s\n" + "fmsb z4.s, p1/M, z16.s, z11.s\n" + "fmul z11.s, z28.s, z2.s[2]\n" + "fneg z13.s, p1/M, z13.s\n" + "fmla z13.s, z27.s, z2.s[1]\n" + "fneg z11.s, p1/M, z11.s\n" + "fadd z11.s, z11.s, z27.s\n" + "fadd z26.s, z13.s, z15.s\n" + "st1w { z26.s }, p0, [x11, x23, LSL #2]\n" + "fsub z26.s, z15.s, z13.s\n" + "fmul z13.s, z8.s, z2.s[1]\n" + "fneg z13.s, p1/M, z13.s\n" + "st1w { z26.s }, p0, [x9, x23, LSL #2]\n" + "fadd z26.s, z11.s, z4.s\n" + "fsub z15.s, z23.s, z30.s\n" + "fmla z13.s, z9.s, z2.s[1]\n" + "fsub z4.s, z4.s, z11.s\n" + "fmsb z30.s, p1/M, z16.s, z23.s\n" + "fmul z23.s, z8.s, z2.s[2]\n" + "fadd z18.s, z13.s, z15.s\n" + "st1w { z18.s }, p0, [x11, x21, LSL #2]\n" + "fneg z23.s, p1/M, z23.s\n" + "fsub z18.s, z15.s, z13.s\n" + "fadd z23.s, z23.s, z9.s\n" + "st1w { z18.s }, p0, [x9, x21, LSL #2]\n" + "fadd z18.s, z23.s, z30.s\n" + "fsub z15.s, z10.s, z14.s\n" + "fmul z13.s, z7.s, z2.s[1]\n" + "fsub z30.s, z30.s, z23.s\n" + "fmsb z14.s, p1/M, z16.s, z10.s\n" + "fmul z10.s, z7.s, z2.s[2]\n" + "fneg z13.s, p1/M, z13.s\n" + "fmla z13.s, z31.s, z2.s[1]\n" + "fneg z10.s, p1/M, z10.s\n" + "fadd z10.s, z10.s, z31.s\n" + "fadd z17.s, z13.s, z15.s\n" + "st1w { z17.s }, p0, [x11, x19, LSL #2]\n" + "fsub z17.s, z15.s, z13.s\n" + "incb x11\n" + "st1w { z17.s }, p0, [x9, x19, LSL #2]\n" + "fadd z17.s, z10.s, z14.s\n" + "fsub z14.s, z14.s, z10.s\n" + "st1w { z22.s }, p0, [x15]\n" + "incb x9\n" + "st1w { z12.s }, p0, [x13]\n" + "st1w { z21.s }, p0, [x15, %x[output_col_1_stride], LSL #2]\n" + "st1w { z5.s }, p0, [x13, %x[output_col_1_stride], LSL #2]\n" + "st1w { z20.s }, p0, [x15, x25, LSL #2]\n" + "st1w { z24.s }, p0, [x13, x25, LSL #2]\n" + "st1w { z26.s }, p0, [x15, x23, LSL #2]\n" + "st1w { z4.s }, p0, [x13, x23, LSL #2]\n" + "st1w { z18.s }, p0, [x15, x21, LSL #2]\n" + "st1w { z30.s }, p0, [x13, x21, LSL #2]\n" + "st1w { z17.s }, p0, [x15, x19, LSL #2]\n" + "incb x15\n" + "st1w { z14.s }, p0, [x13, x19, LSL #2]\n" + "incb x13\n" + "ld1w { z23.s }, p0/Z, [x28]\n" + "ld1w { z22.s }, p0/Z, [x28, %x[input_col_1_stride], LSL #2]\n" + "fmul z13.s, z22.s, z2.s[1]\n" + "ld1w { z21.s }, p0/Z, [x28, x26, LSL #2]\n" + "ld1w { z20.s }, p0/Z, [x28, x24, LSL #2]\n" + "fneg z13.s, p1/M, z13.s\n" + "ld1w { z26.s }, p0/Z, [x28, x22, LSL #2]\n" + "fsub z15.s, z26.s, z21.s\n" + "fmad z23.s, p1/M, z16.s, z26.s\n" + "ld1w { z18.s }, p0/Z, [x28, x20, LSL #2]\n" + "fmla z13.s, z20.s, z2.s[1]\n" + "incb x28\n" + "fmls z23.s, z21.s, z2.s[3]\n" + "fsub z17.s, z15.s, z13.s\n" + "fadd z30.s, z13.s, z15.s\n" + "fmsb z21.s, p1/M, z16.s, z26.s\n" + "fmul z26.s, z22.s, z2.s[2]\n" + "fmad z22.s, p1/M, z16.s, z18.s\n" + "fmad z19.s, p1/M, z16.s, z23.s\n" + "fmad z28.s, p1/M, z16.s, z30.s\n" + "fneg z26.s, p1/M, z26.s\n" + "fadd z26.s, z26.s, z20.s\n" + "fmls z22.s, z20.s, z2.s[3]\n" + "fmls z19.s, z0.s, z2.s[3]\n" + "st1w { z19.s }, p0, [x27]\n" + "fadd z23.s, z26.s, z21.s\n" + "fsub z21.s, z21.s, z26.s\n" + "fmls z28.s, z27.s, z2.s[3]\n" + "fmad z8.s, p1/M, z16.s, z17.s\n" + "fmad z7.s, p1/M, z16.s, z22.s\n" + "fmad z3.s, p1/M, z16.s, z23.s\n" + "fmad z6.s, p1/M, z16.s, z21.s\n" + "fmls z8.s, z9.s, z2.s[3]\n" + "fmls z7.s, z31.s, z2.s[3]\n" + "fmls z3.s, z25.s, z2.s[3]\n" + "st1w { z3.s }, p0, [x27, %x[output_col_1_stride], LSL #2]\n" + "fmls z6.s, z1.s, z2.s[3]\n" + "st1w { z6.s }, p0, [x27, x25, LSL #2]\n" + "st1w { z28.s }, p0, [x27, x23, LSL #2]\n" + "st1w { z8.s }, p0, [x27, x21, LSL #2]\n" + "st1w { z7.s }, p0, [x27, x19, LSL #2]\n" + "incb x27\n" + "whilelt p0.s, XZR, %x[num_channels]\n" + "bne 1b\n" + "2:" // channel_loop_end + + : [input_row_0] "+&r" (input), [num_channels] "+&r" (long_channels), [output_row_0] "+&r" (output) + : [B_values] "r" (B_values), [input_col_1_stride] "r" ((long) input_col_stride), [input_row_stride] "r" ((long) input_row_stride), [output_col_1_stride] "r" ((long) output_col_stride), [output_row_stride] "r" (6 * (long) output_col_stride) + : "cc", "memory", "p0", "p1", "x9", "x10", "x11", "x12", "x13", "x14", "x15", "x16", "x19", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); +} + +} // namespace input_transform +} // namespace winograd +} // namespace arm_conv + +#endif // __aarch64__ && defined(ARM_COMPUTE_ENABLE_SVE) diff --git a/src/core/NEON/kernels/convolution/winograd/input_transforms_fp16.cpp b/src/core/NEON/kernels/convolution/winograd/input_transforms_fp16.cpp new file mode 100644 index 0000000000..35d61fa94d --- /dev/null +++ b/src/core/NEON/kernels/convolution/winograd/input_transforms_fp16.cpp @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + +#include "input_transform.hpp" +#include "winograd_implementations.hpp" + +#include +#include + +namespace arm_conv { +namespace winograd { +namespace input_transform { + +void a64_fp16_6x6(unsigned int, const __fp16 *, size_t, size_t, __fp16 *, size_t); + +#define IMPL(HEIGHT, WIDTH, FUNC, DRIVER) new Transform ## DRIVER <__fp16, __fp16>(#FUNC, HEIGHT, WIDTH, FUNC) + +static const TransformImplementation<__fp16> transforms_fp16[] = { + { IMPL(6, 6, a64_fp16_6x6, Unpadded) }, + { nullptr }, +}; + +template <> +const TransformImplementation<__fp16> *implementation_list(void) +{ + return transforms_fp16; +} + +} // namespace input_transform +} // namespace winograd +} // namespace arm_conv + +#endif // defined(__aarch64__) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) diff --git a/src/core/NEON/kernels/convolution/winograd/input_transforms_fp32.cpp b/src/core/NEON/kernels/convolution/winograd/input_transforms_fp32.cpp new file mode 100644 index 0000000000..ec4e954f71 --- /dev/null +++ b/src/core/NEON/kernels/convolution/winograd/input_transforms_fp32.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "input_transform.hpp" +#include "winograd_implementations.hpp" + +#include +#include + +namespace arm_conv { +namespace winograd { +namespace input_transform { + +#if defined(__aarch64__) +#if defined(ARM_COMPUTE_ENABLE_SVE) +void sve_fp32_6x6(unsigned int, const float *, size_t, size_t, float *, size_t); +#endif // defined(ARM_COMPUTE_ENABLE_SVE) +void a64_fp32_6x6(unsigned int, const float *, size_t, size_t, float *, size_t); +#else // defined(__aarch64__) +void arm_fp32_6x6(unsigned int, const float *, size_t, size_t, float *, size_t); +#endif // defined(__aarch64__) +void arm_fp32_4x4(unsigned int, const float *, size_t, size_t, float *, size_t); +void arm_fp32_1x8(unsigned int, const float *, size_t, size_t, float *, size_t); + +#define IMPL(HEIGHT, WIDTH, FUNC, DRIVER) new Transform ## DRIVER (#FUNC, HEIGHT, WIDTH, FUNC) + +static const TransformImplementation transforms_fp32[] = { +#if defined(__aarch64__) +#if defined(ARM_COMPUTE_ENABLE_SVE) + { IMPL(6, 6, sve_fp32_6x6, Unpadded), MethodConstraints::RequiresSVE }, +#endif // defined(ARM_COMPUTE_ENABLE_SVE) + { IMPL(6, 6, a64_fp32_6x6, Unpadded) }, +#else // defined(__aarch64__) + { IMPL(6, 6, arm_fp32_6x6, Unpadded) }, +#endif // defined(__aarch64__) + { IMPL(4, 4, arm_fp32_4x4, Unpadded) }, + { IMPL(1, 8, arm_fp32_1x8, Unpadded) }, + { new TransformUnpadded("arm_fp32_1x8", 8, 1, TransformUnpadded::get_transposed_kernel(arm_fp32_1x8)) }, + { nullptr }, +}; + +template <> +const TransformImplementation *implementation_list(void) +{ + return transforms_fp32; +} + +} // namespace input_transform +} // namespace winograd +} // namespace arm_conv diff --git a/src/core/NEON/kernels/convolution/winograd/output_transform.hpp b/src/core/NEON/kernels/convolution/winograd/output_transform.hpp new file mode 100644 index 0000000000..5148495608 --- /dev/null +++ b/src/core/NEON/kernels/convolution/winograd/output_transform.hpp @@ -0,0 +1,302 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#include "src/core/NEON/kernels/assembly/winograd.hpp" + +#include "src/core/NEON/kernels/arm_conv/addressing.hpp" + +#include +#include +#include +#include + +namespace arm_conv { +namespace winograd { +namespace output_transform { + +/* Driver class for the Winograd output transforms. + * + * This provides a base implementation which handles iteration over the output + * tensor; subclasses are responsible for managing working space and executing + * the transform on individual tiles. + */ +template +class TransformBase : public ITransform +{ + const std::string m_name; + const unsigned int m_output_rows, m_output_cols; + const unsigned int m_kernel_rows, m_kernel_cols; + + protected: + virtual size_t get_working_space_per_thread(const ConvolutionArgs &) const + { + return 0; + } + + virtual void initialise_thread_working_space(const ConvolutionArgs &, void *) const + { + // Nothing to do + } + + virtual void execute_tile( + unsigned int n_channels, + const TIn *inptr, size_t ld_in_matrix, + const TIn *bias, + TOut *outptr, size_t ld_out_row, size_t ld_out_col, + TOut activation_min, TOut activation_max, + unsigned int valid_rows, unsigned int valid_cols, + void *working_space + ) const = 0; + + void execute_internal( + const ConvolutionArgs &args, + const TIn *inptr, size_t ld_in_batch, size_t ld_in_matrix, size_t ld_in_row, + const TIn *bias, + TOut *outptr, size_t ld_out_batch, size_t ld_out_row, size_t ld_out_col, + void *working_space, unsigned int thread_id, unsigned int n_threads + ) const + { + // Get the working space for this thread, and initialise it. + working_space = reinterpret_cast(working_space) + + this->get_working_space_per_thread(args) * thread_id; + this->initialise_thread_working_space(args, working_space); + + // Get the activation values + auto activation_min = static_cast(-std::numeric_limits::infinity()); + auto activation_max = static_cast(+std::numeric_limits::infinity()); + switch (args.activation.type) + { + case arm_gemm::Activation::Type::BoundedReLU: + activation_max = static_cast(args.activation.param1); + // Fall through + case arm_gemm::Activation::Type::ReLU: + activation_min = static_cast(0); + break; + default: + break; + } + + // Determine the number of tiles in a row, we use this to get the right + // offset into the input data. + const auto n_tile_cols = (args.output_shape.cols + this->get_output_cols() - 1) / this->get_output_cols(); + + // Execute over all batches + for (unsigned int batch = 0; batch < args.n_batches; batch++) + { + auto inptr_row = inptr + thread_id*n_tile_cols*ld_in_row; + auto outptr_row = outptr + thread_id*ld_out_row*this->get_output_rows(); + inptr += ld_in_batch; + outptr += ld_out_batch; + + // Stripe rows of tiles over threads. + for (auto out_i = thread_id * this->get_output_rows(); + out_i < args.output_shape.rows; + out_i += n_threads * this->get_output_rows()) + { + auto inptr_tile = inptr_row; + auto outptr_tile = outptr_row; + inptr_row += n_threads * n_tile_cols * ld_in_row; + outptr_row += n_threads * this->get_output_rows() * ld_out_row; + + // Iterate over all columns + for (auto out_j = 0u; out_j < args.output_shape.cols; + out_j += this->get_output_cols()) + { + // Execute the tile + this->execute_tile( + args.n_output_channels, + inptr_tile, ld_in_matrix, + bias, + outptr_tile, ld_out_row, ld_out_col, + activation_min, activation_max, + args.output_shape.rows - out_i, // Number of valid rows remaining + args.output_shape.cols - out_j, // Number of valid columns remaining + working_space + ); + + // Progress the pointers + inptr_tile += ld_in_row; + outptr_tile += this->get_output_cols() * ld_out_col; + } + } + } + } + + public: + TransformBase(const std::string &name, + unsigned int output_rows, unsigned int output_cols, + unsigned int kernel_rows, unsigned int kernel_cols) + : m_name(name), + m_output_rows(output_rows), m_output_cols(output_cols), + m_kernel_rows(kernel_rows), m_kernel_cols(kernel_cols) + { + } + + const std::string &get_name(void) const override { return m_name; } + + unsigned int get_input_rows(void) const override final { return m_kernel_rows + m_output_rows - 1; } + unsigned int get_input_cols(void) const override final { return m_kernel_cols + m_output_cols - 1; } + + unsigned int get_output_rows(void) const override final { return m_output_rows; } + unsigned int get_output_cols(void) const override final { return m_output_cols; } + + unsigned int get_kernel_rows(void) const override final { return m_kernel_rows; } + unsigned int get_kernel_cols(void) const override final { return m_kernel_cols; } + + size_t get_working_space_size(const ConvolutionArgs &args, unsigned int n_threads) const override + { + return n_threads * this->get_working_space_per_thread(args); + } + + void execute( + const ConvolutionArgs &args, + const void *inptr, size_t ld_in_batch, size_t ld_in_matrix, size_t ld_in_row, + const void *bias, + void *outptr, size_t ld_out_batch, size_t ld_out_row, size_t ld_out_col, + void *working_space, unsigned int thread_id, unsigned int n_threads + ) const override + { + execute_internal( + args, + reinterpret_cast(inptr), ld_in_batch, ld_in_matrix, ld_in_row, + reinterpret_cast(bias), + reinterpret_cast(outptr), ld_out_batch, ld_out_row, ld_out_col, + working_space, thread_id, n_threads + ); + } +}; + +template +class TransformUnpadded : public TransformBase +{ + using Kernel = std::function; + const Kernel m_kernel; + + protected: + size_t get_working_space_per_thread(const ConvolutionArgs &args) const override + { + // We create a buffer the size of the output tile + const auto n_output_points = this->get_output_rows() * this->get_output_cols(); + return sizeof(TOut) * n_output_points * args.n_output_channels; + } + + void execute_tile( + unsigned int n_channels, + const TIn *inptr, size_t ld_in_matrix, + const TIn *bias, + TOut *outptr, size_t ld_out_row, size_t ld_out_col, + TOut activation_min, TOut activation_max, + unsigned int valid_rows, unsigned int valid_cols, + void *working_space + ) const override final + { + // Get copies of the output tensor parameters + auto kernel_outptr = outptr; + auto kernel_ld_out_row = ld_out_row, kernel_ld_out_col = ld_out_col; + + // If there's padding on either the left or the right, then we execute the + // kernel into the output buffer and then perform a copy. + if (valid_rows < this->get_output_rows() || + valid_cols < this->get_output_cols()) + { + // Override the kernel output parameters + kernel_outptr = reinterpret_cast(working_space); + kernel_ld_out_col = n_channels; + kernel_ld_out_row = kernel_ld_out_col * this->get_output_cols(); + } + + // Execute the kernel + m_kernel( + n_channels, + inptr, ld_in_matrix, + bias, + kernel_outptr, kernel_ld_out_row, kernel_ld_out_col, + activation_min, activation_max + ); + + // If necessary, copy from the working space into the destination tensor. + if (valid_rows < this->get_output_rows() || + valid_cols < this->get_output_cols()) + { + const auto last_row = std::min(valid_rows, this->get_output_rows()); + const auto last_col = std::min(valid_cols, this->get_output_cols()); + + for (auto i = 0u; i < last_row; i++) + { + auto patch_tile = kernel_outptr; + auto out_tile = outptr; + kernel_outptr += kernel_ld_out_row; + outptr += ld_out_row; + + for (auto j = 0u; j < last_col; j++) + { + memcpy(out_tile, patch_tile, sizeof(TOut) * n_channels); + patch_tile += kernel_ld_out_col; + out_tile += ld_out_col; + } + } + } + } + + public: + TransformUnpadded(const std::string &name, + unsigned int output_rows, unsigned int output_cols, + unsigned int kernel_rows, unsigned int kernel_cols, + const Kernel kernel) + : TransformBase(name, output_rows, output_cols, kernel_rows, kernel_cols), + m_kernel(kernel) + { + } + + /* Utility method to get a transposed variant of a kernel, this transposed + * version simply calls the original kernel with the output row and column + * strides swapped. + */ + static constexpr Kernel get_transposed_kernel(const Kernel &kernel) + { + return [kernel] ( + const unsigned int n_channels, + const TIn *const inptr, const size_t ld_in_matrix, + const TIn *const bias, + TOut *const outptr, const size_t ld_out_row, const size_t ld_out_col, + const TOut activation_min, const TOut activation_max + ) { + kernel(n_channels, inptr, ld_in_matrix, bias, + outptr, ld_out_col, ld_out_row, + activation_min, activation_max); + }; + } +}; + +} // namespace output_transform +} // namespace winograd +} // namespace arm_conv diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4x4_3x3_fp16_fp16_integers.cpp b/src/core/NEON/kernels/convolution/winograd/output_transforms/a64_fp16_4x4_3x3.cpp similarity index 94% rename from src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4x4_3x3_fp16_fp16_integers.cpp rename to src/core/NEON/kernels/convolution/winograd/output_transforms/a64_fp16_4x4_3x3.cpp index 3c071bdac6..8a2837a125 100644 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4x4_3x3_fp16_fp16_integers.cpp +++ b/src/core/NEON/kernels/convolution/winograd/output_transforms/a64_fp16_4x4_3x3.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020 Arm Limited. + * Copyright (c) 2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -22,25 +22,29 @@ * SOFTWARE. */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -#include "arm.hpp" -#include "output.hpp" -namespace winograd -{ +#include +#include +#include + +namespace arm_conv { +namespace winograd { +namespace output_transform { -template <> -void winograd::OutputTransform<3, 3, 6, 6, __fp16, __fp16, winograd::WinogradRoots::Integers>::transform_tile( - const int n_channels, +void a64_fp16_4x4_3x3( + unsigned int n_channels, const __fp16* inptr, - const int matrix_stride, + const size_t matrix_stride, const __fp16* bptr, __fp16* const output, - const int output_row_stride, - const int output_col_stride, + const size_t output_row_stride, + const size_t output_col_stride, const __fp16 output_min, const __fp16 output_max ) { + constexpr int output_tile_rows = 4, output_tile_cols = 4; + // Construct a map to the output cells __fp16 *outptrs[output_tile_rows][output_tile_cols]; for (int i = 0; i < output_tile_rows; i++) @@ -249,7 +253,8 @@ void winograd::OutputTransform<3, 3, 6, 6, __fp16, __fp16, winograd::WinogradRoo } } -template class OutputTransform<3, 3, 6, 6, __fp16, __fp16, winograd::WinogradRoots::Integers>; +} // namespace output_transform +} // namespace winograd +} // namespace arm_conv -} // namespace winograd #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2_7_fp32_fp32_integers.cpp b/src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_1x2_1x7.cpp similarity index 70% rename from src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2_7_fp32_fp32_integers.cpp rename to src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_1x2_1x7.cpp index 8e257909a3..1fb1189aa5 100644 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2_7_fp32_fp32_integers.cpp +++ b/src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_1x2_1x7.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 Arm Limited. + * Copyright (c) 2022 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -22,42 +22,36 @@ * SOFTWARE. */ -#include "arm.hpp" -#include "output.hpp" +#include +#include +#include -namespace winograd -{ +namespace arm_conv { +namespace winograd { +namespace output_transform { -template <> -void OutputTransform<1, 7, 1, 8, float, float, WinogradRoots::Integers>::transform_tile( - const int n_channels, +void arm_fp32_1x2_1x7( + unsigned int n_channels, const float* inptr, - const int matrix_stride, + const size_t matrix_stride, const float* bptr, - float* const output, - const int, // No need to stride across rows - const int output_col_stride, + float *outptr, + size_t, // No need to stride across rows + const size_t output_col_stride, const float output_min, const float output_max ) { - // Construct a map to the output cells - float *outptrs[output_tile_cols]; - for (int j = 0; j < output_tile_cols; j++) - { - outptrs[j] = output + j*output_col_stride; - } + constexpr auto inner_tile_cols = 8u, output_tile_cols = 2u; // For each channel of the output - int channels_remaining = n_channels; -#ifdef __arm_any__ - for (; channels_remaining >= 4; channels_remaining -= 4) + for (; n_channels >= 4; n_channels -= 4) { // Matrices used and computed during this transform float32x4_t F[inner_tile_cols], f[output_tile_cols], b = vdupq_n_f32(0.0f); // Read a 1x8 tile in the Winograd domain - for (int j = 0; j < inner_tile_cols; j++) + for (auto j = 0u; j < inner_tile_cols; j++) { F[j] = vld1q_f32(inptr + j*matrix_stride); } @@ -72,21 +66,21 @@ void OutputTransform<1, 7, 1, 8, float, float, WinogradRoots::Integers>::transfo b = vld1q_f32(bptr); bptr += 4; } - for (int j = 0; j < output_tile_cols; j++) + for (auto j = 0u; j < output_tile_cols; j++) { const auto y = vminq_f32(vmaxq_f32(f[j] + b, vdupq_n_f32(output_min)), vdupq_n_f32(output_max)); - vst1q_f32(outptrs[j], y); - outptrs[j] += 4; + vst1q_f32(outptr + j*output_col_stride, y); } + outptr += 4; } - for (; channels_remaining >= 2; channels_remaining -= 2) + for (; n_channels >= 2; n_channels -= 2) { // Matrices used and computed during this transform float32x2_t F[inner_tile_cols], f[output_tile_cols], b = vdup_n_f32(0.0f); // Read a 1x8 tile in the Winograd domain - for (int j = 0; j < inner_tile_cols; j++) + for (auto j = 0u; j < inner_tile_cols; j++) { F[j] = vld1_f32(inptr + j*matrix_stride); } @@ -101,26 +95,24 @@ void OutputTransform<1, 7, 1, 8, float, float, WinogradRoots::Integers>::transfo b = vld1_f32(bptr); bptr += 2; } - for (int j = 0; j < output_tile_cols; j++) + for (auto j = 0u; j < output_tile_cols; j++) { const auto y = vmin_f32(vmax_f32(f[j] + b, vdup_n_f32(output_min)), vdup_n_f32(output_max)); - vst1_f32(outptrs[j], y); - outptrs[j] += 2; + vst1_f32(outptr + j*output_col_stride, y); } + outptr += 2; } -#endif // __arm_any__ - for (; channels_remaining; channels_remaining--) + if (n_channels) { // Matrices used and computed during this transform float F[inner_tile_cols], f[output_tile_cols], b = 0.0f; // Read a 1x8 tile in the Winograd domain - for (int j = 0; j < inner_tile_cols; j++) + for (auto j = 0u; j < inner_tile_cols; j++) { F[j] = *(inptr + j*matrix_stride); } - inptr++; f[0] = F[0]*1 + F[1]*1 + F[2]*1 + F[3]*1 + F[4]*1 + F[5]*1 + F[6]*1; f[1] = F[1]*-1 + F[5]*-3 + F[3]*-2 + F[4]*2 + F[6]*3 + F[2]*1 + F[7]*1; @@ -130,14 +122,13 @@ void OutputTransform<1, 7, 1, 8, float, float, WinogradRoots::Integers>::transfo { b = *(bptr++); } - for (int j = 0; j < output_tile_cols; j++) + for (auto j = 0u; j < output_tile_cols; j++) { - *(outptrs[j]++) = std::max(std::min(f[j] + b, output_max), output_min); + *(outptr + j*output_col_stride) = std::max(std::min(f[j] + b, output_max), output_min); } } } -template class OutputTransform<1, 7, 1, 8, float, float, WinogradRoots::Integers>; -template class OutputTransform<7, 1, 8, 1, float, float, WinogradRoots::Integers>; - +} // namespace output_transform } // namespace winograd +} // namespace arm_conv diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4_5_fp32_fp32_integers.cpp b/src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_1x4_1x5.cpp similarity index 75% rename from src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4_5_fp32_fp32_integers.cpp rename to src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_1x4_1x5.cpp index c35037e143..40fef1188b 100644 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4_5_fp32_fp32_integers.cpp +++ b/src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_1x4_1x5.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 Arm Limited. + * Copyright (c) 2022 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -22,42 +22,36 @@ * SOFTWARE. */ -#include "output.hpp" -#include "arm.hpp" +#include +#include +#include -namespace winograd -{ +namespace arm_conv { +namespace winograd { +namespace output_transform { -template <> -void OutputTransform<1, 5, 1, 8, float, float, WinogradRoots::Integers>::transform_tile( - const int n_channels, +void arm_fp32_1x4_1x5( + unsigned int n_channels, const float* inptr, - const int matrix_stride, + const size_t matrix_stride, const float* bptr, - float* const output, - const int, // No need to stride across rows - const int output_col_stride, + float *outptr, + size_t, // No need to stride across rows + const size_t output_col_stride, const float output_min, const float output_max ) { - // Construct a map to the output cells - float *outptrs[output_tile_cols]; - for (int j = 0; j < output_tile_cols; j++) - { - outptrs[j] = output + j*output_col_stride; - } + constexpr auto inner_tile_cols = 8u, output_tile_cols = 4u; // For each channel of the output - int channels_remaining = n_channels; -#ifdef __arm_any__ - for (; channels_remaining >= 4; channels_remaining -= 4) + for (; n_channels >= 4; n_channels -= 4) { // Matrices used and computed during this transform float32x4_t F[inner_tile_cols], f[output_tile_cols], b = vdupq_n_f32(0.0f); // Read a 1x8 tile in the Winograd domain - for (int j = 0; j < inner_tile_cols; j++) + for (auto j = 0u; j < inner_tile_cols; j++) { F[j] = vld1q_f32(inptr + j*matrix_stride); } @@ -74,22 +68,22 @@ void OutputTransform<1, 5, 1, 8, float, float, WinogradRoots::Integers>::transfo b = vld1q_f32(bptr); bptr += 4; } - for (int j = 0; j < output_tile_cols; j++) + for (auto j = 0u; j < output_tile_cols; j++) { const auto y = vmaxq_f32(vminq_f32(vaddq_f32(f[j], b), vdupq_n_f32(output_max)), vdupq_n_f32(output_min)); - vst1q_f32(outptrs[j], y); - outptrs[j] += 4; + vst1q_f32(outptr + j*output_col_stride, y); } + outptr += 4; } - for (; channels_remaining >= 2; channels_remaining -= 2) + for (; n_channels >= 2; n_channels -= 2) { // Matrices used and computed during this transform float32x2_t F[inner_tile_cols], f[output_tile_cols], b = vdup_n_f32(0.0f); // Read a 1x8 tile in the Winograd domain - for (int j = 0; j < inner_tile_cols; j++) + for (auto j = 0u; j < inner_tile_cols; j++) { F[j] = vld1_f32(inptr + j*matrix_stride); } @@ -106,23 +100,22 @@ void OutputTransform<1, 5, 1, 8, float, float, WinogradRoots::Integers>::transfo b = vld1_f32(bptr); bptr += 2; } - for (int j = 0; j < output_tile_cols; j++) + for (auto j = 0u; j < output_tile_cols; j++) { const auto y = vmax_f32(vmin_f32(vadd_f32(f[j], b), vdup_n_f32(output_max)), vdup_n_f32(output_min)); - vst1_f32(outptrs[j], y); - outptrs[j] += 2; + vst1_f32(outptr + j*output_col_stride, y); } + outptr += 2; } -#endif // __arm_any__ - for (; channels_remaining; channels_remaining--) + for (; n_channels; n_channels--) { // Matrices used and computed during this transform float F[inner_tile_cols], f[output_tile_cols], b = 0.0f; // Read a 1x8 tile in the Winograd domain - for (int j = 0; j < inner_tile_cols; j++) + for (auto j = 0u; j < inner_tile_cols; j++) { F[j] = *(inptr + j*matrix_stride); } @@ -138,15 +131,15 @@ void OutputTransform<1, 5, 1, 8, float, float, WinogradRoots::Integers>::transfo { b = *(bptr++); } - for (int j = 0; j < output_tile_cols; j++) + for (auto j = 0u; j < output_tile_cols; j++) { const auto y = std::max(std::min(f[j] + b, output_max), output_min); - *(outptrs[j]++) = y; + *(outptr + j*output_col_stride) = y; } + outptr++; } } -template class OutputTransform<1, 5, 1, 8, float, float, WinogradRoots::Integers>; -template class OutputTransform<5, 1, 8, 1, float, float, WinogradRoots::Integers>; - +} // namespace output_transform } // namespace winograd +} // namespace arm_conv diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_6_3_fp32_fp32_integers.cpp b/src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_1x6_1x3.cpp similarity index 77% rename from src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_6_3_fp32_fp32_integers.cpp rename to src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_1x6_1x3.cpp index 528cd8c691..8203b579cb 100644 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_6_3_fp32_fp32_integers.cpp +++ b/src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_1x6_1x3.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 Arm Limited. + * Copyright (c) 2022 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -22,42 +22,37 @@ * SOFTWARE. */ -#include "output.hpp" -#include "arm.hpp" +#include +#include -namespace winograd -{ +#include + +namespace arm_conv { +namespace winograd { +namespace output_transform { -template <> -void OutputTransform<1, 3, 1, 8, float, float, WinogradRoots::Integers>::transform_tile( - const int n_channels, +void arm_fp32_1x6_1x3( + unsigned int n_channels, const float* inptr, - const int matrix_stride, + const size_t matrix_stride, const float* bptr, - float* const output, - const int, // No need to stride across rows - const int output_col_stride, + float *outptr, + size_t, // No need to stride across rows + const size_t output_col_stride, const float output_min, const float output_max ) { - // Construct a map to the output cells - float *outptrs[output_tile_cols]; - for (int j = 0; j < output_tile_cols; j++) - { - outptrs[j] = output + j*output_col_stride; - } + constexpr unsigned int inner_tile_cols = 8, output_tile_cols = 6; // For each channel of the output - int channels_remaining = n_channels; -#ifdef __arm_any__ - for (; channels_remaining >= 4; channels_remaining -= 4) + for (; n_channels >= 4; n_channels -= 4) { // Matrices used and computed during this transform float32x4_t F[inner_tile_cols], f[output_tile_cols], b = vdupq_n_f32(0.0f); // Read a 1x8 tile in the Winograd domain - for (int j = 0; j < inner_tile_cols; j++) + for (auto j = 0u; j < inner_tile_cols; j++) { F[j] = vld1q_f32(inptr + j*matrix_stride); } @@ -76,21 +71,21 @@ void OutputTransform<1, 3, 1, 8, float, float, WinogradRoots::Integers>::transfo b = vld1q_f32(bptr); bptr += 4; } - for (int j = 0; j < output_tile_cols; j++) + for (auto j = 0u; j < output_tile_cols; j++) { const auto y = vminq_f32(vmaxq_f32(f[j] + b, vdupq_n_f32(output_min)), vdupq_n_f32(output_max)); - vst1q_f32(outptrs[j], y); - outptrs[j] += 4; + vst1q_f32(outptr + j*output_col_stride, y); } + outptr += 4; } - for (; channels_remaining >= 2; channels_remaining -= 2) + for (; n_channels >= 2; n_channels -= 2) { // Matrices used and computed during this transform float32x2_t F[inner_tile_cols], f[output_tile_cols], b = vdup_n_f32(0.0f); // Read a 1x8 tile in the Winograd domain - for (int j = 0; j < inner_tile_cols; j++) + for (auto j = 0u; j < inner_tile_cols; j++) { F[j] = vld1_f32(inptr + j*matrix_stride); } @@ -109,22 +104,21 @@ void OutputTransform<1, 3, 1, 8, float, float, WinogradRoots::Integers>::transfo b = vld1_f32(bptr); bptr += 2; } - for (int j = 0; j < output_tile_cols; j++) + for (auto j = 0u; j < output_tile_cols; j++) { const auto y = vmin_f32(vmax_f32(f[j] + b, vdup_n_f32(output_min)), vdup_n_f32(output_max)); - vst1_f32(outptrs[j], y); - outptrs[j] += 2; + vst1_f32(outptr + j*output_col_stride, y); } + outptr += 2; } -#endif // __arm_any__ - for (; channels_remaining; channels_remaining--) + for (; n_channels; n_channels--) { // Matrices used and computed during this transform float F[inner_tile_cols], f[output_tile_cols], b = 0.0f; // Read a 1x8 tile in the Winograd domain - for (int j = 0; j < inner_tile_cols; j++) + for (auto j = 0u; j < inner_tile_cols; j++) { F[j] = *(inptr + j*matrix_stride); } @@ -142,14 +136,14 @@ void OutputTransform<1, 3, 1, 8, float, float, WinogradRoots::Integers>::transfo { b = *(bptr++); } - for (int j = 0; j < output_tile_cols; j++) + for (auto j = 0u; j < output_tile_cols; j++) { - *(outptrs[j]++) = std::max(std::min(f[j] + b, output_max), output_min); + *(outptr + j*output_col_stride) = std::max(std::min(f[j] + b, output_max), output_min); } + outptr++; } } -template class OutputTransform<1, 3, 1, 8, float, float, WinogradRoots::Integers>; -template class OutputTransform<3, 1, 8, 1, float, float, WinogradRoots::Integers>; - -} // namespace +} // namespace output_transform +} // namespace winograd +} // namespace arm_conv diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2x2_3x3_fp32_fp32_integers.cpp b/src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_2x2_3x3.cpp similarity index 70% rename from src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2x2_3x3_fp32_fp32_integers.cpp rename to src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_2x2_3x3.cpp index 8b0b4707f9..c13a826b4c 100644 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2x2_3x3_fp32_fp32_integers.cpp +++ b/src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_2x2_3x3.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 Arm Limited. + * Copyright (c) 2022 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -22,47 +22,38 @@ * SOFTWARE. */ -#include "arm.hpp" -#include "output.hpp" +#include +#include +#include -namespace winograd -{ +namespace arm_conv { +namespace winograd { +namespace output_transform { -template <> -void OutputTransform<3, 3, 4, 4, float, float, WinogradRoots::Integers>::transform_tile( - const int n_channels, +void arm_fp32_2x2_3x3( + unsigned int n_channels, const float* inptr, - const int matrix_stride, + const size_t matrix_stride, const float* bptr, - float* const output, - const int output_row_stride, - const int output_col_stride, + float *outptr, + const size_t output_row_stride, + const size_t output_col_stride, const float output_min, const float output_max ) { - // Construct a map to the output cells - float *outptrs[output_tile_rows][output_tile_cols]; - for (int i = 0; i < output_tile_rows; i++) - { - for (int j = 0; j < output_tile_cols; j++) - { - outptrs[i][j] = output + i*output_row_stride + j*output_col_stride; - } - } + constexpr auto output_tile_rows = 2u, output_tile_cols = 2u; // For each channel of the output - int channels_remaining = n_channels; -#ifdef __aarch64__ - for (; channels_remaining >= 4; channels_remaining -= 4) + for (; n_channels >= 4; n_channels -= 4) { // Matrices used and computed during this transform float32x4_t F[4][4], FZ[4][2], f[2][2], b; // Read a 4x4 tile in the Winograd domain - for (int i = 0, m = 0; i < 4; i++) + for (auto i = 0u, m = 0u; i < 4; i++) { - for (int j = 0; j < 4; j++, m++) + for (auto j = 0u; j < 4; j++, m++) { F[i][j] = vld1q_f32(inptr + m*matrix_stride); } @@ -70,7 +61,7 @@ void OutputTransform<3, 3, 4, 4, float, float, WinogradRoots::Integers>::transfo inptr += 4; // Compute the matrix F Z - for (int i = 0; i < 4; i++) + for (auto i = 0u; i < 4; i++) { // FZ[i][0] = F[i][0] + F[i][1] + F[i][2]; FZ[i][0] = vaddq_f32(vaddq_f32(F[i][0], F[i][1]), F[i][2]); @@ -80,7 +71,7 @@ void OutputTransform<3, 3, 4, 4, float, float, WinogradRoots::Integers>::transfo } // Compute the output tile f = ZT F Z - for (int j = 0; j < 2; j++) + for (auto j = 0u; j < 2; j++) { // f[0][j] = FZ[0][j] + FZ[1][j] + FZ[2][j]; f[0][j] = vaddq_f32(vaddq_f32(FZ[0][j], FZ[1][j]), FZ[2][j]); @@ -101,29 +92,27 @@ void OutputTransform<3, 3, 4, 4, float, float, WinogradRoots::Integers>::transfo } // Write out the output tile - for (int i = 0; i < output_tile_rows; i++) + for (auto i = 0u; i < output_tile_rows; i++) { - for (int j = 0; j < output_tile_cols; j++) + for (auto j = 0u; j < output_tile_cols; j++) { const auto y = vmaxq_f32(vminq_f32(vaddq_f32(f[i][j], b), vdupq_n_f32(output_max)), vdupq_n_f32(output_min)); - vst1q_f32(outptrs[i][j], y); - outptrs[i][j] += 4; + vst1q_f32(outptr + i*output_row_stride + j*output_col_stride, y); } } + outptr += 4; } -#endif // __aarch64__ -#ifdef __arm_any__ - for (; channels_remaining >= 2; channels_remaining -= 2) + for (; n_channels >= 2; n_channels -= 2) { // Matrices used and computed during this transform float32x2_t F[4][4], FZ[4][2], f[2][2], b; // Read a 4x4 tile in the Winograd domain - for (int i = 0, m = 0; i < 4; i++) + for (auto i = 0u, m = 0u; i < 4; i++) { - for (int j = 0; j < 4; j++, m++) + for (auto j = 0u; j < 4; j++, m++) { F[i][j] = vld1_f32(inptr + m*matrix_stride); } @@ -131,7 +120,7 @@ void OutputTransform<3, 3, 4, 4, float, float, WinogradRoots::Integers>::transfo inptr += 2; // Compute the matrix F Z - for (int i = 0; i < 4; i++) + for (auto i = 0u; i < 4; i++) { // FZ[i][0] = F[i][0] + F[i][1] + F[i][2]; FZ[i][0] = vadd_f32(vadd_f32(F[i][0], F[i][1]), F[i][2]); @@ -141,7 +130,7 @@ void OutputTransform<3, 3, 4, 4, float, float, WinogradRoots::Integers>::transfo } // Compute the output tile f = ZT F Z - for (int j = 0; j < 2; j++) + for (auto j = 0u; j < 2; j++) { // f[0][j] = FZ[0][j] + FZ[1][j] + FZ[2][j]; f[0][j] = vadd_f32(vadd_f32(FZ[0][j], FZ[1][j]), FZ[2][j]); @@ -162,28 +151,27 @@ void OutputTransform<3, 3, 4, 4, float, float, WinogradRoots::Integers>::transfo } // Write out the output tile - for (int i = 0; i < output_tile_rows; i++) + for (auto i = 0u; i < output_tile_rows; i++) { - for (int j = 0; j < output_tile_cols; j++) + for (auto j = 0u; j < output_tile_cols; j++) { const auto y = vmax_f32(vmin_f32(vadd_f32(f[i][j], b), vdup_n_f32(output_max)), vdup_n_f32(output_min)); - vst1_f32(outptrs[i][j], y); - outptrs[i][j] += 2; + vst1_f32(outptr + i*output_row_stride + j*output_col_stride, y); } } + outptr += 2; } -#endif // __arm_any__ - for (; channels_remaining; channels_remaining--) + for (; n_channels; n_channels--) { // Matrices used and computed during this transform float F[4][4], FZ[4][2], f[2][2], b; // Read a 4x4 tile in the Winograd domain - for (int i = 0, m = 0; i < 4; i++) + for (auto i = 0u, m = 0u; i < 4; i++) { - for (int j = 0; j < 4; j++, m++) + for (auto j = 0u; j < 4; j++, m++) { F[i][j] = *(inptr + m*matrix_stride); } @@ -191,14 +179,14 @@ void OutputTransform<3, 3, 4, 4, float, float, WinogradRoots::Integers>::transfo inptr++; // Compute the matrix F Z - for (int i = 0; i < 4; i++) + for (auto i = 0u; i < 4; i++) { FZ[i][0] = F[i][0] + F[i][1] + F[i][2]; FZ[i][1] = F[i][1] - F[i][2] - F[i][3]; } // Compute the output tile f = ZT F Z - for (int j = 0; j < 2; j++) + for (auto j = 0u; j < 2; j++) { f[0][j] = FZ[0][j] + FZ[1][j] + FZ[2][j]; f[1][j] = FZ[1][j] - FZ[2][j] - FZ[3][j]; @@ -215,17 +203,18 @@ void OutputTransform<3, 3, 4, 4, float, float, WinogradRoots::Integers>::transfo } // Write out the output tile - for (int i = 0; i < output_tile_rows; i++) + for (auto i = 0u; i < output_tile_rows; i++) { - for (int j = 0; j < output_tile_cols; j++) + for (auto j = 0u; j < output_tile_cols; j++) { const auto y = std::max(std::min(f[i][j] + b, output_max), output_min); - *(outptrs[i][j]++) = y; + *(outptr + i*output_row_stride + j*output_col_stride) = y; } } + outptr++; } } -template class OutputTransform<3, 3, 4, 4, float, float, WinogradRoots::Integers>; - -} // namespace +} // namespace output_transform +} // namespace winograd +} // namespace arm_conv diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2x2_5x5_fp32_fp32_integers.cpp b/src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_2x2_5x5.cpp similarity index 73% rename from src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2x2_5x5_fp32_fp32_integers.cpp rename to src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_2x2_5x5.cpp index 3996be1c52..256d049032 100644 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_2x2_5x5_fp32_fp32_integers.cpp +++ b/src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_2x2_5x5.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 Arm Limited. + * Copyright (c) 2022 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -22,47 +22,38 @@ * SOFTWARE. */ -#include "output.hpp" -#include "arm.hpp" +#include +#include +#include -namespace winograd -{ +namespace arm_conv { +namespace winograd { +namespace output_transform { -template <> -void OutputTransform<5, 5, 6, 6, float, float, WinogradRoots::Integers>::transform_tile( - const int n_channels, +void arm_fp32_2x2_5x5( + unsigned int n_channels, const float* inptr, - const int matrix_stride, + const size_t matrix_stride, const float* bptr, - float* const output, - const int output_row_stride, - const int output_col_stride, + float *outptr, + const size_t output_row_stride, + const size_t output_col_stride, const float output_min, const float output_max ) { - // Construct a map to the output cells - float *outptrs[output_tile_rows][output_tile_cols]; - for (int i = 0; i < output_tile_rows; i++) - { - for (int j = 0; j < output_tile_cols; j++) - { - outptrs[i][j] = output + i*output_row_stride + j*output_col_stride; - } - } + constexpr auto output_tile_rows = 2u, output_tile_cols = 2u; // For each channel of the output - int channels_remaining = n_channels; -#ifdef __aarch64__ - for (; channels_remaining >= 4; channels_remaining -= 4) + for (; n_channels >= 4; n_channels -= 4) { // Matrices used and computed during this transform float32x4_t F[6][6], FZ[6][2], f[2][2], b; // Read a 6x6 tile in the Winograd domain - for (int i = 0, m = 0; i < 6; i++) + for (auto i = 0u, m = 0u; i < 6; i++) { - for (int j = 0; j < 6; j++, m++) + for (auto j = 0u; j < 6; j++, m++) { F[i][j] = vld1q_f32(inptr + m*matrix_stride); } @@ -70,7 +61,7 @@ void OutputTransform<5, 5, 6, 6, float, float, WinogradRoots::Integers>::transfo inptr += 4; // Compute the matrix F Z - for (int i = 0; i < 6; i++) + for (auto i = 0u; i < 6; i++) { // FZ[i][0] = 1*F[i][0] + 1*F[i][1] + 1*F[i][2] + 1*F[i][3] + 1*F[i][4]; FZ[i][0] = vaddq_f32(vaddq_f32(vaddq_f32(F[i][0], F[i][1]), vaddq_f32(F[i][2], F[i][3])), F[i][4]); @@ -80,7 +71,7 @@ void OutputTransform<5, 5, 6, 6, float, float, WinogradRoots::Integers>::transfo } // Compute the output tile f = ZT F Z - for (int j = 0; j < 2; j++) + for (auto j = 0u; j < 2; j++) { // f[0][j] = 1*FZ[0][j] + 1*FZ[1][j] + 1*FZ[2][j] + 1*FZ[3][j] + 1*FZ[4][j]; f[0][j] = vaddq_f32(vaddq_f32(vaddq_f32(FZ[0][j], FZ[1][j]), vaddq_f32(FZ[2][j], FZ[3][j])), FZ[4][j]); @@ -99,29 +90,27 @@ void OutputTransform<5, 5, 6, 6, float, float, WinogradRoots::Integers>::transfo { b = vdupq_n_f32(0.0f); } - for (int i = 0; i < output_tile_rows; i++) + for (auto i = 0u; i < output_tile_rows; i++) { - for (int j = 0; j < output_tile_cols; j++) + for (auto j = 0u; j < output_tile_cols; j++) { const auto y = vmaxq_f32(vminq_f32(vaddq_f32(f[i][j], b), vdupq_n_f32(output_max)), vdupq_n_f32(output_min)); - vst1q_f32(outptrs[i][j], y); - outptrs[i][j] += 4; + vst1q_f32(outptr + i*output_row_stride + j*output_col_stride, y); } } + outptr += 4; } -#endif // __aarch64__ -#ifdef __arm_any__ - for (; channels_remaining >= 2; channels_remaining -= 2) + for (; n_channels >= 2; n_channels -= 2) { // Matrices used and computed during this transform float32x2_t F[6][6], FZ[6][2], f[2][2], b; // Read a 6x6 tile in the Winograd domain - for (int i = 0, m = 0; i < 6; i++) + for (auto i = 0u, m = 0u; i < 6; i++) { - for (int j = 0; j < 6; j++, m++) + for (auto j = 0u; j < 6; j++, m++) { F[i][j] = vld1_f32(inptr + m*matrix_stride); } @@ -129,7 +118,7 @@ void OutputTransform<5, 5, 6, 6, float, float, WinogradRoots::Integers>::transfo inptr += 2; // Compute the matrix F Z - for (int i = 0; i < 6; i++) + for (auto i = 0u; i < 6; i++) { // FZ[i][0] = 1*F[i][0] + 1*F[i][1] + 1*F[i][2] + 1*F[i][3] + 1*F[i][4]; FZ[i][0] = vadd_f32(vadd_f32(vadd_f32(F[i][0], F[i][1]), vadd_f32(F[i][2], F[i][3])), F[i][4]); @@ -139,7 +128,7 @@ void OutputTransform<5, 5, 6, 6, float, float, WinogradRoots::Integers>::transfo } // Compute the output tile f = ZT F Z - for (int j = 0; j < 2; j++) + for (auto j = 0u; j < 2; j++) { // f[0][j] = 1*FZ[0][j] + 1*FZ[1][j] + 1*FZ[2][j] + 1*FZ[3][j] + 1*FZ[4][j]; f[0][j] = vadd_f32(vadd_f32(vadd_f32(FZ[0][j], FZ[1][j]), vadd_f32(FZ[2][j], FZ[3][j])), FZ[4][j]); @@ -158,43 +147,41 @@ void OutputTransform<5, 5, 6, 6, float, float, WinogradRoots::Integers>::transfo { b = vdup_n_f32(0.0f); } - for (int i = 0; i < output_tile_rows; i++) + for (auto i = 0u; i < output_tile_rows; i++) { - for (int j = 0; j < output_tile_cols; j++) + for (auto j = 0u; j < output_tile_cols; j++) { const auto y = vmax_f32(vmin_f32(vadd_f32(f[i][j], b), vdup_n_f32(output_max)), vdup_n_f32(output_min)); - vst1_f32(outptrs[i][j], y); - outptrs[i][j] += 2; + vst1_f32(outptr + i*output_row_stride + j*output_col_stride, y); } } + outptr += 2; } -#endif // __arm_any__ - for (; channels_remaining; channels_remaining--) + if (n_channels) { // Matrices used and computed during this transform float F[6][6], FZ[6][2], f[2][2], b; // Read a 6x6 tile in the Winograd domain - for (int i = 0, m = 0; i < 6; i++) + for (auto i = 0u, m = 0u; i < 6; i++) { - for (int j = 0; j < 6; j++, m++) + for (auto j = 0u; j < 6; j++, m++) { F[i][j] = *(inptr + m*matrix_stride); } } - inptr++; // Compute the matrix F Z - for (int i = 0; i < 6; i++) + for (auto i = 0u; i < 6; i++) { FZ[i][0] = 1*F[i][0] + 1*F[i][1] + 1*F[i][2] + 1*F[i][3] + 1*F[i][4]; FZ[i][1] = 1*F[i][1] + -1*F[i][2] + 2*F[i][3] + -2*F[i][4] + 1*F[i][5]; } // Compute the output tile f = ZT F Z - for (int j = 0; j < 2; j++) + for (auto j = 0u; j < 2; j++) { f[0][j] = 1*FZ[0][j] + 1*FZ[1][j] + 1*FZ[2][j] + 1*FZ[3][j] + 1*FZ[4][j]; f[1][j] = 1*FZ[1][j] + -1*FZ[2][j] + 2*FZ[3][j] + -2*FZ[4][j] + 1*FZ[5][j]; @@ -209,17 +196,17 @@ void OutputTransform<5, 5, 6, 6, float, float, WinogradRoots::Integers>::transfo { b = 0.0f; } - for (int i = 0; i < output_tile_rows; i++) + for (auto i = 0u; i < output_tile_rows; i++) { - for (int j = 0; j < output_tile_cols; j++) + for (auto j = 0u; j < output_tile_cols; j++) { const auto y = std::max(std::min(f[i][j] + b, output_max), output_min); - *(outptrs[i][j]++) = y; + *(outptr + i*output_row_stride + j*output_col_stride) = y; } } } } -template class OutputTransform<5, 5, 6, 6, float, float, WinogradRoots::Integers>; - -} // namespace +} // namespace output_transform +} // namespace winograd +} // namespace arm_conv diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4x4_3x3_fp32_fp32_integers.cpp b/src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_4x4_3x3.cpp similarity index 78% rename from src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4x4_3x3_fp32_fp32_integers.cpp rename to src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_4x4_3x3.cpp index 1eb9b537d2..c35da54eb6 100644 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output_4x4_3x3_fp32_fp32_integers.cpp +++ b/src/core/NEON/kernels/convolution/winograd/output_transforms/arm_fp32_4x4_3x3.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 Arm Limited. + * Copyright (c) 2022 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -22,48 +22,38 @@ * SOFTWARE. */ -#include "arm.hpp" -#include "output.hpp" +#include +#include +#include -namespace winograd -{ +namespace arm_conv { +namespace winograd { +namespace output_transform { -template <> -void winograd::OutputTransform<3, 3, 6, 6, float, float, winograd::WinogradRoots::Integers>::transform_tile( - const int n_channels, +void arm_fp32_4x4_3x3( + unsigned int n_channels, const float* inptr, - const int matrix_stride, + const size_t matrix_stride, const float* bptr, - float* const output, - const int output_row_stride, - const int output_col_stride, + float *outptr, + const size_t output_row_stride, + const size_t output_col_stride, const float output_min, const float output_max ) { - // Construct a map to the output cells - float *outptrs[output_tile_rows][output_tile_cols]; - for (int i = 0; i < output_tile_rows; i++) - { - for (int j = 0; j < output_tile_cols; j++) - { - outptrs[i][j] = output + i*output_row_stride + j*output_col_stride; - } - } + constexpr auto output_tile_rows = 4u, output_tile_cols = 4u; // For each channel of the output - int channels_remaining = n_channels; - -#ifdef __aarch64__ - for (; channels_remaining >= 4; channels_remaining -= 4) + for (; n_channels >= 4; n_channels -= 4) { // Matrices used and computed during this transform float32x4_t F[6][6], FZ[6][4], f[4][4], b; // Read a 6x6 tile in the Winograd domain - for (int i = 0, m = 0; i < 6; i++) + for (auto i = 0u, m = 0u; i < 6; i++) { - for (int j = 0; j < 6; j++, m++) + for (auto j = 0u; j < 6; j++, m++) { F[i][j] = vld1q_f32(inptr + m*matrix_stride); } @@ -71,7 +61,7 @@ void winograd::OutputTransform<3, 3, 6, 6, float, float, winograd::WinogradRoots inptr += 4; // Compute the matrix F Z - for (int i = 0; i < 6; i++) + for (auto i = 0u; i < 6; i++) { // FZ[i][0] = 1*F[i][0] + 1*F[i][1] + 1*F[i][2] + 1*F[i][3] + 1*F[i][4]; FZ[i][0] = vaddq_f32(vaddq_f32(vaddq_f32(F[i][0], F[i][1]), vaddq_f32(F[i][2], F[i][3])), F[i][4]); @@ -87,7 +77,7 @@ void winograd::OutputTransform<3, 3, 6, 6, float, float, winograd::WinogradRoots } // Compute the output tile f = ZT F Z - for (int j = 0; j < 4; j++) + for (auto j = 0u; j < 4; j++) { // f[0][j] = 1*FZ[0][j] + 1*FZ[1][j] + 1*FZ[2][j] + 1*FZ[3][j] + 1*FZ[4][j]; f[0][j] = vaddq_f32(vaddq_f32(vaddq_f32(FZ[0][j], FZ[1][j]), vaddq_f32(FZ[2][j], FZ[3][j])), FZ[4][j]); @@ -112,29 +102,27 @@ void winograd::OutputTransform<3, 3, 6, 6, float, float, winograd::WinogradRoots { b = vdupq_n_f32(0.0f); } - for (int i = 0; i < output_tile_rows; i++) + for (auto i = 0u; i < output_tile_rows; i++) { - for (int j = 0; j < output_tile_cols; j++) + for (auto j = 0u; j < output_tile_cols; j++) { const auto y = vmaxq_f32(vminq_f32(vaddq_f32(f[i][j], b), vdupq_n_f32(output_max)), vdupq_n_f32(output_min)); - vst1q_f32(outptrs[i][j], y); - outptrs[i][j] += 4; + vst1q_f32(outptr + i*output_row_stride + j*output_col_stride, y); } } + outptr += 4; } -#endif // __aarch64__ -#ifdef __arm_any__ - for (; channels_remaining >= 2; channels_remaining -= 2) + for (; n_channels >= 2; n_channels -= 2) { // Matrices used and computed during this transform float32x2_t F[6][6], FZ[6][4], f[4][4], b; // Read a 6x6 tile in the Winograd domain - for (int i = 0, m = 0; i < 6; i++) + for (auto i = 0u, m = 0u; i < 6; i++) { - for (int j = 0; j < 6; j++, m++) + for (auto j = 0u; j < 6; j++, m++) { F[i][j] = vld1_f32(inptr + m*matrix_stride); } @@ -142,7 +130,7 @@ void winograd::OutputTransform<3, 3, 6, 6, float, float, winograd::WinogradRoots inptr += 2; // Compute the matrix F Z - for (int i = 0; i < 6; i++) + for (auto i = 0u; i < 6; i++) { // FZ[i][0] = 1*F[i][0] + 1*F[i][1] + 1*F[i][2] + 1*F[i][3] + 1*F[i][4]; FZ[i][0] = vadd_f32(vadd_f32(vadd_f32(F[i][0], F[i][1]), vadd_f32(F[i][2], F[i][3])), F[i][4]); @@ -158,7 +146,7 @@ void winograd::OutputTransform<3, 3, 6, 6, float, float, winograd::WinogradRoots } // Compute the output tile f = ZT F Z - for (int j = 0; j < 4; j++) + for (auto j = 0u; j < 4; j++) { // f[0][j] = 1*FZ[0][j] + 1*FZ[1][j] + 1*FZ[2][j] + 1*FZ[3][j] + 1*FZ[4][j]; f[0][j] = vadd_f32(vadd_f32(vadd_f32(FZ[0][j], FZ[1][j]), vadd_f32(FZ[2][j], FZ[3][j])), FZ[4][j]); @@ -183,28 +171,27 @@ void winograd::OutputTransform<3, 3, 6, 6, float, float, winograd::WinogradRoots { b = vdup_n_f32(0.0f); } - for (int i = 0; i < output_tile_rows; i++) + for (auto i = 0u; i < output_tile_rows; i++) { - for (int j = 0; j < output_tile_cols; j++) + for (auto j = 0u; j < output_tile_cols; j++) { const auto y = vmax_f32(vmin_f32(vadd_f32(f[i][j], b), vdup_n_f32(output_max)), vdup_n_f32(output_min)); - vst1_f32(outptrs[i][j], y); - outptrs[i][j] += 2; + vst1_f32(outptr + i*output_row_stride + j*output_col_stride, y); } } + outptr += 2; } -#endif // __arm_any__ - for (; channels_remaining; channels_remaining--) + for (; n_channels; n_channels--) { // Matrices used and computed during this transform float F[6][6], FZ[6][4], f[4][4], b; // Read a 6x6 tile in the Winograd domain - for (int i = 0, m = 0; i < 6; i++) + for (auto i = 0u, m = 0u; i < 6; i++) { - for (int j = 0; j < 6; j++, m++) + for (auto j = 0u; j < 6; j++, m++) { F[i][j] = *(inptr + m*matrix_stride); } @@ -212,7 +199,7 @@ void winograd::OutputTransform<3, 3, 6, 6, float, float, winograd::WinogradRoots inptr++; // Compute the matrix F Z - for (int i = 0; i < 6; i++) + for (auto i = 0u; i < 6; i++) { FZ[i][0] = 1*F[i][0] + 1*F[i][1] + 1*F[i][2] + 1*F[i][3] + 1*F[i][4]; FZ[i][1] = 1*F[i][1] + -1*F[i][2] + 2*F[i][3] + -2*F[i][4]; @@ -221,7 +208,7 @@ void winograd::OutputTransform<3, 3, 6, 6, float, float, winograd::WinogradRoots } // Compute the output tile f = ZT F Z - for (int j = 0; j < 4; j++) + for (auto j = 0u; j < 4; j++) { f[0][j] = 1*FZ[0][j] + 1*FZ[1][j] + 1*FZ[2][j] + 1*FZ[3][j] + 1*FZ[4][j]; f[1][j] = 1*FZ[1][j] + -1*FZ[2][j] + 2*FZ[3][j] + -2*FZ[4][j]; @@ -238,17 +225,18 @@ void winograd::OutputTransform<3, 3, 6, 6, float, float, winograd::WinogradRoots { b = 0.0f; } - for (int i = 0; i < output_tile_rows; i++) + for (auto i = 0u; i < output_tile_rows; i++) { - for (int j = 0; j < output_tile_cols; j++) + for (auto j = 0u; j < output_tile_cols; j++) { const auto y = std::max(std::min(f[i][j] + b, output_max), output_min); - *(outptrs[i][j]++) = y; + *(outptr + i*output_row_stride + j*output_col_stride) = y; } } + outptr++; } } -template class OutputTransform<3, 3, 6, 6, float, float, winograd::WinogradRoots::Integers>; - +} // namespace output_transform } // namespace winograd +} // namespace arm_conv diff --git a/src/core/NEON/kernels/convolution/winograd/output_transforms_fp16.cpp b/src/core/NEON/kernels/convolution/winograd/output_transforms_fp16.cpp new file mode 100644 index 0000000000..c39b1dc083 --- /dev/null +++ b/src/core/NEON/kernels/convolution/winograd/output_transforms_fp16.cpp @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + +#include "output_transform.hpp" +#include "winograd_implementations.hpp" + +namespace arm_conv { +namespace winograd { +namespace output_transform { + +void a64_fp16_4x4_3x3(unsigned int, const __fp16 *, size_t, const __fp16 *, __fp16 *, size_t, size_t, __fp16, __fp16); + +#define IMPL(OUT_HEIGHT, OUT_WIDTH, KERN_HEIGHT, KERN_WIDTH, FUNC, DRIVER) \ + new Transform ## DRIVER <__fp16, __fp16>(#FUNC, OUT_HEIGHT, OUT_WIDTH, KERN_HEIGHT, KERN_WIDTH, FUNC) + + +static const TransformImplementation<__fp16> transforms_fp16[] = { + { IMPL(4, 4, 3, 3, a64_fp16_4x4_3x3, Unpadded) }, + { nullptr } +}; + +template <> +const TransformImplementation<__fp16> *implementation_list(void) +{ + return transforms_fp16; +} + +} // namespace output_transform +} // namespace winograd +} // namespace arm_conv + +#endif // defined(__aarch64__) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) \ No newline at end of file diff --git a/src/core/NEON/kernels/convolution/winograd/output_transforms_fp32.cpp b/src/core/NEON/kernels/convolution/winograd/output_transforms_fp32.cpp new file mode 100644 index 0000000000..73abe8b945 --- /dev/null +++ b/src/core/NEON/kernels/convolution/winograd/output_transforms_fp32.cpp @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "output_transform.hpp" +#include "winograd_implementations.hpp" + +namespace arm_conv { +namespace winograd { +namespace output_transform { + +void arm_fp32_4x4_3x3(unsigned int, const float *, size_t, const float *, float *, size_t, size_t, float, float); +void arm_fp32_2x2_3x3(unsigned int, const float *, size_t, const float *, float *, size_t, size_t, float, float); +void arm_fp32_2x2_5x5(unsigned int, const float *, size_t, const float *, float *, size_t, size_t, float, float); +void arm_fp32_1x6_1x3(unsigned int, const float *, size_t, const float *, float *, size_t, size_t, float, float); +void arm_fp32_1x4_1x5(unsigned int, const float *, size_t, const float *, float *, size_t, size_t, float, float); +void arm_fp32_1x2_1x7(unsigned int, const float *, size_t, const float *, float *, size_t, size_t, float, float); + +#define IMPL(OUT_HEIGHT, OUT_WIDTH, KERN_HEIGHT, KERN_WIDTH, FUNC, DRIVER) \ + new Transform ## DRIVER (#FUNC, OUT_HEIGHT, OUT_WIDTH, KERN_HEIGHT, KERN_WIDTH, FUNC) + +#define IMPL_T(OUT_HEIGHT, OUT_WIDTH, KERN_HEIGHT, KERN_WIDTH, FUNC, DRIVER) \ + new Transform ## DRIVER (#FUNC, OUT_HEIGHT, OUT_WIDTH, KERN_HEIGHT, KERN_WIDTH, Transform ## DRIVER ::get_transposed_kernel(FUNC)) + +static const TransformImplementation transforms_fp32[] = { +#if defined(__aarch64__) +#endif // defined(__aarch64__) + { IMPL(4, 4, 3, 3, arm_fp32_4x4_3x3, Unpadded), MethodConstraints::LargerShape }, + { IMPL(2, 2, 3, 3, arm_fp32_2x2_3x3, Unpadded) }, + { IMPL(2, 2, 5, 5, arm_fp32_2x2_5x5, Unpadded) }, + { IMPL(1, 6, 1, 3, arm_fp32_1x6_1x3, Unpadded) }, + { IMPL_T(6, 1, 3, 1, arm_fp32_1x6_1x3, Unpadded) }, + { IMPL(1, 4, 1, 5, arm_fp32_1x4_1x5, Unpadded) }, + { IMPL_T(4, 1, 5, 1, arm_fp32_1x4_1x5, Unpadded) }, + { IMPL(1, 2, 1, 7, arm_fp32_1x2_1x7, Unpadded) }, + { IMPL_T(2, 1, 7, 1, arm_fp32_1x2_1x7, Unpadded) }, + { nullptr } +}; + +template <> +const TransformImplementation *implementation_list(void) +{ + return transforms_fp32; +} + +} // namespace output_transform +} // namespace winograd +} // namespace arm_conv diff --git a/src/core/NEON/kernels/convolution/winograd/padding.cpp b/src/core/NEON/kernels/convolution/winograd/padding.cpp index 1d44c384d9..aca8448658 100644 --- a/src/core/NEON/kernels/convolution/winograd/padding.cpp +++ b/src/core/NEON/kernels/convolution/winograd/padding.cpp @@ -28,23 +28,22 @@ namespace padding { - template void copy_and_pad_tile( - const unsigned int tile_rows, - const unsigned int tile_cols, - const unsigned int n_channels, - const T* const inptr, - const unsigned int in_row_stride, - const unsigned int in_col_stride, - T* const outptr, - const unsigned int out_row_stride, - const unsigned int out_col_stride, - const unsigned int pad_top, - const unsigned int pad_left, - const unsigned int pad_bottom, - const unsigned int pad_right, - const T pad_value + unsigned int tile_rows, + unsigned int tile_cols, + unsigned int n_channels, + const T *inptr, + unsigned int in_row_stride, + unsigned int in_col_stride, + T* outptr, + unsigned int out_row_stride, + unsigned int out_col_stride, + unsigned int pad_top, + unsigned int pad_left, + unsigned int pad_bottom, + unsigned int pad_right, + T pad_value ) { for (unsigned int out_i = 0; out_i < tile_rows; out_i++) diff --git a/src/core/NEON/kernels/convolution/winograd/weight_transform.hpp b/src/core/NEON/kernels/convolution/winograd/weight_transform.hpp new file mode 100644 index 0000000000..db0f53df1b --- /dev/null +++ b/src/core/NEON/kernels/convolution/winograd/weight_transform.hpp @@ -0,0 +1,145 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#include "src/core/NEON/kernels/assembly/winograd.hpp" +#include +#include + +namespace arm_conv { +namespace winograd { +namespace weight_transform { + +/* Driver class for the Winograd weight transforms. + */ +template +class Transform : public ITransform +{ + using Kernel = std::function; + + const std::string m_name; + const unsigned int m_kernel_rows, m_kernel_cols; + const unsigned int m_transformed_tile_rows, m_transformed_tile_cols; + const Kernel m_kernel; + + void execute_internal( + const ConvolutionArgs &args, + const TIn *inptr, size_t ld_in_row, size_t ld_in_col, size_t ld_input_channel, + TOut *outptr, size_t ld_out_matrix, size_t ld_out_row, + unsigned int thread_id, unsigned int n_threads + ) const + { + // Stripe groups of input channels over threads, this should reduce false + // sharing of the output matrix. + constexpr auto n_input_channels_per_thread = 16u; + + // Get the initial offset for the input and output pointers + const auto offset = thread_id * n_input_channels_per_thread; + inptr += offset * ld_input_channel; + outptr += offset * ld_out_row; + + for (auto start_ic = thread_id * n_input_channels_per_thread; + start_ic < args.n_input_channels; + start_ic += n_threads * n_input_channels_per_thread) + { + // Now iterate over the input channels assigned to this thread. + const auto end_ic = std::min(args.n_input_channels, + start_ic + n_input_channels_per_thread); + for (auto ic = start_ic; ic < end_ic; ic++) + { + m_kernel(args.n_output_channels, inptr, ld_in_row, ld_in_col, + outptr, ld_out_matrix); + inptr += ld_input_channel; + outptr += ld_out_row; + } + + // Progress the pointers to the account for the work not performed by + // this thread. + const auto skip = (n_threads - 1) * n_input_channels_per_thread; + inptr += skip * ld_input_channel; + outptr += skip * ld_out_row; + } + } + + public: + Transform( + const std::string &name, + unsigned int kernel_rows, unsigned int kernel_cols, + unsigned int transformed_tile_rows, unsigned int transformed_tile_cols, + const Kernel kernel + ) + : m_name(name), + m_kernel_rows(kernel_rows), m_kernel_cols(kernel_cols), + m_transformed_tile_rows(transformed_tile_rows), m_transformed_tile_cols(transformed_tile_cols), + m_kernel(kernel) + { + } + + const std::string &get_name(void) const override { return m_name; } + + unsigned int get_kernel_rows(void) const override { return m_kernel_rows; } + unsigned int get_kernel_cols(void) const override { return m_kernel_cols; } + + unsigned int get_transformed_tile_rows(void) const override { return m_transformed_tile_rows; } + unsigned int get_transformed_tile_cols(void) const override { return m_transformed_tile_cols; } + + void execute( + const ConvolutionArgs &args, + const void *inptr, size_t ld_in_row, size_t ld_in_col, size_t ld_input_channel, + void *outptr, size_t ld_out_matrix, size_t ld_out_row, + unsigned int thread_id, unsigned int n_threads + ) const override + { + execute_internal( + args, + reinterpret_cast(inptr), ld_in_row, ld_in_col, ld_input_channel, + reinterpret_cast(outptr), ld_out_matrix, ld_out_row, + thread_id, n_threads + ); + } + + /* Utility method to get a transposed variant of a kernel, this transposed + * version simply calls the original kernel with the input row and column + * strides swapped. + */ + static constexpr Kernel get_transposed_kernel(const Kernel &kernel) + { + return [kernel] ( + const unsigned int n_channels, + const TIn *const inptr, const size_t ld_in_row, const size_t ld_in_col, + TOut *const outptr, const size_t ld_out + ) { + kernel(n_channels, inptr, ld_in_col, ld_in_row, outptr, ld_out); + }; + } +}; + +} // namespace weight_transform +} // namespace winograd +} // namespace arm_conv diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_4x4_3x3_fp16_fp16_integers.cpp b/src/core/NEON/kernels/convolution/winograd/weight_transforms/a64_fp16_4x4_3x3.cpp similarity index 67% rename from src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_4x4_3x3_fp16_fp16_integers.cpp rename to src/core/NEON/kernels/convolution/winograd/weight_transforms/a64_fp16_4x4_3x3.cpp index 3101865027..0d9a65890e 100644 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_4x4_3x3_fp16_fp16_integers.cpp +++ b/src/core/NEON/kernels/convolution/winograd/weight_transforms/a64_fp16_4x4_3x3.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020 Arm Limited. + * Copyright (c) 2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,45 +21,26 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - -#include "arm.hpp" -#include "kernel.hpp" - -namespace winograd -{ - -template <> -void WeightTransform<3, 3, 6, 6, __fp16, __fp16, WinogradRoots::Integers>::execute( - const int n_output_channels, - const int n_input_channels, - const __fp16* const input, // NOTE: Data in HWIO order - __fp16* const output, - const int matrix_stride, - const int matrix_row_stride +#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + +#include +#include + +namespace arm_conv { +namespace winograd { +namespace weight_transform { + +void a64_fp16_4x4_3x3( + unsigned int n_channels, + const __fp16* inptr, // NOTE: Data in HWIO order + const size_t ld_weight_row, + const size_t ld_weight_col, + __fp16* outptr, + const size_t matrix_stride ) { - // Get pointers to each cell of the weight tensor - const auto weight_col_stride = n_input_channels * n_output_channels; - const auto weight_row_stride = 3 * weight_col_stride; - const __fp16 *inptrs[3][3]; - for (int i = 0; i < 3; i++) - { - for (int j = 0; j < 3; j++) - { - inptrs[i][j] = input + i*weight_row_stride + j*weight_col_stride; - } - } - - // For each input channel - for (int ic = 0; ic < n_input_channels; ic++) - { - __fp16 *outptr = output + ic * matrix_row_stride; - - // For each output channel - int channels_remaining = n_output_channels; #ifdef __aarch64__ - for (; channels_remaining >= 8; channels_remaining -= 8) + for (; n_channels >= 8; n_channels -= 8) { // Matrices used and computed in this kernel float16x8_t w[3][3], Ww[6][3], V[6][6]; @@ -69,8 +50,7 @@ void WeightTransform<3, 3, 6, 6, __fp16, __fp16, WinogradRoots::Integers>::execu { for (int j = 0; j < 3; j++) { - w[i][j] = vld1q_f16(inptrs[i][j]); - inptrs[i][j] += 8; + w[i][j] = vld1q_f16(inptr + i*ld_weight_row + j*ld_weight_col); } } @@ -128,11 +108,12 @@ void WeightTransform<3, 3, 6, 6, __fp16, __fp16, WinogradRoots::Integers>::execu vst1q_f16(outptr + m*matrix_stride, V[i][j]); } } + inptr += 8; outptr += 8; } #endif // __aarch64__ #ifdef __arm_any__ - for (; channels_remaining >= 4; channels_remaining -= 4) + for (; n_channels >= 4; n_channels -= 4) { // Matrices used and computed in this kernel float16x4_t w[3][3], Ww[6][3], V[6][6]; @@ -142,8 +123,7 @@ void WeightTransform<3, 3, 6, 6, __fp16, __fp16, WinogradRoots::Integers>::execu { for (int j = 0; j < 3; j++) { - w[i][j] = vld1_f16(inptrs[i][j]); - inptrs[i][j] += 4; + w[i][j] = vld1_f16(inptr + i*ld_weight_row + j*ld_weight_col); } } @@ -201,59 +181,62 @@ void WeightTransform<3, 3, 6, 6, __fp16, __fp16, WinogradRoots::Integers>::execu vst1_f16(outptr + m*matrix_stride, V[i][j]); } } + inptr += 4; outptr += 4; } #endif // __arm_any__ - for (; channels_remaining; channels_remaining--) + for (; n_channels; n_channels--) + { + // Matrices used and computed in this kernel + __fp16 w[3][3], Ww[6][3], V[6][6]; + + // Read weights + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < 3; j++) { - // Matrices used and computed in this kernel - __fp16 w[3][3], Ww[6][3], V[6][6]; - - // Read weights - for (int i = 0; i < 3; i++) - { - for (int j = 0; j < 3; j++) - { - w[i][j] = *(inptrs[i][j]++); - } - } - - // Compute the matrix W w - for (int j = 0; j < 3; j++) - { - Ww[0][j] = 6*w[0][j]; - Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j]; - Ww[2][j] = -4*w[0][j] + 4*w[1][j] + -4*w[2][j]; - Ww[3][j] = 1*w[0][j] + 2*w[1][j] + 4*w[2][j]; - Ww[4][j] = 1*w[0][j] + -2*w[1][j] + 4*w[2][j]; - Ww[5][j] = 24*w[2][j]; - } - - // Compute V = W w WT - for (int i = 0; i < 6; i++) - { - V[i][0] = ( 6*Ww[i][0]) / 576.0; - V[i][1] = (-4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2]) / 576.0; - V[i][2] = (-4*Ww[i][0] + 4*Ww[i][1] + -4*Ww[i][2]) / 576.0; - V[i][3] = ( 1*Ww[i][0] + 2*Ww[i][1] + 4*Ww[i][2]) / 576.0; - V[i][4] = ( 1*Ww[i][0] + -2*Ww[i][1] + 4*Ww[i][2]) / 576.0; - V[i][5] = (24*Ww[i][2]) / 576.0; - } - - // Store the transformed weights - for (int i = 0, m = 0; i < 6; i++) - { - for (int j = 0; j < 6; j++, m++) - { - *(outptr + m*matrix_stride) = V[i][j]; - } - } - outptr++; + w[i][j] = *(inptr + i*ld_weight_row + j*ld_weight_col); } + } + + // Compute the matrix W w + for (int j = 0; j < 3; j++) + { + Ww[0][j] = 6*w[0][j]; + Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j]; + Ww[2][j] = -4*w[0][j] + 4*w[1][j] + -4*w[2][j]; + Ww[3][j] = 1*w[0][j] + 2*w[1][j] + 4*w[2][j]; + Ww[4][j] = 1*w[0][j] + -2*w[1][j] + 4*w[2][j]; + Ww[5][j] = 24*w[2][j]; + } + + // Compute V = W w WT + for (int i = 0; i < 6; i++) + { + V[i][0] = ( 6*Ww[i][0]) / 576.0; + V[i][1] = (-4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2]) / 576.0; + V[i][2] = (-4*Ww[i][0] + 4*Ww[i][1] + -4*Ww[i][2]) / 576.0; + V[i][3] = ( 1*Ww[i][0] + 2*Ww[i][1] + 4*Ww[i][2]) / 576.0; + V[i][4] = ( 1*Ww[i][0] + -2*Ww[i][1] + 4*Ww[i][2]) / 576.0; + V[i][5] = (24*Ww[i][2]) / 576.0; + } + + // Store the transformed weights + for (int i = 0, m = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++, m++) + { + *(outptr + m*matrix_stride) = V[i][j]; + } + } + + inptr++; + outptr++; } } -template class WeightTransform<3, 3, 6, 6, __fp16, __fp16, WinogradRoots::Integers>; +} // namespace weight_transform +} // namespace winograd +} // namespace arm_conv -} // namespace -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#endif // defined(__aarch64__) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) diff --git a/src/core/NEON/kernels/convolution/winograd/weight_transforms/arm_fp32_2x2_3x3.cpp b/src/core/NEON/kernels/convolution/winograd/weight_transforms/arm_fp32_2x2_3x3.cpp new file mode 100644 index 0000000000..e55bcb632f --- /dev/null +++ b/src/core/NEON/kernels/convolution/winograd/weight_transforms/arm_fp32_2x2_3x3.cpp @@ -0,0 +1,200 @@ +/* + * Copyright (c) 2022 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include +#include + +namespace arm_conv { +namespace winograd { +namespace weight_transform { + +void arm_fp32_2x2_3x3( + unsigned int n_channels, + const float *inptr, size_t ld_weight_row, size_t ld_weight_col, + float *outptr, size_t matrix_stride +) +{ + constexpr auto inner_tile_i = 4u; + constexpr auto inner_tile_j = 4u; + +#ifdef __aarch64__ + // For each output channel + for (; n_channels >= 4u; n_channels -= 4) + { + // Matrices used and computed in this kernel + float32x4_t w[3][3], Ww[inner_tile_i][3], V[inner_tile_i][inner_tile_j]; + + // Read weights + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < 3; j++) + { + w[i][j] = vld1q_f32(inptr + i*ld_weight_row + j*ld_weight_col); + } + } + + // Compute the matrix W w + for (int j = 0; j < 3; j++) + { + Ww[0][j] = w[0][j]; + + // Ww[1][j] = 0.5*(w[0][j] + w[1][j] + w[2][j]); + Ww[1][j] = vmulq_n_f32(vaddq_f32(vaddq_f32(w[0][j], w[1][j]), w[2][j]), 0.5f); + + // Ww[2][j] = 0.5*(w[0][j] - w[1][j] + w[2][j]); + Ww[2][j] = vmulq_n_f32(vaddq_f32(vsubq_f32(w[0][j], w[1][j]), w[2][j]), 0.5f); + + Ww[3][j] = w[2][j]; + } + + // Compute V = W w WT + for (auto i = 0u; i < inner_tile_i; i++) + { + V[i][0] = Ww[i][0]; + + // V[i][1] = 0.5*(Ww[i][0] + Ww[i][1] + Ww[i][2]); + V[i][1] = vmulq_n_f32(vaddq_f32(vaddq_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), 0.5f); + + // V[i][2] = 0.5*(Ww[i][0] - Ww[i][1] + Ww[i][2]); + V[i][2] = vmulq_n_f32(vaddq_f32(vsubq_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), 0.5f); + + V[i][3] = Ww[i][2]; + } + + // Store the transformed weights + for (auto i = 0u, m = 0u; i < inner_tile_i; i++) + { + for (auto j = 0u; j < inner_tile_j; j++, m++) + { + vst1q_f32(outptr + m*matrix_stride, V[i][j]); + } + } + + inptr += 4; + outptr += 4; + } +#endif // __aarch64__ + for (; n_channels >= 2u; n_channels -= 2) + { + // Matrices used and computed in this kernel + float32x2_t w[3][3], Ww[inner_tile_i][3], V[inner_tile_i][inner_tile_j]; + + // Read weights + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < 3; j++) + { + w[i][j] = vld1_f32(inptr + i*ld_weight_row + j*ld_weight_col); + } + } + + // Compute the matrix W w + for (int j = 0; j < 3; j++) + { + Ww[0][j] = w[0][j]; + + // Ww[1][j] = 0.5*(w[0][j] + w[1][j] + w[2][j]); + Ww[1][j] = vmul_n_f32(vadd_f32(vadd_f32(w[0][j], w[1][j]), w[2][j]), 0.5f); + + // Ww[2][j] = 0.5*(w[0][j] - w[1][j] + w[2][j]); + Ww[2][j] = vmul_n_f32(vadd_f32(vsub_f32(w[0][j], w[1][j]), w[2][j]), 0.5f); + + Ww[3][j] = w[2][j]; + } + + // Compute V = W w WT + for (auto i = 0u; i < inner_tile_i; i++) + { + V[i][0] = Ww[i][0]; + + // V[i][1] = 0.5*(Ww[i][0] + Ww[i][1] + Ww[i][2]); + V[i][1] = vmul_n_f32(vadd_f32(vadd_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), 0.5f); + + // V[i][2] = 0.5*(Ww[i][0] - Ww[i][1] + Ww[i][2]); + V[i][2] = vmul_n_f32(vadd_f32(vsub_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), 0.5f); + + V[i][3] = Ww[i][2]; + } + + // Store the transformed weights + for (auto i = 0u, m = 0u; i < inner_tile_i; i++) + { + for (auto j = 0u; j < inner_tile_j; j++, m++) + { + vst1_f32(outptr + m*matrix_stride, V[i][j]); + } + } + + inptr += 2; + outptr += 2; + } + for (; n_channels; n_channels--) + { + // Matrices used and computed in this kernel + float w[3][3], Ww[inner_tile_i][3], V[inner_tile_i][inner_tile_j]; + + // Read weights + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < 3; j++) + { + w[i][j] = *(inptr + i*ld_weight_row + j*ld_weight_col); + } + } + + // Compute the matrix W w + for (int j = 0; j < 3; j++) + { + Ww[0][j] = w[0][j]; + Ww[1][j] = 0.5*(w[0][j] + w[1][j] + w[2][j]); + Ww[2][j] = 0.5*(w[0][j] - w[1][j] + w[2][j]); + Ww[3][j] = w[2][j]; + } + + // Compute V = W w WT + for (auto i = 0u; i < inner_tile_i; i++) + { + V[i][0] = Ww[i][0]; + V[i][1] = 0.5*(Ww[i][0] + Ww[i][1] + Ww[i][2]); + V[i][2] = 0.5*(Ww[i][0] - Ww[i][1] + Ww[i][2]); + V[i][3] = Ww[i][2]; + } + + // Store the transformed weights + for (auto i = 0u, m = 0u; i < inner_tile_i; i++) + { + for (auto j = 0u; j < inner_tile_j; j++, m++) + { + *(outptr + m*matrix_stride) = V[i][j]; + } + } + + inptr++; + outptr++; + } +} + +} // namespace weight_transform +} // namespace winograd +} // namespace arm_conv diff --git a/src/core/NEON/kernels/convolution/winograd/weight_transforms/arm_fp32_2x2_5x5.cpp b/src/core/NEON/kernels/convolution/winograd/weight_transforms/arm_fp32_2x2_5x5.cpp new file mode 100644 index 0000000000..9cdf15a4af --- /dev/null +++ b/src/core/NEON/kernels/convolution/winograd/weight_transforms/arm_fp32_2x2_5x5.cpp @@ -0,0 +1,381 @@ +/* + * Copyright (c) 2022 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include +#include + +namespace arm_conv { +namespace winograd { +namespace weight_transform { + +void arm_fp32_2x2_5x5( + unsigned int n_channels, + const float *inptr, const size_t ld_weight_row, const size_t ld_weight_col, + float *outptr, const size_t matrix_stride +) +{ +#ifdef __aarch64__ + // For each output channel + for (; n_channels >= 4; n_channels -= 4) + { + // Matrices used and computed in this kernel + float32x4_t w[5][5], Ww[6][5], V[6][6]; + + // Read weights + for (int i = 0; i < 5; i++) + { + for (int j = 0; j < 5; j++) + { + w[i][j] = vld1q_f32(inptr + i*ld_weight_row + j*ld_weight_col); + } + } + + // Compute the matrix W w + for (int j = 0; j < 5; j++) + { + // Ww[0][j] = w[0][j]/4.0f; + Ww[0][j] = vmulq_n_f32(w[0][j], 1.0f/4.0f); + + // Ww[1][j] = -( w[0][j] + w[1][j] + w[2][j] + w[3][j] + w[4][j])/6.0f; + Ww[1][j] = vmulq_n_f32( + vaddq_f32( + vaddq_f32( + vaddq_f32(w[1][j], w[0][j]), + vaddq_f32(w[3][j], w[2][j]) + ), + w[4][j] + ), + -1.0f/6.0f + ); + + // Ww[2][j] = +(-w[0][j] + w[1][j] - w[2][j] + w[3][j] - w[4][j])/6.0f; + // Ww[2][j] = ((w[1][j] - w[0][j]) + (w[3][j] - w[2][j]) - w[4][j])/6.0f; + Ww[2][j] = vmulq_n_f32( + vsubq_f32( + vaddq_f32( + vsubq_f32(w[1][j], w[0][j]), + vsubq_f32(w[3][j], w[2][j]) + ), + w[4][j] + ), + 1.0f/6.0f + ); + + // Ww[3][j] = (w[0][j]/8.0f + w[1][j]/4.0f + w[2][j]/2.0f + w[3][j] + 2*w[4][j])/3.0f; + Ww[3][j] = vmulq_n_f32( + vmlaq_n_f32( + vaddq_f32( + vaddq_f32(vmulq_n_f32(w[0][j], 1.0f/8.0f), vmulq_n_f32(w[1][j], 1.0f/4.0f)), + vaddq_f32(vmulq_n_f32(w[2][j], 1.0f/2.0f), w[3][j]) + ), + w[4][j], 2.0f + ), + 1.0f/3.0f + ); + + // Ww[4][j] = (w[0][j]/8.0f - w[1][j]/4.0f + w[2][j]/2.0f - w[3][j] + 2*w[4][j])/3.0f; + Ww[4][j] = vmulq_n_f32( + vmlaq_n_f32( + vaddq_f32( + vsubq_f32(vmulq_n_f32(w[0][j], 1.0f/8.0f), vmulq_n_f32(w[1][j], 1.0f/4.0f)), + vsubq_f32(vmulq_n_f32(w[2][j], 1.0f/2.0f), w[3][j]) + ), + w[4][j], 2.0f + ), + 1.0f/3.0f + ); + + // Ww[5][j] = w[4][j]; + Ww[5][j] = w[4][j]; + } + + // Compute V = W w WT + for (int i = 0; i < 6; i++) + { + // V[i][0] = Ww[i][0]/4.0f; + V[i][0] = vmulq_n_f32(Ww[i][0], 1.0f/4.0f); + + // V[i][1] = -( Ww[i][0] + Ww[i][1] + Ww[i][2] + Ww[i][3] + Ww[i][4])/6.0f; + V[i][1] = vmulq_n_f32( + vaddq_f32( + vaddq_f32( + vaddq_f32(Ww[i][1], Ww[i][0]), + vaddq_f32(Ww[i][3], Ww[i][2]) + ), + Ww[i][4] + ), + -1.0f/6.0f + ); + + // V[i][2] = +(-Ww[i][0] + Ww[i][1] - Ww[i][2] + Ww[i][3] - Ww[i][4])/6.0f; + // V[i][2] = ((Ww[i][1] - Ww[i][0]) + (Ww[i][3] - Ww[i][2]) - Ww[i][4])/6.0f; + V[i][2] = vmulq_n_f32( + vsubq_f32( + vaddq_f32( + vsubq_f32(Ww[i][1], Ww[i][0]), + vsubq_f32(Ww[i][3], Ww[i][2]) + ), + Ww[i][4] + ), + 1.0f/6.0f + ); + + // V[i][3] = (Ww[i][0]/8.0f + Ww[i][1]/4.0f + Ww[i][2]/2.0f + Ww[i][3] + 2*Ww[i][4])/3.0f; + V[i][3] = vmulq_n_f32( + vmlaq_n_f32( + vaddq_f32( + vaddq_f32(vmulq_n_f32(Ww[i][0], 1.0f/8.0f), vmulq_n_f32(Ww[i][1], 1.0f/4.0f)), + vaddq_f32(vmulq_n_f32(Ww[i][2], 1.0f/2.0f), Ww[i][3]) + ), + Ww[i][4], 2.0f + ), + 1.0f/3.0f + ); + + // V[i][4] = (Ww[i][0]/8.0f - Ww[i][1]/4.0f + Ww[i][2]/2.0f - Ww[i][3] + 2*Ww[i][4])/3.0f; + V[i][4] = vmulq_n_f32( + vmlaq_n_f32( + vaddq_f32( + vsubq_f32(vmulq_n_f32(Ww[i][0], 1.0f/8.0f), vmulq_n_f32(Ww[i][1], 1.0f/4.0f)), + vsubq_f32(vmulq_n_f32(Ww[i][2], 1.0f/2.0f), Ww[i][3]) + ), + Ww[i][4], 2.0f + ), + 1.0f/3.0f + ); + + // V[i][5] = Ww[i][4]; + V[i][5] = Ww[i][4]; + } + + // Store the transformed weights + for (int i = 0, m = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++, m++) + { + vst1q_f32(outptr + m*matrix_stride, V[i][j]); + } + } + + inptr += 4; + outptr += 4; + } +#endif // __aarch64__ + for (; n_channels >= 2; n_channels -= 2) + { + // Matrices used and computed in this kernel + float32x2_t w[5][5], Ww[6][5], V[6][6]; + + // Read weights + for (int i = 0; i < 5; i++) + { + for (int j = 0; j < 5; j++) + { + w[i][j] = vld1_f32(inptr + i*ld_weight_row + j*ld_weight_col); + } + } + + // Compute the matrix W w + for (int j = 0; j < 5; j++) + { + // Ww[0][j] = w[0][j]/4.0f; + Ww[0][j] = vmul_n_f32(w[0][j], 1.0f/4.0f); + + // Ww[1][j] = -( w[0][j] + w[1][j] + w[2][j] + w[3][j] + w[4][j])/6.0f; + Ww[1][j] = vmul_n_f32( + vadd_f32( + vadd_f32( + vadd_f32(w[1][j], w[0][j]), + vadd_f32(w[3][j], w[2][j]) + ), + w[4][j] + ), + -1.0f/6.0f + ); + + // Ww[2][j] = +(-w[0][j] + w[1][j] - w[2][j] + w[3][j] - w[4][j])/6.0f; + // Ww[2][j] = ((w[1][j] - w[0][j]) + (w[3][j] - w[2][j]) - w[4][j])/6.0f; + Ww[2][j] = vmul_n_f32( + vsub_f32( + vadd_f32( + vsub_f32(w[1][j], w[0][j]), + vsub_f32(w[3][j], w[2][j]) + ), + w[4][j] + ), + 1.0f/6.0f + ); + + // Ww[3][j] = (w[0][j]/8.0f + w[1][j]/4.0f + w[2][j]/2.0f + w[3][j] + 2*w[4][j])/3.0f; + Ww[3][j] = vmul_n_f32( + vmla_n_f32( + vadd_f32( + vadd_f32(vmul_n_f32(w[0][j], 1.0f/8.0f), vmul_n_f32(w[1][j], 1.0f/4.0f)), + vadd_f32(vmul_n_f32(w[2][j], 1.0f/2.0f), w[3][j]) + ), + w[4][j], 2.0f + ), + 1.0f/3.0f + ); + + // Ww[4][j] = (w[0][j]/8.0f - w[1][j]/4.0f + w[2][j]/2.0f - w[3][j] + 2*w[4][j])/3.0f; + Ww[4][j] = vmul_n_f32( + vmla_n_f32( + vadd_f32( + vsub_f32(vmul_n_f32(w[0][j], 1.0f/8.0f), vmul_n_f32(w[1][j], 1.0f/4.0f)), + vsub_f32(vmul_n_f32(w[2][j], 1.0f/2.0f), w[3][j]) + ), + w[4][j], 2.0f + ), + 1.0f/3.0f + ); + + // Ww[5][j] = w[4][j]; + Ww[5][j] = w[4][j]; + } + + // Compute V = W w WT + for (int i = 0; i < 6; i++) + { + // V[i][0] = Ww[i][0]/4.0f; + V[i][0] = vmul_n_f32(Ww[i][0], 1.0f/4.0f); + + // V[i][1] = -( Ww[i][0] + Ww[i][1] + Ww[i][2] + Ww[i][3] + Ww[i][4])/6.0f; + V[i][1] = vmul_n_f32( + vadd_f32( + vadd_f32( + vadd_f32(Ww[i][1], Ww[i][0]), + vadd_f32(Ww[i][3], Ww[i][2]) + ), + Ww[i][4] + ), + -1.0f/6.0f + ); + + // V[i][2] = +(-Ww[i][0] + Ww[i][1] - Ww[i][2] + Ww[i][3] - Ww[i][4])/6.0f; + // V[i][2] = ((Ww[i][1] - Ww[i][0]) + (Ww[i][3] - Ww[i][2]) - Ww[i][4])/6.0f; + V[i][2] = vmul_n_f32( + vsub_f32( + vadd_f32( + vsub_f32(Ww[i][1], Ww[i][0]), + vsub_f32(Ww[i][3], Ww[i][2]) + ), + Ww[i][4] + ), + 1.0f/6.0f + ); + + // V[i][3] = (Ww[i][0]/8.0f + Ww[i][1]/4.0f + Ww[i][2]/2.0f + Ww[i][3] + 2*Ww[i][4])/3.0f; + V[i][3] = vmul_n_f32( + vmla_n_f32( + vadd_f32( + vadd_f32(vmul_n_f32(Ww[i][0], 1.0f/8.0f), vmul_n_f32(Ww[i][1], 1.0f/4.0f)), + vadd_f32(vmul_n_f32(Ww[i][2], 1.0f/2.0f), Ww[i][3]) + ), + Ww[i][4], 2.0f + ), + 1.0f/3.0f + ); + + // V[i][4] = (Ww[i][0]/8.0f - Ww[i][1]/4.0f + Ww[i][2]/2.0f - Ww[i][3] + 2*Ww[i][4])/3.0f; + V[i][4] = vmul_n_f32( + vmla_n_f32( + vadd_f32( + vsub_f32(vmul_n_f32(Ww[i][0], 1.0f/8.0f), vmul_n_f32(Ww[i][1], 1.0f/4.0f)), + vsub_f32(vmul_n_f32(Ww[i][2], 1.0f/2.0f), Ww[i][3]) + ), + Ww[i][4], 2.0f + ), + 1.0f/3.0f + ); + + // V[i][5] = Ww[i][4]; + V[i][5] = Ww[i][4]; + } + + // Store the transformed weights + for (int i = 0, m = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++, m++) + { + vst1_f32(outptr + m*matrix_stride, V[i][j]); + } + } + + inptr += 2; + outptr += 2; + } + for (; n_channels; n_channels--) + { + // Matrices used and computed in this kernel + float w[5][5], Ww[6][5], V[6][6]; + + // Read weights + for (int i = 0; i < 5; i++) + { + for (int j = 0; j < 5; j++) + { + w[i][j] = *(inptr + i*ld_weight_row + j*ld_weight_col); + } + } + + // Compute the matrix W w + for (int j = 0; j < 5; j++) + { + Ww[0][j] = w[0][j]/4.0f; + Ww[1][j] = -( w[0][j] + w[1][j] + w[2][j] + w[3][j] + w[4][j])/6.0f; + Ww[2][j] = +(-w[0][j] + w[1][j] - w[2][j] + w[3][j] - w[4][j])/6.0f; + Ww[3][j] = (w[0][j]/8.0f + w[1][j]/4.0f + w[2][j]/2.0f + w[3][j] + 2*w[4][j])/3.0f; + Ww[4][j] = (w[0][j]/8.0f - w[1][j]/4.0f + w[2][j]/2.0f - w[3][j] + 2*w[4][j])/3.0f; + Ww[5][j] = w[4][j]; + } + + // Compute V = W w WT + for (int i = 0; i < 6; i++) + { + V[i][0] = Ww[i][0]/4.0f; + V[i][1] = -( Ww[i][0] + Ww[i][1] + Ww[i][2] + Ww[i][3] + Ww[i][4])/6.0f; + V[i][2] = +(-Ww[i][0] + Ww[i][1] - Ww[i][2] + Ww[i][3] - Ww[i][4])/6.0f; + V[i][3] = (Ww[i][0]/8.0f + Ww[i][1]/4.0f + Ww[i][2]/2.0f + Ww[i][3] + 2*Ww[i][4])/3.0f; + V[i][4] = (Ww[i][0]/8.0f - Ww[i][1]/4.0f + Ww[i][2]/2.0f - Ww[i][3] + 2*Ww[i][4])/3.0f; + V[i][5] = Ww[i][4]; + } + + // Store the transformed weights + for (int i = 0, m = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++, m++) + { + *(outptr + m*matrix_stride) = V[i][j]; + } + } + + inptr++; + outptr++; + } +} + +} // namespace weight_transform +} // namespace winograd +} // namespace arm_conv diff --git a/src/core/NEON/kernels/convolution/winograd/weight_transforms/arm_fp32_4x4_3x3.cpp b/src/core/NEON/kernels/convolution/winograd/weight_transforms/arm_fp32_4x4_3x3.cpp new file mode 100644 index 0000000000..53cfa3d1d4 --- /dev/null +++ b/src/core/NEON/kernels/convolution/winograd/weight_transforms/arm_fp32_4x4_3x3.cpp @@ -0,0 +1,236 @@ +/* + * Copyright (c) 2022 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include +#include + +namespace arm_conv { +namespace winograd { +namespace weight_transform { + +void arm_fp32_4x4_3x3( + unsigned int n_channels, + const float *inptr, const size_t ld_weight_row, const size_t ld_weight_col, + float *outptr, const size_t matrix_stride +) +{ +#ifdef __aarch64__ + for (; n_channels >= 4; n_channels -= 4) + { + // Matrices used and computed in this kernel + float32x4_t w[3][3], Ww[6][3], V[6][6]; + + // Read weights + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < 3; j++) + { + w[i][j] = vld1q_f32(inptr + i*ld_weight_row + j*ld_weight_col); + } + } + + // Compute the matrix W w + for (int j = 0; j < 3; j++) + { + // Ww[0][j] = 6*w[0][j]; + Ww[0][j] = vmulq_n_f32(w[0][j], 6.0); + + // Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j]; + Ww[1][j] = vmulq_n_f32(vaddq_f32(vaddq_f32(w[0][j], w[1][j]), w[2][j]), -4.0); + + // Ww[2][j] = -4*w[0][j] + 4*w[1][j] + -4*w[2][j]; + Ww[2][j] = vmulq_n_f32(vsubq_f32(vsubq_f32(w[1][j], w[0][j]), w[2][j]), 4.0); + + // Ww[3][j] = 1*w[0][j] + 2*w[1][j] + 4*w[2][j]; + Ww[3][j] = vmlaq_n_f32(vmlaq_n_f32(w[0][j], w[1][j], 2.0f), w[2][j], 4.0f); + + // Ww[4][j] = 1*w[0][j] + -2*w[1][j] + 4*w[2][j]; + Ww[4][j] = vmlaq_n_f32(vmlsq_n_f32(w[0][j], w[1][j], 2.0f), w[2][j], 4.0f); + + // Ww[5][j] = 24*w[2][j]; + Ww[5][j] = vmulq_n_f32(w[2][j], 24.0f); + } + + // Compute V = W w WT + for (int i = 0; i < 6; i++) + { + const float recip576 = 1.0f / 576.0f; + + // V[i][0] = 6*Ww[i][0]; + V[i][0] = vmulq_n_f32(vmulq_n_f32(Ww[i][0], 6.0), recip576); + + // V[i][1] = -4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2]; + V[i][1] = vmulq_n_f32(vmulq_n_f32(vaddq_f32(vaddq_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), -4.0), recip576); + + // V[i][2] = -4*Ww[i][0] + 4*Ww[i][1] + -4*Ww[i][2]; + V[i][2] = vmulq_n_f32(vmulq_n_f32(vsubq_f32(vsubq_f32(Ww[i][1], Ww[i][0]), Ww[i][2]), 4.0), recip576); + + // V[i][3] = 1*Ww[i][0] + 2*Ww[i][1] + 4*Ww[i][2]; + V[i][3] = vmulq_n_f32(vmlaq_n_f32(vmlaq_n_f32(Ww[i][0], Ww[i][1], 2.0f), Ww[i][2], 4.0f), recip576); + + // V[i][4] = 1*Ww[i][0] + -2*Ww[i][1] + 4*Ww[i][2]; + V[i][4] = vmulq_n_f32(vmlaq_n_f32(vmlsq_n_f32(Ww[i][0], Ww[i][1], 2.0f), Ww[i][2], 4.0f), recip576); + + // V[i][5] = 24*Ww[i][2]; + V[i][5] = vmulq_n_f32(vmulq_n_f32(Ww[i][2], 24.0f), recip576); + } + + // Store the transformed weights + for (int i = 0, m = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++, m++) + { + vst1q_f32(outptr + m*matrix_stride, V[i][j]); + } + } + + inptr += 4; + outptr += 4; + } +#endif // __aarch64__ + for (; n_channels >= 2; n_channels -= 2) + { + // Matrices used and computed in this kernel + float32x2_t w[3][3], Ww[6][3], V[6][6]; + + // Read weights + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < 3; j++) + { + w[i][j] = vld1_f32(inptr + i*ld_weight_row + j*ld_weight_col); + } + } + + // Compute the matrix W w + for (int j = 0; j < 3; j++) + { + // Ww[0][j] = 6*w[0][j]; + Ww[0][j] = vmul_n_f32(w[0][j], 6.0); + + // Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j]; + Ww[1][j] = vmul_n_f32(vadd_f32(vadd_f32(w[0][j], w[1][j]), w[2][j]), -4.0); + + // Ww[2][j] = -4*w[0][j] + 4*w[1][j] + -4*w[2][j]; + Ww[2][j] = vmul_n_f32(vsub_f32(vsub_f32(w[1][j], w[0][j]), w[2][j]), 4.0); + + // Ww[3][j] = 1*w[0][j] + 2*w[1][j] + 4*w[2][j]; + Ww[3][j] = vmla_n_f32(vmla_n_f32(w[0][j], w[1][j], 2.0f), w[2][j], 4.0f); + + // Ww[4][j] = 1*w[0][j] + -2*w[1][j] + 4*w[2][j]; + Ww[4][j] = vmla_n_f32(vmls_n_f32(w[0][j], w[1][j], 2.0f), w[2][j], 4.0f); + + // Ww[5][j] = 24*w[2][j]; + Ww[5][j] = vmul_n_f32(w[2][j], 24.0f); + } + + // Compute V = W w WT + for (int i = 0; i < 6; i++) + { + const float recip576 = 1.0f / 576.0f; + + // V[i][0] = 6*Ww[i][0]; + V[i][0] = vmul_n_f32(vmul_n_f32(Ww[i][0], 6.0), recip576); + + // V[i][1] = -4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2]; + V[i][1] = vmul_n_f32(vmul_n_f32(vadd_f32(vadd_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), -4.0), recip576); + + // V[i][2] = -4*Ww[i][0] + 4*Ww[i][1] + -4*Ww[i][2]; + V[i][2] = vmul_n_f32(vmul_n_f32(vsub_f32(vsub_f32(Ww[i][1], Ww[i][0]), Ww[i][2]), 4.0), recip576); + + // V[i][3] = 1*Ww[i][0] + 2*Ww[i][1] + 4*Ww[i][2]; + V[i][3] = vmul_n_f32(vmla_n_f32(vmla_n_f32(Ww[i][0], Ww[i][1], 2.0f), Ww[i][2], 4.0f), recip576); + + // V[i][4] = 1*Ww[i][0] + -2*Ww[i][1] + 4*Ww[i][2]; + V[i][4] = vmul_n_f32(vmla_n_f32(vmls_n_f32(Ww[i][0], Ww[i][1], 2.0f), Ww[i][2], 4.0f), recip576); + + // V[i][5] = 24*Ww[i][2]; + V[i][5] = vmul_n_f32(vmul_n_f32(Ww[i][2], 24.0f), recip576); + } + + // Store the transformed weights + for (int i = 0, m = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++, m++) + { + vst1_f32(outptr + m*matrix_stride, V[i][j]); + } + } + + inptr += 2; + outptr += 2; + } + for (; n_channels; n_channels--) + { + // Matrices used and computed in this kernel + float w[3][3], Ww[6][3], V[6][6]; + + // Read weights + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < 3; j++) + { + w[i][j] = *(inptr + i*ld_weight_row + j*ld_weight_col); + } + } + + // Compute the matrix W w + for (int j = 0; j < 3; j++) + { + Ww[0][j] = 6*w[0][j]; + Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j]; + Ww[2][j] = -4*w[0][j] + 4*w[1][j] + -4*w[2][j]; + Ww[3][j] = 1*w[0][j] + 2*w[1][j] + 4*w[2][j]; + Ww[4][j] = 1*w[0][j] + -2*w[1][j] + 4*w[2][j]; + Ww[5][j] = 24*w[2][j]; + } + + // Compute V = W w WT + for (int i = 0; i < 6; i++) + { + V[i][0] = ( 6*Ww[i][0]) / 576.0; + V[i][1] = (-4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2]) / 576.0; + V[i][2] = (-4*Ww[i][0] + 4*Ww[i][1] + -4*Ww[i][2]) / 576.0; + V[i][3] = ( 1*Ww[i][0] + 2*Ww[i][1] + 4*Ww[i][2]) / 576.0; + V[i][4] = ( 1*Ww[i][0] + -2*Ww[i][1] + 4*Ww[i][2]) / 576.0; + V[i][5] = (24*Ww[i][2]) / 576.0; + } + + // Store the transformed weights + for (int i = 0, m = 0; i < 6; i++) + { + for (int j = 0; j < 6; j++, m++) + { + *(outptr + m*matrix_stride) = V[i][j]; + } + } + + inptr++; + outptr++; + } +} + +} // namespace weight_transform +} // namespace winograd +} // namespace arm_conv diff --git a/src/core/NEON/kernels/convolution/winograd/weight_transforms/cpp_fp32_1x2_1x7.cpp b/src/core/NEON/kernels/convolution/winograd/weight_transforms/cpp_fp32_1x2_1x7.cpp new file mode 100644 index 0000000000..834f982f37 --- /dev/null +++ b/src/core/NEON/kernels/convolution/winograd/weight_transforms/cpp_fp32_1x2_1x7.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2022 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include + +namespace arm_conv { +namespace winograd { +namespace weight_transform { + +void cpp_fp32_1x2_1x7( + unsigned int n_channels, + const float* inptr, size_t, size_t ld_weight_col, + float *outptr, size_t matrix_stride +) +{ + for (; n_channels; n_channels--) + { + // Matrices used and computed in this kernel + float w[7], V[8]; + + // Read weights + for (int j = 0; j < 7; j++) + { + w[j] = *(inptr + j*ld_weight_col); + } + + // Compute V = w WT + V[0] = (w[0]*-1) / 36.0f; + V[1] = (w[1]*-1 + w[3]*-1 + w[5]*-1 + w[0]*1 + w[2]*1 + w[4]*1 + w[6]*1) / 48.0f; + V[2] = (w[0]*1 + w[1]*1 + w[2]*1 + w[3]*1 + w[4]*1 + w[5]*1 + w[6]*1) / 48.0f; + V[3] = (w[0]*-1 + w[6]*-64 + w[4]*-16 + w[2]*-4 + w[1]*2 + w[3]*8 + w[5]*32) / 120.0f; + V[4] = (w[0]*-1 + w[6]*-64 + w[5]*-32 + w[4]*-16 + w[3]*-8 + w[2]*-4 + w[1]*-2) / 120.0f; + V[5] = (w[5]*-243 + w[3]*-27 + w[1]*-3 + w[2]*9 + w[4]*81 + w[6]*729 + w[0]*1) / 720.0f; + V[6] = (w[1]*3 + w[2]*9 + w[3]*27 + w[4]*81 + w[5]*243 + w[6]*729 + w[0]*1) / 720.0f; + V[7] = (w[6]*1) / 1.0f; + + // Store the transformed weights + for (int j = 0; j < 8; j++) + { + *(outptr + j*matrix_stride) = V[j]; + } + + inptr++; + outptr++; + } +} + +} // namespace output_transform +} // namespace winograd +} // namespace arm_conv diff --git a/src/core/NEON/kernels/convolution/winograd/weight_transforms/cpp_fp32_1x4_1x5.cpp b/src/core/NEON/kernels/convolution/winograd/weight_transforms/cpp_fp32_1x4_1x5.cpp new file mode 100644 index 0000000000..585fb2516b --- /dev/null +++ b/src/core/NEON/kernels/convolution/winograd/weight_transforms/cpp_fp32_1x4_1x5.cpp @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2022 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include + +namespace arm_conv { +namespace winograd { +namespace weight_transform { + +void cpp_fp32_1x4_1x5( + unsigned int n_channels, + const float *inptr, + size_t, // ld_weight_row + size_t ld_weight_col, + float *outptr, + size_t matrix_stride +) +{ + constexpr auto kernel_cols = 5u, inner_tile_cols = 8u; + + // For each output channel + for (; n_channels; n_channels--) + { + // Matrices used and computed in this kernel + float w[kernel_cols], V[inner_tile_cols]; + + // Read weights + for (auto j = 0u; j < kernel_cols; j++) + { + w[j] = *(inptr + j * ld_weight_col); + } + + // Compute V = w WT + V[0] = (w[0]*-1) / 36; + V[1] = (w[1]*-1 + w[3]*-1 + w[0]*1 + w[2]*1 + w[4]*1) / 48; + V[2] = (w[0]*1 + w[1]*1 + w[2]*1 + w[3]*1 + w[4]*1) / 48; + V[3] = (w[0]*-1 + w[4]*-16 + w[2]*-4 + w[1]*2 + w[3]*8) / 120; + V[4] = (w[0]*-1 + w[4]*-16 + w[3]*-8 + w[2]*-4 + w[1]*-2) / 120; + V[5] = (w[3]*-27 + w[1]*-3 + w[2]*9 + w[4]*81 + w[0]*1) / 720; + V[6] = (w[1]*3 + w[2]*9 + w[3]*27 + w[4]*81 + w[0]*1) / 720; + V[7] = (w[4]*1) / 1; + + // Store the transformed weights + for (auto j = 0u; j < inner_tile_cols; j++) + { + *(outptr + j*matrix_stride) = V[j]; + } + + inptr++; + outptr++; + } +} + +} // namespace weight_transform +} // namespace winograd +} // namespace arm_conv diff --git a/src/core/NEON/kernels/convolution/winograd/weight_transforms/cpp_fp32_1x6_1x3.cpp b/src/core/NEON/kernels/convolution/winograd/weight_transforms/cpp_fp32_1x6_1x3.cpp new file mode 100644 index 0000000000..63754e529c --- /dev/null +++ b/src/core/NEON/kernels/convolution/winograd/weight_transforms/cpp_fp32_1x6_1x3.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2022 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include + +namespace arm_conv { +namespace winograd { +namespace weight_transform { + +void cpp_fp32_1x6_1x3( + unsigned int n_channels, + const float *inptr, size_t, size_t ld_weight_col, + float *outptr, size_t matrix_stride +) +{ + for (; n_channels; n_channels--) + { + // Matrices used and computed in this kernel + float w[3], V[8]; + + // Read weights + for (int j = 0; j < 3; j++) + { + w[j] = *(inptr + j * ld_weight_col); + } + + // Compute V = w WT + V[0] = (w[0]*-1) / 36.0f; + V[1] = (w[1]*-1 + w[0]*1 + w[2]*1) / 48.0f; + V[2] = (w[0]*1 + w[1]*1 + w[2]*1) / 48.0f; + V[3] = (w[0]*-1 + w[2]*-4 + w[1]*2) / 120.0f; + V[4] = (w[0]*-1 + w[2]*-4 + w[1]*-2) / 120.0f; + V[5] = (w[1]*-3 + w[2]*9 + w[0]*1) / 720.0f; + V[6] = (w[1]*3 + w[2]*9 + w[0]*1) / 720.0f; + V[7] = (w[2]*1) / 1; + + // Store the transformed weights + for (int j = 0; j < 8; j++) + { + *(outptr + j*matrix_stride) = V[j]; + } + + inptr++; + outptr++; + } +} + +} // namespace weight_transform +} // namespace winograd +} // namespace arm_conv diff --git a/src/core/NEON/kernels/convolution/winograd/weight_transforms_fp16.cpp b/src/core/NEON/kernels/convolution/winograd/weight_transforms_fp16.cpp new file mode 100644 index 0000000000..6c8bbe07cf --- /dev/null +++ b/src/core/NEON/kernels/convolution/winograd/weight_transforms_fp16.cpp @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + +#include "winograd_implementations.hpp" +#include "weight_transform.hpp" + +namespace arm_conv { +namespace winograd { +namespace weight_transform { + +void *a64_fp16_4x4_3x3(unsigned int, const __fp16 *, size_t, size_t, __fp16 *, size_t); + +#define IMPL(KERN_ROWS, KERN_COLS, TRANS_ROWS, TRANS_COLS, KERN) \ + new Transform<__fp16>(#KERN, KERN_ROWS, KERN_COLS, TRANS_ROWS, TRANS_COLS, KERN) + +static const TransformImplementation<__fp16> transforms_fp16[] = { + { IMPL(3, 3, 6, 6, a64_fp16_4x4_3x3) }, + { nullptr } +}; + +template <> +const TransformImplementation<__fp16> *implementation_list(void) +{ + return transforms_fp16; +} + +} // namespace weight_transform +} // namespace winograd +} // namespace arm_conv + +#endif // defined(__aarch64__) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) diff --git a/src/core/NEON/kernels/convolution/winograd/weight_transforms_fp32.cpp b/src/core/NEON/kernels/convolution/winograd/weight_transforms_fp32.cpp new file mode 100644 index 0000000000..63f5fc786c --- /dev/null +++ b/src/core/NEON/kernels/convolution/winograd/weight_transforms_fp32.cpp @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "winograd_implementations.hpp" +#include "weight_transform.hpp" + +namespace arm_conv { +namespace winograd { +namespace weight_transform { + +#if defined(__aarch64__) +#if defined(ARM_COMPUTE_ENABLE_SVE) +#endif // defined(ARM_COMPUTE_ENABLE_SVE) +#endif // defined(__aarch64__) +void *arm_fp32_4x4_3x3(unsigned int, const float *, size_t, size_t, float *, size_t); +void *arm_fp32_2x2_3x3(unsigned int, const float *, size_t, size_t, float *, size_t); +void *arm_fp32_2x2_5x5(unsigned int, const float *, size_t, size_t, float *, size_t); +void *cpp_fp32_1x6_1x3(unsigned int, const float *, size_t, size_t, float *, size_t); +void *cpp_fp32_1x4_1x5(unsigned int, const float *, size_t, size_t, float *, size_t); +void *cpp_fp32_1x2_1x7(unsigned int, const float *, size_t, size_t, float *, size_t); + +#define IMPL(KERN_ROWS, KERN_COLS, TRANS_ROWS, TRANS_COLS, KERN) \ + new Transform(#KERN, KERN_ROWS, KERN_COLS, TRANS_ROWS, TRANS_COLS, KERN) + +#define IMPL_T(KERN_ROWS, KERN_COLS, TRANS_ROWS, TRANS_COLS, KERN) \ + new Transform(#KERN, KERN_ROWS, KERN_COLS, TRANS_ROWS, TRANS_COLS, Transform::get_transposed_kernel(KERN)) + +static const TransformImplementation transforms_fp32[] = { +#if defined(__aarch64__) +#if defined(ARM_COMPUTE_ENABLE_SVE) +#endif // defined(ARM_COMPUTE_ENABLE_SVE) +#endif // defined(__aarch64__) + { IMPL(3, 3, 6, 6, arm_fp32_4x4_3x3) }, + { IMPL(3, 3, 4, 4, arm_fp32_2x2_3x3) }, + { IMPL(5, 5, 6, 6, arm_fp32_2x2_5x5) }, + { IMPL(1, 3, 1, 8, cpp_fp32_1x6_1x3) }, + { IMPL_T(3, 1, 8, 1, cpp_fp32_1x6_1x3) }, + { IMPL(1, 5, 1, 8, cpp_fp32_1x4_1x5) }, + { IMPL_T(5, 1, 8, 1, cpp_fp32_1x4_1x5) }, + { IMPL(1, 7, 1, 8, cpp_fp32_1x2_1x7) }, + { IMPL_T(7, 1, 8, 1, cpp_fp32_1x2_1x7) }, + { nullptr } +}; + +template <> +const TransformImplementation *implementation_list(void) +{ + return transforms_fp32; +} + +} // namespace weight_transform +} // namespace winograd +} // namespace arm_conv diff --git a/src/core/NEON/kernels/convolution/winograd/winograd.cpp b/src/core/NEON/kernels/convolution/winograd/winograd.cpp deleted file mode 100644 index d556112853..0000000000 --- a/src/core/NEON/kernels/convolution/winograd/winograd.cpp +++ /dev/null @@ -1,182 +0,0 @@ -/* - * Copyright (c) 2017-2019 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -#include -#include "utils.hpp" -#include "winograd.hpp" - -using namespace winograd; -using array2 = std::pair; - -#define MEMBERFN(RTYPE) \ - template \ - template \ - RTYPE WinogradGEMM::Convolution - -/** Get the output shape of a convolution. */ -MEMBERFN(array2) -::get_output_shape(const std::pair input_shape, - const bool padding_same) { - const unsigned int n_rows = - padding_same ? input_shape.first : input_shape.first - (kernel_rows - 1); - const unsigned int n_cols = padding_same - ? input_shape.second - : input_shape.second - (kernel_cols - 1); - return {n_rows, n_cols}; -} - -/** Get the memory required to store the kernel transformed into the - * Winograd domain. - */ -MEMBERFN(size_t) -::get_kernel_storage_size(const unsigned int n_input_channels, - const unsigned int n_output_channels) { - return N_GEMMS * get_kernel_matrix_size(n_input_channels, n_output_channels); -} - -MEMBERFN(size_t) -::get_input_storage_size(const unsigned int n_batches, - const unsigned int n_rows, const unsigned int n_cols, - const unsigned int n_channels, - const bool same_padding) { - return N_GEMMS * get_input_matrix_size(n_batches, n_rows, n_cols, n_channels, - same_padding); -} - -MEMBERFN(size_t) -::get_output_storage_size(const unsigned int n_batches, - const unsigned int n_rows, const unsigned int n_cols, - const unsigned int n_channels) { - return N_GEMMS * - get_output_matrix_size(n_batches, n_rows, n_cols, n_channels); -} - -/** Get the memory required to apply a Winograd operator to some input. - */ -MEMBERFN(size_t) -::get_working_space_size(const unsigned int n_batches, - const unsigned int n_rows, const unsigned int n_cols, - const unsigned int n_input_channels, - const unsigned int n_output_channels, - const bool padding_same) { - const auto output_shape = get_output_shape({n_rows, n_cols}, padding_same); - - // Get the memory required to store the matrices - const size_t matrix_sizes = - N_GEMMS * - (get_input_matrix_size(n_batches, n_rows, n_cols, n_input_channels, - padding_same) + - get_output_matrix_size(n_batches, output_shape.first, - output_shape.second, n_output_channels)); - return matrix_sizes; -} - -/* Get the memory required by a single "input" matrix. - */ -MEMBERFN(size_t) -::get_input_matrix_size(const unsigned int n_batches, const unsigned int n_rows, - const unsigned int n_cols, - const unsigned int n_channels, - const bool same_padding) { - return get_input_matrix_stride(n_batches, n_rows, n_cols, n_channels, - same_padding) * - sizeof(TGEMMIn); -} - -MEMBERFN(int) -::get_input_matrix_stride(const unsigned int n_batches, const unsigned int n_rows, - const unsigned int n_cols, - const unsigned int n_channels, - const bool same_padding) { - const auto output_shape = get_output_shape({n_rows, n_cols}, same_padding); - const unsigned int tile_rows = iceildiv(output_shape.first, output_tile_rows); - const unsigned int tile_cols = - iceildiv(output_shape.second, output_tile_cols); - const unsigned int M = - roundup(n_batches * tile_rows * tile_cols, M_BLOCK); - const unsigned int K = n_channels; - - return M * K; -} - -/* Get the memory required by a single "output" matrix. - */ -MEMBERFN(size_t) -::get_output_matrix_size(const unsigned int n_batches, - const unsigned int n_rows, const unsigned int n_cols, - const unsigned int n_channels) { - return get_output_matrix_stride(n_batches, n_rows, n_cols, n_channels) * - sizeof(TGEMMOut); -} - -MEMBERFN(int) -::get_output_matrix_stride(const unsigned int n_batches, - const unsigned int n_rows, const unsigned int n_cols, - const unsigned int n_channels) { - // Compute shape for the GEMM - const int tile_rows = iceildiv(n_rows, output_tile_rows); - const int tile_cols = iceildiv(n_cols, output_tile_cols); - const int M = roundup(tile_rows * tile_cols, M_BLOCK); - const int N = roundup(n_channels, N_BLOCK); - - return n_batches * M * N; -} - - -/* Get the memory required by a single "kernel" matrix. - */ -MEMBERFN(size_t) -::get_kernel_matrix_size(const unsigned int n_input_channels, - const unsigned int n_output_channels) { - return sizeof(TGEMMIn) * - get_kernel_matrix_stride(n_input_channels, n_output_channels); -} - -MEMBERFN(int) -::get_kernel_matrix_stride(const unsigned int n_input_channels, - const unsigned int n_output_channels) { - return n_input_channels * roundup(n_output_channels, N_BLOCK); -} - -// Instantiate required implementations -template class WinogradGEMM<2, 2, 3, 3, WinogradRoots::Integers>::Convolution; -template class WinogradGEMM<4, 4, 3, 3, WinogradRoots::Integers>::Convolution; - -template class WinogradGEMM<1, 6, 1, 3, WinogradRoots::Integers>::Convolution; -template class WinogradGEMM<6, 1, 3, 1, WinogradRoots::Integers>::Convolution; - -template class WinogradGEMM<2, 2, 5, 5, WinogradRoots::Integers>::Convolution; - -template class WinogradGEMM<1, 4, 1, 5, WinogradRoots::Integers>::Convolution; -template class WinogradGEMM<4, 1, 5, 1, WinogradRoots::Integers>::Convolution; - -template class WinogradGEMM<1, 2, 1, 7, WinogradRoots::Integers>::Convolution; -template class WinogradGEMM<2, 1, 7, 1, WinogradRoots::Integers>::Convolution; - -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -template class WinogradGEMM<4, 4, 3, 3, WinogradRoots::Integers>::Convolution<__fp16, __fp16, __fp16, __fp16>; -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC diff --git a/src/core/NEON/kernels/convolution/winograd/winograd.hpp b/src/core/NEON/kernels/convolution/winograd/winograd.hpp deleted file mode 100644 index ac82e7b7b9..0000000000 --- a/src/core/NEON/kernels/convolution/winograd/winograd.hpp +++ /dev/null @@ -1,621 +0,0 @@ -/* - * Copyright (c) 2017-2019 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -#pragma once - -#include "arm_gemm.hpp" - -#include -#include - -namespace winograd -{ - -class ITransform -{ - public: - virtual ~ITransform() = default; - - /** - * Get the working space required to perform the transformation. - * - * Note, the working space is only required when performing the - * transformation - hence it can be reused whenever the transformation is - * not running. - * - * @param nthreads The greatest number of threads that will be used to execute the transform. - * @return Size of working space required in bytes. - */ - virtual size_t get_working_space_size(unsigned int nthreads=1) const = 0; - - /** - * Set the working space to be used by the transformation. - * - * Note, the working space is only required when performing the - * transformation - hence it can be reused whenever the transformation is - * not running. - * - * @param Pointer to the working space. - */ - virtual void set_working_space(void *buffer) = 0; - - /** - * Get the window of work a given operator can perform. - */ - virtual unsigned int get_window() const = 0; - - /** - * Perform work upon a window of the transform. - */ - virtual void run(unsigned int start, unsigned int stop, unsigned int threadid=0) = 0; -}; - -class IInputTransform : public ITransform -{ - public: - virtual ~IInputTransform() = default; - - /** - * Set the pointer to the (NHWC-ordered) tensor to be transformed. - */ - virtual void set_input_tensor(const void *input) = 0; - - /** - * Set the pointer to the (NHWC-ordered) tensor to be transformed. - * @param col_stride Stride between columns of the tensor, measured in elements (not bytes). - */ - virtual void set_input_tensor(const void *input, int col_stride) = 0; - - /** - * Set the pointer to the (NHWC-ordered) tensor to be transformed. - * @param row_stride Stride between rows of the tensor, measured in elements (not bytes). - * @param col_stride Stride between columns of the tensor, measured in elements (not bytes). - */ - virtual void set_input_tensor(const void *input, int row_stride, int col_stride) = 0; - - /** - * Set the pointer to the (NHWC-ordered) tensor to be transformed. - * @param batch_stride Stride between batches of the tensor, measured in elements (not bytes). - * @param row_stride Stride between rows of the tensor, measured in elements (not bytes). - * @param col_stride Stride between columns of the tensor, measured in elements (not bytes). - */ - virtual void set_input_tensor(const void *input, int batch_stride, int row_stride, int col_stride) = 0; - - /** - * Set pointers to the matrices written by the transform. - * @param matrices Pointer to the start of the first matrix representing the transformed input. - * @param inter_matrix_stride Stride (in elements) between matrices. - * @param matrix_row_stride Stride (in elements) between the rows within a single matrix. - */ - virtual void set_output_matrices(void *matrices, int inter_matrix_stride, int matrix_row_stride) = 0; -}; - -class IOutputTransform : public ITransform -{ - public: - virtual ~IOutputTransform() = default; - - /** - * Set pointers to the matrices written by the transform. - * @param matrices Pointer to the start of the first matrix representing the input to the transform. - * @param inter_matrix_stride Stride (in elements) between matrices. - * @param matrix_row_stride Stride (in elements) between the rows within a single matrix. - */ - virtual void set_input_matrices(const void *matrices, int inter_matrix_stride, int matrix_row_stride) = 0; - - /** - * Set pointer to the bias tensor (can be ignored or called with nullptr for no bias. - */ - virtual void set_bias(const void *bias=nullptr) = 0; - - /** - * Set pointer to the output tensor produced by the transform. - */ - virtual void set_output_tensor(void *output) = 0; - - /** - * Set pointer to the output tensor produced by the transform. - * @param col_stride Stride between columns of the tensor, measured in elements (not bytes). - */ - virtual void set_output_tensor(void *output, int col_stride) = 0; - - /** - * Set pointer to the output tensor produced by the transform. - * @param row_stride Stride between rows of the tensor, measured in elements (not bytes). - * @param col_stride Stride between columns of the tensor, measured in elements (not bytes). - */ - virtual void set_output_tensor(void *output, int row_stride, int col_stride) = 0; - - /** - * Set pointer to the output tensor produced by the transform. - * @param batch_stride Stride between batches of the tensor, measured in elements (not bytes). - * @param row_stride Stride between rows of the tensor, measured in elements (not bytes). - * @param col_stride Stride between columns of the tensor, measured in elements (not bytes). - */ - virtual void set_output_tensor(void *output, int batch_stride, int row_stride, int col_stride) = 0; -}; - -class IWeightTransform : public ITransform -{ - public: - virtual ~IWeightTransform() = default; - - /** Set pointer to the weight tensor read by the transform. */ - virtual void set_weight_tensor(const void *weights) = 0; - - /** - * Set pointers to the matrices written by the transform. - * @param matrices Pointer to the start of the first matrix representing the transformed input. - * @param inter_matrix_stride Stride (in elements) between matrices. - * @param matrix_row_stride Stride (in elements) between the rows within a single matrix. - */ - virtual void set_output_matrices(void *matrices, int inter_matrix_stride, int matrix_row_stride) = 0; -}; - -enum class WinogradRoots -{ - Integers, -}; - -template -class InputTransform : public IInputTransform -{ - public: - /** Create an InputTransform operator fixed on a given problem and set of - * pointers. - */ - InputTransform( - int kernel_rows, /**< Number of rows in the kernel */ - int kernel_cols, /**< Number of columns in the kernel */ - int n_batches, /**< Number of batches in input tensor. */ - int n_rows, /**< Number of rows in input tensor. */ - int n_cols, /**< Number of columns in input tensor. */ - int n_channels, /**< Number of channels in input tensor. */ - int padding_top, /**< Padding to apply to the top of the image. */ - int padding_left, /**< Padding to apply to the left of the image. */ - int padding_bottom, /**< Padding to apply to the bottom of the image. */ - int padding_right /**< Padding to apply to the right of the image. */ - ); - - InputTransform(InputTransform&) = delete; - InputTransform operator=(InputTransform&) = delete; - - /** Set pointers to the input tensor read by the transform. */ - void set_input_tensor(const void *input) override; - void set_input_tensor(const void *input, int col_stride) override; - void set_input_tensor(const void *input, int row_stride, int col_stride) override; - void set_input_tensor(const void *input, int batch_stride, int row_stride, int col_stride) override; - - /** Set pointers to the matrices written by the transform. */ - void set_output_matrices(void *matrices, int iter_matrix_stride, int matrix_row_stride) override; - - /** Get the working space required to perform the transformation. */ - size_t get_working_space_size(unsigned int nthreads=1) const override; - void set_working_space(void *buffer) override; - - /** Get the window of work a given operator can perform. */ - unsigned int get_window() const override; - static constexpr unsigned int WINDOW_BLOCK = 16; // Base size of window - - /** Perform work upon a window of the input. */ - void run(unsigned int start, unsigned int stop, unsigned int threadid=0) override; - - protected: - const int _n_batches, _n_rows, _n_cols, _n_channels; - - private: - void transform_unpadded_tile( - unsigned int threadid, - int n_channels, - TOut *outptr, - const TIn *inptr - ); - - void transform_padded_tile( - unsigned int threadid, - int n_channels, - TOut *outptr, - const TIn *inptr, - int padding_top, - int padding_left, - int padding_bottom, - int padding_right - ); - - /* Tile implementation */ - static void transform_tile( - int n_channels, /** @param[in] Number of channels in the tensor. */ - const TIn* inptr_base, /** @param[in] Pointer to the base of the input tile. */ - int input_row_stride, /** @param[in] Stride between rows of the input tensor. */ - int input_col_stride, /** @param[in] Stride between columns of the input tensor. */ - TOut* mptr_base, /** @param[out] Base pointer to transformed input matrices. */ - int matrix_stride /** @param[in] Stride between matrices in the input space. */ - ); - - /** Get the working space for a thread. */ - void * get_working_space(unsigned int threadid) const; - - const TIn* _inptr; - TOut* _outptr; - - const int _overlap_rows, _overlap_cols; - const int _padding_top, _padding_left, _padding_bottom, _padding_right; - const int _tiles_M, _tiles_N; - int _matrix_stride, _matrix_row_stride, _matrix_batch_stride; - int _in_col_stride, _in_row_stride, _in_batch_stride; - - const int _working_space_col_stride, _working_space_row_stride; - TIn *_working_space; -}; - -template -class InputTransform : - public InputTransform<1, InnerTileRows, TIn, TOut, Roots> -{ - using Base = InputTransform<1, InnerTileRows, TIn, TOut, Roots>; - - public: - InputTransform( - int kernel_rows, /**< Number of rows in the kernel. */ - int kernel_cols, /**< Number of columns in the kernel. */ - int n_batches, /**< Number of batches in input tensor. */ - int n_rows, /**< Number of rows in input tensor. */ - int n_cols, /**< Number of columns in input tensor. */ - int n_channels, /**< Number of channels in input tensor. */ - int padding_top, /**< Padding to apply to the top of the image. */ - int padding_left, /**< Padding to apply to the left of the image. */ - int padding_bottom, /**< Padding to apply to the bottom of the image. */ - int padding_right /**< Padding to apply to the right of the image. */ - ); - - /** Set pointers to the input tensor read by the transform. */ - void set_input_tensor(const void *input) override; - void set_input_tensor(const void *input, int col_stride) override; - void set_input_tensor(const void *input, int row_stride, int col_stride) override; - void set_input_tensor(const void *input, int batch_stride, int row_stride, int col_stride) override; -}; - -template < - int KernelRows, int KernelCols, - int InnerTileRows, int InnerTileCols, - typename TIn, typename TOut, - WinogradRoots Roots -> -class OutputTransform : public IOutputTransform -{ - public: - OutputTransform( - int n_batches, /**< Number of batches in output tensor. */ - int n_rows, /**< Number of rows in output tensor. */ - int n_cols, /**< Number of columns in output tensor. */ - int n_channels, /**< Number of channels in output tensor. */ - const arm_gemm::Activation &activation - ); - - OutputTransform(OutputTransform&) = delete; - OutputTransform operator=(OutputTransform&) = delete; - - /** Set pointers to the matrices read by the transform. */ - void set_input_matrices(const void *matrices, int iter_matrix_stride, int matrix_row_stride) override; - - /** Set pointer to the bias tensor (can be ignored or called with nullptr for no bias */ - void set_bias(const void *bias=nullptr) override; - - /** Set pointers to the output tensor written by the transform. */ - void set_output_tensor(void *output) override; - void set_output_tensor(void *output, int col_stride) override; - void set_output_tensor(void *output, int row_stride, int col_stride) override; - void set_output_tensor(void *output, int batch_stride, int row_stride, int col_stride) override; - - /** Get the working space required to perform the transformation. */ - size_t get_working_space_size(unsigned int nthreads=1) const override; - void set_working_space(void *buffer) override; - - /** Get the window of work a given operator can perform. */ - unsigned int get_window() const override; - static constexpr unsigned int WINDOW_BLOCK = 16; // Base size of window - - /** Perform work upon a window of the input. */ - void run(unsigned int start, unsigned int stop, unsigned int threadid=0) override; - - protected: - static constexpr int inner_tile_rows = InnerTileRows; - static constexpr int inner_tile_cols = InnerTileCols; - static constexpr int output_tile_rows = InnerTileRows - KernelRows + 1; - static constexpr int output_tile_cols = InnerTileCols - KernelCols + 1; - - const int _n_batches, _n_rows, _n_cols, _n_channels; - const TOut _output_min, _output_max; - - private: - void transform_uncropped_tile( - unsigned int threadid, - int n_channels, - TOut *outptr, - const TIn *inptr, - const TOut *biases - ); - - void transform_cropped_tile( - unsigned int threadid, - int n_channels, - TOut *outptr, - const TIn *inptr, - const TOut *biases, - int pad_bottom, - int pad_right - ); - - /** Implementation of the tile transformation method. */ - static void transform_tile( - int n_channels, - const TIn* matrix_base, - int matrix_stride, - const TOut* biases, - TOut* output, - int output_row_stride, - int output_col_stride, - TOut output_min, - TOut output_max - ); - - /** Get the working space for a thread. */ - void * get_working_space(unsigned int threadid) const; - - const TIn* _matrix_base; - const TOut* _biases; - int _matrix_stride, _matrix_row_stride, _matrix_batch_stride; - TOut* _outptr; - const int _tiles_M, _tiles_N; - int _out_col_stride, _out_row_stride, _out_batch_stride; - - const int _working_space_col_stride, _working_space_row_stride; - TOut *_working_space; -}; - -template < - int KernelRows, - int InnerTileRows, - typename TIn, typename TOut, - WinogradRoots Roots -> -class OutputTransform : - public OutputTransform<1, KernelRows, 1, InnerTileRows, TIn, TOut, Roots> -{ - using Base = OutputTransform<1, KernelRows, 1, InnerTileRows, TIn, TOut, Roots>; - - public: - OutputTransform( - int n_batches, /**< Number of batches in output tensor. */ - int n_rows, /**< Number of rows in output tensor. */ - int n_cols, /**< Number of columns in output tensor. */ - int n_channels, /**< Number of channels in output tensor. */ - const arm_gemm::Activation &activation - ); - - /** Set pointers to the output tensor written by the transform. */ - void set_output_tensor(void *output) override; - void set_output_tensor(void *output, int col_stride) override; - void set_output_tensor(void *output, int row_stride, int col_stride) override; - void set_output_tensor(void *output, int batch_stride, int row_stride, int col_stride) override; -}; - -template < - int KernelRows, int KernelCols, - int InnerTileRows, int InnerTileCols, - typename TIn, typename TOut, - WinogradRoots Roots -> -class WeightTransform : public IWeightTransform -{ - public: - WeightTransform( - int n_output_channels, /**< Number of output channels in the kernel. */ - int n_input_channels /**< Number of input channels in the kernel. */ - ); - - WeightTransform(WeightTransform&) = delete; - WeightTransform operator=(WeightTransform&) = delete; - - /** Set pointer to the weight tensor read by the transform. */ - void set_weight_tensor(const void *weights) override; - - /** Set pointer to the matrices written by the transform. */ - void set_output_matrices(void *matrices, int inter_matrix_stride, int matrix_row_stride) override; - - /** Get the working space required to perform the transformation. */ - size_t get_working_space_size(unsigned int nthreads=1) const override; - void set_working_space(void *buffer) override; - - /** Get the window of work a given operator can perform. */ - unsigned int get_window() const override; - static constexpr unsigned int WINDOW_BLOCK = 16; // Base size of window - - /** Perform work upon a window of the input. */ - void run(unsigned int start, unsigned int stop, unsigned int threadid=0) override; - - protected: - static const int kernel_rows = KernelRows; - static const int kernel_cols = KernelCols; - static const int inner_tile_rows = InnerTileRows; - static const int inner_tile_cols = InnerTileCols; - - private: - /** Apply the transform to a tensor. */ - static void execute( - int n_output_channels, - int n_input_channels, - const TIn* input, - TOut* output, - int matrix_stride, - int matrix_row_stride - ); - - const int _n_output_channels, _n_input_channels; - TOut *_matrices; - int _matrix_stride, _matrix_row_stride; - const TIn *_weights; -}; - -template -class WeightTransform : - public WeightTransform<1, KernelRows, 1, InnerTileRows, TIn, TOut, Roots> -{ - public: - using WeightTransform<1, KernelRows, 1, InnerTileRows, TIn, TOut, Roots>::WeightTransform; -}; - -template -class WinogradGEMM -{ - public: - // Information about the specific Winograd instance - static constexpr int output_tile_rows = OutputTileRows; - static constexpr int output_tile_cols = OutputTileCols; - static constexpr int kernel_rows = KernelRows; - static constexpr int kernel_cols = KernelCols; - static constexpr int inner_tile_rows = output_tile_rows + kernel_rows - 1; - static constexpr int inner_tile_cols = output_tile_cols + kernel_cols - 1; - static constexpr int N_GEMMS = inner_tile_rows * inner_tile_cols; - - /** Transform weights from the spatial to the Winograd domain. */ - template - using WeightsTransform = WeightTransform< - KernelRows, KernelCols, inner_tile_rows, inner_tile_cols, - TIn, TOut, Roots - >; - - /** Transform input feature maps from the spatial to the Winograd domain. - */ - template - using InputTransform = InputTransform< - inner_tile_rows, inner_tile_cols, TIn, TOut, Roots - >; - - /** Transform output feature maps from the Winograd to the spatial domain. - */ - template - using OutputTransform = OutputTransform< - KernelRows, KernelCols, inner_tile_rows, inner_tile_cols, - TIn, TOut, Roots - >; - - /** Perform a convolution. - */ - template - class Convolution - { - public: - // Information about the typed Winograd instance - typedef TOut OutputType; - typedef TOutGEMM GemmOutputType; - typedef TInGEMM GemmInputType; - typedef TIn InputType; - - /** Get the output shape of a convolution. */ - static std::pair get_output_shape( - const std::pair input_shape, - bool padding_same); - - /** Get the memory required to store the kernel transformed into the - * Winograd domain. - */ - static size_t get_kernel_storage_size(unsigned int n_input_channels, - unsigned int n_output_channels); - - /** Get the memory required to store the input tensor transformed into - * the Winograd domain. - */ - static size_t get_input_storage_size( - unsigned int n_batches, // Number of batches - unsigned int n_rows, // Number of input rows - unsigned int n_cols, // Number of input columns - unsigned int n_channels, // Number of input channels - bool padding_same); - - /** Get the memory required to store the output tensor in the Winograd - * domain. - */ - static size_t get_output_storage_size( - unsigned int n_batches, // Number of batches - unsigned int n_rows, // Number of output rows - unsigned int n_cols, // Number of output columns - unsigned int n_channels // Number of output channels - ); - - /** Get the memory required to apply a Winograd operator to some input. - */ - static size_t get_working_space_size( - unsigned int n_batches, - unsigned int n_rows, // Number of input rows - unsigned int n_cols, // Number of input columns - unsigned int n_input_channels, // Number of input channels - unsigned int n_output_channels, // Number of output channels - bool padding_same); - - /* Get the memory required by a single "input" matrix. - */ - static size_t get_input_matrix_size( - unsigned int n_batches, // Number of batches - unsigned int n_rows, // Number of input rows - unsigned int n_cols, // Number of input columns - unsigned int n_channels, // Number of input channels - bool padding_same); - - static int get_input_matrix_stride( - unsigned int n_batches, // Number of batches - unsigned int n_rows, // Number of input rows - unsigned int n_cols, // Number of input columns - unsigned int n_channels, // Number of input channels - bool padding_same); - - /* Get the memory required by a single "output" matrix. - */ - static size_t get_output_matrix_size( - unsigned int n_batches, // Number of batches - unsigned int n_rows, // Number of output rows - unsigned int n_cols, // Number of output columns - unsigned int n_channels // Number of output channels - ); - - static int get_output_matrix_stride( - unsigned int n_batches, // Number of batches - unsigned int n_rows, // Number of output rows - unsigned int n_cols, // Number of output columns - unsigned int n_channels // Number of output channels - ); - - /* Get the memory required by a single "kernel" matrix. - */ - static size_t get_kernel_matrix_size(unsigned int n_input_channels, - unsigned int n_output_channels); - static int get_kernel_matrix_stride(unsigned int n_input_channels, - unsigned int n_output_channels); - - static constexpr int M_BLOCK = 4; /** Size of block used by GEMM. */ - static constexpr int N_BLOCK = 16; /** Size of block used by GEMM. */ - }; -}; - -} // namespace winograd diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_fp16.cpp b/src/core/NEON/kernels/convolution/winograd/winograd_fp16.cpp new file mode 100644 index 0000000000..e1ad9e458d --- /dev/null +++ b/src/core/NEON/kernels/convolution/winograd/winograd_fp16.cpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + +#include "winograd_implementations.hpp" + +namespace arm_conv { +namespace winograd { + +template bool get_implementation<__fp16>( + WinogradImpl &, + const CPUInfo *, + const ConvolutionArgs &, + int max_threads, + bool fast_mode, + const WinogradConfig *, + const arm_gemm::GemmConfig * +); + +} // namespace winograd +} // namespace arm_conv + +#endif // defined(__aarch64__) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_fp32.cpp b/src/core/NEON/kernels/convolution/winograd/winograd_fp32.cpp new file mode 100644 index 0000000000..b92de1dde7 --- /dev/null +++ b/src/core/NEON/kernels/convolution/winograd/winograd_fp32.cpp @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "winograd_implementations.hpp" + +namespace arm_conv { +namespace winograd { + +template bool get_implementation( + WinogradImpl &, + const CPUInfo *, + const ConvolutionArgs &, + int max_threads, + bool fast_mode, + const WinogradConfig *, + const arm_gemm::GemmConfig * +); + +} // namespace winograd +} // namespace arm_conv diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_implementations.hpp b/src/core/NEON/kernels/convolution/winograd/winograd_implementations.hpp new file mode 100644 index 0000000000..510f69baaa --- /dev/null +++ b/src/core/NEON/kernels/convolution/winograd/winograd_implementations.hpp @@ -0,0 +1,341 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#include "src/core/NEON/kernels/assembly/winograd.hpp" +#include +#include + +namespace arm_conv { +namespace winograd { + +enum class MethodConstraints +{ + None, + RequiresSVE = 0x1, + RequiresSVE2 = 0x2, + RequiresSME = 0x4, + RequiresSME2 = 0x8, + LargerShape = 0x10, // Input tensor shape is larger than the output transform tile shape. +}; + +constexpr inline bool operator!(const MethodConstraints &c) +{ + return c == MethodConstraints::None; +} + +constexpr inline MethodConstraints operator|(const MethodConstraints &a, const MethodConstraints &b) +{ + return static_cast(static_cast(a) | static_cast(b)); +} + +constexpr inline MethodConstraints operator&(const MethodConstraints &a, const MethodConstraints &b) +{ + return static_cast(static_cast(a) & static_cast(b)); +} + +inline bool constraints_met(const MethodConstraints &c, const CPUInfo *ci, const ConvolutionArgs &, const WinogradConfig *) +{ + return ( + (!(c & MethodConstraints::RequiresSVE) || (ci->has_sve())) && + (!(c & MethodConstraints::RequiresSVE2) || (ci->has_sve2())) && + (!(c & MethodConstraints::RequiresSME) || (ci->has_sme())) && + (!(c & MethodConstraints::RequiresSME2) || (ci->has_sme2())) + // Add further constraints here + ); +} + +inline bool output_transform_constraints_met(const output_transform::ITransform *transform, const MethodConstraints &c, const CPUInfo *ci, const ConvolutionArgs &conv_args, const WinogradConfig *cfg) +{ + return ( + constraints_met(c, ci, conv_args, cfg) && + (!(c & MethodConstraints::LargerShape) || (conv_args.input_shape.rows > transform->get_output_rows() && conv_args.input_shape.cols > transform->get_output_cols())) + ); +} + +namespace weight_transform { + +template +struct TransformImplementation +{ + std::unique_ptr transform; + MethodConstraints constraints; + + TransformImplementation(const ITransform *transform, const MethodConstraints &constraints = MethodConstraints::None) + : transform(transform), constraints(constraints) + { + } +}; + +template +const TransformImplementation *implementation_list(void); + +} // namespace weight_transform + +namespace input_transform +{ + +template +struct TransformImplementation +{ + std::unique_ptr transform; + MethodConstraints constraints; + + TransformImplementation(const ITransform *transform, const MethodConstraints &constraints = MethodConstraints::None) + : transform(transform), constraints(constraints) + { + } +}; + +template +const TransformImplementation *implementation_list(void); + +} // namespace input_transform + +namespace output_transform +{ + +template +struct TransformImplementation +{ + std::unique_ptr transform; + MethodConstraints constraints; + + TransformImplementation(const ITransform *transform, const MethodConstraints &constraints = MethodConstraints::None) + : transform(transform), constraints(constraints) + { + } +}; + +template +const TransformImplementation *implementation_list(void); + +} // namespace output_transform + +namespace{ + +template +constexpr T iceildiv(T num, T den) +{ + return (num + den - 1) / den; +} + +template +constexpr T iroundup(T num, T den) +{ + return den * iceildiv(num, den); +} + +} + +template +inline std::vector get_weight_transforms( + const CPUInfo *ci, const ConvolutionArgs &conv_args, const WinogradConfig *cfg +) +{ + // Get target inner tile size + const auto target_inner_tile_rows = cfg->output_rows == 0 ? 0 : (conv_args.kernel_shape.rows + cfg->output_rows - 1); + const auto target_inner_tile_cols = cfg->output_cols == 0 ? 0 : (conv_args.kernel_shape.cols + cfg->output_cols - 1); + + std::vector weight_transforms; + for (auto impl = weight_transform::implementation_list(); + impl->transform.get() != nullptr; impl++) + { + // If this transform supports the requested kernel size, then add it to the + // list of weight transforms. + if ( + constraints_met(impl->constraints, ci, conv_args, cfg) && + impl->transform->get_kernel_rows() == conv_args.kernel_shape.rows && + impl->transform->get_kernel_cols() == conv_args.kernel_shape.cols && + (target_inner_tile_rows == 0 || target_inner_tile_rows == impl->transform->get_transformed_tile_rows()) && + (target_inner_tile_cols == 0 || target_inner_tile_cols == impl->transform->get_transformed_tile_cols()) && + (cfg->weight_transform_filter == "" || std::strstr(impl->transform->get_name().c_str(), cfg->weight_transform_filter.c_str())) + ) + { + weight_transforms.push_back(impl->transform.get()); + } + } + + return weight_transforms; +} + +template +inline std::vector get_input_transforms( + const CPUInfo *ci, const ConvolutionArgs &conv_args, const WinogradConfig *cfg +) +{ + // Get target inner tile size + const auto target_inner_tile_rows = cfg->output_rows == 0 ? 0 : (conv_args.kernel_shape.rows + cfg->output_rows - 1); + const auto target_inner_tile_cols = cfg->output_cols == 0 ? 0 : (conv_args.kernel_shape.cols + cfg->output_cols - 1); + + std::vector input_transforms; + for (auto impl = input_transform::implementation_list(); + impl->transform.get() != nullptr; impl++) + { + if( + constraints_met(impl->constraints, ci, conv_args, cfg) && + (target_inner_tile_rows == 0 || target_inner_tile_rows == impl->transform->get_input_rows()) && + (target_inner_tile_cols == 0 || target_inner_tile_cols == impl->transform->get_input_cols()) && + (cfg->input_transform_filter == "" || std::strstr(impl->transform->get_name().c_str(), cfg->input_transform_filter.c_str())) + ) + { + input_transforms.push_back(impl->transform.get()); + } + } + + return input_transforms; +} + +template +inline std::vector get_output_transforms( + const CPUInfo *ci, const ConvolutionArgs &conv_args, const WinogradConfig *cfg +) +{ + std::vector output_transforms; + for (auto impl = output_transform::implementation_list(); + impl->transform.get() != nullptr; impl++) + { + if( + output_transform_constraints_met(impl->transform.get(), impl->constraints, ci, conv_args, cfg) && + impl->transform->get_kernel_rows() == conv_args.kernel_shape.rows && + impl->transform->get_kernel_cols() == conv_args.kernel_shape.cols && + (cfg->output_rows == 0 || cfg->output_rows == impl->transform->get_output_rows()) && + (cfg->output_cols == 0 || cfg->output_cols == impl->transform->get_output_cols()) && + (cfg->output_transform_filter == "" || std::strstr(impl->transform->get_name().c_str(), cfg->output_transform_filter.c_str())) + ) + { + output_transforms.push_back(impl->transform.get()); + } + } + + return output_transforms; +} + +template +bool get_implementation( + WinogradImpl &dest, // Destination for the selected implementation + const CPUInfo *ci, + const ConvolutionArgs &conv_args, + int max_threads, + bool fast_mode, + const WinogradConfig *cfg, + const arm_gemm::GemmConfig *gemm_cfg +) +{ + // Get vectors of valid weight, input and output transforms; then select the + // combination which produces the biggest output tile. + const auto weight_transforms = get_weight_transforms(ci, conv_args, cfg); + const auto input_transforms = get_input_transforms(ci, conv_args, cfg); + const auto output_transforms = get_output_transforms(ci, conv_args, cfg); + + // Now attempt to select a complete set of Winograd transformations which can + // solve the problem. Work backwards from the output transform to find + // matching input implementations. + bool success = false; + for (auto output_transform = output_transforms.cbegin(); + !success && output_transform != output_transforms.cend(); + output_transform++) + { + // Look for matching weight transforms, if we find one then we look for + // matching input transforms. + for (auto weight_transform = weight_transforms.cbegin(); + !success && weight_transform != weight_transforms.cend(); + weight_transform++) + { + // If this weight transform is compatible, then look for a matching input + // transform + if ((*output_transform)->get_input_rows() == (*weight_transform)->get_transformed_tile_rows() && + (*output_transform)->get_input_cols() == (*weight_transform)->get_transformed_tile_cols()) + { + for (auto input_transform = input_transforms.cbegin(); + !success && input_transform != input_transforms.cend(); + input_transform++) + { + // If the input transform is suitable, then set the configuration and + // indicate success. + if ((*input_transform)->get_input_rows() == (*output_transform)->get_input_rows() && + (*input_transform)->get_input_cols() == (*output_transform)->get_input_cols()) + { + dest.output_transform = *output_transform; + dest.input_transform = *input_transform; + dest.weight_transform = *weight_transform; + success = true; + } + } + } + } + } + + if (!success) + { + return false; + } + + // If we're able to construct the Winograd elements, then specify the GEMM + // arguments required to perform the multiply-accumulate step of the + // convolution. + const auto n_output_row_tiles = iceildiv(conv_args.output_shape.rows, dest.output_transform->get_output_rows()); + const auto n_output_col_tiles = iceildiv(conv_args.output_shape.cols, dest.output_transform->get_output_cols()); + const auto n_output_patches = n_output_row_tiles * n_output_col_tiles; + + const int n_multis = dest.input_transform->get_input_rows() * + dest.input_transform->get_input_cols(); + + dest.gemm_args.reset(new arm_gemm::GemmArgs( + ci, + n_output_patches, // M + conv_args.n_output_channels, // N + conv_args.n_input_channels, // K + 1, // K-sections + conv_args.n_batches, // # Batches + n_multis, + false, // Indirect input + {}, // No activation + max_threads, + fast_mode, + gemm_cfg + )); + + // Also provide hints for the Winograd memory layout + auto &ws = dest.winograd_spec; + ws.weight_ld_row = iroundup(conv_args.n_output_channels, 4u); + ws.weight_ld_matrix = conv_args.n_input_channels * ws.weight_ld_row; + ws.weight_matrix_size_bytes = n_multis * ws.weight_ld_matrix * sizeof(TWinogradIn); + + ws.input_ld_row = iroundup(conv_args.n_input_channels, 4u); + ws.input_ld_matrix = iroundup(n_output_patches, 4u) * ws.input_ld_row; + ws.input_ld_batch = n_multis * ws.input_ld_matrix; + ws.input_matrix_size_bytes = conv_args.n_batches * ws.input_ld_batch * sizeof(TWinogradIn); + + ws.output_ld_row = ws.weight_ld_row; + ws.output_ld_matrix = n_output_patches * ws.output_ld_row; + ws.output_ld_batch = n_multis * ws.output_ld_matrix; + ws.output_matrix_size_bytes = conv_args.n_batches * ws.output_ld_batch * sizeof(TWinogradOut); + + return true; +} + +} // namespace winograd +} // namespace arm_conv diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_layer.hpp b/src/core/NEON/kernels/convolution/winograd/winograd_layer.hpp deleted file mode 100644 index 52ff7b3798..0000000000 --- a/src/core/NEON/kernels/convolution/winograd/winograd_layer.hpp +++ /dev/null @@ -1,207 +0,0 @@ -/* - * Copyright (c) 2017-2019 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -#pragma once -#include "arm_gemm_local.hpp" -#include "arm_gemm.hpp" -#include "winograd.hpp" - -namespace winograd -{ - - -class IWinogradConvolutionLayer -{ - public: - virtual ~IWinogradConvolutionLayer() = default; - - virtual unsigned int weight_transform_get_window(void) const = 0; - virtual void weight_transform_run(unsigned int start, unsigned int stop) = 0; - - virtual IInputTransform& input_transform(void) = 0; // Expose the input transform - virtual IOutputTransform& output_transform(void) = 0; // Expose the output transform - virtual arm_gemm::IGemmCommon *gemm(void) = 0; // Expose the underlying GEMM -}; - -/** Example of how to construct an ACL-like interface. - * - * Use `get_weight_storage_size`, `get_input_storage_size` and - * `get_output_storage_size` to allocate memory for the convolution engine. - * Then create a `WinogradConvolutionLayer`. - * - * Initialise the weights using `weights_transform.run(...)`. - * - * For each inference: - * 1. Transform the inputs to the Winograd domain using `input_transform.run(...)` - * 2. Perform a number of GEMMs using `gemms.run(...)` - * 3. Transform the output to the spatial domain using `output_transform.run(...)` - */ -template -class WinogradConvolutionLayer : public IWinogradConvolutionLayer -{ - public: - using WinogradBase = winograd::WinogradGEMM; - using WeightsTransform = typename WinogradBase::template WeightsTransform; - using InputTransform = typename WinogradBase::template InputTransform; - using WinogradConv = typename WinogradBase::template Convolution; - using OutputTransform = typename WinogradBase::template OutputTransform; - - private: - static constexpr int InnerTileRows = OutputTileRows + KernelRows - 1; - static constexpr int InnerTileCols = OutputTileCols + KernelCols - 1; - static constexpr int N_GEMMS = InnerTileRows * InnerTileCols; - - const int _n_output_rows, _n_output_cols; - const int _kernel_matrix_stride, _kernel_matrix_row_stride; - const int _input_matrix_stride, _input_matrix_row_stride; - const int _output_matrix_stride, _output_matrix_row_stride; - const int _tile_rows, _tile_cols; - const int _m, _k, _n; - - WeightsTransform weights_transform; /** Operator to transform weights to Winograd domain. */ - InputTransform _input_transform; /** Operator to transform input to Winograd domain. */ - const arm_gemm::GemmArgs gemm_args; - arm_gemm::UniqueGemmCommon gemms; /** Operator to perform multiple GEMMs. */ - OutputTransform _output_transform; /** Operator to transform output from Winograd domain. */ - - public: - - /** Determine how much memory (in units of TIn) to allocate for the - * transformed weights. - */ - static unsigned int get_weight_storage_size( - const int n_output_channels, /** Number of output feature maps. */ - const int n_input_channels /** Number of input feature maps. */ - ); - - static unsigned int get_weight_stride( - const int n_output_channels, /** Number of output feature maps. */ - const int n_input_channels /** Number of input feature maps. */ - ); - - static unsigned int get_weight_multi_stride( - const int n_output_channels, /** Number of output feature maps. */ - const int n_input_channels /** Number of input feature maps. */ - ); - - /** Determine how much memory (in units of TIn) to allocate for the - * transformed input. - */ - static unsigned int get_input_storage_size( - const int n_batches, /** Number of batches in the input tensor. */ - const int n_channels, /** Number of feature maps in the input tensor. */ - const int n_rows, /** Number of rows in each feature map. */ - const int n_cols, /** Number of columns in each feature map. */ - const bool same_padding /** Use "SAME" padding, otherwise use "VALID". */ - ); - - /** Get the row stride for the A matrix in the Winograd domain. */ - static unsigned int get_input_stride( - const int n_batches, /** Number of batches in the input tensor. */ - const int n_channels, /** Number of feature maps in the input tensor. */ - const int n_rows, /** Number of rows in each feature map. */ - const int n_cols, /** Number of columns in each feature map. */ - const bool same_padding /** Use "SAME" padding, otherwise use "VALID". */ - ); - - /** Get the stride between A matrices in the Winograd domain. */ - static unsigned int get_input_multi_stride( - const int n_batches, /** Number of batches in the input tensor. */ - const int n_channels, /** Number of feature maps in the input tensor. */ - const int n_rows, /** Number of rows in each feature map. */ - const int n_cols, /** Number of columns in each feature map. */ - const bool same_padding /** Use "SAME" padding, otherwise use "VALID". */ - ); - - /** Determine how much memory (in units of TOut) to allocate for the - * (Winograd domain) output. - */ - static unsigned int get_output_storage_size( - const int n_batches, /** Number of batches in the output tensor. */ - const int n_rows, /** Number of rows in each feature map of the input tensor. */ - const int n_cols, /** Number of columns in each feature map of the input tensor. */ - const int n_output_channels, /** Number of feature maps in the output tensor. */ - const bool same_padding /** Use "SAME" padding, otherwise use "VALID". */ - ); - - static unsigned int get_output_stride( - const int n_batches, /** Number of batches in the output tensor. */ - const int n_rows, /** Number of rows in each feature map of the input tensor. */ - const int n_cols, /** Number of columns in each feature map of the input tensor. */ - const int n_output_channels, /** Number of feature maps in the output tensor. */ - const bool same_padding /** Use "SAME" padding, otherwise use "VALID". */ - ); - - static unsigned int get_output_multi_stride( - const int n_batches, /** Number of batches in the output tensor. */ - const int n_rows, /** Number of rows in each feature map of the input tensor. */ - const int n_cols, /** Number of columns in each feature map of the input tensor. */ - const int n_output_channels, /** Number of feature maps in the output tensor. */ - const bool same_padding /** Use "SAME" padding, otherwise use "VALID". */ - ); - - /** Get the shape (rows, cols) of a feature map of the output tensor. */ - static std::pair get_output_feature_map_shape( - const int n_input_rows, /** Number of rows in the input feature map. */ - const int n_input_cols, /** Number of columns in the input feature map. */ - const bool same_padding /** Use "SAME" padding, otherwise use "VALID". */ - ); - - /** Create a new Winograd convolution layer. - */ - WinogradConvolutionLayer( - const CPUInfo &cpuinfo, /** Describes CPU properties. */ - const int n_threads, /** Maximum number of threads used to execute the convolution. */ - const int n_batches, /** Number of batches in the input and output tensors. */ - const int n_input_channels, /** Number of feature maps in a batch of the input tensor. */ - const int n_input_rows, /** Number of rows in a feature map of the input tensor. */ - const int n_input_cols, /** Number of columns in a feature map of the input tensor. */ - const int n_output_channels, /** Number of feature maps in the output tensor. */ - const bool same_padding, /** Use "SAME" padding, otherwise use "VALID". */ - const arm_gemm::Activation &activation, - const TIn* const weights, /** Pointer to weight tensor in spatial domain. Must be ordered as "Height x Rows x Input Feature Maps x Output Feature Maps. */ - TInGEMM* const weights_storage, /** Pointer to storage for weight tensor in the Winograd domain. Must be at least the size returned by `get_weight_storage_size`. */ - const TIn* const input, /** Pointer to NHWC ordered input tensor, in the spatial domain. */ - TInGEMM* const winograd_input, /** Pointer to working space for the input tensor in the Winograd domain. Must be at least the size returned by `get_input_storage_size`. */ - const TOut* const biases, /** Pointer to biases vector. Pass nullptr if no bias is provided. */ - TOut* const output, /** Pointer to NHWC ordered output tensor, in the spatial domain. */ - TOutGEMM* const winograd_output, /** Pointer to working space for the output tensor in the Winograd domain. Must be at least the size returned by `get_output_storage_size`. */ - const bool pretranspose_B=true, /** Hint that the B matrix can be pretransposed. */ - arm_gemm::GemmConfig *gemm_cfg=nullptr /** Pointer to GEMM configuration. */ - ); - - /* Utility methods for interacting with the layer. */ - unsigned int weight_transform_get_window(void) const; - void weight_transform_run(const unsigned int start, const unsigned int stop); - - IInputTransform& input_transform(void); - IOutputTransform& output_transform(void); - - /* Get a pointer to the GEMM underlying the Winograd transform. */ - arm_gemm::IGemmCommon *gemm(void); -}; - -} diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/input.hpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/input.hpp deleted file mode 100644 index c0f50beb2c..0000000000 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/input.hpp +++ /dev/null @@ -1,268 +0,0 @@ -/* - * Copyright (c) 2017-2019 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -#pragma once - -#include - -#include "padding.hpp" -#include "utils.hpp" -#include "winograd.hpp" - -#define MEMBERFN(RTYPE) template <\ - int InnerTileRows, int InnerTileCols,\ - typename TIn, typename TOut, WinogradRoots Roots\ -> RTYPE InputTransform - - -#define Nx1MEMBERFN(RTYPE) template <\ - int InnerTileRows, typename TIn, typename TOut, WinogradRoots Roots\ -> RTYPE InputTransform - -namespace winograd -{ - -MEMBERFN()::InputTransform( - const int kernel_rows, - const int kernel_cols, - const int n_batches, - const int n_rows, - const int n_cols, - const int n_channels, - const int padding_top, - const int padding_left, - const int padding_bottom, - const int padding_right -) : _n_batches(n_batches), _n_rows(n_rows), _n_cols(n_cols), _n_channels(n_channels), - _inptr(nullptr), _outptr(nullptr), - _overlap_rows(kernel_rows - 1), _overlap_cols(kernel_cols - 1), - _padding_top(padding_top), _padding_left(padding_left), _padding_bottom(padding_bottom), _padding_right(padding_right), - _tiles_M(iceildiv(padding_top + n_rows + padding_bottom - kernel_rows + 1, InnerTileRows - kernel_rows + 1)), - _tiles_N(iceildiv(padding_left + n_cols + padding_right - kernel_cols + 1, InnerTileCols - kernel_cols + 1)), - _matrix_stride(0), _matrix_row_stride(0), _matrix_batch_stride(0), - _in_col_stride(0), _in_row_stride(0), _in_batch_stride(0), - _working_space_col_stride(n_channels), - _working_space_row_stride(InnerTileCols * _working_space_col_stride), - _working_space(nullptr) -{ -} - -MEMBERFN(void)::set_input_tensor(const void* const inptr) -{ - set_input_tensor(inptr, _n_channels); -} - -MEMBERFN(void)::set_input_tensor(const void* const inptr, const int ldcol) -{ - set_input_tensor(inptr, _n_cols * ldcol, ldcol); -} - -MEMBERFN(void)::set_input_tensor(const void* const inptr, const int ldrow, const int ldcol) -{ - set_input_tensor(inptr, _n_rows * ldrow, ldrow, ldcol); -} - -MEMBERFN(void)::set_input_tensor(const void* const inptr, const int ldbatch, const int ldrow, const int ldcol) -{ - _inptr = static_cast(inptr); - _in_batch_stride = ldbatch; - _in_row_stride = ldrow; - _in_col_stride = ldcol; -} - -MEMBERFN(void)::set_output_matrices(void * const mptr, const int ldmatrix, const int ldrow) -{ - _outptr = static_cast(mptr); - _matrix_stride = ldmatrix; - _matrix_row_stride = ldrow; - _matrix_batch_stride = _tiles_M * _tiles_N * ldrow; -} - -Nx1MEMBERFN()::InputTransform( - const int kernel_rows, - const int kernel_cols, - const int n_batches, - const int n_rows, - const int n_cols, - const int n_channels, - const int padding_top, - const int padding_left, - const int padding_bottom, - const int padding_right -) : InputTransform<1, InnerTileRows, TIn, TOut, Roots>::InputTransform( - /* Transpose rows and columns */ - kernel_cols, kernel_rows, n_batches, n_cols, n_rows, n_channels, - padding_left, padding_top, padding_right, padding_bottom - ) -{ -} - -Nx1MEMBERFN(void)::set_input_tensor(const void* const inptr) -{ - set_input_tensor(inptr, this->_n_channels); -} - -Nx1MEMBERFN(void)::set_input_tensor(const void* const inptr, const int ldcol) -{ - set_input_tensor(inptr, this->_n_cols * ldcol, ldcol); -} - -Nx1MEMBERFN(void)::set_input_tensor(const void* const inptr, const int ldrow, const int ldcol) -{ - set_input_tensor(inptr, this->_n_rows * ldrow, ldrow, ldcol); -} - -Nx1MEMBERFN(void)::set_input_tensor(const void* const inptr, const int ldbatch, const int ldrow, const int ldcol) -{ - // Transpose row and column strides - Base::set_input_tensor(inptr, ldbatch, ldcol, ldrow); -} - -MEMBERFN(size_t)::get_working_space_size(const unsigned int nthreads) const -{ - return sizeof(TIn) * InnerTileRows * _working_space_row_stride * nthreads; -} - -MEMBERFN(void)::set_working_space(void * const buffer) -{ - _working_space = static_cast(buffer); -} - -MEMBERFN(unsigned int)::get_window(void) const -{ - return iceildiv(_n_channels, WINDOW_BLOCK); -} - -MEMBERFN(void)::run( - const unsigned int start, - const unsigned int stop, - const unsigned int threadid -) -{ - // Determine the channels on which to work - if (start >= get_window()) - { - return; // No work to do beyond the end of the window - } - const unsigned int start_channel = start * WINDOW_BLOCK; - const unsigned int stop_channel = std::min(_n_channels , stop * WINDOW_BLOCK); - const unsigned int n_channels = stop_channel - start_channel; - - // Loop over batches - for (int batch = 0; batch < _n_batches; batch++) - { - const TIn* const inptr_batch = _inptr + start_channel + batch*_in_batch_stride; - TOut* const outptr_batch = _outptr + start_channel + batch*_matrix_batch_stride; - - // Loop over rows of tiles - for (int tile_i = 0; tile_i < _tiles_M; tile_i++) - { - // Compute the starting and ending row of pixels within the row of tiles, - // hence compute the padding to apply to the top and bottom of each tile. - const int row_top = tile_i * (InnerTileRows - _overlap_rows) - _padding_top; - const int row_bottom = row_top + InnerTileRows; - const int row_pad_top = std::max(0, _padding_top - tile_i * (InnerTileRows - _overlap_rows)); - const int row_pad_bottom = std::max(0, row_bottom - _n_rows); - - // Get a pointer to the start of the row. - const int row_offset = std::min(0, row_pad_top - _padding_top); - const TIn* const inptr_row = inptr_batch + _in_row_stride*(row_offset + tile_i*(InnerTileRows - _overlap_rows)); - TOut* const outptr_row = outptr_batch + tile_i*_tiles_N*_matrix_row_stride; - - // Loop over tiles within the row - for (int tile_j = 0; tile_j < _tiles_N; tile_j++) - { - // Compute the starting and ending column of pixels within the tile, - // hence compute the padding to apply to the left and right of the - // tile. - const int tile_left = tile_j * (InnerTileCols - _overlap_cols) - _padding_left; - const int tile_right = tile_left + InnerTileCols; - const int tile_pad_left = std::max(0, _padding_left - tile_j * (InnerTileCols - _overlap_cols)); - const int tile_pad_right = std::max(0, tile_right - _n_cols); - - // Get a pointer to the start of the tile. - const int col_offset = std::min(0, tile_pad_left - _padding_left); - const TIn* const inptr_tile = inptr_row + _in_col_stride*(col_offset + tile_j*(InnerTileCols - _overlap_cols)); - TOut* const outptr_tile = outptr_row + tile_j * _matrix_row_stride; - - // Transform the tile, applying padding if necessary. - if (row_pad_top || tile_pad_left || row_pad_bottom || tile_pad_right) - { - transform_padded_tile( - threadid, n_channels, outptr_tile, inptr_tile, - row_pad_top, tile_pad_left, row_pad_bottom, tile_pad_right - ); - } - else - { - transform_unpadded_tile(threadid, n_channels, outptr_tile, inptr_tile); - } - } - } - } -} - -MEMBERFN(void)::transform_unpadded_tile( - const unsigned int /* threadid unused */, - const int n_channels, - TOut * const outptr, - const TIn * const inptr -) -{ - transform_tile( - n_channels, inptr, _in_row_stride, _in_col_stride, outptr, _matrix_stride - ); -} - -MEMBERFN(void)::transform_padded_tile( - const unsigned int threadid, - const int n_channels, - TOut * const outptr, - const TIn * const inptr, - const int padding_top, - const int padding_left, - const int padding_bottom, - const int padding_right -) -{ - padding::copy_and_pad_tile( - InnerTileRows, InnerTileCols, n_channels, - inptr, _in_row_stride, _in_col_stride, - static_cast(get_working_space(threadid)), _working_space_row_stride, _working_space_col_stride, - padding_top, padding_left, padding_bottom, padding_right - ); - - transform_tile( - n_channels, static_cast(get_working_space(threadid)), - _working_space_row_stride, _working_space_col_stride, - outptr, _matrix_stride - ); -} - -MEMBERFN(void *)::get_working_space(const unsigned int threadid) const -{ - return _working_space + InnerTileRows * _working_space_row_stride * threadid; -} - -} // namespace winograd diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_4x4_fp16_fp16_integers.cpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_4x4_fp16_fp16_integers.cpp deleted file mode 100644 index 5e6ac97121..0000000000 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/input_4x4_fp16_fp16_integers.cpp +++ /dev/null @@ -1,257 +0,0 @@ -/* - * Copyright (c) 2020 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - -#include "input.hpp" -#include "arm.hpp" - -namespace winograd -{ - -template <> -void InputTransform<4, 4, __fp16, __fp16, WinogradRoots::Integers>::transform_tile( - const int n_channels, - const __fp16* const input_base, - const int input_row_stride, - const int input_col_stride, - __fp16* outptr, - const int matrix_stride -) -{ - constexpr int inner_tile_rows = 4, inner_tile_cols = 4; - - // Get pointers into the input tile - const __fp16 *x_ptrs[inner_tile_rows][inner_tile_cols]; - for (int i = 0, xi = 0; i < inner_tile_rows; i++, xi++) - { - // Get a pointer into the row - const __fp16* const row_ptr = input_base + xi*input_row_stride; - - for (int j = 0, xj = 0; j < inner_tile_cols; j++, xj++) - { - x_ptrs[i][j] = row_ptr + xj*input_col_stride; - } - } - - // Matrices used/computed in this kernel. - __fp16 x[inner_tile_rows][inner_tile_cols]; - __fp16 XTx[inner_tile_rows][inner_tile_cols]; - __fp16 U[inner_tile_rows][inner_tile_cols]; - - for (int i = 0; i < inner_tile_rows; i++) - { - for (int j = 0; j < inner_tile_cols; j++) - { - x[i][j] = XTx[i][j] = 0.0f; - } - } - - // Perform the Winograd input transformation for each channel in the input - // tensor. - int channels_remaining = n_channels; -#ifdef __aarch64__ - for (; channels_remaining >= 8; channels_remaining -= 8) - { - // Matrices used/computed in this kernel. - float16x8_t x[inner_tile_rows][inner_tile_cols]; - float16x8_t XTx[inner_tile_rows][inner_tile_cols]; - float16x8_t U[inner_tile_rows][inner_tile_cols]; - - for (int i = 0; i < inner_tile_rows; i++) - { - for (int j = 0; j < inner_tile_cols; j++) - { - x[i][j] = vdupq_n_f16(0.0f); - XTx[i][j] = vdupq_n_f16(0.0f); - } - } - - // Load x - for (int i = 0; i < inner_tile_rows; i++) - { - for (int j = 0; j < inner_tile_cols; j++) - { - x[i][j] = vld1q_f16(x_ptrs[i][j]); - x_ptrs[i][j] += 8; - } - } - - // Compute XT . x - for (int j = 0; j < inner_tile_cols; j++) - { - // XTx[0][j] = x[0][j] - x[2][j]; - XTx[0][j] = vsubq_f16(x[0][j], x[2][j]); - - // XTx[1][j] = x[1][j] + x[2][j]; - XTx[1][j] = vaddq_f16(x[1][j], x[2][j]); - - // XTx[2][j] = x[2][j] - x[1][j]; - XTx[2][j] = vsubq_f16(x[2][j], x[1][j]); - - // XTx[3][j] = x[1][j] - x[3][j]; - XTx[3][j] = vsubq_f16(x[1][j], x[3][j]); - } - - // Compute U = XT . x . X - for (int i = 0; i < inner_tile_rows; i++) - { - // U[i][0] = XTx[i][0] - XTx[i][2]; - U[i][0] = vsubq_f16(XTx[i][0], XTx[i][2]); - - // U[i][1] = XTx[i][1] + XTx[i][2]; - U[i][1] = vaddq_f16(XTx[i][1], XTx[i][2]); - - // U[i][2] = XTx[i][2] - XTx[i][1]; - U[i][2] = vsubq_f16(XTx[i][2], XTx[i][1]); - - // U[i][3] = XTx[i][1] - XTx[i][3]; - U[i][3] = vsubq_f16(XTx[i][1], XTx[i][3]); - } - - // Store the transformed matrix - for (int i = 0, m = 0; i < inner_tile_rows; i++) - { - for (int j = 0; j < inner_tile_cols; j++, m++) - { - vst1q_f16(outptr + m*matrix_stride, U[i][j]); - } - } - outptr += 8; - } -#endif // __aarch64__ -#ifdef __arm_any__ - for (; channels_remaining >= 4; channels_remaining -= 4) - { - // Matrices used/computed in this kernel. - float16x4_t x[inner_tile_rows][inner_tile_cols]; - float16x4_t XTx[inner_tile_rows][inner_tile_cols]; - float16x4_t U[inner_tile_rows][inner_tile_cols]; - - for (int i = 0; i < inner_tile_rows; i++) - { - for (int j = 0; j < inner_tile_cols; j++) - { - x[i][j] = vdup_n_f16(0.0f); - XTx[i][j] = vdup_n_f16(0.0f); - } - } - - // Load x - for (int i = 0; i < inner_tile_rows; i++) - { - for (int j = 0; j < inner_tile_cols; j++) - { - x[i][j] = vld1_f16(x_ptrs[i][j]); - x_ptrs[i][j] += 4; - } - } - - // Compute XT . x - for (int j = 0; j < inner_tile_cols; j++) - { - // XTx[0][j] = x[0][j] - x[2][j]; - XTx[0][j] = vsub_f16(x[0][j], x[2][j]); - - // XTx[1][j] = x[1][j] + x[2][j]; - XTx[1][j] = vadd_f16(x[1][j], x[2][j]); - - // XTx[2][j] = x[2][j] - x[1][j]; - XTx[2][j] = vsub_f16(x[2][j], x[1][j]); - - // XTx[3][j] = x[1][j] - x[3][j]; - XTx[3][j] = vsub_f16(x[1][j], x[3][j]); - } - - // Compute U = XT . x . X - for (int i = 0; i < inner_tile_rows; i++) - { - // U[i][0] = XTx[i][0] - XTx[i][2]; - U[i][0] = vsub_f16(XTx[i][0], XTx[i][2]); - - // U[i][1] = XTx[i][1] + XTx[i][2]; - U[i][1] = vadd_f16(XTx[i][1], XTx[i][2]); - - // U[i][2] = XTx[i][2] - XTx[i][1]; - U[i][2] = vsub_f16(XTx[i][2], XTx[i][1]); - - // U[i][3] = XTx[i][1] - XTx[i][3]; - U[i][3] = vsub_f16(XTx[i][1], XTx[i][3]); - } - - // Store the transformed matrix - for (int i = 0, m = 0; i < inner_tile_rows; i++) - { - for (int j = 0; j < inner_tile_cols; j++, m++) - { - vst1_f16(outptr + m*matrix_stride, U[i][j]); - } - } - outptr += 4; - } -#endif // __arm_any__ - for (; channels_remaining; channels_remaining--) - { - // Load x - for (int i = 0; i < inner_tile_rows; i++) - { - for (int j = 0; j < inner_tile_cols; j++) - { - x[i][j] = *(x_ptrs[i][j]++); - } - } - - // Compute XT . x - for (int j = 0; j < inner_tile_cols; j++) - { - XTx[0][j] = x[0][j] - x[2][j]; - XTx[1][j] = x[1][j] + x[2][j]; - XTx[2][j] = x[2][j] - x[1][j]; - XTx[3][j] = x[1][j] - x[3][j]; - } - - // Compute U = XT . x . X - for (int i = 0; i < inner_tile_rows; i++) - { - U[i][0] = XTx[i][0] - XTx[i][2]; - U[i][1] = XTx[i][1] + XTx[i][2]; - U[i][2] = XTx[i][2] - XTx[i][1]; - U[i][3] = XTx[i][1] - XTx[i][3]; - } - - // Store the transformed matrix - for (int i = 0, m = 0; i < inner_tile_rows; i++) - { - for (int j = 0; j < inner_tile_cols; j++, m++) - { - *(outptr + m*matrix_stride) = U[i][j]; - } - } - outptr++; - } -} - -template class InputTransform<4, 4, __fp16, __fp16, WinogradRoots::Integers>; - -} // namespace -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/kernel.hpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/kernel.hpp deleted file mode 100644 index 27d20811d6..0000000000 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/kernel.hpp +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Copyright (c) 2019 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -#pragma once -#include "winograd.hpp" -using namespace winograd; - -#define MEMBERFN(RTYPE) template <\ - int KernelRows, int KernelCols, int InnerTileRows, int InnerTileCols, typename TIn, typename TOut, WinogradRoots Roots\ -> RTYPE WeightTransform - -MEMBERFN()::WeightTransform( - const int n_output_channels, - const int n_input_channels -) : _n_output_channels(n_output_channels), _n_input_channels(n_input_channels), - _matrices(nullptr), _matrix_stride(0), _matrix_row_stride(0), _weights(nullptr) -{ - -} - -MEMBERFN(void)::set_weight_tensor(const void * const weights) -{ - _weights = static_cast(weights); -} - -MEMBERFN(void)::set_output_matrices(void * const mptr, const int ldmatrix, const int ldrow) -{ - _matrices = static_cast(mptr); - _matrix_stride = ldmatrix; - _matrix_row_stride = ldrow; -} - -MEMBERFN(size_t)::get_working_space_size(unsigned int) const -{ - return 0; -} - -MEMBERFN(void)::set_working_space(void *) -{ -} - -MEMBERFN(unsigned int)::get_window(void) const -{ - // TODO When the weights transform supports multithreading, return the number - // of output channels. For now we return 1 to indicate that the weights must - // be transformed as a single block. - // return n_output_channels; - return 1; -} - -MEMBERFN(void)::run(const unsigned int, const unsigned int, unsigned int) -{ - execute( - _n_output_channels, _n_input_channels, _weights, - _matrices, _matrix_stride, _matrix_row_stride - ); -} diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp deleted file mode 100644 index c1fb559b1d..0000000000 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/output.hpp +++ /dev/null @@ -1,252 +0,0 @@ -/* - * Copyright (c) 2017-2019 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -#pragma once - -#include -#include "winograd.hpp" -#include "padding.hpp" -#include "utils.hpp" - -#define MEMBERFN(RTYPE) template<\ - int KernelRows, int KernelCols, int InnerTileRows, int InnerTileCols,\ - typename TIn, typename TOut, WinogradRoots Roots\ -> RTYPE OutputTransform - -#define Nx1MEMBERFN(RTYPE) template<\ - int KernelRows, int InnerTileRows, typename TIn, typename TOut, WinogradRoots Roots\ -> RTYPE OutputTransform - -namespace winograd -{ - -MEMBERFN() -::OutputTransform(const int n_batches, const int n_rows, const int n_cols, - const int n_channels, const arm_gemm::Activation &activation) - : _n_batches(n_batches), _n_rows(n_rows), _n_cols(n_cols), - _n_channels(n_channels), - _output_min((activation.type == arm_gemm::Activation::Type::ReLU || - activation.type == arm_gemm::Activation::Type::BoundedReLU) - ? static_cast(0.0f) : TypeBounds::lower()), - _output_max((activation.type == arm_gemm::Activation::Type::BoundedReLU) - ? static_cast(activation.param1) : TypeBounds::upper()), - _matrix_base(nullptr), _biases(nullptr), _matrix_stride(0), - _matrix_row_stride(0), _matrix_batch_stride(0), _outptr(nullptr), - _tiles_M(iceildiv(n_rows, output_tile_rows)), - _tiles_N(iceildiv(n_cols, output_tile_cols)), _out_col_stride(0), - _out_row_stride(0), _out_batch_stride(0), - _working_space_col_stride(n_channels), - _working_space_row_stride(output_tile_cols * _working_space_col_stride), - _working_space(nullptr) {} - -MEMBERFN(void)::set_input_matrices(const void * const mptr, const int ldmatrix, const int ldrow) -{ - _matrix_base = static_cast(mptr); - _matrix_stride = ldmatrix; - _matrix_row_stride = ldrow; - _matrix_batch_stride = _tiles_M * _tiles_N * ldrow; -} - -MEMBERFN(void)::set_bias(const void * const bias) -{ - _biases = static_cast(bias); -} - -MEMBERFN(void)::set_output_tensor(void * const outptr) -{ - set_output_tensor(outptr, _n_channels); -} - -MEMBERFN(void)::set_output_tensor(void * const outptr, const int ldcol) -{ - set_output_tensor(outptr, _n_cols * ldcol, ldcol); -} - -MEMBERFN(void)::set_output_tensor(void * const outptr, const int ldrow, const int ldcol) -{ - set_output_tensor(outptr, _n_rows * ldrow, ldrow, ldcol); -} - -MEMBERFN(void)::set_output_tensor(void * const outptr, const int ldbatch, const int ldrow, const int ldcol) -{ - _outptr = static_cast(outptr); - _out_batch_stride = ldbatch; - _out_row_stride = ldrow; - _out_col_stride = ldcol; -} - -Nx1MEMBERFN()::OutputTransform( - const int n_batches, - const int n_rows, - const int n_cols, - const int n_channels, - const arm_gemm::Activation &activation -) : OutputTransform<1, KernelRows, 1, InnerTileRows, TIn, TOut, Roots>::OutputTransform( - n_batches, n_cols, n_rows, n_channels, activation /* Transpose rows and columns */ - ) -{ -} - -Nx1MEMBERFN(void)::set_output_tensor(void * const outptr) -{ - set_output_tensor(outptr, this->_n_channels); -} - -Nx1MEMBERFN(void)::set_output_tensor(void * const outptr, const int ldcol) -{ - set_output_tensor(outptr, this->_n_cols * ldcol, ldcol); -} - -Nx1MEMBERFN(void)::set_output_tensor(void * const outptr, const int ldrow, const int ldcol) -{ - set_output_tensor(outptr, this->_n_rows * ldrow, ldrow, ldcol); -} - -Nx1MEMBERFN(void)::set_output_tensor(void * const outptr, const int ldbatch, const int ldrow, const int ldcol) -{ - // Transpose rows and columns - Base::set_output_tensor(outptr, ldbatch, ldcol, ldrow); -} - -MEMBERFN(size_t)::get_working_space_size(const unsigned int nthreads) const -{ - return sizeof(TOut) * output_tile_rows * _working_space_row_stride * nthreads; -} - -MEMBERFN(void)::set_working_space(void * const buffer) -{ - _working_space = static_cast(buffer); -} - -MEMBERFN(unsigned int)::get_window(void) const -{ - return iceildiv(_n_channels, WINDOW_BLOCK); -} - -MEMBERFN(void)::run( - const unsigned int start, - const unsigned int stop, - const unsigned int threadid -) -{ - // Determine the channels on which to work - if (start >= get_window()) - { - return; // No work to do beyond the end of the window - } - const unsigned int start_channel = start * WINDOW_BLOCK; - const unsigned int stop_channel = std::min(_n_channels, stop * WINDOW_BLOCK); - const unsigned int n_channels = stop_channel - start_channel; - - const auto matrix_tile_col_stride = _matrix_row_stride; - const auto matrix_tile_row_stride = _tiles_N * matrix_tile_col_stride; - - const TOut* const bptr = (_biases == nullptr) ? nullptr : _biases + start_channel; - - // Loop over batches - for (int batch = 0; batch < _n_batches; batch++) - { - const TIn* const matrix_batch = _matrix_base + start_channel + batch * _matrix_batch_stride; - TOut* const outptr_batch = _outptr + start_channel + batch * _out_batch_stride; - - for (int tile_i = 0; tile_i < _tiles_M; tile_i++) - { - // Compute properties of the row of output tiles - const int row_pad_bottom = std::max(0, (tile_i + 1)*output_tile_rows - _n_rows); - const TIn* const matrix_tile_row = matrix_batch + tile_i * matrix_tile_row_stride; - TOut* const outptr_row = outptr_batch + tile_i * output_tile_rows * _out_row_stride; - - for (int tile_j = 0; tile_j < _tiles_N; tile_j++) - { - // Compute property of this specific tile - const int tile_pad_right = std::max(0, (tile_j + 1)*output_tile_cols - _n_cols); - const TIn* const matrix_tile = matrix_tile_row + tile_j * matrix_tile_col_stride; - TOut* const outptr_tile = outptr_row + tile_j * output_tile_cols * _out_col_stride; - - // Perform the transformation - if (row_pad_bottom || tile_pad_right) - { - transform_cropped_tile( - threadid, n_channels, outptr_tile, matrix_tile, bptr, - row_pad_bottom, tile_pad_right - ); - } - else - { - transform_uncropped_tile( - threadid, n_channels, outptr_tile, matrix_tile, bptr - ); - } - } - } - } -} - -MEMBERFN(void)::transform_uncropped_tile( - const unsigned int /* threadid unused */, - const int n_channels, - TOut * const outptr, - const TIn * const inptr, - const TOut * const biases -) -{ - transform_tile( - n_channels, inptr, _matrix_stride, biases, - outptr, _out_row_stride, _out_col_stride, - _output_min, _output_max - ); -} - -MEMBERFN(void)::transform_cropped_tile( - const unsigned int threadid, - const int n_channels, - TOut * const outptr, - const TIn * const inptr, - const TOut * const biases, - const int pad_bottom, - const int pad_right -) -{ - // Transform into working space and then copy the relevant section out. - TOut *wsptr = static_cast(get_working_space(threadid)); - transform_tile( - n_channels, inptr, _matrix_stride, biases, - wsptr, _working_space_row_stride, _working_space_col_stride, - _output_min, _output_max - ); - - padding::crop_and_copy_tile( - output_tile_rows, output_tile_cols, n_channels, - wsptr, _working_space_row_stride, _working_space_col_stride, - outptr, _out_row_stride, _out_col_stride, - 0u, 0u, pad_bottom, pad_right - ); -} - -MEMBERFN(void *)::get_working_space(const unsigned int threadid) const -{ - return _working_space + output_tile_rows * _working_space_row_stride * threadid; -} - -} // namespace winograd diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_2_7_fp32_fp32_integers.cpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_2_7_fp32_fp32_integers.cpp deleted file mode 100644 index 2ee377ceca..0000000000 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_2_7_fp32_fp32_integers.cpp +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright (c) 2019 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -#include "arm.hpp" -#include "kernel.hpp" - -namespace winograd -{ - -template <> -void WeightTransform<1, 7, 1, 8, float, float, WinogradRoots::Integers>::execute( - const int n_output_channels, - const int n_input_channels, - const float* const input, // NOTE: Data in HWIO order - float* const output, - const int matrix_stride, - const int matrix_row_stride -) -{ - // Get pointers to each cell of the weight tensor - const auto weight_col_stride = n_input_channels * n_output_channels; - const float *inptrs[kernel_cols]; - for (int j = 0; j < kernel_cols; j++) - { - inptrs[j] = input + j*weight_col_stride; - } - - // For each input channel - for (int ic = 0; ic < n_input_channels; ic++) - { - float *outptr = output + ic * matrix_row_stride; - - // For each output channel - int channels_remaining = n_output_channels; - for (; channels_remaining; channels_remaining--) - { - // Matrices used and computed in this kernel - float w[kernel_cols], V[inner_tile_cols]; - - // Read weights - for (int j = 0; j < kernel_cols; j++) - { - w[j] = *(inptrs[j]++); - } - - // Compute V = w WT - V[0] = (w[0]*-1) / 36.0f; - V[1] = (w[1]*-1 + w[3]*-1 + w[5]*-1 + w[0]*1 + w[2]*1 + w[4]*1 + w[6]*1) / 48.0f; - V[2] = (w[0]*1 + w[1]*1 + w[2]*1 + w[3]*1 + w[4]*1 + w[5]*1 + w[6]*1) / 48.0f; - V[3] = (w[0]*-1 + w[6]*-64 + w[4]*-16 + w[2]*-4 + w[1]*2 + w[3]*8 + w[5]*32) / 120.0f; - V[4] = (w[0]*-1 + w[6]*-64 + w[5]*-32 + w[4]*-16 + w[3]*-8 + w[2]*-4 + w[1]*-2) / 120.0f; - V[5] = (w[5]*-243 + w[3]*-27 + w[1]*-3 + w[2]*9 + w[4]*81 + w[6]*729 + w[0]*1) / 720.0f; - V[6] = (w[1]*3 + w[2]*9 + w[3]*27 + w[4]*81 + w[5]*243 + w[6]*729 + w[0]*1) / 720.0f; - V[7] = (w[6]*1) / 1.0f; - - // Store the transformed weights - for (int j = 0; j < inner_tile_cols; j++) - { - *(outptr + j*matrix_stride) = V[j]; - } - outptr++; - } - } -} - -template class WeightTransform<1, 7, 1, 8, float, float, WinogradRoots::Integers>; -template class WeightTransform<7, 1, 8, 1, float, float, WinogradRoots::Integers>; - -} // namespace winograd diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_2x2_3x3_fp32_fp32_integers.cpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_2x2_3x3_fp32_fp32_integers.cpp deleted file mode 100644 index 3fde4a7a6b..0000000000 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_2x2_3x3_fp32_fp32_integers.cpp +++ /dev/null @@ -1,220 +0,0 @@ -/* - * Copyright (c) 2019 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -#include "arm.hpp" -#include "kernel.hpp" - -namespace winograd -{ - -template <> -void WeightTransform<3, 3, 4, 4, float, float, WinogradRoots::Integers>::execute( - const int n_output_channels, - const int n_input_channels, - const float* const input, - float* const output, - const int matrix_stride, - const int matrix_row_stride -) -{ - constexpr int inner_tile_i = 4; - constexpr int inner_tile_j = 4; - - // Get pointers to each cell of the weight tensor - const auto weight_col_stride = n_input_channels * n_output_channels; - const auto weight_row_stride = 3 * weight_col_stride; - const float *inptrs[3][3]; - for (int i = 0; i < 3; i++) - { - for (int j = 0; j < 3; j++) - { - inptrs[i][j] = input + i*weight_row_stride + j*weight_col_stride; - } - } - - // For each input channel - for (int ic = 0; ic < n_input_channels; ic++) - { - float *outptr = output + ic * matrix_row_stride; - - // For each output channel - int channels_remaining = n_output_channels; -#ifdef __aarch64__ - for (; channels_remaining >= 4; channels_remaining -= 4) - { - // Matrices used and computed in this kernel - float32x4_t w[3][3], Ww[inner_tile_i][3], V[inner_tile_i][inner_tile_j]; - - // Read weights - for (int i = 0; i < 3; i++) - { - for (int j = 0; j < 3; j++) - { - w[i][j] = vld1q_f32(inptrs[i][j]); - inptrs[i][j] += 4; - } - } - - // Compute the matrix W w - for (int j = 0; j < 3; j++) - { - Ww[0][j] = w[0][j]; - - // Ww[1][j] = 0.5*(w[0][j] + w[1][j] + w[2][j]); - Ww[1][j] = vmulq_n_f32(vaddq_f32(vaddq_f32(w[0][j], w[1][j]), w[2][j]), 0.5f); - - // Ww[2][j] = 0.5*(w[0][j] - w[1][j] + w[2][j]); - Ww[2][j] = vmulq_n_f32(vaddq_f32(vsubq_f32(w[0][j], w[1][j]), w[2][j]), 0.5f); - - Ww[3][j] = w[2][j]; - } - - // Compute V = W w WT - for (int i = 0; i < inner_tile_i; i++) - { - V[i][0] = Ww[i][0]; - - // V[i][1] = 0.5*(Ww[i][0] + Ww[i][1] + Ww[i][2]); - V[i][1] = vmulq_n_f32(vaddq_f32(vaddq_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), 0.5f); - - // V[i][2] = 0.5*(Ww[i][0] - Ww[i][1] + Ww[i][2]); - V[i][2] = vmulq_n_f32(vaddq_f32(vsubq_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), 0.5f); - - V[i][3] = Ww[i][2]; - } - - // Store the transformed weights - for (int i = 0, m = 0; i < inner_tile_i; i++) - { - for (int j = 0; j < inner_tile_j; j++, m++) - { - vst1q_f32(outptr + m*matrix_stride, V[i][j]); - } - } - outptr += 4; - } -#endif // __aarch64__ -#ifdef __arm_any__ - for (; channels_remaining >= 2; channels_remaining -= 2) - { - // Matrices used and computed in this kernel - float32x2_t w[3][3], Ww[inner_tile_i][3], V[inner_tile_i][inner_tile_j]; - - // Read weights - for (int i = 0; i < 3; i++) - { - for (int j = 0; j < 3; j++) - { - w[i][j] = vld1_f32(inptrs[i][j]); - inptrs[i][j] += 2; - } - } - - // Compute the matrix W w - for (int j = 0; j < 3; j++) - { - Ww[0][j] = w[0][j]; - - // Ww[1][j] = 0.5*(w[0][j] + w[1][j] + w[2][j]); - Ww[1][j] = vmul_n_f32(vadd_f32(vadd_f32(w[0][j], w[1][j]), w[2][j]), 0.5f); - - // Ww[2][j] = 0.5*(w[0][j] - w[1][j] + w[2][j]); - Ww[2][j] = vmul_n_f32(vadd_f32(vsub_f32(w[0][j], w[1][j]), w[2][j]), 0.5f); - - Ww[3][j] = w[2][j]; - } - - // Compute V = W w WT - for (int i = 0; i < inner_tile_i; i++) - { - V[i][0] = Ww[i][0]; - - // V[i][1] = 0.5*(Ww[i][0] + Ww[i][1] + Ww[i][2]); - V[i][1] = vmul_n_f32(vadd_f32(vadd_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), 0.5f); - - // V[i][2] = 0.5*(Ww[i][0] - Ww[i][1] + Ww[i][2]); - V[i][2] = vmul_n_f32(vadd_f32(vsub_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), 0.5f); - - V[i][3] = Ww[i][2]; - } - - // Store the transformed weights - for (int i = 0, m = 0; i < inner_tile_i; i++) - { - for (int j = 0; j < inner_tile_j; j++, m++) - { - vst1_f32(outptr + m*matrix_stride, V[i][j]); - } - } - outptr += 2; - } -#endif // __arm_any__ - for (; channels_remaining; channels_remaining--) - { - // Matrices used and computed in this kernel - float w[3][3], Ww[inner_tile_i][3], V[inner_tile_i][inner_tile_j]; - - // Read weights - for (int i = 0; i < 3; i++) - { - for (int j = 0; j < 3; j++) - { - w[i][j] = *(inptrs[i][j]++); - } - } - - // Compute the matrix W w - for (int j = 0; j < 3; j++) - { - Ww[0][j] = w[0][j]; - Ww[1][j] = 0.5*(w[0][j] + w[1][j] + w[2][j]); - Ww[2][j] = 0.5*(w[0][j] - w[1][j] + w[2][j]); - Ww[3][j] = w[2][j]; - } - - // Compute V = W w WT - for (int i = 0; i < inner_tile_i; i++) - { - V[i][0] = Ww[i][0]; - V[i][1] = 0.5*(Ww[i][0] + Ww[i][1] + Ww[i][2]); - V[i][2] = 0.5*(Ww[i][0] - Ww[i][1] + Ww[i][2]); - V[i][3] = Ww[i][2]; - } - - // Store the transformed weights - for (int i = 0, m = 0; i < inner_tile_i; i++) - { - for (int j = 0; j < inner_tile_j; j++, m++) - { - *(outptr + m*matrix_stride) = V[i][j]; - } - } - outptr++; - } - } -} - -template class WeightTransform<3, 3, 4, 4, float, float, WinogradRoots::Integers>; - -} // namespace diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_2x2_5x5_fp32_fp32_integers.cpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_2x2_5x5_fp32_fp32_integers.cpp deleted file mode 100644 index 26ab56f24e..0000000000 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_2x2_5x5_fp32_fp32_integers.cpp +++ /dev/null @@ -1,401 +0,0 @@ -/* - * Copyright (c) 2019 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -#include "arm.hpp" -#include "kernel.hpp" - -namespace winograd -{ - -template <> -void WeightTransform<5, 5, 6, 6, float, float, WinogradRoots::Integers>::execute( - const int n_output_channels, - const int n_input_channels, - const float* const input, - float* const output, - const int matrix_stride, - const int matrix_row_stride -) -{ - // Get pointers to each cell of the weight tensor - const auto weight_col_stride = n_input_channels * n_output_channels; - const auto weight_row_stride = 5 * weight_col_stride; - const float *inptrs[5][5]; - for (int i = 0; i < 5; i++) - { - for (int j = 0; j < 5; j++) - { - inptrs[i][j] = input + i*weight_row_stride + j*weight_col_stride; - } - } - - // For each input channel - for (int ic = 0; ic < n_input_channels; ic++) - { - float *outptr = output + ic * matrix_row_stride; - - // For each output channel - int channels_remaining = n_output_channels; -#ifdef __aarch64__ - for (; channels_remaining >= 4; channels_remaining -= 4) - { - // Matrices used and computed in this kernel - float32x4_t w[5][5], Ww[6][5], V[6][6]; - - // Read weights - for (int i = 0; i < 5; i++) - { - for (int j = 0; j < 5; j++) - { - w[i][j] = vld1q_f32(inptrs[i][j]); - inptrs[i][j] += 4; - } - } - - // Compute the matrix W w - for (int j = 0; j < 5; j++) - { - // Ww[0][j] = w[0][j]/4.0f; - Ww[0][j] = vmulq_n_f32(w[0][j], 1.0f/4.0f); - - // Ww[1][j] = -( w[0][j] + w[1][j] + w[2][j] + w[3][j] + w[4][j])/6.0f; - Ww[1][j] = vmulq_n_f32( - vaddq_f32( - vaddq_f32( - vaddq_f32(w[1][j], w[0][j]), - vaddq_f32(w[3][j], w[2][j]) - ), - w[4][j] - ), - -1.0f/6.0f - ); - - // Ww[2][j] = +(-w[0][j] + w[1][j] - w[2][j] + w[3][j] - w[4][j])/6.0f; - // Ww[2][j] = ((w[1][j] - w[0][j]) + (w[3][j] - w[2][j]) - w[4][j])/6.0f; - Ww[2][j] = vmulq_n_f32( - vsubq_f32( - vaddq_f32( - vsubq_f32(w[1][j], w[0][j]), - vsubq_f32(w[3][j], w[2][j]) - ), - w[4][j] - ), - 1.0f/6.0f - ); - - // Ww[3][j] = (w[0][j]/8.0f + w[1][j]/4.0f + w[2][j]/2.0f + w[3][j] + 2*w[4][j])/3.0f; - Ww[3][j] = vmulq_n_f32( - vmlaq_n_f32( - vaddq_f32( - vaddq_f32(vmulq_n_f32(w[0][j], 1.0f/8.0f), vmulq_n_f32(w[1][j], 1.0f/4.0f)), - vaddq_f32(vmulq_n_f32(w[2][j], 1.0f/2.0f), w[3][j]) - ), - w[4][j], 2.0f - ), - 1.0f/3.0f - ); - - // Ww[4][j] = (w[0][j]/8.0f - w[1][j]/4.0f + w[2][j]/2.0f - w[3][j] + 2*w[4][j])/3.0f; - Ww[4][j] = vmulq_n_f32( - vmlaq_n_f32( - vaddq_f32( - vsubq_f32(vmulq_n_f32(w[0][j], 1.0f/8.0f), vmulq_n_f32(w[1][j], 1.0f/4.0f)), - vsubq_f32(vmulq_n_f32(w[2][j], 1.0f/2.0f), w[3][j]) - ), - w[4][j], 2.0f - ), - 1.0f/3.0f - ); - - // Ww[5][j] = w[4][j]; - Ww[5][j] = w[4][j]; - } - - // Compute V = W w WT - for (int i = 0; i < 6; i++) - { - // V[i][0] = Ww[i][0]/4.0f; - V[i][0] = vmulq_n_f32(Ww[i][0], 1.0f/4.0f); - - // V[i][1] = -( Ww[i][0] + Ww[i][1] + Ww[i][2] + Ww[i][3] + Ww[i][4])/6.0f; - V[i][1] = vmulq_n_f32( - vaddq_f32( - vaddq_f32( - vaddq_f32(Ww[i][1], Ww[i][0]), - vaddq_f32(Ww[i][3], Ww[i][2]) - ), - Ww[i][4] - ), - -1.0f/6.0f - ); - - // V[i][2] = +(-Ww[i][0] + Ww[i][1] - Ww[i][2] + Ww[i][3] - Ww[i][4])/6.0f; - // V[i][2] = ((Ww[i][1] - Ww[i][0]) + (Ww[i][3] - Ww[i][2]) - Ww[i][4])/6.0f; - V[i][2] = vmulq_n_f32( - vsubq_f32( - vaddq_f32( - vsubq_f32(Ww[i][1], Ww[i][0]), - vsubq_f32(Ww[i][3], Ww[i][2]) - ), - Ww[i][4] - ), - 1.0f/6.0f - ); - - // V[i][3] = (Ww[i][0]/8.0f + Ww[i][1]/4.0f + Ww[i][2]/2.0f + Ww[i][3] + 2*Ww[i][4])/3.0f; - V[i][3] = vmulq_n_f32( - vmlaq_n_f32( - vaddq_f32( - vaddq_f32(vmulq_n_f32(Ww[i][0], 1.0f/8.0f), vmulq_n_f32(Ww[i][1], 1.0f/4.0f)), - vaddq_f32(vmulq_n_f32(Ww[i][2], 1.0f/2.0f), Ww[i][3]) - ), - Ww[i][4], 2.0f - ), - 1.0f/3.0f - ); - - // V[i][4] = (Ww[i][0]/8.0f - Ww[i][1]/4.0f + Ww[i][2]/2.0f - Ww[i][3] + 2*Ww[i][4])/3.0f; - V[i][4] = vmulq_n_f32( - vmlaq_n_f32( - vaddq_f32( - vsubq_f32(vmulq_n_f32(Ww[i][0], 1.0f/8.0f), vmulq_n_f32(Ww[i][1], 1.0f/4.0f)), - vsubq_f32(vmulq_n_f32(Ww[i][2], 1.0f/2.0f), Ww[i][3]) - ), - Ww[i][4], 2.0f - ), - 1.0f/3.0f - ); - - // V[i][5] = Ww[i][4]; - V[i][5] = Ww[i][4]; - } - - // Store the transformed weights - for (int i = 0, m = 0; i < 6; i++) - { - for (int j = 0; j < 6; j++, m++) - { - vst1q_f32(outptr + m*matrix_stride, V[i][j]); - } - } - outptr += 4; - } -#endif // __aarch64__ -#ifdef __arm_any__ - for (; channels_remaining >= 2; channels_remaining -= 2) - { - // Matrices used and computed in this kernel - float32x2_t w[5][5], Ww[6][5], V[6][6]; - - // Read weights - for (int i = 0; i < 5; i++) - { - for (int j = 0; j < 5; j++) - { - w[i][j] = vld1_f32(inptrs[i][j]); - inptrs[i][j] += 2; - } - } - - // Compute the matrix W w - for (int j = 0; j < 5; j++) - { - // Ww[0][j] = w[0][j]/4.0f; - Ww[0][j] = vmul_n_f32(w[0][j], 1.0f/4.0f); - - // Ww[1][j] = -( w[0][j] + w[1][j] + w[2][j] + w[3][j] + w[4][j])/6.0f; - Ww[1][j] = vmul_n_f32( - vadd_f32( - vadd_f32( - vadd_f32(w[1][j], w[0][j]), - vadd_f32(w[3][j], w[2][j]) - ), - w[4][j] - ), - -1.0f/6.0f - ); - - // Ww[2][j] = +(-w[0][j] + w[1][j] - w[2][j] + w[3][j] - w[4][j])/6.0f; - // Ww[2][j] = ((w[1][j] - w[0][j]) + (w[3][j] - w[2][j]) - w[4][j])/6.0f; - Ww[2][j] = vmul_n_f32( - vsub_f32( - vadd_f32( - vsub_f32(w[1][j], w[0][j]), - vsub_f32(w[3][j], w[2][j]) - ), - w[4][j] - ), - 1.0f/6.0f - ); - - // Ww[3][j] = (w[0][j]/8.0f + w[1][j]/4.0f + w[2][j]/2.0f + w[3][j] + 2*w[4][j])/3.0f; - Ww[3][j] = vmul_n_f32( - vmla_n_f32( - vadd_f32( - vadd_f32(vmul_n_f32(w[0][j], 1.0f/8.0f), vmul_n_f32(w[1][j], 1.0f/4.0f)), - vadd_f32(vmul_n_f32(w[2][j], 1.0f/2.0f), w[3][j]) - ), - w[4][j], 2.0f - ), - 1.0f/3.0f - ); - - // Ww[4][j] = (w[0][j]/8.0f - w[1][j]/4.0f + w[2][j]/2.0f - w[3][j] + 2*w[4][j])/3.0f; - Ww[4][j] = vmul_n_f32( - vmla_n_f32( - vadd_f32( - vsub_f32(vmul_n_f32(w[0][j], 1.0f/8.0f), vmul_n_f32(w[1][j], 1.0f/4.0f)), - vsub_f32(vmul_n_f32(w[2][j], 1.0f/2.0f), w[3][j]) - ), - w[4][j], 2.0f - ), - 1.0f/3.0f - ); - - // Ww[5][j] = w[4][j]; - Ww[5][j] = w[4][j]; - } - - // Compute V = W w WT - for (int i = 0; i < 6; i++) - { - // V[i][0] = Ww[i][0]/4.0f; - V[i][0] = vmul_n_f32(Ww[i][0], 1.0f/4.0f); - - // V[i][1] = -( Ww[i][0] + Ww[i][1] + Ww[i][2] + Ww[i][3] + Ww[i][4])/6.0f; - V[i][1] = vmul_n_f32( - vadd_f32( - vadd_f32( - vadd_f32(Ww[i][1], Ww[i][0]), - vadd_f32(Ww[i][3], Ww[i][2]) - ), - Ww[i][4] - ), - -1.0f/6.0f - ); - - // V[i][2] = +(-Ww[i][0] + Ww[i][1] - Ww[i][2] + Ww[i][3] - Ww[i][4])/6.0f; - // V[i][2] = ((Ww[i][1] - Ww[i][0]) + (Ww[i][3] - Ww[i][2]) - Ww[i][4])/6.0f; - V[i][2] = vmul_n_f32( - vsub_f32( - vadd_f32( - vsub_f32(Ww[i][1], Ww[i][0]), - vsub_f32(Ww[i][3], Ww[i][2]) - ), - Ww[i][4] - ), - 1.0f/6.0f - ); - - // V[i][3] = (Ww[i][0]/8.0f + Ww[i][1]/4.0f + Ww[i][2]/2.0f + Ww[i][3] + 2*Ww[i][4])/3.0f; - V[i][3] = vmul_n_f32( - vmla_n_f32( - vadd_f32( - vadd_f32(vmul_n_f32(Ww[i][0], 1.0f/8.0f), vmul_n_f32(Ww[i][1], 1.0f/4.0f)), - vadd_f32(vmul_n_f32(Ww[i][2], 1.0f/2.0f), Ww[i][3]) - ), - Ww[i][4], 2.0f - ), - 1.0f/3.0f - ); - - // V[i][4] = (Ww[i][0]/8.0f - Ww[i][1]/4.0f + Ww[i][2]/2.0f - Ww[i][3] + 2*Ww[i][4])/3.0f; - V[i][4] = vmul_n_f32( - vmla_n_f32( - vadd_f32( - vsub_f32(vmul_n_f32(Ww[i][0], 1.0f/8.0f), vmul_n_f32(Ww[i][1], 1.0f/4.0f)), - vsub_f32(vmul_n_f32(Ww[i][2], 1.0f/2.0f), Ww[i][3]) - ), - Ww[i][4], 2.0f - ), - 1.0f/3.0f - ); - - // V[i][5] = Ww[i][4]; - V[i][5] = Ww[i][4]; - } - - // Store the transformed weights - for (int i = 0, m = 0; i < 6; i++) - { - for (int j = 0; j < 6; j++, m++) - { - vst1_f32(outptr + m*matrix_stride, V[i][j]); - } - } - outptr += 2; - } -#endif // __arm_any__ - for (; channels_remaining; channels_remaining--) - { - // Matrices used and computed in this kernel - float w[5][5], Ww[6][5], V[6][6]; - - // Read weights - for (int i = 0; i < 5; i++) - { - for (int j = 0; j < 5; j++) - { - w[i][j] = *(inptrs[i][j]++); - } - } - - // Compute the matrix W w - for (int j = 0; j < 5; j++) - { - Ww[0][j] = w[0][j]/4.0f; - Ww[1][j] = -( w[0][j] + w[1][j] + w[2][j] + w[3][j] + w[4][j])/6.0f; - Ww[2][j] = +(-w[0][j] + w[1][j] - w[2][j] + w[3][j] - w[4][j])/6.0f; - Ww[3][j] = (w[0][j]/8.0f + w[1][j]/4.0f + w[2][j]/2.0f + w[3][j] + 2*w[4][j])/3.0f; - Ww[4][j] = (w[0][j]/8.0f - w[1][j]/4.0f + w[2][j]/2.0f - w[3][j] + 2*w[4][j])/3.0f; - Ww[5][j] = w[4][j]; - } - - // Compute V = W w WT - for (int i = 0; i < 6; i++) - { - V[i][0] = Ww[i][0]/4.0f; - V[i][1] = -( Ww[i][0] + Ww[i][1] + Ww[i][2] + Ww[i][3] + Ww[i][4])/6.0f; - V[i][2] = +(-Ww[i][0] + Ww[i][1] - Ww[i][2] + Ww[i][3] - Ww[i][4])/6.0f; - V[i][3] = (Ww[i][0]/8.0f + Ww[i][1]/4.0f + Ww[i][2]/2.0f + Ww[i][3] + 2*Ww[i][4])/3.0f; - V[i][4] = (Ww[i][0]/8.0f - Ww[i][1]/4.0f + Ww[i][2]/2.0f - Ww[i][3] + 2*Ww[i][4])/3.0f; - V[i][5] = Ww[i][4]; - } - - // Store the transformed weights - for (int i = 0, m = 0; i < 6; i++) - { - for (int j = 0; j < 6; j++, m++) - { - *(outptr + m*matrix_stride) = V[i][j]; - } - } - outptr++; - } - } -} - -template class WeightTransform<5, 5, 6, 6, float, float, WinogradRoots::Integers>; - -} // namespace winograd diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_4_5_fp32_fp32_integers.cpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_4_5_fp32_fp32_integers.cpp deleted file mode 100644 index eeda274453..0000000000 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_4_5_fp32_fp32_integers.cpp +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright (c) 2019 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -#include "arm.hpp" -#include "kernel.hpp" - -namespace winograd -{ - -template <> -void WeightTransform<1, 5, 1, 8, float, float, WinogradRoots::Integers>::execute( - const int n_output_channels, - const int n_input_channels, - const float* const input, // NOTE: Data in HWIO order - float* const output, - const int matrix_stride, - const int matrix_row_stride -) -{ - // Get pointers to each cell of the weight tensor - const auto weight_col_stride = n_input_channels * n_output_channels; - const float *inptrs[kernel_cols]; - for (int j = 0; j < kernel_cols; j++) - { - inptrs[j] = input + j*weight_col_stride; - } - - // For each input channel - for (int ic = 0; ic < n_input_channels; ic++) - { - float *outptr = output + ic * matrix_row_stride; - - // For each output channel - int channels_remaining = n_output_channels; - for (; channels_remaining; channels_remaining--) - { - // Matrices used and computed in this kernel - float w[kernel_cols], V[inner_tile_cols]; - - // Read weights - for (int j = 0; j < kernel_cols; j++) - { - w[j] = *(inptrs[j]++); - } - - // Compute V = w WT - V[0] = (w[0]*-1) / 36; - V[1] = (w[1]*-1 + w[3]*-1 + w[0]*1 + w[2]*1 + w[4]*1) / 48; - V[2] = (w[0]*1 + w[1]*1 + w[2]*1 + w[3]*1 + w[4]*1) / 48; - V[3] = (w[0]*-1 + w[4]*-16 + w[2]*-4 + w[1]*2 + w[3]*8) / 120; - V[4] = (w[0]*-1 + w[4]*-16 + w[3]*-8 + w[2]*-4 + w[1]*-2) / 120; - V[5] = (w[3]*-27 + w[1]*-3 + w[2]*9 + w[4]*81 + w[0]*1) / 720; - V[6] = (w[1]*3 + w[2]*9 + w[3]*27 + w[4]*81 + w[0]*1) / 720; - V[7] = (w[4]*1) / 1; - - // Store the transformed weights - for (int j = 0; j < inner_tile_cols; j++) - { - *(outptr + j*matrix_stride) = V[j]; - } - outptr++; - } - } -} - -template class WeightTransform<1, 5, 1, 8, float, float, WinogradRoots::Integers>; -template class WeightTransform<5, 1, 8, 1, float, float, WinogradRoots::Integers>; - -} // namespace winograd diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_4x4_3x3_fp32_fp32_integers.cpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_4x4_3x3_fp32_fp32_integers.cpp deleted file mode 100644 index 7c2c718bd5..0000000000 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_4x4_3x3_fp32_fp32_integers.cpp +++ /dev/null @@ -1,257 +0,0 @@ -/* - * Copyright (c) 2019 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -#include "arm.hpp" -#include "kernel.hpp" - -namespace winograd -{ - -template <> -void WeightTransform<3, 3, 6, 6, float, float, WinogradRoots::Integers>::execute( - const int n_output_channels, - const int n_input_channels, - const float* const input, // NOTE: Data in HWIO order - float* const output, - const int matrix_stride, - const int matrix_row_stride -) -{ - // Get pointers to each cell of the weight tensor - const auto weight_col_stride = n_input_channels * n_output_channels; - const auto weight_row_stride = 3 * weight_col_stride; - const float *inptrs[3][3]; - for (int i = 0; i < 3; i++) - { - for (int j = 0; j < 3; j++) - { - inptrs[i][j] = input + i*weight_row_stride + j*weight_col_stride; - } - } - - // For each input channel - for (int ic = 0; ic < n_input_channels; ic++) - { - float *outptr = output + ic * matrix_row_stride; - - // For each output channel - int channels_remaining = n_output_channels; -#ifdef __aarch64__ - for (; channels_remaining >= 4; channels_remaining -= 4) - { - // Matrices used and computed in this kernel - float32x4_t w[3][3], Ww[6][3], V[6][6]; - - // Read weights - for (int i = 0; i < 3; i++) - { - for (int j = 0; j < 3; j++) - { - w[i][j] = vld1q_f32(inptrs[i][j]); - inptrs[i][j] += 4; - } - } - - // Compute the matrix W w - for (int j = 0; j < 3; j++) - { - // Ww[0][j] = 6*w[0][j]; - Ww[0][j] = vmulq_n_f32(w[0][j], 6.0); - - // Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j]; - Ww[1][j] = vmulq_n_f32(vaddq_f32(vaddq_f32(w[0][j], w[1][j]), w[2][j]), -4.0); - - // Ww[2][j] = -4*w[0][j] + 4*w[1][j] + -4*w[2][j]; - Ww[2][j] = vmulq_n_f32(vsubq_f32(vsubq_f32(w[1][j], w[0][j]), w[2][j]), 4.0); - - // Ww[3][j] = 1*w[0][j] + 2*w[1][j] + 4*w[2][j]; - Ww[3][j] = vmlaq_n_f32(vmlaq_n_f32(w[0][j], w[1][j], 2.0f), w[2][j], 4.0f); - - // Ww[4][j] = 1*w[0][j] + -2*w[1][j] + 4*w[2][j]; - Ww[4][j] = vmlaq_n_f32(vmlsq_n_f32(w[0][j], w[1][j], 2.0f), w[2][j], 4.0f); - - // Ww[5][j] = 24*w[2][j]; - Ww[5][j] = vmulq_n_f32(w[2][j], 24.0f); - } - - // Compute V = W w WT - for (int i = 0; i < 6; i++) - { - const float recip576 = 1.0f / 576.0f; - - // V[i][0] = 6*Ww[i][0]; - V[i][0] = vmulq_n_f32(vmulq_n_f32(Ww[i][0], 6.0), recip576); - - // V[i][1] = -4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2]; - V[i][1] = vmulq_n_f32(vmulq_n_f32(vaddq_f32(vaddq_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), -4.0), recip576); - - // V[i][2] = -4*Ww[i][0] + 4*Ww[i][1] + -4*Ww[i][2]; - V[i][2] = vmulq_n_f32(vmulq_n_f32(vsubq_f32(vsubq_f32(Ww[i][1], Ww[i][0]), Ww[i][2]), 4.0), recip576); - - // V[i][3] = 1*Ww[i][0] + 2*Ww[i][1] + 4*Ww[i][2]; - V[i][3] = vmulq_n_f32(vmlaq_n_f32(vmlaq_n_f32(Ww[i][0], Ww[i][1], 2.0f), Ww[i][2], 4.0f), recip576); - - // V[i][4] = 1*Ww[i][0] + -2*Ww[i][1] + 4*Ww[i][2]; - V[i][4] = vmulq_n_f32(vmlaq_n_f32(vmlsq_n_f32(Ww[i][0], Ww[i][1], 2.0f), Ww[i][2], 4.0f), recip576); - - // V[i][5] = 24*Ww[i][2]; - V[i][5] = vmulq_n_f32(vmulq_n_f32(Ww[i][2], 24.0f), recip576); - } - - // Store the transformed weights - for (int i = 0, m = 0; i < 6; i++) - { - for (int j = 0; j < 6; j++, m++) - { - vst1q_f32(outptr + m*matrix_stride, V[i][j]); - } - } - outptr += 4; - } -#endif // __aarch64__ -#ifdef __arm_any__ - for (; channels_remaining >= 2; channels_remaining -= 2) - { - // Matrices used and computed in this kernel - float32x2_t w[3][3], Ww[6][3], V[6][6]; - - // Read weights - for (int i = 0; i < 3; i++) - { - for (int j = 0; j < 3; j++) - { - w[i][j] = vld1_f32(inptrs[i][j]); - inptrs[i][j] += 2; - } - } - - // Compute the matrix W w - for (int j = 0; j < 3; j++) - { - // Ww[0][j] = 6*w[0][j]; - Ww[0][j] = vmul_n_f32(w[0][j], 6.0); - - // Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j]; - Ww[1][j] = vmul_n_f32(vadd_f32(vadd_f32(w[0][j], w[1][j]), w[2][j]), -4.0); - - // Ww[2][j] = -4*w[0][j] + 4*w[1][j] + -4*w[2][j]; - Ww[2][j] = vmul_n_f32(vsub_f32(vsub_f32(w[1][j], w[0][j]), w[2][j]), 4.0); - - // Ww[3][j] = 1*w[0][j] + 2*w[1][j] + 4*w[2][j]; - Ww[3][j] = vmla_n_f32(vmla_n_f32(w[0][j], w[1][j], 2.0f), w[2][j], 4.0f); - - // Ww[4][j] = 1*w[0][j] + -2*w[1][j] + 4*w[2][j]; - Ww[4][j] = vmla_n_f32(vmls_n_f32(w[0][j], w[1][j], 2.0f), w[2][j], 4.0f); - - // Ww[5][j] = 24*w[2][j]; - Ww[5][j] = vmul_n_f32(w[2][j], 24.0f); - } - - // Compute V = W w WT - for (int i = 0; i < 6; i++) - { - const float recip576 = 1.0f / 576.0f; - - // V[i][0] = 6*Ww[i][0]; - V[i][0] = vmul_n_f32(vmul_n_f32(Ww[i][0], 6.0), recip576); - - // V[i][1] = -4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2]; - V[i][1] = vmul_n_f32(vmul_n_f32(vadd_f32(vadd_f32(Ww[i][0], Ww[i][1]), Ww[i][2]), -4.0), recip576); - - // V[i][2] = -4*Ww[i][0] + 4*Ww[i][1] + -4*Ww[i][2]; - V[i][2] = vmul_n_f32(vmul_n_f32(vsub_f32(vsub_f32(Ww[i][1], Ww[i][0]), Ww[i][2]), 4.0), recip576); - - // V[i][3] = 1*Ww[i][0] + 2*Ww[i][1] + 4*Ww[i][2]; - V[i][3] = vmul_n_f32(vmla_n_f32(vmla_n_f32(Ww[i][0], Ww[i][1], 2.0f), Ww[i][2], 4.0f), recip576); - - // V[i][4] = 1*Ww[i][0] + -2*Ww[i][1] + 4*Ww[i][2]; - V[i][4] = vmul_n_f32(vmla_n_f32(vmls_n_f32(Ww[i][0], Ww[i][1], 2.0f), Ww[i][2], 4.0f), recip576); - - // V[i][5] = 24*Ww[i][2]; - V[i][5] = vmul_n_f32(vmul_n_f32(Ww[i][2], 24.0f), recip576); - } - - // Store the transformed weights - for (int i = 0, m = 0; i < 6; i++) - { - for (int j = 0; j < 6; j++, m++) - { - vst1_f32(outptr + m*matrix_stride, V[i][j]); - } - } - outptr += 2; - } -#endif // __arm_any__ - for (; channels_remaining; channels_remaining--) - { - // Matrices used and computed in this kernel - float w[3][3], Ww[6][3], V[6][6]; - - // Read weights - for (int i = 0; i < 3; i++) - { - for (int j = 0; j < 3; j++) - { - w[i][j] = *(inptrs[i][j]++); - } - } - - // Compute the matrix W w - for (int j = 0; j < 3; j++) - { - Ww[0][j] = 6*w[0][j]; - Ww[1][j] = -4*w[0][j] + -4*w[1][j] + -4*w[2][j]; - Ww[2][j] = -4*w[0][j] + 4*w[1][j] + -4*w[2][j]; - Ww[3][j] = 1*w[0][j] + 2*w[1][j] + 4*w[2][j]; - Ww[4][j] = 1*w[0][j] + -2*w[1][j] + 4*w[2][j]; - Ww[5][j] = 24*w[2][j]; - } - - // Compute V = W w WT - for (int i = 0; i < 6; i++) - { - V[i][0] = ( 6*Ww[i][0]) / 576.0; - V[i][1] = (-4*Ww[i][0] + -4*Ww[i][1] + -4*Ww[i][2]) / 576.0; - V[i][2] = (-4*Ww[i][0] + 4*Ww[i][1] + -4*Ww[i][2]) / 576.0; - V[i][3] = ( 1*Ww[i][0] + 2*Ww[i][1] + 4*Ww[i][2]) / 576.0; - V[i][4] = ( 1*Ww[i][0] + -2*Ww[i][1] + 4*Ww[i][2]) / 576.0; - V[i][5] = (24*Ww[i][2]) / 576.0; - } - - // Store the transformed weights - for (int i = 0, m = 0; i < 6; i++) - { - for (int j = 0; j < 6; j++, m++) - { - *(outptr + m*matrix_stride) = V[i][j]; - } - } - outptr++; - } - } -} - -template class WeightTransform<3, 3, 6, 6, float, float, WinogradRoots::Integers>; - -} // namespace diff --git a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_6_3_fp32_fp32_integers.cpp b/src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_6_3_fp32_fp32_integers.cpp deleted file mode 100644 index 9b42224eaf..0000000000 --- a/src/core/NEON/kernels/convolution/winograd/winograd_transforms/weights_6_3_fp32_fp32_integers.cpp +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright (c) 2019 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -#include "arm.hpp" -#include "kernel.hpp" - -namespace winograd -{ - -template <> -void WeightTransform<1, 3, 1, 8, float, float, WinogradRoots::Integers>::execute( - const int n_output_channels, - const int n_input_channels, - const float* const input, // NOTE: Data in HWIO order - float* const output, - const int matrix_stride, - const int matrix_row_stride -) -{ - // Get pointers to each cell of the weight tensor - const auto weight_col_stride = n_input_channels * n_output_channels; - const float *inptrs[3]; - for (int j = 0; j < 3; j++) - { - inptrs[j] = input + j*weight_col_stride; - } - - // For each input channel - for (int ic = 0; ic < n_input_channels; ic++) - { - float *outptr = output + ic * matrix_row_stride; - - // For each output channel - int channels_remaining = n_output_channels; - for (; channels_remaining; channels_remaining--) - { - // Matrices used and computed in this kernel - float w[3], V[inner_tile_cols]; - - // Read weights - for (int j = 0; j < 3; j++) - { - w[j] = *(inptrs[j]++); - } - - // Compute V = w WT - V[0] = (w[0]*-1) / 36.0f; - V[1] = (w[1]*-1 + w[0]*1 + w[2]*1) / 48.0f; - V[2] = (w[0]*1 + w[1]*1 + w[2]*1) / 48.0f; - V[3] = (w[0]*-1 + w[2]*-4 + w[1]*2) / 120.0f; - V[4] = (w[0]*-1 + w[2]*-4 + w[1]*-2) / 120.0f; - V[5] = (w[1]*-3 + w[2]*9 + w[0]*1) / 720.0f; - V[6] = (w[1]*3 + w[2]*9 + w[0]*1) / 720.0f; - V[7] = (w[2]*1) / 1; - - // Store the transformed weights - for (int j = 0; j < inner_tile_cols; j++) - { - *(outptr + j*matrix_stride) = V[j]; - } - outptr++; - } - } -} - -template class WeightTransform<1, 3, 1, 8, float, float, WinogradRoots::Integers>; -template class WeightTransform<3, 1, 8, 1, float, float, WinogradRoots::Integers>; - -} // namespace diff --git a/src/core/NEON/wrapper/intrinsics/cvt.h b/src/core/NEON/wrapper/intrinsics/cvt.h index 6e79a92bc2..e52e3dd0c4 100644 --- a/src/core/NEON/wrapper/intrinsics/cvt.h +++ b/src/core/NEON/wrapper/intrinsics/cvt.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020 Arm Limited. + * Copyright (c) 2020, 2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -72,7 +72,7 @@ vcvt(const float32x4_t &a) return vcvtq_s32_f32(a); } -#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) +#if defined(ARM_COMPUTE_ENABLE_BF16) /** Convert 2x128-bit floating point vectors into 1x128-bit bfloat16 vector * * @param[in] inptr Pointer to the input memory to load values from @@ -89,7 +89,7 @@ inline void vcvt_bf16_f32(const float *inptr, uint16_t *outptr) : [outptr] "r"(outptr) : "v0", "v1", "memory"); } -#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ +#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ } // namespace wrapper } // namespace arm_compute diff --git a/src/core/NEON/wrapper/intrinsics/svdup_n.h b/src/core/NEON/wrapper/intrinsics/svdup_n.h index b1aed97d9c..9c42c86db7 100644 --- a/src/core/NEON/wrapper/intrinsics/svdup_n.h +++ b/src/core/NEON/wrapper/intrinsics/svdup_n.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020 Arm Limited. + * Copyright (c) 2020, 2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -46,7 +46,9 @@ SVDUP_N_IMPL(uint64_t, svuint64_t, u64) SVDUP_N_IMPL(float16_t, svfloat16_t, f16) SVDUP_N_IMPL(float, svfloat32_t, f32) SVDUP_N_IMPL(float64_t, svfloat64_t, f64) +#if __ARM_FEATURE_SVE_BF16 SVDUP_N_IMPL(bfloat16_t, svbfloat16_t, bf16) +#endif // #if __ARM_FEATURE_SVE_BF16 #undef SVDUP_N_IMPL @@ -54,4 +56,4 @@ SVDUP_N_IMPL(bfloat16_t, svbfloat16_t, bf16) } // namespace arm_compute #endif /* defined(__ARM_FEATURE_SVE) */ -#endif /* SRC_CORE_NEON_WRAPPER_INTRINSICS_SVDUP_N_H */ \ No newline at end of file +#endif /* SRC_CORE_NEON_WRAPPER_INTRINSICS_SVDUP_N_H */ diff --git a/src/core/NEON/wrapper/svtraits.h b/src/core/NEON/wrapper/svtraits.h index 1d599a246c..5ccd0ba8f1 100644 --- a/src/core/NEON/wrapper/svtraits.h +++ b/src/core/NEON/wrapper/svtraits.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -59,7 +59,10 @@ DEFINE_TYPES(uint64_t) DEFINE_TYPES(float16_t) DEFINE_TYPES(float32_t) DEFINE_TYPES(float64_t) + +#if __ARM_FEATURE_SVE_BF16 DEFINE_TYPES(bfloat16_t) +#endif // #if __ARM_FEATURE_SVE_BF16 #undef DEFINE_TYPES diff --git a/src/core/common/Registrars.h b/src/core/common/Registrars.h index cc76de2be5..42c1aaa9fa 100644 --- a/src/core/common/Registrars.h +++ b/src/core/common/Registrars.h @@ -167,10 +167,10 @@ #define REGISTER_INTEGER_SVE2(func_name) nullptr #endif /* defined(ENABLE_INTEGER_KERNELS) */ -#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) +#if defined(ARM_COMPUTE_ENABLE_BF16) #define REGISTER_BF16_NEON(func_name) &(func_name) -#else /* !(defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16))*/ +#else /* !(defined(ARM_COMPUTE_ENABLE_BF16))*/ #define REGISTER_BF16_NEON(func_name) nullptr -#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)*/ +#endif /* defined(ARM_COMPUTE_ENABLE_BF16)*/ #endif /* SRC_CORE_COMMON_REGISTRARS_H */ diff --git a/src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.cpp b/src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.cpp index 0be2ba02b5..9b6daae619 100644 --- a/src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.cpp +++ b/src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.cpp @@ -55,18 +55,32 @@ Status add_tensor(ClKernelBlueprint &kernel_blueprint, ITensorInfo *tensor_info, return Status{}; } -Status add_kcomp_eltwise_add(ClKernelBlueprint &kernel_blueprint, const ClEltwiseAddKernelDescriptor &, - ArgumentID src0_id, ArgumentID src1_id, ArgumentID &dst_id) +Status add_kcomp_eltwise_op(ClKernelBlueprint &kernel_blueprint, const ClElementwiseKernelDescriptor &desc, + ArgumentID src0_id, ArgumentID src1_id, ArgumentID &dst_id) { kernel_blueprint.impl().add_component( - std::make_unique( + std::make_unique( &kernel_blueprint, + desc, SharedVarLink{ src0_id, SharedVarIO::Input }, SharedVarLink{ src1_id, SharedVarIO::Input }, SharedVarLink{ dst_id, SharedVarIO::Output })); return Status{}; } + +Status add_kcomp_floor(ClKernelBlueprint &kernel_blueprint, const ClFloorKernelDescriptor &, + ArgumentID src_id, ArgumentID &dst_id) +{ + kernel_blueprint.impl().add_component( + std::make_unique( + &kernel_blueprint, + SharedVarLink{ src_id, SharedVarIO::Input }, + SharedVarLink{ dst_id, SharedVarIO::Output })); + + return Status{}; +} + Status add_kcomp_activation(ClKernelBlueprint &, const ClActivationKernelDescriptor &, ArgumentID, ArgumentID &) { return Status{}; diff --git a/src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.h b/src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.h index 067e9737e3..463fc5e7cf 100644 --- a/src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.h +++ b/src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.h @@ -59,9 +59,13 @@ class ClKernelBlueprint }; ///// Kernel Components ///// -/** Component: Eltwise Add */ -Status add_kcomp_eltwise_add(ClKernelBlueprint &, const ClEltwiseAddKernelDescriptor &, ArgumentID src0_id, - ArgumentID src1_id, ArgumentID &dst_id); +/** Component: Eltwise Operator */ +Status add_kcomp_eltwise_op(ClKernelBlueprint &, const ClElementwiseKernelDescriptor &, ArgumentID src0_id, + ArgumentID src1_id, ArgumentID &dst_id); + +/** Component: Floor */ +Status add_kcomp_floor(ClKernelBlueprint &, const ClFloorKernelDescriptor &, ArgumentID src_id, + ArgumentID &dst_id); /** Component: Activation */ Status add_kcomp_activation(ClKernelBlueprint &, const ClActivationKernelDescriptor &, ArgumentID src_id, ArgumentID &dst_id); diff --git a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h index 57ac70aa22..04919acb83 100644 --- a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h +++ b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h @@ -371,6 +371,7 @@ class IClKernelComponent { return Window{}; } + /** Get the tag look-up table used to instantiate the component code. * * @param vtable @@ -557,7 +558,7 @@ struct ClKernelBlueprint::Implementation std::string build_code() { - ARM_COMPUTE_ERROR_ON_MSG(_graph_root < 0, "No root found in the component graph"); + ARM_COMPUTE_ERROR_ON_MSG(_graph_root == -1, "No root found in the component graph"); // These data structures will hold the data from all the components in the blueprint std::set headers_list{}; @@ -666,9 +667,10 @@ struct ClKernelBlueprint::Implementation return _tile_info; } + // Get the global execution window, i.e. that of the root component Window get_execution_window() const { - ARM_COMPUTE_ERROR_ON_MSG(_graph_root < 0, "No root found in the component graph"); + ARM_COMPUTE_ERROR_ON_MSG(_graph_root == -1, "No root found in the component graph"); ARM_COMPUTE_ERROR_ON_MSG(_dst_id == -1, "Destination Tensor Id should be ready before calling get_execution_window()"); return _components.find(_graph_root)->second->get_window(); @@ -925,4 +927,4 @@ struct ClKernelBlueprint::Implementation } // namespace experimental } // namespace arm_compute #endif //ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_IMPL_COMMON_H -#endif /* ENABLE_EXPERIMENTAL_DYNAMIC_FUSION */ \ No newline at end of file +#endif /* ENABLE_EXPERIMENTAL_DYNAMIC_FUSION */ diff --git a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClDirectConvolutionKernelComponent.cpp b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClDirectConvolutionKernelComponent.cpp index b63e2167b7..811cd79811 100644 --- a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClDirectConvolutionKernelComponent.cpp +++ b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClDirectConvolutionKernelComponent.cpp @@ -237,7 +237,7 @@ std::string ClDirectConvolutionKernelComponent::get_component_code() const T_LOAD({{BIA_DATA_TYPE}}, 1, N0, BUFFER, {{bias}}, cout, 0, 1, 0, bias0); // c = c + bias[broadcasted] - T_ADD_BROADCAST_X({{ACC_DATA_TYPE}}, M0, N0, {{dst}}, bias0, {{dst}}); + T_ELTWISE_BROADCAST_ADD_X({{ACC_DATA_TYPE}}, M0, N0, {{dst}}, bias0, {{dst}}); )_"; } diff --git a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseAddKernelComponent.cpp b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseKernelComponent.cpp similarity index 60% rename from src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseAddKernelComponent.cpp rename to src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseKernelComponent.cpp index 965a68f51d..7515aec27a 100644 --- a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseAddKernelComponent.cpp +++ b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseKernelComponent.cpp @@ -23,7 +23,8 @@ */ #ifdef ENABLE_EXPERIMENTAL_DYNAMIC_FUSION -#include "src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseAddKernelComponent.h" +#include "src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseKernelComponent.h" +#include "arm_compute/core/Error.h" #include "arm_compute/core/Validate.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" @@ -34,17 +35,17 @@ namespace experimental { namespace dynamic_fusion { -ComponentType ClElementwiseAddKernelComponent::get_component_type() const +ComponentType ClElementwiseKernelComponent::get_component_type() const { return ComponentType::Simple; } -std::set ClElementwiseAddKernelComponent::get_headers_list() const +std::set ClElementwiseKernelComponent::get_headers_list() const { return std::set { "common/experimental/gemm_fused_post_ops/fp_mixed_precision_helpers.h", "tile_helpers.h" }; } -Window ClElementwiseAddKernelComponent::get_window() const +Window ClElementwiseKernelComponent::get_window() const { const ITensorInfo *lhs_info = _blueprint->impl().get_kernel_argument_info(_lhs.arg_id); const ITensorInfo *rhs_info = _blueprint->impl().get_kernel_argument_info(_rhs.arg_id); @@ -57,14 +58,18 @@ Window ClElementwiseAddKernelComponent::get_window() const auto_init_if_empty(*dst_info, out_shape, 1, lhs_info->data_type()); + TensorShape output_shape = dst_info->tensor_shape(); + // Collapse Dim 1 (W) and Dim 2 (H) together, leave Dim 0 (C) and upper dimensions unchanged + // This is in line with the collapsing convention used by Conv2d + output_shape.collapse(2U, 1U); const unsigned int vector_size_byte_opencl = 16; const unsigned int num_elems_processed_per_iteration = adjust_vec_size(vector_size_byte_opencl / dst_info->element_size(), dst_info->dimension(0)); - Window win = calculate_max_window(*dst_info, Steps(num_elems_processed_per_iteration)); + Window win = calculate_max_window(output_shape, Steps(num_elems_processed_per_iteration)); return win; } -std::string ClElementwiseAddKernelComponent::get_component_code() const +std::string ClElementwiseKernelComponent::get_component_code() const { std::string code; const bool is_root = _blueprint->impl().group(_lhs.arg_id) == SharedVarGroup::Argument && _blueprint->impl().group(_rhs.arg_id) == SharedVarGroup::Argument; @@ -72,7 +77,7 @@ std::string ClElementwiseAddKernelComponent::get_component_code() const if(is_root) { return R"_( - //------------------ START KERNEL {{meta_kernel_id}} ELTWISE_ADD --------------------- + //------------------ START KERNEL {{meta_kernel_id}} ELTWISE_OP --------------------- // IN_0(LHS) {{lhs}} // IN_1(RHS) {{rhs}} // OUT(dst, accum) {{dst}} @@ -83,23 +88,27 @@ std::string ClElementwiseAddKernelComponent::get_component_code() const TILE({{DATA_TYPE}}, M0, N0, lhs_tile); TILE({{DATA_TYPE}}, M0, N0, rhs_tile); + // Since mout maps to dimensions 1 (y) and dimension 2 (z) of the input tensor because of the collapsed window, bout maps to dimension 3 (w) + {{lhs}}_offset_first_element_in_bytes += bout * {{lhs}}_stride_w; + {{rhs}}_offset_first_element_in_bytes += bout * {{rhs}}_stride_w; + T_LOAD({{DATA_TYPE}}, M0, N0, BUFFER, {{lhs}}, cout, mout, 1, {{lhs}}_stride_y, lhs_tile); - T_LOAD({{DATA_TYPE}}, M0, N0, BUFFER, {{rhs}}, cout, mout, 1, {{rhs}}_stride_y, rhs_tile); + T_LOAD({{DATA_TYPE}}, {{rhs_m0}}, {{rhs_n0}}, BUFFER, {{rhs}}, {{rhs_start_x}}, {{rhs_start_y}}, 1, {{rhs}}_stride_y, rhs_tile); #if defined(IS_BROADCAST) - T_ADD_BROADCAST_X({{DATA_TYPE}}, M0, N0, lhs_tile, rhs_tile, {{dst}}); + T_ELTWISE_BROADCAST_{{ELTWISE_OP}}_X({{DATA_TYPE}}, M0, N0, lhs_tile, rhs_tile, {{dst}}); #else // !defined(IS_BROADCAST) - T_ADD({{DATA_TYPE}}, M0, N0, lhs_tile, rhs_tile, {{dst}}); + T_ELTWISE_{{ELTWISE_OP}}({{DATA_TYPE}}, M0, N0, lhs_tile, rhs_tile, {{dst}}); #endif // defined(IS_BROADCAST) } - //------------------ END KERNEL {{meta_kernel_id}} ELTWISE_ADD --------------------- + //------------------ END KERNEL {{meta_kernel_id}} ELTWISE_OP --------------------- )_"; } else { return R"_( - //------------------ START KERNEL {{meta_kernel_id}} ELTWISE_ADD --------------------- + //------------------ START KERNEL {{meta_kernel_id}} ELTWISE_OP --------------------- // IN_0/Out(Accumulator) {{acc}} // IN_1(Addend) {{addend}} @@ -107,37 +116,39 @@ std::string ClElementwiseAddKernelComponent::get_component_code() const { TILE({{DATA_TYPE}}, M0, N0, addend_tile); - T_LOAD({{DATA_TYPE}}, M0, N0, BUFFER, {{addend}}, cout, mout, 1, {{addend}}_stride_y, addend_tile); + T_LOAD({{DATA_TYPE}}, {{rhs_m0}}, {{rhs_n0}}, BUFFER, {{addend}}, {{rhs_start_x}}, {{rhs_start_y}}, 1, {{addend}}_stride_y, addend_tile); #if defined(IS_BROADCAST) - T_ADD_BROADCAST_X({{DATA_TYPE}}, M0, N0, {{acc}}, addend_tile, {{acc}}); + T_ELTWISE_BROADCAST_{{ELTWISE_OP}}_X({{DATA_TYPE}}, M0, N0, {{acc}}, addend_tile, {{acc}}); #else // !defined(IS_BROADCAST) - T_ADD({{DATA_TYPE}}, M0, N0, {{acc}}, addend_tile, {{acc}}); + T_ELTWISE_{{ELTWISE_OP}}({{DATA_TYPE}}, M0, N0, {{acc}}, addend_tile, {{acc}}); #endif // defined(IS_BROADCAST) } - //------------------ END KERNEL {{meta_kernel_id}} ELTWISE_ADD --------------------- + //------------------ END KERNEL {{meta_kernel_id}} ELTWISE_OP --------------------- )_"; } } -CLBuildOptions ClElementwiseAddKernelComponent::generate_build_options() const +CLBuildOptions ClElementwiseKernelComponent::generate_build_options() const { - const auto t_src_info = _blueprint->impl().get_kernel_argument_info(_rhs.arg_id); + const auto t_rhs_info = _blueprint->impl().get_kernel_argument_info(_rhs.arg_id); const auto t_dst_info = _blueprint->impl().get_kernel_argument_info(_blueprint->impl().get_dst_id()); - CLBuildOptions build_opts{}; - const auto n0 = _blueprint->impl().get_execution_window().x().step(); - const auto m0 = _blueprint->impl().get_execution_window().y().step(); - const bool is_broadcast = t_src_info->tensor_shape() != t_dst_info->tensor_shape(); + CLBuildOptions build_opts{}; + const auto n0 = _blueprint->impl().get_execution_window().x().step(); + const auto m0 = _blueprint->impl().get_execution_window().y().step(); + const unsigned int partial_store_n0 = t_dst_info->dimension(0) % n0; + const bool is_broadcast = t_rhs_info->tensor_shape() != t_dst_info->tensor_shape(); build_opts.add_option("-DM0=" + support::cpp11::to_string(m0)); build_opts.add_option("-DN0=" + support::cpp11::to_string(n0)); + build_opts.add_option("-DPARTIAL_N0=" + support::cpp11::to_string(partial_store_n0)); build_opts.add_option_if(is_broadcast, "-DIS_BROADCAST"); return build_opts; } -std::string ClElementwiseAddKernelComponent::generate_config_id() const +std::string ClElementwiseKernelComponent::generate_config_id() const { auto t_dst_info = _blueprint->impl().get_kernel_argument_info(_blueprint->impl().get_dst_id()); std::string config_id{}; @@ -151,7 +162,7 @@ std::string ClElementwiseAddKernelComponent::generate_config_id() const return config_id; } -void ClElementwiseAddKernelComponent::allocate_shared_vars(SharedVarTable &vtable) const +void ClElementwiseKernelComponent::allocate_shared_vars(SharedVarTable &vtable) const { const bool is_root = _blueprint->impl().group(_lhs.arg_id) == SharedVarGroup::Argument && _blueprint->impl().group(_rhs.arg_id) == SharedVarGroup::Argument; vtable.add(_lhs, _blueprint->impl().group(_lhs.arg_id), ClKernelArgDescriptor(_lhs.arg_id, ClKernelTensorArgType::Tensor_4D_t_Buffer), "lhs"); @@ -162,10 +173,11 @@ void ClElementwiseAddKernelComponent::allocate_shared_vars(SharedVarTable &vtabl } } -ClElementwiseAddKernelComponent::TagLUT ClElementwiseAddKernelComponent::get_tag_lut(const SharedVarTable &vtable) const +ClElementwiseKernelComponent::TagLUT ClElementwiseKernelComponent::get_tag_lut(const SharedVarTable &vtable) const { TagLUT lut{}; const auto t_dst_info = _blueprint->impl().get_kernel_argument_info(_blueprint->impl().get_dst_id()); + const auto t_rhs_info = _blueprint->impl().get_kernel_argument_info(_rhs.arg_id); // Arguments and global shared variables const bool is_root = _blueprint->impl().group(_lhs.arg_id) == SharedVarGroup::Argument && _blueprint->impl().group(_rhs.arg_id) == SharedVarGroup::Argument; if(is_root) @@ -199,6 +211,51 @@ ClElementwiseAddKernelComponent::TagLUT ClElementwiseAddKernelComponent::get_tag // Local build options lut["meta_kernel_id"] = id(); lut["DATA_TYPE"] = get_cl_type_from_data_type(t_dst_info->data_type()); + + switch(_desc.eltwise.op) + { + case ArithmeticOperation::DIV: + lut["ELTWISE_OP"] = "DIV"; + break; + case ArithmeticOperation::ADD: + lut["ELTWISE_OP"] = "ADD"; + break; + default: + ARM_COMPUTE_ERROR("Arithmetic Operation not supported"); + } + + // Set broadcast parameters + // PRE: All tensors are broadcast-compatible + const bool is_broadcast = t_rhs_info->tensor_shape() != t_dst_info->tensor_shape(); + if(is_broadcast) + { + // Note that n0 maps to input tensor dimension 0, m0 maps to input dimensions 1 and 2 because of our collapse strategy + if(t_rhs_info->dimension(0) == 1U && t_rhs_info->dimension(1) == 1U && t_rhs_info->dimension(2) == 1U) // Broadcast in X, Y, Z: collapsed rhs win [M0xN0] = [1x1] + { + lut["rhs_m0"] = "1"; + lut["rhs_n0"] = "1"; + lut["rhs_start_y"] = "0"; + lut["rhs_start_x"] = "0"; + } + else if(t_rhs_info->dimension(1) == 1U && t_rhs_info->dimension(2) == 1U) // Broadcast in Y and Z: collapsed rhs win [M0xN0] = [1xN] + { + lut["rhs_m0"] = "1"; + lut["rhs_n0"] = "N0"; + lut["rhs_start_y"] = "0"; + lut["rhs_start_x"] = "cout"; + } + else + { + ARM_COMPUTE_ERROR("Only support rhs broadcasting in all X, Y, Z dimensions, or just in Y and Z dimensions"); + } + } + else + { + lut["rhs_m0"] = "M0"; + lut["rhs_n0"] = "N0"; + lut["rhs_start_y"] = "mout"; + lut["rhs_start_x"] = "cout"; + } return lut; } } // namespace dynamic_fusion diff --git a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseKernelComponent.h b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseKernelComponent.h new file mode 100644 index 0000000000..f8377457d3 --- /dev/null +++ b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseKernelComponent.h @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef ENABLE_EXPERIMENTAL_DYNAMIC_FUSION + +#ifndef ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_IMPL_COMPONENTS_CLELEMENTWISEADDKERNELCOMPONENT_H +#define ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_IMPL_COMPONENTS_CLELEMENTWISEADDKERNELCOMPONENT_H + +#include "src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h" + +namespace arm_compute +{ +namespace experimental +{ +namespace dynamic_fusion +{ +class ClElementwiseKernelComponent : public IClKernelComponent +{ +public: + /** Construct a new Cl Elementwise Kernel Component object + * + * @param[in] blueprint Blueprint to which this component is added + * @param[in] desc Component descriptor + * @param[in] lhs Link to LHS tensor + * @param[in] rhs Link to RHS tensor + * @param[out] dst Link to DST tensor + * + * Support Level + * Data Type: F16, F32 + * Tensor Shape: Any shape of arbitrary dimension >= 1 and <= 4 + * Value Range: All + * Broadcasting: Only RHS tensor can be broadcasted into LHS. Only support broadcasting in dimension 1 and dimension 2 or all dimension 0, 1 and 2 + */ + ClElementwiseKernelComponent(ClKernelBlueprint *blueprint, const ClElementwiseKernelDescriptor &desc, const Link &lhs, const Link &rhs, const Link &dst) + : IClKernelComponent(blueprint), _desc{ desc }, _lhs{ lhs }, _rhs{ rhs }, _dst{ dst } + { + } + + ComponentType get_component_type() const override; + std::set get_headers_list() const override; + std::string get_component_code() const override; + Window get_window() const override; + CLBuildOptions generate_build_options() const override; + std::string generate_config_id() const override; + + virtual std::vector get_links() const override + { + return { _lhs, _rhs, _dst }; + } + + virtual TagLUT get_tag_lut(const SharedVarTable &vtable) const override; + virtual void allocate_shared_vars(SharedVarTable &vtable) const override; + + virtual std::string name() const override + { + return "eltwise_add_" + std::to_string(id()); + } + +private: + ClElementwiseKernelDescriptor _desc{}; + Link _lhs{}; + Link _rhs{}; + Link _dst{}; +}; + +} // namespace dynamic_fusion +} // namespace experimental +} // namespace arm_compute +#endif // ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_IMPL_COMPONENTS_CLELEMENTWISEADDKERNELCOMPONENT_H +#endif /* ENABLE_EXPERIMENTAL_DYNAMIC_FUSION */ \ No newline at end of file diff --git a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClFloorKernelComponent.cpp b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClFloorKernelComponent.cpp new file mode 100644 index 0000000000..0a20a8f600 --- /dev/null +++ b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClFloorKernelComponent.cpp @@ -0,0 +1,153 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#ifdef ENABLE_EXPERIMENTAL_DYNAMIC_FUSION +#include "src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClFloorKernelComponent.h" +#include "arm_compute/core/Error.h" +#include "arm_compute/core/Validate.h" +#include "src/core/helpers/AutoConfiguration.h" +#include "src/core/helpers/WindowHelpers.h" + +namespace arm_compute +{ +namespace experimental +{ +namespace dynamic_fusion +{ +ComponentType ClFloorKernelComponent::get_component_type() const +{ + return ComponentType::Simple; +} +std::set ClFloorKernelComponent::get_headers_list() const +{ + return std::set { "common/experimental/gemm_fused_post_ops/fp_mixed_precision_helpers.h", "tile_helpers.h" }; +} +Window ClFloorKernelComponent::get_window() const +{ + const ITensorInfo *src_info = _blueprint->impl().get_kernel_argument_info(_src.arg_id); + ITensorInfo *dst_info = _blueprint->impl().get_kernel_argument_info(_blueprint->impl().get_dst_id()); + + ARM_COMPUTE_ERROR_ON_NULLPTR(src_info, dst_info); + auto_init_if_empty(*dst_info, src_info->tensor_shape(), 1, src_info->data_type()); + + TensorShape output_shape = dst_info->tensor_shape(); + // Collapse Dim 1 (W) and Dim 2 (H) together, leave Dim 0 (C) and upper dimensions unchanged + // This is in line with the collapsing convention used by Conv2d + output_shape.collapse(2U, 1U); + const unsigned int vector_size_byte_opencl = 16; + const unsigned int num_elems_processed_per_iteration = adjust_vec_size(vector_size_byte_opencl / dst_info->element_size(), dst_info->dimension(0)); + Window win = calculate_max_window(output_shape, Steps(num_elems_processed_per_iteration)); + + return win; +} +std::string ClFloorKernelComponent::get_component_code() const +{ + bool is_root = _blueprint->impl().group(_src.arg_id) == SharedVarGroup::Argument; + if(is_root) + { + return R"_( + //------------------ START KERNEL {{meta_kernel_id}} FLOOR --------------------- + // IN_0(src) {{src}} + // OUT(dst, accum) {{dst}} + TILE({{DATA_TYPE}}, M0, N0, {{dst}}); + { + TILE({{DATA_TYPE}}, M0, N0, src_tile); + + // Since mout maps to dimensions 1 (y) and dimension 2 (z) of the input tensor because of the collapsed window, bout maps to dimension 3 (w) + {{src}}_offset_first_element_in_bytes += bout * {{src}}_stride_w; + T_LOAD({{DATA_TYPE}}, M0, N0, BUFFER, {{src}}, cout, mout, 1, {{src}}_stride_y, src_tile); + + T_FLOOR({{DATA_TYPE}}, M0, N0, src_tile, {{dst}}); + } + //------------------ END KERNEL {{meta_kernel_id}} FLOOR --------------------- +)_"; + } + else + { + return R"_( + //------------------ START KERNEL {{meta_kernel_id}} FLOOR --------------------- + // IN_0/Out(Accumulator) {{acc}} + // output = floor(input) + { + T_FLOOR({{DATA_TYPE}}, M0, N0, {{acc}}, {{acc}}); + } + //------------------ END KERNEL {{meta_kernel_id}} FLOOR --------------------- +)_"; + } +} +CLBuildOptions ClFloorKernelComponent::generate_build_options() const +{ + CLBuildOptions build_opts{}; + const auto n0 = _blueprint->impl().get_execution_window().x().step(); + const auto m0 = _blueprint->impl().get_execution_window().y().step(); + const auto dst_info = _blueprint->impl().get_kernel_argument_info(_blueprint->impl().get_dst_id()); + const unsigned int partial_store_n0 = dst_info->dimension(0) % n0; + build_opts.add_option("-DM0=" + support::cpp11::to_string(m0)); + build_opts.add_option("-DN0=" + support::cpp11::to_string(n0)); + build_opts.add_option("-DPARTIAL_N0=" + support::cpp11::to_string(partial_store_n0)); + return build_opts; +} +std::string ClFloorKernelComponent::generate_config_id() const +{ + auto t_dst_info = _blueprint->impl().get_kernel_argument_info(_blueprint->impl().get_dst_id()); + std::string config_id{}; + config_id += lower_string(string_from_data_type(t_dst_info->data_type())); + config_id += "_"; + config_id += support::cpp11::to_string(t_dst_info->dimension(0)); + config_id += "_"; + config_id += support::cpp11::to_string(t_dst_info->dimension(1)); + config_id += "_"; + config_id += lower_string(string_from_data_layout(t_dst_info->data_layout())); + return config_id; +} +void ClFloorKernelComponent::allocate_shared_vars(SharedVarTable &vtable) const +{ + vtable.add(_src, _blueprint->impl().group(_src.arg_id), ClKernelArgDescriptor(_src.arg_id, ClKernelTensorArgType::Tensor_4D_t_Buffer), "src"); + vtable.add(_dst, _blueprint->impl().group(_dst.arg_id), ClKernelArgDescriptor(_dst.arg_id, ClKernelTensorArgType::Tensor_4D_t_Buffer), "dst"); +} +ClFloorKernelComponent::TagLUT ClFloorKernelComponent::get_tag_lut(const SharedVarTable &vtable) const +{ + TagLUT lut{}; + const auto t_dst_info = _blueprint->impl().get_kernel_argument_info(_blueprint->impl().get_dst_id()); + // Arguments and global shared variables + const bool is_root = _blueprint->impl().group(_src.arg_id) == SharedVarGroup::Argument; + + if(is_root) + { + lut["src"] = vtable.get(_src); + lut["dst"] = vtable.get(_dst); + } + else + { + lut["acc"] = vtable.get(_src); + } + + lut["meta_kernel_id"] = id(); + lut["DATA_TYPE"] = get_cl_type_from_data_type(t_dst_info->data_type()); + return lut; +} +} // namespace dynamic_fusion +} // namespace experimental +} // namespace arm_compute +#endif /* ENABLE_EXPERIMENTAL_DYNAMIC_FUSION */ \ No newline at end of file diff --git a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseAddKernelComponent.h b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClFloorKernelComponent.h similarity index 74% rename from src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseAddKernelComponent.h rename to src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClFloorKernelComponent.h index 5f8b1569ac..e791b36382 100644 --- a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseAddKernelComponent.h +++ b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClFloorKernelComponent.h @@ -23,8 +23,8 @@ */ #ifdef ENABLE_EXPERIMENTAL_DYNAMIC_FUSION -#ifndef ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_IMPL_COMPONENTS_CLELEMENTWISEADDKERNELCOMPONENT_H -#define ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_IMPL_COMPONENTS_CLELEMENTWISEADDKERNELCOMPONENT_H +#ifndef ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_IMPL_COMPONENTS_CLFLOORKERNELCOMPONENT_H +#define ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_IMPL_COMPONENTS_CLFLOORKERNELCOMPONENT_H #include "src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/Common.h" @@ -34,11 +34,22 @@ namespace experimental { namespace dynamic_fusion { -class ClElementwiseAddKernelComponent : public IClKernelComponent +class ClFloorKernelComponent : public IClKernelComponent { public: - ClElementwiseAddKernelComponent(ClKernelBlueprint *blueprint, const Link &lhs, const Link &rhs, const Link &dst) - : IClKernelComponent(blueprint), _lhs{ lhs }, _rhs{ rhs }, _dst{ dst } + /** Construct a new Cl Floor Kernel Component object + * + * @param blueprint Blueprint to which this component is added + * @param src Link to SRC tensor + * @param dst Link to DST tensor + * + * Support Level + * Data Type: F16, F32 + * Tensor Shape: Any shape of arbitrary dimension >= 1 and <= 4 + * Value Range: All + */ + ClFloorKernelComponent(ClKernelBlueprint *blueprint, const Link &src, const Link &dst) + : IClKernelComponent(blueprint), _src{ src }, _dst{ dst } { } @@ -51,7 +62,7 @@ class ClElementwiseAddKernelComponent : public IClKernelComponent virtual std::vector get_links() const override { - return { _lhs, _rhs, _dst }; + return { _src, _dst }; } virtual TagLUT get_tag_lut(const SharedVarTable &vtable) const override; @@ -59,17 +70,16 @@ class ClElementwiseAddKernelComponent : public IClKernelComponent virtual std::string name() const override { - return "eltwise_add_" + std::to_string(id()); + return "floor_" + std::to_string(id()); } private: - Link _lhs{}; - Link _rhs{}; + Link _src{}; Link _dst{}; }; } // namespace dynamic_fusion } // namespace experimental } // namespace arm_compute -#endif // ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_IMPL_COMPONENTS_CLELEMENTWISEADDKERNELCOMPONENT_H -#endif /* ENABLE_EXPERIMENTAL_DYNAMIC_FUSION */ \ No newline at end of file +#endif // ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_IMPL_COMPONENTS_CLFLOORKERNELCOMPONENT_H +#endif /* ENABLE_EXPERIMENTAL_DYNAMIC_FUSION */ diff --git a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClKernelComponents.h b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClKernelComponents.h index 26e50523a9..3f99dd5553 100644 --- a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClKernelComponents.h +++ b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClKernelComponents.h @@ -27,7 +27,8 @@ #define ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_IMPL_CLKERNELCOMPONENTS_H #include "src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClDirectConvolutionKernelComponent.h" -#include "src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseAddKernelComponent.h" +#include "src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClElementwiseKernelComponent.h" +#include "src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClFloorKernelComponent.h" #include "src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClStoreKernelComponents.h" #endif //ARM_COMPUTE_EXPERIMENTAL_DYNAMICFUSION_IMPL_CLKERNELCOMPONENTS_H diff --git a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClStoreKernelComponents.cpp b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClStoreKernelComponents.cpp index 4ac27e007f..7c805d5368 100644 --- a/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClStoreKernelComponents.cpp +++ b/src/core/experimental/dynamic_fusion/ClKernelBuildingImpl/components/ClStoreKernelComponents.cpp @@ -108,6 +108,10 @@ std::string ClStoreIndirectWidthSelectKernelComponent::get_component_code() cons return R"_( //------------------ START KERNEL {{meta_kernel_id}} STORE --------------------- { + // This also follows NHWC layout + // cout maps to global_id(0) maps to Channel + // mout maps to global_id(1) maps to Height and Weight (Collapsed Window) + // bout maps to global_id(3) maps to N / Batch #define _IDST_WIDTH {{dst}}_w #define _IDST_HEIGHT {{dst}}_h TILE(uint, M0, 1, dst_indirect_y); diff --git a/src/core/experimental/dynamic_fusion/OperatorGraph.cpp b/src/core/experimental/dynamic_fusion/OperatorGraph.cpp index a335e5aada..bd88afdb47 100644 --- a/src/core/experimental/dynamic_fusion/OperatorGraph.cpp +++ b/src/core/experimental/dynamic_fusion/OperatorGraph.cpp @@ -198,7 +198,7 @@ void force_conv2d_method(OperatorGraph &graph, Operator conv2d, ConvolutionMetho node->set_method(method); } -Operator add_op_elementwise_add(OperatorGraph &graph, const AddDescriptor &desc, OpTensor lhs, OpTensor rhs, OpTensor dst) +Operator add_op_elementwise_op(OperatorGraph &graph, const ElementwiseDescriptor &desc, OpTensor lhs, OpTensor rhs, OpTensor dst) { auto id = graph.impl()->graph.add_operator({ rhs.id(), lhs.id() }, { dst.id() }); check_dependency_graph_op_success(graph, id.first); @@ -224,7 +224,36 @@ Operator add_op_elementwise_add(OperatorGraph &graph, const AddDescriptor &desc, tensors.add_const_tensor(ACL_SRC_0, graph.impl()->tensors[lhs.id()].get()); tensors.add_const_tensor(ACL_SRC_1, graph.impl()->tensors[rhs.id()].get()); tensors.add_const_tensor(ACL_DST_0, graph.impl()->tensors[dst.id()].get()); - graph.impl()->add_node(id.second, desc, tensors); + graph.impl()->add_node(id.second, desc, tensors); + check_multiple_roots(graph); + + return op_node; +} + +Operator add_op_floor(OperatorGraph &graph, const FloorDescriptor &desc, OpTensor src, OpTensor dst) +{ + auto id = graph.impl()->graph.add_operator({ src.id() }, { dst.id() }); + check_dependency_graph_op_success(graph, id.first); + + Operator op_node(id.second); + + // Infer TensorInfo + auto node_src = graph.impl()->tensors[src.id()]->get_tensor_info(); + OpTensorContent *node_dst = graph.impl()->tensors[dst.id()].get(); + + if(node_dst->get_tensor_info()->total_size() == 0) + { + auto_init_if_empty(*(node_dst->get_tensor_info()), *node_src); + } + + // Check execution space + auto dst_info = node_dst->get_tensor_info(); + check_execution_shape(graph, *dst_info); + + ITensorDescPack tensors; + tensors.add_const_tensor(ACL_SRC_0, graph.impl()->tensors[src.id()].get()); + tensors.add_const_tensor(ACL_DST_0, graph.impl()->tensors[dst.id()].get()); + graph.impl()->add_node(id.second, desc, tensors); check_multiple_roots(graph); return op_node; diff --git a/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelDescriptors.h b/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelDescriptors.h index a9ccf908f0..f10e97e3e9 100644 --- a/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelDescriptors.h +++ b/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelDescriptors.h @@ -42,14 +42,24 @@ struct ClDirectConv2dKernelDescriptor Conv2dDescriptor conv2d{}; }; -struct ClEltwiseAddKernelDescriptor +struct ClElementwiseKernelDescriptor { - friend bool operator==(const ClEltwiseAddKernelDescriptor &desc0, const ClEltwiseAddKernelDescriptor &desc1) + friend bool operator==(const ClElementwiseKernelDescriptor &desc0, const ClElementwiseKernelDescriptor &desc1) { - return desc0.add == desc1.add; + return desc0.eltwise == desc1.eltwise; } - AddDescriptor add{}; + ElementwiseDescriptor eltwise{}; }; + +struct ClFloorKernelDescriptor +{ + friend bool operator==(const ClFloorKernelDescriptor &desc0, const ClFloorKernelDescriptor &desc1) + { + return desc0.floor == desc1.floor; + } + FloorDescriptor floor{}; +}; + struct ClActivationKernelDescriptor { friend bool operator==(const ClActivationKernelDescriptor &, const ClActivationKernelDescriptor &) diff --git a/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp b/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp index de58ce70ed..cab51a2ce6 100644 --- a/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp +++ b/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.cpp @@ -124,7 +124,7 @@ bool ClDirectConv2dKernel::operator==(const ClKernel &other) const return config() == other.config() && tensors() == other.tensors() && desc == converted.desc; } -Status ClAddKernel::generate(ClKernelBlueprint &bp) const +Status ClElementwiseKernel::generate(ClKernelBlueprint &bp) const { const auto lhs = _tensors.get_const_tensor(TensorType::ACL_SRC_0); const auto rhs = _tensors.get_const_tensor(TensorType::ACL_SRC_1); @@ -137,11 +137,11 @@ Status ClAddKernel::generate(ClKernelBlueprint &bp) const ArgumentID dst_id; add_tensor(bp, dst->desc, dst_id, dst->id); - add_kcomp_eltwise_add(bp, desc, lhs_id, rhs_id, dst_id); + add_kcomp_eltwise_op(bp, desc, lhs_id, rhs_id, dst_id); return Status{}; } -Status ClAddKernel::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst) +Status ClElementwiseKernel::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst) { // 1. Check validity ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lhs, rhs, dst); @@ -186,9 +186,61 @@ Status ClAddKernel::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, con return Status{}; } -bool ClAddKernel::operator==(const ClKernel &other) const +bool ClElementwiseKernel::operator==(const ClKernel &other) const { - const auto converted = *utils::cast::polymorphic_downcast(&other); + const auto converted = *utils::cast::polymorphic_downcast(&other); + return config() == other.config() && tensors() == other.tensors() && desc == converted.desc; +} + +Status ClFloorKernel::generate(ClKernelBlueprint &bp) const +{ + const auto src = _tensors.get_const_tensor(TensorType::ACL_SRC_0); + const auto dst = _tensors.get_const_tensor(TensorType::ACL_DST_0); + ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst); + ArgumentID src_id; + add_tensor(bp, src->desc, src_id, src->id); + ArgumentID dst_id; + add_tensor(bp, dst->desc, dst_id, dst->id); + + add_kcomp_floor(bp, desc, src_id, dst_id); + return Status{}; +} + +Status ClFloorKernel::validate(const ITensorInfo *src, const ITensorInfo *dst) +{ + // 1. Check validity + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, dst); + + // Matching data type + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, dst); + + // Matching data layout + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(src, dst); + + // All tensor infos are initialized + ARM_COMPUTE_RETURN_ERROR_ON(src->tensor_shape().total_size() == 0); + ARM_COMPUTE_RETURN_ERROR_ON(dst->tensor_shape().total_size() == 0); + + // Device requirements are met + ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(src); + + // dst shape is correct + ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(src->tensor_shape(), dst->tensor_shape(), 0), "Wrong shape for dst"); + + // 2. Check support level + + // Data type + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F32, DataType::F16); + + // Data layout + ARM_COMPUTE_RETURN_ERROR_ON_DATA_LAYOUT_NOT_IN(src, DataLayout::NHWC); + + return Status{}; +} + +bool ClFloorKernel::operator==(const ClKernel &other) const +{ + const auto converted = *utils::cast::polymorphic_downcast(&other); return config() == other.config() && tensors() == other.tensors() && desc == converted.desc; } @@ -202,6 +254,7 @@ std::vector traverse(const ClKernelGraph &graph) } return kernels; } + std::vector traverse(ClKernelGraph &graph) { std::vector kernels; diff --git a/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.h b/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.h index 54e01ea850..c3580cfaca 100644 --- a/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.h +++ b/src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelGraph.h @@ -139,16 +139,16 @@ struct ClDirectConv2dKernel : public ClKernel ClDirectConv2dKernelDescriptor desc{}; }; -struct ClAddKernel : public ClKernel +struct ClElementwiseKernel : public ClKernel { public: Complexity complexity() const override { return Complexity::Simple; } - ClAddKernel() = default; - ~ClAddKernel() override = default; - ClAddKernel(const ClKernelGraph *graph, Id id, const ClKernelConfig &config, const ClEltwiseAddKernelDescriptor &desc, const ITensorDescPack tensors) + ClElementwiseKernel() = default; + ~ClElementwiseKernel() override = default; + ClElementwiseKernel(const ClKernelGraph *graph, Id id, const ClKernelConfig &config, const ClElementwiseKernelDescriptor &desc, const ITensorDescPack tensors) : ClKernel{ graph, id, config, tensors }, desc{ desc } { } @@ -156,7 +156,27 @@ struct ClAddKernel : public ClKernel bool operator==(const ClKernel &other) const override; Status generate(ClKernelBlueprint &bp) const override; - ClEltwiseAddKernelDescriptor desc{}; + ClElementwiseKernelDescriptor desc{}; +}; + +struct ClFloorKernel : public ClKernel +{ +public: + Complexity complexity() const override + { + return Complexity::Simple; + } + ClFloorKernel() = default; + ~ClFloorKernel() override = default; + ClFloorKernel(const ClKernelGraph *graph, Id id, const ClKernelConfig &config, const ClFloorKernelDescriptor &desc, const ITensorDescPack tensors) + : ClKernel{ graph, id, config, tensors }, desc{ desc } + { + } + static Status validate(const ITensorInfo *src, const ITensorInfo *dst); + bool operator==(const ClKernel &other) const override; + Status generate(ClKernelBlueprint &bp) const override; + + ClFloorKernelDescriptor desc{}; }; struct ClKernelGraph diff --git a/src/core/experimental/dynamic_fusion/WorkloadImpl/OperatorGraphImpl.cpp b/src/core/experimental/dynamic_fusion/WorkloadImpl/OperatorGraphImpl.cpp index f971196729..274a2517bb 100644 --- a/src/core/experimental/dynamic_fusion/WorkloadImpl/OperatorGraphImpl.cpp +++ b/src/core/experimental/dynamic_fusion/WorkloadImpl/OperatorGraphImpl.cpp @@ -113,9 +113,14 @@ bool operator==(const Conv2dDescriptor &conv2d0, const Conv2dDescriptor &conv2d1 return std::make_tuple(conv2d0.pad, conv2d0.stride, conv2d0.dilation) == std::make_tuple(conv2d1.pad, conv2d1.stride, conv2d1.dilation); } -bool operator==(const AddDescriptor &, const AddDescriptor &) +bool operator==(const ElementwiseDescriptor &ed0, const ElementwiseDescriptor &ed1) { - return std::make_tuple() == std::make_tuple(); // Currently two Add ops are always the same + return ed0.op == ed1.op; // Compare Arithmatic Operations of two ElementwiseDescriptor objects +} + +bool operator==(const FloorDescriptor &, const FloorDescriptor &) +{ + return std::make_tuple() == std::make_tuple(); // Currently two Floor ops are always the same } bool Conv2dContent::operator==(const OperatorContent &other) const @@ -124,9 +129,15 @@ bool Conv2dContent::operator==(const OperatorContent &other) const return desc == converted.desc; } -bool AddContent::operator==(const OperatorContent &other) const +bool ElementwiseContent::operator==(const OperatorContent &other) const +{ + const auto converted = *utils::cast::polymorphic_downcast(&other); + return desc == converted.desc; +} + +bool FloorContent::operator==(const OperatorContent &other) const { - const auto converted = *utils::cast::polymorphic_downcast(&other); + const auto converted = *utils::cast::polymorphic_downcast(&other); return desc == converted.desc; } @@ -311,7 +322,7 @@ Status Conv2dContent::translate_direct_conv2d(ClKernelGraph &kernel_graph) const return Status{}; } -Status AddContent::translate(ClKernelGraph &kernel_graph) const +Status ElementwiseContent::translate(ClKernelGraph &kernel_graph) const { const auto lhs = _tensors.get_const_tensor(TensorType::ACL_SRC_0); const auto rhs = _tensors.get_const_tensor(TensorType::ACL_SRC_1); @@ -338,16 +349,46 @@ Status AddContent::translate(ClKernelGraph &kernel_graph) const DependencyGraph::Id add_id; ClKernelConfig config{ UnitWorkloadStage{ UnitWorkloadStage::Stage::Run }, TileDescriptor{}, StoreType::TStoreIndirectWidthSelect }; - st = ClAddKernel::validate(lhs->desc, rhs->desc, dst->desc); + st = ClElementwiseKernel::validate(lhs->desc, rhs->desc, dst->desc); ARM_COMPUTE_RETURN_ON_ERROR(st); - st = kernel_graph.add_kernel(config, ClEltwiseAddKernelDescriptor{ desc }, tensors, add_id); + st = kernel_graph.add_kernel(config, ClElementwiseKernelDescriptor{ desc }, tensors, add_id); ARM_COMPUTE_RETURN_ON_ERROR(st); ARM_COMPUTE_UNUSED(add_id); return Status{}; } +Status FloorContent::translate(ClKernelGraph &kernel_graph) const +{ + const auto src = _tensors.get_const_tensor(TensorType::ACL_SRC_0); + const auto dst = _tensors.get_const_tensor(TensorType::ACL_DST_0); + ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst); + + ITensorDescPack tensors; + + DependencyGraph::Id src_id; + auto st = add_kernel_tensor(kernel_graph, *_graph, *src, src_id); + ARM_COMPUTE_RETURN_ON_ERROR(st); + tensors.add_const_tensor(ACL_SRC_0, kernel_graph.get_tensor(src_id)); + + DependencyGraph::Id dst_id; + st = add_kernel_tensor(kernel_graph, *_graph, *dst, dst_id); + ARM_COMPUTE_RETURN_ON_ERROR(st); + tensors.add_const_tensor(ACL_DST_0, kernel_graph.get_tensor(dst_id)); + + DependencyGraph::Id add_id; + ClKernelConfig config{ UnitWorkloadStage{ UnitWorkloadStage::Stage::Run }, TileDescriptor{}, StoreType::TStoreIndirectWidthSelect }; + + st = ClFloorKernel::validate(src->desc, dst->desc); + ARM_COMPUTE_RETURN_ON_ERROR(st); + + st = kernel_graph.add_kernel(config, ClFloorKernelDescriptor{ desc }, tensors, add_id); + ARM_COMPUTE_RETURN_ON_ERROR(st); + + return Status{}; +} + std::vector traverse(const OperatorGraph::Implementation &graph) { std::vector ops; diff --git a/src/core/experimental/dynamic_fusion/WorkloadImpl/OperatorGraphImpl.h b/src/core/experimental/dynamic_fusion/WorkloadImpl/OperatorGraphImpl.h index 2786d610e1..b303cdb9fc 100644 --- a/src/core/experimental/dynamic_fusion/WorkloadImpl/OperatorGraphImpl.h +++ b/src/core/experimental/dynamic_fusion/WorkloadImpl/OperatorGraphImpl.h @@ -157,19 +157,19 @@ struct Conv2dContent : public OperatorContent Status translate_direct_conv2d(ClKernelGraph &kernel_graph) const; }; -class AddContent : public OperatorContent +class ElementwiseContent : public OperatorContent { public: - AddContent() = default; - AddContent(const OperatorGraph::Implementation *graph, Id id, const AddDescriptor &desc, const ITensorDescPack &tensors) + ElementwiseContent() = default; + ElementwiseContent(const OperatorGraph::Implementation *graph, Id id, const ElementwiseDescriptor &desc, const ITensorDescPack &tensors) : OperatorContent(graph, id, tensors), desc(desc) { } - ~AddContent() = default; - AddContent(const AddContent &) = default; - AddContent &operator=(const AddContent &) = default; - AddContent(AddContent &&) = default; - AddContent &operator=(AddContent &&) = default; + ~ElementwiseContent() = default; + ElementwiseContent(const ElementwiseContent &) = default; + ElementwiseContent &operator=(const ElementwiseContent &) = default; + ElementwiseContent(ElementwiseContent &&) = default; + ElementwiseContent &operator=(ElementwiseContent &&) = default; bool operator==(const OperatorContent &other) const override; OperatorComplexity complexity() const override { @@ -178,7 +178,31 @@ class AddContent : public OperatorContent Status translate(ClKernelGraph &kernel_graph) const override; private: - AddDescriptor desc{}; + ElementwiseDescriptor desc{}; +}; + +class FloorContent : public OperatorContent +{ +public: + FloorContent() = default; + FloorContent(const OperatorGraph::Implementation *graph, Id id, const FloorDescriptor &desc, const ITensorDescPack &tensors) + : OperatorContent(graph, id, tensors), desc(desc) + { + } + ~FloorContent() = default; + FloorContent(const FloorContent &) = default; + FloorContent &operator=(const FloorContent &) = default; + FloorContent(FloorContent &&) = default; + FloorContent &operator=(FloorContent &&) = default; + bool operator==(const OperatorContent &other) const override; + OperatorComplexity complexity() const override + { + return OperatorComplexity::Simple; + } + Status translate(ClKernelGraph &kernel_graph) const override; + +private: + FloorDescriptor desc{}; }; struct OperatorGraph::Implementation diff --git a/src/core/utils/AssemblyUtils.cpp b/src/core/utils/AssemblyUtils.cpp index 1e8a2a54c9..45e7ff78be 100644 --- a/src/core/utils/AssemblyUtils.cpp +++ b/src/core/utils/AssemblyUtils.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -66,5 +66,245 @@ arm_conv::PaddingValues map_to_arm_conv_padding(const PadStrideInfo &pad_stride_ pad_stride_info.pad_right(), pad_stride_info.pad_bottom() }; } + +arm_gemm::WeightFormat map_to_arm_gemm_weight_format(const arm_compute::WeightFormat &weight_format) +{ + arm_gemm::WeightFormat gemm_weight_fromat; + + switch(weight_format) + { + case arm_compute::WeightFormat::UNSPECIFIED: + gemm_weight_fromat = arm_gemm::WeightFormat::UNSPECIFIED; + break; + case arm_compute::WeightFormat::ANY: + gemm_weight_fromat = arm_gemm::WeightFormat::ANY; + break; + case arm_compute::WeightFormat::OHWI: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWI; + break; + case arm_compute::WeightFormat::OHWIo2: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo2; + break; + case arm_compute::WeightFormat::OHWIo4: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4; + break; + case arm_compute::WeightFormat::OHWIo8: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8; + break; + case arm_compute::WeightFormat::OHWIo16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16; + break; + case arm_compute::WeightFormat::OHWIo32: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32; + break; + case arm_compute::WeightFormat::OHWIo64: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64; + break; + case arm_compute::WeightFormat::OHWIo128: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo128; + break; + case arm_compute::WeightFormat::OHWIo4i2: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i2; + break; + case arm_compute::WeightFormat::OHWIo4i2_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i2_bf16; + break; + case arm_compute::WeightFormat::OHWIo8i2: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i2; + break; + case arm_compute::WeightFormat::OHWIo8i2_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i2_bf16; + break; + case arm_compute::WeightFormat::OHWIo16i2: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i2; + break; + case arm_compute::WeightFormat::OHWIo16i2_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i2_bf16; + break; + case arm_compute::WeightFormat::OHWIo32i2: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i2; + break; + case arm_compute::WeightFormat::OHWIo32i2_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i2_bf16; + break; + case arm_compute::WeightFormat::OHWIo64i2: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i2; + break; + case arm_compute::WeightFormat::OHWIo64i2_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i2_bf16; + break; + case arm_compute::WeightFormat::OHWIo4i4: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i4; + break; + case arm_compute::WeightFormat::OHWIo4i4_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i4_bf16; + break; + case arm_compute::WeightFormat::OHWIo8i4: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i4; + break; + case arm_compute::WeightFormat::OHWIo8i4_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i4_bf16; + break; + case arm_compute::WeightFormat::OHWIo16i4: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i4; + break; + case arm_compute::WeightFormat::OHWIo16i4_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i4_bf16; + break; + case arm_compute::WeightFormat::OHWIo32i4: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i4; + break; + case arm_compute::WeightFormat::OHWIo32i4_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i4_bf16; + break; + case arm_compute::WeightFormat::OHWIo64i4: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i4; + break; + case arm_compute::WeightFormat::OHWIo64i4_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i4_bf16; + break; + case arm_compute::WeightFormat::OHWIo2i8: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo2i8; + break; + case arm_compute::WeightFormat::OHWIo4i8: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i8; + break; + case arm_compute::WeightFormat::OHWIo8i8: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i8; + break; + case arm_compute::WeightFormat::OHWIo16i8: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i8; + break; + case arm_compute::WeightFormat::OHWIo32i8: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i8; + break; + case arm_compute::WeightFormat::OHWIo64i8: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i8; + break; + default: + gemm_weight_fromat = arm_gemm::WeightFormat::UNSPECIFIED; + } + return gemm_weight_fromat; +} + +arm_compute::WeightFormat map_to_arm_compute_weight_format(const arm_gemm::WeightFormat &weight_format) +{ + arm_compute::WeightFormat acl_weight_fromat; + + switch(weight_format) + { + case arm_gemm::WeightFormat::UNSPECIFIED: + acl_weight_fromat = arm_compute::WeightFormat::UNSPECIFIED; + break; + case arm_gemm::WeightFormat::ANY: + acl_weight_fromat = arm_compute::WeightFormat::ANY; + break; + case arm_gemm::WeightFormat::OHWI: + acl_weight_fromat = arm_compute::WeightFormat::OHWI; + break; + case arm_gemm::WeightFormat::OHWIo2: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo2; + break; + case arm_gemm::WeightFormat::OHWIo4: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo4; + break; + case arm_gemm::WeightFormat::OHWIo8: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo8; + break; + case arm_gemm::WeightFormat::OHWIo16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo16; + break; + case arm_gemm::WeightFormat::OHWIo32: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo32; + break; + case arm_gemm::WeightFormat::OHWIo64: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo64; + break; + case arm_gemm::WeightFormat::OHWIo128: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo128; + break; + case arm_gemm::WeightFormat::OHWIo4i2: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i2; + break; + case arm_gemm::WeightFormat::OHWIo4i2_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i2_bf16; + break; + case arm_gemm::WeightFormat::OHWIo8i2: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i2; + break; + case arm_gemm::WeightFormat::OHWIo8i2_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i2_bf16; + break; + case arm_gemm::WeightFormat::OHWIo16i2: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i2; + break; + case arm_gemm::WeightFormat::OHWIo16i2_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i2_bf16; + break; + case arm_gemm::WeightFormat::OHWIo32i2: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i2; + break; + case arm_gemm::WeightFormat::OHWIo32i2_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i2_bf16; + break; + case arm_gemm::WeightFormat::OHWIo64i2: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i2; + break; + case arm_gemm::WeightFormat::OHWIo64i2_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i2_bf16; + break; + case arm_gemm::WeightFormat::OHWIo4i4: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i4; + break; + case arm_gemm::WeightFormat::OHWIo4i4_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i4_bf16; + break; + case arm_gemm::WeightFormat::OHWIo8i4: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i4; + break; + case arm_gemm::WeightFormat::OHWIo8i4_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i4_bf16; + break; + case arm_gemm::WeightFormat::OHWIo16i4: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i4; + break; + case arm_gemm::WeightFormat::OHWIo16i4_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i4_bf16; + break; + case arm_gemm::WeightFormat::OHWIo32i4: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i4; + break; + case arm_gemm::WeightFormat::OHWIo32i4_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i4_bf16; + break; + case arm_gemm::WeightFormat::OHWIo64i4: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i4; + break; + case arm_gemm::WeightFormat::OHWIo64i4_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i4_bf16; + break; + case arm_gemm::WeightFormat::OHWIo2i8: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo2i8; + break; + case arm_gemm::WeightFormat::OHWIo4i8: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i8; + break; + case arm_gemm::WeightFormat::OHWIo8i8: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i8; + break; + case arm_gemm::WeightFormat::OHWIo16i8: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i8; + break; + case arm_gemm::WeightFormat::OHWIo32i8: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i8; + break; + case arm_gemm::WeightFormat::OHWIo64i8: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i8; + break; + default: + acl_weight_fromat = arm_compute::WeightFormat::UNSPECIFIED; + } + return acl_weight_fromat; +} } // namespace assembly_utils } // namespace arm_compute diff --git a/src/core/utils/AssemblyUtils.h b/src/core/utils/AssemblyUtils.h index b1aee64d5d..7514175ed6 100644 --- a/src/core/utils/AssemblyUtils.h +++ b/src/core/utils/AssemblyUtils.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -47,6 +47,22 @@ arm_gemm::Activation map_to_arm_gemm_activation(const ActivationLayerInfo &act); * @return Assembly padding values. */ arm_conv::PaddingValues map_to_arm_conv_padding(const PadStrideInfo &pad_stride_info); + +/** Performs a mapping from Compute Library WeightFormat to the assembly WeightFormat enum + * + * @param[in] weight_format Compute Library WeightFormat enum value + * + * @return Assembly WeightFormat + */ +arm_gemm::WeightFormat map_to_arm_gemm_weight_format(const arm_compute::WeightFormat &weight_format); + +/** Performs a mapping from Assembly WeightFormat to the Compute Library WeightFormat enum + * + * @param[in] weight_format Assembly WeightFormat enum value + * + * @return Compute Library WeightFormat + */ +arm_compute::WeightFormat map_to_arm_compute_weight_format(const arm_gemm::WeightFormat &weight_format); } // namespace assembly } // namespace arm_compute #endif /* UTILS_CORE_ASSEMBLY_UTILS_H */ diff --git a/src/cpu/kernels/CpuActivationKernel.cpp b/src/cpu/kernels/CpuActivationKernel.cpp index 74148071ae..9eaf44af51 100644 --- a/src/cpu/kernels/CpuActivationKernel.cpp +++ b/src/cpu/kernels/CpuActivationKernel.cpp @@ -45,54 +45,61 @@ namespace { static const std::vector available_kernels = { +#ifdef __aarch64__ + { // Neon LUT implementantion takes precedence + "neon_qu8_activation_lut", + [](const ActivationDataTypeISASelectorData & data) { return ActivationLayerInfo::is_lut_supported(data.f, data.dt); }, + REGISTER_QASYMM8_NEON(arm_compute::cpu::neon_qasymm8_activation_lut) + }, +#endif // __aarch64__ { "sve2_qu8_activation", - [](const DataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8 && data.isa.sve2; }, + [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8 && data.isa.sve2; }, REGISTER_QASYMM8_SVE2(arm_compute::cpu::sve2_qasymm8_activation) }, { "sve2_qs8_activation", - [](const DataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED && data.isa.sve2; }, + [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED && data.isa.sve2; }, REGISTER_QASYMM8_SIGNED_SVE2(arm_compute::cpu::sve2_qasymm8_signed_activation) }, { "sve2_qs16_activation", - [](const DataTypeISASelectorData & data) { return data.dt == DataType::QSYMM16 && data.isa.sve2; }, + [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::QSYMM16 && data.isa.sve2; }, REGISTER_QSYMM16_SVE2(arm_compute::cpu::sve2_qsymm16_activation) }, { "sve_fp16_activation", - [](const DataTypeISASelectorData & data) { return data.dt == DataType::F16 && data.isa.sve && data.isa.fp16; }, + [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::F16 && data.isa.sve && data.isa.fp16; }, REGISTER_FP16_SVE(arm_compute::cpu::sve_fp16_activation) }, { "sve_fp32_activation", - [](const DataTypeISASelectorData & data) { return data.dt == DataType::F32 && data.isa.sve; }, + [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::F32 && data.isa.sve; }, REGISTER_FP32_SVE(arm_compute::cpu::sve_fp32_activation) }, { "neon_fp16_activation", - [](const DataTypeISASelectorData & data) { return data.dt == DataType::F16 && data.isa.fp16; }, + [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::F16 && data.isa.fp16; }, REGISTER_FP16_NEON(arm_compute::cpu::neon_fp16_activation) }, { "neon_fp32_activation", - [](const DataTypeISASelectorData & data) { return data.dt == DataType::F32; }, + [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::F32; }, REGISTER_FP32_NEON(arm_compute::cpu::neon_fp32_activation) }, { "neon_qu8_activation", - [](const DataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8; }, + [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8; }, REGISTER_QASYMM8_NEON(arm_compute::cpu::neon_qasymm8_activation) }, { "neon_qs8_activation", - [](const DataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED; }, + [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::QASYMM8_SIGNED; }, REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::neon_qasymm8_signed_activation) }, { "neon_qs16_activation", - [](const DataTypeISASelectorData & data) { return data.dt == DataType::QSYMM16; }, + [](const ActivationDataTypeISASelectorData & data) { return data.dt == DataType::QSYMM16; }, REGISTER_QSYMM16_NEON(arm_compute::cpu::neon_qsymm16_activation) }, }; @@ -122,7 +129,7 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, const ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::QSYMM16, DataType::F16, DataType::F32); - const auto *uk = CpuActivationKernel::get_implementation(DataTypeISASelectorData{ src->data_type(), CPUInfo::get().get_isa() }); + const auto *uk = CpuActivationKernel::get_implementation(ActivationDataTypeISASelectorData{ src->data_type(), CPUInfo::get().get_isa(), activation_info.activation() }); ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr); @@ -176,14 +183,21 @@ void CpuActivationKernel::configure(const ITensorInfo *src, ITensorInfo *dst, Ac ARM_COMPUTE_ERROR_ON_NULLPTR(src); ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src, dst, activation_info)); - const auto uk = CpuActivationKernel::get_implementation(DataTypeISASelectorData{ src->data_type(), CPUInfo::get().get_isa() }); + const auto uk = CpuActivationKernel::get_implementation(ActivationDataTypeISASelectorData{ src->data_type(), CPUInfo::get().get_isa(), activation_info.activation() }); ARM_COMPUTE_ERROR_ON_NULLPTR(uk); - _act_info = activation_info; _run_method = uk->ukernel; _name = std::string("CpuActivationKernel").append("/").append(uk->name); +#ifdef __aarch64__ + if(ActivationLayerInfo::is_lut_supported(activation_info.activation(), src->data_type())) + { + activation_info.init_lut(src->quantization_info().uniform(),(dst)?dst->quantization_info().uniform():src->quantization_info().uniform()); + } +#endif // __aarch64__ + _act_info = activation_info; + // Configure kernel window auto win_config = validate_and_configure_window(src, dst); ARM_COMPUTE_ERROR_THROW_ON(win_config.first); diff --git a/src/cpu/kernels/CpuActivationKernel.h b/src/cpu/kernels/CpuActivationKernel.h index b0476303f0..d856a9357f 100644 --- a/src/cpu/kernels/CpuActivationKernel.h +++ b/src/cpu/kernels/CpuActivationKernel.h @@ -75,9 +75,9 @@ class CpuActivationKernel : public ICpuKernel struct ActivationKernel { - const char *name; - const DataTypeISASelectorPtr is_selected; - ActivationKernelPtr ukernel; + const char *name; + const ActivationDataTypeISASelectorDataPtr is_selected; + ActivationKernelPtr ukernel; }; static const std::vector &get_available_kernels(); diff --git a/src/cpu/kernels/CpuAddKernel.cpp b/src/cpu/kernels/CpuAddKernel.cpp index e756effea9..85ae410a94 100644 --- a/src/cpu/kernels/CpuAddKernel.cpp +++ b/src/cpu/kernels/CpuAddKernel.cpp @@ -39,82 +39,127 @@ namespace cpu { namespace kernels { +bool can_interpret_inputs_as_1d_array(const ITensorInfo &src0, const ITensorInfo &src1) +{ + return !src0.has_padding() && !src1.has_padding() && src0.tensor_shape() == src1.tensor_shape(); +} + namespace { static const std::vector available_kernels = { + { + "neon_fp32_add_as_1d_array", + [](const CpuAddKernelDataTypeISASelectorData & data) + { + return (data.dt == DataType::F32) && data.can_interpret_inputs_as_1d_array == true; + }, + REGISTER_FP32_NEON(arm_compute::cpu::add_fp32_neon_as_1d_array) + }, + { + "neon_fp16_add_as_1d_array", + [](const CpuAddKernelDataTypeISASelectorData & data) + { + return (data.dt == DataType::F16) && data.can_interpret_inputs_as_1d_array == true; + }, + REGISTER_FP16_NEON(arm_compute::cpu::add_fp16_neon_as_1d_array) + }, + { + "neon_u8_add_as_1d_array", + [](const CpuAddKernelDataTypeISASelectorData & data) + { + return (data.dt == DataType::U8) && data.can_interpret_inputs_as_1d_array == true; + }, + REGISTER_INTEGER_NEON(arm_compute::cpu::add_u8_neon_as_1d_array) + }, + { + "neon_s16_add_as_1d_array", + [](const CpuAddKernelDataTypeISASelectorData & data) + { + return (data.dt == DataType::S16) && data.can_interpret_inputs_as_1d_array == true; + }, + REGISTER_INTEGER_NEON(arm_compute::cpu::add_s16_neon_as_1d_array) + }, + { + "neon_s32_add_as_1d_array", + [](const CpuAddKernelDataTypeISASelectorData & data) + { + return (data.dt == DataType::S32) && data.can_interpret_inputs_as_1d_array == true; + }, + REGISTER_INTEGER_NEON(arm_compute::cpu::add_s32_neon_as_1d_array) + }, { "sve2_qu8_add", - [](const DataTypeISASelectorData & data) + [](const CpuAddKernelDataTypeISASelectorData & data) { - return (data.dt == DataType::QASYMM8) && data.isa.sve2; + return (data.dt == DataType::QASYMM8) && data.isa.sve2 && data.can_interpret_inputs_as_1d_array == false; }, REGISTER_QASYMM8_SVE2(arm_compute::cpu::add_qasymm8_sve2) }, { "sve2_qs8_add", - [](const DataTypeISASelectorData & data) + [](const CpuAddKernelDataTypeISASelectorData & data) { - return (data.dt == DataType::QASYMM8_SIGNED) && data.isa.sve2; + return (data.dt == DataType::QASYMM8_SIGNED) && data.isa.sve2 && data.can_interpret_inputs_as_1d_array == false; }, REGISTER_QASYMM8_SIGNED_SVE2(arm_compute::cpu::add_qasymm8_signed_sve2) }, { "sve2_qs16_add", - [](const DataTypeISASelectorData & data) + [](const CpuAddKernelDataTypeISASelectorData & data) { - return (data.dt == DataType::QSYMM16) && data.isa.sve2; + return (data.dt == DataType::QSYMM16) && data.isa.sve2 && data.can_interpret_inputs_as_1d_array == false; }, REGISTER_QSYMM16_SVE2(arm_compute::cpu::add_qsymm16_sve2) }, { "sve_fp32_add", - [](const DataTypeISASelectorData & data) + [](const CpuAddKernelDataTypeISASelectorData & data) { - return (data.dt == DataType::F32) && data.isa.sve; + return (data.dt == DataType::F32) && data.isa.sve && data.can_interpret_inputs_as_1d_array == false; }, REGISTER_FP32_SVE(arm_compute::cpu::add_fp32_sve) }, { "sve_fp16_add", - [](const DataTypeISASelectorData & data) + [](const CpuAddKernelDataTypeISASelectorData & data) { - return (data.dt == DataType::F16) && data.isa.sve && data.isa.fp16; + return (data.dt == DataType::F16) && data.isa.sve && data.isa.fp16 && data.can_interpret_inputs_as_1d_array == false; }, REGISTER_FP16_SVE(arm_compute::cpu::add_fp16_sve) }, { "sve_u8_add", - [](const DataTypeISASelectorData & data) + [](const CpuAddKernelDataTypeISASelectorData & data) { - return (data.dt == DataType::U8) && data.isa.sve; + return (data.dt == DataType::U8) && data.isa.sve && data.can_interpret_inputs_as_1d_array == false; }, REGISTER_INTEGER_SVE(arm_compute::cpu::add_u8_sve) }, { "sve_s16_add", - [](const DataTypeISASelectorData & data) + [](const CpuAddKernelDataTypeISASelectorData & data) { - return (data.dt == DataType::S16) && data.isa.sve; + return (data.dt == DataType::S16) && data.isa.sve && data.can_interpret_inputs_as_1d_array == false; }, REGISTER_INTEGER_SVE(arm_compute::cpu::add_s16_sve) }, { "sve_s32_add", - [](const DataTypeISASelectorData & data) + [](const CpuAddKernelDataTypeISASelectorData & data) { - return (data.dt == DataType::S32) && data.isa.sve; + return (data.dt == DataType::S32) && data.isa.sve && data.can_interpret_inputs_as_1d_array == false; }, REGISTER_INTEGER_SVE(arm_compute::cpu::add_s32_sve) }, { "neon_fp32_add", - [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F32); }, + [](const CpuAddKernelDataTypeISASelectorData & data) { return (data.dt == DataType::F32); }, REGISTER_FP32_NEON(arm_compute::cpu::add_fp32_neon) }, { "neon_fp16_add", - [](const DataTypeISASelectorData & data) + [](const CpuAddKernelDataTypeISASelectorData & data) { return (data.dt == DataType::F16) && data.isa.fp16; }, @@ -122,32 +167,32 @@ static const std::vector available_kernels = }, { "neon_u8_add", - [](const DataTypeISASelectorData & data) { return (data.dt == DataType::U8); }, + [](const CpuAddKernelDataTypeISASelectorData & data) { return (data.dt == DataType::U8); }, REGISTER_INTEGER_NEON(arm_compute::cpu::add_u8_neon) }, { "neon_s16_add", - [](const DataTypeISASelectorData & data) { return (data.dt == DataType::S16); }, + [](const CpuAddKernelDataTypeISASelectorData & data) { return (data.dt == DataType::S16); }, REGISTER_INTEGER_NEON(arm_compute::cpu::add_s16_neon) }, { "neon_s32_add", - [](const DataTypeISASelectorData & data) { return (data.dt == DataType::S32); }, + [](const CpuAddKernelDataTypeISASelectorData & data) { return (data.dt == DataType::S32); }, REGISTER_INTEGER_NEON(arm_compute::cpu::add_s32_neon) }, { "neon_qu8_add", - [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8); }, + [](const CpuAddKernelDataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8); }, REGISTER_QASYMM8_NEON(arm_compute::cpu::add_qasymm8_neon) }, { "neon_qs8_add", - [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); }, + [](const CpuAddKernelDataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); }, REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::add_qasymm8_signed_neon) }, { "neon_qs16_add", - [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QSYMM16); }, + [](const CpuAddKernelDataTypeISASelectorData & data) { return (data.dt == DataType::QSYMM16); }, REGISTER_QSYMM16_NEON(arm_compute::cpu::add_qsymm16_neon) } }; @@ -177,7 +222,8 @@ Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, cons "Wrong shape for dst"); } - const auto *uk = CpuAddKernel::get_implementation(DataTypeISASelectorData{ src0.data_type(), CPUInfo::get().get_isa() }); + const auto uk = CpuAddKernel::get_implementation(CpuAddKernelDataTypeISASelectorData{ src0.data_type(), + CPUInfo::get().get_isa(), can_interpret_inputs_as_1d_array(src0, src1) }); ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr); return Status{}; @@ -185,16 +231,25 @@ Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, cons std::pair validate_and_configure_window(const ITensorInfo &src0, const ITensorInfo &src1, ITensorInfo &dst) { - const TensorShape &out_shape = TensorShape::broadcast_shape(src0.tensor_shape(), src1.tensor_shape()); + if(can_interpret_inputs_as_1d_array(src0, src1)) + { + Window window; + window.set(0, Window::Dimension(0, src0.tensor_shape().total_size())); + return std::make_pair(Status{}, window); + } + else + { + const TensorShape &out_shape = TensorShape::broadcast_shape(src0.tensor_shape(), src1.tensor_shape()); - // Auto initialize dst if not initialized - set_shape_if_empty(dst, out_shape); - set_data_type_if_unknown(dst, src0.data_type()); + // Auto initialize dst if not initialized + set_shape_if_empty(dst, out_shape); + set_data_type_if_unknown(dst, src0.data_type()); - Window win = calculate_max_window(out_shape, Steps()); + Window win = calculate_max_window(out_shape, Steps()); - // CpuAddKernel doesn't need padding so update_window_and_padding() can be skipped - return std::make_pair(Status{}, win); + // CpuAddKernel doesn't need padding so update_window_and_padding() can be skipped + return std::make_pair(Status{}, win); + } } } // namespace @@ -203,7 +258,9 @@ void CpuAddKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, I ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst); ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*src0, *src1, *dst, policy)); - const auto uk = CpuAddKernel::get_implementation(DataTypeISASelectorData{ src0->data_type(), CPUInfo::get().get_isa() }); + _can_interpret_inputs_as_1d_array = can_interpret_inputs_as_1d_array(*src0, *src1); + const auto uk = CpuAddKernel::get_implementation(CpuAddKernelDataTypeISASelectorData{ src0->data_type(), + CPUInfo::get().get_isa(), _can_interpret_inputs_as_1d_array }); ARM_COMPUTE_ERROR_ON_NULLPTR(uk); diff --git a/src/cpu/kernels/CpuAddKernel.h b/src/cpu/kernels/CpuAddKernel.h index 6638135580..1afbc1a4d0 100644 --- a/src/cpu/kernels/CpuAddKernel.h +++ b/src/cpu/kernels/CpuAddKernel.h @@ -42,9 +42,9 @@ class CpuAddKernel : public ICpuKernel public: struct AddKernel { - const char *name; - const DataTypeISASelectorPtr is_selected; - AddKernelPtr ukernel; + const char *name; + const CpuAddKernelDataTypeISASelectorDataPtr is_selected; + AddKernelPtr ukernel; }; CpuAddKernel() = default; @@ -91,10 +91,16 @@ class CpuAddKernel : public ICpuKernel static const std::vector &get_available_kernels(); + bool get_can_interpret_inputs_as_1d_array() + { + return _can_interpret_inputs_as_1d_array; + } + private: ConvertPolicy _policy{}; AddKernelPtr _run_method{ nullptr }; std::string _name{}; + bool _can_interpret_inputs_as_1d_array{ false }; }; } // namespace kernels } // namespace cpu diff --git a/src/cpu/kernels/CpuIm2ColKernel.cpp b/src/cpu/kernels/CpuIm2ColKernel.cpp index 875d66594f..25ff6c291c 100644 --- a/src/cpu/kernels/CpuIm2ColKernel.cpp +++ b/src/cpu/kernels/CpuIm2ColKernel.cpp @@ -359,11 +359,11 @@ void CpuIm2ColKernel::configure(const ITensorInfo *src, ITensorInfo *dst, const case DataType::F32: _func = (!conv_info.has_padding()) ? &CpuIm2ColKernel::run_im2col : &CpuIm2ColKernel::run_im2col; break; -#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) +#if defined(ARM_COMPUTE_ENABLE_BF16) case DataType::BFLOAT16: _func = (!conv_info.has_padding()) ? &CpuIm2ColKernel::run_im2col : &CpuIm2ColKernel::run_im2col; break; -#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ +#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: _func = (!conv_info.has_padding()) ? &CpuIm2ColKernel::run_im2col : &CpuIm2ColKernel::run_im2col; @@ -385,11 +385,11 @@ void CpuIm2ColKernel::configure(const ITensorInfo *src, ITensorInfo *dst, const case DataType::F32: _func = (!conv_info.has_padding()) ? &CpuIm2ColKernel::run_im2col : &CpuIm2ColKernel::run_im2col; break; -#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) +#if defined(ARM_COMPUTE_ENABLE_BF16) case DataType::BFLOAT16: _func = (!conv_info.has_padding()) ? &CpuIm2ColKernel::run_im2col : &CpuIm2ColKernel::run_im2col; break; -#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ +#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: _func = (!conv_info.has_padding()) ? &CpuIm2ColKernel::run_im2col : &CpuIm2ColKernel::run_im2col; @@ -453,4 +453,4 @@ size_t CpuIm2ColKernel::get_mws(const CPUInfo &platform, size_t thread_count) co } } // namespace kernels } // namespace cpu -} // namespace arm_compute \ No newline at end of file +} // namespace arm_compute diff --git a/src/cpu/kernels/CpuKernelSelectionTypes.h b/src/cpu/kernels/CpuKernelSelectionTypes.h index afcf014ad2..19c41f9fcd 100644 --- a/src/cpu/kernels/CpuKernelSelectionTypes.h +++ b/src/cpu/kernels/CpuKernelSelectionTypes.h @@ -75,6 +75,21 @@ struct DepthwiseConv2dNativeDataTypeISASelectorData DataType source_dt; const cpuinfo::CpuIsaInfo &isa; }; + +struct ActivationDataTypeISASelectorData +{ + DataType dt; + const cpuinfo::CpuIsaInfo &isa; + ActivationLayerInfo::ActivationFunction f; +}; + +struct CpuAddKernelDataTypeISASelectorData +{ + DataType dt; + cpuinfo::CpuIsaInfo isa; + bool can_interpret_inputs_as_1d_array; +}; + // Selector pointer types using DataTypeISASelectorPtr = std::add_pointer::type; using DataTypeDataLayoutSelectorPtr = std::add_pointer::type; @@ -82,9 +97,11 @@ using PoolDataTypeISASelectorPtr = std::add_pointer::type; using DepthwiseConv2dNativeDataTypeISASelectorPtr = std::add_pointer::type; using CastDataTypeISASelectorDataPtr = std::add_pointer::type; +using ActivationDataTypeISASelectorDataPtr = std::add_pointer::type; +using CpuAddKernelDataTypeISASelectorDataPtr = std::add_pointer::type; } // namespace kernels } // namespace cpu } // namespace arm_compute -#endif // ARM_COMPUTE_CPU_KERNEL_SELECTION_TYPES_H \ No newline at end of file +#endif // ARM_COMPUTE_CPU_KERNEL_SELECTION_TYPES_H diff --git a/src/cpu/kernels/CpuWinogradConv2dKernel.cpp b/src/cpu/kernels/CpuWinogradConv2dKernel.cpp index 803af09a67..818d878119 100644 --- a/src/cpu/kernels/CpuWinogradConv2dKernel.cpp +++ b/src/cpu/kernels/CpuWinogradConv2dKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,531 +21,95 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#include "src/cpu/kernels/CpuWinogradConv2dKernel.h" - -#include "arm_compute/core/Error.h" -#include "arm_compute/core/Helpers.h" -#include "arm_compute/core/ITensor.h" -#include "arm_compute/core/TensorInfo.h" -#include "arm_compute/core/Validate.h" -#include "arm_compute/core/Window.h" -#include "arm_compute/core/utils/misc/ShapeCalculator.h" -#include "src/core/NEON/kernels/convolution/common/utils.hpp" -#include "src/core/NEON/kernels/convolution/winograd/winograd_layer.hpp" -#include "src/core/helpers/AutoConfiguration.h" -#include "src/core/helpers/WindowHelpers.h" -#include +#include "src/cpu/kernels/CpuWinogradConv2dKernel.h" namespace arm_compute { namespace cpu { -//Batched Gemms - -namespace -{ -inline bool is_kernel_size_supported(DataType data_type, Size2D size) +CpuWinogradConv2dTransformInputKernel::CpuWinogradConv2dTransformInputKernel(arm_conv::winograd::WinogradImpl &w_impl, arm_conv::ConvolutionArgs &_c_args, uint32_t nthreads) + : _winograd_impl{ w_impl }, _conv_args{ _c_args }, _nthreads{ nthreads } { - const std::array f32_support = { { Size2D(1, 3), Size2D(3, 1), Size2D(5, 5), Size2D(3, 3), Size2D(1, 5), Size2D(5, 1), Size2D(7, 1), Size2D(1, 7) } }; - const std::array f16_support = { { Size2D(3, 3) } }; - - switch(data_type) - { - case DataType::F16: - return std::end(f16_support) != std::find(std::begin(f16_support), std::end(f16_support), size); - case DataType::F32: - return std::end(f32_support) != std::find(std::begin(f32_support), std::end(f32_support), size); - default: - return false; - } } -Status validate_arguments_winograd_weight_trans(const ITensorInfo *input, const ITensorInfo *output, const WinogradInfo &winograd_info) +void CpuWinogradConv2dTransformInputKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) { - ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input); - ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32); + ARM_COMPUTE_UNUSED(window); + const ITensor *input_nhwc = tensors.get_const_tensor(TensorType::ACL_SRC); + const ITensor *winograd_input_transform = tensors.get_const_tensor(TensorType::ACL_DST); + const ITensor *workspace = tensors.get_const_tensor(TensorType::ACL_INT); - const size_t idx_width = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH); - const size_t idx_height = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT); - const auto input_width = input->dimension(idx_width); - const auto input_height = input->dimension(idx_height); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!is_kernel_size_supported(input->data_type(), Size2D(input_width, input_height)), - "Only 1x3, 3x1, 1x5, 5x1, 7x1, 1x7, 3x3 and 5x5 kernels are supported"); - ARM_COMPUTE_RETURN_ERROR_ON(input->num_dimensions() > 4); - const Size2D &output_tile = winograd_info.output_tile_size; - const std::array supported_tile_sizes = { { Size2D(2U, 2U), Size2D(4U, 4U), Size2D(1U, 6U), Size2D(6U, 1U), Size2D(4, 1), Size2D(1, 4), Size2D(2, 1), Size2D(1, 2) } }; - ARM_COMPUTE_RETURN_ERROR_ON(std::end(supported_tile_sizes) == std::find(std::begin(supported_tile_sizes), std::end(supported_tile_sizes), output_tile)); + const unsigned int width_idx = 1; + const unsigned int height_idx = 2; + const unsigned int batch_idx = 3; + int element_size_in_bytes = input_nhwc->info()->element_size(); + const auto src_strides = input_nhwc->info()->strides_in_bytes(); - // Checks performed when output is configured - if(output->total_size() != 0) - { - const TensorInfo tensor_info_output = input->clone()->set_tensor_shape(arm_compute::misc::shape_calculator::compute_winograd_filter_transform_shape(*input, winograd_info)); + const size_t input_row_stride = src_strides[height_idx] / element_size_in_bytes; + const size_t input_col_stride = src_strides[width_idx] / element_size_in_bytes; + const size_t input_batch_stride = src_strides[batch_idx] / element_size_in_bytes; + const auto input_nhwc_ptr = reinterpret_cast(input_nhwc->buffer() + input_nhwc->info()->offset_first_element_in_bytes()); + auto win_transf_ptr = reinterpret_cast(winograd_input_transform->buffer() + winograd_input_transform->info()->offset_first_element_in_bytes()); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_output); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); - } - - return Status{}; + _winograd_impl.input_transform->execute( + _conv_args, + input_nhwc_ptr, + input_batch_stride, + input_row_stride, + input_col_stride, + win_transf_ptr, + _winograd_impl.winograd_spec, + workspace->buffer(), + info.thread_id, + _nthreads); } -std::pair validate_and_configure_window_winograd_weight_trans(ITensorInfo *input, ITensorInfo *output, const WinogradInfo &winograd_info) +CpuWinogradConv2dTransformOutputKernel::CpuWinogradConv2dTransformOutputKernel(arm_conv::winograd::WinogradImpl &w_impl, arm_conv::ConvolutionArgs &_c_args, uint32_t nthreads) + : _winograd_impl{ w_impl }, _conv_args{ _c_args }, _nthreads{ nthreads } { - // Output tensor auto inizialitation if not yet initialized - auto_init_if_empty(*output, input->clone()->set_tensor_shape(arm_compute::misc::shape_calculator::compute_winograd_filter_transform_shape(*input, winograd_info))); - const Window win = calculate_max_window(*input, Steps(), true /* skip border*/); - return std::make_pair(Status{}, win); } -Status validate_arguments_winograd_input_trans(const ITensorInfo *input, const ITensorInfo *output, const WinogradInfo &winograd_info) +// Inherited methods overridden: +void CpuWinogradConv2dTransformOutputKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) { - const Size2D &kernel_dims = winograd_info.kernel_size; - const PadStrideInfo &conv_info = winograd_info.convolution_info; - ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input); - ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.stride().first != 1 || conv_info.stride().second != 1, "Winograd input transform only supports unit strides"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!is_kernel_size_supported(input->data_type(), Size2D(kernel_dims.width, kernel_dims.height)), - "Only 1x3, 3x1, 3x3 and 5x5 kernels are supported"); - - // Validate configured output - if(output->total_size() != 0) - { - const TensorShape output_shape = misc::shape_calculator::compute_winograd_input_transform_shape(*input, winograd_info); + ARM_COMPUTE_UNUSED(window); + const ITensor *dst_nhwc = tensors.get_const_tensor(TensorType::ACL_DST); + const ITensor *winograd_output_transform = tensors.get_const_tensor(TensorType::ACL_SRC_0); + const ITensor *biases = tensors.get_const_tensor(TensorType::ACL_SRC_1); + const ITensor *workspace = tensors.get_tensor(TensorType::ACL_INT); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); - } - - return Status{}; -} + const unsigned int width_idx = 1; + const unsigned int height_idx = 2; + const unsigned int batch_idx = 3; + const int element_size_in_bytes = dst_nhwc->info()->element_size(); + const auto dst_strides = dst_nhwc->info()->strides_in_bytes(); -std::pair validate_and_configure_window_winograd_input_trans(ITensorInfo *input, ITensorInfo *output, const WinogradInfo &winograd_info) -{ - const TensorShape output_shape = misc::shape_calculator::compute_winograd_input_transform_shape(*input, winograd_info); - // Output auto inizialitation if not yet initialized - auto_init_if_empty(*output, input->clone()->set_tensor_shape(output_shape)); - return std::make_pair(Status{}, calculate_max_window(*input, Steps(), true)); -} - -Status validate_arguments_winograd_output_trans(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const WinogradInfo &winograd_info) -{ - const PadStrideInfo &conv_info = winograd_info.convolution_info; - const Size2D kernel_dims = winograd_info.kernel_size; - - // Number of tiles along the X and Y direction - const unsigned int num_tiles_x = std::ceil((winograd_info.input_dimensions.x() - (kernel_dims.width - 1) + conv_info.pad_left() + conv_info.pad_right()) / static_cast - (winograd_info.output_tile_size.width)); - const unsigned int num_tiles_y = std::ceil((winograd_info.input_dimensions.y() - (kernel_dims.height - 1) + conv_info.pad_top() + conv_info.pad_bottom()) / static_cast - (winograd_info.output_tile_size.height)); - const Size2D num_tiles = Size2D(num_tiles_x, num_tiles_y); - - ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input); - ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(1) != num_tiles.area()); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!is_kernel_size_supported(input->data_type(), Size2D(kernel_dims.width, kernel_dims.height)), - "Only 1x3, 3x1, 3x3 and 5x5 kernels are supported"); - - const std::array supported_gemm_sizes = { { 8U, 16U, 36U } }; - ARM_COMPUTE_RETURN_ERROR_ON(std::end(supported_gemm_sizes) == std::find(std::begin(supported_gemm_sizes), std::end(supported_gemm_sizes), input->dimension(2))); - ARM_COMPUTE_UNUSED(kernel_dims); - if(bias != nullptr) - { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, bias); - ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(0) != bias->dimension(0)); - ARM_COMPUTE_RETURN_ERROR_ON(bias->num_dimensions() != size_t(1)); - } - - // Checks performed when output is configured - if(output->total_size() != 0) + const size_t out_row_stride = dst_strides[height_idx] / element_size_in_bytes; + const size_t out_col_stride = dst_strides[width_idx] / element_size_in_bytes; + const size_t out_batch_stride = dst_strides[batch_idx] / element_size_in_bytes; + const auto wout_transf_ptr = reinterpret_cast(winograd_output_transform->buffer() + winograd_output_transform->info()->offset_first_element_in_bytes()); + auto dst_nhwc_ptr = reinterpret_cast(dst_nhwc->buffer() + dst_nhwc->info()->offset_first_element_in_bytes()); + void *biases_data_ptr = nullptr; + if(biases != nullptr) { - const TensorInfo tensor_info_output = input->clone()->set_tensor_shape(arm_compute::misc::shape_calculator::compute_winograd_output_transform_shape(*input, winograd_info)); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_output); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); + biases_data_ptr = reinterpret_cast(biases->buffer() + biases->info()->offset_first_element_in_bytes()); } - return Status{}; -} - -std::pair validate_and_configure_window_winograd_output_trans(ITensorInfo *input, ITensorInfo *output, const WinogradInfo &winograd_info) -{ - // Output tensor auto initialization if not yet initialized - auto_init_if_empty(*output, input->clone()->set_tensor_shape(arm_compute::misc::shape_calculator::compute_winograd_output_transform_shape(*input, winograd_info))); - - return std::make_pair(Status{}, calculate_max_window(*input, Steps(), true)); -} -} // namespace - -Status ICpuWinogradConv2dTransformWeightsKernel::validate(const ITensorInfo *input, const ITensorInfo *weights) -{ - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights); - const DataLayout data_layout = input->data_layout(); - const unsigned int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); - const unsigned int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!is_kernel_size_supported(input->data_type(), Size2D(weights->dimension(width_idx), weights->dimension(height_idx))), - "Only 1x3, 3x1, 3x3 and 5x5 kernels are supported"); - ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4); - return Status{}; -} - -template -unsigned int CpuWinogradConv2dTransformWeightsKernel::get_weight_storage_size(int num_output_channels, int num_input_channels) const -{ - const KernelShape shape(num_output_channels, KernelRows, KernelCols, num_input_channels); - // WinogradConv returns the size in bytes, we divide by `sizeof(T)` to express that in units of T - return static_cast(WinogradConv::get_kernel_storage_size(num_input_channels, num_output_channels) / sizeof(T)); -} - -template -CpuWinogradConv2dTransformWeightsKernel::CpuWinogradConv2dTransformWeightsKernel() - : _transform(nullptr), _num_output_channels(0), _matrix_stride(0) -{ -} - -template -int CpuWinogradConv2dTransformWeightsKernel::get_matrix_stride(int num_output_channels, int num_input_channels) const -{ - return WinogradConv::get_kernel_matrix_stride(num_input_channels, num_output_channels); -} - -#ifndef DOXYGEN_SKIP_THIS -template -void CpuWinogradConv2dTransformWeightsKernel::configure( - const ITensorInfo *weights_hwio, - ITensorInfo *output, - const int matrix_stride, /** Stride across matrices in the output. */ - const int num_output_channels, /** Number of filters. */ - const int num_input_channels) /** Number of channels in each filter. */ -{ - ARM_COMPUTE_UNUSED(weights_hwio, output); - - _transform = std::make_unique(num_output_channels, num_input_channels); - _num_output_channels = num_output_channels; - _matrix_stride = matrix_stride; - - Window win; - auto win_last = _transform->get_window(); - win.set(Window::DimX, Window::Dimension(0, win_last, 1)); - ICpuKernel::configure(win); -} -#endif /* DOXYGEN_SKIP_THIS */ - -template -void CpuWinogradConv2dTransformWeightsKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) -{ - ARM_COMPUTE_UNUSED(info); - ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); - ARM_COMPUTE_ERROR_ON(tensors.empty()); - - const size_t fst = window.x().start(); - const size_t lst = window.x().end(); - - const ITensor *weights_hwio = tensors.get_const_tensor(TensorType::ACL_SRC); - ITensor *output = tensors.get_tensor(TensorType::ACL_DST); - - _transform->set_weight_tensor(weights_hwio->buffer()); - const int matrix_row_stride = roundup(_num_output_channels, WinogradConv::N_BLOCK); - _transform->set_output_matrices(output->buffer(), _matrix_stride, matrix_row_stride); - _transform->set_working_space(output->buffer()); - - _transform->run(fst, lst); -} - -template -bool CpuWinogradConv2dTransformWeightsKernel::is_parallelisable() const -{ - return false; -} - -template -Status CpuWinogradConv2dTransformWeightsKernel::validate(const ITensorInfo *input, const ITensorInfo *output, - const WinogradInfo &winograd_info) -{ - ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_winograd_weight_trans(input, output, winograd_info)); - ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_winograd_weight_trans(input->clone().get(), output->clone().get(), winograd_info).first); - return Status{}; -} - -template class CpuWinogradConv2dTransformWeightsKernel; -template class CpuWinogradConv2dTransformWeightsKernel; -template class CpuWinogradConv2dTransformWeightsKernel; -template class CpuWinogradConv2dTransformWeightsKernel; -template class CpuWinogradConv2dTransformWeightsKernel; - -template class CpuWinogradConv2dTransformWeightsKernel; -template class CpuWinogradConv2dTransformWeightsKernel; -template class CpuWinogradConv2dTransformWeightsKernel; -template class CpuWinogradConv2dTransformWeightsKernel; - -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -template class CpuWinogradConv2dTransformWeightsKernel<__fp16, 4, 4, 3, 3>; -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - -// Input transform - -template -unsigned int CpuWinogradConv2dTransformInputKernel::get_input_storage_size( - int num_batches, /* Number of batches in the input tensor. */ - int num_channels, /* Number of feature maps in the input tensor. */ - int num_rows, /* Number of rows in each feature map. */ - int num_cols, /* Number of columns in each feature map. */ - bool same_padding /* Use "SAME" padding, otherwise use "VALID". */ -) const -{ - // Construct shapes for the input and kernel tensors. - const Tensor4DShape input_shape(num_batches, num_rows, num_cols, num_channels); - const KernelShape kern_shape(1, KernelRows, KernelCols, num_channels); - // Return the size, converted into units of TIn - return static_cast(WinogradConv::get_input_storage_size(num_batches, num_rows, num_cols, num_channels, same_padding) / sizeof(T)); -} - -template -unsigned int CpuWinogradConv2dTransformInputKernel::get_working_space_size(unsigned int num_threads) const -{ - return _transform->get_working_space_size(num_threads); -} - -template -int CpuWinogradConv2dTransformInputKernel::get_matrix_stride( - int num_batches, /* Number of batches in the input tensor. */ - int num_channels, /* Number of feature maps in the input tensor. */ - int num_rows, /* Number of rows in each feature map. */ - int num_cols, /* Number of columns in each feature map. */ - bool same_padding /* Use "SAME" padding, otherwise use "VALID". */) const -{ - return WinogradConv::get_input_matrix_stride(num_batches, num_rows, num_cols, num_channels, same_padding); -} - -template -CpuWinogradConv2dTransformInputKernel::CpuWinogradConv2dTransformInputKernel() - : _transform(nullptr), _num_channels(0), _matrix_stride(0) -{ -} - -template -void CpuWinogradConv2dTransformInputKernel::configure( - const ITensorInfo *input_nhwc, - const int num_batches, /* Number of batches in input tensor. */ - const int num_rows, /* Number of rows in input tensor. */ - const int num_cols, /* Number of columns in input tensor. */ - const int num_channels, /* Number of channels in input tensor. */ - const PaddingType padding, /* Padding type. */ - ITensorInfo *output, /* Base of output matrices. */ - const int matrix_stride, /* Stride between output matrices. */ - ITensorInfo *workspace) -{ - ARM_COMPUTE_UNUSED(input_nhwc, output, matrix_stride, workspace); - - _num_channels = num_channels; - _matrix_stride = matrix_stride; - - const int padding_top = (padding == PADDING_SAME) ? (KernelRows - 1) / 2 : 0; - const int padding_left = (padding == PADDING_SAME) ? (KernelCols - 1) / 2 : 0; - const int padding_bottom = (padding == PADDING_SAME) ? iceildiv(KernelRows - 1, 2) : 0; - const int padding_right = (padding == PADDING_SAME) ? iceildiv(KernelCols - 1, 2) : 0; - - _transform = std::make_unique( - KernelRows, - KernelCols, - num_batches, - num_rows, - num_cols, - num_channels, - padding_top, /**< Padding to apply to the top of the image. */ - padding_left, /**< Padding to apply to the left of the image. */ - padding_bottom, /**< Padding to apply to the bottom of the image. */ - padding_right /**< Padding to apply to the right of the image. */ - ); - - Window win; - auto win_last = _transform->get_window(); - win.set(Window::DimX, Window::Dimension(0, win_last, 1)); - ICpuKernel::configure(win); -} - -template -void CpuWinogradConv2dTransformInputKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) -{ - ARM_COMPUTE_UNUSED(info); - ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); - ARM_COMPUTE_ERROR_ON(tensors.empty()); - - const ITensor *input_nhwc = tensors.get_const_tensor(TensorType::ACL_SRC); - const ITensor *workspace = tensors.get_const_tensor(TensorType::ACL_INT); - ITensor *output = tensors.get_tensor(TensorType::ACL_DST); - - const int element_size_in_bytes = input_nhwc->info()->element_size(); - const int input_col_stride = input_nhwc->info()->strides_in_bytes().y() / element_size_in_bytes; - const int input_row_stride = input_nhwc->info()->strides_in_bytes().z() / element_size_in_bytes; - const int input_batch_stride = input_nhwc->info()->strides_in_bytes()[3] / element_size_in_bytes; - const auto input_nhwc_ptr = reinterpret_cast(input_nhwc->buffer() + input_nhwc->info()->offset_first_element_in_bytes()); - auto output_ptr = reinterpret_cast(output->buffer() + output->info()->offset_first_element_in_bytes()); - ARM_COMPUTE_ERROR_ON_NULLPTR(output_ptr); - - _transform->set_input_tensor(input_nhwc_ptr, input_batch_stride, input_row_stride, input_col_stride); - _transform->set_output_matrices(output_ptr, _matrix_stride, _num_channels); - - _transform->set_working_space(workspace->buffer()); - - // The code below cannot be moved to configure because biases hasn't been allocated at that point - const size_t fst = window.x().start(); - const size_t lst = window.x().end(); - _transform->run(fst, lst, info.thread_id); -} - -template -Status CpuWinogradConv2dTransformInputKernel::validate(const ITensorInfo *input, const ITensorInfo *output, - const WinogradInfo &winograd_info) -{ - ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_winograd_input_trans(input, output, winograd_info)); - ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_winograd_input_trans(input->clone().get(), output->clone().get(), winograd_info).first); - - return Status{}; -} - -template class CpuWinogradConv2dTransformInputKernel; -template class CpuWinogradConv2dTransformInputKernel; -template class CpuWinogradConv2dTransformInputKernel; -template class CpuWinogradConv2dTransformInputKernel; -template class CpuWinogradConv2dTransformInputKernel; - -template class CpuWinogradConv2dTransformInputKernel; -template class CpuWinogradConv2dTransformInputKernel; -template class CpuWinogradConv2dTransformInputKernel; -template class CpuWinogradConv2dTransformInputKernel; - -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -template class CpuWinogradConv2dTransformInputKernel<__fp16, 4, 4, 3, 3>; -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -// Output transform - -template -unsigned int CpuWinogradConv2dTransformOutputKernel::get_output_storage_size( - int num_batches, /* Number of batches in the output tensor. */ - int num_rows, /* Number of rows in each feature map of the input tensor. */ - int num_cols, /* Number of columns in each feature map of the input tensor. */ - int num_output_channels /* Number of feature maps in the output tensor. */ -) const -{ - // Construct shapes for the input and kernel tensors. - const Tensor4DShape input_shape(num_batches, num_rows, num_cols, 1); - const KernelShape kern_shape(num_output_channels, KernelRows, KernelCols, 1); - // Return the size, converted into units of TOut - return static_cast( - WinogradConv::get_output_storage_size(num_batches, num_rows, num_cols, num_output_channels) / sizeof(T)); -} - -template -CpuWinogradConv2dTransformOutputKernel::CpuWinogradConv2dTransformOutputKernel() - : _transform(nullptr), _matrix_stride(0), _matrix_row_stride(0) -{ + // Output transform + _winograd_impl.output_transform->execute( + _conv_args, + wout_transf_ptr, + _winograd_impl.winograd_spec, + biases_data_ptr, + dst_nhwc_ptr, + out_batch_stride, + out_row_stride, + out_col_stride, + workspace->buffer(), + info.thread_id, + _nthreads); } -template -unsigned int CpuWinogradConv2dTransformOutputKernel::get_working_space_size(unsigned int num_threads) const -{ - return _transform->get_working_space_size(num_threads); -} - -template -int CpuWinogradConv2dTransformOutputKernel::get_matrix_stride( - int num_batches, /* Number of batches in the output tensor. */ - int num_rows, /* Number of rows in each feature map of the input tensor. */ - int num_cols, /* Number of columns in each feature map of the input tensor. */ - int num_output_channels /* Number of feature maps in the output tensor. */ -) const -{ - return WinogradConv::get_output_matrix_stride(num_batches, num_rows, num_cols, num_output_channels); -} - -template -std::pair CpuWinogradConv2dTransformOutputKernel::get_output_shape( - int num_rows, /* Number of rows in each feature map of the input tensor. */ - int num_cols, /* Number of columns in each feature map of the input tensor. */ - bool padding_same) const -{ - return WinogradConv::get_output_shape(std::make_pair(num_rows, num_cols), padding_same); -} - -template -void CpuWinogradConv2dTransformOutputKernel::configure( - const ITensorInfo *biases, - const ITensorInfo *transformed_output, - const int matrix_stride, - ITensorInfo *output_nhwc, - const int num_batches, - const int num_rows, - const int num_cols, - const int num_channels, - ITensorInfo *workspace, - const arm_gemm::Activation &activation) -{ - ARM_COMPUTE_UNUSED(biases, transformed_output, output_nhwc, num_batches, num_rows, num_cols, workspace, activation); - - _matrix_stride = matrix_stride; - _matrix_row_stride = roundup(num_channels, WinogradConv::N_BLOCK); - - // We don't have the biases buffer at this stage as it hasn't been allocated, we pass in nullptr OutputTransform is only used here to compute the window - _transform = std::make_unique(num_batches, num_rows, num_cols, num_channels, activation); - Window win; - auto win_last = _transform->get_window(); - win.set(Window::DimX, Window::Dimension(0, win_last, 1)); - - ICpuKernel::configure(win); -} - -template -void CpuWinogradConv2dTransformOutputKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) -{ - ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); - ARM_COMPUTE_ERROR_ON(tensors.empty()); - - const ITensor *biases = tensors.get_const_tensor(TensorType::ACL_SRC_0); - const ITensor *transformed_output = tensors.get_const_tensor(TensorType::ACL_SRC_1); - ITensor *workspace = tensors.get_tensor(TensorType::ACL_INT); - ITensor *dst_nhwc = tensors.get_tensor(TensorType::ACL_DST); - - const int out_batch_stride = dst_nhwc->info()->strides_in_bytes()[3] / sizeof(T); - const int out_row_stride = dst_nhwc->info()->strides_in_bytes()[2] / sizeof(T); - const int out_col_stride = dst_nhwc->info()->strides_in_bytes()[1] / sizeof(T); - - _transform->set_input_matrices(transformed_output->buffer(), _matrix_stride, _matrix_row_stride); - _transform->set_bias((biases ? reinterpret_cast(biases->buffer() + biases->info()->offset_first_element_in_bytes()) : nullptr)); - _transform->set_output_tensor(dst_nhwc->buffer() + dst_nhwc->info()->offset_first_element_in_bytes(), out_batch_stride, out_row_stride, out_col_stride); - _transform->set_working_space(workspace->buffer()); - - // The code below cannot be moved to configure because biases hasn't been allocated at that point - const size_t fst = window.x().start(); - const size_t lst = window.x().end(); - _transform->run(fst, lst, info.thread_id); -} - -template -Status CpuWinogradConv2dTransformOutputKernel::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, - const WinogradInfo &winograd_info) -{ - ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_winograd_output_trans(input, (bias != nullptr ? bias->clone().get() : nullptr), output, winograd_info)); - ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_winograd_output_trans(input->clone().get(), output->clone().get(), winograd_info).first); - - return Status{}; -} - -template class CpuWinogradConv2dTransformOutputKernel; -template class CpuWinogradConv2dTransformOutputKernel; -template class CpuWinogradConv2dTransformOutputKernel; -template class CpuWinogradConv2dTransformOutputKernel; -template class CpuWinogradConv2dTransformOutputKernel; - -template class CpuWinogradConv2dTransformOutputKernel; -template class CpuWinogradConv2dTransformOutputKernel; -template class CpuWinogradConv2dTransformOutputKernel; -template class CpuWinogradConv2dTransformOutputKernel; - -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -template class CpuWinogradConv2dTransformOutputKernel<__fp16, 4, 4, 3, 3>; -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC } // namespace cpu -} // namespace arm_compute +} // namespace arm_compute \ No newline at end of file diff --git a/src/cpu/kernels/CpuWinogradConv2dKernel.h b/src/cpu/kernels/CpuWinogradConv2dKernel.h index 6909216d94..0170dcae22 100644 --- a/src/cpu/kernels/CpuWinogradConv2dKernel.h +++ b/src/cpu/kernels/CpuWinogradConv2dKernel.h @@ -24,550 +24,79 @@ #ifndef ARM_COMPUTE_CPUWINOGRADCONV2DKERNEL_H #define ARM_COMPUTE_CPUWINOGRADCONV2DKERNEL_H -#include "src/core/NEON/kernels/convolution/common/convolution.hpp" +#include "arm_compute/core/Helpers.h" +#include "arm_compute/core/ITensor.h" +#include "arm_compute/core/ITensorPack.h" +#include "arm_compute/core/Steps.h" +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/runtime/Tensor.h" +#include "src/core/NEON/kernels/assembly/winograd.hpp" #include "src/core/NEON/kernels/convolution/common/tensor.hpp" #include "src/cpu/ICpuKernel.h" -#include "src/core/NEON/kernels/convolution/winograd/winograd_layer.hpp" - namespace arm_compute { namespace cpu { -/** Interface for the kernel to perform Winograd input transform. */ -class ICpuWinogradConv2dTransformInputKernel : public ICpuKernel -{ -public: - /** Get the working space required to perform the transformation. - * - * Note, the working space is only required when performing the - * transformation - hence it can be reused whenever the transformation is - * not running. - * - * @param num_threads The greatest number of threads that will be used to execute the transform. - * @return Size of working space required in bytes. - */ - virtual unsigned int get_working_space_size(unsigned int num_threads) const = 0; - - /** Determine how much memory (in units of TIn) to allocate for the - * transformed input. - * - * @param[in] num_batches Number of batches in the input tensor. - * @param[in] num_channels Number of feature maps in the input tensor. - * @param[in] num_rows Number of rows in each feature map. - * @param[in] num_cols Number of columns in each feature map. - * @param[in] same_padding Use "SAME" padding, otherwise use "VALID". - * - * @return Storage size (in units of TIn) required. - */ - virtual unsigned int get_input_storage_size(int num_batches, int num_channels, int num_rows, int num_cols, bool same_padding) const = 0; - - /** Gets the stride between matrices in the input worspace - * - * @param[in] num_batches Number of batches in the input tensor. - * @param[in] num_channels Number of feature maps in the input tensor. - * @param[in] num_rows Number of rows in each feature map. - * @param[in] num_cols Number of columns in each feature map. - * @param[in] same_padding Use "SAME" padding, otherwise use "VALID". - * - * @return Stride expressed in bytes. - */ - virtual int get_matrix_stride(int num_batches, int num_channels, int num_rows, int num_cols, bool same_padding) const = 0; - - /** Configure the output transform kernel. - * - * @param[in] input_nhwc Input tensor in NHWC data layout format. - * @param[in] num_batches Number of batches in input tensor. - * @param[in] num_rows Number of rows in input tensor. - * @param[in] num_cols Number of columns in input tensor. - * @param[in] num_channels Number of channels in input tensor. - * @param[in] padding Padding type. - * @param[out] output Base of output matrices. - * @param[in] matrix_stride Stride between output matrices. - * @param[in] workspace Tensor to be used as the working space during the computation. - */ - virtual void configure(const ITensorInfo *input_nhwc, const int num_batches, const int num_rows, const int num_cols, const int num_channels, - const PaddingType padding, ITensorInfo *output, const int matrix_stride, ITensorInfo *workspace) = 0; - - /** Destructor */ - virtual ~ICpuWinogradConv2dTransformInputKernel() - { - } -}; - -/** Kernel to perform Winograd input transform. */ -template -class CpuWinogradConv2dTransformInputKernel : public ICpuWinogradConv2dTransformInputKernel +class CpuWinogradConv2dTransformInputKernel final : public ICpuKernel { public: /** Prevent instances of this class from being copied (As this class contains pointers) */ CpuWinogradConv2dTransformInputKernel(const CpuWinogradConv2dTransformInputKernel &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers) */ CpuWinogradConv2dTransformInputKernel &operator=(const CpuWinogradConv2dTransformInputKernel &) = delete; - /** Allow instances of this class to be moved */ - CpuWinogradConv2dTransformInputKernel(CpuWinogradConv2dTransformInputKernel &&) = default; - /** Allow instances of this class to be moved */ - CpuWinogradConv2dTransformInputKernel &operator=(CpuWinogradConv2dTransformInputKernel &&) = default; - /** Default destructor */ - ~CpuWinogradConv2dTransformInputKernel() = default; - /** Determine how much memory (in units of TIn) to allocate for the - * transformed input. - * - * @param[in] num_batches Number of batches in the input tensor. - * @param[in] num_channels Number of feature maps in the input tensor. - * @param[in] num_rows Number of rows in each feature map. - * @param[in] num_cols Number of columns in each feature map. - * @param[in] same_padding Use "SAME" padding, otherwise use "VALID". - * - * @return Storage size (in units of TIn) required. - */ - unsigned int get_input_storage_size( - int num_batches, - int num_channels, - int num_rows, - int num_cols, - bool same_padding) const override; + /** Prevent instances of this class from being moved it contains references.*/ + CpuWinogradConv2dTransformInputKernel(CpuWinogradConv2dTransformInputKernel &&) = delete; - /** Get the working space required to perform the transformation. - * - * Note, the working space is only required when performing the - * transformation - hence it can be reused whenever the transformation is - * not running. - * - * @param[in] num_threads The greatest number of threads that will be used to execute the transform. - * - * @return Size of working space required in bytes. - */ - unsigned int get_working_space_size(unsigned int num_threads) const override; + /** Prevent instances of this class from being moved it contains references.*/ + CpuWinogradConv2dTransformInputKernel &operator=(CpuWinogradConv2dTransformInputKernel &&) = delete; - /** Gets the stride between matrices in the input worspace - * - * @param[in] num_batches Number of batches in the input tensor. - * @param[in] num_channels Number of feature maps in the input tensor. - * @param[in] num_rows Number of rows in each feature map. - * @param[in] num_cols Number of columns in each feature map. - * @param[in] same_padding Use "SAME" padding, otherwise use "VALID". - * - * @return Stride expressed in bytes. - */ - int get_matrix_stride( - int num_batches, - int num_channels, - int num_rows, - int num_cols, - bool same_padding) const override; + CpuWinogradConv2dTransformInputKernel(arm_conv::winograd::WinogradImpl &w_impl, arm_conv::ConvolutionArgs &_c_args, uint32_t nthreads); - /** Default constructor */ - CpuWinogradConv2dTransformInputKernel(); + // Inherited methods overridden: + void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; const char *name() const override { return "CpuWinogradConv2dTransformInputKernel"; } - /** Configure the output transform kernel. - * - * @param[in] input_nhwc Input tensor. Data types supported: F16/F32. Layout supported NHWC. - * @param[in] num_batches Number of batches in input tensor. - * @param[in] num_rows Number of rows in input tensor. - * @param[in] num_cols Number of columns in input tensor. - * @param[in] num_channels Number of channels in input tensor. - * @param[in] padding Padding type. - * @param[out] output Base of output matrices. - * @param[in] matrix_stride Stride between output matrices. - * @param[in] workspace Tensor to be used as the working space during the computation. - */ - void configure( - const ITensorInfo *input_nhwc, - const int num_batches, - const int num_rows, - const int num_cols, - const int num_channels, - const PaddingType padding, - ITensorInfo *output, - const int matrix_stride, - ITensorInfo *workspace) override; - - // Inherited methods overridden: - void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; - - /** Winograd base kernel */ - using WinogradBase = winograd::WinogradGEMM; - /** Winograd convolution kernel */ - using WinogradConv = typename WinogradBase::template Convolution; - - /** Static function to check if given info will lead to a valid configuration of @ref CpuWinogradConv2dTransformInputKernel - * - * @param[in] input First tensor input info. Data types supported: F16/F32. - * @param[in] output Output tensor info. Data types supported: same as @p input. - * @param[in] winograd_info Contains Winograd's information described in @ref WinogradInfo - * - * @return a status - */ - static Status validate(const ITensorInfo *input, const ITensorInfo *output, const WinogradInfo &winograd_info); - private: - using InputTransform = typename WinogradBase::template InputTransform; - - std::unique_ptr _transform{ nullptr }; - int _num_channels; /**< Number of channels in input tensor. */ - int _matrix_stride; /**< Stride between output matrices. */ + arm_conv::winograd::WinogradImpl &_winograd_impl; + arm_conv::ConvolutionArgs &_conv_args; + uint32_t _nthreads; }; - -/** Interface for the kernel to perform Winograd output transform. */ -class ICpuWinogradConv2dTransformOutputKernel : public ICpuKernel -{ -public: - /** Get the working space required to perform the transformation. - * - * Note, the working space is only required when performing the - * transformation - hence it can be reused whenever the transformation is - * not running. - * - * @param[in] num_threads The greatest number of threads that will be used to execute the transform. - * - * @return Size of working space required in bytes. - */ - virtual unsigned int get_working_space_size(unsigned int num_threads) const = 0; - - /** Determine how much memory (in units of TOut) to allocate for the - * (Winograd domain) output. - * - * @param[in] num_batches Number of batches in the output tensor. - * @param[in] num_rows Number of rows in each feature map of the input tensor. - * @param[in] num_cols Number of columns in each feature map of the input tensor. - * @param[in] num_output_channels Number of feature maps in the output tensor. - * - * @return Storage size (in units of TOut) required. - */ - virtual unsigned int get_output_storage_size(int num_batches, int num_rows, int num_cols, int num_output_channels) const = 0; - - /** Gets the stride between matrices in the output worspace - * - * @param[in] num_batches Number of batches in the output tensor. - * @param[in] num_rows Number of rows in each feature map of the input tensor. - * @param[in] num_cols Number of columns in each feature map of the input tensor. - * @param[in] num_output_channels Number of feature maps in the output tensor. - * - * @return Stride expressed in bytes. - */ - virtual int get_matrix_stride(int num_batches, int num_rows, int num_cols, int num_output_channels) const = 0; - - /** Get the output shape of a convolution. - * - * @param[in] num_rows Number of rows in each feature map of the input tensor. - * @param[in] num_cols Number of columns in each feature map of the input tensor. - * @param[in] padding_same True if padding is SAME, false otherwise - * - * @return Shape of the output tensor - */ - virtual std::pair get_output_shape( - int num_rows, /* Number of rows in each feature map of the input tensor. */ - int num_cols, /* Number of columns in each feature map of the input tensor. */ - bool padding_same /* True if padding is SAME, false otherwise */ - ) const = 0; - - /** Configure the output transform kernel. - * - * @param[in] biases Pointer to the biases tensor. - * @param[in] transformed_output Pointer to working space for the output tensor in the Winograd domain. - * @param[in] matrix_stride Output matrix stride, can be computed with winograd::WinogradGEMM<2, 2, 3, 3>::Convolution::get_output_matrix_stride() - * @param[out] output_nhwc Pointer to a tensor in NHWC data layout ordered output tensor, in the spatial domain. - * @param[in] num_batches Number of batches in the input tensor. - * @param[in] num_rows Number of rows in output tensor. - * @param[in] num_cols Number of columns in output tensor. - * @param[in] num_channels Number of feature maps in the output tensor. - * @param[in] workspace Tensor to be used as the working space during the computation. - * @param[in] activation Activation to be used - */ - virtual void configure( - const ITensorInfo *biases, - const ITensorInfo *transformed_output, - const int matrix_stride, - ITensorInfo *output_nhwc, - const int num_batches, - const int num_rows, - const int num_cols, - const int num_channels, - ITensorInfo *workspace, - const arm_gemm::Activation &activation) = 0; - - virtual ~ICpuWinogradConv2dTransformOutputKernel() - { - } -}; - -/** Kernel to perform Winograd output transform. */ -template -class CpuWinogradConv2dTransformOutputKernel : public ICpuWinogradConv2dTransformOutputKernel +class CpuWinogradConv2dTransformOutputKernel : public ICpuKernel { public: - const char *name() const override - { - return "CpuWinogradConv2dTransformOutputKernel"; - } - /** Constructor */ - CpuWinogradConv2dTransformOutputKernel(); - /** Prevent instances of this class from being copied (As this class contains pointers) */ CpuWinogradConv2dTransformOutputKernel(const CpuWinogradConv2dTransformOutputKernel &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers) */ CpuWinogradConv2dTransformOutputKernel &operator=(const CpuWinogradConv2dTransformOutputKernel &) = delete; - /** Allow instances of this class to be moved */ - CpuWinogradConv2dTransformOutputKernel(CpuWinogradConv2dTransformOutputKernel &&) = default; - /** Allow instances of this class to be moved */ - CpuWinogradConv2dTransformOutputKernel &operator=(CpuWinogradConv2dTransformOutputKernel &&) = default; - /** Default destructor */ - ~CpuWinogradConv2dTransformOutputKernel() = default; - - // Inherited methods overridden: - /** Determine how much memory (in units of TOut) to allocate for the - * (Winograd domain) output. - * - * @param[in] num_batches Number of batches in the output tensor. - * @param[in] num_rows Number of rows in each feature map of the input tensor. - * @param[in] num_cols Number of columns in each feature map of the input tensor. - * @param[in] num_output_channels Number of feature maps in the output tensor. - * - * @return Storage size (in units of TOut) required. - */ - unsigned int get_output_storage_size(int num_batches, int num_rows, int num_cols, int num_output_channels) const override; - /** Gets the stride between matrices in the output worspace - * - * @param[in] num_batches Number of batches in the output tensor. - * @param[in] num_rows Number of rows in each feature map of the input tensor. - * @param[in] num_cols Number of columns in each feature map of the input tensor. - * @param[in] num_output_channels Number of feature maps in the output tensor. - * - * @return Stride expressed in bytes. - */ - int get_matrix_stride(int num_batches, int num_rows, int num_cols, int num_output_channels) const override; - /** Get the output shape of a convolution. - * - * @param[in] num_rows Number of rows in each feature map of the input tensor. - * @param[in] num_cols Number of columns in each feature map of the input tensor. - * @param[in] padding_same True if padding is SAME, false otherwise - * - * @return Shape of the output tensor - */ - std::pair get_output_shape( - int num_rows, /* Number of rows in each feature map of the input tensor. */ - int num_cols, /* Number of columns in each feature map of the input tensor. */ - bool padding_same) const override; + /** Prevent instances of this class from being moved it contains references.*/ + CpuWinogradConv2dTransformOutputKernel(CpuWinogradConv2dTransformOutputKernel &&) = delete; - /** Get the working space required to perform the transformation. - * - * Note, the working space is only required when performing the - * transformation - hence it can be reused whenever the transformation is - * not running. - * - * @param[in] num_threads The greatest number of threads that will be used to execute the transform. - * - * @return Size of working space required in bytes. - */ - unsigned int get_working_space_size(unsigned int num_threads) const override; + /** Prevent instances of this class from being moved it contains references.*/ + CpuWinogradConv2dTransformOutputKernel &operator=(CpuWinogradConv2dTransformOutputKernel &&) = delete; - /** Configure the output transform kernel. - * - * @param[in] biases Pointer to the biases tensor. - * @param[in] transformed_output Pointer to working space for the output tensor in the Winograd domain. - * @param[in] matrix_stride Output matrix stride, can be computed with winograd::WinogradGEMM<2, 2, 3, 3>::Convolution::get_output_matrix_stride() - * @param[out] output_nhwc Pointer to a tensor with NHWC data layout, in the spatial domain. - * @param[in] num_batches Number of batches in the input tensor. - * @param[in] num_rows Number of rows in output tensor. - * @param[in] num_cols Number of columns in output tensor. - * @param[in] num_channels Number of feature maps in the output tensor. - * @param[in] workspace Tensor to be used as the working space during the computation. - * @param[in] activation Activation to be used - */ - void configure( - const ITensorInfo *biases, - const ITensorInfo *transformed_output, - const int matrix_stride, - ITensorInfo *output_nhwc, - const int num_batches, - const int num_rows, - const int num_cols, - const int num_channels, - ITensorInfo *workspace, - const arm_gemm::Activation &activation) override; + CpuWinogradConv2dTransformOutputKernel(arm_conv::winograd::WinogradImpl &w_impl, arm_conv::ConvolutionArgs &_c_args, uint32_t nthreads); + // Inherited methods overridden: void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; - /** Static function to check if given info will lead to a valid configuration of @ref CpuWinogradConv2dTransformOutputKernel - * - * @param[in] input Source tensor info with shape [C, N, 16, batches] or [C, N, 36, batches]. Data types supported: F16/F32. - * @param[in] bias Biases tensor info. Shared biases supported. Biases are 1D tensor with dimensions [OFM]. It can be a nullptr. Data type supported: as @p input - * @param[in] output Destination tensor info with shape [output_convolved_dims.width, output_convolved_dims.height, C, batches]. Data type supported: same as @p input - * @param[in] winograd_info Contains Winograd's information described in @ref WinogradInfo - * - * @return a status - */ - static Status validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, const WinogradInfo &winograd_info); - -private: - using WinogradBase = winograd::WinogradGEMM; - using WinogradConv = typename WinogradBase::template Convolution; - using OutputTransform = typename WinogradBase::template OutputTransform; - - std::unique_ptr _transform{ nullptr }; - int _matrix_stride; - int _matrix_row_stride; -}; - -/** Interface for the kernel to perform Winograd weights transform. */ -class ICpuWinogradConv2dTransformWeightsKernel : public ICpuKernel -{ -public: - /** Prevent instances of this class from being copied (As this class contains pointers) */ - ICpuWinogradConv2dTransformWeightsKernel(const ICpuWinogradConv2dTransformWeightsKernel &) = default; - /** Prevent instances of this class from being copied (As this class contains pointers) */ - ICpuWinogradConv2dTransformWeightsKernel &operator=(const ICpuWinogradConv2dTransformWeightsKernel &) = default; - /** Allow instances of this class to be moved */ - ICpuWinogradConv2dTransformWeightsKernel(ICpuWinogradConv2dTransformWeightsKernel &&) = default; - /** Allow instances of this class to be moved */ - ICpuWinogradConv2dTransformWeightsKernel &operator=(ICpuWinogradConv2dTransformWeightsKernel &&) = default; - - ICpuWinogradConv2dTransformWeightsKernel() - { - } - virtual ~ICpuWinogradConv2dTransformWeightsKernel() - { - } - /** Determine how much memory (in units of T) to allocate for the - * transformed weights. - * - * @param[in] num_output_channels Number of output feature maps. - * @param[in] num_input_channels Number of input feature maps. - * - * @return Storage size (in units of T) required. - */ - virtual unsigned int get_weight_storage_size(int num_output_channels, int num_input_channels) const = 0; - /** Gets the stride between matrices in the kernel worspace - * - * @param[in] num_output_channels Number of output feature maps. - * @param[in] num_input_channels Number of input feature maps. - * - * @return Stride expressed in bytes. - */ - virtual int get_matrix_stride(int num_output_channels, int num_input_channels) const = 0; - - /** Configure the weights transform kernel. - * - * @param[in] weights_hwio Pointer to the weights tensor info - * @param[out] output Pointer to working space for the output tensor in the Winograd domain. - * @param[in] matrix_stride Stride across matrices in the output workspace. - * @param[in] num_output_channels Number of filters. - * @param[in] num_input_channels Number of channels in each filter. - */ - - virtual void configure(const ITensorInfo *weights_hwio, ITensorInfo *output, const int matrix_stride, const int num_output_channels, const int num_input_channels) = 0; - - /** Static function to check if given info will lead to a valid configuration of @ref CpuWinogradConv2dTransformWeightsKernel - * - * @param[in] input First tensor input info. Data types supported: F16/F32. - * @param[in] weights Weights tensor info. Data types supported: same as @p input. - * - * @return a status - */ - static Status validate(const ITensorInfo *input, const ITensorInfo *weights); -}; - -/** Kernel to perform Winograd weights transform. */ -template -class CpuWinogradConv2dTransformWeightsKernel final : public ICpuWinogradConv2dTransformWeightsKernel -{ -public: - /** Prevent instances of this class from being copied (As this class contains pointers) */ - CpuWinogradConv2dTransformWeightsKernel(const CpuWinogradConv2dTransformWeightsKernel &) = delete; - /** Prevent instances of this class from being copied (As this class contains pointers) */ - CpuWinogradConv2dTransformWeightsKernel &operator=(const CpuWinogradConv2dTransformWeightsKernel &) = delete; - /** Allow instances of this class to be moved */ - CpuWinogradConv2dTransformWeightsKernel(CpuWinogradConv2dTransformWeightsKernel &&) = default; - /** Allow instances of this class to be moved */ - CpuWinogradConv2dTransformWeightsKernel &operator=(CpuWinogradConv2dTransformWeightsKernel &&) = default; - /** Default destructor */ - ~CpuWinogradConv2dTransformWeightsKernel() = default; - - /** Default constructor. */ - CpuWinogradConv2dTransformWeightsKernel(); const char *name() const override { - return "CpuWinogradConv2dTransformWeightsKernel"; + return "CpuWinogradConv2dTransformOutputKernel"; } - /** Static function to check if given info will lead to a valid configuration of @ref CpuWinogradConv2dTransformWeightsKernel - * - * @param[in] input Source tensor info. The input is a 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM] (NCHW data layout). - * kernel_x must be 3 and equal to kernel_y. Data types supported: F16/F32. - * @param[in] output Destination tensor info. The output is a 3D tensor with dimensions [OFM, IFM, 16] or [OFM, IFM, 36]. Data type supported: same as @p input - * @param[in] winograd_info Contains Winograd's information described in @ref WinogradInfo - * - * @return a status - */ - static Status validate(const ITensorInfo *input, const ITensorInfo *output, const WinogradInfo &winograd_info); - - // Inherited methods overridden: - -#ifndef DOXYGEN_SKIP_THIS - /** Configure the weights transform kernel. - * - * @param[in] weights_hwio Pointer to the weights tensor info - * @param[out] output Pointer to working space for the output tensor in the Winograd domain. - * @param[in] matrix_stride Stride across matrices in the output workspace. - * @param[in] num_output_channels Number of filters. - * @param[in] num_input_channels Number of channels in each filter. - */ - void configure(const ITensorInfo *weights_hwio, ITensorInfo *output, const int matrix_stride, const int num_output_channels, const int num_input_channels) override; -#endif /* DOXYGEN_SKIP_THIS */ - - /** Determine how much memory (in units of T) to allocate for the - * transformed weights. - * - * @param[in] num_output_channels Number of output feature maps. - * @param[in] num_input_channels Number of input feature maps. - * - * @return Storage size (in units of T) required. - */ - unsigned int get_weight_storage_size(int num_output_channels, int num_input_channels) const override; - - /** Gets the stride between matrices in the input worspace - * - * @param[in] num_output_channels Number of output feature maps. - * @param[in] num_input_channels Number of input feature maps. - * - * @return Stride expressed in bytes. - */ - int get_matrix_stride(int num_output_channels, int num_input_channels) const override; - void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; - bool is_parallelisable() const override; - private: - using WinogradBase = winograd::WinogradGEMM; - using WinogradConv = typename WinogradBase::template Convolution; - using WeightsTransform = typename WinogradBase::template WeightsTransform; - - std::unique_ptr _transform{ nullptr }; - int _num_output_channels; - int _matrix_stride; -}; - -/** Kernel to perform Winograd. */ -template -class CpuWinogradConv2dConfiguration -{ -public: - /** Winograd base kernel */ - using WinogradBase = winograd::WinogradGEMM; - /** Winograd convolution kernel */ - - using WinogradConv = typename WinogradBase::template Convolution; - - using TransformInputKernel = CpuWinogradConv2dTransformInputKernel; - using TransformWeightsKernel = CpuWinogradConv2dTransformWeightsKernel; - using TransformOutputKernel = CpuWinogradConv2dTransformOutputKernel; + arm_conv::winograd::WinogradImpl &_winograd_impl; + const arm_conv::ConvolutionArgs &_conv_args; + uint32_t _nthreads; }; } // namespace cpu diff --git a/src/cpu/kernels/activation/generic/neon/qasymm8.cpp b/src/cpu/kernels/activation/generic/neon/qasymm8.cpp index 62e329e691..5095ecf5bd 100644 --- a/src/cpu/kernels/activation/generic/neon/qasymm8.cpp +++ b/src/cpu/kernels/activation/generic/neon/qasymm8.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 Arm Limited. + * Copyright (c) 2020-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -31,11 +31,417 @@ #include #include #include +#include namespace arm_compute { namespace cpu { +namespace +{ +#ifdef __aarch64__ + +void substitute_bytes_neon( + const uint8_t *table, + size_t num_strings, + size_t string_length, + const uint8_t *const *input, + uint8_t *const *output) +{ + __asm__ __volatile__( + "ldr q16, [%x[table], #0x0]\n" + "ldr q17, [%x[table], #0x10]\n" + "mov x22, #0x0\n" + "ldr q18, [%x[table], #0x20]\n" + "ldr q19, [%x[table], #0x30]\n" + "ldr q20, [%x[table], #0x40]\n" + "ldr q21, [%x[table], #0x50]\n" + "ldr q22, [%x[table], #0x60]\n" + "ldr q23, [%x[table], #0x70]\n" + "ldr q24, [%x[table], #0x80]\n" + "ldr q25, [%x[table], #0x90]\n" + "ldr q26, [%x[table], #0xa0]\n" + "ldr q27, [%x[table], #0xb0]\n" + "ldr q28, [%x[table], #0xc0]\n" + "ldr q29, [%x[table], #0xd0]\n" + "ldr q30, [%x[table], #0xe0]\n" + "ldr q31, [%x[table], #0xf0]\n" + "1:" // string loop + "ldr x21, [%x[input], x22, LSL #0x3]\n" + "ldr x20, [%x[output], x22, LSL #0x3]\n" + "movi v12.16b, #0x40\n" + "movi v11.16b, #0x80\n" + "movi v10.16b, #0xc0\n" + "mov x19, %x[string_length]\n" + "2:" // 4 rounds: width loop + "cmp x19, #0x30\n" + "bge 27f\n" + "tbz x19, #5, 10f\n" + "ld1 { v9.16b }, [x21], #0x10\n" + "ld1 { v13.16b }, [x21], #0x10\n" + "tbz x19, #3, 6f\n" + "ldr d14, [x21], #0x8\n" + "tbz x19, #2, 4f\n" + "ld1 { v14.s }[2], [x21], #0x4\n" + "tbz x19, #1, 3f\n" + "ld1 { v14.h }[6], [x21], #0x2\n" + "tbz x19, #0, 26f\n" + "ld1 { v14.b }[14], [x21]\n" + "b 26f\n" + "3:" // 4 rounds: Partial load: partial_1_44 + "tbz x19, #0, 26f\n" + "ld1 { v14.b }[12], [x21]\n" + "b 26f\n" + "4:" // 4 rounds: Partial load: partial_2_40 + "tbz x19, #1, 5f\n" + "ld1 { v14.h }[4], [x21], #0x2\n" + "tbz x19, #0, 26f\n" + "ld1 { v14.b }[10], [x21]\n" + "b 26f\n" + "5:" // 4 rounds: Partial load: partial_1_40 + "tbz x19, #0, 26f\n" + "ld1 { v14.b }[8], [x21]\n" + "b 26f\n" + "6:" // 4 rounds: Partial load: partial_4_32 + "tbz x19, #2, 8f\n" + "ldr s14, [x21], #0x4\n" + "tbz x19, #1, 7f\n" + "ld1 { v14.h }[2], [x21], #0x2\n" + "tbz x19, #0, 26f\n" + "ld1 { v14.b }[6], [x21]\n" + "b 26f\n" + "7:" // 4 rounds: Partial load: partial_1_36 + "tbz x19, #0, 26f\n" + "ld1 { v14.b }[4], [x21]\n" + "b 26f\n" + "8:" // 4 rounds: Partial load: partial_2_32 + "tbz x19, #1, 9f\n" + "ldr h14, [x21], #0x2\n" + "tbz x19, #0, 26f\n" + "ld1 { v14.b }[2], [x21]\n" + "b 26f\n" + "9:" // 4 rounds: Partial load: partial_1_32 + "tbz x19, #0, 26f\n" + "ldr b14, [x21, #0x0]\n" + "b 26f\n" + "10:" // 4 rounds: Partial load: partial_16_0 + "tbz x19, #4, 18f\n" + "ld1 { v9.16b }, [x21], #0x10\n" + "tbz x19, #3, 14f\n" + "ldr d13, [x21], #0x8\n" + "tbz x19, #2, 12f\n" + "ld1 { v13.s }[2], [x21], #0x4\n" + "tbz x19, #1, 11f\n" + "ld1 { v13.h }[6], [x21], #0x2\n" + "tbz x19, #0, 26f\n" + "ld1 { v13.b }[14], [x21]\n" + "b 26f\n" + "11:" // 4 rounds: Partial load: partial_1_28 + "tbz x19, #0, 26f\n" + "ld1 { v13.b }[12], [x21]\n" + "b 26f\n" + "12:" // 4 rounds: Partial load: partial_2_24 + "tbz x19, #1, 13f\n" + "ld1 { v13.h }[4], [x21], #0x2\n" + "tbz x19, #0, 26f\n" + "ld1 { v13.b }[10], [x21]\n" + "b 26f\n" + "13:" // 4 rounds: Partial load: partial_1_24 + "tbz x19, #0, 26f\n" + "ld1 { v13.b }[8], [x21]\n" + "b 26f\n" + "14:" // 4 rounds: Partial load: partial_4_16 + "tbz x19, #2, 16f\n" + "ldr s13, [x21], #0x4\n" + "tbz x19, #1, 15f\n" + "ld1 { v13.h }[2], [x21], #0x2\n" + "tbz x19, #0, 26f\n" + "ld1 { v13.b }[6], [x21]\n" + "b 26f\n" + "15:" // 4 rounds: Partial load: partial_1_20 + "tbz x19, #0, 26f\n" + "ld1 { v13.b }[4], [x21]\n" + "b 26f\n" + "16:" // 4 rounds: Partial load: partial_2_16 + "tbz x19, #1, 17f\n" + "ldr h13, [x21], #0x2\n" + "tbz x19, #0, 26f\n" + "ld1 { v13.b }[2], [x21]\n" + "b 26f\n" + "17:" // 4 rounds: Partial load: partial_1_16 + "tbz x19, #0, 26f\n" + "ldr b13, [x21, #0x0]\n" + "b 26f\n" + "18:" // 4 rounds: Partial load: partial_8_0 + "tbz x19, #3, 22f\n" + "ldr d9, [x21], #0x8\n" + "tbz x19, #2, 20f\n" + "ld1 { v9.s }[2], [x21], #0x4\n" + "tbz x19, #1, 19f\n" + "ld1 { v9.h }[6], [x21], #0x2\n" + "tbz x19, #0, 26f\n" + "ld1 { v9.b }[14], [x21]\n" + "b 26f\n" + "19:" // 4 rounds: Partial load: partial_1_12 + "tbz x19, #0, 26f\n" + "ld1 { v9.b }[12], [x21]\n" + "b 26f\n" + "20:" // 4 rounds: Partial load: partial_2_8 + "tbz x19, #1, 21f\n" + "ld1 { v9.h }[4], [x21], #0x2\n" + "tbz x19, #0, 26f\n" + "ld1 { v9.b }[10], [x21]\n" + "b 26f\n" + "21:" // 4 rounds: Partial load: partial_1_8 + "tbz x19, #0, 26f\n" + "ld1 { v9.b }[8], [x21]\n" + "b 26f\n" + "22:" // 4 rounds: Partial load: partial_4_0 + "tbz x19, #2, 24f\n" + "ldr s9, [x21], #0x4\n" + "tbz x19, #1, 23f\n" + "ld1 { v9.h }[2], [x21], #0x2\n" + "tbz x19, #0, 26f\n" + "ld1 { v9.b }[6], [x21]\n" + "b 26f\n" + "23:" // 4 rounds: Partial load: partial_1_4 + "tbz x19, #0, 26f\n" + "ld1 { v9.b }[4], [x21]\n" + "b 26f\n" + "24:" // 4 rounds: Partial load: partial_2_0 + "tbz x19, #1, 25f\n" + "ldr h9, [x21], #0x2\n" + "tbz x19, #0, 26f\n" + "ld1 { v9.b }[2], [x21]\n" + "b 26f\n" + "25:" // 4 rounds: Partial load: partial_1_0 + "ldr b9, [x21, #0x0]\n" + "26:" // 4 rounds: Partial load: Done + "b 28f\n" + "27:" // 4 rounds: Full load + "ldr q9, [x21, #0x0]\n" + "ldr q13, [x21, #0x10]\n" + "ldr q14, [x21, #0x20]\n" + "add x21, x21, #0x30\n" + "28:" // 4 rounds: Load done + "sub v8.16b, v9.16b, v12.16b\n" + "sub v7.16b, v9.16b, v11.16b\n" + "tbl v8.16b, { v20.16b, v21.16b, v22.16b, v23.16b }, v8.16b\n" + "sub v6.16b, v9.16b, v10.16b\n" + "sub v5.16b, v13.16b, v12.16b\n" + "tbl v9.16b, { v16.16b, v17.16b, v18.16b, v19.16b }, v9.16b\n" + "sub v4.16b, v13.16b, v11.16b\n" + "sub v3.16b, v13.16b, v10.16b\n" + "tbl v7.16b, { v24.16b, v25.16b, v26.16b, v27.16b }, v7.16b\n" + "sub v2.16b, v14.16b, v12.16b\n" + "sub v1.16b, v14.16b, v11.16b\n" + "tbl v6.16b, { v28.16b, v29.16b, v30.16b, v31.16b }, v6.16b\n" + "sub v0.16b, v14.16b, v10.16b\n" + "tbl v13.16b, { v16.16b, v17.16b, v18.16b, v19.16b }, v13.16b\n" + "tbl v5.16b, { v20.16b, v21.16b, v22.16b, v23.16b }, v5.16b\n" + "tbl v4.16b, { v24.16b, v25.16b, v26.16b, v27.16b }, v4.16b\n" + "tbl v3.16b, { v28.16b, v29.16b, v30.16b, v31.16b }, v3.16b\n" + "orr v9.16b, v9.16b, v8.16b\n" + "tbl v14.16b, { v16.16b, v17.16b, v18.16b, v19.16b }, v14.16b\n" + "tbl v2.16b, { v20.16b, v21.16b, v22.16b, v23.16b }, v2.16b\n" + "orr v7.16b, v7.16b, v6.16b\n" + "tbl v1.16b, { v24.16b, v25.16b, v26.16b, v27.16b }, v1.16b\n" + "tbl v0.16b, { v28.16b, v29.16b, v30.16b, v31.16b }, v0.16b\n" + "orr v13.16b, v13.16b, v5.16b\n" + "orr v4.16b, v4.16b, v3.16b\n" + "orr v14.16b, v14.16b, v2.16b\n" + "cmp x19, #0x30\n" + "orr v1.16b, v1.16b, v0.16b\n" + "orr v9.16b, v9.16b, v7.16b\n" + "orr v13.16b, v13.16b, v4.16b\n" + "orr v14.16b, v14.16b, v1.16b\n" + "bge 53f\n" + "tbz x19, #5, 36f\n" + "st1 { v9.16b }, [x20], #0x10\n" + "st1 { v13.16b }, [x20], #0x10\n" + "tbz x19, #3, 32f\n" + "str d14, [x20], #0x8\n" + "tbz x19, #2, 30f\n" + "st1 { v14.s }[2], [x20], #0x4\n" + "tbz x19, #1, 29f\n" + "st1 { v14.h }[6], [x20], #0x2\n" + "tbz x19, #0, 52f\n" + "st1 { v14.b }[14], [x20]\n" + "b 52f\n" + "29:" // 4 rounds: Partial writeback: partial_1_44 + "tbz x19, #0, 52f\n" + "st1 { v14.b }[12], [x20]\n" + "b 52f\n" + "30:" // 4 rounds: Partial writeback: partial_2_40 + "tbz x19, #1, 31f\n" + "st1 { v14.h }[4], [x20], #0x2\n" + "tbz x19, #0, 52f\n" + "st1 { v14.b }[10], [x20]\n" + "b 52f\n" + "31:" // 4 rounds: Partial writeback: partial_1_40 + "tbz x19, #0, 52f\n" + "st1 { v14.b }[8], [x20]\n" + "b 52f\n" + "32:" // 4 rounds: Partial writeback: partial_4_32 + "tbz x19, #2, 34f\n" + "str s14, [x20], #0x4\n" + "tbz x19, #1, 33f\n" + "st1 { v14.h }[2], [x20], #0x2\n" + "tbz x19, #0, 52f\n" + "st1 { v14.b }[6], [x20]\n" + "b 52f\n" + "33:" // 4 rounds: Partial writeback: partial_1_36 + "tbz x19, #0, 52f\n" + "st1 { v14.b }[4], [x20]\n" + "b 52f\n" + "34:" // 4 rounds: Partial writeback: partial_2_32 + "tbz x19, #1, 35f\n" + "str h14, [x20], #0x2\n" + "tbz x19, #0, 52f\n" + "st1 { v14.b }[2], [x20]\n" + "b 52f\n" + "35:" // 4 rounds: Partial writeback: partial_1_32 + "tbz x19, #0, 52f\n" + "str b14, [x20, #0x0]\n" + "b 52f\n" + "36:" // 4 rounds: Partial writeback: partial_16_0 + "tbz x19, #4, 44f\n" + "st1 { v9.16b }, [x20], #0x10\n" + "tbz x19, #3, 40f\n" + "str d13, [x20], #0x8\n" + "tbz x19, #2, 38f\n" + "st1 { v13.s }[2], [x20], #0x4\n" + "tbz x19, #1, 37f\n" + "st1 { v13.h }[6], [x20], #0x2\n" + "tbz x19, #0, 52f\n" + "st1 { v13.b }[14], [x20]\n" + "b 52f\n" + "37:" // 4 rounds: Partial writeback: partial_1_28 + "tbz x19, #0, 52f\n" + "st1 { v13.b }[12], [x20]\n" + "b 52f\n" + "38:" // 4 rounds: Partial writeback: partial_2_24 + "tbz x19, #1, 39f\n" + "st1 { v13.h }[4], [x20], #0x2\n" + "tbz x19, #0, 52f\n" + "st1 { v13.b }[10], [x20]\n" + "b 52f\n" + "39:" // 4 rounds: Partial writeback: partial_1_24 + "tbz x19, #0, 52f\n" + "st1 { v13.b }[8], [x20]\n" + "b 52f\n" + "40:" // 4 rounds: Partial writeback: partial_4_16 + "tbz x19, #2, 42f\n" + "str s13, [x20], #0x4\n" + "tbz x19, #1, 41f\n" + "st1 { v13.h }[2], [x20], #0x2\n" + "tbz x19, #0, 52f\n" + "st1 { v13.b }[6], [x20]\n" + "b 52f\n" + "41:" // 4 rounds: Partial writeback: partial_1_20 + "tbz x19, #0, 52f\n" + "st1 { v13.b }[4], [x20]\n" + "b 52f\n" + "42:" // 4 rounds: Partial writeback: partial_2_16 + "tbz x19, #1, 43f\n" + "str h13, [x20], #0x2\n" + "tbz x19, #0, 52f\n" + "st1 { v13.b }[2], [x20]\n" + "b 52f\n" + "43:" // 4 rounds: Partial writeback: partial_1_16 + "tbz x19, #0, 52f\n" + "str b13, [x20, #0x0]\n" + "b 52f\n" + "44:" // 4 rounds: Partial writeback: partial_8_0 + "tbz x19, #3, 48f\n" + "str d9, [x20], #0x8\n" + "tbz x19, #2, 46f\n" + "st1 { v9.s }[2], [x20], #0x4\n" + "tbz x19, #1, 45f\n" + "st1 { v9.h }[6], [x20], #0x2\n" + "tbz x19, #0, 52f\n" + "st1 { v9.b }[14], [x20]\n" + "b 52f\n" + "45:" // 4 rounds: Partial writeback: partial_1_12 + "tbz x19, #0, 52f\n" + "st1 { v9.b }[12], [x20]\n" + "b 52f\n" + "46:" // 4 rounds: Partial writeback: partial_2_8 + "tbz x19, #1, 47f\n" + "st1 { v9.h }[4], [x20], #0x2\n" + "tbz x19, #0, 52f\n" + "st1 { v9.b }[10], [x20]\n" + "b 52f\n" + "47:" // 4 rounds: Partial writeback: partial_1_8 + "tbz x19, #0, 52f\n" + "st1 { v9.b }[8], [x20]\n" + "b 52f\n" + "48:" // 4 rounds: Partial writeback: partial_4_0 + "tbz x19, #2, 50f\n" + "str s9, [x20], #0x4\n" + "tbz x19, #1, 49f\n" + "st1 { v9.h }[2], [x20], #0x2\n" + "tbz x19, #0, 52f\n" + "st1 { v9.b }[6], [x20]\n" + "b 52f\n" + "49:" // 4 rounds: Partial writeback: partial_1_4 + "tbz x19, #0, 52f\n" + "st1 { v9.b }[4], [x20]\n" + "b 52f\n" + "50:" // 4 rounds: Partial writeback: partial_2_0 + "tbz x19, #1, 51f\n" + "str h9, [x20], #0x2\n" + "tbz x19, #0, 52f\n" + "st1 { v9.b }[2], [x20]\n" + "b 52f\n" + "51:" // 4 rounds: Partial writeback: partial_1_0 + "str b9, [x20, #0x0]\n" + "52:" // 4 rounds: Partial writeback: Done + "b 54f\n" + "53:" // 4 rounds: Full writeback + "str q9, [x20, #0x0]\n" + "str q13, [x20, #0x10]\n" + "str q14, [x20, #0x20]\n" + "add x20, x20, #0x30\n" + "54:" // 4 rounds: Writeback done + "subs x19, x19, #0x30\n" + "bgt 2b\n" + "add x22, x22, #0x1\n" + "cmp x22, %x[num_strings]\n" + "bne 1b\n" + : + : [input] "r"(input), [num_strings] "r"(num_strings), [output] "r"(output), [string_length] "r"(string_length), [table] "r"(table) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x19", "x20", "x21", "x22"); +} + +#endif // __aarch64__ +} // namespace + +void neon_qasymm8_activation_lut(const ITensor *src, ITensor *dst, const ActivationLayerInfo &act_info, const Window &window) +{ + ARM_COMPUTE_ERROR_ON(!ActivationLayerInfo::is_lut_supported(act_info.activation(), src->info()->data_type())); +#ifdef __aarch64__ + const int window_step_x = src->info()->tensor_shape().x(); + Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); + win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); + Iterator input(src, win_collapsed); + Iterator output(dst, win_collapsed); + execute_window_loop(win_collapsed, [&](const Coordinates &) + { + const auto input_ptr = reinterpret_cast(input.ptr()); + auto output_ptr = reinterpret_cast(output.ptr()); + substitute_bytes_neon(act_info.lut().data(), 1u, window_step_x, &input_ptr, &output_ptr); + }, + input, output); +#else // #ifdef __aarch64__ + ARM_COMPUTE_UNUSED(src); + ARM_COMPUTE_UNUSED(dst); + ARM_COMPUTE_UNUSED(act_info); + ARM_COMPUTE_UNUSED(window); + ARM_COMPUTE_ERROR("LUT Only supported in aarch64."); +#endif // __aarch64__ +} + void neon_qasymm8_activation(const ITensor *src, ITensor *dst, const ActivationLayerInfo &act_info, const Window &window) { constexpr int window_step_x = 16; @@ -61,14 +467,17 @@ void neon_qasymm8_activation(const ITensor *src, ITensor *dst, const ActivationL #ifndef __aarch64__ const auto vconst_0_f32 = vdupq_n_f32(0); #endif // __aarch64__ - const float32x4_t va_f32 = vdupq_n_f32(act_info.a()); - const float32x4_t vb_f32 = vdupq_n_f32(act_info.b()); - const float a_f32 = act_info.a(); - const float b_f32 = act_info.b(); - const auto const_6_f32 = vdupq_n_f32(6.f); - const auto const_0_f32 = vdupq_n_f32(0.f); - const auto const_3_f32 = vdupq_n_f32(3.f); - const auto const_inv_6_f32 = vdupq_n_f32(0.166666667f); + const float32x4_t va_f32 = vdupq_n_f32(act_info.a()); + const float32x4_t vb_f32 = vdupq_n_f32(act_info.b()); + const float a_f32 = act_info.a(); + const float b_f32 = act_info.b(); + +#ifndef __aarch64__ + const auto const_6_f32 = vdupq_n_f32(6.f); + const auto const_0_f32 = vdupq_n_f32(0.f); + const auto const_3_f32 = vdupq_n_f32(3.f); + const auto const_inv_6_f32 = vdupq_n_f32(0.166666667f); +#endif // __aarch64__ // Initialise scale/offset for re-quantization float s = qi_in.scale / qi_out.scale; @@ -143,6 +552,7 @@ void neon_qasymm8_activation(const ITensor *src, ITensor *dst, const ActivationL // Re-quantize to new output space tmp = vquantize(tmp_dep, qi_out); } +#ifndef __aarch64__ // LUT-based implementation is used for aarch64 instead. else if(act == ActivationLayerInfo::ActivationFunction::HARD_SWISH) { // De-quantize @@ -164,17 +574,6 @@ void neon_qasymm8_activation(const ITensor *src, ITensor *dst, const ActivationL { const auto vin_deq = vdequantize(vin, qi_in); -#ifdef __aarch64__ - const uint32x4x4_t pos_mask = - { - { - wrapper::vcgtz(vin_deq.val[0]), - wrapper::vcgtz(vin_deq.val[1]), - wrapper::vcgtz(vin_deq.val[2]), - wrapper::vcgtz(vin_deq.val[3]), - } - }; -#else // __aarch64__ const uint32x4x4_t pos_mask = { { @@ -184,7 +583,6 @@ void neon_qasymm8_activation(const ITensor *src, ITensor *dst, const ActivationL wrapper::vcgt(vin_deq.val[3], vconst_0_f32), } }; -#endif // __aarch64__ const float32x4x4_t tmp_dep = { @@ -198,6 +596,7 @@ void neon_qasymm8_activation(const ITensor *src, ITensor *dst, const ActivationL tmp = vquantize(tmp_dep, qi_out); } +#endif // __aarch64__ else { ARM_COMPUTE_ERROR("Unsupported activation function"); @@ -237,6 +636,7 @@ void neon_qasymm8_activation(const ITensor *src, ITensor *dst, const ActivationL tmp_f = a_f32 * std::tanh(b_f32 * tmp_f); tmp = quantize_qasymm8(tmp_f, qi_out); } +#ifndef __aarch64__ // LUT-based implementation is used for aarch64 instead. else if(act == ActivationLayerInfo::ActivationFunction::HARD_SWISH) { float tmp_f = dequantize_qasymm8(in, qi_in); @@ -249,6 +649,7 @@ void neon_qasymm8_activation(const ITensor *src, ITensor *dst, const ActivationL tmp_f = tmp_f > 0 ? tmp_f : tmp_f * a_f32; tmp = quantize_qasymm8(tmp_f, qi_out); } +#endif // __aarch64__ else { ARM_COMPUTE_ERROR("Unsupported activation function"); diff --git a/src/cpu/kernels/activation/generic/sve2/qasymm8.cpp b/src/cpu/kernels/activation/generic/sve2/qasymm8.cpp index 2fa8dee5f1..928a414fb0 100644 --- a/src/cpu/kernels/activation/generic/sve2/qasymm8.cpp +++ b/src/cpu/kernels/activation/generic/sve2/qasymm8.cpp @@ -57,10 +57,7 @@ void sve2_qasymm8_activation(const ITensor *src, ITensor *dst, const ActivationL const auto vconst_1 = svdup_n_f32(1.f); const auto va_f32 = svdup_n_f32(act_info.a()); const auto vb_f32 = svdup_n_f32(act_info.b()); - const auto const_6_f32 = svdup_n_f32(6.f); - const auto const_0_f32 = svdup_n_f32(0.f); - const auto const_3_f32 = svdup_n_f32(3.f); - const auto const_inv_6_f32 = svdup_n_f32(0.166666667f); + // Initialise scale/offset for re-quantization bool requant = true; @@ -146,19 +143,6 @@ void sve2_qasymm8_activation(const ITensor *src, ITensor *dst, const ActivationL // Re-quantize to new output space tmp = svquantize_z(pg, tmp_dep, qi_out); } - else if(act == ActivationLayerInfo::ActivationFunction::HARD_SWISH) - { - // De-quantize - const auto vin_deq = svdequantize_z(pg, vin, qi_in); - // Perform activation - const svfloat32x4_t tmp_dep = svcreate4_f32(svmul_f32_z(pg, svget4_f32(vin_deq, 0), svmul_f32_z(pg, const_inv_6_f32, svmin_f32_z(pg, const_6_f32, svmax_f32_z(pg, const_0_f32, svadd_f32_z(pg, - svget4_f32(vin_deq, 0), const_3_f32))))), - svmul_f32_z(pg, svget4_f32(vin_deq, 1), svmul_f32_z(pg, const_inv_6_f32, svmin_f32_z(pg, const_6_f32, svmax_f32_z(pg, const_0_f32, svadd_f32_z(pg, svget4_f32(vin_deq, 1), const_3_f32))))), - svmul_f32_z(pg, svget4_f32(vin_deq, 2), svmul_f32_z(pg, const_inv_6_f32, svmin_f32_z(pg, const_6_f32, svmax_f32_z(pg, const_0_f32, svadd_f32_z(pg, svget4_f32(vin_deq, 2), const_3_f32))))), - svmul_f32_z(pg, svget4_f32(vin_deq, 3), svmul_f32_z(pg, const_inv_6_f32, svmin_f32_z(pg, const_6_f32, svmax_f32_z(pg, const_0_f32, svadd_f32_z(pg, svget4_f32(vin_deq, 3), const_3_f32)))))); - // Re-quantize to new output space - tmp = svquantize_z(pg, tmp_dep, qi_out); - } else if(act == ActivationLayerInfo::ActivationFunction::LEAKY_RELU) { svbool_t p0, p1, p2, p3; diff --git a/src/cpu/kernels/activation/list.h b/src/cpu/kernels/activation/list.h index bf9aa0f373..b2322a6477 100644 --- a/src/cpu/kernels/activation/list.h +++ b/src/cpu/kernels/activation/list.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 Arm Limited. + * Copyright (c) 2020-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -32,6 +32,7 @@ namespace cpu void func_name(const ITensor *src, ITensor *dst, const ActivationLayerInfo &act_info, const Window &window) DECLARE_ACTIVATION_KERNEL(neon_qasymm8_activation); +DECLARE_ACTIVATION_KERNEL(neon_qasymm8_activation_lut); DECLARE_ACTIVATION_KERNEL(sve2_qasymm8_activation); DECLARE_ACTIVATION_KERNEL(neon_qasymm8_signed_activation); DECLARE_ACTIVATION_KERNEL(sve2_qasymm8_signed_activation); diff --git a/src/cpu/kernels/add/generic/neon/fp16.cpp b/src/cpu/kernels/add/generic/neon/fp16.cpp index 12d4a467b7..bb6636af1e 100644 --- a/src/cpu/kernels/add/generic/neon/fp16.cpp +++ b/src/cpu/kernels/add/generic/neon/fp16.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -33,6 +33,11 @@ void add_fp16_neon(const ITensor *src0, const ITensor *src1, ITensor *dst, const { return add_same_neon(src0, src1, dst, policy, window); } + +void add_fp16_neon_as_1d_array(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window) +{ + return add_same_neon_as_1d_array(src0, src1, dst, policy, window); +} } } // namespace arm_compute #endif /* (__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) */ diff --git a/src/cpu/kernels/add/generic/neon/fp32.cpp b/src/cpu/kernels/add/generic/neon/fp32.cpp index 3563162fce..1d313a191d 100644 --- a/src/cpu/kernels/add/generic/neon/fp32.cpp +++ b/src/cpu/kernels/add/generic/neon/fp32.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -32,5 +32,10 @@ void add_fp32_neon(const ITensor *src0, const ITensor *src1, ITensor *dst, const { return add_same_neon(src0, src1, dst, policy, window); } + +void add_fp32_neon_as_1d_array(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window) +{ + return add_same_neon_as_1d_array(src0, src1, dst, policy, window); +} } } // namespace arm_compute diff --git a/src/cpu/kernels/add/generic/neon/impl.cpp b/src/cpu/kernels/add/generic/neon/impl.cpp index ad3e445ab0..67985c985e 100644 --- a/src/cpu/kernels/add/generic/neon/impl.cpp +++ b/src/cpu/kernels/add/generic/neon/impl.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 Arm Limited. + * Copyright (c) 2020-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -128,6 +128,35 @@ void add_same_neon(const ITensor *src0, const ITensor *src1, ITensor *dst, const } } +template +void add_same_neon_as_1d_array(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window) +{ + const ScalarType *src0_ptr = reinterpret_cast(src0->buffer()); + const ScalarType *src1_ptr = reinterpret_cast(src1->buffer()); + ScalarType *dst_ptr = reinterpret_cast(dst->buffer()); + + constexpr int window_step_x = 16 / sizeof(ScalarType); + const auto window_start_x = static_cast(window.x().start()); + const auto window_end_x = static_cast(window.x().end()); + + int x = window_start_x; + for(; x <= (window_end_x - window_step_x); x += window_step_x) + { + const auto val1 = wrapper::vloadq(src0_ptr + x); + const auto val2 = wrapper::vloadq(src1_ptr + x); + const auto res = (policy == ConvertPolicy::SATURATE) ? wrapper::vqadd(val1, val2) : wrapper::vadd(val1, val2); + wrapper::vstore(dst_ptr + x, res); + } + + // Compute left-over elements + for(; x < window_end_x; ++x) + { + const auto val1 = *(src0_ptr + x); + const auto val2 = *(src1_ptr + x); + *(dst_ptr + x) = (policy == ConvertPolicy::SATURATE) ? wrapper::add_sat(val1, val2) : val1 + val2; + } +} + template void add_same_neon(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window); template void add_same_neon(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window); template void add_same_neon(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window); @@ -137,5 +166,14 @@ template void add_same_neon(const ITensor *src0, const ITensor *src1, I template void add_same_neon(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window); #endif /* (__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) */ +template void add_same_neon_as_1d_array(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window); +template void add_same_neon_as_1d_array(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window); +template void add_same_neon_as_1d_array(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window); +template void add_same_neon_as_1d_array(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window); + +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) +template void add_same_neon_as_1d_array(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window); +#endif /* (__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) */ + } // namespace cpu } // namespace arm_compute diff --git a/src/cpu/kernels/add/generic/neon/impl.h b/src/cpu/kernels/add/generic/neon/impl.h index 07afdda225..f8f0f517b0 100644 --- a/src/cpu/kernels/add/generic/neon/impl.h +++ b/src/cpu/kernels/add/generic/neon/impl.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -32,6 +32,9 @@ namespace cpu { template void add_same_neon(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window); + +template +void add_same_neon_as_1d_array(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window); } // namespace cpu } // namespace arm_compute #endif // SRC_CORE_NEON_KERNELS_ADD_IMPL_H \ No newline at end of file diff --git a/src/cpu/kernels/add/generic/neon/integer.cpp b/src/cpu/kernels/add/generic/neon/integer.cpp index 62c19e66b1..ffead03474 100644 --- a/src/cpu/kernels/add/generic/neon/integer.cpp +++ b/src/cpu/kernels/add/generic/neon/integer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -42,5 +42,20 @@ void add_s32_neon(const ITensor *src0, const ITensor *src1, ITensor *dst, const { return add_same_neon(src0, src1, dst, policy, window); } + +void add_u8_neon_as_1d_array(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window) +{ + return add_same_neon_as_1d_array(src0, src1, dst, policy, window); +} + +void add_s16_neon_as_1d_array(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window) +{ + return add_same_neon_as_1d_array(src0, src1, dst, policy, window); +} + +void add_s32_neon_as_1d_array(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window) +{ + return add_same_neon_as_1d_array(src0, src1, dst, policy, window); +} } } // namespace arm_compute diff --git a/src/cpu/kernels/add/list.h b/src/cpu/kernels/add/list.h index 9d7c9a67ff..0285b231e0 100644 --- a/src/cpu/kernels/add/list.h +++ b/src/cpu/kernels/add/list.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 Arm Limited. + * Copyright (c) 2020-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -38,10 +38,15 @@ DECLARE_ADD_KERNEL(add_qasymm8_neon); DECLARE_ADD_KERNEL(add_qasymm8_signed_neon); DECLARE_ADD_KERNEL(add_qsymm16_neon); DECLARE_ADD_KERNEL(add_fp32_neon); +DECLARE_ADD_KERNEL(add_fp32_neon_as_1d_array); DECLARE_ADD_KERNEL(add_fp16_neon); +DECLARE_ADD_KERNEL(add_fp16_neon_as_1d_array); DECLARE_ADD_KERNEL(add_u8_neon); +DECLARE_ADD_KERNEL(add_u8_neon_as_1d_array); DECLARE_ADD_KERNEL(add_s16_neon); +DECLARE_ADD_KERNEL(add_s16_neon_as_1d_array); DECLARE_ADD_KERNEL(add_s32_neon); +DECLARE_ADD_KERNEL(add_s32_neon_as_1d_array); DECLARE_ADD_KERNEL(add_fp32_sve); DECLARE_ADD_KERNEL(add_fp16_sve); DECLARE_ADD_KERNEL(add_u8_sve); diff --git a/src/cpu/kernels/assembly/arm_gemm.hpp b/src/cpu/kernels/assembly/arm_gemm.hpp index 200e04f9a8..4c127b4ec3 100644 --- a/src/cpu/kernels/assembly/arm_gemm.hpp +++ b/src/cpu/kernels/assembly/arm_gemm.hpp @@ -47,6 +47,46 @@ enum class GemmMethod GEMM_HYBRID_QUANTIZED }; +enum class WeightFormat +{ + UNSPECIFIED = 0x1, + ANY = 0x2, + OHWI = 0x100100, + OHWIo2 = 0x100200, + OHWIo4 = 0x100400, + OHWIo8 = 0x100800, + OHWIo16 = 0x101000, + OHWIo32 = 0x102000, + OHWIo64 = 0x104000, + OHWIo128 = 0x108000, + OHWIo4i2 = 0x200400, + OHWIo4i2_bf16 = 0x200410, + OHWIo8i2 = 0x200800, + OHWIo8i2_bf16 = 0x200810, + OHWIo16i2 = 0x201000, + OHWIo16i2_bf16 = 0x201010, + OHWIo32i2 = 0x202000, + OHWIo32i2_bf16 = 0x202010, + OHWIo64i2 = 0x204000, + OHWIo64i2_bf16 = 0x204010, + OHWIo4i4 = 0x400400, + OHWIo4i4_bf16 = 0x400410, + OHWIo8i4 = 0x400800, + OHWIo8i4_bf16 = 0x400810, + OHWIo16i4 = 0x401000, + OHWIo16i4_bf16 = 0x401010, + OHWIo32i4 = 0x402000, + OHWIo32i4_bf16 = 0x402010, + OHWIo64i4 = 0x404000, + OHWIo64i4_bf16 = 0x404010, + OHWIo2i8 = 0x800200, + OHWIo4i8 = 0x800400, + OHWIo8i8 = 0x800800, + OHWIo16i8 = 0x801000, + OHWIo32i8 = 0x802000, + OHWIo64i8 = 0x804000 +}; + struct KernelDescription { GemmMethod method = GemmMethod::DEFAULT; @@ -69,6 +109,7 @@ struct GemmConfig std::string filter = ""; unsigned int inner_block_size = 0; unsigned int outer_block_size = 0; + WeightFormat weight_format = WeightFormat::ANY; GemmConfig(GemmMethod method) : method(method) @@ -102,24 +143,25 @@ struct GemmArgs { public: const CPUInfo *_ci; - unsigned int _Msize; - unsigned int _Nsize; - unsigned int _Ksize; + unsigned int _Msize; // num of tiles + unsigned int _Nsize; // output channels + unsigned int _Ksize; // input channels unsigned int _Ksections; unsigned int _nbatches; - unsigned int _nmulti; + unsigned int _nmulti; // n_gemms to be performed bool _indirect_input; Activation _act; int _maxthreads; + bool _fixed_format; bool _fast_mode; const GemmConfig *_cfg; GemmArgs(const CPUInfo *ci, unsigned int M, unsigned int N, unsigned int K, unsigned int Ksections, unsigned int nbatches, unsigned int nmulti, bool indirect_input, Activation act, const int maxthreads, - bool fast_mode = false, const GemmConfig *cfg = nullptr) - : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _Ksections(Ksections), _nbatches(nbatches), _nmulti(nmulti), _indirect_input(indirect_input), _act(act), _maxthreads(maxthreads), _fast_mode(fast_mode), - _cfg(cfg) + bool fixed_format = false, bool fast_mode = false, const GemmConfig *cfg = nullptr) + : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _Ksections(Ksections), _nbatches(nbatches), _nmulti(nmulti), _indirect_input(indirect_input), _act(act), _maxthreads(maxthreads), + _fixed_format(fixed_format), _fast_mode(fast_mode), _cfg(cfg) { } }; @@ -188,6 +230,6 @@ template std::vector get_compatible_kernels(const GemmArgs &args, const OutputStage & = {}); template -bool has_opt_gemm(const GemmArgs &args, const OutputStage & = {}); +bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const OutputStage & = {}); } // namespace arm_gemm diff --git a/src/cpu/kernels/cast/generic/neon/bfloat16.cpp b/src/cpu/kernels/cast/generic/neon/bfloat16.cpp index aac4ef4ca0..eed537039f 100644 --- a/src/cpu/kernels/cast/generic/neon/bfloat16.cpp +++ b/src/cpu/kernels/cast/generic/neon/bfloat16.cpp @@ -21,7 +21,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) +#if defined(ARM_COMPUTE_ENABLE_BF16) #include "arm_compute/core/TensorInfo.h" #include "src/core/NEON/wrapper/wrapper.h" @@ -142,4 +142,4 @@ void neon_bfloat16_to_fp32_cast(const ITensor *_src, ITensor *_dst, const Thread } // namespace cpu } // namespace arm_compute -#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ +#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ diff --git a/src/cpu/kernels/elementwise_binary/generic/sve/fp16.cpp b/src/cpu/kernels/elementwise_binary/generic/sve/fp16.cpp index 8adacbfe67..85224351df 100644 --- a/src/cpu/kernels/elementwise_binary/generic/sve/fp16.cpp +++ b/src/cpu/kernels/elementwise_binary/generic/sve/fp16.cpp @@ -33,7 +33,7 @@ namespace cpu template void sve_fp16_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window) { - return elementwise_arithmetic_op(in1, in2, out, window); + return elementwise_arithmetic_op(in1, in2, out, op, window); } template void sve_fp16_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); @@ -48,7 +48,7 @@ template void sve_fp16_elementwise_binary(const ITen template void sve_fp16_comparison_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window) { - return elementwise_comparison_op(in1, in2, out, window); + return elementwise_comparison_op(in1, in2, out, op, window); } template void sve_fp16_comparison_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); diff --git a/src/cpu/kernels/elementwise_binary/generic/sve/fp32.cpp b/src/cpu/kernels/elementwise_binary/generic/sve/fp32.cpp index 0f80813d15..2b479f76f1 100644 --- a/src/cpu/kernels/elementwise_binary/generic/sve/fp32.cpp +++ b/src/cpu/kernels/elementwise_binary/generic/sve/fp32.cpp @@ -31,7 +31,7 @@ namespace cpu template void sve_fp32_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window) { - return elementwise_arithmetic_op(in1, in2, out, window); + return elementwise_arithmetic_op(in1, in2, out, op, window); } template void sve_fp32_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); @@ -46,7 +46,7 @@ template void sve_fp32_elementwise_binary(const ITen template void sve_fp32_comparison_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window) { - return elementwise_comparison_op(in1, in2, out, window); + return elementwise_comparison_op(in1, in2, out, op, window); } template void sve_fp32_comparison_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); template void sve_fp32_comparison_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); diff --git a/src/cpu/kernels/elementwise_binary/generic/sve/impl.cpp b/src/cpu/kernels/elementwise_binary/generic/sve/impl.cpp index 2a8b155d14..c0515f2abc 100644 --- a/src/cpu/kernels/elementwise_binary/generic/sve/impl.cpp +++ b/src/cpu/kernels/elementwise_binary/generic/sve/impl.cpp @@ -32,81 +32,117 @@ namespace cpu { using namespace arm_compute::wrapper; -template -struct LoopArguments +template +void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, ArithmeticOperation op, const Window &window) { - OperatorType op; - const InputScalarType *input1_ptr; - const InputScalarType *input2_ptr; - OutputScalarType *output_ptr; -}; - -template -struct BroadcastLoopArguments -{ - OperatorType op; - const InputScalarType *input1_ptr; - InputScalarType broadcast_value; - OutputScalarType *output_ptr; - bool reorder; -}; + using VectorType = typename sve_vector::type; -template -void arithmetic_op_loop(svbool_t pg, const LoopArguments &args) -{ - const auto in1 = svld1(pg, args.input1_ptr); - const auto in2 = svld1(pg, args.input2_ptr); - const auto res = elementwise_arithmetic_op::type>(pg, in1, in2, args.op); - svst1(pg, args.output_ptr, res); -} + const auto all_true_pg = svptrue(); -template -void arithmetic_op_broadcast_loop(svbool_t pg, const BroadcastLoopArguments &args) -{ - const auto non_broadcast_vector = svld1(pg, args.input1_ptr); - const auto broadcast_vector = svdup_n(args.broadcast_value); - const auto in1 = args.reorder ? broadcast_vector : non_broadcast_vector; - const auto in2 = args.reorder ? non_broadcast_vector : broadcast_vector; - const auto res = elementwise_arithmetic_op::type>(pg, in1, in2, args.op); - svst1(pg, args.output_ptr, res); -} + // Create input windows + Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()); + Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()); -template -void comparison_op_loop(svbool_t pg, const LoopArguments &args) -{ - const auto in1 = svld1(pg, args.input1_ptr); - const auto in2 = svld1(pg, args.input2_ptr); - const auto res = elementwise_comparison_op::type, typename sve_vector::type>(pg, in1, in2, args.op); - const svbool_t output_pg = narrow_to_byte_predicate(pg); - svst1(output_pg, args.output_ptr, res); -} + // Clear X Dimension on execution window as we handle manually + Window win = window; + win.set(Window::DimX, Window::Dimension(0, 1, 1)); -template -void comparison_op_broadcast_loop(svbool_t pg, const BroadcastLoopArguments &args) -{ - const auto non_broadcast_vector = svld1(pg, args.input1_ptr); - const auto broadcast_vector = svdup_n(args.broadcast_value); - const auto in1 = args.reorder ? broadcast_vector : non_broadcast_vector; - const auto in2 = args.reorder ? non_broadcast_vector : broadcast_vector; - const auto res = elementwise_comparison_op::type, typename sve_vector::type>(pg, in1, in2, args.op); - const svbool_t output_pg = narrow_to_byte_predicate(pg); - svst1(output_pg, args.output_ptr, res); -} + const auto window_start_x = static_cast(window.x().start()); + const auto window_end_x = static_cast(window.x().end()); + const bool is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x(); + + if(is_broadcast_across_x) + { + const bool is_broadcast_input_2 = input2_win.x().step() == 0; + Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win; + Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win; + const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1; + const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1; + + // Clear X Dimension on execution window as we handle manually + non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator broadcast_input(broadcast_tensor, broadcast_win); + Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win); + Iterator output(out, win); + + execute_window_loop(win, [&](const Coordinates &) + { + auto output_ptr = reinterpret_cast(output.ptr()); + const auto non_broadcast_input_ptr = reinterpret_cast(non_broadcast_input.ptr()); + const ScalarType broadcast_value = *reinterpret_cast(broadcast_input.ptr()); + const auto broadcast_vector = svdup_n(broadcast_value); + + int x = window_start_x; -template -using LoopFuncType = void (*)(svbool_t, const LoopArguments &); + svbool_t pg = svwhilelt(x, window_end_x); + do + { + const auto non_broadcast_vector = svld1(pg, non_broadcast_input_ptr + x); + VectorType res{}; -template -using BroadcastLoopFuncType = void (*)(svbool_t, const BroadcastLoopArguments &); + if(is_broadcast_input_2) + { + res = elementwise_arithmetic_op::type>(pg, non_broadcast_vector, broadcast_vector, op); + } + else + { + res = elementwise_arithmetic_op::type>(pg, broadcast_vector, non_broadcast_vector, op); + } + svst1(pg, output_ptr + x, res); -template ::type, - typename OutputScalarType = typename sve_scalar::type> -void elementwise_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, - OperatorType op, - LoopFuncType func, - BroadcastLoopFuncType broadcast_func) + x += svcnt(); + pg = svwhilelt(x, window_end_x); + } + while(svptest_any(all_true_pg, pg)); + }, + broadcast_input, non_broadcast_input, output); + } + else + { + // Clear X Dimension on execution window as we handle manually + input1_win.set(Window::DimX, Window::Dimension(0, 1, 1)); + input2_win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator input1(in1, input1_win); + Iterator input2(in2, input2_win); + Iterator output(out, win); + + execute_window_loop(win, [&](const Coordinates &) + { + auto output_ptr = reinterpret_cast(output.ptr()); + const auto input1_ptr = reinterpret_cast(input1.ptr()); + const auto input2_ptr = reinterpret_cast(input2.ptr()); + + int x = window_start_x; + + svbool_t pg = svwhilelt(x, window_end_x); + do + { + const auto in1 = svld1(pg, input1_ptr + x); + const auto in2 = svld1(pg, input2_ptr + x); + const auto res = elementwise_arithmetic_op::type>(pg, in1, in2, op); + svst1(pg, output_ptr + x, res); + + x += svcnt(); + pg = svwhilelt(x, window_end_x); + } + while(svptest_any(all_true_pg, pg)); + }, + input1, input2, output); + } +} +template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const ArithmeticOperation op, const Window &window); +template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const ArithmeticOperation op, const Window &window); +template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const ArithmeticOperation op, const Window &window); +template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const ArithmeticOperation op, const Window &window); + +template +void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, ComparisonOperation op, const Window &window) { + static_assert(sizeof(InputScalarType) >= sizeof(OutputScalarType), "input data type's width should be equal to or greater than output data type's width"); + + using OutputVectorType = typename sve_vector::type; const auto all_true_pg = svptrue(); // Create input windows @@ -141,20 +177,26 @@ void elementwise_op(const ITensor *in1, const ITensor *in2, ITensor *out, const auto output_ptr = reinterpret_cast(output.ptr()); const auto non_broadcast_input_ptr = reinterpret_cast(non_broadcast_input.ptr()); const InputScalarType broadcast_value = *reinterpret_cast(broadcast_input.ptr()); + const auto broadcast_vector = svdup_n(broadcast_value); int x = window_start_x; svbool_t pg = svwhilelt(x, window_end_x); do { - broadcast_func(pg, + const auto non_broadcast_vector = svld1(pg, non_broadcast_input_ptr + x); + const svbool_t output_pg = narrow_to_byte_predicate(pg); + OutputVectorType res{}; + if(is_broadcast_input_2) { - op, - non_broadcast_input_ptr + x, - broadcast_value, - output_ptr + x, - !is_broadcast_input_2 - }); + res = elementwise_comparison_op::type, typename sve_vector::type>(pg, non_broadcast_vector, broadcast_vector, op); + } + else + { + res = elementwise_comparison_op::type, typename sve_vector::type>(pg, broadcast_vector, non_broadcast_vector, op); + } + svst1(output_pg, output_ptr + x, res); + x += svcnt(); pg = svwhilelt(x, window_end_x); } @@ -183,13 +225,12 @@ void elementwise_op(const ITensor *in1, const ITensor *in2, ITensor *out, const svbool_t pg = svwhilelt(x, window_end_x); do { - func(pg, - { - op, - input1_ptr + x, - input2_ptr + x, - output_ptr + x - }); + const auto in1 = svld1(pg, input1_ptr + x); + const auto in2 = svld1(pg, input2_ptr + x); + const auto res = elementwise_comparison_op::type, typename sve_vector::type>(pg, in1, in2, op); + const svbool_t output_pg = narrow_to_byte_predicate(pg); + svst1(output_pg, output_ptr + x, res); + x += svcnt(); pg = svwhilelt(x, window_end_x); } @@ -199,97 +240,11 @@ void elementwise_op(const ITensor *in1, const ITensor *in2, ITensor *out, const } } -template -void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window) -{ - using VectorType = typename sve_vector::type; - - elementwise_op(in1, in2, out, window, op, - &arithmetic_op_loop, - &arithmetic_op_broadcast_loop); -} -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); - -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); - -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); - -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); - -template -void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window) -{ - static_assert(sizeof(InputScalarType) >= sizeof(OutputScalarType), "input data type's width should be equal to or greater than output data type's width"); - using InputVectorType = typename sve_vector::type; - using OutputVectorType = typename sve_vector::type; - - elementwise_op(in1, in2, out, window, op, - &comparison_op_loop, - &comparison_op_broadcast_loop); -} - -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); - -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); - -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); - -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); - -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); -template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); +template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const ComparisonOperation op, const Window &window); +template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const ComparisonOperation op, const Window &window); +template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const ComparisonOperation op, const Window &window); +template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const ComparisonOperation op, const Window &window); +template void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const ComparisonOperation op, const Window &window); template <> svint32_t elementwise_pow(svbool_t &pg, const svint32_t &a, const svint32_t &b) diff --git a/src/cpu/kernels/elementwise_binary/generic/sve/impl.h b/src/cpu/kernels/elementwise_binary/generic/sve/impl.h index 606090d417..860c50a1e0 100644 --- a/src/cpu/kernels/elementwise_binary/generic/sve/impl.h +++ b/src/cpu/kernels/elementwise_binary/generic/sve/impl.h @@ -153,11 +153,11 @@ OutputVectorType elementwise_comparison_op(svbool_t &pg, const InputVectorType & return ret; } -template -void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); +template +void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, ArithmeticOperation op, const Window &window); -template -void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); +template +void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, ComparisonOperation op, const Window &window); } // namespace cpu } // namespace arm_compute #endif /* SRC_CORE_SVE_KERNELS_ELEMENTWISE_LIST_H */ diff --git a/src/cpu/kernels/elementwise_binary/generic/sve/integer.cpp b/src/cpu/kernels/elementwise_binary/generic/sve/integer.cpp index 8f7e27184b..c313fc6e04 100644 --- a/src/cpu/kernels/elementwise_binary/generic/sve/integer.cpp +++ b/src/cpu/kernels/elementwise_binary/generic/sve/integer.cpp @@ -31,7 +31,7 @@ namespace cpu template void sve_s32_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window) { - return elementwise_arithmetic_op(in1, in2, out, window); + return elementwise_arithmetic_op(in1, in2, out, op, window); } template void sve_s32_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); template void sve_s32_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); @@ -45,7 +45,7 @@ template void sve_s32_elementwise_binary(const ITens template void sve_s16_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window) { - return elementwise_arithmetic_op(in1, in2, out, window); + return elementwise_arithmetic_op(in1, in2, out, op, window); } template void sve_s16_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); template void sve_s16_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); @@ -59,7 +59,7 @@ template void sve_s16_elementwise_binary(const ITens template void sve_u8_comparison_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window) { - return elementwise_comparison_op(in1, in2, out, window); + return elementwise_comparison_op(in1, in2, out, op, window); } template void sve_u8_comparison_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); template void sve_u8_comparison_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); @@ -71,7 +71,7 @@ template void sve_u8_comparison_elementwise_binary void sve_s16_comparison_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window) { - return elementwise_comparison_op(in1, in2, out, window); + return elementwise_comparison_op(in1, in2, out, op, window); } template void sve_s16_comparison_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); template void sve_s16_comparison_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); @@ -83,7 +83,7 @@ template void sve_s16_comparison_elementwise_binary void sve_s32_comparison_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window) { - return elementwise_comparison_op(in1, in2, out, window); + return elementwise_comparison_op(in1, in2, out, op, window); } template void sve_s32_comparison_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); template void sve_s32_comparison_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); diff --git a/src/cpu/kernels/elementwise_binary/generic/sve2/impl.h b/src/cpu/kernels/elementwise_binary/generic/sve2/impl.h index f34d05eb37..41e0ac77db 100644 --- a/src/cpu/kernels/elementwise_binary/generic/sve2/impl.h +++ b/src/cpu/kernels/elementwise_binary/generic/sve2/impl.h @@ -31,37 +31,6 @@ namespace cpu { using namespace arm_compute::wrapper; -template -struct QuantizedLoopArguments -{ - OperatorType op; - const InputScalarType *input1_ptr; - const InputScalarType *input2_ptr; - OutputScalarType *output_ptr; - - const svint32_t &in1_offset; - const svint32_t &in2_offset; - const svint32_t &out_offset; - const svfloat32_t &in1_scale; - const svfloat32_t &in2_scale; - const svfloat32_t &out_scale; -}; - -template -struct BroadcastQuantizedLoopArguments -{ - OperatorType op; - const InputScalarType *input1_ptr; - float broadcast_value; - OutputScalarType *output_ptr; - bool reorder; - - const svint32_t &in1_offset; - const svint32_t &out_offset; - const svfloat32_t &in1_scale; - const svfloat32_t &out_scale; -}; - inline svfloat32x4_t load_quantized(const int8_t *ptr, svbool_t pg, const svint32_t &offset, const svfloat32_t &scale) { auto x = svld1(pg, ptr); @@ -131,98 +100,143 @@ inline void store_quantized(int8_t *ptr, svbool_t pg, svfloat32x4_t data, const svst1(pg, ptr, narrowed); } -template -inline void arithmetic_op_quantized_loop(svbool_t pg, const QuantizedLoopArguments &args) +template +void elementwise_arithmetic_quantized_op(const ITensor *in1, const ITensor *in2, ITensor *out, ArithmeticOperation op, const Window &window) { - const auto in1 = load_quantized(args.input1_ptr, pg, args.in1_offset, args.in1_scale); - const auto in2 = load_quantized(args.input2_ptr, pg, args.in2_offset, args.in2_scale); + const auto all_true_pg = wrapper::svptrue(); - const auto result = svcreate4( - elementwise_arithmetic_op(pg, svget4(in1, 0), svget4(in2, 0), args.op), - elementwise_arithmetic_op(pg, svget4(in1, 1), svget4(in2, 1), args.op), - elementwise_arithmetic_op(pg, svget4(in1, 2), svget4(in2, 2), args.op), - elementwise_arithmetic_op(pg, svget4(in1, 3), svget4(in2, 3), args.op)); + // Create input windows + Window input1_win = window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()); + Window input2_win = window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()); - store_quantized(args.output_ptr, pg, result, args.out_offset, args.out_scale); -} + // Clear X Dimension on execution window as we handle manually + Window win = window; + win.set(Window::DimX, Window::Dimension(0, 1, 1)); -template -inline void arithmetic_op_broadcast_quantized_loop(svbool_t pg, const BroadcastQuantizedLoopArguments &args) -{ - const auto in1 = load_quantized(args.input1_ptr, pg, args.in1_offset, args.in1_scale); - const auto in2 = svcreate4( - svdup_n(args.broadcast_value), svdup_n(args.broadcast_value), svdup_n(args.broadcast_value), svdup_n(args.broadcast_value)); + const auto window_start_x = static_cast(window.x().start()); + const auto window_end_x = static_cast(window.x().end()); + const bool is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x(); - const auto &af = args.reorder ? in2 : in1; - const auto &bf = args.reorder ? in1 : in2; + const auto output_voffset = svdup_n(out->info()->quantization_info().uniform().offset); + const auto output_vscale = svdup_n(1.f / out->info()->quantization_info().uniform().scale); - const auto result = svcreate4( - elementwise_arithmetic_op(pg, svget4(af, 0), svget4(bf, 0), args.op), - elementwise_arithmetic_op(pg, svget4(af, 1), svget4(bf, 1), args.op), - elementwise_arithmetic_op(pg, svget4(af, 2), svget4(bf, 2), args.op), - elementwise_arithmetic_op(pg, svget4(af, 3), svget4(bf, 3), args.op)); + if(is_broadcast_across_x) + { + const bool is_broadcast_input_2 = input2_win.x().step() == 0; + Window broadcast_win = is_broadcast_input_2 ? input2_win : input1_win; + Window non_broadcast_win = !is_broadcast_input_2 ? input2_win : input1_win; + const ITensor *broadcast_tensor = is_broadcast_input_2 ? in2 : in1; + const ITensor *non_broadcast_tensor = !is_broadcast_input_2 ? in2 : in1; - store_quantized(args.output_ptr, pg, result, args.out_offset, args.out_scale); -} + const auto non_broadcast_qinfo = is_broadcast_input_2 ? in1->info()->quantization_info() : in2->info()->quantization_info(); + const auto broadcast_qinfo = is_broadcast_input_2 ? in2->info()->quantization_info() : in1->info()->quantization_info(); -template -inline void comparison_op_quantized_loop(svbool_t pg, const QuantizedLoopArguments &args) -{ - const auto in1 = load_quantized(args.input1_ptr, pg, args.in1_offset, args.in1_scale); - const auto in2 = load_quantized(args.input2_ptr, pg, args.in2_offset, args.in2_scale); + const auto non_broadcast_voffset = svdup_n(non_broadcast_qinfo.uniform().offset); + const auto non_broadcast_vscale = svdup_n(non_broadcast_qinfo.uniform().scale); - using OutputVectorType = typename wrapper::traits::sve_vector::type; + // Clear X Dimension on execution window as we handle manually + non_broadcast_win.set(Window::DimX, Window::Dimension(0, 1, 1)); - const auto result = svcreate4( - elementwise_comparison_op(pg, svget4(in1, 0), svget4(in2, 0), args.op), - elementwise_comparison_op(pg, svget4(in1, 1), svget4(in2, 1), args.op), - elementwise_comparison_op(pg, svget4(in1, 2), svget4(in2, 2), args.op), - elementwise_comparison_op(pg, svget4(in1, 3), svget4(in2, 3), args.op)); + Iterator broadcast_input(broadcast_tensor, broadcast_win); + Iterator non_broadcast_input(non_broadcast_tensor, non_broadcast_win); + Iterator output(out, win); - const auto zipped_bottom = svzip1(svget4(result, 0), svget4(result, 1)); - const auto zipped_top = svzip1(svget4(result, 2), svget4(result, 3)); - const auto zipped = svzip1(zipped_bottom, zipped_top); - svst1(pg, args.output_ptr, zipped); -} + execute_window_loop(win, [&](const Coordinates &) + { + auto output_ptr = reinterpret_cast(output.ptr()); + const auto non_broadcast_input_ptr = reinterpret_cast(non_broadcast_input.ptr()); + const ScalarType broadcast_value = *reinterpret_cast(broadcast_input.ptr()); + const float broadcast_value_f = Qasymm8QuantizationHelper::dequantize(broadcast_value, broadcast_qinfo); + const auto in2 = svcreate4(svdup_n(broadcast_value_f), svdup_n(broadcast_value_f), svdup_n(broadcast_value_f), svdup_n(broadcast_value_f)); -template -inline void comparison_op_broadcast_quantized_loop(svbool_t pg, const BroadcastQuantizedLoopArguments &args) -{ - const auto in1 = load_quantized(args.input1_ptr, pg, args.in1_offset, args.in1_scale); - const auto in2 = svcreate4( - svdup_n(args.broadcast_value), svdup_n(args.broadcast_value), svdup_n(args.broadcast_value), svdup_n(args.broadcast_value)); + int x = window_start_x; + + svbool_t pg = wrapper::svwhilelt(x, window_end_x); + do + { + const auto in1 = load_quantized(non_broadcast_input_ptr + x, pg, non_broadcast_voffset, non_broadcast_vscale); - const auto &af = args.reorder ? in2 : in1; - const auto &bf = args.reorder ? in1 : in2; + svfloat32x4_t result{}; - using OutputVectorType = typename wrapper::traits::sve_vector::type; + if(!is_broadcast_input_2) + { + result = svcreate4( + elementwise_arithmetic_op(pg, svget4(in2, 0), svget4(in1, 0), op), + elementwise_arithmetic_op(pg, svget4(in2, 1), svget4(in1, 1), op), + elementwise_arithmetic_op(pg, svget4(in2, 2), svget4(in1, 2), op), + elementwise_arithmetic_op(pg, svget4(in2, 3), svget4(in1, 3), op)); + } + else + { + result = svcreate4( + elementwise_arithmetic_op(pg, svget4(in1, 0), svget4(in2, 0), op), + elementwise_arithmetic_op(pg, svget4(in1, 1), svget4(in2, 1), op), + elementwise_arithmetic_op(pg, svget4(in1, 2), svget4(in2, 2), op), + elementwise_arithmetic_op(pg, svget4(in1, 3), svget4(in2, 3), op)); + } - const auto result = svcreate4( - elementwise_comparison_op(pg, svget4(af, 0), svget4(bf, 0), args.op), - elementwise_comparison_op(pg, svget4(af, 1), svget4(bf, 1), args.op), - elementwise_comparison_op(pg, svget4(af, 2), svget4(bf, 2), args.op), - elementwise_comparison_op(pg, svget4(af, 3), svget4(bf, 3), args.op)); + store_quantized(output_ptr + x, pg, result, output_voffset, output_vscale); - const auto zipped_bottom = svzip1(svget4(result, 0), svget4(result, 1)); - const auto zipped_top = svzip1(svget4(result, 2), svget4(result, 3)); - const auto zipped = svzip1(zipped_bottom, zipped_top); - svst1(pg, args.output_ptr, zipped); -} + x += wrapper::svcnt(); + pg = wrapper::svwhilelt(x, window_end_x); + } + while(svptest_any(all_true_pg, pg)); + }, + broadcast_input, non_broadcast_input, output); + } + else + { + // Clear X Dimension on execution window as we handle manually + input1_win.set(Window::DimX, Window::Dimension(0, 1, 1)); + input2_win.set(Window::DimX, Window::Dimension(0, 1, 1)); + + Iterator input1(in1, input1_win); + Iterator input2(in2, input2_win); + Iterator output(out, win); + + const auto in1_voffset = svdup_n(in1->info()->quantization_info().uniform().offset); + const auto in1_vscale = svdup_n(in1->info()->quantization_info().uniform().scale); -template -using LoopQuantizedFuncType = void (*)(svbool_t, const QuantizedLoopArguments &); + const auto in2_voffset = svdup_n(in2->info()->quantization_info().uniform().offset); + const auto in2_vscale = svdup_n(in2->info()->quantization_info().uniform().scale); -template -using BroadcastQuantizedLoopFuncType = void (*)(svbool_t, const BroadcastQuantizedLoopArguments &); + execute_window_loop(win, [&](const Coordinates &) + { + auto output_ptr = reinterpret_cast(output.ptr()); + const auto input1_ptr = reinterpret_cast(input1.ptr()); + const auto input2_ptr = reinterpret_cast(input2.ptr()); -template ::type, - typename OutputScalarType = typename wrapper::sve_scalar::type> -void elementwise_quantized_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window, - OperatorType op, - LoopQuantizedFuncType func, - BroadcastQuantizedLoopFuncType broadcast_func) + int x = window_start_x; + + svbool_t pg = wrapper::svwhilelt(x, window_end_x); + do + { + const auto in1 = load_quantized(input1_ptr + x, pg, in1_voffset, in1_vscale); + const auto in2 = load_quantized(input2_ptr + x, pg, in2_voffset, in2_vscale); + + const auto result = svcreate4( + elementwise_arithmetic_op(pg, svget4(in1, 0), svget4(in2, 0), op), + elementwise_arithmetic_op(pg, svget4(in1, 1), svget4(in2, 1), op), + elementwise_arithmetic_op(pg, svget4(in1, 2), svget4(in2, 2), op), + elementwise_arithmetic_op(pg, svget4(in1, 3), svget4(in2, 3), op)); + + store_quantized(output_ptr + x, pg, result, output_voffset, output_vscale); + + x += wrapper::svcnt(); + pg = wrapper::svwhilelt(x, window_end_x); + } + while(svptest_any(all_true_pg, pg)); + }, + input1, input2, output); + } +} + +template +void elementwise_comparison_quantized_op(const ITensor *in1, const ITensor *in2, ITensor *out, ComparisonOperation op, const Window &window) { + static_assert(sizeof(InputScalarType) >= sizeof(OutputScalarType), "input data type's width should be equal to or greater than output data type's width"); + + using OutputVectorType = typename wrapper::traits::sve_vector::type; const auto all_true_pg = wrapper::svptrue(); // Create input windows @@ -237,9 +251,6 @@ void elementwise_quantized_op(const ITensor *in1, const ITensor *in2, ITensor *o const auto window_end_x = static_cast(window.x().end()); const bool is_broadcast_across_x = in1->info()->tensor_shape().x() != in2->info()->tensor_shape().x(); - const auto output_voffset = svdup_n(out->info()->quantization_info().uniform().offset); - const auto output_vscale = svdup_n(1.f / out->info()->quantization_info().uniform().scale); - if(is_broadcast_across_x) { const bool is_broadcast_input_2 = input2_win.x().step() == 0; @@ -266,23 +277,40 @@ void elementwise_quantized_op(const ITensor *in1, const ITensor *in2, ITensor *o auto output_ptr = reinterpret_cast(output.ptr()); const auto non_broadcast_input_ptr = reinterpret_cast(non_broadcast_input.ptr()); const InputScalarType broadcast_value = *reinterpret_cast(broadcast_input.ptr()); + const float broadcast_value_f = Qasymm8QuantizationHelper::dequantize(broadcast_value, broadcast_qinfo); + const auto in2 = svcreate4(svdup_n(broadcast_value_f), svdup_n(broadcast_value_f), svdup_n(broadcast_value_f), svdup_n(broadcast_value_f)); int x = window_start_x; svbool_t pg = wrapper::svwhilelt(x, window_end_x); do { - const auto args = BroadcastQuantizedLoopArguments + const auto in1 = load_quantized(non_broadcast_input_ptr + x, pg, non_broadcast_voffset, non_broadcast_vscale); + + svuint8x4_t result{}; + + if(!is_broadcast_input_2) { - op, - non_broadcast_input_ptr + x, - Qasymm8QuantizationHelper::dequantize(broadcast_value, broadcast_qinfo), - output_ptr + x, - !is_broadcast_input_2, - non_broadcast_voffset, output_voffset, - non_broadcast_vscale, output_vscale - }; - broadcast_func(pg, args); + result = svcreate4( + elementwise_comparison_op(pg, svget4(in2, 0), svget4(in1, 0), op), + elementwise_comparison_op(pg, svget4(in2, 1), svget4(in1, 1), op), + elementwise_comparison_op(pg, svget4(in2, 2), svget4(in1, 2), op), + elementwise_comparison_op(pg, svget4(in2, 3), svget4(in1, 3), op)); + } + else + { + result = svcreate4( + elementwise_comparison_op(pg, svget4(in1, 0), svget4(in2, 0), op), + elementwise_comparison_op(pg, svget4(in1, 1), svget4(in2, 1), op), + elementwise_comparison_op(pg, svget4(in1, 2), svget4(in2, 2), op), + elementwise_comparison_op(pg, svget4(in1, 3), svget4(in2, 3), op)); + } + + const auto zipped_bottom = svzip1(svget4(result, 0), svget4(result, 1)); + const auto zipped_top = svzip1(svget4(result, 2), svget4(result, 3)); + const auto zipped = svzip1(zipped_bottom, zipped_top); + svst1(pg, output_ptr + x, zipped); + x += wrapper::svcnt(); pg = wrapper::svwhilelt(x, window_end_x); } @@ -317,16 +345,19 @@ void elementwise_quantized_op(const ITensor *in1, const ITensor *in2, ITensor *o svbool_t pg = wrapper::svwhilelt(x, window_end_x); do { - const auto args = QuantizedLoopArguments - { - op, - input1_ptr + x, - input2_ptr + x, - output_ptr + x, - in1_voffset, in2_voffset, output_voffset, - in1_vscale, in2_vscale, output_vscale - }; - func(pg, args); + const auto in1 = load_quantized(input1_ptr + x, pg, in1_voffset, in1_vscale); + const auto in2 = load_quantized(input2_ptr + x, pg, in2_voffset, in2_vscale); + const auto result = svcreate4( + elementwise_comparison_op(pg, svget4(in1, 0), svget4(in2, 0), op), + elementwise_comparison_op(pg, svget4(in1, 1), svget4(in2, 1), op), + elementwise_comparison_op(pg, svget4(in1, 2), svget4(in2, 2), op), + elementwise_comparison_op(pg, svget4(in1, 3), svget4(in2, 3), op)); + + const auto zipped_bottom = svzip1(svget4(result, 0), svget4(result, 1)); + const auto zipped_top = svzip1(svget4(result, 2), svget4(result, 3)); + const auto zipped = svzip1(zipped_bottom, zipped_top); + svst1(pg, output_ptr + x, zipped); + x += wrapper::svcnt(); pg = wrapper::svwhilelt(x, window_end_x); } @@ -335,26 +366,6 @@ void elementwise_quantized_op(const ITensor *in1, const ITensor *in2, ITensor *o input1, input2, output); } } - -template -void elementwise_arithmetic_quantized_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window) -{ - using VectorType = typename wrapper::traits::sve_vector::type; - elementwise_quantized_op(in1, in2, out, window, op, - &arithmetic_op_quantized_loop, - &arithmetic_op_broadcast_quantized_loop); -} - -template -void elementwise_comparison_quantized_op(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window) -{ - static_assert(sizeof(InputScalarType) >= sizeof(OutputScalarType), "input data type's width should be equal to or greater than output data type's width"); - using InputVectorType = typename wrapper::traits::sve_vector::type; - using OutputVectorType = typename wrapper::traits::sve_vector::type; - elementwise_quantized_op(in1, in2, out, window, op, - &comparison_op_quantized_loop, - &comparison_op_broadcast_quantized_loop); -} } // namespace cpu } // namespace arm_compute diff --git a/src/cpu/kernels/elementwise_binary/generic/sve2/qasymm8.cpp b/src/cpu/kernels/elementwise_binary/generic/sve2/qasymm8.cpp index 479f053685..7435bb4f29 100644 --- a/src/cpu/kernels/elementwise_binary/generic/sve2/qasymm8.cpp +++ b/src/cpu/kernels/elementwise_binary/generic/sve2/qasymm8.cpp @@ -31,7 +31,7 @@ namespace cpu template void sve2_qasymm8_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window) { - return elementwise_arithmetic_quantized_op(in1, in2, out, window); + return elementwise_arithmetic_quantized_op(in1, in2, out, op, window); } template void sve2_qasymm8_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); @@ -46,7 +46,7 @@ template void sve2_qasymm8_elementwise_binary(const template void sve2_qasymm8_comparison_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window) { - return elementwise_comparison_quantized_op(in1, in2, out, window); + return elementwise_comparison_quantized_op(in1, in2, out, op, window); } template void sve2_qasymm8_comparison_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); diff --git a/src/cpu/kernels/elementwise_binary/generic/sve2/qasymm8_signed.cpp b/src/cpu/kernels/elementwise_binary/generic/sve2/qasymm8_signed.cpp index ce6250be35..1027a1eed0 100644 --- a/src/cpu/kernels/elementwise_binary/generic/sve2/qasymm8_signed.cpp +++ b/src/cpu/kernels/elementwise_binary/generic/sve2/qasymm8_signed.cpp @@ -31,7 +31,7 @@ namespace cpu template void sve2_qasymm8_signed_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window) { - return elementwise_arithmetic_quantized_op(in1, in2, out, window); + return elementwise_arithmetic_quantized_op(in1, in2, out, op, window); } template void sve2_qasymm8_signed_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); @@ -46,7 +46,7 @@ template void sve2_qasymm8_signed_elementwise_binary template void sve2_qasymm8_signed_comparison_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window) { - return elementwise_comparison_quantized_op(in1, in2, out, window); + return elementwise_comparison_quantized_op(in1, in2, out, op, window); } template void sve2_qasymm8_signed_comparison_elementwise_binary(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window); diff --git a/src/cpu/kernels/softmax/generic/sve/impl.cpp b/src/cpu/kernels/softmax/generic/sve/impl.cpp index f1442224e8..2340a31cbd 100644 --- a/src/cpu/kernels/softmax/generic/sve/impl.cpp +++ b/src/cpu/kernels/softmax/generic/sve/impl.cpp @@ -94,8 +94,9 @@ void sve_softmax_logits_1d_float(const ITensor *in, const ITensor *max, void *co /* Compute exponentials and sum */ { /* Get max value */ - const auto max_val = *reinterpret_cast(max_it.ptr()); - const auto vec_max = wrapper::svdup_n(max_val); + const auto max_val = *reinterpret_cast(max_it.ptr()); + const auto vec_max = wrapper::svdup_n(max_val); + const auto vec_beta = wrapper::svdup_n(static_cast(beta)); /* Init sum to zero */ auto vec_sum = wrapper::svdup_n(static_cast(0)); @@ -106,19 +107,19 @@ void sve_softmax_logits_1d_float(const ITensor *in, const ITensor *max, void *co do { auto vec_elements = svld1(pg, in_ptr + x); - vec_elements = svsub_z(pg, vec_elements, vec_max); - if(is_log) - { - vec_elements = svmul_z(pg, vec_elements, wrapper::svdup_n(static_cast(beta))); - vec_sum = svadd_m(pg, vec_sum, wrapper::svexp_z(pg, vec_elements)); - } - else + vec_elements = svmul_z(pg, svsub_z(pg, vec_elements, vec_max), vec_beta); + if(!is_log) { - vec_elements = wrapper::svexp_z(pg, svmul_z(pg, vec_elements, wrapper::svdup_n(static_cast(beta)))); + vec_elements = wrapper::svexp_z(pg, vec_elements); vec_sum = svadd_m(pg, vec_sum, vec_elements); } svst1(pg, tmp_ptr + x, vec_elements); + if(is_log) + { + vec_sum = svadd_m(pg, vec_sum, wrapper::svexp_z(pg, vec_elements)); + } + x += wrapper::svcnt(); pg = wrapper::svwhilelt(x, input_width); } diff --git a/src/cpu/kernels/softmax/generic/sve2/impl.cpp b/src/cpu/kernels/softmax/generic/sve2/impl.cpp index 9cdfe61446..8f677c62d4 100644 --- a/src/cpu/kernels/softmax/generic/sve2/impl.cpp +++ b/src/cpu/kernels/softmax/generic/sve2/impl.cpp @@ -80,13 +80,13 @@ void sve2_softmax_logits_1d_quantized(const ITensor *in, const ITensor *max, voi svbool_t pg_3 = svunpkhi(svunpkhi(pg)); do { - auto vec_elements = svld1(pg, in_ptr + x); - vec_elements = svsub_z(pg, vec_max, vec_elements); + const auto vec_elements = svld1(pg, in_ptr + x); + const auto vec_elements_sub = svreinterpret_u8(svsub_z(pg, vec_max, vec_elements)); - auto vec_elements_flt_0 = svcvt_f32_z(pg_0, svunpklo(svunpklo(vec_elements))); - auto vec_elements_flt_1 = svcvt_f32_z(pg_1, svunpkhi(svunpklo(vec_elements))); - auto vec_elements_flt_2 = svcvt_f32_z(pg_2, svunpklo(svunpkhi(vec_elements))); - auto vec_elements_flt_3 = svcvt_f32_z(pg_3, svunpkhi(svunpkhi(vec_elements))); + auto vec_elements_flt_0 = svcvt_f32_z(pg_0, svunpklo(svunpklo(vec_elements_sub))); + auto vec_elements_flt_1 = svcvt_f32_z(pg_1, svunpkhi(svunpklo(vec_elements_sub))); + auto vec_elements_flt_2 = svcvt_f32_z(pg_2, svunpklo(svunpkhi(vec_elements_sub))); + auto vec_elements_flt_3 = svcvt_f32_z(pg_3, svunpkhi(svunpkhi(vec_elements_sub))); if(is_log) { @@ -180,10 +180,10 @@ void sve2_softmax_logits_1d_quantized(const ITensor *in, const ITensor *max, voi if(is_qasymm8_signed) { const auto offset_vec = svdup_n_f32(128.f); - res_0 = svsub_z(pg_0, vec_in_0, offset_vec); - res_1 = svsub_z(pg_1, vec_in_1, offset_vec); - res_2 = svsub_z(pg_2, vec_in_2, offset_vec); - res_3 = svsub_z(pg_3, vec_in_3, offset_vec); + res_0 = svsub_z(pg_0, res_0, offset_vec); + res_1 = svsub_z(pg_1, res_1, offset_vec); + res_2 = svsub_z(pg_2, res_2, offset_vec); + res_3 = svsub_z(pg_3, res_3, offset_vec); } } diff --git a/src/cpu/operators/CpuAdd.cpp b/src/cpu/operators/CpuAdd.cpp index 76ec7d7d8d..828361e7cf 100644 --- a/src/cpu/operators/CpuAdd.cpp +++ b/src/cpu/operators/CpuAdd.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -27,6 +27,8 @@ #include "src/common/utils/Log.h" +#include "arm_compute/runtime/NEON/NEScheduler.h" + namespace arm_compute { namespace cpu @@ -45,5 +47,17 @@ Status CpuAdd::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ARM_COMPUTE_RETURN_ERROR_ON(act_info.enabled()); return kernels::CpuAddKernel::validate(src0, src1, dst, policy); } + +void CpuAdd::run(ITensorPack &tensors) +{ + if(static_cast(_kernel.get())->get_can_interpret_inputs_as_1d_array()) + { + NEScheduler::get().schedule_op(_kernel.get(), Window::DimX, _kernel->window(), tensors); + } + else + { + ICpuOperator::run(tensors); + } +} } // namespace cpu } // namespace arm_compute diff --git a/src/cpu/operators/CpuAdd.h b/src/cpu/operators/CpuAdd.h index d8ec620aeb..4ad6d7fe65 100644 --- a/src/cpu/operators/CpuAdd.h +++ b/src/cpu/operators/CpuAdd.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -62,6 +62,9 @@ class CpuAdd : public ICpuOperator * @return a status */ static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst, ConvertPolicy policy, const ActivationLayerInfo &act_info = ActivationLayerInfo()); + + // Inherited methods overridden: + void run(ITensorPack &tensors) override; }; } // namespace cpu } // namespace arm_compute diff --git a/src/cpu/operators/CpuFullyConnected.cpp b/src/cpu/operators/CpuFullyConnected.cpp index 6d77c614f7..3172644488 100644 --- a/src/cpu/operators/CpuFullyConnected.cpp +++ b/src/cpu/operators/CpuFullyConnected.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -53,7 +53,7 @@ std::pair get_quantized_asymmetric_output_min_max(const { PixelValue type_min{}; PixelValue type_max{}; - std::tie(type_min, type_max) = get_min_max(data_type); + std::tie(type_min, type_max) = get_min_max(data_type); const UniformQuantizationInfo q_unif = q_info.uniform(); if(act_info.enabled()) @@ -162,8 +162,9 @@ CpuFullyConnected::CpuFullyConnected() _is_fc_after_conv(false), _is_quantized_asymmetric(false), _is_prepared(false), - _enable_fast_math(false) - + _enable_fast_math(false), + _fixed_format(false), + _weight_format(arm_compute::WeightFormat::UNSPECIFIED) { } @@ -199,6 +200,8 @@ void CpuFullyConnected::configure_mm(const ITensorInfo *src, const ITensorInfo * GEMMInfo gemm_info(false, false, true /* Reshape weights only for the first run */); gemm_info.set_activation_info(act); gemm_info.set_fast_math(_enable_fast_math); + gemm_info.set_fixed_format(_fixed_format); + gemm_info.set_weight_format(_weight_format); _mm_gemm = std::make_unique(); _mm_gemm->configure(src, weights, biases, dst, 1.f, 1.0f, gemm_info); } @@ -229,7 +232,7 @@ void CpuFullyConnected::configure_fc_fc(const ITensorInfo *src, const ITensorInf } void CpuFullyConnected::configure(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, - FullyConnectedLayerInfo fc_info) + FullyConnectedLayerInfo fc_info, const WeightsInfo &weights_info) { // Perform validate step ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights, dst); @@ -248,6 +251,8 @@ void CpuFullyConnected::configure(const ITensorInfo *src, const ITensorInfo *wei _is_prepared = false; _trans_weights_idx = AuxTensorIdx::Count; _enable_fast_math = fc_info.enable_fast_math; + _fixed_format = weights_info.weight_format() != WeightFormat::UNSPECIFIED; + _weight_format = weights_info.weight_format(); // With the Fully Connected layer we can have 4 different cases: // 1) Convolution layer -> Fully Connected layer without batches @@ -261,9 +266,7 @@ void CpuFullyConnected::configure(const ITensorInfo *src, const ITensorInfo *wei const bool is_batched_fc_layer = dst->dimension(1) > 1; if(is_batched_fc_layer) { - _is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(src->tensor_shape().cbegin() + 3, - src->tensor_shape().cend(), - dst->tensor_shape().cbegin() + 1)); + _is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(src->tensor_shape().cbegin() + 3, src->tensor_shape().cend(), dst->tensor_shape().cbegin() + 1)); } else { @@ -323,12 +326,10 @@ void CpuFullyConnected::configure(const ITensorInfo *src, const ITensorInfo *wei { // Release permuted weights at the end of prepare as they are further transposed by the assembly dispatch // Do not release them if biases are dynamic and data type is quantized, since the weights tensor will be used for biases offset calculation - _aux_mem[TransposedWeights] = MemoryInfo(offset_int_vec(TransposedWeights), (_is_quantized_asymmetric - && biases && !(biases->are_values_constant())) ? - MemoryLifetime::Persistent : - MemoryLifetime::Prepare, + _aux_mem[TransposedWeights] = MemoryInfo(offset_int_vec(TransposedWeights), (_is_quantized_asymmetric && biases + && !(biases->are_values_constant())) ? MemoryLifetime::Persistent : MemoryLifetime::Prepare, _reshaped_weights.total_size()); - _aux_mem[ConvertedWeights] = MemoryInfo(offset_int_vec(ConvertedWeights), MemoryLifetime::Prepare, _converted_weights.total_size()); + _aux_mem[ConvertedWeights] = MemoryInfo(offset_int_vec(ConvertedWeights), MemoryLifetime::Prepare, _converted_weights.total_size()); } else { @@ -338,6 +339,18 @@ void CpuFullyConnected::configure(const ITensorInfo *src, const ITensorInfo *wei _aux_mem[FlattenedSrc] = MemoryInfo(offset_int_vec(FlattenedSrc), MemoryLifetime::Temporary, _flattened_src.total_size()); } +Status CpuFullyConnected::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, + const ITensorInfo *biases, const ITensorInfo *dst, FullyConnectedLayerInfo fc_info, WeightsInfo weights_info) +{ + GEMMInfo gemm_info(false, false, true /* Reshape weights only for the first run */); + gemm_info.set_activation_info(fc_info.activation_info); + gemm_info.set_fast_math(fc_info.enable_fast_math); + gemm_info.set_fixed_format(weights_info.weight_format() != WeightFormat::UNSPECIFIED); + gemm_info.set_weight_format(weights_info.weight_format()); + + return CpuGemm::has_opt_impl(expected_weight_format, src, weights, biases, dst, gemm_info); +} + Status CpuFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, FullyConnectedLayerInfo fc_info) { @@ -384,9 +397,7 @@ Status CpuFullyConnected::validate(const ITensorInfo *src, const ITensorInfo *we if(is_batched_fc_layer) { - is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(src->tensor_shape().cbegin() + 3, - src->tensor_shape().cend(), - dst->tensor_shape().cbegin() + 1)); + is_fc_after_conv = (TensorShape::num_max_dimensions >= 4) && (std::equal(src->tensor_shape().cbegin() + 3, src->tensor_shape().cend(), dst->tensor_shape().cbegin() + 1)); } else { diff --git a/src/cpu/operators/CpuFullyConnected.h b/src/cpu/operators/CpuFullyConnected.h index 44fa21f9f8..36511e9d32 100644 --- a/src/cpu/operators/CpuFullyConnected.h +++ b/src/cpu/operators/CpuFullyConnected.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -72,20 +72,21 @@ class CpuFullyConnected : public ICpuOperator * |QASYMM8 |QASYMM8 |S32 |QASYMM8 | * |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |QASYMM8_SIGNED | * - * @param[in] src Source tensor info. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32. - * @param[in] weights Weights tensor info. The weights must be 2 dimensional. - * If this function is called after a Convolution Layer, the (transposed) weights will have as many rows as the product of the first 3 input's dimensions. - * If it is called after another FullyConnected Layer, the (transposed) weights will have as many rows as the input's first dimension. - * Data type supported: Same as @p src. - * @param[in] biases Bias tensor info. Can be nullptr. Data type supported: Same as @p weights, S32 if @p weights is QASYMM8/QASYMM8_SIGNED. - * @param[out] dst Destination tensor info. Its shape should be equal to the output of a matrix multiplication between: - * - The output of im2col on the input and the (transposed) 2D weights, if the function is called after a Convolution Layer - * - The input tensor and the (transposed) 2D weights, if the function is called after another FullyConnected Layer. - * Data type supported: Same as @p src. - * @param[in] fc_info (Optional) Fully connected layer additional info + * @param[in] src Source tensor info. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32. + * @param[in] weights Weights tensor info. The weights must be 2 dimensional. + * If this function is called after a Convolution Layer, the (transposed) weights will have as many rows as the product of the first 3 input's dimensions. + * If it is called after another FullyConnected Layer, the (transposed) weights will have as many rows as the input's first dimension. + * Data type supported: Same as @p src. + * @param[in] biases Bias tensor info. Can be nullptr. Data type supported: Same as @p weights, S32 if @p weights is QASYMM8/QASYMM8_SIGNED. + * @param[out] dst Destination tensor info. Its shape should be equal to the output of a matrix multiplication between: + * - The output of im2col on the input and the (transposed) 2D weights, if the function is called after a Convolution Layer + * - The input tensor and the (transposed) 2D weights, if the function is called after another FullyConnected Layer. + * Data type supported: Same as @p src. + * @param[in] fc_info (Optional) Fully connected layer additional info + * @param[in] weights_info (Optional) Stores neccessary compute information when weights are already reshaped */ void configure(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, - FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo()); + FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(), const WeightsInfo &weights_info = WeightsInfo()); /** Static function to check if given info will lead to a valid configuration of @ref CpuFullyConnected * * Similar to @ref CpuFullyConnected @@ -95,9 +96,19 @@ class CpuFullyConnected : public ICpuOperator static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo()); + /** Static function that queries whether there exists fixed-format kernel and if it exists it will return in the first argument in what format + * weights are expected to be reshaped as defined by WeightFormat class. Apart from the first argument the rest of the arguments are the same + * as in @ref CpuFullyConnectedLayer::validate() except that all arguments are required. + * + * @return a status + */ + static Status has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, + const ITensorInfo *biases, const ITensorInfo *dst, + FullyConnectedLayerInfo fc_info, WeightsInfo weights_info); + //Inherited methods override - void run(ITensorPack &tensors) override; - void prepare(ITensorPack &tensors) override; + void run(ITensorPack &tensors) override; + void prepare(ITensorPack &tensors) override; experimental::MemoryRequirements workspace() const override; private: @@ -136,12 +147,14 @@ class CpuFullyConnected : public ICpuOperator experimental::MemoryRequirements _aux_mem; - bool _needs_weights_conversion; - bool _needs_weights_reshape; - bool _is_fc_after_conv; - bool _is_quantized_asymmetric; - bool _is_prepared; - bool _enable_fast_math; + bool _needs_weights_conversion; + bool _needs_weights_reshape; + bool _is_fc_after_conv; + bool _is_quantized_asymmetric; + bool _is_prepared; + bool _enable_fast_math; + bool _fixed_format; + arm_compute::WeightFormat _weight_format; }; } // namespace cpu } // namespace arm_compute diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp index 9c7ad92761..f6582c73f8 100644 --- a/src/cpu/operators/CpuGemm.cpp +++ b/src/cpu/operators/CpuGemm.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -50,6 +50,8 @@ cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info) asm_info.depth_output_gemm3d = info.depth_output_gemm3d(); asm_info.activation_info = info.activation_info(); asm_info.fast_mode = info.fast_math(); + asm_info.fixed_format = info.fixed_format(); + asm_info.weight_format = info.weight_format(); return asm_info; } @@ -72,8 +74,7 @@ void CpuGemm::configure(const ITensorInfo *a, const ITensorInfo *b, const ITenso _run_alpha_scale = alpha != 1.f; _run_bias_addition = c != nullptr && gemm_info.reshape_b_only_on_first_run(); _run_addition = beta != 0 && c != nullptr && !gemm_info.reshape_b_only_on_first_run(); - _run_activation = gemm_info.activation_info().enabled() && (!run_optimised || (run_optimised - && !cpu::CpuGemmAssemblyDispatch::is_activation_supported(gemm_info.activation_info()))); + _run_activation = gemm_info.activation_info().enabled() && (!run_optimised || (run_optimised && !cpu::CpuGemmAssemblyDispatch::is_activation_supported(gemm_info.activation_info()))); if(run_optimised) { @@ -177,7 +178,8 @@ Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITens if(d->total_size() != 0) { - ARM_COMPUTE_RETURN_ERROR_ON(b->dimension(0) != d->dimension(0)); + // For fixed format we are expecting some kind of blocked format for B/RHS so the dimension won't necessarily match the result matrix any more. + ARM_COMPUTE_RETURN_ERROR_ON(!gemm_info.fixed_format() && b->dimension(0) != d->dimension(0)); if(gemm_info.depth_output_gemm3d() != 0) { if(gemm_info.reinterpret_input_as_3d()) @@ -277,7 +279,7 @@ void CpuGemm::run(ITensorPack &tensors) auto c = tensors.get_const_tensor(ACL_SRC_2); auto d = tensors.get_tensor(ACL_DST); - if(_asm_glue->is_configured()) + if(_asm_glue && _asm_glue->is_configured()) { // Pass c to asm dispatch only if it's the bias tensor ITensorPack asm_pack = tensors; @@ -343,7 +345,7 @@ void CpuGemm::prepare(ITensorPack &tensors) { if(!_is_prepared) { - if(_asm_glue->is_configured()) + if(_asm_glue && _asm_glue->is_configured()) { _asm_glue->prepare(tensors); } @@ -365,5 +367,18 @@ experimental::MemoryRequirements CpuGemm::workspace() const { return _aux_mem; } + +Status CpuGemm::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, + const GEMMInfo &gemm_info) +{ + const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info); + + return CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, asm_info); +} + +bool CpuGemm::isVarWeightsKernel() const +{ + return _asm_glue && _asm_glue->isVarWeightsKernel(); +} } // namespace cpu } // namespace arm_compute diff --git a/src/cpu/operators/CpuGemm.h b/src/cpu/operators/CpuGemm.h index 334ab6c647..8d34b22437 100644 --- a/src/cpu/operators/CpuGemm.h +++ b/src/cpu/operators/CpuGemm.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -101,11 +101,29 @@ class CpuGemm : public ICpuOperator static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, float alpha, float beta, const GEMMInfo &gemm_info = GEMMInfo()); + /** Indicates whether or not there is an optimal assembly implementation that can be used to process the given parameters. + * + * This method has the same use of @ref + * NEGEMMConvolutionLayer::has_opt_impl, with the only caveat that + * the value of arm_compute::WeightFormat need to be passed via the + * parameter gemm_info. + */ + static Status has_opt_impl(arm_compute::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, + const GEMMInfo &gemm_info = GEMMInfo()); + // Inherited methods overridden: void run(ITensorPack &tensors) override; void prepare(ITensorPack &constants) override; experimental::MemoryRequirements workspace() const override; + /** Indicates if the convolution executes in variable weights mode. + * + * When ACL executes convolution in variable weights mode, it does + * not perform any processing of the weights tensor. Instead, it + * utilizes the data as it is given by the user. + */ + bool isVarWeightsKernel() const; + private: enum AuxTensorIdx { diff --git a/src/cpu/operators/CpuGemmConv2d.cpp b/src/cpu/operators/CpuGemmConv2d.cpp index c021d31059..f3a16f104f 100644 --- a/src/cpu/operators/CpuGemmConv2d.cpp +++ b/src/cpu/operators/CpuGemmConv2d.cpp @@ -62,13 +62,13 @@ CpuGemmConv2d::SkipInfo CpuGemmConv2d::skip_im_col_info(const ITensorInfo *src, const unsigned int kernel_height = weights->dimension(idx_height); unsigned int conv_w = 0; unsigned int conv_h = 0; - std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width), - src->dimension(idx_height), - kernel_width, - kernel_height, - conv_info, - dilation); - const bool skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1); + std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width), + src->dimension(idx_height), + kernel_width, + kernel_height, + conv_info, + dilation); + const bool skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1); if(skip_im2col) { @@ -99,15 +99,15 @@ CpuGemmConv2d::CpuGemmConv2d() CpuGemmConv2d::~CpuGemmConv2d() = default; void CpuGemmConv2d::configure_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *dst, const ActivationLayerInfo &act_info, - bool enable_fast_math, int gemm_3d_depth) + bool enable_fast_math, int gemm_3d_depth, bool fixed_format, arm_compute::WeightFormat weight_format) { ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights); - ARM_COMPUTE_ERROR_THROW_ON(validate_mm(src, weights, biases, dst, act_info, enable_fast_math, gemm_3d_depth, _skip_im2col)); + ARM_COMPUTE_ERROR_THROW_ON(validate_mm(src, weights, biases, dst, act_info, enable_fast_math, gemm_3d_depth, _skip_im2col, fixed_format, weight_format)); // Create GEMMInfo structure const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth, _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, - false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info); + false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info, experimental::PostOpList(), fixed_format, weight_format); // Supported activations in GEMM const std::set supported_acts = { ActivationLayerInfo::ActivationFunction::RELU, @@ -139,8 +139,8 @@ void CpuGemmConv2d::configure_mm(const ITensorInfo *src, const ITensorInfo *weig PixelValue type_min{}; PixelValue type_max{}; std::tie(type_min, type_max) = get_min_max(data_type); - int32_t min_activation = type_min.get(); - int32_t max_activation = type_max.get(); + int32_t min_activation = type_min.get(); + int32_t max_activation = type_max.get(); if(supported_acts.count(act_info.activation()) != 0) { @@ -156,7 +156,8 @@ void CpuGemmConv2d::configure_mm(const ITensorInfo *src, const ITensorInfo *weig quantization::calculate_quantized_multipliers(iqinfo, wqinfo, oqinfo, output_info); _mm_gemmlowp = std::make_unique(); - _mm_gemmlowp->configure(&tmp_src, &tmp_weights, biases, dst, GEMMInfo(false, false, true, gemm_3d_depth, _skip_im2col, false, output_info, false, enable_fast_math, false, act_info)); + _mm_gemmlowp->configure(&tmp_src, &tmp_weights, biases, dst, GEMMInfo(false, false, true, gemm_3d_depth, _skip_im2col, false, output_info, false, enable_fast_math, false, act_info, + experimental::PostOpList(), fixed_format, weight_format)); auto mm_mem_req = _mm_gemmlowp->workspace(); for(unsigned int cont = 0; cont < mm_mem_req.size(); ++cont) @@ -178,7 +179,7 @@ void CpuGemmConv2d::configure_mm(const ITensorInfo *src, const ITensorInfo *weig } Status CpuGemmConv2d::validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, - const ActivationLayerInfo &act_info, bool enable_fast_math, int gemm_3d_depth, bool skip_im2col) + const ActivationLayerInfo &act_info, bool enable_fast_math, int gemm_3d_depth, bool skip_im2col, bool fixed_format, arm_compute::WeightFormat weight_format) { const DataType data_type = src->data_type(); const bool is_quantized = is_data_type_quantized_asymmetric(data_type); @@ -187,7 +188,7 @@ Status CpuGemmConv2d::validate_mm(const ITensorInfo *src, const ITensorInfo *wei // Create GEMMInfo structure const GEMMInfo gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, - false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info); + false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info, experimental::PostOpList(), fixed_format, weight_format); if(is_quantized) { @@ -202,8 +203,8 @@ Status CpuGemmConv2d::validate_mm(const ITensorInfo *src, const ITensorInfo *wei PixelValue type_min{}; PixelValue type_max{}; std::tie(type_min, type_max) = get_min_max(data_type); - int32_t min_activation = type_min.get(); - int32_t max_activation = type_max.get(); + int32_t min_activation = type_min.get(); + int32_t max_activation = type_max.get(); const std::set supported_acts = { ActivationLayerInfo::ActivationFunction::RELU, ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, @@ -227,6 +228,7 @@ Status CpuGemmConv2d::validate_mm(const ITensorInfo *src, const ITensorInfo *wei std::unique_ptr weights_qa = weights->clone(); input_qa->set_quantization_info(QuantizationInfo(iqinfo.uniform().scale, -iqinfo.uniform().offset)); weights_qa->set_quantization_info(QuantizationInfo(wqinfo.uniform().scale, -wqinfo.uniform().offset)); + return CpuGemmLowpMatrixMultiplyCore::validate(input_qa.get(), weights_qa.get(), biases, dst, GEMMInfo(false, false, true, gemm_3d_depth, skip_im2col, false, output_info, false, enable_fast_math, false, act_info)); } @@ -286,14 +288,15 @@ void CpuGemmConv2d::configure(const ITensorInfo *src, const ITensorInfo *weights ITensorInfo *gemm_output_to_use = dst; // Get convolved dimensions - unsigned int conv_w = 0; - unsigned int conv_h = 0; + unsigned int conv_w = 0; + unsigned int conv_h = 0; std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width), src->dimension(idx_height), kernel_width, kernel_height, conv_info, dilation); + ARM_COMPUTE_ERROR_ON_MSG((dst->dimension(idx_width) != conv_w) || (dst->dimension(idx_height) != conv_h), "Output shape does not match the expected one"); @@ -303,8 +306,8 @@ void CpuGemmConv2d::configure(const ITensorInfo *src, const ITensorInfo *weights _skip_col2im = skip_info.skip_col2im; // Get parameters from conv_info - unsigned int stride_x = 0; - unsigned int stride_y = 0; + unsigned int stride_x = 0; + unsigned int stride_y = 0; std::tie(stride_x, stride_y) = conv_info.stride(); unsigned int mat_weights_cols = weights->dimension(idx_kernels); @@ -357,7 +360,8 @@ void CpuGemmConv2d::configure(const ITensorInfo *src, const ITensorInfo *weights // Configure GEMM // In case we need to skip col2im, GEMM3D (gemm_3d_depth != 0) must be called in order to avoid reshaping the output matrix const unsigned int gemm_3d_depth = _skip_col2im ? conv_h : 0; - configure_mm(gemm_input_to_use, &_weights_reshaped, biases, gemm_output_to_use, act_info, enable_fast_math, gemm_3d_depth); + const bool fixed_format = weights_info.weight_format() != arm_compute::WeightFormat::UNSPECIFIED; + configure_mm(gemm_input_to_use, &_weights_reshaped, biases, gemm_output_to_use, act_info, enable_fast_math, gemm_3d_depth, fixed_format, weights_info.weight_format()); if(!_skip_col2im && _data_layout == DataLayout::NCHW) { @@ -384,6 +388,38 @@ void CpuGemmConv2d::configure(const ITensorInfo *src, const ITensorInfo *weights _aux_mem[GemmOutput] = MemoryInfo(offset_int_vec(GemmOutput), MemoryLifetime::Temporary, _gemm_output.total_size()); } +Status CpuGemmConv2d::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, + const PadStrideInfo &conv_info, + const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, const bool enable_fast_math) +{ + const DataLayout data_layout = src->data_layout(); + const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); + const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); + const unsigned int kernel_width = weights->dimension(idx_width); + const unsigned int kernel_height = weights->dimension(idx_height); + unsigned int conv_w = 0; + unsigned int conv_h = 0; + std::tie(conv_w, conv_h) = scaled_dimensions(src->dimension(idx_width), + src->dimension(idx_height), + kernel_width, + kernel_height, + conv_info, + dilation); + + const CpuGemmConv2d::SkipInfo skip_info = CpuGemmConv2d::skip_im_col_info(src, weights, conv_info, + dilation, act_info); + + const bool skip_im2col = skip_info.skip_im2col; + const bool skip_col2im = skip_info.skip_col2im; + const unsigned int gemm_3d_depth = skip_col2im ? conv_h : 0; + const bool fixed_format = weights_info.weight_format() != arm_compute::WeightFormat::UNSPECIFIED; + const GEMMInfo gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */, + gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */, + false, GEMMLowpOutputStageInfo(), false, enable_fast_math, false, act_info, experimental::PostOpList(), fixed_format, weights_info.weight_format()); + + return CpuGemm::has_opt_impl(expected_weight_format, src, weights, biases, dst, gemm_info); +} + Status CpuGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const PadStrideInfo &conv_info, const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, bool enable_fast_math, unsigned int num_groups) { @@ -428,9 +464,9 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weight dilation); // Check if GEMM3D is supported - const CpuGemmConv2d::SkipInfo skip_info = CpuGemmConv2d::skip_im_col_info(src, weights, conv_info, - dilation, act_info); - const bool skip_im2col = skip_info.skip_im2col, skip_col2im = skip_info.skip_col2im; + const CpuGemmConv2d::SkipInfo skip_info = CpuGemmConv2d::skip_im_col_info(src, weights, conv_info, + dilation, act_info); + const bool skip_im2col = skip_info.skip_im2col, skip_col2im = skip_info.skip_col2im; ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_channel) != src->dimension(idx_channel)); ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4); @@ -450,7 +486,7 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weight { ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, biases); } - ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(idx_kernels)); + ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != dst->dimension(idx_channel)); ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1); } @@ -472,7 +508,7 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weight im2col_reshaped_info = TensorInfo(shape_im2col, 1, data_type); im2col_reshaped_info.set_quantization_info(src->quantization_info()); - ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuIm2ColKernel::validate(src, &im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation)); + ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuIm2ColKernel::validate(src, &im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation, 1)); gemm_input_to_use = &im2col_reshaped_info; } @@ -490,8 +526,11 @@ Status CpuGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weight info_gemm = TensorInfo(dst->tensor_shape(), 1, output_data_type); } info_gemm.set_quantization_info(dst->quantization_info()).set_data_layout(src->data_layout()); - gemm_output_to_use = &info_gemm; - ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases, gemm_output_to_use, act_info, enable_fast_math, skip_col2im ? conv_h : 0, skip_im2col)); + gemm_output_to_use = &info_gemm; + const bool fixed_format = weights_info.weight_format() != arm_compute::WeightFormat::UNSPECIFIED; + + ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases, gemm_output_to_use, act_info, enable_fast_math, skip_col2im ? conv_h : 0, skip_im2col, fixed_format, + weights_info.weight_format())); // Validate Col2Im/ReshapeLayer if(!skip_col2im && (data_layout == DataLayout::NCHW)) @@ -519,7 +558,7 @@ void CpuGemmConv2d::run(ITensorPack &tensors) { // Run input reshaping unsigned int y_dim = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT); - ITensorPack pack = + ITensorPack pack = { { TensorType::ACL_SRC, src }, { TensorType::ACL_DST, im2col_output.get() } @@ -548,7 +587,10 @@ void CpuGemmConv2d::run(ITensorPack &tensors) // Runs CpuGemm or CpuGemmLowpMatrixMultiplyCore functions ITensorPack pack_mm = tensors; pack_mm.add_const_tensor(TensorType::ACL_SRC_0, gemm_input_to_use); - pack_mm.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get()); + if(!this->isVarWeightsKernel()) + { + pack_mm.add_const_tensor(TensorType::ACL_SRC_1, reshaped_wei.get()); + } pack_mm.add_tensor(TensorType::ACL_DST, gemm_output_to_use); if(_is_quantized) { @@ -598,22 +640,28 @@ void CpuGemmConv2d::prepare(ITensorPack &tensors) { if(!_is_prepared) { + // Variable weights executions that use fixed-format kernels + // need no reshaping of the weights. + if(this->isVarWeightsKernel()) + { + _is_quantized ? _mm_gemmlowp->prepare(tensors) : _mm_gemm->prepare(tensors); + _is_prepared = true; + return; + } + // Run weights reshaping and mark original weights tensor as unused CpuAuxTensorHandler weights_reshaped(offset_int_vec(WeightsReshaped), _weights_reshaped, tensors); auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1); - ITensorPack pack = + ITensorPack pack = { { TensorType::ACL_SRC, weights }, { TensorType::ACL_DST, weights_reshaped.get() } }; NEScheduler::get().schedule_op(_weights_reshape_kernel.get(), 3, _weights_reshape_kernel->window(), pack); weights->mark_as_unused(); - - // Prepare GEMM ITensorPack gemm_pack = tensors; gemm_pack.add_const_tensor(TensorType::ACL_SRC_1, weights_reshaped.get()); _is_quantized ? _mm_gemmlowp->prepare(gemm_pack) : _mm_gemm->prepare(gemm_pack); - _is_prepared = true; } } @@ -621,5 +669,9 @@ experimental::MemoryRequirements CpuGemmConv2d::workspace() const { return _aux_mem; } +bool CpuGemmConv2d::isVarWeightsKernel() const +{ + return _mm_gemm && _mm_gemm->isVarWeightsKernel(); +} } // namespace cpu } // namespace arm_compute diff --git a/src/cpu/operators/CpuGemmConv2d.h b/src/cpu/operators/CpuGemmConv2d.h index aec4a2ffa5..08b76a6c46 100644 --- a/src/cpu/operators/CpuGemmConv2d.h +++ b/src/cpu/operators/CpuGemmConv2d.h @@ -117,9 +117,20 @@ class CpuGemmConv2d : public ICpuOperator const WeightsInfo &weights_info = WeightsInfo(), const Size2D &dilation = Size2D(1U, 1U), const ActivationLayerInfo &act_info = ActivationLayerInfo(), bool enable_fast_math = false, unsigned int num_groups = 1); + /** Indicates whether or not there is an optimal assembly implementation that can be used to process the given parameters. + * + * The paramter list is the same as @ref NEGEMMConvolutionLayer::has_opt_impl + * + * @return a status. + */ + static Status has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, + const PadStrideInfo &conv_info, + const WeightsInfo &weights_info = WeightsInfo(), const Size2D &dilation = Size2D(1U, 1U), const ActivationLayerInfo &act_info = ActivationLayerInfo(), + const bool enable_fast_math = false); + // Inherited methods overridden: - void run(ITensorPack &tensors) override; - void prepare(ITensorPack &tensors) override; + void run(ITensorPack &tensors) override; + void prepare(ITensorPack &tensors) override; experimental::MemoryRequirements workspace() const override; private: @@ -135,9 +146,11 @@ class CpuGemmConv2d : public ICpuOperator * @param[in] enable_fast_math (Optional) Enable fast math computation. In case this flag were set, the function could dispatch the fastest implementation * available which may introduce a drop of accuracy as well. Default is false * @param[in] gemm_3d_depth (Optional) Depth of GEMM 3D (Defaults to 1) + * @param[in] fixed_format (Optional) Select GEMM execution with variable weights. + * @param[in] weight_format (Optional) The layout to be used for the weights tensor when running GEMM with variable weights. */ void configure_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *output, const ActivationLayerInfo &act_info = ActivationLayerInfo(), - bool enable_fast_math = false, int gemm_3d_depth = 1); + bool enable_fast_math = false, int gemm_3d_depth = 1, bool fixed_format = false, arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED); /** Static function to check if given info will lead to a valid configuration of @ref NEGEMMConvolutionLayer matrix multiply routines * * @param[in] src Input tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/BFLOAT16/F16/F32. @@ -151,11 +164,13 @@ class CpuGemmConv2d : public ICpuOperator * available which may introduce a drop of accuracy as well. Default is false * @param[in] gemm_3d_depth (Optional) Depth of GEMM 3D (Defaults to 1) * @param[in] skip_im2col (Optional) Flag which specifies if im2col has to be skipped. i.e. 1x1 convolution with NHWC data layout. (Default to false) + * @param[in] fixed_format (Optional) Select GEMM execution with variable weights. + * @param[in] weight_format (Optional) The layout to be used for the weights tensor when running GEMM with variable weights. * * @return a status */ static Status validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const ActivationLayerInfo &act_info = ActivationLayerInfo(), - bool enable_fast_math = false, int gemm_3d_depth = 1, bool skip_im2col = false); + bool enable_fast_math = false, int gemm_3d_depth = 1, bool skip_im2col = false, bool fixed_format = false, arm_compute::WeightFormat weight_format = arm_compute::WeightFormat::UNSPECIFIED); /** Static function to check if GEMM3D is supported in @ref NEGEMM or in @ref CpuGemmMLowpMatrixMultiplyCore * * @param[in] src Input tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/BFLOAT16/F16/F32. @@ -187,6 +202,11 @@ class CpuGemmConv2d : public ICpuOperator static SkipInfo skip_im_col_info(const ITensorInfo *src, const ITensorInfo *weights, const PadStrideInfo &conv_info, const Size2D &dilation, const ActivationLayerInfo &act_info); + /** Indicates if the convolution executes in variable weights mode. + * + * Similar to @ref CpuGemm::isVarWeightsKernel + */ + bool isVarWeightsKernel() const; enum AuxTensorIdx { // CpuGemmLowpMatrixMultiplyCore has up to 8 internal tensors diff --git a/src/cpu/operators/CpuGemmDirectConv2d.cpp b/src/cpu/operators/CpuGemmDirectConv2d.cpp index 75c057e455..ee47a17d64 100644 --- a/src/cpu/operators/CpuGemmDirectConv2d.cpp +++ b/src/cpu/operators/CpuGemmDirectConv2d.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -57,11 +57,11 @@ GEMMLowpOutputStageInfo calculate_output_stage_metadata(const ITensorInfo *src, ActivationLayerInfo::ActivationFunction::BOUNDED_RELU, ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU }; - PixelValue type_min{}; - PixelValue type_max{}; + PixelValue type_min{}; + PixelValue type_max{}; std::tie(type_min, type_max) = get_min_max(data_type); - int32_t min_activation = type_min.get(); - int32_t max_activation = type_max.get(); + int32_t min_activation = type_min.get(); + int32_t max_activation = type_max.get(); if(supported_acts.count(act.activation()) != 0) { std::tie(min_activation, max_activation) = get_quantized_activation_min_max(act, data_type, uoqinfo); @@ -88,6 +88,8 @@ cpu::AsmGemmInfo init_assembly_metadata(const Conv2dInfo &info, bool is_indirect asm_info.padding_value = 0.f; asm_info.negated_offsets = false; asm_info.fast_mode = info.enable_fast_math; + asm_info.fixed_format = info.weights_info.weight_format() != WeightFormat::UNSPECIFIED; + asm_info.weight_format = info.weights_info.weight_format(); return asm_info; } } // namespace @@ -146,7 +148,9 @@ void CpuGemmDirectConv2d::configure(const ITensorInfo *src, const ITensorInfo *w } else { - _aux_mem[PermutedWeights] = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Persistent, weights->total_size()); + // We must permute weights if they are WeightFormat::UNSPECIFIED + if(info.weights_info.weight_format() == WeightFormat::UNSPECIFIED) + _aux_mem[PermutedWeights] = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Persistent, weights->total_size()); } } Status CpuGemmDirectConv2d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv2dInfo &info) @@ -193,7 +197,9 @@ void CpuGemmDirectConv2d::run(ITensorPack &tensors) _gemm_asm_func->run(tensors); if(_run_activation) { - _activation_func->run(tensors); + ITensor *io = tensors.get_tensor(ACL_DST); + ITensorPack pack{ { ACL_SRC, io }, { ACL_DST, io } }; + _activation_func->run(pack); } } @@ -201,6 +207,13 @@ void CpuGemmDirectConv2d::prepare(ITensorPack &tensors) { if(!_is_prepared) { + // If we are using fixed-format kernel the weights are already reshaped + if(_gemm_asm_func && _gemm_asm_func->isVarWeightsKernel()) + { + _gemm_asm_func->prepare(tensors); + _is_prepared = true; + return; + } const ITensor *weights = tensors.get_const_tensor(ACL_SRC_1); ITensor *weights_aux = utils::cast::polymorphic_cast(tensors.get_tensor(offset_int_vec(PermutedWeights))); ARM_COMPUTE_ERROR_ON_NULLPTR(weights, weights_aux); @@ -222,4 +235,4 @@ experimental::MemoryRequirements CpuGemmDirectConv2d::workspace() const return _aux_mem; } } // namespace cpu -} // namespace arm_compute \ No newline at end of file +} // namespace arm_compute diff --git a/src/cpu/operators/CpuWinogradConv2d.cpp b/src/cpu/operators/CpuWinogradConv2d.cpp index dcc18ce8fa..81cf651b76 100644 --- a/src/cpu/operators/CpuWinogradConv2d.cpp +++ b/src/cpu/operators/CpuWinogradConv2d.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -31,19 +31,19 @@ #include "arm_compute/runtime/NEON/NEScheduler.h" #include "src/common/utils/Log.h" #include "src/core/CPP/Validate.h" +#include "src/core/NEON/kernels/assembly/winograd.hpp" +#include "src/core/NEON/kernels/convolution/common/tensor.hpp" #include "src/core/NEON/kernels/convolution/common/utils.hpp" -#include "src/core/NEON/kernels/convolution/winograd/winograd.hpp" #include "src/core/helpers/MemoryHelpers.h" +#include "src/core/helpers/WindowHelpers.h" +#include "src/core/utils/AssemblyUtils.h" #include "src/cpu/kernels/CpuWinogradConv2dKernel.h" +#include "src/cpu/kernels/assembly/arm_gemm.hpp" #include "src/cpu/operators/CpuActivation.h" #include "src/cpu/operators/CpuPermute.h" -#include "src/cpu/operators/CpuWinogradConv2d.h" #include "src/cpu/utils/CpuAuxTensorHandler.h" - #include "support/Cast.h" -#include - namespace arm_compute { namespace cpu @@ -53,174 +53,20 @@ using namespace arm_compute::utils::cast; namespace { -arm_gemm::Activation arm_gemm_activation_from_acl_activation(const ActivationLayerInfo &act_info) -{ - switch(act_info.activation()) - { - case ActivationLayerInfo::ActivationFunction::RELU: - { - return arm_gemm::Activation(arm_gemm::Activation::Type::ReLU, act_info.a(), act_info.b()); - } - case ActivationLayerInfo::ActivationFunction::BOUNDED_RELU: - { - return arm_gemm::Activation(arm_gemm::Activation::Type::BoundedReLU, act_info.a(), act_info.b()); - } - default: - { - return arm_gemm::Activation(arm_gemm::Activation::Type::None); - } - } -} - -inline Status validate_kernel_3x3(const Size2D input_dims, const ITensorInfo *src, const TensorInfo *input0, const TensorInfo *input1, const TensorInfo *batched_mm_output, - const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const WinogradInfo &winograd_info, const ActivationLayerInfo &act_info) -{ - ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F16, DataType::F32); - - if(src->data_type() == DataType::F32) - { - if(input_dims.width > 4 && input_dims.height > 4) - { - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformInputKernel::validate(src, input0, winograd_info))); - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformWeightsKernel::validate(weights, input1, winograd_info))); - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformOutputKernel::validate(batched_mm_output, biases, dst, winograd_info))); - } - else - { - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformInputKernel::validate(src, input0, winograd_info))); - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformWeightsKernel::validate(weights, input1, winograd_info))); - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformOutputKernel::validate(batched_mm_output, biases, dst, winograd_info))); - } - } -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - else if(src->data_type() == DataType::F16) - { - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformInputKernel<__fp16, 4, 4, 3, 3>::validate(src, input0, winograd_info))); - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformWeightsKernel<__fp16, 4, 4, 3, 3>::validate(weights, input1, winograd_info))); - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformOutputKernel<__fp16, 4, 4, 3, 3>::validate(batched_mm_output, biases, dst, winograd_info))); - } -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - - if(act_info.enabled()) - { - CpuActivation::validate(dst, nullptr, act_info); - } - return Status{}; -} - -inline Status validate_kernel_5x5(const ITensorInfo *src, const TensorInfo *input0, const TensorInfo *input1, const TensorInfo *batched_mm_output, - const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const WinogradInfo &winograd_info, const ActivationLayerInfo &act_info) -{ - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformInputKernel::validate(src, input0, winograd_info))); - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformWeightsKernel::validate(weights, input1, winograd_info))); - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformOutputKernel::validate(batched_mm_output, biases, dst, winograd_info))); - if(act_info.enabled()) - { - CpuActivation::validate(dst, nullptr, act_info); - } - return Status{}; -} - -inline Status validate_kernel_3x1(const ITensorInfo *src, const TensorInfo *input0, const TensorInfo *input1, const TensorInfo *batched_mm_output, - const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const WinogradInfo &winograd_info, const ActivationLayerInfo &act_info) -{ - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F32); - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformInputKernel::validate(src, input0, winograd_info))); - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformWeightsKernel::validate(weights, input1, winograd_info))); - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformOutputKernel::validate(batched_mm_output, biases, dst, winograd_info))); - if(act_info.enabled()) - { - CpuActivation::validate(dst, nullptr, act_info); - } - return Status{}; -} - -inline Status validate_kernel_1x3(const ITensorInfo *src, const TensorInfo *input0, const TensorInfo *input1, const TensorInfo *batched_mm_output, - const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const WinogradInfo &winograd_info, const ActivationLayerInfo &act_info) -{ - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F32); - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformInputKernel::validate(src, input0, winograd_info))); - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformWeightsKernel::validate(weights, input1, winograd_info))); - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformOutputKernel::validate(batched_mm_output, biases, dst, winograd_info))); - - if(act_info.enabled()) - { - CpuActivation::validate(dst, nullptr, act_info); - } - return Status{}; -} - -inline Status validate_kernel_5x1(const ITensorInfo *src, const TensorInfo *input0, const TensorInfo *input1, const TensorInfo *batched_mm_output, - const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const WinogradInfo &winograd_info, const ActivationLayerInfo &act_info) -{ - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F32); - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformInputKernel::validate(src, input0, winograd_info))); - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformWeightsKernel::validate(weights, input1, winograd_info))); - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformOutputKernel::validate(batched_mm_output, biases, dst, winograd_info))); - if(act_info.enabled()) - { - CpuActivation::validate(dst, nullptr, act_info); - } - return Status{}; -} -inline Status validate_kernel_1x5(const ITensorInfo *src, const TensorInfo *input0, const TensorInfo *input1, const TensorInfo *batched_mm_output, - const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const WinogradInfo &winograd_info, const ActivationLayerInfo &act_info) -{ - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F32); - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformInputKernel::validate(src, input0, winograd_info))); - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformWeightsKernel::validate(weights, input1, winograd_info))); - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformOutputKernel::validate(batched_mm_output, biases, dst, winograd_info))); - if(act_info.enabled()) - { - CpuActivation::validate(dst, nullptr, act_info); - } - return Status{}; -} - -inline Status validate_kernel_7x1(const ITensorInfo *src, const TensorInfo *input0, const TensorInfo *input1, const TensorInfo *batched_mm_output, - const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const WinogradInfo &winograd_info, const ActivationLayerInfo &act_info) -{ - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F32); - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformInputKernel::validate(src, input0, winograd_info))); - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformWeightsKernel::validate(weights, input1, winograd_info))); - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformOutputKernel::validate(batched_mm_output, biases, dst, winograd_info))); - if(act_info.enabled()) - { - CpuActivation::validate(dst, nullptr, act_info); - } - return Status{}; -} - -inline Status validate_kernel_1x7(const ITensorInfo *src, const TensorInfo *input0, const TensorInfo *input1, const TensorInfo *batched_mm_output, - const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const WinogradInfo &winograd_info, const ActivationLayerInfo &act_info) -{ - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F32); - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformInputKernel::validate(src, input0, winograd_info))); - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformWeightsKernel::validate(weights, input1, winograd_info))); - ARM_COMPUTE_RETURN_ON_ERROR((CpuWinogradConv2dTransformOutputKernel::validate(batched_mm_output, biases, dst, winograd_info))); - - if(act_info.enabled()) - { - CpuActivation::validate(dst, nullptr, act_info); - } - return Status{}; -} - -inline Tensor4DShape internal_get_input_shape(const ITensorInfo *src) +inline Tensor4DShape internal_get_shape(const ITensorInfo *in) { - const DataLayout data_layout = src->data_layout(); - const int in_width = src->dimension(get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)); - const int in_height = src->dimension(get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)); - const int in_channels = src->dimension(get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)); - const int in_batches = src->dimension(3); + const DataLayout data_layout = in->data_layout(); + const int in_width = in->dimension(get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)); + const int in_height = in->dimension(get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)); + const int in_channels = in->dimension(get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)); + const int in_batches = in->dimension(get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES)); return Tensor4DShape{ in_batches, in_height, in_width, in_channels }; } Status validate_arguments(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const PadStrideInfo &conv_info) { - ARM_COMPUTE_UNUSED(dst); + ARM_COMPUTE_UNUSED(dst, weights); ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src); ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.stride().first != 1 || conv_info.stride().second != 1, "Winograd layer only supports unit strides."); @@ -229,108 +75,85 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *weights, co ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, biases); ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1); } - return ICpuWinogradConv2dTransformWeightsKernel::validate(src, weights); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, weights); + return Status{}; } -Size2D winograd_output_tile(const Size2D &input_dims, const Size2D &kernel_dims, DataType data_type) + +bool get_winograd_kernel_implementation(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, + const PadStrideInfo &conv_info, const ActivationLayerInfo &act_info, bool enable_fast_math, + arm_conv::winograd::WinogradImpl *winograd_impl, std::unique_ptr &conv_args) { - Size2D output_tile = Size2D{}; - if(kernel_dims == Size2D(3U, 3U)) - { - output_tile = (input_dims.width <= 4 || input_dims.height <= 4) ? Size2D(2U, 2U) : Size2D(4U, 4U); - if(data_type == DataType::F16) - { - output_tile = Size2D(4U, 4U); - } - } - else if(kernel_dims == Size2D(5U, 5U)) - { - output_tile = Size2D(2U, 2U); - } - else if(kernel_dims == Size2D(1U, 3U)) - { - output_tile = Size2D(1U, 6U); - } - else if(kernel_dims == Size2D(3U, 1U)) - { - output_tile = Size2D(6U, 1U); - } - else if(kernel_dims == Size2D(1U, 5U)) - { - output_tile = Size2D(1U, 4U); - } - else if(kernel_dims == Size2D(5U, 1U)) - { - output_tile = Size2D(4U, 1U); - } - else if(kernel_dims == Size2D(7U, 1U)) + arm_conv::winograd::WinogradConfig winograd_cfg; + arm_gemm::GemmConfig cfg; + + const DataType data_type = src->data_type(); + Tensor4DShape in_shape{ internal_get_shape(src) }; + Tensor4DShape out_shape{ internal_get_shape(dst) }; + Tensor4DShape kernel_shape{ internal_get_shape(weights) }; + uint32_t nthreads = NEScheduler::get().num_threads(); + // Get configuration arguments for Winograd + winograd_cfg.output_rows = 0; + winograd_cfg.output_cols = 0; + conv_args = std::make_unique( + in_shape.n_batches, + arm_conv::Shape2D{ static_cast(in_shape.n_rows), static_cast(in_shape.n_cols) }, + in_shape.n_channels, + conv_info.pad_top(), + conv_info.pad_left(), + arm_conv::Shape2D{ static_cast(out_shape.n_rows), static_cast(out_shape.n_cols) }, + out_shape.n_channels, + arm_conv::Shape2D{ static_cast(kernel_shape.n_rows), static_cast(kernel_shape.n_cols) }, + assembly_utils::map_to_arm_gemm_activation(act_info)); + + bool success = false; + if(data_type == DataType::F32) { - output_tile = Size2D(2U, 1U); + success = arm_conv::winograd::get_implementation( + *winograd_impl, &CPUInfo::get(), *conv_args, nthreads, enable_fast_math, &winograd_cfg, nullptr); } - else if(kernel_dims == Size2D(1U, 7U)) +#if defined(__aarch64__) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + else if(data_type == DataType::F16) { - output_tile = Size2D(1U, 2U); + success = arm_conv::winograd::get_implementation<__fp16>( + *winograd_impl, &CPUInfo::get(), *conv_args, nthreads, enable_fast_math, &winograd_cfg, nullptr); } - return output_tile; -} - -bool check_support_fast_math(const Size2D &output_tile, const Size2D &kernel_size, DataType data_type) -{ - // Check if we want to configure a Winograd configuration which requires fast math - using WinogradConfiguration = std::pair, std::pair>; - - const std::vector fast_math_winograd_f16 = - { - WinogradConfiguration(std::pair(4, 4), std::pair(3, 3)) - }; - - const std::vector fast_math_winograd_f32 = - { - WinogradConfiguration(std::pair(2, 2), std::pair(5, 5)), - WinogradConfiguration(std::pair(4, 4), std::pair(5, 5)) - }; - - auto p = std::make_pair(std::pair(output_tile.width, output_tile.height), - std::pair(kernel_size.width, kernel_size.height)); - - switch(data_type) +#endif // defined(__aarch64__) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + else { - case DataType::F16: - return std::find(fast_math_winograd_f16.begin(), fast_math_winograd_f16.end(), p) != fast_math_winograd_f16.end(); - case DataType::F32: - return std::find(fast_math_winograd_f32.begin(), fast_math_winograd_f32.end(), p) != fast_math_winograd_f32.end(); - default: - return false; + success = false; } + return success; } - inline bool fuse_function_supported(const ActivationLayerInfo &act_info) { return act_info.activation() == ActivationLayerInfo::ActivationFunction::RELU || act_info.activation() == ActivationLayerInfo::ActivationFunction::BOUNDED_RELU; } - } // namespace CpuWinogradConv2d::CpuWinogradConv2d() + : _gemm_function(std::make_unique()), _activation_func(std::make_unique()), + _transform_input_kernel(nullptr), + _transform_output_kernel(nullptr), _permute_input(std::make_unique()), _permute_output(std::make_unique()), _permute_weights(std::make_unique()), - _transform_input_kernel(nullptr), - _transform_weights_kernel(nullptr), - _transform_output_kernel(nullptr), - _data_layout(), _aux_mem(AuxTensorIdx::Count), - _input_nhwc(), - _output_nhwc(), + _conv_args{ nullptr }, + _winograd_impl{}, + _data_layout(), + _winograd_transformed_input{}, + _winograd_transformed_output{}, + _winograd_transformed_weights{}, _input_workspace(), - _kernel_storage(), _output_workspace(), - _input_transformed(), - _output_transformed(), _weights_hwio(), - _run_activation(false), - _is_prepared(false) + _input_nhwc(), + _output_nhwc(), + _is_prepared{ false }, + _run_activation{ false } { } @@ -342,464 +165,202 @@ void CpuWinogradConv2d::configure(const ITensorInfo *src, const ITensorInfo *wei ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights, dst); ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src, weights, biases, dst, conv_info)); ARM_COMPUTE_LOG_PARAMS(src, weights, biases, dst, conv_info, act_info, enable_fast_math); + ARM_COMPUTE_UNUSED(biases); + const DataType data_type = src->data_type(); + uint32_t nthreads = NEScheduler::get().num_threads(); + _data_layout = src->data_layout(); + const Tensor4DShape kernel_shape{ internal_get_shape(weights) }; + + bool success = get_winograd_kernel_implementation(src, weights, dst, conv_info, act_info, enable_fast_math, &_winograd_impl, _conv_args); + + ARM_COMPUTE_EXIT_ON_MSG_VAR(!success, "Unsupported kernel size: %d x %d.\n", kernel_shape.n_rows, kernel_shape.n_cols); + ARM_COMPUTE_LOG_MSG_WITH_FORMAT_ACL(arm_compute::logging::LogLevel::INFO, "Using input transform: %s\n", _winograd_impl.input_transform->get_name().c_str()); + ARM_COMPUTE_LOG_MSG_WITH_FORMAT_ACL(arm_compute::logging::LogLevel::INFO, "Using weight transform: %s\n", _winograd_impl.input_transform->get_name().c_str()); + ARM_COMPUTE_LOG_MSG_WITH_FORMAT_ACL(arm_compute::logging::LogLevel::INFO, "Using output transform: %s\n", _winograd_impl.input_transform->get_name().c_str()); + + const bool has_impl = ((_winograd_impl.input_transform != nullptr) && (_winograd_impl.output_transform != nullptr) && (_winograd_impl.gemm_args != nullptr)); + if(has_impl) + { + // Determine how much working space is required, allocate it. + const size_t input_workspace_size = _winograd_impl.input_transform->get_working_space_size(*_conv_args, nthreads); + const size_t output_workspace_size = _winograd_impl.output_transform->get_working_space_size(*_conv_args, nthreads); + + TensorInfo input_workspace_info(TensorShape(input_workspace_size), 1, DataType::U8); + TensorInfo output_workspace_info(TensorShape(output_workspace_size), 1, DataType::U8); + _input_workspace = input_workspace_info; + _output_workspace = output_workspace_info; + + const auto &wds = _winograd_impl.winograd_spec; + + // Preparing winograd transformed input tensor + const size_t data_type_size = src->element_size(); + const uint32_t m = _winograd_impl.gemm_args->_Msize; // Total number of tiles + const uint32_t k = _winograd_impl.gemm_args->_Ksize; // Input channels + const uint32_t n = _winograd_impl.gemm_args->_Nsize; // Output channels + const uint32_t n_gemms = _winograd_impl.gemm_args->_nmulti; + const uint32_t n_batches = _winograd_impl.gemm_args->_nbatches; + constexpr size_t storage_alignment = 64; + + const TensorShape a_shape(k, m, n_batches, n_gemms); + Strides a_strides(data_type_size); + a_strides.set(1, data_type_size * _winograd_impl.winograd_spec.input_ld_row); + a_strides.set(2, data_type_size * _winograd_impl.winograd_spec.input_ld_batch); + a_strides.set(3, data_type_size * _winograd_impl.winograd_spec.input_ld_matrix); + + const TensorShape b_shape(n, k, n_gemms); + Strides b_strides(data_type_size); + b_strides.set(1, data_type_size * _winograd_impl.winograd_spec.weight_ld_row); + b_strides.set(2, data_type_size * _winograd_impl.winograd_spec.weight_ld_matrix); + + const TensorShape d_shape(n, m, n_batches, n_gemms); + Strides d_strides(data_type_size); + d_strides.set(1, data_type_size * _winograd_impl.winograd_spec.output_ld_row); + d_strides.set(2, data_type_size * _winograd_impl.winograd_spec.output_ld_batch); + d_strides.set(3, data_type_size * _winograd_impl.winograd_spec.output_ld_matrix); + + TensorInfo a_info{}; + TensorInfo b_info{}; + TensorInfo d_info{}; + a_info.init(a_shape, 1, data_type, a_strides, 0, wds.input_matrix_size_bytes); + b_info.init(b_shape, 1, data_type, b_strides, 0, wds.weight_matrix_size_bytes); + d_info.init(d_shape, 1, data_type, d_strides, 0, wds.output_matrix_size_bytes); + + _winograd_transformed_input = a_info; + _winograd_transformed_weights = b_info; + _winograd_transformed_output = d_info; + + PermutationVector weights_permutation_vector(3U, 0U, 1U, 2U); + + // Configure the kernel to transform the input tensor from NCHW -> NHWC + if(_data_layout == DataLayout::NCHW) + { + _permute_input->configure(src, &_input_nhwc, PermutationVector(2U, 0U, 1U)); + weights_permutation_vector = PermutationVector(3U, 2U, 0U, 1U); + } - // Get indices for the width and height - _data_layout = src->data_layout(); - const unsigned int width_idx = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH); - const unsigned int height_idx = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT); - const unsigned int channel_idx = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::CHANNEL); + // Re-order a weight tensor from [Output feature map x Input feature map x Height x Width] to [Height x Width x Input feature map x Output feature map] + _permute_weights->configure(weights, &_weights_hwio, weights_permutation_vector); - const Size2D input_dims = Size2D(src->dimension(width_idx), src->dimension(height_idx)); - const Size2D kernel_size = Size2D(weights->dimension(width_idx), weights->dimension(height_idx)); - const DataType data_type = src->data_type(); - const Size2D output_tile = winograd_output_tile(input_dims, kernel_size, data_type); + // Reorder the convoluted output to ACL's ordering NCHW + if(_data_layout == DataLayout::NCHW) + { + // configure and allocate dst tensor to be used to convert from winograd domain to spatial domain when calling to reshape_output() + TensorInfo info(TensorShape(dst->dimension(2), dst->dimension(0), + dst->dimension(1), dst->dimension(3)), + 1, dst->data_type()); + _output_nhwc = info; + _permute_output->configure(&_output_nhwc, dst, PermutationVector(1U, 2U, 0U)); + } - // Check if the Winograd configuration requires fast math - if(!enable_fast_math) - { - ARM_COMPUTE_ERROR_ON_MSG(check_support_fast_math(output_tile, kernel_size, data_type), - "This Winograd configuration requires enable_fast_math=true"); - } + // Configure input transform kernel + _transform_input_kernel = std::make_unique(_winograd_impl, *_conv_args, nthreads); - _is_prepared = false; + // Configure GEMM function + _gemm_function->configure(&_winograd_transformed_input, &_winograd_transformed_weights, nullptr, &_winograd_transformed_output, 1.0f, 0.f); - std::unique_ptr transform_input_kernel; - std::unique_ptr transform_weights_kernel; - std::unique_ptr transform_output_kernel; + // Configure output transform kernel + _transform_output_kernel = std::make_unique(_winograd_impl, *_conv_args, nthreads); - int n_gemms = 1; - int N_BLOCK = 1; // Size of block used by GEMM. - if(data_type == DataType::F32) - { - if(kernel_size == Size2D(3, 3)) - { - if(src->dimension(width_idx) > 4 && src->dimension(height_idx) > 4) - { - using config = CpuWinogradConv2dConfiguration; - transform_input_kernel = std::make_unique(); - transform_weights_kernel = std::make_unique(); - transform_output_kernel = std::make_unique(); - n_gemms = config::WinogradBase::N_GEMMS; - N_BLOCK = config::WinogradConv::N_BLOCK; - } - else - { - using config = CpuWinogradConv2dConfiguration; - transform_input_kernel = std::make_unique(); - transform_weights_kernel = std::make_unique(); - transform_output_kernel = std::make_unique(); - n_gemms = config::WinogradBase::N_GEMMS; - N_BLOCK = config::WinogradConv::N_BLOCK; - } - } - else if(kernel_size == Size2D(5, 5)) - { - using config = CpuWinogradConv2dConfiguration; - transform_input_kernel = std::make_unique(); - transform_weights_kernel = std::make_unique(); - transform_output_kernel = std::make_unique(); - n_gemms = config::WinogradBase::N_GEMMS; - N_BLOCK = config::WinogradConv::N_BLOCK; - } - else if(kernel_size == Size2D(1, 3)) - { - using config = CpuWinogradConv2dConfiguration; - transform_input_kernel = std::make_unique(); - transform_weights_kernel = std::make_unique(); - transform_output_kernel = std::make_unique(); - n_gemms = config::WinogradBase::N_GEMMS; - N_BLOCK = config::WinogradConv::N_BLOCK; - } - else if(kernel_size == Size2D(3, 1)) - { - using config = CpuWinogradConv2dConfiguration; - transform_input_kernel = std::make_unique(); - transform_weights_kernel = std::make_unique(); - transform_output_kernel = std::make_unique(); - n_gemms = config::WinogradBase::N_GEMMS; - N_BLOCK = config::WinogradConv::N_BLOCK; - } - else if(kernel_size == Size2D(1, 5)) - { - using config = CpuWinogradConv2dConfiguration; - transform_input_kernel = std::make_unique(); - transform_weights_kernel = std::make_unique(); - transform_output_kernel = std::make_unique(); - n_gemms = config::WinogradBase::N_GEMMS; - N_BLOCK = config::WinogradConv::N_BLOCK; - } - else if(kernel_size == Size2D(5, 1)) - { - using config = CpuWinogradConv2dConfiguration; - transform_input_kernel = std::make_unique(); - transform_weights_kernel = std::make_unique(); - transform_output_kernel = std::make_unique(); - n_gemms = config::WinogradBase::N_GEMMS; - N_BLOCK = config::WinogradConv::N_BLOCK; - } - else if(kernel_size == Size2D(1, 7)) - { - using config = CpuWinogradConv2dConfiguration; - transform_input_kernel = std::make_unique(); - transform_weights_kernel = std::make_unique(); - transform_output_kernel = std::make_unique(); - n_gemms = config::WinogradBase::N_GEMMS; - N_BLOCK = config::WinogradConv::N_BLOCK; - } - else if(kernel_size == Size2D(7, 1)) - { - using config = CpuWinogradConv2dConfiguration; - transform_input_kernel = std::make_unique(); - transform_weights_kernel = std::make_unique(); - transform_output_kernel = std::make_unique(); - n_gemms = config::WinogradBase::N_GEMMS; - N_BLOCK = config::WinogradConv::N_BLOCK; - } - else - { - ARM_COMPUTE_ERROR("Not supported."); - } - } -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - else if(data_type == DataType::F16) - { - if(kernel_size == Size2D(3, 3)) + //Configure Activation Layer + _run_activation = act_info.enabled() && !fuse_function_supported(act_info); + if(_run_activation) { - using config = CpuWinogradConv2dConfiguration<__fp16, __fp16, 4, 4, 3, 3>; - transform_input_kernel = std::make_unique(); - transform_weights_kernel = std::make_unique(); - transform_output_kernel = std::make_unique(); - n_gemms = config::WinogradBase::N_GEMMS; - N_BLOCK = config::WinogradConv::N_BLOCK; + _activation_func->configure(dst, nullptr, act_info); } - else + + auto asm_mem_req = _gemm_function->workspace(); + _aux_mem[GemmWorkspace] = asm_mem_req[GemmWorkspace]; + _aux_mem[Pretranspose] = asm_mem_req[Pretranspose]; + _aux_mem[InterleavedLHS] = asm_mem_req[InterleavedLHS]; + _aux_mem[TransposedRHS] = asm_mem_req[TransposedRHS]; + _aux_mem[TempResult] = asm_mem_req[TempResult]; + + // Request temporary memory. Overlap memory needed for Input/Output transformations as they run on different non-overlapping time-steps. + _aux_mem[TransformedInput] = MemoryInfo(offset_int_vec(TransformedInput), MemoryLifetime::Temporary, wds.input_matrix_size_bytes, storage_alignment); + _aux_mem[TransformedOutput] = MemoryInfo(offset_int_vec(TransformedOutput), MemoryLifetime::Temporary, wds.output_matrix_size_bytes, storage_alignment); + _aux_mem[WorkspaceIO] = MemoryInfo(offset_int_vec(WorkspaceIO), MemoryLifetime::Temporary, std::max(input_workspace_size, output_workspace_size)); + _aux_mem[PermutedWeights] = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Prepare, _weights_hwio.total_size()); + _aux_mem[TransformedWeights] = MemoryInfo(offset_int_vec(TransformedWeights), MemoryLifetime::Persistent, wds.weight_matrix_size_bytes, storage_alignment); + if(_data_layout == DataLayout::NCHW) { - ARM_COMPUTE_ERROR("Not supported."); + _aux_mem[PermutedInput].merge(offset_int_vec(PermutedInput), src->total_size()); + _aux_mem[PermutedOutput].merge(offset_int_vec(PermutedOutput), dst->total_size()); } } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - else - { - ARM_COMPUTE_ERROR("Not supported."); - } - - const PaddingType use_padding_type = (conv_info.pad_top() != 0u || conv_info.pad_left() != 0) ? PADDING_SAME : PADDING_VALID; - const bool use_same_padding = use_padding_type == PADDING_SAME; - - // Get convolved dimensions - const int in_channels = src->dimension(channel_idx); - const int out_channels = dst->dimension(channel_idx); - - const Tensor4DShape in_shape(internal_get_input_shape(src)); - const size_t data_type_size = src->element_size(); - // Get the memory required to instantiate a new Winograd operator. - constexpr size_t storage_alignment = 64; - - // Kernel Storage - const size_t kernel_storage_size = transform_weights_kernel->get_weight_storage_size(out_channels, in_channels) * data_type_size; - - // Input storage - const size_t input_storage_size = transform_input_kernel->get_input_storage_size(in_shape.n_batches, in_shape.n_channels, in_shape.n_rows, in_shape.n_cols, use_same_padding) * data_type_size; - - // Output storage - const size_t output_storage_size = transform_output_kernel->get_output_storage_size(in_shape.n_batches, in_shape.n_rows, in_shape.n_cols, out_channels) * data_type_size; - const int kernel_matrix_stride = transform_weights_kernel->get_matrix_stride(out_channels, in_channels); - const int output_matrix_stride = transform_output_kernel->get_matrix_stride(in_shape.n_batches, in_shape.n_rows, in_shape.n_cols, out_channels); - const auto output_shape = transform_output_kernel->get_output_shape(in_shape.n_rows, in_shape.n_cols, use_padding_type == PADDING_SAME); - const int input_matrix_stride = transform_input_kernel->get_matrix_stride(in_shape.n_batches, in_channels, in_shape.n_rows, in_shape.n_cols, use_padding_type == PADDING_SAME); - - // Configure GEMM - const int tile_rows = iceildiv(output_shape.first, output_tile.height); - const int tile_cols = iceildiv(output_shape.second, output_tile.width); - const int m = in_shape.n_batches * tile_rows * tile_cols; - const int k = in_shape.n_channels; - const int n = out_channels; - const int kernel_matrix_row_stride = roundup(out_channels, N_BLOCK); - const int output_matrix_row_stride = kernel_matrix_row_stride; - - TensorShape a_shape(k, m, 1, n_gemms); - Strides a_strides(data_type_size); - a_strides.set(1, a_strides[0] * k); - //a_strides.set(2, data_type_size * input_matrix_stride / n_gemms); FIXME: This is the real batch size, but RSH's code crashes if it's not 0. - a_strides.set(2, 0); - a_strides.set(3, data_type_size * input_matrix_stride); - - TensorShape b_shape(n, k, n_gemms); - Strides b_strides(data_type_size); - b_strides.set(1, data_type_size * kernel_matrix_row_stride); - b_strides.set(2, data_type_size * kernel_matrix_stride); - - TensorShape d_shape(n, m, 1, n_gemms); - Strides d_strides(data_type_size); - d_strides.set(1, data_type_size * output_matrix_row_stride); - //d_strides.set(2, data_type_size * output_matrix_stride / n_gemms); FIXME: This is the real batch size, but RSH's code crashes if it's not 0. - d_strides.set(2, 0); - d_strides.set(3, data_type_size * output_matrix_stride); - - TensorInfo a_info{}; - TensorInfo b_info{}; - TensorInfo d_info{}; - a_info.init(a_shape, 1, data_type, a_strides, 0, input_storage_size); - b_info.init(b_shape, 1, data_type, b_strides, 0, kernel_storage_size); - d_info.init(d_shape, 1, data_type, d_strides, 0, output_storage_size); - - _input_transformed = a_info; - _kernel_storage = b_info; - _output_transformed = d_info; - - const ITensorInfo *input_to_use = src; - ITensorInfo *output_to_use = dst; - PermutationVector weights_permutation_vector(3U, 0U, 1U, 2U); - const unsigned int max_num_threads = NEScheduler::get().num_threads(); - - // Configure the kernel to transform the input tensor from NCHW -> NHWC - if(_data_layout == DataLayout::NCHW) - { - _permute_input->configure(src, &_input_nhwc, PermutationVector(2U, 0U, 1U)); - input_to_use = &_input_nhwc; - weights_permutation_vector = PermutationVector(3U, 2U, 0U, 1U); - } - - // Configure input transform kernel - transform_input_kernel->configure(input_to_use, in_shape.n_batches, in_shape.n_rows, in_shape.n_cols, in_shape.n_channels, use_padding_type, - &_input_transformed, input_matrix_stride, &_input_workspace); - const size_t input_workspace_size = transform_input_kernel->get_working_space_size(max_num_threads); - TensorInfo input_workspace_info(TensorShape(input_workspace_size), 1, DataType::U8); - _input_workspace = input_workspace_info; - - // Re-order a weight tensor from [Output feature map x Input feature map x Height x Width] to [Height x Width x Input feature map x Output feature map] - _permute_weights->configure(weights, &_weights_hwio, weights_permutation_vector); - transform_weights_kernel->configure(&_weights_hwio, &_kernel_storage, kernel_matrix_stride, out_channels, in_channels); - - // Configure GEMM function - _gemm_function->configure(&_input_transformed, &_kernel_storage, nullptr, &_output_transformed, 1.0f, 0.f); - - // Configure output transform function - // The biases tensor has not been allocated at this point in time, the output transform will add the biases to the final result in the run() method - if(_data_layout == DataLayout::NCHW) - { - // configure and allocate dst tensor to be used to convert from winograd domain to spatial domain when calling to reshape_output() - TensorInfo info(TensorShape(dst->dimension(2), dst->dimension(0), - dst->dimension(1), dst->dimension(3)), - 1, dst->data_type()); - _output_nhwc = info; - output_to_use = &_output_nhwc; - } - const arm_gemm::Activation activation = arm_gemm_activation_from_acl_activation(act_info); - - transform_output_kernel->configure(biases, - &_output_transformed, - output_matrix_stride, - output_to_use, - in_shape.n_batches, - output_shape.first, - output_shape.second, - out_channels, - &_output_workspace, - activation); - - const size_t output_workspace_size = transform_output_kernel->get_working_space_size(max_num_threads); - TensorInfo output_workspace_info(TensorShape(output_workspace_size), 1, DataType::U8); - _output_workspace = output_workspace_info; - - // Reorder the convoluted output to ACL's ordering NCHW - if(_data_layout == DataLayout::NCHW) - { - _permute_output->configure(&_output_nhwc, dst, PermutationVector(1U, 2U, 0U)); - } - - _transform_input_kernel = std::move(transform_input_kernel); - _transform_weights_kernel = std::move(transform_weights_kernel); - _transform_output_kernel = std::move(transform_output_kernel); - - //Configure Activation Layer - _run_activation = act_info.enabled() && !fuse_function_supported(act_info); - if(_run_activation) - { - _activation_func->configure(dst, nullptr, act_info); - } - - auto asm_mem_req = _gemm_function->workspace(); - _aux_mem[GemmWorkspace] = asm_mem_req[GemmWorkspace]; - _aux_mem[Pretranspose] = asm_mem_req[Pretranspose]; - _aux_mem[InterleavedLHS] = asm_mem_req[InterleavedLHS]; - _aux_mem[TransposedRHS] = asm_mem_req[TransposedRHS]; - _aux_mem[TempResult] = asm_mem_req[TempResult]; - - // Request temporary memory. Overlap memory needed for Input/Output transformations as they run on different non-overlapping time-steps. - _aux_mem[TransformedInput] = MemoryInfo(offset_int_vec(TransformedInput), MemoryLifetime::Temporary, input_storage_size, storage_alignment); - _aux_mem[TransformedOutput] = MemoryInfo(offset_int_vec(TransformedOutput), MemoryLifetime::Temporary, output_storage_size, storage_alignment); - _aux_mem[WorkspaceIO] = MemoryInfo(offset_int_vec(WorkspaceIO), MemoryLifetime::Temporary, std::max(input_workspace_size, output_workspace_size)); - _aux_mem[PermutedWeights] = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Prepare, _weights_hwio.total_size()); - _aux_mem[TransformedWeights] = MemoryInfo(offset_int_vec(TransformedWeights), MemoryLifetime::Persistent, kernel_storage_size, storage_alignment); - if(_data_layout == DataLayout::NCHW) - { - _aux_mem[PermutedInput].merge(offset_int_vec(PermutedInput), src->total_size()); - _aux_mem[PermutedOutput].merge(offset_int_vec(PermutedOutput), dst->total_size()); - } } - Status CpuWinogradConv2d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const PadStrideInfo &conv_info, const ActivationLayerInfo &act_info, bool enable_fast_math) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, weights, dst); ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src, weights, biases, dst, conv_info)); - // Get indices for the width and height - const size_t idx_width = get_data_layout_dimension_index(src->data_layout(), DataLayoutDimension::WIDTH); - const size_t idx_height = get_data_layout_dimension_index(src->data_layout(), DataLayoutDimension::HEIGHT); - - // Input shape, kernel size and output tile - const Size2D input_dims = Size2D(src->dimension(idx_width), src->dimension(idx_height)); - const Size2D kernel_size = Size2D(weights->dimension(idx_width), weights->dimension(idx_height)); - const DataType data_type = src->data_type(); - const Size2D output_tile = winograd_output_tile(input_dims, kernel_size, data_type); + const Tensor4DShape kernel_shape{ internal_get_shape(weights) }; + arm_conv::winograd::WinogradImpl winograd_impl{}; - // Check if the Winograd configuration requires fast math - if(!enable_fast_math) - { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(check_support_fast_math(output_tile, kernel_size, data_type), - "This Winograd configuration requires enable_fast_math=true"); - } + std::unique_ptr conv_args; + const bool success = get_winograd_kernel_implementation(src, weights, dst, conv_info, act_info, enable_fast_math, &winograd_impl, conv_args); - const WinogradInfo winograd_info = WinogradInfo(output_tile, - kernel_size, - input_dims, - conv_info, - src->data_layout()); - - // Validate input transform - const TensorShape input0_shape = misc::shape_calculator::compute_winograd_input_transform_shape(*src, winograd_info); - const TensorInfo input0 = src->clone()->set_tensor_shape(input0_shape); - // Validate filter transform - const TensorShape input1_shape = misc::shape_calculator::compute_winograd_filter_transform_shape(*weights, winograd_info); - const TensorInfo input1 = weights->clone()->set_tensor_shape(input1_shape); - // Validate batched matrix multiply - TensorShape batched_mm_output_shape = input0.tensor_shape(); - batched_mm_output_shape[0] = input1.tensor_shape()[0]; - const TensorInfo batched_mm_output = input0.clone()->set_tensor_shape(batched_mm_output_shape); - - if(kernel_size == Size2D(3, 3)) - { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != 0u && conv_info.pad_top() != 1, "Only SAME or VALID padding supported"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_bottom() != 0u && conv_info.pad_bottom() != 1, "Only SAME or VALID padding supported"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_left() != 0u && conv_info.pad_left() != 1, "Only SAME or VALID padding supported"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_right() != 0u && conv_info.pad_right() != 1, "Only SAME or VALID padding supported"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_right() != conv_info.pad_left(), "Only SAME or VALID padding supported"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != conv_info.pad_bottom(), "Only SAME or VALID padding supported"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != conv_info.pad_left(), "Only SAME or VALID padding supported"); - return validate_kernel_3x3(input_dims, src, &input0, &input1, &batched_mm_output, weights, biases, dst, winograd_info, act_info); - } - else if(kernel_size == Size2D(5, 5)) - { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != 0u && conv_info.pad_top() != 2, "Only SAME or VALID padding supported"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_left() != 0u && conv_info.pad_left() != 2, "Only SAME or VALID padding supported"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_bottom() != 0u && conv_info.pad_bottom() != 2, "Only SAME or VALID padding supported"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_right() != 0u && conv_info.pad_right() != 2, "Only SAME or VALID padding supported"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_right() != conv_info.pad_left(), "Only SAME or VALID padding supported"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != conv_info.pad_bottom(), "Only SAME or VALID padding supported"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != conv_info.pad_left(), "Only SAME or VALID padding supported"); - return validate_kernel_5x5(src, &input0, &input1, &batched_mm_output, weights, biases, dst, winograd_info, act_info); - } - if(kernel_size == Size2D(3, 1)) - { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_left() != 0u && conv_info.pad_left() != 1, "Only SAME or VALID padding supported"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_right() != 0u && conv_info.pad_right() != 1, "Only SAME or VALID padding supported"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != 0u && conv_info.pad_bottom() != 0, "Only SAME or VALID padding supported"); - return validate_kernel_3x1(src, &input0, &input1, &batched_mm_output, weights, biases, dst, winograd_info, act_info); - } - else if(kernel_size == Size2D(1, 3)) - { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != 0u && conv_info.pad_top() != 1, "Only SAME or VALID padding supported"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_bottom() != 0u && conv_info.pad_bottom() != 1, "Only SAME or VALID padding supported"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_left() != 0u && conv_info.pad_right() != 0, "Only SAME or VALID padding supported"); - return validate_kernel_1x3(src, &input0, &input1, &batched_mm_output, weights, biases, dst, winograd_info, act_info); - } - else if(kernel_size == Size2D(5, 1)) - { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_left() != 0u && conv_info.pad_left() != 2, "Only SAME or VALID padding supported"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_right() != 0u && conv_info.pad_right() != 2, "Only SAME or VALID padding supported"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != 0u && conv_info.pad_bottom() != 0, "Only SAME or VALID padding supported"); - return validate_kernel_5x1(src, &input0, &input1, &batched_mm_output, weights, biases, dst, winograd_info, act_info); - } - else if(kernel_size == Size2D(1, 5)) - { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != 0u && conv_info.pad_top() != 2, "Only SAME or VALID padding supported"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_bottom() != 0u && conv_info.pad_bottom() != 2, "Only SAME or VALID padding supported"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_left() != 0u && conv_info.pad_right() != 0, "Only SAME or VALID padding supported"); - return validate_kernel_1x5(src, &input0, &input1, &batched_mm_output, weights, biases, dst, winograd_info, act_info); - } - else if(kernel_size == Size2D(7, 1)) - { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_left() != 0u && conv_info.pad_left() != 3, "Only SAME or VALID padding supported"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_right() != 0u && conv_info.pad_right() != 3, "Only SAME or VALID padding supported"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != 0u && conv_info.pad_bottom() != 0, "Only SAME or VALID padding supported"); - return validate_kernel_7x1(src, &input0, &input1, &batched_mm_output, weights, biases, dst, winograd_info, act_info); - } - else if(kernel_size == Size2D(1, 7)) - { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_top() != 0u && conv_info.pad_top() != 3, "Only SAME or VALID padding supported"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_bottom() != 0u && conv_info.pad_bottom() != 3, "Only SAME or VALID padding supported"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.pad_left() != 0u && conv_info.pad_right() != 0, "Only SAME or VALID padding supported"); - return validate_kernel_1x7(src, &input0, &input1, &batched_mm_output, weights, biases, dst, winograd_info, act_info); - } - else - { - ARM_COMPUTE_RETURN_ERROR_MSG("Kernel shape not supported"); - } + ARM_COMPUTE_RETURN_ERROR_ON_MSG_VAR(success == false, "Unsupported kernel size: %d x %d.\n", kernel_shape.n_rows, kernel_shape.n_cols); + ARM_COMPUTE_LOG_MSG_WITH_FORMAT_ACL(arm_compute::logging::LogLevel::INFO, "Using input transform: %s\n", winograd_impl.input_transform->get_name().c_str()); + ARM_COMPUTE_LOG_MSG_WITH_FORMAT_ACL(arm_compute::logging::LogLevel::INFO, "Using weight transform: %s\n", winograd_impl.input_transform->get_name().c_str()); + ARM_COMPUTE_LOG_MSG_WITH_FORMAT_ACL(arm_compute::logging::LogLevel::INFO, "Using output transform: %s\n", winograd_impl.input_transform->get_name().c_str()); + return Status{}; } void CpuWinogradConv2d::run(ITensorPack &tensors) { prepare(tensors); + auto src = tensors.get_const_tensor(ACL_SRC_0); + auto biases = tensors.get_const_tensor(ACL_SRC_2); + auto output = tensors.get_tensor(ACL_DST); + Window win; + + const uint32_t nthreads = NEScheduler::get().num_threads(); - auto a = tensors.get_const_tensor(ACL_SRC_0); - auto c = tensors.get_const_tensor(ACL_SRC_2); - auto d = tensors.get_tensor(ACL_DST); + // The Winograd transform implementation does fine-grain threading inside the transforms. Just pass thread_id and nthreads. + win.set(Window::DimX, Window::Dimension(0, nthreads, 1)); + // Wrap the winograd-domain tensorInfos created in configuration in tensors and allocate the required memory. CpuAuxTensorHandler input_nhwc(offset_int_vec(PermutedInput), _input_nhwc, tensors, true); - CpuAuxTensorHandler input_transformed(offset_int_vec(TransformedInput), _input_transformed, tensors, true); + CpuAuxTensorHandler winograd_input_transformed(offset_int_vec(TransformedInput), _winograd_transformed_input, tensors, true); CpuAuxTensorHandler input_workspace(offset_int_vec(WorkspaceIO), _input_workspace, tensors, true); - - const bool is_nchw = _data_layout == DataLayout::NCHW; + const bool is_nchw = _data_layout == DataLayout::NCHW; if(is_nchw) { //Bring channels to the front as Winograd code expects the tensor to be in the format NHWC - ITensorPack pack{ { ACL_SRC, a }, { ACL_DST, input_nhwc.get() } }; + ITensorPack pack{ { ACL_SRC, src }, { ACL_DST, input_nhwc.get() } }; _permute_input->run(pack); } - // Transform input tensor to the winograd domain - ITensorPack transform_input_pack{ { ACL_SRC, is_nchw ? input_nhwc.get() : a }, { ACL_DST, input_transformed.get() }, { ACL_INT, input_workspace.get() } }; - NEScheduler::get().schedule_op(_transform_input_kernel.get(), Window::DimX, _transform_input_kernel->window(), transform_input_pack); + CpuAuxTensorHandler winograd_output_transformed(offset_int_vec(TransformedOutput), _winograd_transformed_output, tensors, true); + CpuAuxTensorHandler output_workspace(offset_int_vec(WorkspaceIO), _output_workspace, tensors, true); + CpuAuxTensorHandler output_nhwc(offset_int_vec(PermutedOutput), _output_nhwc, tensors, true); - CpuAuxTensorHandler output_transformed(offset_int_vec(TransformedOutput), _output_transformed, tensors, true); - CpuAuxTensorHandler weights_transformed(offset_int_vec(TransformedWeights), _kernel_storage, tensors, true); + ITensorPack transform_input_pack{ { ACL_SRC, is_nchw ? input_nhwc.get() : src }, { ACL_DST, winograd_input_transformed.get() }, { ACL_INT, input_workspace.get() } }; + NEScheduler::get().schedule_op(_transform_input_kernel.get(), Window::DimX, win, transform_input_pack); + + CpuAuxTensorHandler winograd_weights_transformed(offset_int_vec(TransformedWeights), _winograd_transformed_weights, tensors, true); // Run 16 GEMMs in multiple threads, each kernel runs one or more GEMMs ITensorPack gemm_pack = tensors; - gemm_pack.add_const_tensor(ACL_SRC, input_transformed.get()); - gemm_pack.add_const_tensor(ACL_SRC_1, weights_transformed.get()); + gemm_pack.add_const_tensor(ACL_SRC, winograd_input_transformed.get()); + gemm_pack.add_const_tensor(ACL_SRC_1, winograd_weights_transformed.get()); gemm_pack.add_const_tensor(ACL_BIAS, nullptr); - gemm_pack.add_tensor(ACL_DST, output_transformed.get()); + gemm_pack.add_tensor(ACL_DST, winograd_output_transformed.get()); _gemm_function->run(gemm_pack); - // Transform output tensor to the spatial domain - CpuAuxTensorHandler output_workspace(offset_int_vec(WorkspaceIO), _output_workspace, tensors, true); - CpuAuxTensorHandler output_nhwc(offset_int_vec(PermutedOutput), _output_nhwc, tensors, true); - ITensorPack transform_output_pack{ { ACL_SRC_0, c }, { ACL_SRC_1, output_transformed.get() }, { ACL_DST, is_nchw ? output_nhwc.get() : d }, { ACL_INT, output_workspace.get() } }; - NEScheduler::get().schedule_op(_transform_output_kernel.get(), Window::DimX, _transform_output_kernel->window(), transform_output_pack); - + // Output transform + ITensorPack transform_output_pack{ { ACL_SRC_0, winograd_output_transformed.get() }, { ACL_DST, is_nchw ? output_nhwc.get() : output }, { ACL_SRC_1, biases }, { ACL_INT, output_workspace.get() } }; + NEScheduler::get().schedule_op(_transform_output_kernel.get(), Window::DimX, win, transform_output_pack); if(is_nchw) { // Reorder the convoluted output to ACL's ordering NCHW - ITensorPack pack{ { ACL_SRC, output_nhwc.get() }, { ACL_DST, d } }; + ITensorPack pack{ { ACL_SRC, output_nhwc.get() }, { ACL_DST, output } }; _permute_output->run(pack); } - if(_run_activation) { - ITensorPack pack{ { ACL_SRC, d }, { ACL_DST, d } }; + ITensorPack pack{ { ACL_SRC, output }, { ACL_DST, output } }; _activation_func->run(pack); } } @@ -808,34 +369,54 @@ void CpuWinogradConv2d::prepare(ITensorPack &tensors) { if(!_is_prepared) { - // Permute weights const ITensor *weights = tensors.get_const_tensor(ACL_SRC_1); ITensor *weights_aux = utils::cast::polymorphic_cast(tensors.get_tensor(offset_int_vec(PermutedWeights))); - ARM_COMPUTE_ERROR_ON_NULLPTR(weights, weights_aux); CpuAuxTensorHandler permuted_weights(_weights_hwio, *weights_aux); ITensorPack permute_tensors{ { ACL_SRC, weights }, { ACL_DST, permuted_weights.get() } }; _permute_weights->run(permute_tensors); + const int element_size_in_bytes = permuted_weights.get()->info()->element_size(); + // Weights were in OHWI format, before being permuted "permuted_weights" to be in HWIO format. + const unsigned int height_idx = 3; // H in HWIO + const unsigned int width_idx = 2; // W in HWIO + const unsigned int channel_idx = 1; // I in HWIO - // Transform weights + const int permuted_weight_row_stride = permuted_weights.get()->info()->strides_in_bytes()[height_idx] / element_size_in_bytes; + const int permuted_weight_col_stride = permuted_weights.get()->info()->strides_in_bytes()[width_idx] / element_size_in_bytes; + const int permuted_weight_channel_stride = permuted_weights.get()->info()->strides_in_bytes()[channel_idx] / element_size_in_bytes; + + // Wrap the winograd-domain transformed weight TensorInfo in Auxiliary tensor and allocate the required memory. ITensor *weights_transf = utils::cast::polymorphic_cast(tensors.get_tensor(offset_int_vec(TransformedWeights))); ARM_COMPUTE_ERROR_ON_NULLPTR(weights_transf); - - CpuAuxTensorHandler transformed_weights(_kernel_storage, *weights_transf); - ITensorPack transform_tensors{ { ACL_SRC, permuted_weights.get() }, { ACL_DST, transformed_weights.get() } }; - NEScheduler::get().schedule_op(_transform_weights_kernel.get(), Window::DimX, _transform_weights_kernel->window(), transform_tensors); - + CpuAuxTensorHandler winograd_transformed_weights(_winograd_transformed_weights, *weights_transf); + + const void *permuted_weights_ptr; + void *win_wght_transf_ptr; + + permuted_weights_ptr = reinterpret_cast(permuted_weights.get()->buffer() + permuted_weights.get()->info()->offset_first_element_in_bytes()); + win_wght_transf_ptr = reinterpret_cast(winograd_transformed_weights.get()->buffer() + winograd_transformed_weights.get()->info()->offset_first_element_in_bytes()); + + // Prepare Weights + _winograd_impl.weight_transform->execute( + *_conv_args, + permuted_weights_ptr, + permuted_weight_row_stride, + permuted_weight_col_stride, + permuted_weight_channel_stride, + win_wght_transf_ptr, + _winograd_impl.winograd_spec, + 0, 1 // Thread 1 of 1 + ); ITensorPack gemm_pack = tensors; - gemm_pack.add_const_tensor(ACL_SRC_1, transformed_weights.get()); + gemm_pack.add_const_tensor(ACL_SRC_1, winograd_transformed_weights.get()); _gemm_function->prepare(gemm_pack); - - _is_prepared = true; + _is_prepared = 1; } } - experimental::MemoryRequirements CpuWinogradConv2d::workspace() const { return _aux_mem; } + } // namespace cpu -} // namespace arm_compute \ No newline at end of file +} // namespace arm_compute diff --git a/src/cpu/operators/CpuWinogradConv2d.h b/src/cpu/operators/CpuWinogradConv2d.h index 0abd110f73..e0df34e2db 100644 --- a/src/cpu/operators/CpuWinogradConv2d.h +++ b/src/cpu/operators/CpuWinogradConv2d.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -29,6 +29,7 @@ #include "src/core/common/Macros.h" #include "src/cpu/ICpuOperator.h" #include "src/cpu/kernels/CpuWinogradConv2dKernel.h" +#include "src/cpu/kernels/assembly/gemm_common.hpp" #include "src/cpu/operators/CpuActivation.h" #include "src/cpu/operators/CpuGemm.h" #include "src/cpu/operators/CpuPermute.h" @@ -59,13 +60,13 @@ class CpuWinogradConv2d : public ICpuOperator * |F16 |F16 |F16 |F16 | * |F32 |F32 |F32 |F32 | * - * @param[in] src Source tensor info. 3 lower dimensions represent a single input [width, height, IFM], + * @param[in] src Source tensor Info. 3 lower dimensions represent a single input [width, height, IFM], * while every optional dimension from 4 and above represent a batch of inputs. * Data types supported: F16/F32. - * @param[in] weights Weights tensor info. Weights are 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM]. Data type supported: Same as @p input. + * @param[in] weights Weights tensor Info. Weights are 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM]. Data type supported: Same as @p input. * Currently only 3x3 and 5x5 kernels are supported. - * @param[in] biases Biases tensor info. Shared biases supported. Biases are 1D tensor with dimensions [OFM]. Data type supported: Same as @p weights. - * @param[out] dst Destination tensor info. 3 lower dimensions represent a single output [width, height, OFM], while the rest represent batch of outputs. + * @param[in] biases Biases tensor Info. Shared biases supported. Biases are 1D tensor with dimensions [OFM]. Data type supported: Same as @p weights. + * @param[out] dst Destination tensor Info. 3 lower dimensions represent a single output [width, height, OFM], while the rest represent batch of outputs. * Data types supported: Same as @p input. * @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo. Currently only unit strides are supported. * @param[in] act_info (Optional) Activation layer information in case of a fused activation. @@ -107,28 +108,27 @@ class CpuWinogradConv2d : public ICpuOperator PermutedOutput = TransformedInput, Count = 10 }; - - std::unique_ptr _gemm_function; - std::unique_ptr _activation_func; - std::unique_ptr _permute_input; - std::unique_ptr _permute_output; - std::unique_ptr _permute_weights; - std::unique_ptr _transform_input_kernel; - std::unique_ptr _transform_weights_kernel; - std::unique_ptr _transform_output_kernel; - - DataLayout _data_layout; - experimental::MemoryRequirements _aux_mem{ Count }; - TensorInfo _input_nhwc; - TensorInfo _output_nhwc; - TensorInfo _input_workspace; - TensorInfo _kernel_storage; - TensorInfo _output_workspace; - TensorInfo _input_transformed; - TensorInfo _output_transformed; - TensorInfo _weights_hwio; - bool _run_activation; - bool _is_prepared; + std::unique_ptr _gemm_function; + std::unique_ptr _activation_func; + std::unique_ptr _transform_input_kernel; + std::unique_ptr _transform_output_kernel; + std::unique_ptr _permute_input; + std::unique_ptr _permute_output; + std::unique_ptr _permute_weights; + experimental::MemoryRequirements _aux_mem{ Count }; + std::unique_ptr _conv_args; // Make it unique ptr because this type does not have a default constructor + arm_conv::winograd::WinogradImpl _winograd_impl; + DataLayout _data_layout; + TensorInfo _winograd_transformed_input; + TensorInfo _winograd_transformed_output; + TensorInfo _winograd_transformed_weights; + TensorInfo _input_workspace; + TensorInfo _output_workspace; + TensorInfo _weights_hwio; + TensorInfo _input_nhwc; + TensorInfo _output_nhwc; + bool _is_prepared; + bool _run_activation; }; } // namespace cpu } // namespace arm_compute diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp index a97e53d4af..77da83070b 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp @@ -25,6 +25,7 @@ #include "arm_compute/runtime/NEON/NEScheduler.h" #include "src/core/CPP/Validate.h" +#include "src/core/NEON/kernels/arm_gemm/utils.hpp" #include "src/core/helpers/MemoryHelpers.h" #include "src/core/utils/AssemblyUtils.h" #include "src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h" @@ -156,10 +157,17 @@ class Fallback : public CpuGemmAssemblyDispatch::IFallback const std::vector &multipliers); // Inherited methods overridden: - void run(ITensorPack &tensors) override; - void prepare(ITensorPack &tensors) override; + void run(ITensorPack &tensors) override; + void prepare(ITensorPack &tensors) override; bool is_configured() const override; experimental::MemoryRequirements workspace() const override; + bool isVarWeightsKernel() const override + { + if(!_gemm_kernel_asm) + return false; + const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format); + return wf != arm_compute::WeightFormat::UNSPECIFIED && wf != arm_compute::WeightFormat::ANY; + } private: enum AuxTensorIdx @@ -203,12 +211,12 @@ class Fallback : public CpuGemmAssemblyDispatch::IFallback /** Indirect buffer */ std::unique_ptr _indirect_arg{}; std::unique_ptr _indirect_buf{}; - std::vector _indirect_pad{}; - arm_gemm::ConvolutionParameters _cp{}; - experimental::MemoryRequirements _aux_mem{ Count }; - bool _B_pretranspose_required{ false }; - bool _is_b_constant{ true }; - bool _is_c_constant{ true }; + std::vector _indirect_pad{}; + arm_gemm::ConvolutionParameters _cp{}; + experimental::MemoryRequirements _aux_mem{ Count }; + bool _B_pretranspose_required{ false }; + bool _is_b_constant{ true }; + bool _is_c_constant{ true }; }; template @@ -420,6 +428,8 @@ void Fallback::prepare(ITensorPack &tensors) // Pretranspose B if required if(_gemm_kernel_asm->B_pretranspose_required()) { + // Fixed format kernels need no pretranspose. + ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format))); const int ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput); const auto in1_ptr = reinterpret_cast(b->buffer() + b->info()->offset_first_element_in_bytes()); const int multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput); @@ -483,9 +493,51 @@ void Fallback::run(ITensorPack &tensors) // Check if B is pre-tranposed and de-reference if not if(!_gemm_kernel_asm->B_is_pretransposed()) { - ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput); - multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput); - in1_ptr = reinterpret_cast(b->buffer() + b->info()->offset_first_element_in_bytes()); + ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput); + multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput); + const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format); + if(is_fixed_format(wf)) + { + // The 4D tensor of dimension O'HWI' created for the + // OHWIoi format is in reality seen + // as a 2D tensor at arm_gemm level, where the rows are + // O'/ and the columns are * + // H * W * I'. + ITensorInfo *tensor_info = b->info(); + const DataLayout data_layout = tensor_info->data_layout(); + const TensorShape tensor_shape = tensor_info->tensor_shape(); + const int tensor_height = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)]; + const int tensor_width = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)]; + int tensor_channels = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)]; + const int interleave_by = arm_compute::interleave_by(wf); + const int blocked_by = arm_compute::block_by(wf); + // We need to find a new stride that is distance from the data for one + // set of output channels to the next + if(ldb == tensor_channels && multi_stride_b == tensor_channels * tensor_width) + { + // In this case dimensions that are packed are height, width and channel + // so we need to stride it by interleave_by + if(tensor_channels % blocked_by != 0) + { + // We need to pad + tensor_channels = arm_gemm::iceildiv(tensor_channels, blocked_by) * blocked_by; + } + ldb = interleave_by * tensor_height * tensor_width * tensor_channels; + } + else if(multi_stride_b == 0 || (ldb == tensor_width && multi_stride_b == tensor_height * tensor_width)) + { + // In this case dimension that is packed is only height + // so we need to stride only height by interleave_by + ldb = interleave_by * tensor_height; + } + else + { + // If dimensions are not packed as above error is thrown + // as at the moment other forms of packing are not supported + ARM_COMPUTE_ERROR("Unsupported packing for fixed format kernel"); + } + } + in1_ptr = reinterpret_cast(b->buffer() + b->info()->offset_first_element_in_bytes()); } // If necessary, run pretranspose every time if either weights or biases are non-constant @@ -576,7 +628,9 @@ void create_arm_gemm(std::unique_ptr &arm_ge const CPUInfo &ci = NEScheduler::get().cpu_info(); unsigned int num_threads = NEScheduler::get().num_threads(); - arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fast_mode); + arm_gemm::GemmConfig cfg; + cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format); + arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode, &cfg); // Create arm_gemm fallback auto fallback = std::make_unique>(); @@ -594,7 +648,9 @@ void create_arm_gemm_quant(std::unique_ptr & const CPUInfo &ci = NEScheduler::get().cpu_info(); const unsigned int num_threads = NEScheduler::get().num_threads(); - arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fast_mode); + arm_gemm::GemmConfig cfg; + cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format); + arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode, &cfg); // Create arm_gemm fallback auto fallback = std::make_unique>(); @@ -635,7 +691,8 @@ CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch() { } -Status CpuGemmAssemblyDispatch::has_opt_impl(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info) +Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, + const AsmGemmInfo &info) { ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d); ARM_COMPUTE_UNUSED(c); @@ -643,12 +700,14 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(const ITensorInfo *a, const ITensor Params p = extract_parameters(a, b, d, info); const CPUInfo &ci = NEScheduler::get().cpu_info(); unsigned int num_threads = NEScheduler::get().num_threads(); - - arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fast_mode); + arm_gemm::GemmConfig cfg; + cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format); + arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format); + arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode, &cfg); switch(a->data_type()) { case DataType::F32: - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(args, {})), + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for F32 input"); break; #ifdef __aarch64__ @@ -656,12 +715,12 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(const ITensorInfo *a, const ITensor case DataType::QASYMM8: if(d->data_type() == DataType::S32) { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(args, {})), + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for U8/QASYMM8 input and S32 output"); } else { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(args, {})), + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for U8 input and U8 output"); } break; @@ -669,27 +728,27 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(const ITensorInfo *a, const ITensor case DataType::QASYMM8_SIGNED: if(d->data_type() == DataType::S32) { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(args, {})), + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output"); } else { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(args, {})), + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for S8 input and S32 output"); } break; #endif /* __aarch64__ */ -#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) +#if defined(ARM_COMPUTE_ENABLE_BF16) case DataType::BFLOAT16: { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(args, {})), + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for BFLOAT16 input and F32 output"); break; } -#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ +#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: - ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(args, {})), + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm(arm_gemm_expected_wf, args, {})), "We could not find an optimized kernel for BFLOAT16 input and F32 output"); break; #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ @@ -697,6 +756,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(const ITensorInfo *a, const ITensor ARM_COMPUTE_RETURN_ERROR_ON_MSG(true, "Usupported type. Could not find a kernel"); break; } + expected_weight_format = assembly_utils::map_to_arm_compute_weight_format(arm_gemm_expected_wf); return Status{}; } @@ -729,7 +789,17 @@ Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && d->data_type() != DataType::QASYMM8, "Only QASYMM8 output supported for QASYMM8 input"); - return CpuGemmAssemblyDispatch::has_opt_impl(a, b, c, d, info); + arm_compute::WeightFormat expected_weight_format; + const Status ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info); + if((bool)ret && expected_weight_format != arm_compute::WeightFormat::ANY) + { + // Correctness check: if the format expected by the kernel is + // not "any", make sure that the one found matches the format + // intended by the caller. + ARM_COMPUTE_RETURN_ERROR_ON_MSG((expected_weight_format != info.weight_format), + "The format expected by the kernel does not correspond with the one requested by the user."); + } + return ret; } bool CpuGemmAssemblyDispatch::is_activation_supported(const ActivationLayerInfo &activation) @@ -778,11 +848,11 @@ void CpuGemmAssemblyDispatch::configure(const ITensorInfo *a, const ITensorInfo } break; #endif /* __aarch64__ */ -#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) +#if defined(ARM_COMPUTE_ENABLE_BF16) case DataType::BFLOAT16: create_arm_gemm(_arm_gemm, a, b, c, d, act, info); break; -#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ +#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: create_arm_gemm(_arm_gemm, a, b, c, d, act, info); @@ -801,7 +871,7 @@ void CpuGemmAssemblyDispatch::prepare(ITensorPack &tensors) bool CpuGemmAssemblyDispatch::is_configured() const { - return _arm_gemm != nullptr && _arm_gemm->is_configured(); + return _arm_gemm && _arm_gemm->is_configured(); } void CpuGemmAssemblyDispatch::run(ITensorPack &tensors) diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h index 74359eee72..691eeff8d2 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h @@ -41,17 +41,19 @@ enum class AsmConvMethod struct AsmGemmInfo { - AsmConvMethod method{ AsmConvMethod::Im2Col }; - PadStrideInfo ps_info{}; - ActivationLayerInfo activation_info{}; - GEMMLowpOutputStageInfo output_stage{}; - bool negated_offsets{ true }; - bool reinterpret_input_as_3d{ false }; - bool depth_output_gemm3d{ false }; - int64_t padding_top{ 0 }; - int64_t padding_left{ 0 }; - float padding_value{ 0.f }; - bool fast_mode{ false }; + AsmConvMethod method{ AsmConvMethod::Im2Col }; + PadStrideInfo ps_info{}; + ActivationLayerInfo activation_info{}; + GEMMLowpOutputStageInfo output_stage{}; + bool negated_offsets{ true }; + bool reinterpret_input_as_3d{ false }; + bool depth_output_gemm3d{ false }; + int64_t padding_top{ 0 }; + int64_t padding_left{ 0 }; + float padding_value{ 0.f }; + bool fast_mode{ false }; + bool fixed_format{ false }; + arm_compute::WeightFormat weight_format{ arm_compute::WeightFormat::UNSPECIFIED }; }; /** Assembly kernel glue */ @@ -68,11 +70,12 @@ class CpuGemmAssemblyDispatch : public ICpuOperator class IFallback { public: - virtual void run(ITensorPack &tensors) = 0; - virtual void prepare(ITensorPack &tensors) = 0; - virtual experimental::MemoryRequirements workspace() const = 0; - virtual bool is_configured() const = 0; - virtual ~IFallback() = default; + virtual void run(ITensorPack &tensors) = 0; + virtual void prepare(ITensorPack &tensors) = 0; + virtual experimental::MemoryRequirements workspace() const = 0; + virtual bool is_configured() const = 0; + virtual bool isVarWeightsKernel() const = 0; + virtual ~IFallback() = default; }; public: @@ -100,15 +103,14 @@ class CpuGemmAssemblyDispatch : public ICpuOperator /** Indicates whether or not there is an optimal assembly implementation that can be used to process the given parameters. * - * @param[in] a Input tensor info (Matrix A) - * @param[in] b Input tensor info (Matrix B) - * @param[in] c Input tensor info (Matrix C) used to pass the bias for quantized calculations - * @param[in] d Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0. - * @param[in] info GEMM meta-data + * This method has the same use of @ref + * NEGEMMConvolutionLayer::has_opt_impl, with the only caveat that + * the value of arm_compute::WeightFormat need to be passed via the + * parameter info. * * @return a status. */ - static Status has_opt_impl(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info); + static Status has_opt_impl(arm_compute::WeightFormat &weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info); /** Checks if activation is supported by the gemm assembly dispatcher * * @param[in] activation Activation to check @@ -121,10 +123,18 @@ class CpuGemmAssemblyDispatch : public ICpuOperator * @return True if the function is configured and ready to run */ bool is_configured() const; + /** Indicates if the convolution executes in variable weights mode. + * + * Similar to @ref CpuGemm::isVarWeightsKernel + */ + bool isVarWeightsKernel() const + { + return _arm_gemm && _arm_gemm->isVarWeightsKernel(); + } // Inherited methods overridden: - void prepare(ITensorPack &tensors) override; - void run(ITensorPack &tensors) override; + void prepare(ITensorPack &tensors) override; + void run(ITensorPack &tensors) override; experimental::MemoryRequirements workspace() const override; private: diff --git a/src/gpu/cl/ClKernelLibrary.cpp b/src/gpu/cl/ClKernelLibrary.cpp index 1bf7f2b3ac..0f08f5d044 100644 --- a/src/gpu/cl/ClKernelLibrary.cpp +++ b/src/gpu/cl/ClKernelLibrary.cpp @@ -272,6 +272,8 @@ const std::map ClKernelLibrary::_kernel_program_map = { "gemm_mv", "common/gemv.cl" }, { "gemm_mv_quantized", "common/gemv.cl" }, { "gemm_mm_native", "common/gemm.cl" }, + { "gemm_mm_reshaped_only_rhs_nt_mmul", "common/gemm_reshaped_only_rhs_mmul.cl" }, + { "gemm_mm_reshaped_only_rhs_nt_mmul_texture", "common/gemm_reshaped_only_rhs_mmul.cl" }, { "gemm_mm_native_post_act_eltwise_op_act", "common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_native.cl" }, { "gemm_mm_reshaped_lhs_nt_rhs_t", "common/gemm.cl" }, { "gemm_mm_reshaped_lhs_nt_rhs_t_texture", "common/gemm.cl" }, @@ -301,6 +303,7 @@ const std::map ClKernelLibrary::_kernel_program_map = { "gemmlowp_mm_reshaped_lhs_nt_rhs_t", "common/gemmlowp.cl" }, { "gemmlowp_mm_reshaped_only_rhs_t", "common/gemmlowp.cl" }, { "gemmlowp_mm_reshaped_only_rhs_t_fused_output_stage_fixedpoint", "common/gemmlowp.cl" }, + { "gemmlowp_mm_reshaped_only_rhs_mmul", "common/gemmlowp_reshaped_only_rhs_mmul.cl" }, { "gemmlowp_offset_contribution", "common/gemmlowp.cl" }, { "gemmlowp_offset_contribution_quantize_down", "common/gemmlowp.cl" }, { "gemmlowp_offset_contribution_quantize_down_fixedpoint", "common/gemmlowp.cl" }, @@ -582,6 +585,10 @@ const std::map ClKernelLibrary::_program_source_map = { "common/gemm.cl", #include "./cl_kernels/common/gemm.clembed" + }, + { + "common/gemm_reshaped_only_rhs_mmul.cl", +#include "./cl_kernels/common/gemm_reshaped_only_rhs_mmul.clembed" }, { "common/gemm_utils.cl", @@ -610,6 +617,10 @@ const std::map ClKernelLibrary::_program_source_map = { "common/gemmlowp.cl", #include "./cl_kernels/common/gemmlowp.clembed" + }, + { + "common/gemmlowp_reshaped_only_rhs_mmul.cl", +#include "./cl_kernels/common/gemmlowp_reshaped_only_rhs_mmul.clembed" }, { "common/gemv.cl", diff --git a/src/gpu/cl/kernels/ClCastKernel.cpp b/src/gpu/cl/kernels/ClCastKernel.cpp index bfcd152297..6baa31e710 100644 --- a/src/gpu/cl/kernels/ClCastKernel.cpp +++ b/src/gpu/cl/kernels/ClCastKernel.cpp @@ -52,7 +52,7 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *dst, Conver ARM_COMPUTE_RETURN_ERROR_ON(src == dst); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, - DataType::U8, DataType::S8, DataType::QASYMM8, DataType::QSYMM8_PER_CHANNEL, DataType::S16, + DataType::U8, DataType::S8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8_PER_CHANNEL, DataType::S16, DataType::U16, DataType::U32, DataType::S32, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, diff --git a/src/gpu/cl/kernels/ClCastKernel.h b/src/gpu/cl/kernels/ClCastKernel.h index 5c223fc5fa..7fadfa73d0 100644 --- a/src/gpu/cl/kernels/ClCastKernel.h +++ b/src/gpu/cl/kernels/ClCastKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2021 Arm Limited. + * Copyright (c) 2016-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -49,6 +49,7 @@ class ClCastKernel : public IClKernel * * - QSYMM8_PER_CHANNEL -> QASYMM8 (ATTENTION: it is the user's responsibility to keep track of the quantization info in the TensorInfo meta-data) * - U8 -> S8, U16, S16, U32, S32, F16, F32 + * - S8 -> U8, U16, S16, U32, S32, F16, F32 * - U16 -> U8, S8, S16, U32, S32, F16, F32 * - S16 -> U8, S8, U16, U32, S32, F16, F32 * - U32 -> U8, S8, U16, S16, S32, F16, F32 diff --git a/src/gpu/cl/kernels/ClDirectConv2dKernel.cpp b/src/gpu/cl/kernels/ClDirectConv2dKernel.cpp index ff8c2c32a0..c4b70ca82b 100644 --- a/src/gpu/cl/kernels/ClDirectConv2dKernel.cpp +++ b/src/gpu/cl/kernels/ClDirectConv2dKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -23,11 +23,11 @@ */ #include "src/gpu/cl/kernels/ClDirectConv2dKernel.h" -#include "arm_compute/core/CL/CLHelpers.h" #include "arm_compute/core/CL/CLKernelLibrary.h" #include "arm_compute/core/CL/ICLTensor.h" #include "arm_compute/core/Helpers.h" #include "arm_compute/core/ITensor.h" +#include "arm_compute/core/KernelDescriptors.h" #include "arm_compute/core/PixelValue.h" #include "arm_compute/core/Utils.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" @@ -40,6 +40,7 @@ #include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h" #include "support/Cast.h" #include "support/StringSupport.h" + namespace arm_compute { namespace opencl @@ -49,7 +50,7 @@ namespace kernels namespace { Status validate_arguments(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, - const PadStrideInfo &conv_info, const ActivationLayerInfo &act_info) + const PadStrideInfo &conv_info, const ActivationLayerInfo &act_info, const DirectConvComputeKernelInfo &desc) { ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(src); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8_SIGNED, DataType::QASYMM8, DataType::F16, DataType::F32); @@ -83,6 +84,21 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *weights, co } } + if(data_layout == DataLayout::NHWC) + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG(desc.n0 != 1 && desc.n0 != 2 && desc.n0 != 3 && desc.n0 != 4 && desc.n0 != 8 && desc.n0 != 16, + "N0 can only be: 1, 2, 3, 4, 8, and 16"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(desc.k0 != 1 && desc.k0 != 2 && desc.k0 != 3 && desc.k0 != 4 && desc.k0 != 8 && desc.k0 != 16, + "K0 can only be: 1, 2, 3, 4, 8, and 16"); + if(desc.export_weights_to_cl_image) + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG(desc.k0 != 4 && desc.k0 != 8 && desc.k0 != 16, + "K0 can only be: 4, 8, and 16"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!export_weights_to_cl_image(weights), + "Export to CLImage is not supported for this weight configuration"); + } + } + if(biases != nullptr) { if(is_data_type_quantized_asymmetric(src->data_type())) @@ -121,50 +137,6 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *weights, co } return Status{}; } - -bool export_to_cl_image_support(ITensorInfo *tensor, GPUTarget gpu_target, DataLayout data_layout) -{ - if(tensor->tensor_shape()[0] % 4 || (data_layout != DataLayout::NHWC)) - { - return false; - } - - // If not floating point - if(!is_data_type_float(tensor->data_type())) - { - return false; - } - - if(gpu_target == GPUTarget::G71 || get_arch_from_target(gpu_target) == GPUTarget::MIDGARD) - { - return false; - } - - // Check if the cl_khr_image2d_from_buffer extension is supported on the target platform - if(!image2d_from_buffer_supported(CLKernelLibrary::get().get_device())) - { - return false; - } - - // Check cl image pitch alignment - if(get_cl_image_pitch_alignment(CLKernelLibrary::get().get_device()) == 0) - { - return false; - } - - const size_t image_w = tensor->tensor_shape()[0] / 4; - const size_t image_h = tensor->tensor_shape()[1] * tensor->tensor_shape()[2] * tensor->tensor_shape()[3]; - const size_t max_image_w = CLKernelLibrary::get().get_device().getInfo(); - const size_t max_image_h = CLKernelLibrary::get().get_device().getInfo(); - - if(image_w > max_image_w || image_h > max_image_h) - { - return false; - } - - return true; -} - } // namespace ClDirectConv2dKernel::ClDirectConv2dKernel() @@ -173,12 +145,12 @@ ClDirectConv2dKernel::ClDirectConv2dKernel() } void ClDirectConv2dKernel::configure(const CLCompileContext &compile_context, ITensorInfo *src, ITensorInfo *weights, ITensorInfo *biases, ITensorInfo *dst, - const PadStrideInfo &conv_info, const ActivationLayerInfo &act_info) + const PadStrideInfo &conv_info, const ActivationLayerInfo &act_info, const DirectConvComputeKernelInfo &desc) { ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights, dst); // Perform validation - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src, weights, biases, dst, conv_info, act_info)); + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src, weights, biases, dst, conv_info, act_info, desc)); const int conv_stride_x = std::get<0>(conv_info.stride()); const int conv_stride_y = std::get<1>(conv_info.stride()); @@ -208,15 +180,12 @@ void ClDirectConv2dKernel::configure(const CLCompileContext &compile_context, IT Window win; if(_data_layout == DataLayout::NHWC) { - const unsigned int vec_size = std::min(static_cast(dst->tensor_shape()[0]), 4u); - unsigned int num_rows = 1U; - if(dst->tensor_shape()[0] > 16) - { - num_rows = src->data_type() == DataType::F32 ? 2U : 4U; - } + output_shape.collapse(2U, 1U); + const unsigned int n0 = adjust_vec_size(desc.n0, output_shape[0]); + const unsigned int m0 = adjust_vec_size(desc.m0, output_shape[1]); // Create window and update padding - win = calculate_max_window(output_shape, Steps(vec_size, num_rows)); + win = calculate_max_window(output_shape, Steps(n0, m0)); } else if(_data_layout == DataLayout::NCHW) { @@ -233,16 +202,17 @@ void ClDirectConv2dKernel::configure(const CLCompileContext &compile_context, IT { kernel_name << "direct_convolution_nhwc"; - const unsigned int n0 = win.x().step(); - const unsigned int m0 = win.y().step(); - const unsigned int k0 = adjust_vec_size(is_data_type_quantized(data_type) ? 16u : 8u, src->dimension(channel_idx)); - const unsigned int partial_store_n0 = dst->dimension(channel_idx) % n0; - const unsigned int pad_left = conv_info.pad_left(); - const unsigned int pad_top = conv_info.pad_top(); - const bool export_to_cl_image = export_to_cl_image_support(weights, gpu_target, _data_layout); + const unsigned int n0 = win.x().step(); + const unsigned int m0 = win.y().step(); + const unsigned int k0 = adjust_vec_size(desc.k0, src->dimension(channel_idx)); + const unsigned int partial_store_n0 = dst->dimension(channel_idx) % n0; + const unsigned int pad_left = conv_info.pad_left(); + const unsigned int pad_top = conv_info.pad_top(); + + _export_to_cl_image = desc.export_weights_to_cl_image; // Update the padding for the weights tensor if we can export to cl_image - if(export_to_cl_image) + if(_export_to_cl_image) { gemm::update_padding_for_cl_image(weights); } @@ -253,12 +223,28 @@ void ClDirectConv2dKernel::configure(const CLCompileContext &compile_context, IT build_options.add_option(std::string("-DBIA_DATA_TYPE=" + get_cl_type_from_data_type(biases->data_type()))); } - build_options.add_option("-cl-fast-relaxed-math"); + // Conditions of -cl-fast-relaxed-math causing accuracy issues can be traced from COMPMID-5324 + const auto act_function = act_info.activation(); + const auto dst_data_type = dst->data_type(); + + if((gpu_target != GPUTarget::G71 && (gpu_target & GPUTarget::GPU_ARCH_MASK) == GPUTarget::BIFROST) + && (act_function == ActivationLayerInfo::ActivationFunction::BOUNDED_RELU || act_function == ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU) + && (dst_data_type == DataType::F32 || dst_data_type == DataType::F16)) + { + // -cl-fast-relaxed-math also sets -cl-finite-math-only and -cl-unsafe-math-optimizations + // to disable -cl-finite-math-only, we only include -cl-unsafe-math-optimizations + build_options.add_option("-cl-unsafe-math-optimizations"); + } + else + { + build_options.add_option("-cl-fast-relaxed-math"); + } + build_options.add_option("-DSRC_TENSOR_TYPE=BUFFER"); build_options.add_option("-DSRC_DATA_TYPE=" + get_cl_type_from_data_type(src->data_type())); build_options.add_option("-DDST_TENSOR_TYPE=BUFFER"); - build_options.add_option("-DDST_DATA_TYPE=" + get_cl_type_from_data_type(dst->data_type())); - build_options.add_option_if_else(export_to_cl_image, "-DWEI_TENSOR_TYPE=IMAGE", "-DWEI_TENSOR_TYPE=BUFFER"); + build_options.add_option("-DDST_DATA_TYPE=" + get_cl_type_from_data_type(dst_data_type)); + build_options.add_option_if_else(_export_to_cl_image, "-DWEI_TENSOR_TYPE=IMAGE", "-DWEI_TENSOR_TYPE=BUFFER"); build_options.add_option("-DWEI_WIDTH=" + support::cpp11::to_string(weights->dimension(width_idx))); build_options.add_option("-DWEI_HEIGHT=" + support::cpp11::to_string(weights->dimension(height_idx))); build_options.add_option("-DWEI_DATA_TYPE=" + get_cl_type_from_data_type(weights->data_type())); @@ -271,7 +257,7 @@ void ClDirectConv2dKernel::configure(const CLCompileContext &compile_context, IT build_options.add_option("-DK0=" + support::cpp11::to_string(k0)); build_options.add_option("-DPARTIAL_N0=" + support::cpp11::to_string(partial_store_n0)); build_options.add_option_if((src->dimension(channel_idx) % k0) != 0, "-DLEFTOVER_LOOP"); - build_options.add_option("-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(act_info.activation()))); + build_options.add_option("-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(act_function))); if(is_data_type_quantized(data_type)) { @@ -309,6 +295,8 @@ void ClDirectConv2dKernel::configure(const CLCompileContext &compile_context, IT } else { + _export_to_cl_image = false; + kernel_name << "direct_convolution_nchw"; build_options.add_option_if(biases != nullptr, std::string("-DHAS_BIAS")); build_options.add_option("-DSRC_WIDTH=" + support::cpp11::to_string(src->dimension(width_idx))); @@ -377,9 +365,9 @@ void ClDirectConv2dKernel::configure(const CLCompileContext &compile_context, IT } Status ClDirectConv2dKernel::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, - const PadStrideInfo &conv_info, const ActivationLayerInfo &act_info) + const PadStrideInfo &conv_info, const ActivationLayerInfo &act_info, const DirectConvComputeKernelInfo &desc) { - ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src, weights, biases, dst, conv_info, act_info)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src, weights, biases, dst, conv_info, act_info, desc)); return Status{}; } @@ -400,13 +388,7 @@ void ClDirectConv2dKernel::run_op(ITensorPack &tensors, const Window &window, cl { cl::Image2D weights_cl_image; - const size_t dim_y_collapsed = ceil_to_multiple(dst->info()->dimension(1) * dst->info()->dimension(2), slice.y().step()); - const bool export_to_cl_image = export_to_cl_image_support(weights->info(), get_target(), _data_layout); - - slice.set(Window::DimY, Window::Dimension(0, dim_y_collapsed, slice.y().step())); - slice.set(Window::DimZ, Window::Dimension(0, dst->info()->dimension(3), 1)); - - if(export_to_cl_image) + if(_export_to_cl_image) { const size_t image_w = weights->info()->dimension(0) / 4; const size_t image_h = weights->info()->dimension(1) * weights->info()->dimension(2) * weights->info()->dimension(3); @@ -420,7 +402,7 @@ void ClDirectConv2dKernel::run_op(ITensorPack &tensors, const Window &window, cl unsigned int idx = 0; add_4d_tensor_nhwc_argument(idx, src); add_4d_tensor_nhwc_argument(idx, dst); - if(export_to_cl_image) + if(_export_to_cl_image) { _kernel.setArg(idx++, weights_cl_image); } diff --git a/src/gpu/cl/kernels/ClDirectConv2dKernel.h b/src/gpu/cl/kernels/ClDirectConv2dKernel.h index 5681927816..0cb8aebbe1 100644 --- a/src/gpu/cl/kernels/ClDirectConv2dKernel.h +++ b/src/gpu/cl/kernels/ClDirectConv2dKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -30,6 +30,9 @@ namespace arm_compute { +// Forward declaration +struct DirectConvComputeKernelInfo; + namespace opencl { namespace kernels @@ -62,9 +65,10 @@ class ClDirectConv2dKernel : public IClKernel * The 3rd dimensions must be equal to the 4th dimension of the @p kernels tensor. Data types supported: Same as @p src. * @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo. * @param[in] act_info Contains activaton information described in @ref ActivationLayerInfo. + * @param[in] desc Direct convolution descriptor used to build the NHWC direct convolution kernel. For NCHW, this parameter is ignored. */ void configure(const CLCompileContext &compile_context, ITensorInfo *src, ITensorInfo *weights, ITensorInfo *biases, ITensorInfo *dst, - const PadStrideInfo &conv_info, const ActivationLayerInfo &act_info); + const PadStrideInfo &conv_info, const ActivationLayerInfo &act_info, const DirectConvComputeKernelInfo &desc); /** Static function to check if given info will lead to a valid configuration * * Similar to ClDirectConv2dKernel::configure() @@ -72,7 +76,7 @@ class ClDirectConv2dKernel : public IClKernel * @return a status */ static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, - const PadStrideInfo &conv_info, const ActivationLayerInfo &act_info); + const PadStrideInfo &conv_info, const ActivationLayerInfo &act_info, const DirectConvComputeKernelInfo &desc); // Inherited methods overridden: void run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) override; @@ -80,6 +84,7 @@ class ClDirectConv2dKernel : public IClKernel public: DataLayout _data_layout{}; PadStrideInfo _conv_info{}; + bool _export_to_cl_image{ false }; }; } // namespace kernels } // namespace opencl diff --git a/src/gpu/cl/kernels/ClDirectConv3dKernel.cpp b/src/gpu/cl/kernels/ClDirectConv3dKernel.cpp index a0735b1112..79f425189a 100644 --- a/src/gpu/cl/kernels/ClDirectConv3dKernel.cpp +++ b/src/gpu/cl/kernels/ClDirectConv3dKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -42,6 +42,8 @@ Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, cons { ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_LAYOUT(src0, src1, dst); ARM_COMPUTE_RETURN_ERROR_ON_MSG(src0->data_layout() != DataLayout::NDHWC, "Only NDHWC layout supported"); + + // When fusing activation, same workaround introduced for COMPMID-5324 may be necessary ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv3d_info.act_info.enabled(), "Fused activation not supported"); ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(src0); diff --git a/src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp b/src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp new file mode 100644 index 0000000000..cdd047cb28 --- /dev/null +++ b/src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp @@ -0,0 +1,480 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel.h" + +#include "arm_compute/core/CL/CLHelpers.h" +#include "arm_compute/core/CL/ICLTensor.h" +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" + +#include "src/core/helpers/AutoConfiguration.h" +#include "src/core/helpers/WindowHelpers.h" + +#include "support/Cast.h" + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +using namespace misc::shape_calculator; + +namespace +{ +using ElementsProcessed = Steps; + +Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst, const GEMMKernelInfo &gemm_info, + const ITensorInfo *vector_sum_col, const ITensorInfo *vector_sum_row, const ITensorInfo *bias, + const ITensorInfo *output_multipliers, const ITensorInfo *output_shifts) +{ + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!arm_matrix_multiply_supported(CLKernelLibrary::get().get_device()), "The extension cl_arm_matrix_multiply is not supported on the target platform"); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src0, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src0, src1); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(src0->num_dimensions() > 4, "The number of dimensions for the LHS matrix must be <= 4"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(src1->num_dimensions() > 3, "The number of dimensions for the RHS matrix must be <= 3"); + + const GEMMRHSMatrixInfo rhs_info = gemm_info.rhs_info; + const GEMMLHSMatrixInfo lhs_info = gemm_info.lhs_info; + const GEMMLowpOutputStageInfo output_stage = gemm_info.output_stage; + + ARM_COMPUTE_RETURN_ERROR_ON_MSG(rhs_info.k0 != 4 || lhs_info.k0 != 4, "Only 4 is supported as value for k0"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(lhs_info.m0 == 1 || lhs_info.m0 == 2 || lhs_info.m0 == 4), "Only 1,2,4 are supported for m0"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(rhs_info.n0 == 1 || rhs_info.n0 == 4 || rhs_info.n0 == 8), "Only 1,4,8 are supported for n0"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(rhs_info.export_to_cl_image, "Export to CLImage not supported for quantized GEMM"); + + const int m = gemm_info.m; + const int n = gemm_info.n; + const int k = gemm_info.k; + + TensorShape tensor_shape1{ src1->tensor_shape() }; + tensor_shape1.set(0, n); + tensor_shape1.set(1, k); + + const TensorInfo tensor_info1 = src1->clone()->set_tensor_shape(tensor_shape1); + const TensorInfo tensor_info_reshaped1 = src1->clone()->set_tensor_shape(compute_rhs_reshaped_shape(tensor_info1, rhs_info)); + + ARM_COMPUTE_RETURN_ERROR_ON(src0->dimension(0) != static_cast(k)); + if(gemm_info.reinterpret_input_as_3d) + { + ARM_COMPUTE_RETURN_ERROR_ON(src0->dimension(1) * src0->dimension(2) != static_cast(m)); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON(src0->dimension(1) != static_cast(m)); + } + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(src1, &tensor_info_reshaped1); + + const TensorShape expected_dst_shape = compute_mm_shape(*src0, *src1, gemm_info); + if(dst->total_size() != 0) + { + const TensorInfo tensor_info_dst = dst->clone()->set_tensor_shape(expected_dst_shape); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(dst, &tensor_info_dst); + if(output_stage.type == GEMMLowpOutputStageType::NONE) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, 1, DataType::S32); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src0, dst); + } + } + + if(bias != nullptr) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bias, 1, DataType::S32); + ARM_COMPUTE_RETURN_ERROR_ON(expected_dst_shape[0] != bias->dimension(0)); + } + + ARM_COMPUTE_RETURN_ERROR_ON_MSG((output_stage.type == GEMMLowpOutputStageType::QUANTIZE_DOWN) || (output_stage.type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FLOAT), + "Only GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT is supported"); + + // Checks performed if the dst stage needs to be fused + if(output_stage.type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT) + { + // If a_offset == 0, vector_sum_col can be a nullptr + if(gemm_info.a_offset != 0) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(vector_sum_col, 1, DataType::S32); + ARM_COMPUTE_RETURN_ERROR_ON(vector_sum_col->dimension(0) != expected_dst_shape[0]); + } + + // If b_offset == 0, vector_sum_row can be a nullptr + if(gemm_info.b_offset != 0) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(vector_sum_row, 1, DataType::S32); + + // Check if mm result is a 3D reinterpretation + const bool reinterpret_as_3d = expected_dst_shape.num_dimensions() > 1 && expected_dst_shape.y() != vector_sum_row->tensor_shape().x(); + + // Validate input + ARM_COMPUTE_RETURN_ERROR_ON(reinterpret_as_3d && vector_sum_row->dimension(0) != (expected_dst_shape[1] * expected_dst_shape[2])); + ARM_COMPUTE_RETURN_ERROR_ON(!reinterpret_as_3d && vector_sum_row->dimension(0) != expected_dst_shape[1]); + + if(expected_dst_shape.num_dimensions() > 1) + { + const unsigned int dst_batch_idx = reinterpret_as_3d ? 3 : 2; + + TensorShape vector_sum_row_shape = vector_sum_row->tensor_shape(); + vector_sum_row_shape.collapse_from(1); + TensorShape collapsed_dst_shape(expected_dst_shape); + collapsed_dst_shape.collapse_from(dst_batch_idx); + + ARM_COMPUTE_RETURN_ERROR_ON_MSG(vector_sum_row_shape[1] != collapsed_dst_shape[dst_batch_idx], + "vector_sum_row must have the same number of batches of dst tensor"); + + if(gemm_info.a_offset != 0) + { + TensorShape vector_sum_col_shape = vector_sum_col->tensor_shape(); + vector_sum_col_shape.collapse_from(1); + + ARM_COMPUTE_RETURN_ERROR_ON_MSG(vector_sum_col_shape[1] != 1 && vector_sum_col_shape[1] != vector_sum_row_shape[1], + "vector_sum_col tensor must have the same number of batches of vector_sum_row_shape or the number of batches must be set to 1"); + } + } + } + + if(dst->total_size() != 0) + { + ARM_COMPUTE_RETURN_ERROR_ON(output_stage.output_data_type != dst->data_type()); + } + ARM_COMPUTE_RETURN_ERROR_ON(output_stage.gemmlowp_min_bound > output_stage.gemmlowp_max_bound); + + if(output_multipliers != nullptr && output_shifts != nullptr) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output_multipliers, 1, DataType::S32); + ARM_COMPUTE_RETURN_ERROR_ON(output_multipliers->num_dimensions() > 1); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output_shifts, 1, DataType::S32); + ARM_COMPUTE_RETURN_ERROR_ON(output_shifts->num_dimensions() > 1); + if(output_stage.is_quantized_per_channel) + { + ARM_COMPUTE_RETURN_ERROR_ON(expected_dst_shape[0] != output_shifts->dimension(0)); + ARM_COMPUTE_RETURN_ERROR_ON(expected_dst_shape[0] != output_multipliers->dimension(0)); + } + } + } + return Status{}; +} + +std::pair validate_and_configure_window(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst, const GEMMKernelInfo &gemm_info, + ITensorInfo *vector_sum_col, const ITensorInfo *vector_sum_row, ITensorInfo *bias, + ITensorInfo *output_multipliers, ITensorInfo *output_shifts, ElementsProcessed &num_elements_processed) +{ + const GEMMLowpOutputStageInfo output_stage = gemm_info.output_stage; + + unsigned int &num_elems_processed_per_iteration_x = num_elements_processed[0]; + unsigned int &num_elems_processed_per_iteration_y = num_elements_processed[1]; + bool reinterpret_output_as_3d = (gemm_info.depth_output_gemm3d != 0); + + Window win{}; + bool window_changed = false; + + constexpr unsigned int mmul_n0 = 4; + constexpr unsigned int mmul_m0 = 4; + constexpr unsigned int mmul_k0 = 16; + + reinterpret_output_as_3d = false; + // dst tensor auto initialization if not yet initialized + const TensorShape expected_dst_shape = compute_mm_shape(*src0, *src1, gemm_info); + if(output_stage.type != GEMMLowpOutputStageType::NONE) + { + auto_init_if_empty(*dst, src0->clone()->set_tensor_shape(expected_dst_shape).set_data_type(output_stage.output_data_type)); + } + else + { + auto_init_if_empty(*dst, src0->clone()->set_tensor_shape(expected_dst_shape).set_data_type(DataType::S32)); + } + + TensorInfo tmp_info(*dst); + + if(reinterpret_output_as_3d) + { + // Since the dst tensor has to be reinterpreted as 3D and the execute window is based on a 2D GEMM, + // the window needs to be constructed on the 2D collapsed version of the tensor + TensorShape tmp_shape(dst->tensor_shape()); + tmp_shape.collapse(2U, 1U); + tmp_info.set_tensor_shape(tmp_shape); + } + + // Configure kernel window + num_elems_processed_per_iteration_x = 1; + num_elems_processed_per_iteration_y = 1; + + win = calculate_max_window(tmp_info, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y)); + + if(output_stage.type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT) + { + if(gemm_info.a_offset != 0) + { + AccessWindowHorizontal vector_sum_col_access(vector_sum_col, 0, num_elems_processed_per_iteration_x); + window_changed = window_changed || update_window_and_padding(win, vector_sum_col_access); + } + // No access window needed for vector_sum_row + ARM_COMPUTE_UNUSED(vector_sum_row); + + if(bias != nullptr) + { + AccessWindowHorizontal bias_access(bias, 0, num_elems_processed_per_iteration_x); + window_changed = window_changed || update_window_and_padding(win, bias_access); + } + + if(output_multipliers != nullptr && output_stage.is_quantized_per_channel) + { + AccessWindowHorizontal output_multipliers_access(output_multipliers, 0, num_elems_processed_per_iteration_x); + AccessWindowHorizontal output_shifts_access(output_shifts, 0, num_elems_processed_per_iteration_x); + window_changed = window_changed || update_window_and_padding(win, output_multipliers_access, output_shifts_access); + } + } + + // Collapse along the Z direction + // This collapse needs to be here in order to tune the Z dimension of LWS + const unsigned int dimension_to_collapse = std::min(static_cast(dst->num_dimensions()), 2u); + Window collapsed = win.collapse(win, dimension_to_collapse); + + // Reconfigure window size, one arm_matrix_multiply kernel needs 16 threads to finish. + Window::Dimension x_dimension = collapsed.x(); + Window::Dimension y_dimension = collapsed.y(); + + // Make M and N multiple of M0 and N0 respectively + const unsigned int ceil_to_multiple_n_n0 = ceil_to_multiple(x_dimension.end(), gemm_info.rhs_info.n0); + const unsigned int ceil_to_multiple_m_m0 = ceil_to_multiple(y_dimension.end(), gemm_info.lhs_info.m0); + + // Divide M and N by M0 and N0 respectively + const unsigned int n_div_n0 = ceil_to_multiple_n_n0 / gemm_info.rhs_info.n0; + const unsigned int m_div_m0 = ceil_to_multiple_m_m0 / gemm_info.lhs_info.m0; + + // Make n_div_n0 and m_div_m0 multiple of mmul_n0 and mmul_k0 respectively + const unsigned int ceil_to_multiple_n_div_n0_mmul_n0 = ceil_to_multiple(n_div_n0, mmul_n0); + const unsigned int ceil_to_multiple_m_div_m0_mmul_m0 = ceil_to_multiple(m_div_m0, mmul_k0); + + // Ensure x_dimension is multiple of MMUL block size (mmul_n0 * mmul_m0) + x_dimension.set_end(ceil_to_multiple_n_div_n0_mmul_n0 * mmul_n0); + y_dimension.set_end(ceil_to_multiple_m_div_m0_mmul_m0 / mmul_m0); + + collapsed.set(Window::DimX, x_dimension); + collapsed.set(Window::DimY, y_dimension); + + Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{}; + return std::make_pair(err, collapsed); +} +} // namespace + +ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel::ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel() +{ + _type = CLKernelType::GEMM; +} + +void ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel::configure(const CLCompileContext &compile_context, const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst, + const GEMMKernelInfo &gemm_info, + ITensorInfo *vector_sum_col, const ITensorInfo *vector_sum_row, ITensorInfo *bias, + ITensorInfo *output_multipliers, ITensorInfo *output_shifts) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst); + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src0, src1, dst, gemm_info, vector_sum_col, vector_sum_row, bias, output_multipliers, output_shifts)); + + auto padding_info = get_padding_info({ src0, src1, dst, vector_sum_row }); + const GEMMRHSMatrixInfo rhs_info = gemm_info.rhs_info; + const GEMMLHSMatrixInfo lhs_info = gemm_info.lhs_info; + const GEMMLowpOutputStageInfo output_stage = gemm_info.output_stage; + const int32_t a_offset = gemm_info.a_offset; + const int32_t b_offset = gemm_info.b_offset; + constexpr int mmul_m0 = 4; + constexpr int mmul_n0 = 4; + constexpr int mmul_k0 = 16; + + _m = gemm_info.m; + _n = gemm_info.n; + _k = gemm_info.k; + + ElementsProcessed num_elements_processed{}; + + // Configure kernel window + auto win_config = validate_and_configure_window(src0, src1, dst, gemm_info, vector_sum_col, vector_sum_row, bias, output_multipliers, output_shifts, num_elements_processed); + ARM_COMPUTE_ERROR_THROW_ON(win_config.first); + ICLKernel::configure_internal(win_config.second); + + const unsigned int m0_leftover = _m % lhs_info.m0; + const unsigned int n0_leftover = _n % rhs_info.n0; + + // Create build options + CLBuildOptions build_opts; + build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(src0->data_type())); + build_opts.add_option("-DVEC_TYPE=" + get_cl_type_from_data_type(src0->data_type()) + "4"); + build_opts.add_option("-DACC_DATA_TYPE=int"); + build_opts.add_option("-DOUT_DATA_TYPE=" + get_cl_type_from_data_type(dst->data_type())); + build_opts.add_option("-DM0=" + support::cpp11::to_string(lhs_info.m0)); + build_opts.add_option("-DN0=" + support::cpp11::to_string(rhs_info.n0)); + build_opts.add_option("-DK0=" + support::cpp11::to_string(rhs_info.k0)); + build_opts.add_option("-DM0_LEFTOVER=" + support::cpp11::to_string(m0_leftover)); + build_opts.add_option("-DN0_LEFTOVER=" + support::cpp11::to_string(n0_leftover)); + build_opts.add_option("-DMMUL_M0=" + support::cpp11::to_string(mmul_m0)); + build_opts.add_option("-DMMUL_N0=" + support::cpp11::to_string(mmul_n0)); + build_opts.add_option("-DMMUL_K0=" + support::cpp11::to_string(mmul_k0)); + build_opts.add_option("-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(gemm_info.activation_info.activation()))); + build_opts.add_option("-DA_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.a())); + build_opts.add_option("-DB_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.b())); + + std::string kernel_name("gemmlowp_mm_reshaped_only_rhs_mmul"); + + if(output_stage.type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT) + { + build_opts.add_option("-DFUSED_OUTPUT_STAGE_FIXED_POINT"); + _fuse_output_stage = true; + // If a_offset == 0, vector_sum_col can be a nullptr + if(a_offset != 0 && vector_sum_col != nullptr) + { + build_opts.add_option("-DA_OFFSET=" + support::cpp11::to_string(a_offset)); + build_opts.add_option_if(vector_sum_col->tensor_shape().num_dimensions() > 1, "-DSUM_COL_HAS_BATCHES"); + } + // If b_offset == 0, vector_sum_row can be a nullptr + build_opts.add_option_if(b_offset != 0, "-DB_OFFSET=" + support::cpp11::to_string(b_offset)); + build_opts.add_option("-DK_OFFSET=" + support::cpp11::to_string(a_offset * b_offset * src0->dimension(0))); + build_opts.add_option_if(bias != nullptr, "-DADD_BIAS"); + build_opts.add_option_if(gemm_info.broadcast_bias == true, "-DBROADCAST_BIAS"); + build_opts.add_option("-DRESULT_OFFSET=" + support::cpp11::to_string(output_stage.gemmlowp_offset)); + build_opts.add_option("-DRESULT_MULTIPLIER=" + support::cpp11::to_string(output_stage.gemmlowp_multipliers[0])); + build_opts.add_option("-DRESULT_SHIFT=" + support::cpp11::to_string(output_stage.gemmlowp_shifts[0])); + + const int min = output_stage.gemmlowp_min_bound; + const int max = output_stage.gemmlowp_max_bound; + + PixelValue min_val{}; + PixelValue max_val{}; + std::tie(min_val, max_val) = get_min_max(dst->data_type()); + build_opts.add_option_if(min != min_val.get(), "-DMIN_BOUND=" + support::cpp11::to_string(min)); + build_opts.add_option_if(max != max_val.get(), "-DMAX_BOUND=" + support::cpp11::to_string(max)); + } + + // A macro guard to compile ONLY the kernel of interest + build_opts.add_option("-D" + upper_string(kernel_name)); + + // Create kernel + _kernel = create_kernel(compile_context, kernel_name, build_opts.options()); + + // Set config_id for enabling LWS tuning + _config_id = kernel_name; + _config_id += "_"; + _config_id += (bias != nullptr ? "add_bias_" : ""); + _config_id += (gemm_info.broadcast_bias ? "broadcast_bias_" : ""); + _config_id += (gemm_info.activation_info.enabled() ? "fused_activation_" : ""); + _config_id += lower_string(string_from_data_type(src0->data_type())); + _config_id += "_"; + _config_id += support::cpp11::to_string(_m); + _config_id += "_"; + _config_id += support::cpp11::to_string(_n); + _config_id += "_"; + _config_id += support::cpp11::to_string(_k); + _config_id += "_"; + _config_id += support::cpp11::to_string(lhs_info.m0); + _config_id += "_"; + _config_id += support::cpp11::to_string(rhs_info.n0); + + ARM_COMPUTE_ERROR_ON(has_padding_changed(padding_info)); +} + +Status ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst, const GEMMKernelInfo &gemm_info, + const ITensorInfo *vector_sum_col, const ITensorInfo *vector_sum_row, const ITensorInfo *bias, + const ITensorInfo *output_multipliers, const ITensorInfo *output_shifts) +{ + ElementsProcessed num_elements_processed{}; + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src0, src1, dst, gemm_info, vector_sum_col, vector_sum_row, bias, output_multipliers, output_shifts)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(src0->clone().get(), + src1->clone().get(), + dst->clone().get(), + gemm_info, + vector_sum_col != nullptr ? vector_sum_col->clone().get() : nullptr, + vector_sum_row != nullptr ? vector_sum_row->clone().get() : nullptr, + bias != nullptr ? bias->clone().get() : nullptr, + output_multipliers != nullptr ? output_multipliers->clone().get() : nullptr, + output_shifts != nullptr ? output_shifts->clone().get() : nullptr, + num_elements_processed) + .first); + + return Status{}; +} + +void ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel::run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) +{ + ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); + ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICLKernel::window(), window); + + const auto src0 = utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_SRC_0)); + const auto src1 = utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_SRC_1)); + const auto src2 = utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_SRC_2)); + const auto vector_sum_col = utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_VEC_COL_SUM)); + const auto vector_sum_row = utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_VEC_ROW_SUM)); + auto dst = utils::cast::polymorphic_downcast(tensors.get_tensor(TensorType::ACL_DST)); + + ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst); + + if(src1->info()->num_dimensions() < 3) + { + // The stride_z for matrix B must be zero if we do not slice + ARM_COMPUTE_ERROR_ON(src1->info()->strides_in_bytes()[3] != 0); + } + + cl::Image2D src1_image2d; + + Window slice = window.first_slice_window_3D(); + + do + { + unsigned int idx = 0; + + add_3d_tensor_nhw_argument(idx, src0); + add_3d_tensor_nhw_argument(idx, src1); + + // Bias buffer (_add_bias == true) + if(src2 != nullptr) + { + add_3d_tensor_nhw_argument(idx, src2); + } + // dst buffer + add_3d_tensor_nhw_argument(idx, dst); + + // Pass m, n and k at runtime as signed ints, to ensure results of any subtraction they could be operand in, would still be signed. + _kernel.setArg(idx++, _m); + _kernel.setArg(idx++, _n); + _kernel.setArg(idx++, _k); + + if(_fuse_output_stage) + { + if(vector_sum_col != nullptr) + { + add_3d_tensor_nhw_argument(idx, vector_sum_col); + } + if(vector_sum_row != nullptr) + { + add_3d_tensor_nhw_argument(idx, vector_sum_row); + } + } + + enqueue(queue, *this, slice, cl::NDRange(32, 2), false); + } + while(window.slide_window_slice_3D(slice)); +} +} // namespace kernels +} // namespace opencl +} // namespace arm_compute diff --git a/src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel.h b/src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel.h new file mode 100644 index 0000000000..0ae549cd53 --- /dev/null +++ b/src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel.h @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CL_GEMMLOWP_MATRIXMULTIPLY_RESHAPED_ONLY_RHS_MMUL_KERNEL_H +#define ARM_COMPUTE_CL_GEMMLOWP_MATRIXMULTIPLY_RESHAPED_ONLY_RHS_MMUL_KERNEL_H + +#include "arm_compute/core/KernelDescriptors.h" +#include "src/core/common/Macros.h" +#include "src/gpu/cl/IClKernel.h" + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +/** OpenCL kernel to multiply matrices with QASYMM8/QASYMM8_SIGNED data types when only the input matrix RHS (src1) has been reshaped using the MMUL instruction + * + * @note The input matrix src1 must be reshaped through @ref opencl::kernels::ClGemmReshapeRhsMatrixKernel + * @note For fused output stage, only GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT type is supported + */ +class ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel : public IClKernel +{ +public: + ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel(); + ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel); + /** Initialise the kernel's source and destination. + * + * @param[in] compile_context The compile context to be used. + * @param[in] src0 Input tensor containing the LHS matrix. Data type supported: QASYMM8/QASYMM8_SIGNED + * @param[in] src1 Input tensor containing the RHS reshaped matrix. Data type supported: same as @p src0 + * @param[out] dst Destination tensor. Data type supported: QASYMM8/QASYMM8_SIGNED/S32. + * @param[in] gemm_info GEMM information used to retrieve the original dimensions of the input matrices, output stage information and RHS/LHS info. + * lhs_info.m0: 1,2,4 + * Only the following values are supported for RHS info: + * rhs_info.n0: 1,4,8 + * rhs_info.k0: same as lhs_info.k0: 4 + * rhs_info.transpose: true + * @param[in] vector_sum_col (Optional) Input row-vector of sums of all the entries in each column of matrix B. + * Note: vector_sum_col can be a nullptr in case a_offset = 0. Data type supported: S32 + * @param[in] vector_sum_row (Optional) Input row-vector of sums of all the entries in each row of matrix A. + * Note: vector_sum_row can be a nullptr in case b_offset = 0. Data type supported: S32 + * @param[in] bias (Optional) Biases tensor. Can be a nullptr if the addition of biases is not required. + * Biases are 1D tensor with dimensions [OFM] or same dimensionality as dst if gemm_info.broadcast_bias is false. Data type supported: S32. + * @param[in] output_multipliers (Optional) Output multipliers tensor. Supported data types: S32. + * @param[in] output_shifts (Optional) Output shifts tensor. Supported data types: S32. + */ + void configure(const CLCompileContext &compile_context, const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst, const GEMMKernelInfo &gemm_info, + ITensorInfo *vector_sum_col = nullptr, const ITensorInfo *vector_sum_row = nullptr, ITensorInfo *bias = nullptr, + ITensorInfo *output_multipliers = nullptr, ITensorInfo *output_shifts = nullptr); + /** Static function to check if given info will lead to a valid configuration + * + * Similar to @ref ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel::configure() + * + * @return a status + */ + static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst, const GEMMKernelInfo &gemm_info, + const ITensorInfo *vector_sum_col = nullptr, const ITensorInfo *vector_sum_row = nullptr, const ITensorInfo *bias = nullptr, + const ITensorInfo *output_multipliers = nullptr, const ITensorInfo *output_shifts = nullptr); + + // Inherited methods overridden: + void run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) override; + +private: + bool _fuse_output_stage{ false }; + signed int _m{ 1 }; + signed int _n{ 1 }; + signed int _k{ 1 }; +}; +} // namespace kernels +} // namespace opencl +} // namespace arm_compute +#endif /* ARM_COMPUTE_CL_GEMMLOWP_MATRIXMULTIPLY_RESHAPED_ONLY_RHS_MMULKERNEL_H */ \ No newline at end of file diff --git a/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp new file mode 100644 index 0000000000..fe46913517 --- /dev/null +++ b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.cpp @@ -0,0 +1,365 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.h" + +#include "arm_compute/core/CL/CLHelpers.h" +#include "arm_compute/core/CL/CLKernelLibrary.h" +#include "arm_compute/core/CL/ICLTensor.h" +#include "arm_compute/core/CL/OpenCL.h" +#include "arm_compute/core/Helpers.h" +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/Utils.h" +#include "arm_compute/core/Validate.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include "src/core/CL/CLUtils.h" +#include "src/core/helpers/AutoConfiguration.h" +#include "src/core/helpers/WindowHelpers.h" +#include "src/core/utils/helpers/float_ops.h" +#include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h" +#include "support/Cast.h" +#include "support/StringSupport.h" + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +namespace +{ +using ElementsProcessed = Steps; + +// Block size dimensions for the MMUL extension +constexpr int mmul_m0 = 4; +constexpr int mmul_n0 = 4; +constexpr int mmul_k0 = 4; + +Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info, + const GEMMRHSMatrixInfo &rhs_info, + const GEMMKernelInfo &gemm_info) +{ + ARM_COMPUTE_UNUSED(alpha); + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(!arm_matrix_multiply_supported(CLKernelLibrary::get().get_device()), "The extension cl_arm_matrix_multiply is not supported on the target platform"); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src0, 1, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src0, src1); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(src0->num_dimensions() > 4, "The number of dimensions for the LHS matrix must be <= 4"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(src1->num_dimensions() > 3, "The number of dimensions for the RHS matrix must be <= 3"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_info.m0 < 1, "Only values greater than 0 are supported for m0"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(rhs_info.n0 != 1 && rhs_info.n0 != 2 && rhs_info.n0 != 3 && rhs_info.n0 != 4 && rhs_info.n0 != 8 && rhs_info.n0 != 16, "Only 1,2,3,4,8, and 16 are supported for n0"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG((rhs_info.k0 != 1 || lhs_info.k0 != 1), "Only 1 is supported for k0"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG((rhs_info.h0 != 4), "Only 4 is supported for h0"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(rhs_info.interleave != true, "Only true is supported for interleave with mmul extension enabled"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(rhs_info.transpose != false, "Only false is supported for transpose with mmul extension enabled"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.fp_mixed_precision, "Mixed precision not supported"); + ARM_COMPUTE_RETURN_ON_ERROR(gemm::validate_image2d_support_on_rhs(*src1, rhs_info)); + + const unsigned int m = gemm_info.m; + const unsigned int n = gemm_info.n; + const unsigned int k = gemm_info.k; + + ARM_COMPUTE_UNUSED(m); + ARM_COMPUTE_UNUSED(n); + ARM_COMPUTE_UNUSED(k); + + ARM_COMPUTE_RETURN_ERROR_ON(src0->dimension(0) != k); + + // Validate the reinterpreted-as-3D-case + if(gemm_info.depth_output_gemm3d != 0) + { + ARM_COMPUTE_RETURN_ERROR_ON(src0->dimension(1) * src0->dimension(2) != m); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON(src0->dimension(1) != m); + } + + // Validate the gemm-batched case + if(src1->num_dimensions() > 2) + { + if(gemm_info.depth_output_gemm3d != 0) + { + ARM_COMPUTE_RETURN_ERROR_ON(src0->dimension(3) != src1->dimension(2)); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON(src0->dimension(2) != src1->dimension(2)); + } + } + + if(src2 != nullptr && !(helpers::float_ops::is_zero(beta))) + { + const unsigned int src2_dim0 = src2->dimension(0); + const unsigned int src2_dim1 = src2->dimension(1); + + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src2, src1); + if(gemm_info.broadcast_bias) + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG((src2_dim1 != 1 || src2_dim0 != n), "Incorrect dimension of bias matrix which is to be broadcasted"); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG((src2_dim0 != n || src2_dim1 != m), "Incorrect dimension of bias matrix"); + } + } + + TensorShape tensor_shape1{ src1->tensor_shape() }; + tensor_shape1.set(0, n); + tensor_shape1.set(1, k); + + const TensorInfo tensor_info1 = src1->clone()->set_tensor_shape(tensor_shape1); + const TensorInfo tensor_info_reshaped1 = src1->clone()->set_tensor_shape(misc::shape_calculator::compute_rhs_reshaped_shape(tensor_info1, rhs_info)); + + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(src1, &tensor_info_reshaped1); + + if(dst->total_size() != 0) + { + const TensorInfo tensor_info_dst = dst->clone()->set_tensor_shape(misc::shape_calculator::compute_mm_shape(*src0, *src1, gemm_info)); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(dst, &tensor_info_dst); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src0, dst); + } + + return Status{}; +} + +std::pair validate_and_configure_window(ITensorInfo *src0, ITensorInfo *src1, ITensorInfo *src2, ITensorInfo *dst, const GEMMLHSMatrixInfo &lhs_info, + const GEMMRHSMatrixInfo &rhs_info, + const GEMMKernelInfo &gemm_info) +{ + ARM_COMPUTE_UNUSED(src0, src1, src2); + bool reinterpret_output_as_3d = gemm_info.depth_output_gemm3d != 0; + + // dst tensor auto initialization if not yet initialized + auto_init_if_empty(*dst, src0->clone()->set_tensor_shape(misc::shape_calculator::compute_mm_shape(*src0, *src1, gemm_info))); + + TensorInfo tmp_info(*dst); + + if(reinterpret_output_as_3d) + { + // Since the dst tensor has to be reinterpreted as 3D and the execute window is based on a 2D GEMM, + // the window needs to be constructed on the 2D collapsed version of the tensor + TensorShape tmp_shape(dst->tensor_shape()); + tmp_shape.collapse(2U, 1U); + tmp_info.set_tensor_shape(tmp_shape); + } + + Window win = calculate_max_window(tmp_info, Steps(1, 1)); + + // Collapse along the Z direction + // This collapse needs to be here in order to tune the Z dimension of LWS + const unsigned int dimension_to_collapse = std::min(static_cast(dst->num_dimensions()), 2u); + Window collapsed = win.collapse(win, dimension_to_collapse); + + // Reconfigure window size, one arm_matrix_multiply kernel needs 16 threads to finish. + Window::Dimension x_dimension = collapsed.x(); + Window::Dimension y_dimension = collapsed.y(); + + // Make M and N multiple of M0 and N0 respectively + const unsigned int ceil_to_multiple_n_n0 = ceil_to_multiple(x_dimension.end(), rhs_info.n0); + const unsigned int ceil_to_multiple_m_m0 = ceil_to_multiple(y_dimension.end(), lhs_info.m0); + + // Divide M and N by M0 and N0 respectively + const unsigned int n_div_n0 = ceil_to_multiple_n_n0 / rhs_info.n0; + const unsigned int m_div_m0 = ceil_to_multiple_m_m0 / lhs_info.m0; + + // Make n_div_n0 and m_div_m0 multiple of mmul_n0 and mmul_k0 respectively + const unsigned int ceil_to_multiple_n_div_n0_mmul_n0 = ceil_to_multiple(n_div_n0, mmul_n0); + const unsigned int ceil_to_multiple_m_div_m0_mmul_k0 = ceil_to_multiple(m_div_m0, mmul_k0); + + // Ensure x_dimension is multiple of MMUL block size (mmul_n0 * mmul_k0) + x_dimension.set_end(ceil_to_multiple_n_div_n0_mmul_n0 * mmul_k0); + y_dimension.set_end(ceil_to_multiple_m_div_m0_mmul_k0 / mmul_k0); + + collapsed.set(Window::DimX, x_dimension); + collapsed.set(Window::DimY, y_dimension); + + return std::make_pair(Status{}, collapsed); +} +} // namespace + +ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel::ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel() +{ + _type = CLKernelType::GEMM; +} + +void ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel::configure(const CLCompileContext &compile_context, ITensorInfo *src0, ITensorInfo *src1, ITensorInfo *src2, ITensorInfo *dst, float alpha, + float beta, + const GEMMLHSMatrixInfo &lhs_info, + const GEMMRHSMatrixInfo &rhs_info, const GEMMKernelInfo &gemm_info) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst); + + // dst tensor auto initialization if not yet initialized + auto_init_if_empty(*dst, src0->clone()->set_tensor_shape(misc::shape_calculator::compute_mm_shape(*src0, *src1, gemm_info))); + + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src0, src1, src2, dst, alpha, beta, lhs_info, rhs_info, gemm_info)); + + auto padding_info = get_padding_info({ src0, src1, src2, dst }); + _add_bias = src2 != nullptr; + _export_to_cl_image = rhs_info.export_to_cl_image; + + // Configure kernel window + auto win_config = validate_and_configure_window(src0, src1, src2, dst, lhs_info, rhs_info, gemm_info); + ARM_COMPUTE_ERROR_THROW_ON(win_config.first); + + IClKernel::configure_internal(win_config.second); + + _m = gemm_info.m; + _n = gemm_info.n; + _k = gemm_info.k; + + const unsigned int m0_leftover = _m % lhs_info.m0; + const unsigned int n0_leftover = _n % rhs_info.n0; + + // Create build options + CLBuildOptions build_opts; + build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(src0->data_type())); + build_opts.add_option_if(!(helpers::float_ops::is_one(alpha)), "-DALPHA=" + float_to_string_with_full_precision(alpha)); + build_opts.add_option_if(src2 != nullptr, "-DBETA=" + float_to_string_with_full_precision(beta)); + build_opts.add_option_if(helpers::float_ops::is_one(beta), "-DUNIT_BETA"); + build_opts.add_option_if(gemm_info.broadcast_bias, "-DBROADCAST_BIAS"); + build_opts.add_option_if(src0->data_type() == DataType::F16, "-DHALF_PRECISION"); + build_opts.add_option("-DM0=" + support::cpp11::to_string(lhs_info.m0)); + build_opts.add_option("-DN0=" + support::cpp11::to_string(rhs_info.n0)); + build_opts.add_option("-DK0=" + support::cpp11::to_string(rhs_info.k0)); + build_opts.add_option("-DM0_LEFTOVER=" + support::cpp11::to_string(m0_leftover)); + build_opts.add_option("-DN0_LEFTOVER=" + support::cpp11::to_string(n0_leftover)); + build_opts.add_option("-DMMUL_M0=" + support::cpp11::to_string(mmul_m0)); + build_opts.add_option("-DMMUL_N0=" + support::cpp11::to_string(mmul_n0)); + build_opts.add_option("-DMMUL_K0=" + support::cpp11::to_string(mmul_k0)); + build_opts.add_option("-DACTIVATION_TYPE=" + lower_string(string_from_activation_func(gemm_info.activation_info.activation()))); + build_opts.add_option("-DA_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.a())); + build_opts.add_option("-DB_VAL=" + float_to_string_with_full_precision(gemm_info.activation_info.b())); + + std::string kernel_name("gemm_mm_reshaped_only_rhs_nt_mmul"); + kernel_name += rhs_info.export_to_cl_image ? "_texture" : ""; + + // A macro guard to compile ONLY the kernel of interest + build_opts.add_option("-D" + upper_string(kernel_name)); + + // Create kernel + _kernel = create_kernel(compile_context, kernel_name, build_opts.options()); + + // Set config_id for enabling LWS tuning + _config_id = kernel_name; + _config_id += "_"; + _config_id += (_add_bias ? "add_bias_" : ""); + _config_id += (gemm_info.broadcast_bias ? "broadcast_bias_" : ""); + _config_id += (gemm_info.activation_info.enabled() ? "fused_activation_" : ""); + _config_id += lower_string(string_from_data_type(src0->data_type())); + _config_id += "_"; + _config_id += support::cpp11::to_string(_m); + _config_id += "_"; + _config_id += support::cpp11::to_string(_n); + _config_id += "_"; + _config_id += support::cpp11::to_string(_k); + _config_id += "_"; + _config_id += support::cpp11::to_string(lhs_info.m0); + _config_id += "_"; + _config_id += support::cpp11::to_string(rhs_info.n0); + + ARM_COMPUTE_ERROR_ON(has_padding_changed(padding_info)); +} + +Status ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta, + const GEMMLHSMatrixInfo &lhs_info, + const GEMMRHSMatrixInfo &rhs_info, const GEMMKernelInfo &gemm_info) +{ + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src0, src1, src2, dst, alpha, beta, lhs_info, rhs_info, gemm_info)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(src0->clone().get(), + src1->clone().get(), + src2 != nullptr ? src2->clone().get() : nullptr, + dst->clone().get(), + lhs_info, + rhs_info, + gemm_info) + .first); + + return Status{}; +} + +void ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel::run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) +{ + ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); + ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICLKernel::window(), window); + + const auto src0 = utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_SRC_0)); + const auto src1 = utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_SRC_1)); + const auto src2 = utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_SRC_2)); + auto dst = utils::cast::polymorphic_downcast(tensors.get_tensor(TensorType::ACL_DST)); + + ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst); + ARM_COMPUTE_ERROR_ON(_add_bias && src2 == nullptr); + + if(src1->info()->num_dimensions() < 3) + { + // The stride_z for matrix B must be zero if we do not slice + ARM_COMPUTE_ERROR_ON(src1->info()->strides_in_bytes()[3] != 0); + } + + cl::Image2D src1_image2d; + + if(_export_to_cl_image) + { + const TensorShape shape2d(src1->info()->dimension(0) / 4, src1->info()->dimension(1) * src1->info()->dimension(2)); + const size_t image_row_pitch = src1->info()->strides_in_bytes()[1]; + + src1_image2d = create_image2d_from_buffer(CLKernelLibrary::get().context(), src1->cl_buffer(), shape2d, src1->info()->data_type(), image_row_pitch); + } + + Window slice = window.first_slice_window_3D(); + + do + { + unsigned int idx = 0; + + add_3d_tensor_nhw_argument(idx, src0); + if(_export_to_cl_image) + { + _kernel.setArg(idx++, src1_image2d); + } + add_3d_tensor_nhw_argument(idx, src1); + + // Bias buffer (_add_bias == true) + if(_add_bias) + { + add_3d_tensor_nhw_argument(idx, src2); + } + // dst buffer + add_3d_tensor_nhw_argument(idx, dst); + + // Pass m, n and k at runtime as signed ints, to ensure results of any subtractions they could be operand in, would still be signed. + _kernel.setArg(idx++, _m); + _kernel.setArg(idx++, _n); + _kernel.setArg(idx++, _k); + + // LWS_x should be multiple of 16 at least. (32, 2) has been chosen to have more work-items on a single core + // LWS also enforces the order of execution of the workitems which improves cache utilization + enqueue(queue, *this, slice, cl::NDRange(32, 2), false); + } + while(window.slide_window_slice_3D(slice)); +} +} // namespace kernels +} // namespace opencl +} // namespace arm_compute diff --git a/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.h b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.h new file mode 100644 index 0000000000..59612fcf5d --- /dev/null +++ b/src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CL_GEMM_MATRIXMULTIPLY_RESHAPED_ONLY_RHS_MMUL_KERNEL_H +#define ARM_COMPUTE_CL_GEMM_MATRIXMULTIPLY_RESHAPED_ONLY_RHS_MMUL_KERNEL_H + +#include "arm_compute/core/KernelDescriptors.h" +#include "src/core/common/Macros.h" +#include "src/gpu/cl/ClCompileContext.h" +#include "src/gpu/cl/IClKernel.h" + +namespace arm_compute +{ +namespace opencl +{ +namespace kernels +{ +/** OpenCL kernel to multiply matrices using MMUL when only the input matrix RHS (src1) has been reshaped */ +class ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel : public IClKernel +{ +public: + ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel(); + ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel); + /** Initialize the kernel's input and dst. + * + * @param[in] compile_context The compile context to be used. + * @param[in] src0 Input tensor for the LHS matrix. Data type supported: F16/F32. + * @param[in] src1 Input tensor containing the RHS reshaped matrix. Data type supported: same as @p src0. + * @param[in] src2 Input tensor containing the bias matrix. Data type supported: same as @p src0. + * @param[out] dst dst tensor info. Data type supported: same as @p src0 + * @param[in] alpha Weight of the matrix product + * @param[in] beta Weight of the matrix bias + * @param[in] lhs_info LHS matrix information used to retrieve the number of rows and accumulations to be processed by each thread. Only the following values are supported: + * lhs_info.m0 > 0 + * lhs_info.k0: 1 + * @param[in] rhs_info RHS matrix information used to retrieve the number of columns and accumulations to be processed by each thread. Only the following values are supported: + * rhs_info.n0: 1,2,3,4,8,16 + * rhs_info.k0: same of lhs_info.k0 + * rhs_info.transpose: false + * @param[in] gemm_info GEMM information used to retrieve the original dimensions of the input matrices + */ + void configure(const ClCompileContext &compile_context, ITensorInfo *src0, ITensorInfo *src1, ITensorInfo *src2, ITensorInfo *dst, float alpha, float beta, + const GEMMLHSMatrixInfo &lhs_info, + const GEMMRHSMatrixInfo &rhs_info, + const GEMMKernelInfo &gemm_info); + /** Static function to check if given info will lead to a valid configuration + * + * Similar to @ref ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel::configure() + * + * @return a status + */ + static Status validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *src2, const ITensorInfo *dst, float alpha, float beta, const GEMMLHSMatrixInfo &lhs_info, + const GEMMRHSMatrixInfo &rhs_info, + const GEMMKernelInfo &gemm_info); + + // Inherited methods overridden: + void run_op(ITensorPack &tensors, const Window &window, cl::CommandQueue &queue) override; + +private: + bool _add_bias{ false }; + bool _export_to_cl_image{ false }; + signed int _m{ 1 }; + signed int _n{ 1 }; + signed int _k{ 1 }; +}; +} // namespace kernels +} // namespace opencl +} // namespace arm_compute +#endif /* ARM_COMPUTE_CL_GEMM_MATRIXMULTIPLY_RESHAPED_ONLY_RHS_MMUL_KERNEL_H */ diff --git a/src/gpu/cl/kernels/ClWinogradOutputTransformKernel.cpp b/src/gpu/cl/kernels/ClWinogradOutputTransformKernel.cpp index ff57c83959..9eb249a66a 100644 --- a/src/gpu/cl/kernels/ClWinogradOutputTransformKernel.cpp +++ b/src/gpu/cl/kernels/ClWinogradOutputTransformKernel.cpp @@ -176,27 +176,46 @@ void ClWinogradOutputTransformKernel::configure(const ClCompileContext &compile_ build_opts.add_option("-DVEC_SIZE=4"); } + _num_tiles_x = num_tiles.width; + + // Conditions of -cl-fast-relaxed-math causing accuracy issues can be traced from COMPMID-5324 + const GPUTarget gpu_target = get_target(); + const auto act_function = act_info.activation(); + const auto src_data_type = src->data_type(); + + if((gpu_target != GPUTarget::G71 && (gpu_target & GPUTarget::GPU_ARCH_MASK) == GPUTarget::BIFROST) + && (act_function == ActivationLayerInfo::ActivationFunction::BOUNDED_RELU || act_function == ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU) + && (src_data_type == DataType::F32 || src_data_type == DataType::F16)) + { + // -cl-fast-relaxed-math also sets -cl-finite-math-only and -cl-unsafe-math-optimizations + // to disable -cl-finite-math-only, we only include -cl-unsafe-math-optimizations + build_opts.add_option("-cl-unsafe-math-optimizations"); + } + else + { + build_opts.add_option("-cl-fast-relaxed-math"); + } + if(_is_nhwc) { build_opts.add_option_if(bias != nullptr, std::string("-DHAS_BIAS")); - build_opts.add_option("-cl-fast-relaxed-math"); build_opts.add_option("-DN0=" + support::cpp11::to_string(win_config.second.x().step())); build_opts.add_option("-DOUTPUT_TILE_W=" + support::cpp11::to_string(output_tile_size.width)); build_opts.add_option("-DOUTPUT_TILE_H=" + support::cpp11::to_string(output_tile_size.height)); - build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(src->data_type())); + build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(src_data_type)); build_opts.add_option_if(total_batches > 1, "-DSRC_DEPTH=" + support::cpp11::to_string(src->dimension(2))); build_opts.add_option_if(winograd_info.kernel_size.height == 1, "-DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL"); build_opts.add_option_if(winograd_info.kernel_size.width == 1, "-DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL"); + build_opts.add_option("-DNUM_TILES_X=" + support::cpp11::to_string(_num_tiles_x)); } else { build_opts.add_option_if(bias != nullptr, std::string("-DHAS_BIAS")); - build_opts.add_option("-cl-fast-relaxed-math"); build_opts.add_option("-DN0=" + support::cpp11::to_string(win_config.second.x().step())); build_opts.add_option("-DNUM_TILES_X=" + support::cpp11::to_string(num_tiles.width)); build_opts.add_option("-DOUTPUT_TILE_W=" + support::cpp11::to_string(output_tile_size.width)); build_opts.add_option("-DOUTPUT_TILE_H=" + support::cpp11::to_string(output_tile_size.height)); - build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(src->data_type())); + build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(src_data_type)); build_opts.add_option("-DSRC_HEIGHT=" + support::cpp11::to_string(src->dimension(1))); build_opts.add_option("-DDST_WIDTH=" + support::cpp11::to_string(dst->dimension(idx_width))); build_opts.add_option("-DDST_HEIGHT=" + support::cpp11::to_string(dst->dimension(idx_height))); @@ -206,10 +225,9 @@ void ClWinogradOutputTransformKernel::configure(const ClCompileContext &compile_ } // Storing tensor dimensions to be sent later as kernel arguments - _src_height = src->dimension(1); - _dst_width = dst->dimension(idx_width); - _dst_height = dst->dimension(idx_height); - _num_tiles_x = num_tiles.width; + _src_height = src->dimension(1); + _dst_width = dst->dimension(idx_width); + _dst_height = dst->dimension(idx_height); // Create kernel std::string kernel_name = "winograd_output_transform_" + output_tile_size.to_string() + "_" + kernel_size.to_string() + "_" + lower_string(string_from_data_layout(winograd_info.output_data_layout)); @@ -221,7 +239,7 @@ void ClWinogradOutputTransformKernel::configure(const ClCompileContext &compile_ // Set config_id for enabling LWS tuning _config_id = kernel_name; _config_id += "_"; - _config_id += lower_string(string_from_data_type(src->data_type())); + _config_id += lower_string(string_from_data_type(src_data_type)); _config_id += "_"; _config_id += support::cpp11::to_string(src->dimension(0)); _config_id += "_"; @@ -279,7 +297,6 @@ void ClWinogradOutputTransformKernel::run_op(ITensorPack &tensors, const Window _kernel.setArg(idx2++, _src_height); _kernel.setArg(idx2++, _dst_width); _kernel.setArg(idx2++, _dst_height); - _kernel.setArg(idx2++, _num_tiles_x); } do diff --git a/src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigBifrost.cpp b/src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigBifrost.cpp new file mode 100644 index 0000000000..4ea198133b --- /dev/null +++ b/src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigBifrost.cpp @@ -0,0 +1,192 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigBifrost.h" + +#include "arm_compute/core/CL/CLHelpers.h" +#include "arm_compute/core/CL/CLKernelLibrary.h" +#include "arm_compute/core/GPUTarget.h" +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/TensorShape.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include + +namespace arm_compute +{ +namespace cl_direct_conv +{ +using namespace arm_compute::misc::shape_calculator; + +ClDirectConvDefaultConfigBifrost::ClDirectConvDefaultConfigBifrost(GPUTarget gpu) + : IClDirectConvKernelConfig(gpu) +{ +} + +DirectConvComputeKernelInfo ClDirectConvDefaultConfigBifrost::configure(const ITensorInfo *src, const ITensorInfo *wei, const PadStrideInfo &conv_info) +{ + using ConfigurationFunctionExecutorPtr = DirectConvComputeKernelInfo (ClDirectConvDefaultConfigBifrost::*)(const ITensorInfo * src, const ITensorInfo * wei, const PadStrideInfo & conv_info); + + ClDirectConvConfigArray configs_G71(&ClDirectConvDefaultConfigBifrost::configure_G71_f32, + &ClDirectConvDefaultConfigBifrost::configure_G71_f16, + &ClDirectConvDefaultConfigBifrost::configure_G71_u8); + + ClDirectConvConfigArray configs_default(&ClDirectConvDefaultConfigBifrost::configure_default_f32, + &ClDirectConvDefaultConfigBifrost::configure_default_f16, + &ClDirectConvDefaultConfigBifrost::configure_G71_u8); + + ConfigurationFunctionExecutorPtr func = nullptr; + switch(_target) + { + case GPUTarget::G71: + func = configs_G71.get_function(src->data_type()); + break; + default: + func = configs_default.get_function(src->data_type()); + break; + } + + ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not supported for direct convolution"); + return (this->*func)(src, wei, conv_info); +} + +DirectConvComputeKernelInfo ClDirectConvDefaultConfigBifrost::configure_G71_f32(const ITensorInfo *src, const ITensorInfo *wei, const PadStrideInfo &conv_info) +{ + DirectConvComputeKernelInfo desc; + + if(src->data_layout() == DataLayout::NHWC) + { + // Get the output shape + TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*src, *wei, conv_info); + + desc.n0 = 4; + + if(output_shape[0] > 16) + { + desc.m0 = 2; + } + + desc.k0 = 8; + + desc.export_weights_to_cl_image = false; + } + + return desc; +} + +DirectConvComputeKernelInfo ClDirectConvDefaultConfigBifrost::configure_G71_f16(const ITensorInfo *src, const ITensorInfo *wei, const PadStrideInfo &conv_info) +{ + DirectConvComputeKernelInfo desc; + + if(src->data_layout() == DataLayout::NHWC) + { + // Get the output shape + TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*src, *wei, conv_info); + + desc.n0 = 4; + + if(output_shape[0] > 16) + { + desc.m0 = 4; + } + + desc.k0 = 8; + + desc.export_weights_to_cl_image = false; + } + + return desc; +} + +DirectConvComputeKernelInfo ClDirectConvDefaultConfigBifrost::configure_G71_u8(const ITensorInfo *src, const ITensorInfo *wei, const PadStrideInfo &conv_info) +{ + DirectConvComputeKernelInfo desc; + + if(src->data_layout() == DataLayout::NHWC) + { + // Get the output shape + TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*src, *wei, conv_info); + + desc.n0 = 4; + + if(output_shape[0] > 16) + { + desc.m0 = 4; + } + + desc.k0 = 16; + + desc.export_weights_to_cl_image = false; + } + + return desc; +} + +DirectConvComputeKernelInfo ClDirectConvDefaultConfigBifrost::configure_default_f32(const ITensorInfo *src, const ITensorInfo *wei, const PadStrideInfo &conv_info) +{ + DirectConvComputeKernelInfo desc; + + if(src->data_layout() == DataLayout::NHWC) + { + // Get the output shape + TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*src, *wei, conv_info); + + desc.n0 = 4; + + if(output_shape[0] > 16) + { + desc.m0 = 2; + } + + desc.k0 = 8; + + desc.export_weights_to_cl_image = export_weights_to_cl_image(wei); + } + + return desc; +} + +DirectConvComputeKernelInfo ClDirectConvDefaultConfigBifrost::configure_default_f16(const ITensorInfo *src, const ITensorInfo *wei, const PadStrideInfo &conv_info) +{ + DirectConvComputeKernelInfo desc; + + if(src->data_layout() == DataLayout::NHWC) + { + // Get the output shape + TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*src, *wei, conv_info); + + desc.n0 = 4; + + if(output_shape[0] > 16) + { + desc.m0 = 4; + } + + desc.k0 = 8; + + desc.export_weights_to_cl_image = export_weights_to_cl_image(wei); + } + + return desc; +} +} // namespace opencl +} // namespace arm_compute diff --git a/src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigBifrost.h b/src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigBifrost.h new file mode 100644 index 0000000000..1e4cb66ec0 --- /dev/null +++ b/src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigBifrost.h @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CL_DIRECT_CONV_DEFAULT_CONFIG_BIFROST_H +#define ARM_COMPUTE_CL_DIRECT_CONV_DEFAULT_CONFIG_BIFROST_H + +#include "src/gpu/cl/kernels/direct_conv/IClDirectConvKernelConfig.h" + +namespace arm_compute +{ +namespace cl_direct_conv +{ +/** Bifrost based OpenCL direct convolution configuration */ +class ClDirectConvDefaultConfigBifrost final : public IClDirectConvKernelConfig +{ +public: + /** Constructor + * + * @param[in] gpu GPU target + */ + ClDirectConvDefaultConfigBifrost(GPUTarget gpu); + + // Inherited overridden method + DirectConvComputeKernelInfo configure(const ITensorInfo *src, const ITensorInfo *wei, const PadStrideInfo &conv_info) override; + +private: + DirectConvComputeKernelInfo configure_G71_f32(const ITensorInfo *src, const ITensorInfo *wei, const PadStrideInfo &conv_info); + DirectConvComputeKernelInfo configure_G71_f16(const ITensorInfo *src, const ITensorInfo *wei, const PadStrideInfo &conv_info); + DirectConvComputeKernelInfo configure_G71_u8(const ITensorInfo *src, const ITensorInfo *wei, const PadStrideInfo &conv_info); + DirectConvComputeKernelInfo configure_default_f32(const ITensorInfo *src, const ITensorInfo *wei, const PadStrideInfo &conv_info); + DirectConvComputeKernelInfo configure_default_f16(const ITensorInfo *src, const ITensorInfo *wei, const PadStrideInfo &conv_info); +}; +} // namespace opencl +} // namespace arm_compute +#endif /* ARM_COMPUTE_CL_DIRECT_CONV_DEFAULT_CONFIG_BIFROST_H */ diff --git a/src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigValhall.cpp b/src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigValhall.cpp new file mode 100644 index 0000000000..d87cada159 --- /dev/null +++ b/src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigValhall.cpp @@ -0,0 +1,358 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigValhall.h" + +#include "arm_compute/core/CL/CLHelpers.h" +#include "arm_compute/core/CL/CLKernelLibrary.h" +#include "arm_compute/core/GPUTarget.h" +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/TensorShape.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include + +namespace arm_compute +{ +namespace cl_direct_conv +{ +using namespace arm_compute::misc::shape_calculator; + +ClDirectConvDefaultConfigValhall::ClDirectConvDefaultConfigValhall(GPUTarget gpu) + : IClDirectConvKernelConfig(gpu) +{ +} + +DirectConvComputeKernelInfo ClDirectConvDefaultConfigValhall::configure(const ITensorInfo *src, const ITensorInfo *wei, const PadStrideInfo &conv_info) +{ + using ConfigurationFunctionExecutorPtr = DirectConvComputeKernelInfo (ClDirectConvDefaultConfigValhall::*)(const ITensorInfo * src, const ITensorInfo * wei, const PadStrideInfo & conv_info); + + ClDirectConvConfigArray configs_G78(&ClDirectConvDefaultConfigValhall::configure_G78_f32, + &ClDirectConvDefaultConfigValhall::configure_G78_f16, + &ClDirectConvDefaultConfigValhall::configure_G78_u8); + + ClDirectConvConfigArray configs_G57(&ClDirectConvDefaultConfigValhall::configure_G57_f32, + &ClDirectConvDefaultConfigValhall::configure_G57_f16, + &ClDirectConvDefaultConfigValhall::configure_G78_u8); + + ConfigurationFunctionExecutorPtr func = nullptr; + switch(_target) + { + case GPUTarget::G57: + func = configs_G57.get_function(src->data_type()); + break; + case GPUTarget::G78: + default: + func = configs_G78.get_function(src->data_type()); + break; + } + + ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not supported for direct convolution"); + return (this->*func)(src, wei, conv_info); +} + +DirectConvComputeKernelInfo ClDirectConvDefaultConfigValhall::configure_G78_f32(const ITensorInfo *src, const ITensorInfo *wei, const PadStrideInfo &conv_info) +{ + DirectConvComputeKernelInfo desc; + + if(src->data_layout() == DataLayout::NHWC) + { + // Get the output shape + const TensorShape wei_shape = wei->tensor_shape(); + const TensorShape dst_shape = misc::shape_calculator::compute_deep_convolution_shape(*src, *wei, conv_info); + const bool export_to_cl_image = export_weights_to_cl_image(wei); + + const int32_t ofm = dst_shape[0]; + const int32_t m = dst_shape[1] * dst_shape[2]; + const bool is_pointwise = (wei_shape[1] == wei_shape[2]) && wei_shape[1] == 1; + + desc.export_weights_to_cl_image = export_to_cl_image; + + if(dst_shape[0] <= 4) + { + if(is_pointwise) + { + if(ofm == 4) + { + desc.m0 = 1; + desc.n0 = 4; + desc.k0 = 16; + } + else + { + desc.m0 = 1; + desc.n0 = 1; + desc.k0 = 16; + } + } + else + { + desc.m0 = 1; + desc.n0 = 2; + desc.k0 = 16; + } + } + else + { + if(m < 64) + { + desc.m0 = 1; + desc.n0 = 1; + desc.k0 = 16; + } + else + { + desc.m0 = 4; + desc.n0 = 4; + desc.k0 = 4; + } + } + } + + return desc; +} + +DirectConvComputeKernelInfo ClDirectConvDefaultConfigValhall::configure_G78_f16(const ITensorInfo *src, const ITensorInfo *wei, const PadStrideInfo &conv_info) +{ + DirectConvComputeKernelInfo desc; + + if(src->data_layout() == DataLayout::NHWC) + { + // Get the output shape + const TensorShape wei_shape = wei->tensor_shape(); + const TensorShape dst_shape = misc::shape_calculator::compute_deep_convolution_shape(*src, *wei, conv_info); + const bool export_to_cl_image = export_weights_to_cl_image(wei); + + const int32_t ofm = dst_shape[0]; + const int32_t m = dst_shape[1] * dst_shape[2]; + const bool is_pointwise = (wei_shape[1] == wei_shape[2]) && wei_shape[1] == 1; + + desc.export_weights_to_cl_image = export_to_cl_image; + + if(dst_shape[0] <= 4) + { + if(is_pointwise) + { + if(ofm == 4) + { + desc.m0 = 1; + desc.n0 = 4; + desc.k0 = 16; + } + else + { + desc.m0 = 1; + desc.n0 = 1; + desc.k0 = 16; + } + } + else + { + desc.m0 = 1; + desc.n0 = dst_shape[0]; + desc.k0 = 16; + } + } + else + { + if(m < 64) + { + desc.m0 = 1; + desc.n0 = 1; + desc.k0 = 16; + } + else + { + if(ofm > 16) + { + desc.m0 = 4; + desc.n0 = 4; + desc.k0 = 8; + } + else + { + desc.m0 = 4; + desc.n0 = 4; + desc.k0 = 16; + } + } + } + } + + return desc; +} + +DirectConvComputeKernelInfo ClDirectConvDefaultConfigValhall::configure_G78_u8(const ITensorInfo *src, const ITensorInfo *wei, const PadStrideInfo &conv_info) +{ + DirectConvComputeKernelInfo desc; + + if(src->data_layout() == DataLayout::NHWC) + { + // Get the output shape + TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*src, *wei, conv_info); + + desc.n0 = 4; + + if(output_shape[0] > 16) + { + desc.m0 = 4; + } + + desc.k0 = 16; + + desc.export_weights_to_cl_image = false; + } + + return desc; +} + +DirectConvComputeKernelInfo ClDirectConvDefaultConfigValhall::configure_G57_f32(const ITensorInfo *src, const ITensorInfo *wei, const PadStrideInfo &conv_info) +{ + DirectConvComputeKernelInfo desc; + + if(src->data_layout() == DataLayout::NHWC) + { + // Get the output shape + const TensorShape wei_shape = wei->tensor_shape(); + const TensorShape dst_shape = misc::shape_calculator::compute_deep_convolution_shape(*src, *wei, conv_info); + const bool export_to_cl_image = export_weights_to_cl_image(wei); + + const int32_t m = dst_shape[1] * dst_shape[2]; + const bool is_pointwise = (wei_shape[1] == wei_shape[2]) && wei_shape[1] == 1; + + desc.export_weights_to_cl_image = export_to_cl_image; + + if(dst_shape[0] <= 4) + { + if(is_pointwise) + { + desc.m0 = 1; + desc.n0 = 1; + desc.k0 = 16; + } + else + { + desc.m0 = 1; + desc.n0 = dst_shape[0]; + desc.k0 = 16; + } + } + else + { + if(m < 64) + { + if(m == 1) + { + desc.m0 = 1; + desc.n0 = 1; + desc.k0 = 16; + } + else + { + desc.m0 = 4; + desc.n0 = 2; + desc.k0 = 8; + } + } + else + { + desc.m0 = 4; + desc.n0 = 4; + desc.k0 = 4; + } + } + } + + return desc; +} + +DirectConvComputeKernelInfo ClDirectConvDefaultConfigValhall::configure_G57_f16(const ITensorInfo *src, const ITensorInfo *wei, const PadStrideInfo &conv_info) +{ + DirectConvComputeKernelInfo desc; + + if(src->data_layout() == DataLayout::NHWC) + { + // Get the output shape + const TensorShape wei_shape = wei->tensor_shape(); + const TensorShape dst_shape = misc::shape_calculator::compute_deep_convolution_shape(*src, *wei, conv_info); + const bool export_to_cl_image = export_weights_to_cl_image(wei); + + const int32_t ofm = dst_shape[0]; + const int32_t m = dst_shape[1] * dst_shape[2]; + const bool is_pointwise = (wei_shape[1] == wei_shape[2]) && wei_shape[1] == 1; + + desc.export_weights_to_cl_image = export_to_cl_image; + + if(dst_shape[0] <= 4) + { + if(is_pointwise) + { + desc.m0 = 2; + desc.n0 = 1; + desc.k0 = 16; + } + else + { + desc.m0 = 1; + desc.n0 = dst_shape[0]; + desc.k0 = 16; + } + } + else + { + if(m < 64) + { + if(m == 1) + { + desc.m0 = 1; + desc.n0 = 1; + desc.k0 = 16; + } + else + { + desc.m0 = 4; + desc.n0 = 2; + desc.k0 = 8; + } + } + else + { + if(ofm > 16) + { + desc.m0 = 4; + desc.n0 = 8; + desc.k0 = 8; + } + else + { + desc.m0 = 8; + desc.n0 = 4; + desc.k0 = 4; + } + } + } + } + + return desc; +} +} // namespace opencl +} // namespace arm_compute diff --git a/src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigValhall.h b/src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigValhall.h new file mode 100644 index 0000000000..2c65b88846 --- /dev/null +++ b/src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigValhall.h @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CL_DIRECT_CONV_DEFAULT_CONFIG_VALHALL_H +#define ARM_COMPUTE_CL_DIRECT_CONV_DEFAULT_CONFIG_VALHALL_H + +#include "src/gpu/cl/kernels/direct_conv/IClDirectConvKernelConfig.h" + +namespace arm_compute +{ +namespace cl_direct_conv +{ +/** Valhall based OpenCL direct convolution configuration */ +class ClDirectConvDefaultConfigValhall final : public IClDirectConvKernelConfig +{ +public: + /** Constructor + * + * @param[in] gpu GPU target + */ + ClDirectConvDefaultConfigValhall(GPUTarget gpu); + + // Inherited overridden method + DirectConvComputeKernelInfo configure(const ITensorInfo *src, const ITensorInfo *wei, const PadStrideInfo &conv_info) override; + +private: + DirectConvComputeKernelInfo configure_G78_f32(const ITensorInfo *src, const ITensorInfo *wei, const PadStrideInfo &conv_info); + DirectConvComputeKernelInfo configure_G78_f16(const ITensorInfo *src, const ITensorInfo *wei, const PadStrideInfo &conv_info); + DirectConvComputeKernelInfo configure_G78_u8(const ITensorInfo *src, const ITensorInfo *wei, const PadStrideInfo &conv_info); + DirectConvComputeKernelInfo configure_G57_f32(const ITensorInfo *src, const ITensorInfo *wei, const PadStrideInfo &conv_info); + DirectConvComputeKernelInfo configure_G57_f16(const ITensorInfo *src, const ITensorInfo *wei, const PadStrideInfo &conv_info); +}; +} // namespace opencl +} // namespace arm_compute +#endif /* ARM_COMPUTE_CL_DIRECT_CONV_DEFAULT_CONFIG_VALHALL_H */ diff --git a/src/gpu/cl/kernels/direct_conv/ClDirectConvKernelConfig.h b/src/gpu/cl/kernels/direct_conv/ClDirectConvKernelConfig.h new file mode 100644 index 0000000000..c1c2e439c6 --- /dev/null +++ b/src/gpu/cl/kernels/direct_conv/ClDirectConvKernelConfig.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CL_DIRECT_CONV_KERNEL_CONFIGURATION_H +#define ARM_COMPUTE_CL_DIRECT_CONV_KERNEL_CONFIGURATION_H + +#include "src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigBifrost.h" +#include "src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigValhall.h" +#include "src/gpu/cl/kernels/direct_conv/IClDirectConvKernelConfig.h" + +#include + +namespace arm_compute +{ +namespace cl_direct_conv +{ +/** ClDirectConvolution factory class */ +class ClDirectConvKernelConfigurationFactory final +{ +public: + /** Static method to call the ClDirectConvolution kernel configuration class accordingly with the GPU target + * + * @param[in] gpu GPU target + * + * @return IClDirectConvKernelConfig + */ + static std::unique_ptr create(GPUTarget gpu) + { + switch(get_arch_from_target(gpu)) + { + case GPUTarget::MIDGARD: + return std::make_unique(GPUTarget::G71); + case GPUTarget::BIFROST: + return std::make_unique(gpu); + case GPUTarget::VALHALL: + return std::make_unique(gpu); + default: + ARM_COMPUTE_ERROR("Not supported GPU target"); + } + } +}; +} // namespace opencl +} // namespace arm_compute +#endif /* ARM_COMPUTE_CL_DIRECT_CONV_KERNEL_CONFIGURATION_H */ diff --git a/src/gpu/cl/kernels/direct_conv/IClDirectConvKernelConfig.h b/src/gpu/cl/kernels/direct_conv/IClDirectConvKernelConfig.h new file mode 100644 index 0000000000..837fa35341 --- /dev/null +++ b/src/gpu/cl/kernels/direct_conv/IClDirectConvKernelConfig.h @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_ICL_DIRECT_CONV_KERNEL_CONFIG_H +#define ARM_COMPUTE_ICL_DIRECT_CONV_KERNEL_CONFIG_H + +#include "arm_compute/core/GPUTarget.h" +#include "arm_compute/core/KernelDescriptors.h" +#include "arm_compute/core/Types.h" +#include "src/core/common/Macros.h" + +namespace arm_compute +{ +namespace cl_direct_conv +{ +/** Basic container for the OpenCL direct convolution configuration functions */ +template +class ClDirectConvConfigArray +{ +public: + /** Alias for F32 index */ + static constexpr size_t DT_F32 = 0; + /** Alias for F16 index */ + static constexpr size_t DT_F16 = 1; + /** Alias for Int8 index */ + static constexpr size_t DT_INT8 = 2; + + /** Constructor + * + * @param[in] func_f32 Function to call for direct convolution F32 + * @param[in] func_f16 Function to call for direct convolution F16 + * @param[in] func_int8 Function to call for direct convolution Int8 (QASYMM8, QASYMM8_SIGNED, QSYMM8_PER_CHANNEL) + * + */ + ClDirectConvConfigArray(T func_f32, T func_f16, T func_int8) + : _configs{ func_f32, func_f16, func_int8 } + { + } + + /** Method to return the direct convolution configuration function based on data type + * + * @param[in] data_type Input data type + * + * @return the valid function otherwise it returns nullptr if the data type is not valid + */ + T get_function(DataType data_type) + { + switch(data_type) + { + case DataType::F32: + return _configs.at(DT_F32); + case DataType::F16: + return _configs.at(DT_F16); + case DataType::QASYMM8: + case DataType::QASYMM8_SIGNED: + case DataType::QSYMM8_PER_CHANNEL: + return _configs.at(DT_INT8); + default: + return nullptr; + } + } + +private: + std::array _configs; +}; + +/** Basic interface for the Direct convolution kernel configuration */ +class IClDirectConvKernelConfig +{ +public: + /** Constructor + * + * @param[in] arch GPU target + */ + IClDirectConvKernelConfig(GPUTarget arch) + : _target(arch) + { + } + ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(IClDirectConvKernelConfig); + /** Virtual destructor */ + virtual ~IClDirectConvKernelConfig() = default; + /** This method returns the @ref DirectConvComputeKernelInfo for the given inputs + * + * @param[in] src Source tensor (activation tensor) + * @param[in] wei Weights tensor + * @param[in] conv_info Convolution info + */ + virtual DirectConvComputeKernelInfo configure(const ITensorInfo *src, const ITensorInfo *wei, const PadStrideInfo &conv_info) = 0; + +protected: + GPUTarget _target; +}; +} // namespace opencl +} // namespace arm_compute +#endif /* ARM_COMPUTE_ICL_DIRECT_CONV_KERNEL_CONFIG_H */ diff --git a/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp b/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp index 1bf27ba277..67da06102d 100644 --- a/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp +++ b/src/gpu/cl/kernels/gemm/ClGemmHelpers.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021 Arm Limited. + * Copyright (c) 2019-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -110,6 +110,23 @@ Status validate_image2d_support_on_rhs(const ITensorInfo &tensor_reshaped_info, return Status{}; } + +bool is_mmul_kernel_preferred(const unsigned int m, const unsigned int n, const unsigned int k, const unsigned int b, + const DataType data_type, unsigned int &best_m0, unsigned int &best_n0) +{ + ARM_COMPUTE_UNUSED(n, k, b, data_type); + + const unsigned int mmul_k0 = 4; + best_m0 = 4; + best_n0 = 4; + + const unsigned int ceil_to_multiple_m_m0 = ceil_to_multiple(m, best_m0); + const unsigned int m_div_m0 = ceil_to_multiple_m_m0 / best_m0; + const unsigned int ceil_to_multiple_m_div_m0_mmul_k0 = ceil_to_multiple(m_div_m0, mmul_k0); + const unsigned int gws_y = ceil_to_multiple_m_div_m0_mmul_k0 / mmul_k0; + + return ((k % mmul_k0) == 0) && (gws_y > 4); +} } // namespace gemm } // namespace kernels } // namespace opencl diff --git a/src/gpu/cl/kernels/gemm/ClGemmHelpers.h b/src/gpu/cl/kernels/gemm/ClGemmHelpers.h index 3fce8c9173..bf1e8fce82 100644 --- a/src/gpu/cl/kernels/gemm/ClGemmHelpers.h +++ b/src/gpu/cl/kernels/gemm/ClGemmHelpers.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021 Arm Limited. + * Copyright (c) 2019-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -88,6 +88,21 @@ void update_padding_for_cl_image(ITensorInfo *tensor); * @return Status reporting if we can use the image2d OpenCL object on the RHS reshaped matrix */ Status validate_image2d_support_on_rhs(const ITensorInfo &tensor_reshaped_info, const GEMMRHSMatrixInfo &rhs_info); + +/** Determine if the MMUL kernels should be preferred + * + * @param[in] m Number of rows of the LHS matrix + * @param[in] n Number of columns of the RHS matrix + * @param[in] k Number of columns of the LHS matrix, rows of the RHS matrix + * @param[in] b Batch size + * @param[in] data_type Data type FP32/FP16 + * @param[in, out] best_m0 Suggested M0 (number of rows of the output block) for the kernel + * @param[in, out] best_n0 Suggested N0 (number of columns of the output block) for the kernel + * + * @return true if MMUL kernel is preferred over kernels w/o MMUL, false otherwise + */ +bool is_mmul_kernel_preferred(const unsigned int m, const unsigned int n, const unsigned int k, const unsigned int b, + const DataType data_type, unsigned int &best_m0, unsigned int &best_n0); } // namespace gemm } // namespace kernels } // namespace opencl diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp index a82084a8df..97762980be 100644 --- a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp +++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 Arm Limited. + * Copyright (c) 2020-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -29,7 +29,9 @@ #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/TensorShape.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" + #include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h" +#include "src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h" #include @@ -61,6 +63,10 @@ std::pair ClGemmDefaultConfigReshapedRhsOn &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16, &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8); + CLGEMMConfigArray configs_G715(&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f32, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f16, + &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8); + ConfigurationFunctionExecutorPtr func = nullptr; switch(_target) @@ -68,6 +74,10 @@ std::pair ClGemmDefaultConfigReshapedRhsOn case GPUTarget::G78: func = configs_G78.get_function(data_type); break; + case GPUTarget::G715: + case GPUTarget::G615: + func = configs_G715.get_function(data_type); + break; case GPUTarget::G77: default: func = configs_G77.get_function(data_type); @@ -564,6 +574,36 @@ std::pair ClGemmDefaultConfigReshapedRhsOn } } } + +std::pair ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + unsigned int best_m0; + unsigned int best_n0; + + if(is_mmul_kernel_preferred(m, n, k, b, DataType::F32, best_m0, best_n0)) + { + return configure_lhs_rhs_info(m, n, best_m0, best_n0, 1, 1, 4, false, true, false, false, true); + } + else + { + return configure_G77_f32(m, n, k, b); + } +} + +std::pair ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b) +{ + unsigned int best_m0; + unsigned int best_n0; + + if(is_mmul_kernel_preferred(m, n, k, b, DataType::F16, best_m0, best_n0)) + { + return configure_lhs_rhs_info(m, n, best_m0, best_n0, 1, 1, 4, false, true, false, false, true); + } + else + { + return configure_G78_f16(m, n, k, b); + } +} } // namespace gemm } // namespace kernels } // namespace opencl diff --git a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h index c5e80a7ddc..0ec068fffd 100644 --- a/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h +++ b/src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 Arm Limited. + * Copyright (c) 2020-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -53,6 +53,8 @@ class ClGemmDefaultConfigReshapedRhsOnlyValhall final : public IClGemmKernelConf std::pair configure_G78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); std::pair configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); std::pair configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G715_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b); + std::pair configure_G715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b); }; } // namespace gemm } // namespace kernels diff --git a/src/gpu/cl/operators/ClConv2d.cpp b/src/gpu/cl/operators/ClConv2d.cpp index 23c1b8af9a..8119fc8e3d 100644 --- a/src/gpu/cl/operators/ClConv2d.cpp +++ b/src/gpu/cl/operators/ClConv2d.cpp @@ -261,33 +261,49 @@ ConvolutionMethod ClConv2d::get_convolution_method(const ITensorInfo *src, const const bool is_ifm_ge_16 = src->dimension(idx_c) >= 16; const bool is_ofm_lte_8 = weights->dimension(3U) <= 8; const bool workload_gte_8192 = (output_shape[0] * output_shape[1] * output_shape[2]) / 16 >= 8192; - const bool is_ifm_gt_ofm = src->dimension(idx_c) > weights->dimension(3U); + const bool is_ifm_gt_ofm = src->dimension(idx_c) > weights->dimension(3U); + const bool is_m_one = output_shape[1] * output_shape[2] == 1; // Run Winograd if valid and IFM >= 16 if(is_wino_valid && is_ifm_ge_16) { - return ConvolutionMethod::WINOGRAD; + if(is_ofm_lte_8) + { + if(gpu_target == arm_compute::GPUTarget::G71 || gpu_target == arm_compute::GPUTarget::G72 || get_arch_from_target(gpu_target) == arm_compute::GPUTarget::MIDGARD) + { + return ConvolutionMethod::WINOGRAD; + } + } + else + { + return ConvolutionMethod::WINOGRAD; + } } // Direct convolution case if(is_direct_valid) { - if((gpu_target == arm_compute::GPUTarget::G71 || - gpu_target == arm_compute::GPUTarget::G72 || - gpu_target == arm_compute::GPUTarget::MIDGARD)) + if((gpu_target == arm_compute::GPUTarget::G71 || gpu_target == arm_compute::GPUTarget::G72 || get_arch_from_target(gpu_target) == arm_compute::GPUTarget::MIDGARD)) { if(is_large_kernel_sz && is_ifm_ge_16 && is_ifm_gt_ofm) { return ConvolutionMethod::DIRECT; } } - else + else if(gpu_target == arm_compute::GPUTarget::G76) { if((is_large_kernel_sz && workload_gte_8192 && is_ifm_ge_16) || (is_ofm_lte_8 && is_ifm_ge_16)) { return ConvolutionMethod::DIRECT; } } + else + { + if( ((is_large_kernel_sz || is_m_one) && workload_gte_8192) || is_ofm_lte_8 ) + { + return ConvolutionMethod::DIRECT; + } + } } // Default case diff --git a/src/gpu/cl/operators/ClDirectConv2d.cpp b/src/gpu/cl/operators/ClDirectConv2d.cpp index 53de6fc403..ded275dbae 100644 --- a/src/gpu/cl/operators/ClDirectConv2d.cpp +++ b/src/gpu/cl/operators/ClDirectConv2d.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -23,13 +23,22 @@ */ #include "src/gpu/cl/operators/ClDirectConv2d.h" +#include "arm_compute/core/KernelDescriptors.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/runtime/CL/CLScheduler.h" #include "src/core/CL/kernels/CLFillBorderKernel.h" +#include "src/core/helpers/AutoConfiguration.h" #include "src/gpu/cl/kernels/ClActivationKernel.h" #include "src/gpu/cl/kernels/ClDirectConv2dKernel.h" +#include "src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigBifrost.h" +#include "src/gpu/cl/kernels/direct_conv/ClDirectConvDefaultConfigValhall.h" +#include "src/gpu/cl/kernels/direct_conv/ClDirectConvKernelConfig.h" +#include "src/gpu/cl/kernels/direct_conv/IClDirectConvKernelConfig.h" #include "src/common/utils/Log.h" +using namespace arm_compute::cl_direct_conv; + namespace arm_compute { namespace opencl @@ -43,6 +52,17 @@ ITensorPack select_activation_src_dst(ITensorPack &tensors) pack.add_tensor(TensorType::ACL_DST, tensors.get_tensor(TensorType::ACL_DST)); return pack; } + +DirectConvComputeKernelInfo config_direct_convolution_nhwc(const ITensorInfo *src, const ITensorInfo *weights, const PadStrideInfo &conv_info) +{ + // Get GPU target + GPUTarget gpu_target = CLScheduler::get().target(); + + std::unique_ptr t = ClDirectConvKernelConfigurationFactory::create(gpu_target); + + return t->configure(src, weights, conv_info); +} + } // namespace void ClDirectConv2d::configure(const CLCompileContext &compile_context, ITensorInfo *src, ITensorInfo *weights, ITensorInfo *biases, ITensorInfo *dst, @@ -51,11 +71,14 @@ void ClDirectConv2d::configure(const CLCompileContext &compile_context, ITensorI ARM_COMPUTE_ERROR_ON_NULLPTR(src); ARM_COMPUTE_LOG_PARAMS(src, weights, biases, dst, conv_info, act_info); + // Initialize the direct convolution descriptor + const DirectConvComputeKernelInfo desc = config_direct_convolution_nhwc(src, weights, conv_info); + // Configure direct convolution kernel const ActivationLayerInfo conv2d_act_info = (src->data_layout() == DataLayout::NHWC && is_data_type_float(src->data_type())) ? act_info : ActivationLayerInfo(); auto k = std::make_unique(); k->set_target(CLScheduler::get().target()); - k->configure(compile_context, src, weights, biases, dst, conv_info, conv2d_act_info); + k->configure(compile_context, src, weights, biases, dst, conv_info, conv2d_act_info, desc); _direct_conv_kernel = std::move(k); // Configure border handler @@ -83,7 +106,10 @@ void ClDirectConv2d::configure(const CLCompileContext &compile_context, ITensorI Status ClDirectConv2d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const PadStrideInfo &conv_info, const ActivationLayerInfo &act_info) { - ARM_COMPUTE_RETURN_ON_ERROR(kernels::ClDirectConv2dKernel::validate(src, weights, biases, dst, conv_info, ActivationLayerInfo())); + // Initialize the direct convolution descriptor + const DirectConvComputeKernelInfo desc = config_direct_convolution_nhwc(src, weights, conv_info); + + ARM_COMPUTE_RETURN_ON_ERROR(kernels::ClDirectConv2dKernel::validate(src, weights, biases, dst, conv_info, ActivationLayerInfo(), desc)); if(act_info.enabled()) { ARM_COMPUTE_RETURN_ON_ERROR(kernels::ClActivationKernel::validate(dst, dst, act_info)); diff --git a/src/gpu/cl/operators/ClGemm.cpp b/src/gpu/cl/operators/ClGemm.cpp index 88f6b79b56..4db39a635d 100644 --- a/src/gpu/cl/operators/ClGemm.cpp +++ b/src/gpu/cl/operators/ClGemm.cpp @@ -191,6 +191,7 @@ ClGemm::ClGemm() _mm_native_kernel(std::make_unique()), _mm_reshaped_kernel(std::make_unique()), _mm_reshaped_only_rhs_kernel(std::make_unique()), + _mm_reshaped_only_rhs_mmul_kernel(std::make_unique()), _tmp_a(), _tmp_b(), _reshape_b_only_on_first_run(false), @@ -324,6 +325,53 @@ void ClGemm::configure_reshaped_only_rhs(const CLCompileContext &compile_context _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size()); } +void ClGemm::configure_reshaped_only_rhs_mmul(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, + const GEMMInfo &gemm_info) +{ + DataType data_type = a->data_type(); + bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); + const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1); + const unsigned int n = b->dimension(0); + const unsigned int k = a->dimension(0); + const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2); + const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); + const GPUTarget gpu_target = CLScheduler::get().target(); + bool broadcast_bias = gemm_info.broadcast_bias(); + + GEMMKernelInfo kernel_info; + kernel_info.m = m; + kernel_info.n = n; + kernel_info.k = k; + kernel_info.depth_output_gemm3d = depth_output_gemm3d; + kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d; + kernel_info.broadcast_bias = broadcast_bias; + kernel_info.activation_info = gemm_info.activation_info(); + kernel_info.post_ops = gemm_info.post_ops(); + + // Set the target for the kernels + _mm_reshaped_only_rhs_mmul_kernel->set_target(gpu_target); + + GEMMLHSMatrixInfo lhs_info{}; + GEMMRHSMatrixInfo rhs_info{}; + + // Pick up the GEMM configuration + auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }); + lhs_info = gemm_config.lhs_info; + rhs_info = gemm_config.rhs_info; + // Force H0 to 4 in order to use the MMUL extension + rhs_info.h0 = 4; + + // Reshape Rhs matrix + _reshape_rhs_kernel->configure(compile_context, b, &_tmp_b, rhs_info); + + // Configure matrix multiply kernel with no y padding support + kernel_info.has_pad_y = false; + _mm_reshaped_only_rhs_mmul_kernel->configure(compile_context, a, &_tmp_b, c, output, alpha, beta, lhs_info, rhs_info, kernel_info); + + // Request memory for RHS reshape matrix + _aux_mem[RhsReshape] = MemoryInfo(offset_int_vec(RhsReshape), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _tmp_b.total_size()); +} + Status ClGemm::validate_native(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info) { ARM_COMPUTE_UNUSED(alpha); @@ -458,6 +506,54 @@ Status ClGemm::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInf return Status{}; } +Status ClGemm::validate_reshaped_only_rhs_mmul(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info) +{ + ARM_COMPUTE_UNUSED(alpha); + ARM_COMPUTE_UNUSED(output); + TensorInfo tmp_b_info{}; + + // Get the GPU target + const GPUTarget gpu_target = CLScheduler::get().target(); + const DataType data_type = a->data_type(); + bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d(); + const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1); + const unsigned int n = b->dimension(0); + const unsigned int k = a->dimension(0); + const unsigned int batch_size = reinterpret_input_as_3d ? a->dimension(3) : a->dimension(2); + const int depth_output_gemm3d = gemm_info.depth_output_gemm3d(); + const bool broadcast_bias = gemm_info.broadcast_bias(); + + GEMMKernelInfo kernel_info; + kernel_info.m = m; + kernel_info.n = n; + kernel_info.k = k; + kernel_info.depth_output_gemm3d = depth_output_gemm3d; + kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d; + kernel_info.broadcast_bias = broadcast_bias; + kernel_info.activation_info = gemm_info.activation_info(); + kernel_info.post_ops = gemm_info.post_ops(); + + GEMMLHSMatrixInfo lhs_info; + GEMMRHSMatrixInfo rhs_info; + + // Pick up the GEMM configuration + // NOTE: No need to validate mlgo configurations as they automatically fall back to default heuristics if validation fails + const auto gemm_config = select_default_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery{ gpu_target, data_type, m, n, k, batch_size }); + lhs_info = gemm_config.lhs_info; + rhs_info = gemm_config.rhs_info; + // Force H0 to 4 in order to use the MMUL extension + rhs_info.h0 = 4; + + auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info))); + ARM_COMPUTE_RETURN_ON_ERROR(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info)); + + // Validate matrix multiply + kernel_info.has_pad_y = false; + ARM_COMPUTE_RETURN_ON_ERROR(ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel::validate(a, &tmp_b_info, c, output, alpha, beta, lhs_info, rhs_info, kernel_info)); + + return Status{}; +} + void ClGemm::configure(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info) { ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output); @@ -501,6 +597,11 @@ void ClGemm::configure(const CLCompileContext &compile_context, ITensorInfo *a, configure_reshaped_only_rhs(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info); break; } + case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL: + { + configure_reshaped_only_rhs_mmul(compile_context, a, b, c_to_use, output, alpha, beta, gemm_info); + break; + } default: { ARM_COMPUTE_ERROR("GEMMType not supported"); @@ -545,6 +646,11 @@ Status ClGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs(a, b, c_to_use, output, alpha, beta, gemm_info)); break; } + case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL: + { + ARM_COMPUTE_RETURN_ON_ERROR(validate_reshaped_only_rhs_mmul(a, b, c_to_use, output, alpha, beta, gemm_info)); + break; + } default: { ARM_COMPUTE_RETURN_ERROR_MSG("GEMMType not supported"); @@ -627,6 +733,34 @@ void ClGemm::run(ITensorPack &tensors) } break; } + case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL: + { + if(!_reshape_b_only_on_first_run) + { + // Run transpose kernel + ITensorPack reshape_rhs_pack{ { ACL_SRC, rhs }, { ACL_DST, rhs_reshaped.get() } }; + CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false); + } + // In case of RESHAPED_ONLY_RHS, we need to check the padding requirement + // Check if the lhs or dst tensors have padding + const unsigned int cross_plane_pad_lhs = lhs->info()->padding().top + lhs->info()->padding().bottom; + const unsigned int cross_plane_pad_dst = dst->info()->padding().top + dst->info()->padding().bottom; + bool has_pad_y = (cross_plane_pad_lhs != 0) || (cross_plane_pad_dst != 0); + + // Copy original tensor pack and overwrite rhs with reshaped counterpart + ITensorPack gemm_reshaped_onlyrhs_pack(tensors); + gemm_reshaped_onlyrhs_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get()); + + if(has_pad_y) + { + ARM_COMPUTE_ERROR_ON(has_pad_y); + } + else + { + CLScheduler::get().enqueue_op(*_mm_reshaped_only_rhs_mmul_kernel, gemm_reshaped_onlyrhs_pack, true); + } + break; + } default: { ARM_COMPUTE_ERROR("GEMMType not supported"); diff --git a/src/gpu/cl/operators/ClGemm.h b/src/gpu/cl/operators/ClGemm.h index 3c0cad3ca4..aac463f0b8 100644 --- a/src/gpu/cl/operators/ClGemm.h +++ b/src/gpu/cl/operators/ClGemm.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2021 Arm Limited. + * Copyright (c) 2016-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -34,6 +34,7 @@ #include "src/gpu/cl/kernels/ClGemmMatrixMultiplyNativeKernel.h" #include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.h" #include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.h" +#include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.h" #include "src/gpu/cl/kernels/ClGemmReshapeLhsMatrixKernel.h" #include "src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.h" @@ -50,6 +51,7 @@ namespace opencl * -# @ref kernels::ClGemmMatrixMultiplyNativeKernel (only if NATIVE is selected by the select_gemm_kernel method()) * -# @ref kernels::ClGemmMatrixMultiplyReshapedKernel (only if RESHAPED is selected by the select_gemm_kernel method()) * -# @ref kernels::ClGemmMatrixMultiplyReshapedOnlyRhsKernel (only if RESHAPED_ONLY_RHS is selected by the select_gemm_kernel method()) + * -# @ref kernels::ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel (only if RESHAPED_ONLY_RHS_MMUL is selected by the select_gemm_kernel method()) */ class ClGemm : public IClOperator { @@ -102,10 +104,12 @@ class ClGemm : public IClOperator void configure_native(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); void configure_reshaped(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); void configure_reshaped_only_rhs(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); + void configure_reshaped_only_rhs_mmul(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); static Status validate_native(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); static Status validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); static Status validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); + static Status validate_reshaped_only_rhs_mmul(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info); private: enum AuxTensorIdx @@ -116,17 +120,18 @@ class ClGemm : public IClOperator }; private: - std::unique_ptr _reshape_lhs_kernel; - std::unique_ptr _reshape_rhs_kernel; - std::unique_ptr _mm_native_kernel; - std::unique_ptr _mm_reshaped_kernel; - std::unique_ptr _mm_reshaped_only_rhs_kernel; - TensorInfo _tmp_a; - TensorInfo _tmp_b; - bool _reshape_b_only_on_first_run; - CLGEMMKernelType _gemm_kernel_type; - bool _is_prepared; - experimental::MemoryRequirements _aux_mem{}; + std::unique_ptr _reshape_lhs_kernel; + std::unique_ptr _reshape_rhs_kernel; + std::unique_ptr _mm_native_kernel; + std::unique_ptr _mm_reshaped_kernel; + std::unique_ptr _mm_reshaped_only_rhs_kernel; + std::unique_ptr _mm_reshaped_only_rhs_mmul_kernel; + TensorInfo _tmp_a; + TensorInfo _tmp_b; + bool _reshape_b_only_on_first_run; + CLGEMMKernelType _gemm_kernel_type; + bool _is_prepared; + experimental::MemoryRequirements _aux_mem{}; }; } // namespace opencl } // namespace arm_compute diff --git a/src/gpu/cl/operators/ClGemmLowpMatrixMultiplyCore.cpp b/src/gpu/cl/operators/ClGemmLowpMatrixMultiplyCore.cpp index 7a62186677..2622274587 100644 --- a/src/gpu/cl/operators/ClGemmLowpMatrixMultiplyCore.cpp +++ b/src/gpu/cl/operators/ClGemmLowpMatrixMultiplyCore.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -23,23 +23,15 @@ */ #include "src/gpu/cl/operators/ClGemmLowpMatrixMultiplyCore.h" -#include "arm_compute/core/CL/ICLTensor.h" -#include "arm_compute/core/Error.h" -#include "arm_compute/core/Helpers.h" -#include "arm_compute/core/KernelDescriptors.h" #include "arm_compute/core/Log.h" -#include "arm_compute/core/TensorInfo.h" -#include "arm_compute/core/Types.h" -#include "arm_compute/core/Validate.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" -#include "arm_compute/core/utils/quantization/AsymmHelpers.h" -#include "arm_compute/runtime/CL/CLScheduler.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/MemoryHelpers.h" #include "src/gpu/cl/kernels/ClCastKernel.h" #include "src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyNativeKernel.h" #include "src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedOnlyRhsKernel.h" +#include "src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel.h" #include "src/gpu/cl/kernels/ClGemmLowpOffsetContributionKernel.h" #include "src/gpu/cl/kernels/ClGemmLowpOffsetContributionOutputStageKernel.h" #include "src/gpu/cl/kernels/ClGemmLowpReductionKernel.h" @@ -47,9 +39,6 @@ #include "src/gpu/cl/utils/ClAuxTensorHandler.h" #include "src/runtime/CL/gemm_auto_heuristics/CLGEMMAutoHeuristics.h" -#include "src/common/utils/Log.h" -#include "utils/TypePrinter.h" - namespace arm_compute { namespace opencl @@ -67,6 +56,7 @@ inline bool validate_gemm_kernel(CLGEMMKernelType kernel_type) { case CLGEMMKernelType::NATIVE: case CLGEMMKernelType::RESHAPED_ONLY_RHS: + case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL: { return true; } @@ -165,6 +155,41 @@ inline bool validate_lhs_rhs_info_reshaped_only_rhs(const GEMMLHSMatrixInfo &lhs return true; } +// Validate lhs_info and rhs_info for reshaped only rhs kernel +inline bool validate_lhs_rhs_info_reshaped_only_rhs_mmul(const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *output, + unsigned int m, unsigned int n, unsigned int k, bool reinterpret_input_as_3d, int depth_output_gemm3d) +{ + // Validate GEMMLHSMatrixInfo and GEMMRHSMatrixInfo for reshaped only rhs kernel + TensorInfo tmp_b_info{}; + // Validate reshape RHS kernel + auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_rhs_reshaped_shape(*b, rhs_info))); + if(!bool(ClGemmReshapeRhsMatrixKernel::validate(b, &tmp_b_info, rhs_info))) + { + return false; + } + // Validate mm kernel + // NOTE: Ignore all other parameters (eg. depth_output_gemm3d, output stage etc.) and only validate lhs and rhs info + // NOTE: This assumes: + // 1. lhs and rhs info's validity does not depend on these other parameters and vice versa(in ClGemmLowpMatrixMultiplyReshapedOnlyRHSKernel.cpp validate_arguments). + // 2. lhs and rhs info does not cause window and padding issues through side effects (in ClGemmLowpMatrixMultiplyReshapedOnlyRHSKernel.cpp validate_and_configure_window). + GEMMKernelInfo gemm_kernel_info; + gemm_kernel_info.m = m; + gemm_kernel_info.n = n; + gemm_kernel_info.k = k; + gemm_kernel_info.reinterpret_input_as_3d = reinterpret_input_as_3d; + gemm_kernel_info.depth_output_gemm3d = depth_output_gemm3d; + gemm_kernel_info.lhs_info = lhs_info; + gemm_kernel_info.rhs_info = rhs_info; + // Since we ignore the output stage, output data type has to be S32 to pass the validation + TensorInfo output_info_copy(*output); + output_info_copy.set_data_type(DataType::S32); + if(!bool(ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel::validate(a, &tmp_b_info, &output_info_copy, gemm_kernel_info))) + { + return false; + } + return true; +} + // Automatically select between mlgo (prioritized) and default heuristics for reshaped only rhs kernel configs std::pair auto_select_gemm_config_reshaped_only_rhs(auto_heuristics::CommonQuery query, bool reinterpret_input_as_3d, int depth_output_gemm3d, const ITensorInfo *a, @@ -184,6 +209,19 @@ std::pair auto_select_gemm_config_reshaped return { config.lhs_info, config.rhs_info }; } +// Automatically select between mlgo (prioritized) and default heuristics for reshaped only rhs kernel configs +std::pair auto_select_gemm_config_reshaped_only_rhs_mmul(auto_heuristics::CommonQuery query, bool reinterpret_input_as_3d, int depth_output_gemm3d, + const ITensorInfo *a, + const ITensorInfo *b, const ITensorInfo *output) +{ + ARM_COMPUTE_UNUSED(a, b, output, reinterpret_input_as_3d, depth_output_gemm3d); + auto config = auto_heuristics::select_default_gemm_config_reshaped_only_rhs(query); + validate_lhs_rhs_info_reshaped_only_rhs_mmul(config.lhs_info, config.rhs_info, a, b, output, query.m, query.n, query.k, reinterpret_input_as_3d, depth_output_gemm3d); + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Use reshaped_only_rhs_mmul config from default heuristics: LHS info: %s ; RHS info: %s ", to_string(config.lhs_info).c_str(), + to_string(config.rhs_info).c_str()); + return { config.lhs_info, config.rhs_info }; +} + inline bool is_gemm_reshaped(CLGEMMKernelType kernel_type) { switch(kernel_type) @@ -191,6 +229,7 @@ inline bool is_gemm_reshaped(CLGEMMKernelType kernel_type) case CLGEMMKernelType::NATIVE: return false; case CLGEMMKernelType::RESHAPED_ONLY_RHS: + case CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL: return true; default: ARM_COMPUTE_ERROR("Not supported gemmlowp kernel!"); @@ -202,6 +241,7 @@ ClGemmLowpMatrixMultiplyCore::ClGemmLowpMatrixMultiplyCore() : _weights_to_qasymm8(std::make_unique()), _mm_native_kernel(std::make_unique()), _mm_reshaped_only_rhs_kernel(std::make_unique()), + _mm_reshaped_only_rhs_mmul_kernel(std::make_unique()), _mtx_b_reshape_kernel(std::make_unique()), _mtx_a_reduction_kernel(std::make_unique()), _mtx_b_reduction_kernel(std::make_unique()), @@ -218,7 +258,7 @@ void ClGemmLowpMatrixMultiplyCore::configure(const CLCompileContext &compile_con const GEMMInfo &gemm_info) { ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output); - ARM_COMPUTE_ERROR_THROW_ON(ClGemmLowpMatrixMultiplyCore::validate(a, b, c != nullptr ? c : nullptr, output, gemm_info)); + ARM_COMPUTE_ERROR_THROW_ON(ClGemmLowpMatrixMultiplyCore::validate(a, b, c, output, gemm_info)); ARM_COMPUTE_LOG_PARAMS(a, b, c, output, gemm_info); _reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run(); @@ -234,6 +274,7 @@ void ClGemmLowpMatrixMultiplyCore::configure(const CLCompileContext &compile_con // Set the target for the kernels _mm_native_kernel->set_target(gpu_target); _mm_reshaped_only_rhs_kernel->set_target(gpu_target); + _mm_reshaped_only_rhs_mmul_kernel->set_target(gpu_target); GEMMRHSMatrixInfo rhs_info; GEMMLHSMatrixInfo lhs_info; @@ -249,8 +290,7 @@ void ClGemmLowpMatrixMultiplyCore::configure(const CLCompileContext &compile_con const auto reshape_info = GEMMReshapeInfo(m, n, k, 1, 1, depth_output_gemm3d, reinterpret_input_as_3d); - // Check if we need to reshape the matrix A and matrix B - _is_gemm_reshaped = is_gemm_reshaped(auto_select_gemm_kernel(auto_heuristics::CommonQuery{ gpu_target, a->data_type(), m, n, k, batch_size }, _reshape_b_only_on_first_run)); + _gemm_kernel_type = auto_select_gemm_kernel(auto_heuristics::CommonQuery{ gpu_target, a->data_type(), m, n, k, batch_size }, _reshape_b_only_on_first_run); if(_convert_to_qasymm8) { @@ -261,7 +301,7 @@ void ClGemmLowpMatrixMultiplyCore::configure(const CLCompileContext &compile_con } ITensorInfo *matrix_b = _convert_to_qasymm8 ? &_qasymm8_weights : b; - if(_is_gemm_reshaped) + if(_gemm_kernel_type == CLGEMMKernelType::RESHAPED_ONLY_RHS) { matrix_b = &_tmp_b; @@ -274,6 +314,19 @@ void ClGemmLowpMatrixMultiplyCore::configure(const CLCompileContext &compile_con // Configure reshape RHS kernel _mtx_b_reshape_kernel->configure(compile_context, _convert_to_qasymm8 ? &_qasymm8_weights : b, &_tmp_b, rhs_info); } + if(_gemm_kernel_type == CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL) + { + matrix_b = &_tmp_b; + + // Pick up the GEMM configuration + // It doesn't matter whether Datatype is DataType::QASYMM8 or DataType::QASYMM8_SIGNED, since it only affect the shape configuration + std::tie(lhs_info, rhs_info) = auto_select_gemm_config_reshaped_only_rhs_mmul(auto_heuristics::CommonQuery{ gpu_target, DataType::QASYMM8, m, n, k, batch_size }, reinterpret_input_as_3d, + depth_output_gemm3d, + a, _convert_to_qasymm8 ? &_qasymm8_weights : b, output); + + // Configure reshape RHS kernel + _mtx_b_reshape_kernel->configure(compile_context, _convert_to_qasymm8 ? &_qasymm8_weights : b, &_tmp_b, rhs_info); + } // Using default reduction info const GEMMLowpReductionKernelInfo reduction_info {}; @@ -326,20 +379,30 @@ void ClGemmLowpMatrixMultiplyCore::configure(const CLCompileContext &compile_con gemm_kernel_info.output_stage = gemmlowp_output_stage; - if(_is_gemm_reshaped && gemmlowp_output_stage.type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT) + if(_gemm_kernel_type == CLGEMMKernelType::RESHAPED_ONLY_RHS && gemmlowp_output_stage.type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT) { // Configure and tune matrix multiply kernel with fused output stage _mm_reshaped_only_rhs_kernel->configure(compile_context, a, matrix_b, output, gemm_kernel_info, _a_offset == 0 ? nullptr : &_vector_sum_col, _b_offset == 0 ? nullptr : &_vector_sum_row, c != nullptr ? c : nullptr, &_gemm_output_stage_multipliers, &_gemm_output_stage_shifts); } + else if(_gemm_kernel_type == CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL && gemmlowp_output_stage.type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT) + { + // Configure and tune matrix multiply kernel with fused output stage + _mm_reshaped_only_rhs_mmul_kernel->configure(compile_context, a, matrix_b, output, gemm_kernel_info, _a_offset == 0 ? nullptr : &_vector_sum_col, + _b_offset == 0 ? nullptr : &_vector_sum_row, c != nullptr ? c : nullptr, &_gemm_output_stage_multipliers, &_gemm_output_stage_shifts); + } else { _run_output_stage = true; - if(_is_gemm_reshaped) + if(_gemm_kernel_type == CLGEMMKernelType::RESHAPED_ONLY_RHS) { _mm_reshaped_only_rhs_kernel->configure(compile_context, a, matrix_b, &_mm_result_s32, gemm_kernel_info); } + if(_gemm_kernel_type == CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL) + { + _mm_reshaped_only_rhs_mmul_kernel->configure(compile_context, a, matrix_b, &_mm_result_s32, gemm_kernel_info); + } else { // Pick up the GEMM configuration @@ -359,11 +422,16 @@ void ClGemmLowpMatrixMultiplyCore::configure(const CLCompileContext &compile_con else { _run_offset_contribution = true; - if(_is_gemm_reshaped) + if(_gemm_kernel_type == CLGEMMKernelType::RESHAPED_ONLY_RHS) { // Configure and tune matrix multiply kernel _mm_reshaped_only_rhs_kernel->configure(compile_context, a, matrix_b, output, gemm_kernel_info); } + else if(_gemm_kernel_type == CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL) + { + // Configure and tune matrix multiply kernel + _mm_reshaped_only_rhs_mmul_kernel->configure(compile_context, a, matrix_b, output, gemm_kernel_info); + } else { // Pick up the GEMM configuration @@ -382,7 +450,7 @@ void ClGemmLowpMatrixMultiplyCore::configure(const CLCompileContext &compile_con // Request memory _aux_mem[RhsQAsymm8] = MemoryInfo(offset_int_vec(RhsQAsymm8), _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary, _qasymm8_weights.total_size()); - if(_is_gemm_reshaped) + if(is_gemm_reshaped(_gemm_kernel_type)) { // Overwrite Rhs as prepare if gemm is reshaped as there will be a two-step transformation _aux_mem[RhsQAsymm8] = MemoryInfo(offset_int_vec(RhsQAsymm8), _reshape_b_only_on_first_run ? MemoryLifetime::Prepare : MemoryLifetime::Temporary, _qasymm8_weights.total_size()); @@ -607,7 +675,7 @@ void ClGemmLowpMatrixMultiplyCore::run(ITensorPack &tensors) const ITensor *matrix_a = a; const ITensor *matrix_b = _convert_to_qasymm8 ? rhs_qasymm8.get() : b; - if(_is_gemm_reshaped) + if(is_gemm_reshaped(_gemm_kernel_type)) { matrix_b = tmp_b.get(); if(!_reshape_b_only_on_first_run) @@ -645,7 +713,7 @@ void ClGemmLowpMatrixMultiplyCore::run(ITensorPack &tensors) } // Run matrix multiply - if(_is_gemm_reshaped) + if(is_gemm_reshaped(_gemm_kernel_type)) { ITensorPack gemm_reshaped_pack; if(_run_offset_contribution) @@ -669,7 +737,18 @@ void ClGemmLowpMatrixMultiplyCore::run(ITensorPack &tensors) { TensorType::ACL_DST, dst }, }); } - CLScheduler::get().enqueue_op(*_mm_reshaped_only_rhs_kernel, gemm_reshaped_pack, false); + if(_gemm_kernel_type == CLGEMMKernelType::RESHAPED_ONLY_RHS) + { + CLScheduler::get().enqueue_op(*_mm_reshaped_only_rhs_kernel, gemm_reshaped_pack, false); + } + else if(_gemm_kernel_type == CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL) + { + CLScheduler::get().enqueue_op(*_mm_reshaped_only_rhs_mmul_kernel, gemm_reshaped_pack, false); + } + else + { + ARM_COMPUTE_ERROR("Invalid reshaped kernel"); + } } else { @@ -728,7 +807,7 @@ void ClGemmLowpMatrixMultiplyCore::prepare(ITensorPack &tensors) b->mark_as_unused(); } - if(_is_gemm_reshaped && _reshape_b_only_on_first_run) + if(is_gemm_reshaped(_gemm_kernel_type) && _reshape_b_only_on_first_run) { // Run reshape kernel and mark original weights tensor as unused ITensorPack mtx_b_pack = diff --git a/src/gpu/cl/operators/ClGemmLowpMatrixMultiplyCore.h b/src/gpu/cl/operators/ClGemmLowpMatrixMultiplyCore.h index 1965e3f97b..6fa4352bf8 100644 --- a/src/gpu/cl/operators/ClGemmLowpMatrixMultiplyCore.h +++ b/src/gpu/cl/operators/ClGemmLowpMatrixMultiplyCore.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -40,6 +40,7 @@ namespace kernels class ClCastKernel; class ClGemmLowpMatrixMultiplyNativeKernel; class ClGemmLowpMatrixMultiplyReshapedOnlyRhsKernel; +class ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel; class ClGemmReshapeRhsMatrixKernel; class ClGemmLowpMatrixAReductionKernel; class ClGemmLowpMatrixBReductionKernel; @@ -120,14 +121,15 @@ class ClGemmLowpMatrixMultiplyCore : public IClOperator private: // Kernels used - std::unique_ptr _weights_to_qasymm8; - std::unique_ptr _mm_native_kernel; - std::unique_ptr _mm_reshaped_only_rhs_kernel; - std::unique_ptr _mtx_b_reshape_kernel; - std::unique_ptr _mtx_a_reduction_kernel; - std::unique_ptr _mtx_b_reduction_kernel; - std::unique_ptr _offset_contribution_kernel; - std::unique_ptr _offset_contribution_output_stage_kernel; + std::unique_ptr _weights_to_qasymm8; + std::unique_ptr _mm_native_kernel; + std::unique_ptr _mm_reshaped_only_rhs_kernel; + std::unique_ptr _mm_reshaped_only_rhs_mmul_kernel; + std::unique_ptr _mtx_b_reshape_kernel; + std::unique_ptr _mtx_a_reduction_kernel; + std::unique_ptr _mtx_b_reduction_kernel; + std::unique_ptr _offset_contribution_kernel; + std::unique_ptr _offset_contribution_output_stage_kernel; // Temporary tensors TensorInfo _qasymm8_weights{}; @@ -138,15 +140,15 @@ class ClGemmLowpMatrixMultiplyCore : public IClOperator TensorInfo _gemm_output_stage_multipliers{}; TensorInfo _gemm_output_stage_shifts{}; - int32_t _a_offset{ 0 }; - int32_t _b_offset{ 0 }; - bool _is_gemm_reshaped{ true }; - bool _reshape_b_only_on_first_run{ false }; - bool _run_output_stage{ false }; - bool _convert_to_qasymm8{ false }; - bool _run_offset_contribution{ false }; - bool _is_prepared{ false }; - GEMMInfo _gemm_info{}; + int32_t _a_offset{ 0 }; + int32_t _b_offset{ 0 }; + bool _reshape_b_only_on_first_run{ false }; + bool _run_output_stage{ false }; + bool _convert_to_qasymm8{ false }; + bool _run_offset_contribution{ false }; + bool _is_prepared{ false }; + GEMMInfo _gemm_info{}; + CLGEMMKernelType _gemm_kernel_type{}; experimental::MemoryRequirements _aux_mem{}; }; diff --git a/src/gpu/cl/operators/ClWinogradConv2d.cpp b/src/gpu/cl/operators/ClWinogradConv2d.cpp index ffa1effc74..b4163a5986 100644 --- a/src/gpu/cl/operators/ClWinogradConv2d.cpp +++ b/src/gpu/cl/operators/ClWinogradConv2d.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021 Arm Limited. + * Copyright (c) 2018-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -214,6 +214,7 @@ void ClWinogradConv2d::configure(const ClCompileContext &compile_context, ITenso (src->data_type() == DataType::F16))); // Configure output transform + _output_transform->set_target(CLScheduler::get().target()); _output_transform->configure(compile_context, &_batched_mm_output, biases, dst, winograd_info, act_info); _aux_mem = _batched_mm.workspace(); diff --git a/src/runtime/CL/CLScheduler.cpp b/src/runtime/CL/CLScheduler.cpp index 26124ed7e9..8d30c05361 100644 --- a/src/runtime/CL/CLScheduler.cpp +++ b/src/runtime/CL/CLScheduler.cpp @@ -137,8 +137,9 @@ void CLScheduler::default_init(ICLTuner *cl_tuner, CLGEMMHeuristicsHandle *gemm_ init(ctx, queue, dev, cl_tuner, gemm_h); } - // Set CL tuner - _cl_tuner = cl_tuner; + // Set CL tuner and GEMM heuristics + _cl_tuner = cl_tuner; + _gemm_heuristics = gemm_h; } void CLScheduler::default_reinit(ICLTuner *cl_tuner, CLGEMMHeuristicsHandle *gemm_h, CLBackendType cl_backend_type) diff --git a/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp b/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp index 2ee23c4262..e821726d0e 100644 --- a/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp +++ b/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -227,6 +227,7 @@ void CLDepthwiseConvolutionLayer::configure(const CLCompileContext &compile_cont const ConvolutionInfo conv_kernel_info{ conv_info, depth_multiplier, act_info, dilation }; + _dwc_native_kernel->set_target(gpu_target); _dwc_native_kernel->configure(compile_context, input_to_use, weights_to_use, biases, output_to_use, dwc_native_compute_info, conv_kernel_info, output_multipliers_to_use, output_shifts_to_use); diff --git a/src/runtime/CL/functions/CLGEMM.cpp b/src/runtime/CL/functions/CLGEMM.cpp index cc6689c504..427ea51ab9 100644 --- a/src/runtime/CL/functions/CLGEMM.cpp +++ b/src/runtime/CL/functions/CLGEMM.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -30,7 +30,6 @@ #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/Utils.h" -#include "arm_compute/runtime/CL/functions/CLGEMM.h" #include "src/core/helpers/MemoryHelpers.h" #include "src/gpu/cl/operators/ClGemm.h" diff --git a/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp b/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp index 64271a8801..4c7daf916e 100644 --- a/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp +++ b/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 Arm Limited. + * Copyright (c) 2020-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -79,10 +79,28 @@ CLGEMMKernelType CLGEMMDefaultTypeValhall::select_kernel(const CLGEMMKernelSelec { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeValhall::default_q8 } }; + // Mali-G715 and Mali-G615 configurations + static std::map gemm_g715_configs = + { + { DataType::F32, &CLGEMMDefaultTypeValhall::g715_f32 }, + { DataType::F16, &CLGEMMDefaultTypeValhall::g715_f16 }, + { DataType::QASYMM8, &CLGEMMDefaultTypeValhall::default_q8 }, + { DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeValhall::default_q8 }, + { DataType::QSYMM8, &CLGEMMDefaultTypeValhall::default_q8 }, + { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeValhall::default_q8 } + }; + const DataType data_type = params.data_type; switch(_target) { + case GPUTarget::G715: + case GPUTarget::G615: + if(gemm_g715_configs.find(data_type) != gemm_g715_configs.end()) + { + return (this->*gemm_g715_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant); + } + ARM_COMPUTE_ERROR("Not supported data type"); case GPUTarget::G78: if(gemm_g78_configs.find(data_type) != gemm_g78_configs.end()) { @@ -306,5 +324,46 @@ CLGEMMKernelType CLGEMMDefaultTypeValhall::g78_f16(unsigned int m, unsigned int return CLGEMMKernelType::RESHAPED_ONLY_RHS; } + +CLGEMMKernelType CLGEMMDefaultTypeValhall::g715_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) +{ + if(!is_rhs_constant) + { + return default_f32(m, n, k, b, is_rhs_constant); + } + + unsigned int best_m0; + unsigned int best_n0; + + if(opencl::kernels::gemm::is_mmul_kernel_preferred(m, n, k, b, DataType::F32, best_m0, best_n0)) + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL; + } + else + { + return default_f32(m, n, k, b, is_rhs_constant); + } +} + +CLGEMMKernelType CLGEMMDefaultTypeValhall::g715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant) +{ + if(!is_rhs_constant) + { + return g78_f16(m, n, k, b, is_rhs_constant); + } + + unsigned int best_m0; + unsigned int best_n0; + + if(opencl::kernels::gemm::is_mmul_kernel_preferred(m, n, k, b, DataType::F16, best_m0, best_n0)) + { + return CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL; + } + else + { + return g78_f16(m, n, k, b, is_rhs_constant); + } +} + } // namespace cl_gemm } // namespace arm_compute diff --git a/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h b/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h index c88fbcf557..0893f11132 100644 --- a/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h +++ b/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 Arm Limited. + * Copyright (c) 2020-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -50,6 +50,8 @@ class CLGEMMDefaultTypeValhall final : public ICLGEMMKernelSelection CLGEMMKernelType g77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); CLGEMMKernelType g78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); CLGEMMKernelType g78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); + CLGEMMKernelType g715_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); + CLGEMMKernelType g715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant); }; } // namespace cl_gemm } // namespace arm_compute diff --git a/src/runtime/IScheduler.cpp b/src/runtime/IScheduler.cpp index 1d068c9b38..39f41555fa 100644 --- a/src/runtime/IScheduler.cpp +++ b/src/runtime/IScheduler.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2021 Arm Limited. + * Copyright (c) 2016-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -188,7 +188,7 @@ std::size_t IScheduler::adjust_num_of_windows(const Window &window, std::size_t recommended_split_dim = dims; } } - ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("%lu dimension is not a suitable dimension to split the workload. Recommended: %lu recommended_split_dim", split_dimension, + ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("%zu dimension is not a suitable dimension to split the workload. Recommended: %zu recommended_split_dim", split_dimension, recommended_split_dim); } diff --git a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp index 77028d96a2..4f858fb54b 100644 --- a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp +++ b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -61,7 +61,7 @@ NEFullyConnectedLayer::NEFullyConnectedLayer(std::shared_ptr mem } void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, - FullyConnectedLayerInfo fc_info) + FullyConnectedLayerInfo fc_info, const WeightsInfo &weights_info) { // Perform validate step ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output); @@ -76,7 +76,7 @@ void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weigh _impl->original_weights = weights; _impl->is_prepared = false; - _impl->op->configure(input->info(), weights->info(), (biases != nullptr) ? biases->info() : nullptr, output->info(), fc_info); + _impl->op->configure(input->info(), weights->info(), (biases != nullptr) ? biases->info() : nullptr, output->info(), fc_info, weights_info); if(_impl->weights_manager != nullptr) { @@ -88,6 +88,13 @@ void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weigh _impl->workspace = manage_workspace(_impl->aux_mem_req, _impl->memory_group, _impl->run_pack, _impl->run_pack); } +Status NEFullyConnectedLayer::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *input, const ITensorInfo *weights, + const ITensorInfo *biases, const ITensorInfo *output, const FullyConnectedLayerInfo &fc_info, + const WeightsInfo &weights_info) +{ + return cpu::CpuFullyConnected::has_opt_impl(expected_weight_format, input, weights, biases, output, fc_info, weights_info); +} + Status NEFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, FullyConnectedLayerInfo fc_info) { diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp index 58ade9fb3a..0266c48f86 100644 --- a/src/runtime/NEON/functions/NEGEMM.cpp +++ b/src/runtime/NEON/functions/NEGEMM.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -84,6 +84,13 @@ Status NEGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso return cpu::CpuGemm::validate(a, b, c, output, alpha, beta, gemm_info); } +Status NEGEMM::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, + float alpha, float beta, const GEMMInfo &gemm_info) +{ + ARM_COMPUTE_UNUSED(alpha, beta); + return cpu::CpuGemm::has_opt_impl(expected_weight_format, a, b, c, output, gemm_info); +} + void NEGEMM::run() { prepare(); diff --git a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp index c780d63763..fe3ea6a767 100644 --- a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp +++ b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -58,6 +58,7 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig const Size2D &dilation, const ActivationLayerInfo &act_info, bool enable_fast_math, unsigned int num_groups) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output); + _impl->weights = weights; _impl->op = std::make_unique(); _impl->op->configure(input->info(), weights->info(), (biases != nullptr ? biases->info() : nullptr), output->info(), conv_info, weights_info, dilation, act_info, enable_fast_math, num_groups); @@ -79,6 +80,13 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI return cpu::CpuGemmConv2d::validate(input, weights, biases, output, conv_info, weights_info, dilation, act_info, enable_fast_math, num_groups); } +Status NEGEMMConvolutionLayer::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, + const PadStrideInfo &conv_info, + const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, const bool enable_fast_math) +{ + return cpu::CpuGemmConv2d::has_opt_impl(expected_weight_format, src, weights, biases, dst, conv_info, weights_info, dilation, act_info, enable_fast_math); +} + void NEGEMMConvolutionLayer::run() { prepare(); diff --git a/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp b/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp index f0c153d4f4..a8eded29ff 100644 --- a/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp +++ b/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -34,7 +34,6 @@ #include "src/cpu/operators/CpuWinogradConv2d.h" #include "src/core/NEON/kernels/convolution/common/utils.hpp" -#include "src/core/NEON/kernels/convolution/winograd/winograd.hpp" namespace arm_compute { diff --git a/src/runtime/OMP/OMPScheduler.cpp b/src/runtime/OMP/OMPScheduler.cpp index e9b0bf4426..aad24b4f01 100644 --- a/src/runtime/OMP/OMPScheduler.cpp +++ b/src/runtime/OMP/OMPScheduler.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -89,20 +89,21 @@ void OMPScheduler::schedule_op(ICPPKernel *kernel, const Hints &hints, const Win #ifndef DOXYGEN_SKIP_THIS void OMPScheduler::run_workloads(std::vector &workloads) { - const unsigned int num_threads = std::min(_num_threads, static_cast(workloads.size())); - if(num_threads < 1) + const unsigned int amount_of_work = static_cast(workloads.size()); + if(amount_of_work < 1 || _num_threads == 1) { return; } ThreadInfo info; info.cpu_info = &cpu_info(); - info.num_threads = num_threads; - #pragma omp parallel firstprivate(info) num_threads(num_threads) + info.num_threads = _num_threads; + #pragma omp parallel for firstprivate(info) num_threads(_num_threads) default(shared) proc_bind(close) schedule(static, 1) + for(unsigned int wid = 0; wid < amount_of_work; ++wid) { const int tid = omp_get_thread_num(); info.thread_id = tid; - workloads[tid](info); + workloads[wid](info); } } #endif /* DOXYGEN_SKIP_THIS */ diff --git a/support/Bfloat16.h b/support/Bfloat16.h index 173f2d16e2..5fd45cf209 100644 --- a/support/Bfloat16.h +++ b/support/Bfloat16.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021 Arm Limited. + * Copyright (c) 2020-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -40,7 +40,7 @@ namespace inline uint16_t float_to_bf16(const float v) { const uint32_t *fromptr = reinterpret_cast(&v); -#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) +#if defined(ARM_COMPUTE_ENABLE_BF16) uint16_t res; __asm __volatile( @@ -50,7 +50,7 @@ inline uint16_t float_to_bf16(const float v) : : [fromptr] "r"(fromptr), [toptr] "r"(&res) : "v0", "memory"); -#else /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ +#else /* defined(ARM_COMPUTE_ENABLE_BF16) */ uint16_t res = (*fromptr >> 16); const uint16_t error = (*fromptr & 0x0000ffff); uint16_t bf_l = res & 0x0001; @@ -58,7 +58,7 @@ inline uint16_t float_to_bf16(const float v) { res += 1; } -#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ +#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ return res; } diff --git a/support/ToolchainSupport.h b/support/ToolchainSupport.h index 8ea50ebe15..0557d1d775 100644 --- a/support/ToolchainSupport.h +++ b/support/ToolchainSupport.h @@ -297,7 +297,7 @@ inline bfloat16 lowest() template ::value>::type> inline bool isfinite(T value) { - return std::isfinite(value); + return std::isfinite(static_cast(value)); } inline bool isfinite(half_float::half value) @@ -310,12 +310,11 @@ inline bool isfinite(bfloat16 value) return std::isfinite(float(value)); } -#if !defined(_WIN64) // std::signbit template ::value>::type> inline bool signbit(T value) { - return std::signbit(value); + return std::signbit(static_cast(value)); } inline bool signbit(half_float::half value) @@ -327,7 +326,6 @@ inline bool signbit(bfloat16 value) { return std::signbit(float(value)); } -#endif // !defined(_WIN64) } // namespace cpp11 } // namespace support } // namespace arm_compute diff --git a/tests/AssetsLibrary.h b/tests/AssetsLibrary.h index 80b9ecbd92..bd97cb7bd4 100644 --- a/tests/AssetsLibrary.h +++ b/tests/AssetsLibrary.h @@ -725,7 +725,7 @@ void AssetsLibrary::fill_tensor_uniform(T &&tensor, std::random_device::result_t case DataType::U8: case DataType::QASYMM8: { - std::uniform_int_distribution distribution_u8(std::numeric_limits::lowest(), std::numeric_limits::max()); + std::uniform_int_distribution distribution_u8(std::numeric_limits::lowest(), std::numeric_limits::max()); fill(tensor, distribution_u8, seed_offset); break; } @@ -734,7 +734,7 @@ void AssetsLibrary::fill_tensor_uniform(T &&tensor, std::random_device::result_t case DataType::QSYMM8_PER_CHANNEL: case DataType::QASYMM8_SIGNED: { - std::uniform_int_distribution distribution_s8(std::numeric_limits::lowest(), std::numeric_limits::max()); + std::uniform_int_distribution distribution_s8(std::numeric_limits::lowest(), std::numeric_limits::max()); fill(tensor, distribution_s8, seed_offset); break; } @@ -826,20 +826,20 @@ void AssetsLibrary::fill_tensor_uniform_ranged(T case DataType::U8: case DataType::QASYMM8: { - const auto converted_pairs = detail::convert_range_pair(excluded_range_pairs); - RangedUniformDistribution distribution_u8(std::numeric_limits::lowest(), - std::numeric_limits::max(), - converted_pairs); + const auto converted_pairs = detail::convert_range_pair(excluded_range_pairs); + RangedUniformDistribution distribution_u8(std::numeric_limits::lowest(), + std::numeric_limits::max(), + converted_pairs); fill(tensor, distribution_u8, seed_offset); break; } case DataType::S8: case DataType::QSYMM8: { - const auto converted_pairs = detail::convert_range_pair(excluded_range_pairs); - RangedUniformDistribution distribution_s8(std::numeric_limits::lowest(), - std::numeric_limits::max(), - converted_pairs); + const auto converted_pairs = detail::convert_range_pair(excluded_range_pairs); + RangedUniformDistribution distribution_s8(std::numeric_limits::lowest(), + std::numeric_limits::max(), + converted_pairs); fill(tensor, distribution_s8, seed_offset); break; } @@ -918,7 +918,7 @@ void AssetsLibrary::fill_tensor_uniform(T &&tensor, std::random_device::result_t case DataType::QASYMM8: { ARM_COMPUTE_ERROR_ON(!(std::is_same::value)); - std::uniform_int_distribution distribution_u8(low, high); + std::uniform_int_distribution distribution_u8(low, high); fill(tensor, distribution_u8, seed_offset); break; } @@ -927,7 +927,7 @@ void AssetsLibrary::fill_tensor_uniform(T &&tensor, std::random_device::result_t case DataType::QASYMM8_SIGNED: { ARM_COMPUTE_ERROR_ON(!(std::is_same::value)); - std::uniform_int_distribution distribution_s8(low, high); + std::uniform_int_distribution distribution_s8(low, high); fill(tensor, distribution_s8, seed_offset); break; } diff --git a/tests/SConscript b/tests/SConscript index 87907f40fc..b848f27043 100644 --- a/tests/SConscript +++ b/tests/SConscript @@ -1,3 +1,4 @@ +#!/usr/bin/python # -*- coding: utf-8 -*- # Copyright (c) 2017-2022 Arm Limited. @@ -67,7 +68,8 @@ Import("arm_compute_test_framework") test_env.Append(LIBS = arm_compute_test_framework) # Disable floating-point expression contraction (e.g. fused multiply-add operations) -test_env.Append(CXXFLAGS = ['-ffp-contract=off']) +if not 'windows' in env['os']: + test_env.Append(CXXFLAGS = ['-ffp-contract=off']) # Remove -Wnoexcept from tests if 'g++' in test_env['CXX'] and '-Wnoexcept' in test_env['CXXFLAGS']: @@ -83,7 +85,10 @@ if env['os'] in ['android', 'macos', 'bare_metal'] or env['standalone']: Import("arm_compute_a") Import("arm_compute_core_a") Import("arm_compute_graph_a") - test_env.Append(LIBS = [arm_compute_graph_a, arm_compute_a, arm_compute_core_a]) + if env['os']=='windows': + test_env.Append(LIBS = [arm_compute_graph_a, arm_compute_a]) + else: + test_env.Append(LIBS = [arm_compute_graph_a, arm_compute_a, arm_compute_core_a]) arm_compute_lib = arm_compute_graph_a else: Import("arm_compute_graph_so") @@ -156,7 +161,7 @@ if env['neon']: extra_link_flags = [] if env['os'] == 'android': test_env.Append(LIBS = ["log"]) -elif env['os'] not in ['bare_metal', 'macos']: +elif env['os'] not in ['windows','bare_metal', 'macos']: test_env.Append(LIBS = ["rt"]) extra_link_flags += ['-fstack-protector-strong'] @@ -172,13 +177,17 @@ bm_link_flags = [] if test_env['linker_script']: bm_link_flags += ['-Wl,--build-id=none', '-T', env['linker_script']] -if test_env['reference_openmp'] and env['os'] not in ['bare_metal', 'macos']: +if test_env['reference_openmp'] and env['os'] not in ['bare_metal', 'macos','windows']: test_env['CXXFLAGS'].append('-fopenmp') test_env['LINKFLAGS'].append('-fopenmp') if 'ndk_above_r21' in env: test_env['LINKFLAGS'].append('-static-openmp') +# Testing for fixed format GEMM kernels. +if env['experimental_fixed_format_kernels'] and test_env['validation_tests']: + test_env.Append(CPPDEFINES = ['ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS']) + if test_env['validation_tests']: arm_compute_validation_framework = env.StaticLibrary('arm_compute_validation_framework', Glob('validation/reference/*.cpp') + Glob('validation/*.cpp'), LINKFLAGS=test_env['LINKFLAGS'], CXXFLAGS=test_env['CXXFLAGS'], LIBS= [ arm_compute_test_framework, arm_compute_core_a]) Depends(arm_compute_validation_framework , arm_compute_test_framework) @@ -299,4 +308,4 @@ if test_env['benchmark_examples']: Depends(arm_compute_benchmark_examples, arm_compute_test_framework) Depends(arm_compute_benchmark_examples, arm_compute_lib) Default(arm_compute_benchmark_examples) - Export('arm_compute_benchmark_examples') \ No newline at end of file + Export('arm_compute_benchmark_examples') diff --git a/tests/datasets/GatherDataset.h b/tests/datasets/GatherDataset.h index 29a99d5239..8fec5441b1 100644 --- a/tests/datasets/GatherDataset.h +++ b/tests/datasets/GatherDataset.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 Arm Limited. + * Copyright (c) 2018-2019, 2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -106,6 +106,19 @@ class GatherDataset std::vector _axis{}; }; +class SmallGatherMultiDimIndicesDataset final : public GatherDataset +{ +public: + SmallGatherMultiDimIndicesDataset() + { + add_config(TensorShape(2U, 6U), TensorShape(4U, 9U), 1); + add_config(TensorShape(15U, 15U), TensorShape(3U, 2U, 2U), 1); + add_config(TensorShape(15U, 15U), TensorShape(2U, 11U), 1); + add_config(TensorShape(5U, 3U, 4U), TensorShape(2U, 7U), 1); + add_config(TensorShape(1U, 5U, 3U), TensorShape(1U, 7U, 3U), 1); + } +}; + class SmallGatherDataset final : public GatherDataset { public: diff --git a/tests/framework/Asserts.h b/tests/framework/Asserts.h index 28d3da9a85..7adfa8f2f3 100644 --- a/tests/framework/Asserts.h +++ b/tests/framework/Asserts.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -30,6 +30,8 @@ #include #include +#include "utils/TypePrinter.h" + namespace arm_compute { namespace test @@ -42,6 +44,11 @@ inline int make_printable(int8_t value) return value; } +inline std::string make_printable(const arm_compute::WeightFormat wf) +{ + return arm_compute::to_string(wf); +} + inline unsigned int make_printable(uint8_t value) { return value; diff --git a/tests/framework/SConscript b/tests/framework/SConscript index c4fe50db05..450ffd77b0 100644 --- a/tests/framework/SConscript +++ b/tests/framework/SConscript @@ -1,6 +1,7 @@ +#!/usr/bin/python # -*- coding: utf-8 -*- -# Copyright (c) 2017-2021 Arm Limited. +# Copyright (c) 2017-2022 Arm Limited. # # SPDX-License-Identifier: MIT # @@ -29,8 +30,8 @@ Import('vars') # vars is imported from arm_compute: variables = [ - BoolVariable("pmu", "Enable PMU counters", False), - BoolVariable("mali", "Enable Arm® Mali™ hardware counters", False), + BoolVariable("pmu", "Enable the PMU cycle counter to measure execution time in benchmark tests. (Your device needs to support it)", False), + BoolVariable("mali", "Enable the collection of Arm® Mali™ hardware counters to measure execution time in benchmark tests. (Your device needs to have a Arm® Mali™ driver that supports it)", False), ] # We need a separate set of Variables for the Help message (Otherwise the global variables will get displayed twice) diff --git a/tests/main.cpp b/tests/main.cpp index bc264de378..ba18339e50 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -240,6 +240,7 @@ int main(int argc, char **argv) p->print_entry("cpu_has_fp16", support::cpp11::to_string(cpu_info.has_fp16())); p->print_entry("cpu_has_bf16", support::cpp11::to_string(cpu_info.has_bf16())); p->print_entry("cpu_has_dotprod", support::cpp11::to_string(cpu_info.has_dotprod())); + p->print_entry("cpu_has_svebf16", support::cpp11::to_string(cpu_info.has_svebf16())); for(unsigned int j = 0; j < num_cpus; ++j) { diff --git a/tests/validate_examples/cl_gemm.cpp b/tests/validate_examples/cl_gemm.cpp index 82dfc053db..8189b228c2 100644 --- a/tests/validate_examples/cl_gemm.cpp +++ b/tests/validate_examples/cl_gemm.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -32,19 +32,19 @@ #include "arm_compute/runtime/CL/functions/CLGEMM.h" #include "arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h" #include "arm_compute/runtime/CL/functions/CLGEMMLowpOutputStage.h" -#include "src/core/CL/kernels/CLDepthConvertLayerKernel.h" #include "src/core/CL/kernels/CLFillBorderKernel.h" -#include "src/core/CL/kernels/CLGEMMLowpMatrixMultiplyNativeKernel.h" -#include "src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h" -#include "src/core/CL/kernels/CLGEMMLowpOffsetContributionKernel.h" -#include "src/core/CL/kernels/CLGEMMLowpOffsetContributionOutputStageKernel.h" -#include "src/core/CL/kernels/CLGEMMLowpReductionKernel.h" -#include "src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h" -#include "src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h" -#include "src/core/CL/kernels/CLGEMMReshapeLHSMatrixKernel.h" -#include "src/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.h" -#include "src/core/CL/kernels/CLIm2ColKernel.h" -#include "src/core/CL/kernels/CLWeightsReshapeKernel.h" +#include "src/gpu/cl/kernels/ClCastKernel.h" +#include "src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyNativeKernel.h" +#include "src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedOnlyRhsKernel.h" +#include "src/gpu/cl/kernels/ClGemmLowpOffsetContributionKernel.h" +#include "src/gpu/cl/kernels/ClGemmLowpOffsetContributionOutputStageKernel.h" +#include "src/gpu/cl/kernels/ClGemmLowpReductionKernel.h" +#include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedKernel.h" +#include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsKernel.h" +#include "src/gpu/cl/kernels/ClGemmReshapeLhsMatrixKernel.h" +#include "src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.h" +#include "src/gpu/cl/kernels/ClIm2ColKernel.h" +#include "src/gpu/cl/kernels/ClWeightsReshapeKernel.h" #include "tests/AssetsLibrary.h" #include "tests/CL/CLAccessor.h" #include "tests/Globals.h" @@ -204,7 +204,11 @@ class CLGEMMValidateExample : public ValidateExample mm_gemmlowp.configure(&src0, &src1, nullptr, &tmp_dst); // Configure GEMMlowp output stage - mm_gemmlowp_output_stage.configure(&tmp_dst, add_bias ? &biases : nullptr, &dst, dst_multiplier, dst_shift, offset_dst); + GEMMLowpOutputStageInfo gemm_info{}; + gemm_info.gemmlowp_multiplier = dst_multiplier; + gemm_info.gemmlowp_shift = dst_shift; + gemm_info.gemmlowp_offset = offset_dst; + mm_gemmlowp_output_stage.configure(&tmp_dst, add_bias ? &biases : nullptr, &dst, gemm_info); tmp_dst.allocator()->allocate(); biases.allocator()->allocate(); fill(CLAccessor(biases), 3); @@ -392,9 +396,9 @@ class CLGEMMValidateExample : public ValidateExample CLTensor src0{}, src1{}, src2{}, dst{}; CLTensor tmp_dst{}, biases{}; - CLGEMM mm_gemm{}; - CLGEMMLowpMatrixMultiplyCore mm_gemmlowp{}; - CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint mm_gemmlowp_output_stage{}; + CLGEMM mm_gemm{}; + CLGEMMLowpMatrixMultiplyCore mm_gemmlowp{}; + CLGEMMLowpOutputStage mm_gemmlowp_output_stage{}; size_t M{ 7 }, N{ 3 }, K{ 5 }, B{ 1 }; DataType data_type{ DataType::F32 }; diff --git a/tests/validation/CL/GEMMLowpMatrixMultiplyReshapedOnlyRhsMMUL.cpp b/tests/validation/CL/GEMMLowpMatrixMultiplyReshapedOnlyRhsMMUL.cpp new file mode 100644 index 0000000000..a0d13c3e39 --- /dev/null +++ b/tests/validation/CL/GEMMLowpMatrixMultiplyReshapedOnlyRhsMMUL.cpp @@ -0,0 +1,206 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "arm_compute/runtime/CL/functions/CLCast.h" +#include "arm_compute/runtime/CL/functions/CLReductionOperation.h" +#include "src/gpu/cl/kernels/ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel.h" +#include "src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.h" +#include "tests/CL/CLAccessor.h" +#include "tests/CL/Helper.h" +#include "tests/framework/Macros.h" +#include "tests/framework/datasets/Datasets.h" +#include "tests/validation/fixtures/GEMMLowpFixture.h" + +namespace arm_compute +{ +namespace test +{ +namespace validation +{ +using namespace arm_compute::opencl::kernels; + +// Create function for CLGEMMReshapeRHSMatrixKernel +using CLGEMMReshapeRHSMatrix = CLSynthetizeOperator; + +// Create function for CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel +using CLGEMMLowpMatrixMultiplyReshapedOnlyRHS = CLSynthetizeOperator; + +// Fixture for CLGEMMLowpMatrixMultiplyReshapedOnlyRHS +using CLGEMMLowpMatrixMultiplyReshapedOnlyRHSMMULFixture = + GEMMLowpMatrixMultiplyReshapedOnlyRHSMMULValidationFixture; + +// Fixture for CLGEMMLowpMatrixMultiplyReshapedOnlyRHS +using CLGEMMLowpMatrixMultiplyReshapedOnlyRHSMMULOutputStageFixtureSigned = + GEMMLowpMatrixMultiplyReshapedOnlyRHSMMULOutputStageValidationFixture; + +using CLGEMMLowpMatrixMultiplyReshapedOnlyRHSMMULOutputStageFixtureUnsigned = + GEMMLowpMatrixMultiplyReshapedOnlyRHSMMULOutputStageValidationFixture; + +namespace +{ +// *INDENT-OFF* +// clang-format off + +/** M values to test */ +const auto m_values = framework::dataset::make("M", {16, 49}); + +/** N values to test */ +const auto n_values = framework::dataset::make("N", {16, 259}); + +/** K values to test */ +const auto k_values = framework::dataset::make("K", {192}); + +/** Batch size values to test */ +const auto b_values = framework::dataset::make("batch_size", {1, 2}); + +/** M0 values to test - Precommit */ +const auto m0 = framework::dataset::make("M0", {1, 2, 4}); + +/** N0 values to test - Precommit */ +const auto n0 = framework::dataset::make("N0", { 1, 4, 8}); + +/** K0 values to test - Precommit */ +const auto k0 = framework::dataset::make("K0", { 4 }); + +/** H0 values to test - Precommit */ +const auto h0 = framework::dataset::make("H0", 1); + +/** Interleave values to test with RHS matrix */ +const auto i_values_rhs = framework::dataset::make("interleave_rhs", { false }); + +/** Transpose values to test with RHS matrix */ +const auto t_values_rhs = framework::dataset::make("transpose_rhs", { true }); + +const auto broadcast_bias = framework::dataset::make("broadcast_bias", {true, false}); + +} // namespace + +TEST_SUITE(CL) +TEST_SUITE(GEMMLowpMatrixMultiplyReshapedOnlyRhsMMUL) +FIXTURE_DATA_TEST_CASE(Signed, CLGEMMLowpMatrixMultiplyReshapedOnlyRHSMMULFixture, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values, + n_values), + k_values), + b_values), + m0), + n0), + k0), + h0), + i_values_rhs), + t_values_rhs), + framework::dataset::make("DataType", { DataType::QASYMM8_SIGNED }))) +{ + // Validate output + if(arm_matrix_multiply_supported(CLKernelLibrary::get().get_device())) + { + validate(CLAccessor(_target), _reference); + } + else + { + ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped"); + framework::ARM_COMPUTE_PRINT_INFO(); + } +} +FIXTURE_DATA_TEST_CASE(Unsigned, CLGEMMLowpMatrixMultiplyReshapedOnlyRHSMMULFixture, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values, + n_values), + k_values), + b_values), + m0), + n0), + k0), + h0), + i_values_rhs), + t_values_rhs), + framework::dataset::make("DataType", { DataType::QASYMM8}))) +{ + // Validate output + if(arm_matrix_multiply_supported(CLKernelLibrary::get().get_device())) + { + validate(CLAccessor(_target), _reference); + } + else + { + ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped"); + framework::ARM_COMPUTE_PRINT_INFO(); + } +} +FIXTURE_DATA_TEST_CASE(OutputStageSigned, CLGEMMLowpMatrixMultiplyReshapedOnlyRHSMMULOutputStageFixtureSigned, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values, + n_values), + k_values), + b_values), + m0), + n0), + k0), + h0), + i_values_rhs), + t_values_rhs), + broadcast_bias), + framework::dataset::make("DataType", { DataType::QASYMM8_SIGNED}))) +{ + // Validate output + if(arm_matrix_multiply_supported(CLKernelLibrary::get().get_device())) + { + validate(CLAccessor(_target), _reference); + } + else + { + ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped"); + framework::ARM_COMPUTE_PRINT_INFO(); + } +} +FIXTURE_DATA_TEST_CASE(OutputStageUnsigned, CLGEMMLowpMatrixMultiplyReshapedOnlyRHSMMULOutputStageFixtureUnsigned, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values, + n_values), + k_values), + b_values), + m0), + n0), + k0), + h0), + i_values_rhs), + t_values_rhs), + broadcast_bias), + framework::dataset::make("DataType", { DataType::QASYMM8}))) +{ + // Validate output + if(arm_matrix_multiply_supported(CLKernelLibrary::get().get_device())) + { + validate(CLAccessor(_target), _reference); + } + else + { + ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped"); + framework::ARM_COMPUTE_PRINT_INFO(); + } +} +TEST_SUITE_END() // GEMMLowpMatrixMultiplyReshapedOnlyRhsMMUL +TEST_SUITE_END() // CL +} // namespace validation +} // namespace test +} // namespace arm_compute \ No newline at end of file diff --git a/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRhsMMUL.cpp b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRhsMMUL.cpp new file mode 100644 index 0000000000..7808be8529 --- /dev/null +++ b/tests/validation/CL/GEMMMatrixMultiplyReshapedOnlyRhsMMUL.cpp @@ -0,0 +1,231 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "src/gpu/cl/kernels/ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel.h" +#include "src/gpu/cl/kernels/ClGemmReshapeRhsMatrixKernel.h" +#include "tests/CL/CLAccessor.h" +#include "tests/CL/Helper.h" +#include "tests/framework/Macros.h" +#include "tests/framework/datasets/Datasets.h" +#include "tests/validation/Validation.h" +#include "tests/validation/fixtures/GEMMFixture.h" + +namespace arm_compute +{ +namespace test +{ +namespace validation +{ +using namespace arm_compute::opencl::kernels; + +// Create function for ClGemmReshapeRhsMatrixKernel +using CLGEMMReshapeRHSMatrix = CLSynthetizeOperator; + +// Create function for ClGemmMatrixMultiplyReshapedOnlyRhsMMULKernel +using CLGEMMMatrixMultiplyReshapedOnlyRhsMMUL = CLSynthetizeOperator; + +// Fixture for CLGEMMMatrixMultiplyReshapedOnlyRhsMMUL +template +using CLGEMMMatrixMultiplyReshapedOnlyRhsMMULFixture = GEMMMatrixMultiplyReshapedOnlyRhsMMULValidationFixture; + +namespace +{ +// *INDENT-OFF* +// clang-format off +RelativeTolerance rel_tolerance_f32(0.001f); +constexpr float abs_tolerance_f32(0.0001f); +RelativeTolerance rel_tolerance_f16(half_float::half(0.001f)); +constexpr float abs_tolerance_f16(0.3f); + +/** Alpha values to test - Precommit */ +const auto a_values = framework::dataset::make("alpha", {1.0f, 0.75f} ); + +/** Beta values to test - Precommit */ +const auto beta_values = framework::dataset::make("beta", {0.0f, -0.75f} ); + +/** M values to test */ +const auto m_values = framework::dataset::make("M", {49}); + +/** N values to test */ +const auto n_values = framework::dataset::make("N", {257}); + +/** K values to test */ +/** The test case requires this to be multiple of 4*/ +const auto k_values = framework::dataset::make("K", {192}); + +/** Batch size values to test */ +const auto b_values = framework::dataset::make("batch_size", {1, 2}); + +/** Activation values to test */ +const auto act_values = framework::dataset::make("Activation", +{ + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU), +}); + +/** M0 values to test - Precommit */ +const auto m0_values_precommit = framework::dataset::make("M0", { 1, 2, 4 }); + +/** N0 values to test - Precommit */ +const auto n0_values_precommit = framework::dataset::make("N0", { 4, 8 }); + +/** K0 values to test - Precommit */ +const auto k0_values_precommit = framework::dataset::make("K0", { 1 }); + +/** Broadcast bias from vector to matrix */ +const auto broadcast_bias_values = framework::dataset::make("broadcast_bias", { false, true } ); + +} // namespace + +TEST_SUITE(CL) +TEST_SUITE(GEMMMatrixMultiplyReshapedOnlyRhsMMUL) +TEST_SUITE(Float) +TEST_SUITE(FP32) +FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRhsMMULFixture, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values, + n_values), + k_values), + b_values), + m0_values_precommit), + n0_values_precommit), + k0_values_precommit), + framework::dataset::make("ExportToCLImage", false)), + framework::dataset::make("DataType", DataType::F32)), + a_values), + beta_values), + broadcast_bias_values), + act_values)) +{ + // Validate output + if(validate_result) + { + validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); + } + else + { + ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped"); + framework::ARM_COMPUTE_PRINT_INFO(); + } +} + +TEST_SUITE_END() // FP32 + +TEST_SUITE(FP16) +FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRhsMMULFixture, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values, + n_values), + k_values), + b_values), + m0_values_precommit), + n0_values_precommit), + k0_values_precommit), + framework::dataset::make("ExportToCLImage", false)), + framework::dataset::make("DataType", DataType::F16)), + a_values), + beta_values), + broadcast_bias_values), + act_values)) +{ + // Validate output + if(validate_result) + { + validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16); + } + else + { + ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped"); + framework::ARM_COMPUTE_PRINT_INFO(); + } +} +TEST_SUITE_END() // FP16 + +TEST_SUITE(ExportToCLImage) +TEST_SUITE(FP32) +FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRhsMMULFixture, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values, + n_values), + k_values), + b_values), + m0_values_precommit), + n0_values_precommit), + k0_values_precommit), + framework::dataset::make("ExportToCLImage", true)), + framework::dataset::make("DataType", DataType::F32)), + a_values), + beta_values), + broadcast_bias_values), + act_values)) +{ + // Validate output + if(validate_result) + { + validate(CLAccessor(_target), _reference, rel_tolerance_f32, 0.f, abs_tolerance_f32); + } + else + { + ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped"); + framework::ARM_COMPUTE_PRINT_INFO(); + } +} + +TEST_SUITE_END() // FP32 + +TEST_SUITE(FP16) +FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMMatrixMultiplyReshapedOnlyRhsMMULFixture, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine( + m_values, + n_values), + k_values), + b_values), + m0_values_precommit), + n0_values_precommit), + k0_values_precommit), + framework::dataset::make("ExportToCLImage", true)), + framework::dataset::make("DataType", DataType::F16)), + a_values), + beta_values), + broadcast_bias_values), + act_values)) +{ + // Validate output + if(validate_result) + { + validate(CLAccessor(_target), _reference, rel_tolerance_f16, 0.f, abs_tolerance_f16); + } + else + { + ARM_COMPUTE_TEST_INFO("cl_arm_matrix_multiply not supported. TEST skipped"); + framework::ARM_COMPUTE_PRINT_INFO(); + } +} +TEST_SUITE_END() // FP16 +TEST_SUITE_END() // ExportToCLImage +TEST_SUITE_END() // Float +TEST_SUITE_END() // GEMMMatrixMultiplyReshapedOnlyRhsMMUL +TEST_SUITE_END() // CL +} // namespace validation +} // namespace test +} // namespace arm_compute diff --git a/tests/validation/CL/UNIT/dynamic_fusion/ArbitraryElementwiseFusion.cpp b/tests/validation/CL/UNIT/dynamic_fusion/ArbitraryElementwiseFusion.cpp new file mode 100644 index 0000000000..1b1e8aa761 --- /dev/null +++ b/tests/validation/CL/UNIT/dynamic_fusion/ArbitraryElementwiseFusion.cpp @@ -0,0 +1,394 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifdef ENABLE_EXPERIMENTAL_DYNAMIC_FUSION + +#include "src/core/experimental/dynamic_fusion/ClKernelBuildingAPI.h" +#include "src/core/utils/helpers/float_ops.h" +#include "tests/CL/CLAccessor.h" +#include "tests/framework/Macros.h" +#include "tests/validation/Validation.h" +#include "tests/validation/reference/ConvolutionLayer.h" +#include "tests/validation/reference/ElementwiseOperations.h" +#include "tests/validation/reference/Permute.h" + +#include "arm_compute/runtime/experimental/ClCompositeOperator.h" +#include "tests/validation/reference/Floor.h" + +#include "arm_compute/core/ITensor.h" +#include "arm_compute/runtime/CL/CLTensor.h" +#include "tests/validation/CL/UNIT/dynamic_fusion/Utils.h" + +using namespace arm_compute::experimental::dynamic_fusion; +using namespace arm_compute::test::validation::utils; + +namespace arm_compute +{ +namespace test +{ +namespace validation +{ +TEST_SUITE(CL) +TEST_SUITE(UNIT) +TEST_SUITE(DYNAMIC_FUSION) +TEST_SUITE(ArbitraryFusion) + +TEST_CASE(ElementwiseBroadcasting, framework::DatasetMode::ALL) +{ + // Test elementwise broadcasting + const auto data_type = DataType::F32; + const auto data_layout = DataLayout::NHWC; + + const auto input_shape = TensorShape(7, 9, 5); + const auto rhs_shape = TensorShape(7, 1, 1); + const auto dst_shape = TensorShape(7, 9, 5); + + // Tensor Info + auto input_info = TensorInfo(input_shape, 1, data_type, data_layout); + auto addend_info = TensorInfo(rhs_shape, 1, data_type, data_layout); + auto dst_info = TensorInfo(); + + ElementwiseDescriptor add_desc{ ArithmeticOperation::ADD }; + + CLScheduler::get().default_reinit(); + const auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); + OperatorGraph op_graph; + + const auto op_input = add_tensor(op_graph, input_info); + const auto op_addend = add_tensor(op_graph, addend_info); + const auto op_dst = add_tensor(op_graph, dst_info); + + add_op_elementwise_op(op_graph, add_desc, op_input, op_addend, op_dst); + + const ClWorkloadContext workload_ctx{ GpuInfo{ CLScheduler::get().target() } }; + ClWorkload workload; + build(workload, op_graph, workload_ctx); + + ClCompositeOperator op; + op.configure(cl_compile_ctx, workload); + + // Construct tensors + CLTensor t_input{}; + CLTensor t_addend{}; + CLTensor t_dst{}; + + // Init tensors + t_input.allocator()->init(input_info); + t_addend.allocator()->init(addend_info); + t_dst.allocator()->init(dst_info); + + // Allocate and fill tensors + t_input.allocator()->allocate(); + t_addend.allocator()->allocate(); + t_dst.allocator()->allocate(); + + // Fill + fill(CLAccessor(t_input), 0, library.get()); + fill(CLAccessor(t_addend), 1, library.get()); + + // Pack tensors + OpTensorBinding bp_tensors({ { op_input, &t_input }, + { op_addend, &t_addend }, + { op_dst, &t_dst } + }); + + // Populate prepare and run pack-maps (including allocating aux tensors) + ClAuxTensorData aux_tensor_data{}; + TensorPackMap prepare_pack_map{}; + TensorPackMap run_pack_map{}; + bind_tensors(aux_tensor_data, prepare_pack_map, run_pack_map, workload, bp_tensors); + + op.prepare(prepare_pack_map); + op.run(run_pack_map); + + // Create reference + SimpleTensor ref_input{ input_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC }; + SimpleTensor ref_addend{ rhs_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC }; + + // Fill reference + fill(ref_input, 0, library.get()); + fill(ref_addend, 1, library.get()); + + auto ref_input_nchw = reference::permute(ref_input, PermutationVector(1U, 2U, 0U)); + auto ref_addend_nchw = reference::permute(ref_addend, PermutationVector(1U, 2U, 0U)); + + auto dst_shape_nchw = dst_shape; + permute(dst_shape_nchw, PermutationVector(1U, 2U, 0U)); + + auto ref_t_dst_nchw = reference::arithmetic_operation( + ArithmeticOperation::ADD, + ref_input_nchw, + ref_addend_nchw, + data_type, + ConvertPolicy{}); + + const auto ref_t_dst = reference::permute(ref_t_dst_nchw, PermutationVector(2U, 0U, 1U)); + + RelativeTolerance tolerance_f32(0.001f); + validate(CLAccessor(t_dst), ref_t_dst_nchw, tolerance_f32); +} +TEST_CASE(DivFloor, framework::DatasetMode::ALL) +{ + // x = floor(div(input, input2)) + const auto data_type = DataType::F32; + const auto eltwise_info = ElementwiseDescriptor{ ArithmeticOperation::DIV }; + + // Tensor Values + const auto width = 7U; + const auto height = 6U; + + // Shapes + const auto input1_shape = TensorShape(width, height); + const auto input2_shape = TensorShape(width, height); + const auto dst_shape = TensorShape(width, height); + + // Create reference + SimpleTensor ref_src_nhwc{ input1_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC }; + SimpleTensor ref_src2_nhwc{ input2_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC }; + + // Fill reference + fill(ref_src_nhwc, 0, library.get()); + fill(ref_src2_nhwc, 1, library.get()); + + auto ref_src = reference::permute(ref_src_nhwc, PermutationVector(1U, 2U, 0U)); + auto ref_src2 = reference::permute(ref_src2_nhwc, PermutationVector(1U, 2U, 0U)); + + TensorShape dst_shape_nchw{ dst_shape }; + permute(dst_shape_nchw, PermutationVector(1U, 2U, 0U)); + + const auto ref_dst_nchw = reference::floor_layer(reference::arithmetic_operation( + ArithmeticOperation::DIV, + ref_src, + ref_src2, + data_type, + ConvertPolicy::SATURATE)); + + const auto ref_t_dst = reference::permute(ref_dst_nchw, PermutationVector(2U, 0U, 1U)); + + // Tensor Info + auto input1_info = TensorInfo(input1_shape, 1, data_type, DataLayout::NHWC); + auto input2_info = TensorInfo(input2_shape, 1, data_type, DataLayout::NHWC); + auto dst_info = TensorInfo(); + auto acc_info = TensorInfo(); // Intermediate tensor for division + + // Initialise Scheduler + CLScheduler::get().default_reinit(); + const auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); + OperatorGraph op_graph; + + // add tensors + auto op_input1 = add_tensor(op_graph, input1_info); + auto op_input2 = add_tensor(op_graph, input2_info); + auto op_acc = add_tensor(op_graph, acc_info); + auto op_dst = add_tensor(op_graph, dst_info); + + add_op_elementwise_op(op_graph, eltwise_info, op_input1, op_input2, op_acc); + add_op_floor(op_graph, FloorDescriptor(), op_acc, op_dst); + + const ClWorkloadContext workload_ctx{ GpuInfo{ CLScheduler::get().target() } }; + ClWorkload workload; + build(workload, op_graph, workload_ctx); + + ClCompositeOperator op; + op.configure(cl_compile_ctx, workload); + + // Configure and add tensors. + CLTensor t_input1{}; + CLTensor t_input2{}; + CLTensor t_dst{}; + + // Init Tensors + t_input1.allocator()->init(input1_info); + t_input2.allocator()->init(input2_info); + t_dst.allocator()->init(dst_info); + + // Allocate and fill tensors + t_input1.allocator()->allocate(); + t_input2.allocator()->allocate(); + t_dst.allocator()->allocate(); + + fill(CLAccessor(t_input1), 0, library.get()); + fill(CLAccessor(t_input2), 1, library.get()); + + // "Pack" tensors + OpTensorBinding bp_tensors({ { op_input1, &t_input1 }, + { op_input2, &t_input2 }, + { op_dst, &t_dst } + }); + + // Populate prepare and run pack-maps (including allocating aux tensors) + ClAuxTensorData aux_tensor_data{}; + TensorPackMap prepare_pack_map{}; + TensorPackMap run_pack_map{}; + bind_tensors(aux_tensor_data, prepare_pack_map, run_pack_map, workload, bp_tensors); + + op.prepare(prepare_pack_map); + op.run(run_pack_map); + + RelativeTolerance tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */ + validate(CLAccessor(t_dst), ref_dst_nchw, tolerance_f32); +} +TEST_CASE(Dconv2dAddDiv, framework::DatasetMode::ALL) +{ + // output = div(divend, add(addend, conv2d1x1(direct_conv)(input, weights, bias))) + const auto data_type = DataType::F32; + const auto data_layout = DataLayout::NHWC; + + const auto input_shape = TensorShape(384, 12, 12); + const auto weight_shape = TensorShape(384, 1, 1, 16); + const auto dst_shape = TensorShape(16, 12, 12); + + // Tensor Info + auto input_info = TensorInfo(input_shape, 1, data_type, data_layout); + auto weight_info = TensorInfo(weight_shape, 1, data_type, data_layout); + auto addend_info = TensorInfo(dst_shape, 1, data_type, data_layout); + auto divend_info = TensorInfo(dst_shape, 1, data_type, data_layout); + auto acc_info = TensorInfo(); // Intermediate tensor for conv + auto acc_1_info = TensorInfo(); + auto dst_info = TensorInfo(); + + Conv2dDescriptor conv2d_desc{}; + ElementwiseDescriptor add_desc{ ArithmeticOperation::ADD }; + ElementwiseDescriptor div_desc{ ArithmeticOperation::DIV }; + + CLScheduler::get().default_reinit(); + const auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); + OperatorGraph op_graph; + + const auto op_input = add_tensor(op_graph, input_info); + const auto op_weight = add_tensor(op_graph, weight_info); + const auto op_addend = add_tensor(op_graph, addend_info); + const auto op_divend = add_tensor(op_graph, divend_info); + const auto op_acc = add_tensor(op_graph, acc_info); // temp accumulator; TensorInfo to be inferred + const auto op_acc_1 = add_tensor(op_graph, acc_1_info); // temp accumulator; TensorInfo to be inferred + const auto op_dst = add_tensor(op_graph, dst_info); + + auto conv2d = add_op_conv2d(op_graph, conv2d_desc, op_input, op_weight, op_acc); + force_conv2d_method(op_graph, conv2d, ConvolutionMethod::DIRECT); + add_op_elementwise_op(op_graph, add_desc, op_acc, op_addend, op_acc_1); + add_op_elementwise_op(op_graph, div_desc, op_acc_1, op_divend, op_dst); + + const ClWorkloadContext workload_ctx{ GpuInfo{ CLScheduler::get().target() } }; + ClWorkload workload; + build(workload, op_graph, workload_ctx); + + ClCompositeOperator op; + op.configure(cl_compile_ctx, workload); + + // Construct tensors + CLTensor t_input{}; + CLTensor t_weight{}; + CLTensor t_addend{}; + CLTensor t_divend{}; + CLTensor t_dst{}; + + // Init tensors + t_input.allocator()->init(input_info); + t_weight.allocator()->init(weight_info); + t_divend.allocator()->init(divend_info); + t_addend.allocator()->init(addend_info); + t_dst.allocator()->init(dst_info); + + // Allocate and fill tensors + t_input.allocator()->allocate(); + t_weight.allocator()->allocate(); + t_divend.allocator()->allocate(); + t_addend.allocator()->allocate(); + t_dst.allocator()->allocate(); + + // Fill + fill(CLAccessor(t_input), 0, library.get()); + fill(CLAccessor(t_weight), 1, library.get()); + fill(CLAccessor(t_addend), 2, library.get()); + fill(CLAccessor(t_divend), 3, library.get()); + + // Pack tensors + OpTensorBinding bp_tensors({ { op_input, &t_input }, + { op_weight, &t_weight }, + { op_addend, &t_addend }, + { op_divend, &t_divend }, + { op_dst, &t_dst } + }); + + // Populate prepare and run pack-maps (including allocating aux tensors) + ClAuxTensorData aux_tensor_data{}; + TensorPackMap prepare_pack_map{}; + TensorPackMap run_pack_map{}; + bind_tensors(aux_tensor_data, prepare_pack_map, run_pack_map, workload, bp_tensors); + + op.prepare(prepare_pack_map); + op.run(run_pack_map); + + // Create reference + SimpleTensor ref_input{ input_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC }; + SimpleTensor ref_weight{ weight_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC }; + SimpleTensor ref_bias_placeholder{ dst_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC }; + SimpleTensor ref_addend{ dst_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC }; + SimpleTensor ref_divend{ dst_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC }; + + // Fill reference + fill(ref_input, 0, library.get()); + fill(ref_weight, 1, library.get()); + fill(ref_addend, 2, library.get()); + fill(ref_divend, 3, library.get()); + + auto ref_input_nchw = reference::permute(ref_input, PermutationVector(1U, 2U, 0U)); + auto ref_weight_nchw = reference::permute(ref_weight, PermutationVector(1U, 2U, 0U)); + auto ref_bias_placeholder_nchw = reference::permute(ref_bias_placeholder, PermutationVector(1U, 2U, 0U)); + auto ref_addend_nchw = reference::permute(ref_addend, PermutationVector(1U, 2U, 0U)); + auto ref_divend_nchw = reference::permute(ref_divend, PermutationVector(1U, 2U, 0U)); + + auto dst_shape_nchw = dst_shape; + permute(dst_shape_nchw, PermutationVector(1U, 2U, 0U)); + + PadStrideInfo legacy_pad_stride(conv2d_desc.stride.x(), conv2d_desc.stride.y(), conv2d_desc.pad.left, conv2d_desc.pad.right, conv2d_desc.pad.top, conv2d_desc.pad.bottom, DimensionRoundingType{}); + auto ref_acc_nchw = reference::arithmetic_operation( + ArithmeticOperation::ADD, + ref_addend_nchw, + reference::convolution_layer(ref_input_nchw, ref_weight_nchw, ref_bias_placeholder_nchw, dst_shape_nchw, legacy_pad_stride, conv2d_desc.dilation), + data_type, + ConvertPolicy{}); + + auto ref_t_dst_nchw = reference::arithmetic_operation( + ArithmeticOperation::DIV, + ref_acc_nchw, + ref_divend_nchw, + data_type, + ConvertPolicy{}); + + const auto ref_t_dst = reference::permute(ref_t_dst_nchw, PermutationVector(2U, 0U, 1U)); + + RelativeTolerance tolerance_f32(0.001f); + validate(CLAccessor(t_dst), ref_t_dst_nchw, tolerance_f32); +} + +TEST_SUITE_END() // ArbitraryFusion +TEST_SUITE_END() // DYNAMIC_FUSION +TEST_SUITE_END() // UNIT +TEST_SUITE_END() // CL + +} // namespace validation +} // namespace test +} // namespace arm_compute + +#endif /* ENABLE_EXPERIMENTAL_DYNAMIC_FUSION */ diff --git a/tests/validation/CL/UNIT/dynamic_fusion/ClCompositeKernel.cpp b/tests/validation/CL/UNIT/dynamic_fusion/ClCompositeKernel.cpp index 96a845c36e..3ffbc077c6 100644 --- a/tests/validation/CL/UNIT/dynamic_fusion/ClCompositeKernel.cpp +++ b/tests/validation/CL/UNIT/dynamic_fusion/ClCompositeKernel.cpp @@ -74,8 +74,9 @@ TEST_CASE(MoveNet_SubGraph_1_DirectConv2d, framework::DatasetMode::ALL) ClExecutionDescriptor exec_desc{}; Status st{}; - const auto data_type = DataType::F32; - const auto conv_info = Conv2dDescriptor{ Padding2D{ 1U, 1U, 1U, 1U }, { 1U, 1U } /* stride */ }; + const auto data_type = DataType::F32; + const auto conv_info = Conv2dDescriptor{ Padding2D{ 1U, 1U, 1U, 1U }, { 1U, 1U } /* stride */ }; + const auto eltwise_info = ElementwiseDescriptor{ ArithmeticOperation::ADD }; const auto width = 7U; const auto height = 6U; @@ -99,7 +100,7 @@ TEST_CASE(MoveNet_SubGraph_1_DirectConv2d, framework::DatasetMode::ALL) const auto m0 = (OFM > 16) ? ((data_type == DataType::F32) ? 2U : 4U) : 1U; const ClDirectConv2dKernelDescriptor direct_conv2d_desc{ conv_info }; - const ClEltwiseAddKernelDescriptor eltwise_add_desc{}; + const ClElementwiseKernelDescriptor eltwise_add_desc{ eltwise_info }; const TileDescriptor store_tile_info{ Size2D(n0, m0), Size2D(width, height), ClippingStrategy::TOP_LEFT }; ArgumentID src_id{ g_arg_placeholder }; @@ -119,7 +120,7 @@ TEST_CASE(MoveNet_SubGraph_1_DirectConv2d, framework::DatasetMode::ALL) st = add_tensor(bp, &dst_info, dst_id); st = add_kcomp_direct_conv2d(bp, direct_conv2d_desc, src_id, wei_id, bia_id, acc_id); - st = add_kcomp_eltwise_add(bp, eltwise_add_desc, addend_id, acc_id, acc_1_id); + st = add_kcomp_eltwise_op(bp, eltwise_add_desc, addend_id, acc_id, acc_1_id); st = add_kcomp_store(bp, StoreType::TStoreIndirectWidthSelect, acc_1_id, dst_id); exec_desc.skip_sliding_window = true; diff --git a/tests/validation/CL/UNIT/dynamic_fusion/Floor.cpp b/tests/validation/CL/UNIT/dynamic_fusion/Floor.cpp new file mode 100644 index 0000000000..2b8f69e5e7 --- /dev/null +++ b/tests/validation/CL/UNIT/dynamic_fusion/Floor.cpp @@ -0,0 +1,135 @@ +/* + * Copyright (c) 2022 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#ifdef ENABLE_EXPERIMENTAL_DYNAMIC_FUSION +#include "arm_compute/core/TensorInfo.h" + +#include "arm_compute/core/CL/CLKernelLibrary.h" +#include "arm_compute/core/experimental/ClWorkload.h" +#include "arm_compute/runtime/CL/CLScheduler.h" +#include "arm_compute/runtime/experimental/ClCompositeOperator.h" +#include "src/core/experimental/dynamic_fusion/WorkloadImpl/ClKernelDescriptors.h" +#include "tests/CL/CLAccessor.h" +#include "tests/framework/Asserts.h" +#include "tests/framework/Macros.h" +#include "tests/validation/CL/UNIT/dynamic_fusion/Utils.h" +#include "tests/validation/Validation.h" + +#include "tests/validation/reference/Floor.h" +#include "tests/validation/reference/Permute.h" + +#ifdef ARM_COMPUTE_ASSERTS_ENABLED +#include "tests/SimpleTensorPrinter.h" +#endif /* ARM_COMPUTE_ASSERTS_ENABLED */ + +using namespace arm_compute::experimental::dynamic_fusion; +using namespace arm_compute::test::validation::utils; + +namespace arm_compute +{ +namespace test +{ +namespace validation +{ +TEST_SUITE(CL) +TEST_SUITE(UNIT) +TEST_SUITE(DYNAMIC_FUSION) +TEST_CASE(Operator_Floor_1_F32, framework::DatasetMode::ALL) +{ + /* Computation: + * out = floor(input) + */ + const auto data_type = DataType::F32; + const auto data_layout = DataLayout::NHWC; + const auto t_shape = TensorShape(32, 16); + auto t_input_info = TensorInfo(t_shape, 1, data_type, data_layout); + auto t_dst_info = TensorInfo(); + + FloorDescriptor floor_desc{}; + + // Create reference + SimpleTensor ref_t_input{ t_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC }; + + // Fill reference + fill(ref_t_input, 0, library.get()); + + auto ref_t_input_nchw = reference::permute(ref_t_input, PermutationVector(1U, 2U, 0U)); + auto t_dst_shape_nchw = t_shape; + permute(t_dst_shape_nchw, PermutationVector(1U, 2U, 0U)); + + auto ref_t_dst_nchw = reference::floor_layer(ref_t_input_nchw); + const auto ref_t_dst = reference::permute(ref_t_dst_nchw, PermutationVector(2U, 0U, 1U)); + + CLScheduler::get().default_reinit(); + const auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context(); + OperatorGraph op_graph; + + const auto op_t_input = add_tensor(op_graph, t_input_info); + const auto op_t_dst = add_tensor(op_graph, t_dst_info); + + add_op_floor(op_graph, floor_desc, op_t_input, op_t_dst); + + const ClWorkloadContext workload_ctx{ GpuInfo{ CLScheduler::get().target() } }; + ClWorkload workload; + build(workload, op_graph, workload_ctx); + + ClCompositeOperator op; + op.configure(cl_compile_ctx, workload); + + // Construct tensors + CLTensor t_input{}; + CLTensor t_dst{}; + + // Init tensors + t_input.allocator()->init(t_input_info); + t_dst.allocator()->init(t_dst_info); + + // Allocate and fill tensors + t_input.allocator()->allocate(); + t_dst.allocator()->allocate(); + fill(CLAccessor(t_input), 0, library.get()); + // "Pack" tensors + OpTensorBinding bp_tensors({ { op_t_input, &t_input }, + { op_t_dst, &t_dst } + }); + + // Populate prepare and run pack-maps (including allocating aux tensors) + ClAuxTensorData aux_tensor_data{}; + TensorPackMap prepare_pack_map{}; + TensorPackMap run_pack_map{}; + bind_tensors(aux_tensor_data, prepare_pack_map, run_pack_map, workload, bp_tensors); + + op.prepare(prepare_pack_map); + op.run(run_pack_map); + RelativeTolerance tolerance_f32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for floating point data types */ + validate(CLAccessor(t_dst), ref_t_dst_nchw, tolerance_f32); +} + +TEST_SUITE_END() // DYNAMIC_FUSION +TEST_SUITE_END() // UNIT +TEST_SUITE_END() // CL +} // namespace validation +} // namespace test +} // namespace arm_compute +#endif /* ENABLE_EXPERIMENTAL_DYNAMIC_FUSION */ \ No newline at end of file diff --git a/tests/validation/CL/UNIT/dynamic_fusion/Integration_OperatorFuseMovenetSubGraph1.cpp b/tests/validation/CL/UNIT/dynamic_fusion/Integration_OperatorFuseMovenetSubGraph1.cpp index fe8d23ef15..3a8b7c8ce8 100644 --- a/tests/validation/CL/UNIT/dynamic_fusion/Integration_OperatorFuseMovenetSubGraph1.cpp +++ b/tests/validation/CL/UNIT/dynamic_fusion/Integration_OperatorFuseMovenetSubGraph1.cpp @@ -77,8 +77,8 @@ TEST_CASE(Operator_Fuse_Movenet_SubGraph_1_F32, framework::DatasetMode::ALL) auto t_acc_info = TensorInfo(); // Intermediate tensor for cond3 auto t_dst_info = TensorInfo(); - Conv2dDescriptor conv2d_desc{}; - AddDescriptor add_desc{}; + Conv2dDescriptor conv2d_desc{}; + ElementwiseDescriptor add_desc{ ArithmeticOperation::ADD }; // Create reference SimpleTensor ref_t_input{ t_input_shape, data_type, 1, QuantizationInfo(), DataLayout::NHWC }; @@ -119,7 +119,7 @@ TEST_CASE(Operator_Fuse_Movenet_SubGraph_1_F32, framework::DatasetMode::ALL) auto conv2d = add_op_conv2d(op_graph, conv2d_desc, op_t_input, op_t_weight, op_t_acc); force_conv2d_method(op_graph, conv2d, ConvolutionMethod::DIRECT); - add_op_elementwise_add(op_graph, add_desc, op_t_acc, op_t_l1_addend, op_t_dst); + add_op_elementwise_op(op_graph, add_desc, op_t_acc, op_t_l1_addend, op_t_dst); const ClWorkloadContext workload_ctx{ GpuInfo{ CLScheduler::get().target() } }; ClWorkload workload; @@ -180,8 +180,8 @@ TEST_CASE(DataType_QASYMM8, framework::DatasetMode::ALL) auto t_acc_info = TensorInfo(t_dst_shape, 1, data_type, data_layout); auto t_dst_info = TensorInfo(t_dst_shape, 1, data_type, data_layout); - Conv2dDescriptor conv2d_desc{}; - AddDescriptor add_desc{}; + Conv2dDescriptor conv2d_desc{}; + ElementwiseDescriptor add_desc{}; OperatorGraph op_graph; @@ -192,7 +192,7 @@ TEST_CASE(DataType_QASYMM8, framework::DatasetMode::ALL) const auto op_t_dst = add_tensor(op_graph, t_dst_info); auto conv2d = add_op_conv2d(op_graph, conv2d_desc, op_t_input, op_t_weight, op_t_acc); - add_op_elementwise_add(op_graph, add_desc, op_t_acc, op_t_l1_addend, op_t_dst); + add_op_elementwise_op(op_graph, add_desc, op_t_acc, op_t_l1_addend, op_t_dst); force_conv2d_method(op_graph, conv2d, ConvolutionMethod::DIRECT); const ClWorkloadContext workload_ctx{ GpuInfo{ CLScheduler::get().target() } }; @@ -290,7 +290,7 @@ TEST_CASE(Enlarging_Execution_Space, framework::DatasetMode::ALL) auto t_dst_info = TensorInfo(); OperatorGraph op_graph; - const auto add_desc = AddDescriptor{}; + const auto add_desc = ElementwiseDescriptor{}; const auto op_t_l0_lhs = add_tensor(op_graph, t_l0_lhs_info); const auto op_t_l0_rhs = add_tensor(op_graph, t_l0_rhs_info); @@ -300,9 +300,9 @@ TEST_CASE(Enlarging_Execution_Space, framework::DatasetMode::ALL) const auto op_t_l1_dst = add_tensor(op_graph, t_l1_dst_info); // temp accumulator; TensorInfo to be inferred const auto op_t_dst = add_tensor(op_graph, t_dst_info); - add_op_elementwise_add(op_graph, add_desc, op_t_l0_lhs, op_t_l0_rhs, op_t_l0_dst); - add_op_elementwise_add(op_graph, add_desc, op_t_l0_dst, op_t_l1_rhs, op_t_l1_dst); - add_op_elementwise_add(op_graph, add_desc, op_t_l1_dst, op_t_l2_lhs, op_t_dst); + add_op_elementwise_op(op_graph, add_desc, op_t_l0_lhs, op_t_l0_rhs, op_t_l0_dst); + add_op_elementwise_op(op_graph, add_desc, op_t_l0_dst, op_t_l1_rhs, op_t_l1_dst); + add_op_elementwise_op(op_graph, add_desc, op_t_l1_dst, op_t_l2_lhs, op_t_dst); const ClWorkloadContext workload_ctx{ GpuInfo{ CLScheduler::get().target() } }; ClWorkload workload; @@ -334,7 +334,7 @@ TEST_CASE(Root_Simple_And_Complex, framework::DatasetMode::ALL) OperatorGraph op_graph; const auto conv2d_desc = Conv2dDescriptor{}; - const auto add_desc = AddDescriptor{}; + const auto add_desc = ElementwiseDescriptor{}; const auto op_t_l0_0_input = add_tensor(op_graph, t_l0_0_input_info); const auto op_t_l0_0_weight = add_tensor(op_graph, t_l0_0_weight_info); @@ -345,8 +345,8 @@ TEST_CASE(Root_Simple_And_Complex, framework::DatasetMode::ALL) const auto op_t_dst = add_tensor(op_graph, t_dst_info); add_op_conv2d(op_graph, conv2d_desc, op_t_l0_0_input, op_t_l0_0_weight, op_t_l0_0_dst); - add_op_elementwise_add(op_graph, add_desc, op_t_l0_1_lhs, op_t_l0_1_rhs, op_t_l0_1_dst); - add_op_elementwise_add(op_graph, add_desc, op_t_l0_0_dst, op_t_l0_1_dst, op_t_dst); + add_op_elementwise_op(op_graph, add_desc, op_t_l0_1_lhs, op_t_l0_1_rhs, op_t_l0_1_dst); + add_op_elementwise_op(op_graph, add_desc, op_t_l0_0_dst, op_t_l0_1_dst, op_t_dst); const ClWorkloadContext workload_ctx{ GpuInfo{ CLScheduler::get().target() } }; ClWorkload workload; @@ -374,7 +374,7 @@ TEST_CASE(Loop, framework::DatasetMode::ALL) OperatorGraph op_graph; const auto conv2d_desc = Conv2dDescriptor{}; - const auto add_desc = AddDescriptor{}; + const auto add_desc = ElementwiseDescriptor{}; const auto op_t_l0_lhs = add_tensor(op_graph, t_l0_lhs_info); const auto op_t_l1_lhs = add_tensor(op_graph, t_l1_lhs_info); @@ -382,7 +382,7 @@ TEST_CASE(Loop, framework::DatasetMode::ALL) const auto op_t_state1 = add_tensor(op_graph, state1_info); add_op_conv2d(op_graph, conv2d_desc, op_t_l0_lhs, op_t_state0, op_t_state1); - add_op_elementwise_add(op_graph, add_desc, op_t_l1_lhs, op_t_state1, op_t_state0); + add_op_elementwise_op(op_graph, add_desc, op_t_l1_lhs, op_t_state1, op_t_state0); const ClWorkloadContext workload_ctx{ GpuInfo{ CLScheduler::get().target() } }; ClWorkload workload; diff --git a/tests/validation/NEON/ActivationLayer.cpp b/tests/validation/NEON/ActivationLayer.cpp index 1f43de49d2..e45b7fa5ad 100644 --- a/tests/validation/NEON/ActivationLayer.cpp +++ b/tests/validation/NEON/ActivationLayer.cpp @@ -309,7 +309,7 @@ DATA_TEST_CASE(KernelSelection, framework::DatasetMode::ALL, concat(concat( cpu_isa.sve2 = (cpu_ext == "SVE2"); cpu_isa.fp16 = (data_type == DataType::F16); - const auto *selected_impl = CpuActivationKernel::get_implementation(DataTypeISASelectorData{data_type, cpu_isa}, cpu::KernelSelectionType::Preferred); + const auto *selected_impl = CpuActivationKernel::get_implementation(ActivationDataTypeISASelectorData{data_type, cpu_isa,ActivationLayerInfo::ActivationFunction::BOUNDED_RELU}, cpu::KernelSelectionType::Preferred); ARM_COMPUTE_ERROR_ON_NULLPTR(selected_impl); diff --git a/tests/validation/NEON/ArithmeticAddition.cpp b/tests/validation/NEON/ArithmeticAddition.cpp index c72e082a74..f94e329c9c 100644 --- a/tests/validation/NEON/ArithmeticAddition.cpp +++ b/tests/validation/NEON/ArithmeticAddition.cpp @@ -89,7 +89,7 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( } DATA_TEST_CASE(KernelSelection, framework::DatasetMode::ALL, concat(concat( - combine(framework::dataset::make("CpuExt", std::string("NEON")), + combine(combine(framework::dataset::make("CpuExt", std::string("NEON")), framework::dataset::make("DataType", { DataType::F32, DataType::F16, DataType::U8, @@ -99,19 +99,22 @@ DATA_TEST_CASE(KernelSelection, framework::DatasetMode::ALL, concat(concat( DataType::QASYMM8_SIGNED, DataType::QSYMM16 })), - combine(framework::dataset::make("CpuExt", std::string("SVE")), + framework::dataset::make("CanInterpretAs1D", {true, false})), + combine(combine(framework::dataset::make("CpuExt", std::string("SVE")), framework::dataset::make("DataType", { DataType::F32, DataType::F16, DataType::U8, DataType::S16, DataType::S32 - }))), - combine(framework::dataset::make("CpuExt", std::string("SVE2")), + })), + framework::dataset::make("CanInterpretAs1D", {true, false}))), + combine(combine(framework::dataset::make("CpuExt", std::string("SVE2")), framework::dataset::make("DataType", { DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16 - }))), - cpu_ext, data_type) + })), + framework::dataset::make("CanInterpretAs1D", {false}))), + cpu_ext, data_type, can_interpret_inputs_as_1d_array) { using namespace cpu::kernels; @@ -121,11 +124,23 @@ DATA_TEST_CASE(KernelSelection, framework::DatasetMode::ALL, concat(concat( cpu_isa.sve2 = (cpu_ext == "SVE2"); cpu_isa.fp16 = (data_type == DataType::F16); - const auto *selected_impl = CpuAddKernel::get_implementation(DataTypeISASelectorData{data_type, cpu_isa}, cpu::KernelSelectionType::Preferred); + const auto *selected_impl = CpuAddKernel::get_implementation(CpuAddKernelDataTypeISASelectorData{data_type, cpu_isa, can_interpret_inputs_as_1d_array}, cpu::KernelSelectionType::Preferred); ARM_COMPUTE_ERROR_ON_NULLPTR(selected_impl); - std::string expected = lower_string(cpu_ext) + "_" + cpu_impl_dt(data_type) + "_add"; + bool float_or_integer = (data_type == DataType::F32 || data_type == DataType::F16 || data_type == DataType::U8 || + data_type == DataType::S16 || data_type == DataType::S32); + + std::string expected; + if(can_interpret_inputs_as_1d_array && float_or_integer) + { + expected = "neon_" + cpu_impl_dt(data_type) + "_add_as_1d_array"; + } + else + { + expected = lower_string(cpu_ext) + "_" + cpu_impl_dt(data_type) + "_add"; + } + std::string actual = selected_impl->name; ARM_COMPUTE_EXPECT_EQUAL(expected, actual, framework::LogLevel::ERRORS); diff --git a/tests/validation/NEON/ConvolutionLayer.cpp b/tests/validation/NEON/ConvolutionLayer.cpp index 578921bddd..67f7c8896f 100644 --- a/tests/validation/NEON/ConvolutionLayer.cpp +++ b/tests/validation/NEON/ConvolutionLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -504,6 +504,320 @@ TEST_SUITE_END() // FP16 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ TEST_SUITE_END() // WinogradLayer +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS +TEST_SUITE(VariableWeightUtils) + +// UC2_1_* tests: the user requests a specific fixed format, but there is no kernel that supports it. + +template +using HasOptImplFixtureNoFastMath = HasOptImplFixture; + +template +using HasOptImplFixtureFastMath = HasOptImplFixture; + +// UC2_1 + +FIXTURE_DATA_TEST_CASE(UC2_1_CpuGemmConv2d, HasOptImplFixtureNoFastMath, framework::DatasetMode::ALL, + combine(framework::dataset::make("DataType", { DataType::F32 }), + framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::OHWIo2 }))) +{ + ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS); +} +FIXTURE_DATA_TEST_CASE(UC2_1_NEGEMMConvolutionLayer, HasOptImplFixtureNoFastMath, framework::DatasetMode::ALL, + combine(framework::dataset::make("DataType", { DataType::F32 }), + framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::OHWIo2 }))) +{ + ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS); +} + +FIXTURE_DATA_TEST_CASE(UC2_1_CpuGemmConv2d_FastMath, HasOptImplFixtureFastMath, framework::DatasetMode::ALL, + combine(framework::dataset::make("DataType", { DataType::F32 }), + framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::OHWIo2 }))) +{ + ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS); +} + +FIXTURE_DATA_TEST_CASE(UC2_1_NEGEMMConvolutionLayer_FastMath, HasOptImplFixtureFastMath, framework::DatasetMode::ALL, + combine(framework::dataset::make("DataType", { DataType::F32 }), + framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::OHWIo2 }))) +{ + ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS); +} + +// UC2_2_* tests: the user requests a specific fixed format, and a +// kernel that support that fixed format is found. + +FIXTURE_DATA_TEST_CASE(UC2_2_CpuGemmConv2d, HasOptImplFixtureNoFastMath, framework::DatasetMode::ALL, + combine(framework::dataset::make("DataType", { DataType::F32 }), + framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::OHWIo4 }))) +{ + ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format == arm_compute::WeightFormat::OHWIo4, framework::LogLevel::ERRORS); +} + +FIXTURE_DATA_TEST_CASE(UC2_2_NEGEMMConvolutionLayer, HasOptImplFixtureNoFastMath, framework::DatasetMode::ALL, + combine(framework::dataset::make("DataType", { DataType::F32 }), + framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::OHWIo4 }))) +{ + ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format == arm_compute::WeightFormat::OHWIo4, framework::LogLevel::ERRORS); +} + +FIXTURE_DATA_TEST_CASE(UC2_2_CpuGemmConv2d_FastMath, HasOptImplFixtureFastMath, framework::DatasetMode::ALL, + combine(framework::dataset::make("DataType", { DataType::F32 }), + framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::OHWIo8i4_bf16 }))) +{ + ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT_EQUAL(_computed_weight_format, arm_compute::WeightFormat::OHWIo8i4_bf16, framework::LogLevel::ERRORS); +} + +FIXTURE_DATA_TEST_CASE(UC2_2_NEGEMMConvolutionLayer_FastMath, HasOptImplFixtureFastMath, framework::DatasetMode::ALL, + combine(framework::dataset::make("DataType", { DataType::F32 }), + framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::OHWIo8i4_bf16 }))) +{ + ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format == arm_compute::WeightFormat::OHWIo8i4_bf16, framework::LogLevel::ERRORS); +} + +// UC3_1_* tests: the user queries for ANY fixed format, but there is +// no kernel that support the use case specified by the user (for +// example, there is no fixed format kernel for the datatype of the +// problem). + +FIXTURE_DATA_TEST_CASE(UC3_1_CpuGemmConv2d, HasOptImplFixtureNoFastMath, framework::DatasetMode::ALL, + combine(framework::dataset::make("DataType", { DataType::S32 }), + framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::ANY }))) +{ + ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS); +} + +FIXTURE_DATA_TEST_CASE(UC3_1_NEGEMMConvolutionLayer, HasOptImplFixtureNoFastMath, framework::DatasetMode::ALL, + combine(framework::dataset::make("DataType", { DataType::S32 }), + framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::ANY }))) +{ + ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS); +} + +FIXTURE_DATA_TEST_CASE(UC3_1_CpuGemmConv2d_FastMath, HasOptImplFixtureFastMath, framework::DatasetMode::ALL, + combine(framework::dataset::make("DataType", { DataType::S32 }), + framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::ANY }))) +{ + ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS); +} + +FIXTURE_DATA_TEST_CASE(UC3_1_NEGEMMConvolutionLayer_FastMath, HasOptImplFixtureFastMath, framework::DatasetMode::ALL, + combine(framework::dataset::make("DataType", { DataType::S32 }), + framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::ANY }))) +{ + ARM_COMPUTE_EXPECT(!_kernel_found, framework::LogLevel::ERRORS); +} + +// UC3_2_* tests: the user queries for ANY fixed format. The search +// succeeded and the fixed format found is prompted back for +// consumption by the user. Note that we just test the +// _computed_weight_format to be anything but not the formats that are +// not fixed formats (ANY and UNSPECIFIED). This is because the weight +// format that the runtime produces depends on the size of the vector +// units of the hardware where the tests is executed. For example, a +// format like OHWIo4 for FP32 data returned for 128-bit NEON hardware +// is replaced by OHWIo8 when running on 256-bit SVE. + +FIXTURE_DATA_TEST_CASE(UC3_2_CpuGemmConv2d, HasOptImplFixtureNoFastMath, framework::DatasetMode::ALL, + combine(framework::dataset::make("DataType", { DataType::F32 }), + framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::ANY }))) +{ + ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS); +} + +FIXTURE_DATA_TEST_CASE(UC3_2_NEGEMMConvolutionLayer, HasOptImplFixtureNoFastMath, framework::DatasetMode::ALL, + combine(framework::dataset::make("DataType", { DataType::F32 }), + framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::ANY }))) +{ + ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS); +} + +FIXTURE_DATA_TEST_CASE(UC3_2_CpuGemmConv2d_FastMath, HasOptImplFixtureFastMath, framework::DatasetMode::ALL, + combine(framework::dataset::make("DataType", { DataType::F32 }), + framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::ANY }))) +{ + ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(arm_compute::is_fixed_format_fast_math(_computed_weight_format), framework::LogLevel::ERRORS); +} + +FIXTURE_DATA_TEST_CASE(UC3_2_NEGEMMConvolutionLayer_FastMath, HasOptImplFixtureFastMath, framework::DatasetMode::ALL, + combine(framework::dataset::make("DataType", { DataType::F32 }), + framework::dataset::make("QueryWeightFormat", { arm_compute::WeightFormat::ANY }))) +{ + ARM_COMPUTE_EXPECT(_kernel_found, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::ANY, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(_computed_weight_format != arm_compute::WeightFormat::UNSPECIFIED, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(arm_compute::is_fixed_format_fast_math(_computed_weight_format), framework::LogLevel::ERRORS); +} + +namespace +{ +using TestCaseType = std::tuple; +auto prepare_weights_shapes = framework::dataset::make("TensorShape", +{ + // OHWIoi + // + // OHWI --> O'HWI', where: + // + // O'= smallest multiple of such that O<=O' + // I'= smallest multiple of such that I<=I' + // + + // Change N for OHWIo4 + TestCaseType({ { 1U, 1U, 1U, 1U }, { 1U, 1U, 1U, 4U }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 1U, 1U, 1U, 2U }, { 1U, 1U, 1U, 4U }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 1U, 1U, 1U, 3U }, { 1U, 1U, 1U, 4U }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 1U, 1U, 1U, 4U }, { 1U, 1U, 1U, 4U }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 1U, 1U, 1U, 5U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 1U, 1U, 1U, 6U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 1U, 1U, 1U, 7U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 1U, 1U, 1U, 8U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 1U, 1U, 1U, 9U }, { 1U, 1U, 1U, 12U }, arm_compute::WeightFormat::OHWIo4 }), + // // Change N for OHWIo8 + TestCaseType({ { 1U, 1U, 1U, 1U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }), + TestCaseType({ { 1U, 1U, 1U, 2U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }), + TestCaseType({ { 1U, 1U, 1U, 3U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }), + TestCaseType({ { 1U, 1U, 1U, 4U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }), + TestCaseType({ { 1U, 1U, 1U, 5U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }), + TestCaseType({ { 1U, 1U, 1U, 6U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }), + TestCaseType({ { 1U, 1U, 1U, 7U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }), + TestCaseType({ { 1U, 1U, 1U, 8U }, { 1U, 1U, 1U, 8U }, arm_compute::WeightFormat::OHWIo8 }), + TestCaseType({ { 1U, 1U, 1U, 9U }, { 1U, 1U, 1U, 16U }, arm_compute::WeightFormat::OHWIo8 }), + // // Change N for OHWIo4 when H, W and C are not 1 + TestCaseType({ { 3U, 4U, 2U, 1U }, { 3, 4, 2, 4 }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 3U, 4U, 2U, 2U }, { 3, 4, 2, 4 }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 3U, 4U, 2U, 3U }, { 3, 4, 2, 4 }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 3U, 4U, 2U, 4U }, { 3, 4, 2, 4 }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 3U, 4U, 2U, 5U }, { 3, 4, 2, 8 }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 3U, 4U, 2U, 6U }, { 3, 4, 2, 8 }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 3U, 4U, 2U, 7U }, { 3, 4, 2, 8 }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 3U, 4U, 2U, 8U }, { 3, 4, 2, 8 }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 3U, 4U, 2U, 9U }, { 3, 4, 2, 12 }, arm_compute::WeightFormat::OHWIo4 }), + + // // Fix N and move HWI around, with different data layouts and formats + TestCaseType({ { 2U, 4U, 3U, 5U }, { 2, 4, 3, 8 }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 3U, 4U, 2U, 5U }, { 3, 4, 2, 8 }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 2U, 4U, 3U, 9U }, { 2, 4, 3, 16 }, arm_compute::WeightFormat::OHWIo8 }), + TestCaseType({ { 3U, 4U, 2U, 9U }, { 3, 4, 2, 16 }, arm_compute::WeightFormat::OHWIo8 }), + TestCaseType({ { 1024U, 1U, 1U, 1001U }, { 1024, 1, 1, 1008 }, arm_compute::WeightFormat::OHWIo8 }), + + // // Adding on I (=C) + TestCaseType({ { 1U, 4U, 3U, 5U }, { 2, 4, 3, 8 }, arm_compute::WeightFormat::OHWIo4i2 }), + TestCaseType({ { 2U, 4U, 3U, 5U }, { 2, 4, 3, 8 }, arm_compute::WeightFormat::OHWIo4i2 }), + TestCaseType({ { 3U, 4U, 3U, 5U }, { 4, 4, 3, 8 }, arm_compute::WeightFormat::OHWIo4i2 }), + + // --------- + TestCaseType({ { 2, 2, 1, 5 }, { 2, 2, 1, 8 }, arm_compute::WeightFormat::OHWIo4 }), + TestCaseType({ { 1, 2, 2, 5 }, { 1, 2, 2, 8 }, arm_compute::WeightFormat::OHWIo4 }), + +}); +} // unnamed namespace + +DATA_TEST_CASE(PrepareWeightShape, framework::DatasetMode::ALL, + prepare_weights_shapes, shapes) +{ + const TensorShape input_shape = std::get<0>(shapes); + const TensorShape expected_shape = std::get<1>(shapes); + const arm_compute::WeightFormat wf = std::get<2>(shapes); + const DataType DT = DataType::F32; + const DataLayout DL = DataLayout::NHWC; + const auto TI = TensorInfo(input_shape, 1 /*num_channels, deprecated*/, DT, DL); + const TensorInfo computed = ::arm_compute::test::validation::prepare_weights(TI, wf); + const TensorInfo expected = TensorInfo(expected_shape, 1 /*num_channels, deprecated*/, DT, DL); + ARM_COMPUTE_EXPECT_EQUAL(computed, expected, framework::LogLevel::ERRORS); +} + +TEST_SUITE_END() // VariableWeightUtils + +TEST_SUITE(ExperimentalCpuAPIVariableWeightWithFixtures) + +template +using VarWidth = VariableWeightsFixture; + +FIXTURE_DATA_TEST_CASE(RunSmallFloat, VarWidth, framework::DatasetMode::ALL, + combine(combine(datasets::SmallConvolutionLayerDataset(), + framework::dataset::make("DataLayout", { DataLayout::NHWC })), + framework::dataset::make("ACL Scalar type", { DataType::F32 }))) +{ + // Validate output + validate(Accessor(_target), _reference, rel_tolerance_f32, 0.f, float(abs_tolerance_f32)); +} + +FIXTURE_DATA_TEST_CASE(RunSmallHalf, VarWidth, framework::DatasetMode::ALL, + combine(combine(datasets::SmallConvolutionLayerDataset(), + framework::dataset::make("DataLayout", { DataLayout::NHWC })), + framework::dataset::make("ACL Scalar type", { DataType::F16 }))) +{ + // Validate output + validate(Accessor(_target), _reference, rel_tolerance_f16, 0.f, half(abs_tolerance_f16)); +} + +#if defined(ARM_COMPUTE_ENABLE_BF16) +template +using VarWidthFastMath = VariableWeightsFixture; + +FIXTURE_DATA_TEST_CASE(RunSmallFloatFastMath, VarWidthFastMath, framework::DatasetMode::ALL, + combine(combine(datasets::SmallConvolutionLayerDataset(), + framework::dataset::make("DataLayout", { DataLayout::NHWC })), + framework::dataset::make("ACL Scalar type", { DataType::F32 }))) +{ + // Validate output + validate(Accessor(_target), _reference, rel_tolerance_f32, 0.f, float(abs_tolerance_f32)); +} +#endif // ARM_COMPUTE_ENABLE_BF16 + +TEST_SUITE_END() // ExperimentalCpuAPIVariableWeightWithFixtures + +TEST_SUITE(ExperimentalNEAPIVariableWeightWithFixtures) + +template +using NEGEMMVarWidth = VariableWeightsFixtureNEInterface; + +FIXTURE_DATA_TEST_CASE(NEGEMMRunSmallFloat, NEGEMMVarWidth, framework::DatasetMode::ALL, + combine(combine(datasets::SmallConvolutionLayerDataset(), + framework::dataset::make("DataLayout", { DataLayout::NHWC })), + framework::dataset::make("ACL Scalar type", { DataType::F32 }))) +{ + // Validate output + validate(Accessor(_target), _reference, rel_tolerance_f32, 0.f, float(abs_tolerance_f32)); +} + +FIXTURE_DATA_TEST_CASE(NEGEMMRunSmallHalf, NEGEMMVarWidth, framework::DatasetMode::ALL, + combine(combine(datasets::SmallConvolutionLayerDataset(), + framework::dataset::make("DataLayout", { DataLayout::NHWC })), + framework::dataset::make("ACL Scalar type", { DataType::F16 }))) +{ + // Validate output + validate(Accessor(_target), _reference, rel_tolerance_f16, 0.f, half(abs_tolerance_f16)); +} + +#if defined(ARM_COMPUTE_ENABLE_BF16) +template +using NEGEMMVarWidthFastMath = VariableWeightsFixtureNEInterface; + +FIXTURE_DATA_TEST_CASE(NEGEMMRunSmallFloatFastMath, NEGEMMVarWidthFastMath, framework::DatasetMode::ALL, + combine(combine(datasets::SmallConvolutionLayerDataset(), + framework::dataset::make("DataLayout", { DataLayout::NHWC })), + framework::dataset::make("ACL Scalar type", { DataType::F32 }))) +{ + // Validate output + validate(Accessor(_target), _reference, rel_tolerance_f32, 0.f, float(abs_tolerance_f32)); +} +#endif // ARM_COMPUTE_ENABLE_BF16 + +TEST_SUITE_END() // ExperimentalNEAPIVariableWeightWithFixtures + +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS + TEST_SUITE(GEMMConvolutionLayer) template using NEGEMMConvolutionLayerFixture = ConvolutionValidationFixture; @@ -606,27 +920,22 @@ TEST_CASE(MultipleExecutionWithConfigure, framework::DatasetMode::ALL) } TEST_SUITE(Float) -#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) +#if defined(ARM_COMPUTE_ENABLE_BF16) TEST_SUITE(BFLOAT16) FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(), - framework::dataset::make("ReshapeWeights", { true })), - framework::dataset::make("DataType", DataType::BFLOAT16)), - framework::dataset::make("DataLayout", { DataLayout::NHWC })), + framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::BFLOAT16)), framework::dataset::make("DataLayout", { DataLayout::NHWC })), ActivationFunctionsDataset)) { // Validate output validate(Accessor(_target), _reference, rel_tolerance_f32, 0.f, float(abs_tolerance_f32)); } TEST_SUITE_END() // BFLOAT16 -#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ +#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC TEST_SUITE(FP16) FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(), - framework::dataset::make("ReshapeWeights", { true })), - framework::dataset::make("DataType", DataType::F16)), - framework::dataset::make("DataLayout", { DataLayout::NCHW })), - ActivationFunctionsDataset)) + framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("DataLayout", { DataLayout::NCHW })), ActivationFunctionsDataset)) { // Validate output validate(Accessor(_target), _reference, rel_tolerance_f16, tolerance_num, abs_tolerance_f16); @@ -636,9 +945,7 @@ TEST_SUITE_END() // FP16 TEST_SUITE(FP32) FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(), - framework::dataset::make("ReshapeWeights", { true })), - framework::dataset::make("DataType", DataType::F32)), - framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), + framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), ActivationFunctionsDataset)) { // Validate output @@ -680,11 +987,8 @@ const auto QuantizedActivationFunctionsDataset = framework::dataset::make("Activ TEST_SUITE(Quantized) TEST_SUITE(QASYMM8) FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerQuantizedFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(), - framework::dataset::make("ReshapeWeights", { true })), - framework::dataset::make("DataType", DataType::QASYMM8)), - framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), - framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })), - QuantizedActivationFunctionsDataset)) + framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::QASYMM8)), framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), + framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })), QuantizedActivationFunctionsDataset)) { // Validate output validate(Accessor(_target), _reference, tolerance_qasymm8); @@ -710,11 +1014,8 @@ TEST_SUITE_END() // QASYMM8 TEST_SUITE(QASYMM8_SIGNED) FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerQuantizedFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(), - framework::dataset::make("ReshapeWeights", { true })), - framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), - framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), - framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.01f, -10) })), - QuantizedActivationFunctionsDataset)) + framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })), + framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.01f, -10) })), QuantizedActivationFunctionsDataset)) { // Validate output validate(Accessor(_target), _reference, tolerance_qasymm8); @@ -868,10 +1169,7 @@ TEST_CASE(MultipleExecutionWithConfigure, framework::DatasetMode::ALL) TEST_SUITE(Float) TEST_SUITE(FP32) FIXTURE_DATA_TEST_CASE(RunSmall, NEDirectGEMMConv2dLayerFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(), - framework::dataset::make("ReshapeWeights", { true })), - framework::dataset::make("DataType", DataType::F32)), - framework::dataset::make("DataLayout", { DataLayout::NHWC })), - ActivationFunctionsDataset)) + framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("DataLayout", { DataLayout::NHWC })), ActivationFunctionsDataset)) { // Validate output validate(Accessor(_target), _reference, rel_tolerance_f32, 0.f, float(abs_tolerance_f32)); @@ -895,11 +1193,8 @@ const auto QuantizedActivationFunctionsDataset = framework::dataset::make("Activ TEST_SUITE(Quantized) TEST_SUITE(QASYMM8) FIXTURE_DATA_TEST_CASE(RunSmall, NEDirectGEMMConv2dLayerQuantizedFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(), - framework::dataset::make("ReshapeWeights", { true })), - framework::dataset::make("DataType", DataType::QASYMM8)), - framework::dataset::make("DataLayout", { DataLayout::NHWC })), - framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })), - QuantizedActivationFunctionsDataset)) + framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::QASYMM8)), framework::dataset::make("DataLayout", { DataLayout::NHWC })), + framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })), QuantizedActivationFunctionsDataset)) { // Validate output validate(Accessor(_target), _reference, tolerance_qasymm8); @@ -908,11 +1203,8 @@ TEST_SUITE_END() // QASYMM8 TEST_SUITE(QASYMM8_SIGNED) FIXTURE_DATA_TEST_CASE(RunSmall, NEDirectGEMMConv2dLayerQuantizedFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(), - framework::dataset::make("ReshapeWeights", { true })), - framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), - framework::dataset::make("DataLayout", { DataLayout::NHWC })), - framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.01f, -10) })), - QuantizedActivationFunctionsDataset)) + framework::dataset::make("ReshapeWeights", { true })), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), framework::dataset::make("DataLayout", { DataLayout::NHWC })), + framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.01f, -10) })), QuantizedActivationFunctionsDataset)) { // Validate output validate(Accessor(_target), _reference, tolerance_qasymm8); diff --git a/tests/validation/NEON/DepthConvertLayer.cpp b/tests/validation/NEON/DepthConvertLayer.cpp index 5649e5a556..378652c24f 100644 --- a/tests/validation/NEON/DepthConvertLayer.cpp +++ b/tests/validation/NEON/DepthConvertLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -342,7 +342,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEDepthConvertLayerToS32Fixture, frame } TEST_SUITE_END() // S16_to_S32 -#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) +#if defined(ARM_COMPUTE_ENABLE_BF16) TEST_SUITE(BFLOAT16_to_F32) FIXTURE_DATA_TEST_CASE(RunSmall, NEDepthConvertLayerToF32Fixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), DepthConvertLayerBF16toF32Dataset), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), @@ -362,7 +362,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEDepthConvertLayerToBF16Fixture, framew validate(Accessor(_target), _reference); } TEST_SUITE_END() // F32_to_BFLOAT16 -#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */ +#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC TEST_SUITE(F16_to_QASYMM8) diff --git a/tests/validation/NEON/FillBorder.cpp b/tests/validation/NEON/FillBorder.cpp index 343ad831e4..928990b2b4 100644 --- a/tests/validation/NEON/FillBorder.cpp +++ b/tests/validation/NEON/FillBorder.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Arm Limited. + * Copyright (c) 2017-2020, 2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -60,10 +60,10 @@ DATA_TEST_CASE(FillBorder, framework::DatasetMode::ALL, combine(combine(combine( { BorderSize border_size{ static_cast(size) }; - std::mt19937 generator(library->seed()); - std::uniform_int_distribution distribution_u8(0, 255); - const uint8_t border_value = distribution_u8(generator); - const uint8_t tensor_value = distribution_u8(generator); + std::mt19937 generator(library->seed()); + std::uniform_int_distribution distribution_u8(0, 255); + const uint8_t border_value = distribution_u8(generator); + const uint8_t tensor_value = distribution_u8(generator); // Create tensors Tensor src = create_tensor(shape, data_type); @@ -77,7 +77,7 @@ DATA_TEST_CASE(FillBorder, framework::DatasetMode::ALL, combine(combine(combine( validate(src.info()->padding(), padding); // Fill tensor with constant value - std::uniform_int_distribution distribution{ tensor_value, tensor_value }; + std::uniform_int_distribution distribution{ tensor_value, tensor_value }; library->fill(Accessor(src), distribution, 0); // Create and configure kernel diff --git a/tests/validation/NEON/Gather.cpp b/tests/validation/NEON/Gather.cpp index ca1e166bd1..0aea19939e 100644 --- a/tests/validation/NEON/Gather.cpp +++ b/tests/validation/NEON/Gather.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021 Arm Limited. + * Copyright (c) 2019-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -100,12 +100,14 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip( template using NEGatherFixture = GatherFixture; +const auto gather_small_shapes = arm_compute::test::framework::dataset::concat(datasets::SmallGatherDataset(), datasets::SmallGatherMultiDimIndicesDataset()); + TEST_SUITE(Float) TEST_SUITE(FP16) FIXTURE_DATA_TEST_CASE(RunSmall, NEGatherFixture, framework::DatasetMode::PRECOMMIT, - combine(datasets::SmallGatherDataset(), framework::dataset::make("DataType", DataType::F16))) + combine(gather_small_shapes, framework::dataset::make("DataType", DataType::F16))) { // Validate output validate(Accessor(_target), _reference); @@ -125,7 +127,7 @@ TEST_SUITE(FP32) FIXTURE_DATA_TEST_CASE(RunSmall, NEGatherFixture, framework::DatasetMode::PRECOMMIT, - combine(datasets::SmallGatherDataset(), framework::dataset::make("DataType", DataType::F32))) + combine(gather_small_shapes, framework::dataset::make("DataType", DataType::F32))) { // Validate output validate(Accessor(_target), _reference); @@ -146,7 +148,7 @@ TEST_SUITE(U8) FIXTURE_DATA_TEST_CASE(RunSmall, NEGatherFixture, framework::DatasetMode::PRECOMMIT, - combine(datasets::SmallGatherDataset(), framework::dataset::make("DataType", DataType::U8))) + combine(gather_small_shapes, framework::dataset::make("DataType", DataType::U8))) { // Validate output validate(Accessor(_target), _reference); @@ -166,7 +168,7 @@ TEST_SUITE(U16) FIXTURE_DATA_TEST_CASE(RunSmall, NEGatherFixture, framework::DatasetMode::PRECOMMIT, - combine(datasets::SmallGatherDataset(), framework::dataset::make("DataType", DataType::U16))) + combine(gather_small_shapes, framework::dataset::make("DataType", DataType::U16))) { // Validate output validate(Accessor(_target), _reference); diff --git a/tests/validation/NEON/UNIT/TensorAllocator.cpp b/tests/validation/NEON/UNIT/TensorAllocator.cpp index d84bcd4a20..0aab9ef9b5 100644 --- a/tests/validation/NEON/UNIT/TensorAllocator.cpp +++ b/tests/validation/NEON/UNIT/TensorAllocator.cpp @@ -193,7 +193,7 @@ TEST_CASE(ImportMemoryMallocPadded, framework::DatasetMode::ALL) ARM_COMPUTE_ASSERT(tensor.info()->is_resizable()); } -#if !defined(BARE_METAL) +#if !defined(_WIN64) && !defined(BARE_METAL) TEST_CASE(ImportMemoryMappedFile, framework::DatasetMode::ALL) { const ActivationLayerInfo act_info(ActivationLayerInfo::ActivationFunction::RELU); @@ -250,7 +250,7 @@ TEST_CASE(ImportMemoryMappedFile, framework::DatasetMode::ALL) tensor.allocator()->free(); ARM_COMPUTE_ASSERT(tensor.info()->is_resizable()); } -#endif // !defined(BARE_METAL) +#endif // !defined(_WIN64) && !defined(BARE_METAL) TEST_CASE(AlignedAlloc, framework::DatasetMode::ALL) { diff --git a/tests/validation/UNIT/GPUTarget.cpp b/tests/validation/UNIT/GPUTarget.cpp index b5eccf6239..5ec2592f00 100644 --- a/tests/validation/UNIT/GPUTarget.cpp +++ b/tests/validation/UNIT/GPUTarget.cpp @@ -37,6 +37,7 @@ TEST_SUITE(GPUTarget) TEST_CASE(GetGPUTargetFromName, framework::DatasetMode::ALL) { + ARM_COMPUTE_EXPECT(get_target_from_name("Mali-T000") == GPUTarget::MIDGARD, framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(get_target_from_name("Mali-T600") == GPUTarget::T600, framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(get_target_from_name("Mali-T700") == GPUTarget::T700, framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(get_target_from_name("Mali-T800") == GPUTarget::T800, framework::LogLevel::ERRORS); @@ -45,15 +46,22 @@ TEST_CASE(GetGPUTargetFromName, framework::DatasetMode::ALL) ARM_COMPUTE_EXPECT(get_target_from_name("Mali-G51") == GPUTarget::G51, framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(get_target_from_name("Mali-G51BIG") == GPUTarget::G51BIG, framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(get_target_from_name("Mali-G51LIT") == GPUTarget::G51LIT, framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(get_target_from_name("Mali-G52") == GPUTarget::G52, framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(get_target_from_name("Mali-G52LIT") == GPUTarget::G52LIT, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(get_target_from_name("Mali-G31") == GPUTarget::G31, framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(get_target_from_name("Mali-G76") == GPUTarget::G76, framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(get_target_from_name("Mali-G76 r0p0") == GPUTarget::G76, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(get_target_from_name("Mali-G52") == GPUTarget::G52, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(get_target_from_name("Mali-G52LIT") == GPUTarget::G52LIT, framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(get_target_from_name("Mali-G77") == GPUTarget::G77, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(get_target_from_name("Mali-G57") == GPUTarget::G57, framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(get_target_from_name("Mali-G78") == GPUTarget::G78, framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(get_target_from_name("Mali-G78AE") == GPUTarget::G78, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(get_target_from_name("Mali-G68") == GPUTarget::G68, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(get_target_from_name("Mali-G78AE") == GPUTarget::G78AE, framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(get_target_from_name("Mali-G710") == GPUTarget::G710, framework::LogLevel::ERRORS); - ARM_COMPUTE_EXPECT(get_target_from_name("Mali-T000") == GPUTarget::MIDGARD, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(get_target_from_name("Mali-G610") == GPUTarget::G610, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(get_target_from_name("Mali-G510") == GPUTarget::G510, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(get_target_from_name("Mali-G310") == GPUTarget::G310, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(get_target_from_name("Mali-G715") == GPUTarget::G715, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(get_target_from_name("Mali-G615") == GPUTarget::G615, framework::LogLevel::ERRORS); } TEST_CASE(GPUTargetIsIn, framework::DatasetMode::ALL) diff --git a/tests/validation/fixtures/ArgMinMaxFixture.h b/tests/validation/fixtures/ArgMinMaxFixture.h index caa6bb8d9c..2bbce4077e 100644 --- a/tests/validation/fixtures/ArgMinMaxFixture.h +++ b/tests/validation/fixtures/ArgMinMaxFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021 Arm Limited. + * Copyright (c) 2018-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -80,7 +80,7 @@ class ArgMinMaxValidationBaseFixture : public framework::Fixture case DataType::QASYMM8: { std::pair bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f); - std::uniform_int_distribution distribution(bounds.first, bounds.second); + std::uniform_int_distribution distribution(bounds.first, bounds.second); library->fill(tensor, distribution, 0); break; @@ -88,7 +88,7 @@ class ArgMinMaxValidationBaseFixture : public framework::Fixture case DataType::QASYMM8_SIGNED: { std::pair bounds = get_quantized_qasymm8_signed_bounds(tensor.quantization_info(), -1.0f, 1.0f); - std::uniform_int_distribution distribution(bounds.first, bounds.second); + std::uniform_int_distribution distribution(bounds.first, bounds.second); library->fill(tensor, distribution, 0); break; diff --git a/tests/validation/fixtures/ConvertFullyConnectedWeightsFixture.h b/tests/validation/fixtures/ConvertFullyConnectedWeightsFixture.h index ae844332c3..38088b4000 100644 --- a/tests/validation/fixtures/ConvertFullyConnectedWeightsFixture.h +++ b/tests/validation/fixtures/ConvertFullyConnectedWeightsFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2021 Arm Limited. + * Copyright (c) 2018-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -61,7 +61,7 @@ class ConvertFullyConnectedWeightsValidationFixture : public framework::Fixture { case DataType::QASYMM8: { - std::uniform_int_distribution distribution(0, 10); + std::uniform_int_distribution distribution(0, 10); library->fill(tensor, distribution, i); break; } diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h index 0b3f070e9c..63e6dc9377 100644 --- a/tests/validation/fixtures/ConvolutionLayerFixture.h +++ b/tests/validation/fixtures/ConvolutionLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -28,6 +28,7 @@ #include "arm_compute/core/Types.h" #include "arm_compute/graph/Utils.h" #include "arm_compute/runtime/NEON/NEScheduler.h" +#include "src/core/NEON/kernels/arm_gemm/utils.hpp" #include "src/graph/mutators/MutatorUtils.h" #include "tests/AssetsLibrary.h" #include "tests/Globals.h" @@ -121,15 +122,15 @@ class ConvolutionValidationGenericFixture : public framework::Fixture { case DataType::QASYMM8: { - std::pair bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f); - std::uniform_int_distribution distribution(bounds.first, bounds.second); + std::pair bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f); + std::uniform_int_distribution distribution(bounds.first, bounds.second); library->fill(tensor, distribution, i); break; } case DataType::QASYMM8_SIGNED: { - std::pair bounds = get_quantized_qasymm8_signed_bounds(tensor.quantization_info(), -1.0f, 1.0f); - std::uniform_int_distribution distribution(bounds.first, bounds.second); + std::pair bounds = get_quantized_qasymm8_signed_bounds(tensor.quantization_info(), -1.0f, 1.0f); + std::uniform_int_distribution distribution(bounds.first, bounds.second); library->fill(tensor, distribution, i); break; } @@ -149,7 +150,7 @@ class ConvolutionValidationGenericFixture : public framework::Fixture max_bound = bounds.second; } } - std::uniform_int_distribution distribution(min_bound, max_bound); + std::uniform_int_distribution distribution(min_bound, max_bound); library->fill(tensor, distribution, i); break; } @@ -397,6 +398,297 @@ class ConvolutionValidationQuantizedPerChannelFixture : public ConvolutionValida quantization_info, QuantizationInfo(weights_scales), act_info); } }; + +#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS +inline TensorInfo prepare_weights(const TensorInfo tensor_info, const arm_compute::WeightFormat weight_format) +{ + const DataLayout data_layout = tensor_info.data_layout(); + ARM_COMPUTE_EXPECT(data_layout == DataLayout::NHWC, framework::LogLevel::ERRORS); + const DataType data_type = tensor_info.data_type(); + const TensorShape tensor_shape = tensor_info.tensor_shape(); + const int N = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES)]; // N=O + const int H = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)]; + const int W = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)]; + const int C = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)]; // C=I + + const int interleave_by = arm_compute::interleave_by(weight_format); + const int block_by = arm_compute::block_by(weight_format); + const int Ip = arm_gemm::roundup(C, block_by); // C'=I' + const int Op = arm_gemm::roundup(N, interleave_by); // O'=N' + + const TensorShape TS(Ip, W, H, Op); + return TensorInfo(TS, 1 /*num_channels*/, data_type, data_layout); +} + +template +inline void rearrange_data(const AccessorType src, AccessorType dst, const arm_compute::WeightFormat weight_format) +{ + ARM_COMPUTE_EXPECT(arm_compute::is_fixed_format(weight_format), framework::LogLevel::ERRORS); + // Data Layout: OHWIoi + const int interleave_by = arm_compute::interleave_by(weight_format); + const int block_by = arm_compute::block_by(weight_format); + const TensorShape src_tensor_shape = src.shape(); + const DataLayout data_layout = src.data_layout(); + ARM_COMPUTE_EXPECT(data_layout == DataLayout::NHWC, framework::LogLevel::ERRORS); + const unsigned int O = src_tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES)]; // N=O + const unsigned int H = src_tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)]; + const unsigned int W = src_tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)]; + const unsigned int I = src_tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)]; // C=I + const unsigned int Ip = arm_gemm::roundup(I, block_by); // C'=I' + const unsigned int Op = arm_gemm::roundup(O, interleave_by); // N'=O' + + ARM_COMPUTE_EXPECT_EQUAL(Op * H * W * Ip, (unsigned)dst.num_elements(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(src.num_elements() <= dst.num_elements(), framework::LogLevel::ERRORS); + + const ScalarType *src_ptr = reinterpret_cast(src.data()); + ScalarType *dst_ptr = reinterpret_cast(dst.data()); + for(unsigned i = 0; i < I; ++i) + for(unsigned w = 0; w < W; ++w) + for(unsigned h = 0; h < H; ++h) + for(unsigned o = 0; o < O; ++o) + { + ScalarType src_element; + switch(data_layout) + { + case DataLayout::NHWC: + { + src_element = src_ptr[o * H * W * I + h * W * I + w * I + i]; + } + break; + default: + { + ARM_COMPUTE_ERROR("Unsupported memory layout."); + } + } + const int x5 = std::floor(((float)o) / interleave_by); + const int x4 = h; + const int x3 = w; + const int x2 = std::floor((float)i / block_by); + const int x1 = o % interleave_by; + const int x0 = i % block_by; + unsigned dst_idx = x5 * H * W * Ip * interleave_by + + x4 * W * Ip * interleave_by + + x3 * Ip * interleave_by + + x2 * interleave_by * block_by + + x1 * block_by + + x0; + dst_ptr[dst_idx] = src_element; + } +} + +template +class VariableWeightsFixtureBaseClass : public framework::Fixture +{ +public: + template + void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, DataLayout data_layout, + const DataType data_type) + { + conv = std::make_unique(); + // prepare data + _data_layout = data_layout; + // Fixed format kernels for variable weights can work only with NHWC format. + ARM_COMPUTE_EXPECT_EQUAL(_data_layout, DataLayout::NHWC, framework::LogLevel::ERRORS); + _data_type = data_type; + // run the code + compute_target(input_shape, weights_shape, bias_shape, output_shape, info, dilation); + compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, dilation); + } + void teardown() + { + _target.allocator()->free(); + } + +protected: + template + void fill(U &&tensor, int i) + { + switch(tensor.data_type()) + { + case DataType::F16: + { + arm_compute::utils::uniform_real_distribution_16bit distribution{ -1.0f, 1.0f }; + library->fill(tensor, distribution, i); + break; + } + case DataType::F32: + { + std::uniform_real_distribution distribution(-1.0f, 1.0f); + library->fill(tensor, distribution, i); + break; + } + default: + library->fill_tensor_uniform(tensor, i); + } + } + +private: + virtual void configure_and_execute_kernel(TensorInfo src_tensor_info, TensorInfo weight_tensor_info, TensorInfo bias_tensor_info, TensorInfo dst_tensor_info, const WeightsInfo weights_info, + const PadStrideInfo &conv_info, + const Size2D &dilation) = 0; + + void compute_target(TensorShape input_shape, TensorShape weights_shape, const TensorShape &bias_shape, TensorShape output_shape, const PadStrideInfo &conv_info, + const Size2D &dilation) + { + // The dataset is always in NCHW format - we need to make C the + // innermost dimension because the fixed-format kernel work only + // with NHWC layout. + permute(input_shape, PermutationVector(2U, 0U, 1U)); + permute(weights_shape, PermutationVector(2U, 0U, 1U)); + permute(output_shape, PermutationVector(2U, 0U, 1U)); + const auto src_tensor_info = TensorInfo(input_shape, 1, _data_type, _data_layout); + const auto weight_tensor_info = TensorInfo(weights_shape, 1, _data_type, _data_layout); + const auto bias_tensor_info = TensorInfo(bias_shape, 1, _data_type, _data_layout); + auto dst_tensor_info = TensorInfo(output_shape, 1, _data_type, _data_layout); + + const int kernel_height = weights_shape[get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT)]; + const int kernel_width = weights_shape[get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH)]; + const int num_kernels = weights_shape[get_data_layout_dimension_index(_data_layout, DataLayoutDimension::BATCHES)]; + + const WeightsInfo query_weights_info(/*reshape_weights*/ false, kernel_width, kernel_height, num_kernels, false, arm_compute::WeightFormat::ANY); + const bool kernel_found = bool(ConvolutionFunction::has_opt_impl(_computed_weight_format, &src_tensor_info, &weight_tensor_info, + &bias_tensor_info, &dst_tensor_info, conv_info, query_weights_info)); + // Make surethat the setup founds a fixed-format kernel as requested by the test case. + ARM_COMPUTE_EXPECT(kernel_found, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(arm_compute::is_fixed_format(_computed_weight_format), framework::LogLevel::ERRORS); + + const WeightsInfo weights_info(/*reshape_weights*/ false, kernel_width, kernel_height, num_kernels, false, _computed_weight_format); + configure_and_execute_kernel(src_tensor_info, weight_tensor_info, bias_tensor_info, dst_tensor_info, weights_info, conv_info, + dilation); + } + void compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info, + const Size2D &dilation) + { + ARM_COMPUTE_UNUSED(input_shape, weights_shape, bias_shape, output_shape, info, + dilation); + + // Create reference + SimpleTensor src{ input_shape, _data_type }; + SimpleTensor weights{ weights_shape, _data_type }; + SimpleTensor bias{ bias_shape, _data_type }; + fill(src, 0); + fill(bias, 1); + fill(weights, 3); + _reference = reference::convolution_layer(src, weights, bias, output_shape, info, dilation, 1 /*num_groups*/); + } + DataLayout _data_layout{}; + DataType _data_type{}; + +protected: + std::unique_ptr conv{}; + arm_compute::WeightFormat _computed_weight_format{ arm_compute::WeightFormat::UNSPECIFIED }; + TensorClass _target{}; + SimpleTensor _reference{}; +}; + +template +class VariableWeightsFixture : public VariableWeightsFixtureBaseClass +{ + void configure_and_execute_kernel(TensorInfo src_tensor_info, TensorInfo weight_tensor_info, TensorInfo bias_tensor_info, TensorInfo dst_tensor_info, const WeightsInfo weights_info, + const PadStrideInfo &conv_info, + const Size2D &dilation) + { + this->conv->configure(&src_tensor_info, &weight_tensor_info, &bias_tensor_info, &dst_tensor_info, conv_info, weights_info, dilation, ActivationLayerInfo(), enable_fast_math); + + // Allocate input tensors + auto src = create_tensor(src_tensor_info); + auto weights_original = create_tensor(weight_tensor_info); + const TensorInfo new_tensor_info = prepare_weights(weight_tensor_info, this->_computed_weight_format); + auto weights_transformed = create_tensor(new_tensor_info); + auto bias = create_tensor(bias_tensor_info); + src.allocator()->allocate(); + weights_original.allocator()->allocate(); + weights_transformed.allocator()->allocate(); + bias.allocator()->allocate(); + // Allocate destination tensor + this->_target = create_tensor(dst_tensor_info); + this->_target.allocator()->allocate(); + + // Prepare source and biases that are left unchanged. + this->fill(AccessorType(src), 0); + this->fill(AccessorType(bias), 1); + + // First run + this->fill(AccessorType(weights_original), 2); + rearrange_data(AccessorType(weights_original), AccessorType(weights_transformed), this->_computed_weight_format); + ITensorPack run_pack{ { TensorType::ACL_SRC_0, &src }, { TensorType::ACL_SRC_1, &weights_transformed }, { TensorType::ACL_SRC_2, &bias }, { TensorType::ACL_DST, &(this->_target) } }; + this->conv->run(run_pack); + // Second run, with new weights + this->fill(AccessorType(weights_original), 3); + rearrange_data(AccessorType(weights_original), AccessorType(weights_transformed), this->_computed_weight_format); + this->conv->run(run_pack); + src.allocator()->free(); + weights_original.allocator()->free(); + weights_transformed.allocator()->free(); + bias.allocator()->free(); + } +}; + +template +class VariableWeightsFixtureNEInterface : public VariableWeightsFixtureBaseClass +{ + void configure_and_execute_kernel(TensorInfo src_tensor_info, TensorInfo weight_tensor_info, TensorInfo bias_tensor_info, TensorInfo dst_tensor_info, const WeightsInfo weights_info, + const PadStrideInfo &conv_info, + const Size2D &dilation) + { + // Allocate input tensors + auto src = create_tensor(src_tensor_info); + auto weights_original = create_tensor(weight_tensor_info); + const TensorInfo new_tensor_info = prepare_weights(weight_tensor_info, this->_computed_weight_format); + auto weights_transformed = create_tensor(new_tensor_info); + auto bias = create_tensor(bias_tensor_info); + src.allocator()->allocate(); + weights_original.allocator()->allocate(); + weights_transformed.allocator()->allocate(); + bias.allocator()->allocate(); + // Allocate destination tensor + this->_target = create_tensor(dst_tensor_info); + this->_target.allocator()->allocate(); + this->conv->configure(&src, &weights_transformed, &bias, &(this->_target), conv_info, weights_info, dilation, ActivationLayerInfo(), enable_fast_math); + // Prepare source and biases that are left unchanged. + this->fill(AccessorType(src), 0); + this->fill(AccessorType(bias), 1); + + // First run + this->fill(AccessorType(weights_original), 2); + rearrange_data(AccessorType(weights_original), AccessorType(weights_transformed), this->_computed_weight_format); + this->conv->run(); + // Second run, with new weights + this->fill(AccessorType(weights_original), 3); + rearrange_data(AccessorType(weights_original), AccessorType(weights_transformed), this->_computed_weight_format); + this->conv->run(); + src.allocator()->free(); + weights_original.allocator()->free(); + weights_transformed.allocator()->free(); + bias.allocator()->free(); + } +}; + +template +class HasOptImplFixture : public framework::Fixture +{ +public: + template + void setup(DataType data_type, arm_compute::WeightFormat query_weight_format) + { + auto conv = std::make_unique(); + const auto src_info = TensorInfo(TensorShape(56U, 56U, 64U), 1, data_type, DataLayout::NHWC); + const auto weight_info = TensorInfo(TensorShape(64, 3U, 3U, 64U), 1, enable_fast_math ? DataType::BFLOAT16 : data_type, DataLayout::NHWC); + const auto bias_info = TensorInfo(TensorShape(64U), 1, data_type, DataLayout::NHWC); + auto dst_info = TensorInfo(TensorShape(56U, 56U, 64U), 1, data_type, DataLayout::NHWC); + const auto conv_info = PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::FLOOR); + const WeightsInfo weights_info(false, 3U, 3U, 64U, false, query_weight_format); + _kernel_found = bool(ConvolutionClass::has_opt_impl(_computed_weight_format, &src_info, &weight_info, + &bias_info, &dst_info, conv_info, weights_info, + /*dilation*/ Size2D(1U, 1U), /*act_info*/ ActivationLayerInfo(), enable_fast_math)); + } + +protected: + bool _kernel_found{ false }; + arm_compute::WeightFormat _computed_weight_format{ arm_compute::WeightFormat::UNSPECIFIED }; +}; +#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS + } // namespace validation } // namespace test } // namespace arm_compute diff --git a/tests/validation/fixtures/DeconvolutionLayerFixture.h b/tests/validation/fixtures/DeconvolutionLayerFixture.h index 14f071eed0..d13eab2f54 100644 --- a/tests/validation/fixtures/DeconvolutionLayerFixture.h +++ b/tests/validation/fixtures/DeconvolutionLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -75,14 +75,14 @@ class DeconvolutionLayerFixtureBase : public framework::Fixture case DataType::QASYMM8: { std::pair bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f); - std::uniform_int_distribution distribution(bounds.first, bounds.second); + std::uniform_int_distribution distribution(bounds.first, bounds.second); library->fill(tensor, distribution, i); break; } case DataType::QASYMM8_SIGNED: { std::pair bounds = get_quantized_qasymm8_signed_bounds(tensor.quantization_info(), -1.0f, 1.0f); - std::uniform_int_distribution distribution(bounds.first, bounds.second); + std::uniform_int_distribution distribution(bounds.first, bounds.second); library->fill(tensor, distribution, i); break; } @@ -102,7 +102,7 @@ class DeconvolutionLayerFixtureBase : public framework::Fixture max_bound = bounds.second; } } - std::uniform_int_distribution distribution(min_bound, max_bound); + std::uniform_int_distribution distribution(min_bound, max_bound); library->fill(tensor, distribution, i); break; } diff --git a/tests/validation/fixtures/DepthConvertLayerFixture.h b/tests/validation/fixtures/DepthConvertLayerFixture.h index 130b583dc1..53d29b44ba 100644 --- a/tests/validation/fixtures/DepthConvertLayerFixture.h +++ b/tests/validation/fixtures/DepthConvertLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -61,7 +61,7 @@ class DepthConvertLayerValidationBaseFixture : public framework::Fixture if(is_data_type_quantized(tensor.data_type())) { std::pair bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f); - std::uniform_int_distribution distribution(bounds.first, bounds.second); + std::uniform_int_distribution distribution(bounds.first, bounds.second); library->fill(tensor, distribution, i); } diff --git a/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h b/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h index cecccc87bb..9fd973ad20 100644 --- a/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h +++ b/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -199,14 +199,14 @@ class DepthwiseConvolutionLayerValidationGenericFixture : public framework::Fixt { case DataType::QASYMM8: { - std::uniform_int_distribution distribution(0, 15); + std::uniform_int_distribution distribution(0, 15); library->fill(tensor, distribution, i); break; } case DataType::QASYMM8_SIGNED: case DataType::QSYMM8_PER_CHANNEL: { - std::uniform_int_distribution distribution(-10, 10); + std::uniform_int_distribution distribution(-10, 10); library->fill(tensor, distribution, i); break; } diff --git a/tests/validation/fixtures/DirectConvolutionLayerFixture.h b/tests/validation/fixtures/DirectConvolutionLayerFixture.h index 614aa20753..31186e2b1d 100644 --- a/tests/validation/fixtures/DirectConvolutionLayerFixture.h +++ b/tests/validation/fixtures/DirectConvolutionLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -112,14 +112,14 @@ class DirectConvolutionValidationGenericFixture : public framework::Fixture { case DataType::QASYMM8: { - std::uniform_int_distribution distribution(0, 50); + std::uniform_int_distribution distribution(0, 50); library->fill(tensor, distribution, i); break; } case DataType::QASYMM8_SIGNED: { // Use small input range to avoid all the test results being saturated at the end. - std::uniform_int_distribution distribution(-25, 25); + std::uniform_int_distribution distribution(-25, 25); library->fill(tensor, distribution, i); break; } diff --git a/tests/validation/fixtures/FullyConnectedLayerFixture.h b/tests/validation/fixtures/FullyConnectedLayerFixture.h index 3048c56f6b..b5efccdf70 100644 --- a/tests/validation/fixtures/FullyConnectedLayerFixture.h +++ b/tests/validation/fixtures/FullyConnectedLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -92,12 +92,12 @@ class FullyConnectedLayerValidationGenericFixture : public framework::Fixture { if(_data_type == DataType::QASYMM8) { - std::uniform_int_distribution distribution(0, 30); + std::uniform_int_distribution distribution(0, 30); library->fill(tensor, distribution, i); } else if(_data_type == DataType::QASYMM8_SIGNED) { - std::uniform_int_distribution distribution(-15, 15); + std::uniform_int_distribution distribution(-15, 15); library->fill(tensor, distribution, i); } else if(_data_type == DataType::S32) @@ -291,7 +291,7 @@ class FullyConnectedWithDynamicTensorsFixture : public framework::Fixture } else if(_data_type == DataType::QASYMM8) { - std::uniform_int_distribution distribution(0, 30); + std::uniform_int_distribution distribution(0, 30); library->fill(tensor, distribution, i); } else if(_data_type == DataType::S32) @@ -336,7 +336,7 @@ class FullyConnectedWithDynamicTensorsFixture : public framework::Fixture } else if(_data_type == DataType::QASYMM8) { - constexpr AbsoluteTolerance tolerance_qasymm8(1); + constexpr AbsoluteTolerance tolerance_qasymm8(1); validate(AccessorType(target), ref, tolerance_qasymm8); } else diff --git a/tests/validation/fixtures/GEMMFixture.h b/tests/validation/fixtures/GEMMFixture.h index 884b13da80..55bbbdaf80 100644 --- a/tests/validation/fixtures/GEMMFixture.h +++ b/tests/validation/fixtures/GEMMFixture.h @@ -163,18 +163,18 @@ class GEMMValidationFixture : public framework::Fixture const int m = reinterpret_output_as_3d ? output_shape[1] * output_shape[2] : output_shape[1]; const int batch_size = reinterpret_output_as_3d ? output_shape[3] : output_shape[2]; - // In case of broadcast, we need simply copy the first into the following "M" ones + // In case of broadcast, we need to simply copy the first into the following "M" ones for(int i = 1; i < m * batch_size; i++) { memcpy(c.data() + i * n, c.data(), n * sizeof(T)); } } - + /* Note: Assuming the usual batch matmul dimensions A = (B x M x K), B = (B x K x N), if pretranspose_A is set to true, then A is assumed to be (B x K x M), therefore, A must be pre-transposed before passing it to the fixture. And, we transpose A again in the fixture to make it (B x M x K) in order to be able to call reference implementation that works with (B x M x K) input. Similarly, if pretranspose_B is set to true, then B is assumed to be (B x N x K), B must be pre-transposed before passing it to the fixture. */ - + // Define transposed shapes TensorShape a_transposed_shape(a.shape().y(), a.shape().x()); TensorShape b_transposed_shape(b.shape().y(), b.shape().x()); @@ -315,7 +315,7 @@ class GEMMMatrixMultiplyValidationFixture : public framework::Fixture if(broadcast_bias) { - // In case of broadcast, we need simply copy the first into the following "M" ones + // In case of broadcast, we need to simply copy the first into the following "M" ones for(int i = 1; i < m * batch_size; i++) { memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); @@ -438,7 +438,7 @@ class GEMMMatrixMultiply3DValidationFixture : public framework::Fixture fill(rhs, 1); fill(bias, 2); - // In case of broadcast, we need simply copy the first into the following "M" ones + // In case of broadcast, we need to simply copy the first into the following "M" ones for(int i = 1; i < m * batch_size; i++) { memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); @@ -593,7 +593,7 @@ class GEMMMatrixMultiplyInterleavedTransposedValidationFixture : public framewor if(broadcast_bias) { - // In case of broadcast, we need simply copy the first into the following "M" ones + // In case of broadcast, we need to simply copy the first into the following "M" ones for(int i = 1; i < m * batch_size; i++) { memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); @@ -748,7 +748,7 @@ class GEMMMatrixMultiplyInterleavedTransposed3DValidationFixture : public framew fill(rhs, 1); fill(bias, 2); - // In case of broadcast, we need simply copy the first into the following "M" ones + // In case of broadcast, we need to simply copy the first into the following "M" ones for(int i = 1; i < m * batch_size; i++) { memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); @@ -923,7 +923,7 @@ class GEMMMatrixMultiplyReshapedValidationFixture : public framework::Fixture if(broadcast_bias) { - // In case of broadcast, we need simply copy the first into the following "M" ones + // In case of broadcast, we need to simply copy the first into the following "M" ones for(int i = 1; i < m * batch_size; i++) { memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); @@ -1169,7 +1169,7 @@ class GEMMMatrixMultiplyReshapedWithPostOpsValidationFixture : public framework: if(broadcast_bias) { - // In case of broadcast, we need simply copy the first into the following "M" ones + // In case of broadcast, we need to simply copy the first into the following "M" ones for(int i = 1; i < m * batch_size; i++) { memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); @@ -1361,7 +1361,7 @@ class GEMMMatrixMultiplyReshaped3DValidationFixture : public framework::Fixture fill(rhs, 1); fill(bias, 2); - // In case of broadcast, we need simply copy the first into the following "M" ones + // In case of broadcast, we need to simply copy the first into the following "M" ones for(int i = 1; i < m * batch_size; i++) { memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); @@ -1533,7 +1533,7 @@ class GEMMMatrixMultiplyReshapedOnlyRHSValidationFixture : public framework::Fix if(broadcast_bias) { - // In case of broadcast, we need simply copy the first into the following "M" ones + // In case of broadcast, we need to simply copy the first into the following "M" ones for(int i = 1; i < m * batch_size; i++) { memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); @@ -1759,7 +1759,7 @@ class GEMMMatrixMultiplyReshapedOnlyRHSWithPostOpsValidationFixture : public fra if(broadcast_bias) { - // In case of broadcast, we need simply copy the first into the following "M" ones + // In case of broadcast, we need to simply copy the first into the following "M" ones for(int i = 1; i < m * batch_size; i++) { memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); @@ -1941,7 +1941,7 @@ class GEMMMatrixMultiplyReshapedOnlyRHS3DValidationFixture : public framework::F fill(rhs, 1); fill(bias, 2); - // In case of broadcast, we need simply copy the first into the following "M" ones + // In case of broadcast, we need to simply copy the first into the following "M" ones for(int i = 1; i < m * batch_size; i++) { memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); @@ -2078,7 +2078,7 @@ class GEMMMatrixMultiplyNativeValidationFixture : public framework::Fixture if(broadcast_bias) { - // In case of broadcast, we need simply copy the first into the following "M" ones + // In case of broadcast, we need to simply copy the first into the following "M" ones for(int i = 1; i < m * batch_size; i++) { memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); @@ -2274,7 +2274,7 @@ class GEMMMatrixMultiplyNativeWithPostOpsValidationFixture : public framework::F if(broadcast_bias) { - // In case of broadcast, we need simply copy the first into the following "M" ones + // In case of broadcast, we need to simply copy the first into the following "M" ones for(int i = 1; i < m * batch_size; i++) { memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); @@ -2421,7 +2421,7 @@ class GEMMMatrixMultiplyNative3DValidationFixture : public framework::Fixture fill(rhs, 1); fill(bias, 2); - // In case of broadcast, we need simply copy the first into the following "M" ones + // In case of broadcast, we need to simply copy the first into the following "M" ones for(int i = 1; i < m * batch_size; i++) { memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); @@ -2434,6 +2434,171 @@ class GEMMMatrixMultiplyNative3DValidationFixture : public framework::Fixture SimpleTensor _reference{}; }; +template +class GEMMMatrixMultiplyReshapedOnlyRhsMMULValidationFixture : public framework::Fixture +{ +public: + template + void setup(unsigned int m, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0, bool export_to_cl_image, DataType data_type, float alpha, + float beta, bool broadcast_bias, + const ActivationLayerInfo &act_info) + { + GEMMLHSMatrixInfo lhs_info; + lhs_info.m0 = m0; + lhs_info.k0 = k0; + + GEMMRHSMatrixInfo rhs_info; + rhs_info.n0 = n0; + rhs_info.k0 = k0; + rhs_info.interleave = true; + rhs_info.transpose = false; + rhs_info.h0 = 4; + rhs_info.export_to_cl_image = export_to_cl_image; + + // Set the tensor shapes for LHS and RHS matrices + const TensorShape lhs_shape(k, m, batch_size); + const TensorShape rhs_shape(n, k, batch_size); + const TensorShape bias_shape(n, + broadcast_bias ? 1 : m, + broadcast_bias ? 1 : batch_size); + + _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, alpha, beta, broadcast_bias, act_info); + _reference = compute_reference(lhs_shape, rhs_shape, data_type, alpha, beta, broadcast_bias, act_info); + } + +protected: + template + void fill(U &&tensor, int i) + { + static_assert(std::is_floating_point::value || std::is_same::value, "Only floating point data types supported."); + using DistributionType = typename std::conditional::value, arm_compute::utils::uniform_real_distribution_16bit, std::uniform_real_distribution>::type; + + DistributionType distribution{ T(-1.0f), T(1.0f) }; + library->fill(tensor, distribution, i); + + // Fill border with infinity in order to check the presence of NaN values (i.e. inf * 0) + DistributionType distribution_inf{ T(std::numeric_limits::infinity()), T(std::numeric_limits::infinity()) }; + library->fill_borders_with_garbage(tensor, distribution_inf, i); + } + + TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, + DataType data_type, float alpha, float beta, bool broadcast_bias, const ActivationLayerInfo &act_info) + { + // Create tensors + TensorType lhs = create_tensor(lhs_shape, data_type, 1); + TensorType rhs = create_tensor(rhs_shape, data_type, 1); + TensorType bias = create_tensor(bias_shape, data_type, 1); + TensorType rhs_reshaped; + TensorType dst; + + const unsigned int M = lhs_shape[1]; + const unsigned int N = rhs_shape[0]; + const unsigned int K = lhs_shape[0]; + GEMMKernelInfo kernel_info; + kernel_info.m = M; + kernel_info.n = N; + kernel_info.k = K; + kernel_info.depth_output_gemm3d = 0; + kernel_info.reinterpret_input_as_3d = false; + kernel_info.broadcast_bias = broadcast_bias; + kernel_info.activation_info = act_info; + + // Create and configure function + ReshapeRHSOperatorType reshape_rhs; + GEMMOperatorType gemm; + + validate_result = bool(reshape_rhs.validate(rhs.info(), rhs_reshaped.info(), rhs_info)); + if(!validate_result) + { + return nullptr; + } + + reshape_rhs.configure(rhs.info(), rhs_reshaped.info(), rhs_info); + + validate_result = bool(gemm.validate(lhs.info(), rhs_reshaped.info(), bias.info(), dst.info(), alpha, beta, lhs_info, rhs_info, kernel_info)); + if(!validate_result) + { + return nullptr; + } + + gemm.configure(lhs.info(), rhs_reshaped.info(), bias.info(), dst.info(), alpha, beta, lhs_info, rhs_info, kernel_info); + + ARM_COMPUTE_ASSERT(lhs.info()->is_resizable()); + ARM_COMPUTE_ASSERT(rhs.info()->is_resizable()); + ARM_COMPUTE_ASSERT(bias.info()->is_resizable()); + + // Allocate tensors + lhs.allocator()->allocate(); + rhs.allocator()->allocate(); + rhs_reshaped.allocator()->allocate(); + bias.allocator()->allocate(); + dst.allocator()->allocate(); + + ARM_COMPUTE_ASSERT(!lhs.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!rhs.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!rhs_reshaped.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!bias.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!dst.info()->is_resizable()); + + // Fill tensors + fill(AccessorType(lhs), 0); + fill(AccessorType(rhs), 1); + fill(AccessorType(bias), 2); + + // Compute GEMM + ITensorPack reshape_rhs_pack = { { ACL_SRC, &rhs }, { ACL_DST, &rhs_reshaped } }; + reshape_rhs.run(reshape_rhs_pack); + ITensorPack gemm_pack({ { ACL_SRC_0, &lhs }, + { ACL_SRC_1, &rhs_reshaped }, + { ACL_SRC_2, &bias }, + { ACL_DST, &dst } + }); + gemm.run(gemm_pack); + + return dst; + } + + SimpleTensor compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, DataType data_type, float alpha, float beta, bool broadcast_bias, + const ActivationLayerInfo &act_info) + { + if(!validate_result) + return SimpleTensor(); + + TensorShape dst_shape = lhs_shape; + dst_shape[0] = rhs_shape[0]; + dst_shape[1] = lhs_shape[1]; + + // Create reference + SimpleTensor lhs{ lhs_shape, data_type, 1 }; + SimpleTensor rhs{ rhs_shape, data_type, 1 }; + SimpleTensor bias{ dst_shape, data_type, 1 }; + + const int n = rhs_shape[0]; + const int m = lhs_shape[1]; + const int batch_size = lhs_shape[2]; + + // Fill reference + fill(lhs, 0); + fill(rhs, 1); + fill(bias, 2); + + if(broadcast_bias) + { + // In case of broadcast, we need to simply copy the first into the following "M" ones + for(int i = 1; i < m * batch_size; i++) + { + memcpy(bias.data() + i * n, bias.data(), n * sizeof(T)); + } + } + + return reference::activation_layer(reference::gemm(lhs, rhs, bias, alpha, beta), act_info); + } + + bool validate_result = true; + TensorType _target{}; + SimpleTensor _reference{}; +}; + } // namespace validation } // namespace test } // namespace arm_compute diff --git a/tests/validation/fixtures/GEMMLowpFixture.h b/tests/validation/fixtures/GEMMLowpFixture.h index 3da4c02f6d..6d073cd361 100644 --- a/tests/validation/fixtures/GEMMLowpFixture.h +++ b/tests/validation/fixtures/GEMMLowpFixture.h @@ -24,19 +24,10 @@ #ifndef ARM_COMPUTE_TEST_GEMMLOWP_FIXTURE #define ARM_COMPUTE_TEST_GEMMLOWP_FIXTURE -#include "arm_compute/core/KernelDescriptors.h" -#include "arm_compute/core/TensorShape.h" -#include "arm_compute/core/Types.h" #include "arm_compute/core/utils/quantization/AsymmHelpers.h" -#include "tests/AssetsLibrary.h" -#include "tests/Globals.h" -#include "tests/IAccessor.h" -#include "tests/framework/Asserts.h" #include "tests/framework/Fixture.h" -#include "tests/validation/Helpers.h" #include "tests/validation/reference/GEMMLowp.h" - -#include +#include "tests/validation/Validation.h" namespace arm_compute { @@ -67,13 +58,13 @@ void fill(U &&tensor, int i) max_bound = bounds.second; } } - std::uniform_int_distribution distribution(min_bound, max_bound); + std::uniform_int_distribution distribution(min_bound, max_bound); library->fill(tensor, distribution, i); break; } case DataType::QASYMM8: { - std::uniform_int_distribution distribution(1, 254); + std::uniform_int_distribution distribution(1, 254); library->fill(tensor, distribution, i); break; } @@ -1362,6 +1353,370 @@ class GEMMLowpMatrixMultiplyReshapedOnlyRHSValidationFixture : public framework: SimpleTensor _reference{}; }; +template +class GEMMLowpMatrixMultiplyReshapedOnlyRHSMMULOutputStageValidationFixture : public framework::Fixture +{ +public: + template + void setup(unsigned int m, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, + unsigned int k0, unsigned int h0, bool interleave_rhs, bool transpose_rhs, bool broadcast_bias, DataType data_type) + { + GEMMLowpOutputStageInfo output_stage; + output_stage.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT; + output_stage.output_data_type = data_type; + output_stage.gemmlowp_multipliers = std::vector { 1 }; + output_stage.gemmlowp_shifts = std::vector { 1 }; + output_stage.gemmlowp_multipliers[0] = 1; + output_stage.gemmlowp_shifts[0] = 1; + output_stage.gemmlowp_offset = 0; + constexpr float scale = 0.001f; + quantization::calculate_quantized_multiplier(scale, &output_stage.gemmlowp_multipliers[0], &output_stage.gemmlowp_shifts[0]); + output_stage.gemmlowp_min_bound = -100; + output_stage.gemmlowp_max_bound = 100; + + GEMMLHSMatrixInfo lhs_info; + lhs_info.m0 = m0; + lhs_info.k0 = k0; + + GEMMRHSMatrixInfo rhs_info; + rhs_info.n0 = n0; + rhs_info.k0 = k0; + rhs_info.h0 = h0; + rhs_info.interleave = interleave_rhs; + rhs_info.transpose = transpose_rhs; + + int a_offset = 1; + int b_offset = 1; + + // Set the tensor shapes for LHS and RHS matrices + const TensorShape lhs_shape(k, m, batch_size); + const TensorShape rhs_shape(n, k, batch_size); + const TensorShape bias_shape(n, + broadcast_bias ? 1 : m, + broadcast_bias ? 1 : batch_size); + + _target = compute_target(lhs_shape, rhs_shape, bias_shape, lhs_info, rhs_info, data_type, output_stage, a_offset, b_offset); + if(gemm_validated == true) + { + _reference = compute_reference(lhs_shape, rhs_shape, bias_shape, data_type, output_stage, a_offset, b_offset); + } + } + +protected: + template + void fill(U &&tensor, int i) + { + switch(tensor.data_type()) + { + case DataType::QASYMM8: + { + // Between 1 and 254 in order to avoid having -128 and 128 for the DOT product path + std::uniform_int_distribution<> distribution(1, 254); + library->fill(tensor, distribution, i); + } + break; + case DataType::QASYMM8_SIGNED: + { + std::uniform_int_distribution<> distribution(-127, 126); + library->fill(tensor, distribution, i); + } + break; + case DataType::S32: + { + std::uniform_int_distribution<> distribution(-10000, 10000); + library->fill(tensor, distribution, i); + } + break; + default: + ARM_COMPUTE_ERROR("Unsupported data type"); + } + } + + TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, const GEMMLHSMatrixInfo &lhs_info, + const GEMMRHSMatrixInfo &rhs_info, DataType data_type, GEMMLowpOutputStageInfo output_stage, const int a_offset, const int b_offset) + { + // Create tensors + TensorType lhs = create_tensor(lhs_shape, data_type, 1, QuantizationInfo(1.0f / 255, a_offset)); + TensorType rhs = create_tensor(rhs_shape, data_type, 1, QuantizationInfo(1.0f / 255, b_offset)); + TensorType bias = create_tensor(bias_shape, DataType::S32, 1); + TensorType dst; + TensorType rhs_reshaped; + + const unsigned int M = lhs_shape[1]; + const unsigned int N = rhs_shape[0]; + const unsigned int K = lhs_shape[0]; + + // Tensors for precomputing sum of lhs rows / rhs columns + TensorType vec_sum_rows = create_tensor(TensorShape(M, 1, lhs_shape[2]), DataType::S32, 1); + TensorType vec_sum_cols = create_tensor(TensorShape(N, 1, rhs_shape[2]), DataType::S32, 1); + + GEMMKernelInfo gemm_info; + gemm_info.m = M; + gemm_info.n = N; + gemm_info.k = K; + gemm_info.lhs_info = lhs_info; + gemm_info.rhs_info = rhs_info; + gemm_info.output_stage = output_stage; + gemm_info.a_offset = a_offset; + gemm_info.b_offset = b_offset; + // The output tensor will be auto-initialized within the function + + // Create and configure function + ReshapeRHSOperatorType reshape_rhs; + GEMMFunctionType gemm; + reshape_rhs.configure(rhs.info(), rhs_reshaped.info(), rhs_info); + + // If GEMM is not validated, do not try to run. The validation will check + // if the technology supports this extension. If not, the test will be skipped. + // If it supports, the test will fail anyway because target and reference + // will not match. + gemm_validated = bool(gemm.validate(lhs.info(), rhs_reshaped.info(), dst.info(), gemm_info, vec_sum_cols.info(), vec_sum_rows.info(), bias.info())); + if(gemm_validated == true) + { + gemm.configure(lhs.info(), rhs_reshaped.info(), dst.info(), gemm_info, vec_sum_cols.info(), vec_sum_rows.info(), bias.info()); + + ARM_COMPUTE_ASSERT(lhs.info()->is_resizable()); + ARM_COMPUTE_ASSERT(rhs.info()->is_resizable()); + ARM_COMPUTE_ASSERT(bias.info()->is_resizable()); + + // Allocate tensors + lhs.allocator()->allocate(); + rhs.allocator()->allocate(); + rhs_reshaped.allocator()->allocate(); + bias.allocator()->allocate(); + vec_sum_cols.allocator()->allocate(); + vec_sum_rows.allocator()->allocate(); + dst.allocator()->allocate(); + + ARM_COMPUTE_ASSERT(!lhs.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!rhs.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!rhs_reshaped.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!bias.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!vec_sum_cols.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!vec_sum_rows.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!dst.info()->is_resizable()); + + // Fill tensors + fill(AccessorType(lhs), 0); + fill(AccessorType(rhs), 1); + fill(AccessorType(bias), 2); + + TensorType lhs_32 = create_tensor(lhs_shape, DataType::S32, 1); + TensorType rhs_32 = create_tensor(rhs_shape, DataType::S32, 1); + CastOperation cast_lhs; + CastOperation cast_rhs; + cast_lhs.configure(&lhs, &lhs_32, ConvertPolicy::SATURATE); + cast_rhs.configure(&rhs, &rhs_32, ConvertPolicy::SATURATE); + lhs_32.allocator()->allocate(); + rhs_32.allocator()->allocate(); + cast_lhs.run(); + cast_rhs.run(); + + ReduceOperation lhs_sum_rows; + ReduceOperation rhs_sum_cols; + + lhs_sum_rows.configure(&lhs_32, &vec_sum_rows, 0, ReductionOperation::SUM, false); + rhs_sum_cols.configure(&rhs_32, &vec_sum_cols, 1, ReductionOperation::SUM); + + lhs_sum_rows.run(); + rhs_sum_cols.run(); + + // Compute GEMM + ITensorPack reshape_rhs_pack = { { ACL_SRC, &rhs }, { ACL_DST, &rhs_reshaped } }; + reshape_rhs.run(reshape_rhs_pack); + ITensorPack gemm_pack({ { ACL_SRC_0, &lhs }, { ACL_SRC_1, &rhs_reshaped }, { ACL_SRC_2, &bias }, { ACL_DST, &dst }, { ACL_VEC_COL_SUM, &vec_sum_cols }, { ACL_VEC_ROW_SUM, &vec_sum_rows } }); + gemm.run(gemm_pack); + } + + return dst; + } + + SimpleTensor compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const TensorShape &bias_shape, DataType data_type, GEMMLowpOutputStageInfo output_stage, + const int a_offset, const int b_offset) + { + TensorShape dst_shape = lhs_shape; + dst_shape[0] = rhs_shape[0]; + dst_shape[1] = lhs_shape[1]; + + // Create reference + SimpleTensor lhs{ lhs_shape, data_type, 1, QuantizationInfo(1.0f / 255, a_offset) }; + SimpleTensor rhs{ rhs_shape, data_type, 1, QuantizationInfo(1.0f / 255, b_offset) }; + SimpleTensor bias{ bias_shape, DataType::S32, 1 }; + SimpleTensor dst{ dst_shape, DataType::S32, 1 }; + SimpleTensor dst_final{ dst_shape, data_type, 1 }; + + // Fill reference + fill(lhs, 0); + fill(rhs, 1); + fill(bias, 2); + + dst = reference::gemmlowp_matrix_multiply_core(lhs, rhs, dst_shape, a_offset, b_offset); + dst_final = reference::gemmlowp_quantize_down_scale_by_fixedpoint(dst, bias, + output_stage.gemmlowp_multipliers, output_stage.gemmlowp_shifts, output_stage.gemmlowp_offset, output_stage.gemmlowp_min_bound, output_stage.gemmlowp_max_bound); + return dst_final; + } + + bool gemm_validated = true; + TensorType _target{}; + SimpleTensor _reference{}; +}; + +template +class GEMMLowpMatrixMultiplyReshapedOnlyRHSMMULValidationFixture : public framework::Fixture +{ +public: + template + void setup(unsigned int m, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, + unsigned int k0, unsigned int h0, bool interleave_rhs, bool transpose_rhs, DataType data_type) + { + GEMMLHSMatrixInfo lhs_info; + lhs_info.m0 = m0; + lhs_info.k0 = k0; + + GEMMRHSMatrixInfo rhs_info; + rhs_info.n0 = n0; + rhs_info.k0 = k0; + rhs_info.h0 = h0; + rhs_info.interleave = interleave_rhs; + rhs_info.transpose = transpose_rhs; + + // Set the tensor shapes for LHS and RHS matrices + const TensorShape lhs_shape(k, m, batch_size); + const TensorShape rhs_shape(n, k, batch_size); + + _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, data_type); + if(gemm_validated == true) + { + _reference = compute_reference(lhs_shape, rhs_shape, data_type); + } + } + +protected: + template + void fill(U &&tensor, int i) + { + switch(tensor.data_type()) + { + case DataType::QASYMM8: + { + // Between 1 and 254 in order to avoid having -128 and 128 for the DOT product path + std::uniform_int_distribution<> distribution(1, 254); + library->fill(tensor, distribution, i); + } + break; + case DataType::QASYMM8_SIGNED: + { + std::uniform_int_distribution<> distribution(-127, 126); + library->fill(tensor, distribution, i); + } + break; + default: + ARM_COMPUTE_ERROR("Unsupported data type"); + } + } + + TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info, + const GEMMRHSMatrixInfo &rhs_info, DataType data_type) + { + // Create tensors + TensorType lhs = create_tensor(lhs_shape, data_type, 1); + TensorType rhs = create_tensor(rhs_shape, data_type, 1); + TensorType rhs_reshaped; + TensorType dst; + + const unsigned int M = lhs_shape[1]; + const unsigned int N = rhs_shape[0]; + const unsigned int K = lhs_shape[0]; + + GEMMKernelInfo gemm_info; + gemm_info.m = M; + gemm_info.n = N; + gemm_info.k = K; + gemm_info.lhs_info = lhs_info; + gemm_info.rhs_info = rhs_info; + // The output tensor will be auto-initialized within the function + + // Create and configure function + ReshapeRHSOperatorType reshape_rhs; + GEMMFunctionType gemm; + reshape_rhs.configure(rhs.info(), rhs_reshaped.info(), rhs_info); + + // If GEMM is not validated, do not try to run. The validation will check + // if the technology supports this extension. If not, the test will be skipped. + // If it supports, the test will fail anyway because target and reference + // will not match. + gemm_validated = bool(gemm.validate(lhs.info(), rhs_reshaped.info(), dst.info(), gemm_info, nullptr, nullptr, nullptr)); + if(gemm_validated == true) + { + gemm.configure(lhs.info(), rhs_reshaped.info(), dst.info(), gemm_info, nullptr, nullptr, nullptr); + + ARM_COMPUTE_ASSERT(lhs.info()->is_resizable()); + ARM_COMPUTE_ASSERT(rhs.info()->is_resizable()); + + // Allocate tensors + lhs.allocator()->allocate(); + rhs.allocator()->allocate(); + rhs_reshaped.allocator()->allocate(); + dst.allocator()->allocate(); + + ARM_COMPUTE_ASSERT(!lhs.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!rhs.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!rhs_reshaped.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!dst.info()->is_resizable()); + + // Fill tensors + fill(AccessorType(lhs), 0); + fill(AccessorType(rhs), 1); + + // Compute GEMM + ITensorPack reshape_rhs_pack = { { ACL_SRC, &rhs }, { ACL_DST, &rhs_reshaped } }; + reshape_rhs.run(reshape_rhs_pack); + ITensorPack gemm_pack({ { ACL_SRC_0, &lhs }, { ACL_SRC_1, &rhs_reshaped }, { ACL_DST, &dst } }); + gemm.run(gemm_pack); + } + + return dst; + } + + SimpleTensor compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, DataType data_type) + { + TensorShape dst_shape = lhs_shape; + dst_shape[0] = rhs_shape[0]; + dst_shape[1] = lhs_shape[1]; + + if(data_type == DataType::QASYMM8) + { + // Create reference + SimpleTensor lhs{ lhs_shape, data_type, 1 }; + SimpleTensor rhs{ rhs_shape, data_type, 1 }; + SimpleTensor dst{ dst_shape, DataType::S32, 1 }; + + // Fill reference + fill(lhs, 0); + fill(rhs, 1); + + return reference::gemmlowp_matrix_multiply_core(lhs, rhs, dst_shape, 0, 0); + } + else + { + // Create reference + SimpleTensor lhs{ lhs_shape, data_type, 1 }; + SimpleTensor rhs{ rhs_shape, data_type, 1 }; + SimpleTensor dst{ dst_shape, DataType::S32, 1 }; + + // Fill reference + fill(lhs, 0); + fill(rhs, 1); + + return reference::gemmlowp_matrix_multiply_core(lhs, rhs, dst_shape, 0, 0); + } + } + + bool gemm_validated = true; + TensorType _target{}; + SimpleTensor _reference{}; +}; + template class GEMMLowpMatrixMultiplyReshapedOnlyRHS3DValidationFixture : public framework::Fixture { diff --git a/tests/validation/fixtures/ReductionOperationFixture.h b/tests/validation/fixtures/ReductionOperationFixture.h index fc422ad35b..c333f1391f 100644 --- a/tests/validation/fixtures/ReductionOperationFixture.h +++ b/tests/validation/fixtures/ReductionOperationFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -76,14 +76,14 @@ class ReductionOperationValidationFixture : public framework::Fixture if(tensor.data_type() == DataType::QASYMM8) { std::pair bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f); - std::uniform_int_distribution distribution(bounds.first, bounds.second); + std::uniform_int_distribution distribution(bounds.first, bounds.second); library->fill(tensor, distribution, 0); } else if(tensor.data_type() == DataType::QASYMM8_SIGNED) { std::pair bounds = get_quantized_qasymm8_signed_bounds(tensor.quantization_info(), -1.0f, 1.0f); - std::uniform_int_distribution distribution(bounds.first, bounds.second); + std::uniform_int_distribution distribution(bounds.first, bounds.second); library->fill(tensor, distribution, 0); } diff --git a/tests/validation/fixtures/ScaleFixture.h b/tests/validation/fixtures/ScaleFixture.h index b719a22fdf..c0b44bcb5f 100644 --- a/tests/validation/fixtures/ScaleFixture.h +++ b/tests/validation/fixtures/ScaleFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2021 Arm Limited. + * Copyright (c) 2017-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -59,8 +59,8 @@ class ScaleValidationGenericFixture : public framework::Fixture generate_scale(shape); - std::mt19937 generator(library->seed()); - std::uniform_int_distribution distribution_u8(0, 255); + std::mt19937 generator(library->seed()); + std::uniform_int_distribution distribution_u8(0, 255); _constant_border_value = static_cast(distribution_u8(generator)); _target = compute_target(shape, data_layout); diff --git a/tests/validation/reference/Gather.cpp b/tests/validation/reference/Gather.cpp index 93ac09cf95..8de1a473eb 100644 --- a/tests/validation/reference/Gather.cpp +++ b/tests/validation/reference/Gather.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 Arm Limited. + * Copyright (c) 2018-2019, 2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -45,22 +45,55 @@ SimpleTensor gather(const SimpleTensor &src, const SimpleTensor Window win; win.use_tensor_dimensions(dst_shape); - execute_window_loop(win, [&](const Coordinates & id) + if(indices.shape().num_dimensions() == 1u) { - Coordinates offset; - for(unsigned int dim = 0; dim < id.num_dimensions(); ++dim) + execute_window_loop(win, [&](const Coordinates & id) { - if(dim == actual_axis) + Coordinates offset; + for(unsigned int dim = 0; dim < id.num_dimensions(); ++dim) { - offset.set(dim, indices_ptr[id[dim]]); + if(dim == actual_axis) + { + offset.set(dim, indices_ptr[id[dim]]); + } + else + { + offset.set(dim, id[dim]); + } } - else + *reinterpret_cast(dst(id)) = *reinterpret_cast(src(offset)); + }); + } + else + { + if(actual_axis == 1) + { + win.set(Window::DimX, Window::Dimension(0, 1, 1)); + execute_window_loop(win, [&](const Coordinates & id) { - offset.set(dim, id[dim]); - } + auto *dst_ptr = dst(id); + Coordinates index_offset; + for(uint32_t k = 0; k < indices.shape().num_dimensions(); ++k) + { + index_offset.set(k, id[k + 1]); + } + const uint32_t row = *reinterpret_cast(indices(index_offset)); + Coordinates src_offset; + src_offset.set(0, 0); + src_offset.set(1, row); + for(uint32_t j = 2; j < src.shape().num_dimensions(); ++j) + { + src_offset.set(j, id[1 + indices.shape().num_dimensions() + (j - 2)]); + } + const auto in_ptr_row = src(src_offset); + memcpy(dst_ptr, in_ptr_row, src.shape()[0] * src.element_size()); + }); + } + else + { + ARM_COMPUTE_ERROR("Not implemented."); } - *reinterpret_cast(dst(id)) = *reinterpret_cast(src(offset)); - }); + } return dst; } @@ -72,4 +105,4 @@ template SimpleTensor gather(const SimpleTensor &src, const Si } // namespace reference } // namespace validation } // namespace test -} // namespace arm_compute \ No newline at end of file +} // namespace arm_compute diff --git a/utils/TypePrinter.h b/utils/TypePrinter.h index dae81e4a5a..fe7f13a19e 100644 --- a/utils/TypePrinter.h +++ b/utils/TypePrinter.h @@ -472,6 +472,16 @@ inline ::std::ostream &operator<<(::std::ostream &os, const BoundingBoxTransform return os; } +#if defined(ARM_COMPUTE_ENABLE_BF16) +inline ::std::ostream &operator<<(::std::ostream &os, const bfloat16 &v) +{ + std::stringstream str; + str << v; + os << str.str(); + return os; +} +#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */ + /** Formatted output of the BoundingBoxTransformInfo type. * * @param[in] bbox_info Type to output. @@ -2310,6 +2320,9 @@ inline ::std::ostream &operator<<(::std::ostream &os, const GPUTarget &gpu_targe case GPUTarget::GPU_ARCH_MASK: os << "GPU_ARCH_MASK"; break; + case GPUTarget::GPU_GENERATION_MASK: + os << "GPU_GENERATION_MASK"; + break; case GPUTarget::MIDGARD: os << "MIDGARD"; break; @@ -2343,21 +2356,51 @@ inline ::std::ostream &operator<<(::std::ostream &os, const GPUTarget &gpu_targe case GPUTarget::G51LIT: os << "G51LIT"; break; + case GPUTarget::G31: + os << "G31"; + break; case GPUTarget::G76: os << "G76"; break; + case GPUTarget::G52: + os << "G52"; + break; + case GPUTarget::G52LIT: + os << "G52LIT"; + break; case GPUTarget::G77: os << "G77"; break; + case GPUTarget::G57: + os << "G57"; + break; case GPUTarget::G78: os << "G78"; break; - case GPUTarget::G31: - os << "G31"; + case GPUTarget::G68: + os << "G68"; + break; + case GPUTarget::G78AE: + os << "G78AE"; break; case GPUTarget::G710: os << "G710"; break; + case GPUTarget::G610: + os << "G610"; + break; + case GPUTarget::G510: + os << "G510"; + break; + case GPUTarget::G310: + os << "G310"; + break; + case GPUTarget::G715: + os << "G715"; + break; + case GPUTarget::G615: + os << "G615"; + break; default: ARM_COMPUTE_ERROR("NOT_SUPPORTED!"); } @@ -3236,6 +3279,92 @@ inline std::string to_string(const Conv3dInfo &conv3d_info) return str.str(); } +/** Formatted output of the arm_compute::WeightFormat type. + * + * @param[in] wf arm_compute::WeightFormat Type to output. + * + * @return Formatted string. + */ +inline std::string to_string(const WeightFormat wf) +{ +#define __CASE_WEIGHT_FORMAT(wf) \ +case WeightFormat::wf: \ + return #wf; + switch(wf) + { + __CASE_WEIGHT_FORMAT(UNSPECIFIED) + __CASE_WEIGHT_FORMAT(ANY) + __CASE_WEIGHT_FORMAT(OHWI) + __CASE_WEIGHT_FORMAT(OHWIo2) + __CASE_WEIGHT_FORMAT(OHWIo4) + __CASE_WEIGHT_FORMAT(OHWIo8) + __CASE_WEIGHT_FORMAT(OHWIo16) + __CASE_WEIGHT_FORMAT(OHWIo32) + __CASE_WEIGHT_FORMAT(OHWIo64) + __CASE_WEIGHT_FORMAT(OHWIo128) + __CASE_WEIGHT_FORMAT(OHWIo4i2) + __CASE_WEIGHT_FORMAT(OHWIo4i2_bf16) + __CASE_WEIGHT_FORMAT(OHWIo8i2) + __CASE_WEIGHT_FORMAT(OHWIo8i2_bf16) + __CASE_WEIGHT_FORMAT(OHWIo16i2) + __CASE_WEIGHT_FORMAT(OHWIo16i2_bf16) + __CASE_WEIGHT_FORMAT(OHWIo32i2) + __CASE_WEIGHT_FORMAT(OHWIo32i2_bf16) + __CASE_WEIGHT_FORMAT(OHWIo64i2) + __CASE_WEIGHT_FORMAT(OHWIo64i2_bf16) + __CASE_WEIGHT_FORMAT(OHWIo4i4) + __CASE_WEIGHT_FORMAT(OHWIo4i4_bf16) + __CASE_WEIGHT_FORMAT(OHWIo8i4) + __CASE_WEIGHT_FORMAT(OHWIo8i4_bf16) + __CASE_WEIGHT_FORMAT(OHWIo16i4) + __CASE_WEIGHT_FORMAT(OHWIo16i4_bf16) + __CASE_WEIGHT_FORMAT(OHWIo32i4) + __CASE_WEIGHT_FORMAT(OHWIo32i4_bf16) + __CASE_WEIGHT_FORMAT(OHWIo64i4) + __CASE_WEIGHT_FORMAT(OHWIo64i4_bf16) + __CASE_WEIGHT_FORMAT(OHWIo2i8) + __CASE_WEIGHT_FORMAT(OHWIo4i8) + __CASE_WEIGHT_FORMAT(OHWIo8i8) + __CASE_WEIGHT_FORMAT(OHWIo16i8) + __CASE_WEIGHT_FORMAT(OHWIo32i8) + __CASE_WEIGHT_FORMAT(OHWIo64i8) + default: + return "invalid value"; + } +#undef __CASE_WEIGHT_FORMAT +} + +/** Formatted output of the arm_compute::WeightFormat type. + * + * @param[out] os Output stream. + * @param[in] wf WeightFormat to output. + * + * @return Modified output stream. + */ +inline ::std::ostream &operator<<(::std::ostream &os, const arm_compute::WeightFormat &wf) +{ + os << to_string(wf); + return os; +} + +/** Formatted output of the std::tuple tuple. + * + * @param[in] values tuple of input and output tensor shapes and WeightFormat used. + * + * @return Formatted string. + */ +inline std::string to_string(const std::tuple values) +{ + std::stringstream str; + str << "[Input shape = " << std::get<0>(values); + str << ", "; + str << "Expected output shape = " << std::get<1>(values); + + str << ", "; + str << "WeightFormat = " << std::get<2>(values) << "]"; + return str.str(); +} + } // namespace arm_compute #endif /* __ARM_COMPUTE_TYPE_PRINTER_H__ */ diff --git a/utils/command_line/CommandLineParser.h b/utils/command_line/CommandLineParser.h index e8fabc4251..523f25e8a1 100644 --- a/utils/command_line/CommandLineParser.h +++ b/utils/command_line/CommandLineParser.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Arm Limited. + * Copyright (c) 2017-2020, 2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -27,6 +27,7 @@ #include "Option.h" #include "arm_compute/core/utils/misc/Utility.h" +#include #include #include #include
armv8.2-a-sve gcc-arm-10.2-2020.11-x86_64-aarch64-none-linux-gnu
Android - armv7a - NDK r18b -
Android armv8a + NDK r18b
armv8.2-a