diff --git a/CMakeLists.txt b/CMakeLists.txt index 2752df7e68..a1f89a6682 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved.## cmake_minimum_required(VERSION 3.0.0) @@ -107,8 +107,9 @@ option (ENABLE_UPPERCASE_API "export APIs with uppercase" OFF) option (ENABLE_COMPLEX_RETURN_INTEL "Enable complex_return_intel" OFF) option (ENABLE_TRSM_PREINVERSION "Enable TRSM preinversion" ON) option (ENABLE_AOCL_DYNAMIC "Enable Dynamic Multi-threading" OFF) -option(DISABLE_BLIS_ARCH_TYPE "Disable BLIS_ARCH_TYPE functionality" OFF) +option(DISABLE_BLIS_ARCH_TYPE "Disable BLIS_ARCH_TYPE and BLIS_MODEL_TYPE functionality" OFF) option(RENAME_BLIS_ARCH_TYPE "Rename BLIS_ARCH_TYPE env var renamed to supplied value" BLIS_ARCH_TYPE) +option(RENAME_BLIS_MODEL_TYPE "Rename BLIS_MODEL_TYPE env var renamed to supplied value" BLIS_MODEL_TYPE) if (${AOCL_BLIS_FAMILY} STREQUAL "amdzen") set(REF_KERNEL_MIRRORING_PY "${CMAKE_SOURCE_DIR}/build/blis_ref_kernel_mirror.py") @@ -181,9 +182,11 @@ endif () if (ENABLE_JRIR_RR) message("Round robin thread method enabled") set(BLIS_ENABLE_JRIR_RR TRUE) + set(BLIS_ENABLE_JRIR_SLAB FALSE) elseif (ENABLE_JRIR_SLAB) message("SLAB thread method enabled") set(BLIS_ENABLE_JRIR_SLAB TRUE) + set(BLIS_ENABLE_JRIR_RR FALSE) else () message("Unsupported method of thread partitioning in jr and ir loops") endif () @@ -202,18 +205,23 @@ endif () if (ENABLE_BLAS) add_definitions(-DBLIS_ENABLE_BLAS) + set(BLIS_ENABLE_BLAS TRUE) else () add_definitions(-DBLIS_DISABLE_BLAS) + set(BLIS_ENABLE_BLAS FALSE) endif () if (ENABLE_CBLAS) add_definitions(-DBLIS_ENABLE_CBLAS) + set(BLIS_ENABLE_CBLAS TRUE) if (NOT ENABLE_BLAS) # Force BLAS layer when CBLAS is enabled add_definitions(-DBLIS_ENABLE_BLAS) + set(BLIS_ENABLE_BLAS TRUE) endif () else () add_definitions(-DBLIS_DISABLE_CBLAS) + set(BLIS_ENABLE_CBLAS FALSE) endif () if (ENABLE_BLASTEST) @@ -286,8 +294,10 @@ endif() if(DISABLE_BLIS_ARCH_TYPE) set(BLIS_DISABLE_BLIS_ARCH_TYPE TRUE) + set(BLIS_DISABLE_BLIS_MODEL_TYPE TRUE) else() set(BLIS_DISABLE_BLIS_ARCH_TYPE FALSE) + set(BLIS_DISABLE_BLIS_MODEL_TYPE FALSE) endif() if(RENAME_BLIS_ARCH_TYPE) @@ -298,6 +308,30 @@ else() set(rename_blis_arch_type "BLIS_ARCH_TYPE") endif() +if(RENAME_BLIS_MODEL_TYPE) + set(__blis_model_type_name TRUE) + set(rename_blis_model_type "${RENAME_BLIS_MODEL_TYPE}") +else() + set(__blis_model_type_name TRUE) + set(rename_blis_model_type "BLIS_MODEL_TYPE") +endif() + +find_package(Doxygen) +set(W_DIR "${CMAKE_CURRENT_SOURCE_DIR}/docs") +if(NOT (DOXYGEN_FOUND)) + message(STATUS "Doxygen not found please install and try again.") +else() + execute_process(COMMAND doxygen Doxyfile + WORKING_DIRECTORY ${W_DIR} + COMMAND_ECHO STDOUT) +endif() +if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/docs/html/index.html) + message(STATUS "Documentation generated successfully, to view documentation open docs/html/index.html .") +else() + message(STATUS "Document generation failed.") +endif() + +set(CMAKE_BUILD_TYPE ${CMAKE_CONFIGURATION_TYPES}) #print configurations message("---cmake configurations---") @@ -322,8 +356,9 @@ message(BLIS_ENABLE_SANDBOX : ${BLIS_ENABLE_SANDBOX}) message(BLIS_ENABLE_SHARED : ${BLIS_ENABLE_SHARED}) message(DISABLE_BLIS_ARCH_TYPE : ${DISABLE_BLIS_ARCH_TYPE}) message(RENAME_BLIS_ARCH_TYPE : ${RENAME_BLIS_ARCH_TYPE}) +message(RENAME_BLIS_MODEL_TYPE : ${RENAME_BLIS_MODEL_TYPE}) -SET(ENABLE_SIMD_FLAGS "AVX2" CACHE STRING "Set compiler SIMD flags") +SET(ENABLE_SIMD_FLAGS "none" CACHE STRING "Set compiler SIMD flags") SET_PROPERTY(CACHE ENABLE_SIMD_FLAGS PROPERTY STRINGS none SSE2 AVX AVX2) if(${ENABLE_SIMD_FLAGS} MATCHES "AVX2") @@ -334,15 +369,6 @@ elseif(${ENABLE_SIMD_FLAGS} MATCHES "SSE2") add_definitions(/arch:SSE2) endif() -if(${TARGET_ARCH} STREQUAL zen4 OR - ${TARGET_ARCH} STREQUAL amdzen) - set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/kernels/zen4/1/bli_amaxv_zen_int_avx512.c PROPERTIES COMPILE_FLAGS /arch:AVX512) - set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/kernels/zen4/3/bli_gemmtrsm_l_zen_16x14.c PROPERTIES COMPILE_FLAGS /arch:AVX512) - set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/kernels/zen4/3/bli_gemmtrsm_u_zen_16x14.c PROPERTIES COMPILE_FLAGS /arch:AVX512) - set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/kernels/skx/3/bli_dgemm_skx_asm_16x14.c PROPERTIES COMPILE_FLAGS /arch:AVX512) - set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c PROPERTIES COMPILE_FLAGS /arch:AVX512) -endif() - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /W0 ") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /Oi") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /MP") @@ -588,10 +614,37 @@ set(BLIS_VERSION_STRING ${BLIS_VERSION}) string(TIMESTAMP BUILD_DATE "%Y%m%d") add_definitions(-DBLIS_VERSION_STRING="AOCL-BLIS ${BLIS_VERSION_STRING} Build ${BUILD_DATE}") +# Set object libraries created in kernels directory to be added into BLIS library. +set(OBJECT_LIBRARIES + $ + $ + $ + $ + $ + $ + $ + $ + $ +) +# Ammend the list of object libraries to include zen4 paths as appropriate. +if(${TARGET_ARCH} STREQUAL zen4 OR + ${TARGET_ARCH} STREQUAL amdzen) + set(OBJECT_LIBRARIES ${OBJECT_LIBRARIES} + $ + $ + $ + $ + $ + $ + ) +endif() + if(BUILD_SHARED_LIBS) add_library("${PROJECT_NAME}" SHARED ${CMAKE_SOURCE_DIR}/bli_config.h ${CMAKE_SOURCE_DIR}/include/${TARGET_ARCH}/blis.h - ${headers}) + ${headers} + ${OBJECT_LIBRARIES} + ) if(ENABLE_OPENMP) target_link_libraries("${PROJECT_NAME}" PRIVATE OpenMP::OpenMP_CXX) endif() @@ -601,7 +654,9 @@ endif() if(NOT BUILD_SHARED_LIBS) add_library("${PROJECT_NAME}" STATIC ${CMAKE_SOURCE_DIR}/bli_config.h ${CMAKE_SOURCE_DIR}/include/${TARGET_ARCH}/blis.h - ${headers}) + ${headers} + ${OBJECT_LIBRARIES} + ) if(ENABLE_OPENMP) set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}" STATIC_LIBRARY_OPTIONS "${OpenMP_libomp_LIBRARY}") else() diff --git a/Makefile b/Makefile index 1f86acc7e5..0a1a4646ad 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -273,9 +273,11 @@ BASE_OBJ_CBLAS_PATH := $(BASE_OBJ_FRAME_PATH)/compat/cblas ifeq ($(MK_ENABLE_CBLAS),no) MK_BLIS_OBJS := $(filter-out $(BASE_OBJ_CBLAS_PATH)/%.o, $(MK_BLIS_OBJS) ) endif -ifeq ($(MK_ENABLE_BLAS),no) -MK_BLIS_OBJS := $(filter-out $(BASE_OBJ_BLAS_PATH)/%.o, $(MK_BLIS_OBJS) ) -endif +# Include bla_ files so that we get the *_blis_impl interfaces. Actual BLAS +# interfaces will not be included from these files when MK_ENABLE_BLAS is no. +##ifeq ($(MK_ENABLE_BLAS),no) +##MK_BLIS_OBJS := $(filter-out $(BASE_OBJ_BLAS_PATH)/%.o, $(MK_BLIS_OBJS) ) +##endif diff --git a/README.md b/README.md index 28179306c7..ce923198e4 100644 --- a/README.md +++ b/README.md @@ -1,714 +1,11 @@ -![The BLIS cat is sleeping.](http://www.cs.utexas.edu/users/field/blis_cat.png) +# AOCL-BLAS library -[![Build Status](https://travis-ci.org/flame/blis.svg?branch=master)](https://travis-ci.org/flame/blis) -[![Build Status](https://ci.appveyor.com/api/projects/status/github/flame/blis?branch=master&svg=true)](https://ci.appveyor.com/project/shpc/blis/branch/master) +AOCL-BLAS is AMD's optimized version of BLAS targeted for AMD EPYC and Ryzen CPUs. It is developed as a forked version of BLIS (https://github.com/flame/blis), which is developed by members of the [Science of High-Performance Computing](http://shpc.oden.utexas.edu/) (SHPC) group in the [Institute for Computational Engineering and Sciences](https://www.oden.utexas.edu/) at [The University of Texas at Austin](https://www.utexas.edu/) and other collaborators (including AMD). All known features and functionalities of BLIS are retained and supported in AOCL-BLAS library. AOCL-BLAS is regularly updated with the improvements from the upstream repository. -Contents --------- +AOCL BLAS is optimized with SSE2, AVX2, AVX512 instruction sets which would be enabled based on the target Zen architecture using the dynamic dispatch feature. All prominent Level 3, Level 2 and Level 1 APIs are designed and optimized for specific paths targeting different size spectrums e.g., Small, Medium and Large sizes. These algorithms are designed and customized to exploit the architectural improvements of the target platform. -* **[Introduction](#introduction)** -* **[Education and Learning](#education-and-learning)** -* **[What's New](#whats-new)** -* **[What People Are Saying About BLIS](#what-people-are-saying-about-blis)** -* **[Key Features](#key-features)** -* **[How to Download BLIS](#how-to-download-blis)** -* **[Getting Started](#getting-started)** -* **[Documentation](#documentation)** -* **[External Packages](#external-packages)** -* **[Discussion](#discussion)** -* **[Contributing](#contributing)** -* **[Citations](#citations)** -* **[Funding](#funding)** +For detailed instructions on how to configure, build, install, and link against AOCL-BLAS on AMD CPUs, please refer to the AOCL User Guide located on AMD developer [portal](https://www.amd.com/en/developer/aocl.html). -Introduction ------------- - -BLIS is a portable software framework for instantiating high-performance -BLAS-like dense linear algebra libraries. The framework was designed to isolate -essential kernels of computation that, when optimized, immediately enable -optimized implementations of most of its commonly used and computationally -intensive operations. BLIS is written in [ISO -C99](http://en.wikipedia.org/wiki/C99) and available under a -[new/modified/3-clause BSD -license](http://opensource.org/licenses/BSD-3-Clause). While BLIS exports a -[new BLAS-like API](docs/BLISTypedAPI.md), -it also includes a BLAS compatibility layer which gives application developers -access to BLIS implementations via traditional [BLAS routine -calls](http://www.netlib.org/lapack/lug/node145.html). -An [object-based API](docs/BLISObjectAPI.md) unique to BLIS is also available. - -For a thorough presentation of our framework, please read our -[ACM Transactions on Mathematical Software (TOMS)](https://toms.acm.org/) -journal article, ["BLIS: A Framework for Rapidly Instantiating BLAS -Functionality"](http://dl.acm.org/authorize?N91172). -For those who just want an executive summary, please see the -[Key Features](#key-features) section below. - -In a follow-up article (also in [ACM TOMS](https://toms.acm.org/)), -["The BLIS Framework: Experiments in -Portability"](http://dl.acm.org/authorize?N16240), -we investigate using BLIS to instantiate level-3 BLAS implementations on a -variety of general-purpose, low-power, and multicore architectures. - -An IPDPS'14 conference paper titled ["Anatomy of High-Performance Many-Threaded -Matrix -Multiplication"](http://www.cs.utexas.edu/users/flame/pubs/blis3_ipdps14.pdf) -systematically explores the opportunities for parallelism within the five loops -that BLIS exposes in its matrix multiplication algorithm. - -For other papers related to BLIS, please see the -[Citations section](#citations) below. - -It is our belief that BLIS offers substantial benefits in productivity when -compared to conventional approaches to developing BLAS libraries, as well as a -much-needed refinement of the BLAS interface, and thus constitutes a major -advance in dense linear algebra computation. While BLIS remains a -work-in-progress, we are excited to continue its development and further -cultivate its use within the community. - -The BLIS framework is primarily developed and maintained by individuals in the -[Science of High-Performance Computing](http://shpc.ices.utexas.edu/) -(SHPC) group in the -[Oden Institute for Computational Engineering and Sciences](https://www.oden.utexas.edu/) -at [The University of Texas at Austin](https://www.utexas.edu/). -Please visit the [SHPC](http://shpc.ices.utexas.edu/) website for more -information about our research group, such as a list of -[people](http://shpc.ices.utexas.edu/people.html) -and [collaborators](http://shpc.ices.utexas.edu/collaborators.html), -[funding sources](http://shpc.ices.utexas.edu/funding.html), -[publications](http://shpc.ices.utexas.edu/publications.html), -and [other educational projects](http://www.ulaff.net/) (such as MOOCs). - -Education and Learning ----------------------- - -Want to understand what's under the hood? -Many of the same concepts and principles employed when developing BLIS are -introduced and taught in a basic pedagogical setting as part of -[LAFF-On Programming for High Performance (LAFF-On-PfHP)](http://www.ulaff.net/), -one of several massive open online courses (MOOCs) in the -[Linear Algebra: Foundations to Frontiers](http://www.ulaff.net/) series, -all of which are available for free via the [edX platform](http://www.edx.org/). - -What's New ----------- - - * **Multithreaded small/skinny matrix support for sgemm now available!** Thanks to -funding and hardware support from Oracle, we have now accelerated `gemm` for -single-precision real matrix problems where one or two dimensions is exceedingly -small. This work is similar to the `gemm` optimization announced last year. -For now, we have only gathered performance results on an AMD Epyc Zen2 system, but -we hope to publish additional graphs for other architectures in the future. You may -find these Zen2 graphs via the [PerformanceSmall](docs/PerformanceSmall.md) document. - - * **BLIS awarded SIAM Activity Group on Supercomputing Best Paper Prize for 2020!** -We are thrilled to announce that the paper that we internally refer to as the -second BLIS paper, - - "The BLIS Framework: Experiments in Portability." Field G. Van Zee, Tyler Smith, Bryan Marker, Tze Meng Low, Robert A. van de Geijn, Francisco Igual, Mikhail Smelyanskiy, Xianyi Zhang, Michael Kistler, Vernon Austel, John A. Gunnels, Lee Killough. ACM Transactions on Mathematical Software (TOMS), 42(2):12:1--12:19, 2016. - - was selected for the [SIAM Activity Group on Supercomputing Best Paper Prize](https://www.siam.org/prizes-recognition/activity-group-prizes/detail/siag-sc-best-paper-prize) -for 2020. The prize is awarded once every two years to a paper judged to be -the most outstanding paper in the field of parallel scientific and engineering -computing, and has only been awarded once before (in 2016) since its inception -in 2015 (the committee did not award the prize in 2018). The prize -[was awarded](https://www.oden.utexas.edu/about/news/ScienceHighPerfomanceComputingSIAMBestPaperPrize/) -at the [2020 SIAM Conference on Parallel Processing for Scientific Computing](https://www.siam.org/conferences/cm/conference/pp20) in Seattle. Robert was present at -the conference to give -[a talk on BLIS](https://meetings.siam.org/sess/dsp_programsess.cfm?SESSIONCODE=68266) and accept the prize alongside other coauthors. -The selection committee sought to recognize the paper, "which validates BLIS, -a framework relying on the notion of microkernels that enables both productivity -and high performance." Their statement continues, "The framework will continue -having an important influence on the design and the instantiation of dense linear -algebra libraries." - - * **Multithreaded small/skinny matrix support for dgemm now available!** Thanks to -contributions made possible by our partnership with AMD, we have dramatically -accelerated `gemm` for double-precision real matrix problems where one or two -dimensions is exceedingly small. A natural byproduct of this optimization is -that the traditional case of small _m = n = k_ (i.e. square matrices) is also -accelerated, even though it was not targeted specifically. And though only -`dgemm` was optimized for now, support for other datatypes and/or other operations -may be implemented in the future. We've also added new graphs to the -[PerformanceSmall](docs/PerformanceSmall.md) document to showcase multithreaded -performance when one or more matrix dimensions are small. - - * **Performance comparisons now available!** We recently measured the -performance of various level-3 operations on a variety of hardware architectures, -as implemented within BLIS and other BLAS libraries for all four of the standard -floating-point datatypes. The results speak for themselves! Check out our -extensive performance graphs and background info in our new -[Performance](docs/Performance.md) document. - - * **BLIS is now in Debian Unstable!** Thanks to Debian developer-maintainers -[M. Zhou](https://github.com/cdluminate) and -[Nico Schlömer](https://github.com/nschloe) for sponsoring our package in Debian. -Their participation, contributions, and advocacy were key to getting BLIS into -the second-most popular Linux distribution (behind Ubuntu, which Debian packages -feed into). The Debian tracker page may be found -[here](https://tracker.debian.org/pkg/blis). - - * **BLIS now supports mixed-datatype gemm!** The `gemm` operation may now be -executed on operands of mixed domains and/or mixed precisions. Any combination -of storage datatype for A, B, and C is now supported, along with a separate -computation precision that can differ from the storage precision of A and B. -And even the 1m method now supports mixed-precision computation. -For more details, please see our [ACM TOMS](https://toms.acm.org/) journal -article submission ([current -draft](http://www.cs.utexas.edu/users/flame/pubs/blis7_toms_rev0.pdf)). - - * **BLIS now implements the 1m method.** Let's face it: writing complex -assembly `gemm` microkernels for a new architecture is never a priority--and -now, it almost never needs to be. The 1m method leverages existing real domain -`gemm` microkernels to implement all complex domain level-3 operations. For -more details, please see our [ACM TOMS](https://toms.acm.org/) journal article -submission ([current -draft](http://www.cs.utexas.edu/users/flame/pubs/blis6_toms_rev2.pdf)). - -What People Are Saying About BLIS ---------------------------------- - -*["I noticed a substantial increase in multithreaded performance on my own -machine, which was extremely satisfying."](https://groups.google.com/d/msg/blis-discuss/8iu9B5KCxpA/uftpjgIsBwAJ)* ... *["[I was] happy it worked so well!"](https://groups.google.com/d/msg/blis-discuss/8iu9B5KCxpA/uftpjgIsBwAJ)* (Justin Shea) - -*["This is an awesome library."](https://github.com/flame/blis/issues/288#issuecomment-447488637)* ... *["I want to thank you and the blis team for your efforts."](https://github.com/flame/blis/issues/288#issuecomment-448074704)* ([@Lephar](https://github.com/Lephar)) - -*["Any time somebody outside Intel beats MKL by a nontrivial amount, I report it to the MKL team. It is fantastic for any open-source project to get within 10% of MKL... [T]his is why Intel funds BLIS development."](https://github.com/flame/blis/issues/264#issuecomment-428673275)* ([@jeffhammond](https://github.com/jeffhammond)) - -*["So BLIS is now a part of Elk."](https://github.com/flame/blis/issues/267#issuecomment-429303902)* ... *["We have found that zgemm applied to a 15000x15000 matrix with multi-threaded BLIS on a 32-core Ryzen 2990WX processor is about twice as fast as MKL"](https://github.com/flame/blis/issues/264#issuecomment-428373946)* ... *["I'm starting to like this a lot."](https://github.com/flame/blis/issues/264#issuecomment-428926191)* ([@jdk2016](https://github.com/jdk2016)) - -*["I [found] BLIS because I was looking for BLAS operations on C-ordered arrays for NumPy. BLIS has that, but even better is the fact that it's developed in the open using a more modern language than Fortran."](https://github.com/flame/blis/issues/254#issuecomment-423838345)* ([@nschloe](https://github.com/nschloe)) - -*["The specific reason to have BLIS included [in Linux distributions] is the KNL and SKX [AVX-512] BLAS support, which OpenBLAS doesn't have."](https://github.com/flame/blis/issues/210#issuecomment-393126303)* ([@loveshack](https://github.com/loveshack)) - -*["All tests pass without errors on OpenBSD. Thanks!"](https://github.com/flame/blis/issues/202#issuecomment-389691543)* ([@ararslan](https://github.com/ararslan)) - -*["Thank you very much for your great help!... Looking forward to benchmarking."](https://github.com/flame/blis/issues/180#issuecomment-375895449)* ([@mrader1248](https://github.com/mrader1248)) - -*["Thanks for the beautiful work."](https://github.com/flame/blis/issues/163#issue-286575452)* ([@mmrmo](https://github.com/mmrmo)) - -*["[M]y software currently uses BLIS for its BLAS interface..."](https://github.com/flame/blis/issues/129#issuecomment-302904805)* ([@ShadenSmith](https://github.com/ShadenSmith)) - -*["[T]hanks so much for your work on this! Excited to test."](https://github.com/flame/blis/issues/129#issuecomment-341565071)* ... *["[On AMD Excavator], BLIS is competitive to / slightly faster than OpenBLAS for dgemms in my tests."](https://github.com/flame/blis/issues/129#issuecomment-341608673)* ([@iotamudelta](https://github.com/iotamudelta)) - -*["BLIS provided the only viable option on KNL, whose ecosystem is at present dominated by blackbox toolchains. Thanks again. Keep on this great work."](https://github.com/flame/blis/issues/116#issuecomment-281225101)* ([@heroxbd](https://github.com/heroxbd)) - -*["I want to definitely try this out..."](https://github.com/flame/blis/issues/12#issuecomment-48086295)* ([@ViralBShah](https://github.com/ViralBShah)) - -Key Features ------------- - -BLIS offers several advantages over traditional BLAS libraries: - - * **Portability that doesn't impede high performance.** Portability was a top -priority of ours when creating BLIS. With virtually no additional effort on the -part of the developer, BLIS is configurable as a fully-functional reference -implementation. But more importantly, the framework identifies and isolates a -key set of computational kernels which, when optimized, immediately and -automatically optimize performance across virtually all level-2 and level-3 -BLIS operations. In this way, the framework acts as a productivity multiplier. -And since the optimized (non-portable) code is compartmentalized within these -few kernels, instantiating a high-performance BLIS library on a new -architecture is a relatively straightforward endeavor. - - * **Generalized matrix storage.** The BLIS framework exports interfaces that -allow one to specify both the row stride and column stride of a matrix. This -allows one to compute with matrices stored in column-major order, row-major -order, or by general stride. (This latter storage format is important for those -seeking to implement tensor contractions on multidimensional arrays.) -Furthermore, since BLIS tracks stride information for each matrix, operands of -different storage formats can be used within the same operation invocation. By -contrast, BLAS requires column-major storage. And while the CBLAS interface -supports row-major storage, it does not allow mixing storage formats. - - * **Rich support for the complex domain.** BLIS operations are developed and -expressed in their most general form, which is typically in the complex domain. -These formulations then simplify elegantly down to the real domain, with -conjugations becoming no-ops. Unlike the BLAS, all input operands in BLIS that -allow transposition and conjugate-transposition also support conjugation -(without transposition), which obviates the need for thread-unsafe workarounds. -Also, where applicable, both complex symmetric and complex Hermitian forms are -supported. (BLAS omits some complex symmetric operations, such as `symv`, -`syr`, and `syr2`.) Another great example of BLIS serving as a portability -lever is its implementation of the 1m method for complex matrix multiplication, -a novel mechanism of providing high-performance complex level-3 operations using -only real domain microkernels. This new innovation guarantees automatic level-3 -support in the complex domain even when the kernel developers entirely forgo -writing complex kernels. - - * **Advanced multithreading support.** BLIS allows multiple levels of -symmetric multithreading for nearly all level-3 operations. (Currently, users -may choose to obtain parallelism via either OpenMP or POSIX threads). This -means that matrices may be partitioned in multiple dimensions simultaneously to -attain scalable, high-performance parallelism on multicore and many-core -architectures. The key to this innovation is a thread-specific control tree -infrastructure which encodes information about the logical thread topology and -allows threads to query and communicate data amongst one another. BLIS also -employs so-called "quadratic partitioning" when computing dimension sub-ranges -for each thread, so that arbitrary diagonal offsets of structured matrices with -unreferenced regions are taken into account to achieve proper load balance. -More recently, BLIS introduced a runtime abstraction to specify parallelism on -a per-call basis, which is useful for applications that want to handle most of -the parallelism. - - * **Ease of use.** The BLIS framework, and the library of routines it -generates, are easy to use for end users, experts, and vendors alike. An -optional BLAS compatibility layer provides application developers with -backwards compatibility to existing BLAS-dependent codes. Or, one may adjust or -write their application to take advantage of new BLIS functionality (such as -generalized storage formats or additional complex operations) by calling one -of BLIS's native APIs directly. BLIS's typed API will feel familiar to many -veterans of BLAS since these interfaces use BLAS-like calling sequences. And -many will find BLIS's object-based APIs a delight to use when customizing -or writing their own BLIS operations. (Objects are relatively lightweight -`structs` and passed by address, which helps tame function calling overhead.) - - * **Multilayered API, exposed kernels, and sandboxes.** The BLIS framework -exposes its -implementations in various layers, allowing expert developers to access exactly -the functionality desired. This layered interface includes that of the -lowest-level kernels, for those who wish to bypass the bulk of the framework. -Optimizations can occur at various levels, in part thanks to exposed packing -and unpacking facilities, which by default are highly parameterized and -flexible. And more recently, BLIS introduced sandboxes--a way to provide -alternative implementations of `gemm` that do not use any more of the BLIS -infrastructure than is desired. Sandboxes provide a convenient and -straightforward way of modifying the `gemm` implementation without disrupting -any other level-3 operation or any other part of the framework. This works -especially well when the developer wants to experiment with new optimizations -or try a different algorithm. - - * **Functionality that grows with the community's needs.** As its name -suggests, the BLIS framework is not a single library or static API, but rather -a nearly-complete template for instantiating high-performance BLAS-like -libraries. Furthermore, the framework is extensible, allowing developers to -leverage existing components to support new operations as they are identified. -If such operations require new kernels for optimal efficiency, the framework -and its APIs will be adjusted and extended accordingly. - - * **Code re-use.** Auto-generation approaches to achieving the aforementioned -goals tend to quickly lead to code bloat due to the multiple dimensions of -variation supported: operation (i.e. `gemm`, `herk`, `trmm`, etc.); parameter -case (i.e. side, [conjugate-]transposition, upper/lower storage, unit/non-unit -diagonal); datatype (i.e. single-/double-precision real/complex); matrix -storage (i.e. row-major, column-major, generalized); and algorithm (i.e. -partitioning path and kernel shape). These "brute force" approaches often -consider and optimize each operation or case combination in isolation, which is -less than ideal when the goal is to provide entire libraries. BLIS was designed -to be a complete framework for implementing basic linear algebra operations, -but supporting this vast amount of functionality in a manageable way required a -holistic design that employed careful abstractions, layering, and recycling of -generic (highly parameterized) codes, subject to the constraint that high -performance remain attainable. - - * **A foundation for mixed domain and/or mixed precision operations.** BLIS -was designed with the hope of one day allowing computation on real and complex -operands within the same operation. Similarly, we wanted to allow mixing -operands' numerical domains, floating-point precisions, or both domain and -precision, and to optionally compute in a precision different than one or both -operands' storage precisions. This feature has been implemented for the general -matrix multiplication (`gemm`) operation, providing 128 different possible type -combinations, which, when combined with existing transposition, conjugation, -and storage parameters, enables 55,296 different `gemm` use cases. For more -details, please see the documentation on [mixed datatype](docs/MixedDatatypes.md) -support and/or our [ACM TOMS](https://toms.acm.org/) journal paper on -mixed-domain/mixed-precision `gemm` ([linked below](#citations)). - -How to Download BLIS --------------------- - -There are a few ways to download BLIS. We list the most common four ways below. -We **highly recommend** using either Option 1 or 2. Otherwise, we recommend -Option 3 (over Option 4) so your compiler can perform optimizations specific -to your hardware. - -1. **Download a source repository with `git clone`.** -Generally speaking, we prefer using `git clone` to clone a `git` repository. -Having a repository allows the user to periodically pull in the latest changes -and quickly rebuild BLIS whenever they wish. Also, implicit in cloning a -repository is that the repository defaults to using the `master` branch, which -contains the latest "stable" commits since the most recent release. (This is -in contrast to Option 3 in which the user is opting for code that may be -slightly out of date.) - - In order to clone a `git` repository of BLIS, please obtain a repository -URL by clicking on the green button above the file/directory listing near the -top of this page (as rendered by GitHub). Generally speaking, it will amount -to executing the following command in your terminal shell: - ``` - git clone https://github.com/flame/blis.git - ``` - -2. **Download a source repository via a zip file.** -If you are uncomfortable with using `git` but would still like the latest -stable commits, we recommend that you download BLIS as a zip file. - - In order to download a zip file of the BLIS source distribution, please -click on the green button above the file listing near the top of this page. -This should reveal a link for downloading the zip file. - -3. **Download a source release via a tarball/zip file.** -Alternatively, if you would like to stick to the code that is included in -official releases, you may download either a tarball or zip file of any of -BLIS's previous [tagged releases](https://github.com/flame/blis/releases). -We consider this option to be less than ideal for most people since it will -likely mean you miss out on the latest bugfix or feature commits (in contrast -to Options 1 or 2), and you also will not be able to update your code with a -simple `git pull` command (in contrast to Option 1). - -4. **Download a binary package specific to your OS.** -While we don't recommend this as the first choice for most users, we provide -links to community members who generously maintain BLIS packages for various -Linux distributions such as Debian Unstable and EPEL/Fedora. Please see the -[External Packages](#external-packages) section below for more information. - -Getting Started ---------------- - -*NOTE: This section assumes you've either cloned a BLIS source code repository -via `git`, downloaded the latest source code via a zip file, or downloaded the -source code for a tagged version release---Options 1, 2, or 3, respectively, -as discussed in [the previous section](#how-to-download-blis).* - -If you just want to build a sequential (not parallelized) version of BLIS -in a hurry and come back and explore other topics later, you can configure -and build BLIS as follows: -``` -$ ./configure auto -$ make [-j] -``` -You can then verify your build by running BLAS- and BLIS-specific test -drivers via `make check`: -``` -$ make check [-j] -``` -And if you would like to install BLIS to the directory specified to `configure` -via the `--prefix` option, run the `install` target: -``` -$ make install -``` -Please read the output of `./configure --help` for a full list of configure-time -options. -If/when you have time, we *strongly* encourage you to read the detailed -walkthrough of the build system found in our [Build System](docs/BuildSystem.md) -guide. - -Documentation -------------- - -We provide extensive documentation on the BLIS build system, APIs, test -infrastructure, and other important topics. All documentation is formatted in -markdown and included in the BLIS source distribution (usually in the `docs` -directory). Slightly longer descriptions of each document may be found via in -the project's [wiki](https://github.com/flame/blis/wiki) section. - -**Documents for everyone:** - - * **[Build System](docs/BuildSystem.md).** This document covers the basics of -configuring and building BLIS libraries, as well as related topics. - - * **[Testsuite](docs/Testsuite.md).** This document describes how to run -BLIS's highly parameterized and configurable test suite, as well as the -included BLAS test drivers. - - * **[BLIS Typed API Reference](docs/BLISTypedAPI.md).** Here we document the -so-called "typed" (or BLAS-like) API. This is the API that many users who are -already familiar with the BLAS will likely want to use. You can find lots of -example code for the typed API in the [examples/tapi](examples/tapi) directory -included in the BLIS source distribution. - - * **[BLIS Object API Reference](docs/BLISObjectAPI.md).** Here we document -the object API. This is API abstracts away properties of vectors and matrices -within `obj_t` structs that can be queried with accessor functions. Many -developers and experts prefer this API over the typed API. You can find lots of -example code for the object API in the [examples/oapi](examples/oapi) directory -included in the BLIS source distribution. - - * **[Hardware Support](docs/HardwareSupport.md).** This document maintains a -table of supported microarchitectures. - - * **[Multithreading](docs/Multithreading.md).** This document describes how to -use the multithreading features of BLIS. - - * **[Mixed-Datatypes](docs/MixedDatatypes.md).** This document provides an -overview of BLIS's mixed-datatype functionality and provides a brief example -of how to take advantage of this new code. - - * **[Performance](docs/Performance.md).** This document reports empirically -measured performance of a representative set of level-3 operations on a variety -of hardware architectures, as implemented within BLIS and other BLAS libraries -for all four of the standard floating-point datatypes. - - * **[PerformanceSmall](docs/PerformanceSmall.md).** This document reports -empirically measured performance of `gemm` on select hardware architectures -within BLIS and other BLAS libraries when performing matrix problems where one -or two dimensions is exceedingly small. - - * **[Release Notes](docs/ReleaseNotes.md).** This document tracks a summary of -changes included with each new version of BLIS, along with contributor credits -for key features. - - * **[Frequently Asked Questions](docs/FAQ.md).** If you have general questions -about BLIS, please read this FAQ. If you can't find the answer to your question, -please feel free to join the [blis-devel](https://groups.google.com/group/blis-devel) -mailing list and post a question. We also have a -[blis-discuss](https://groups.google.com/group/blis-discuss) mailing list that -anyone can post to (even without joining). - -**Documents for github contributors:** - - * **[Contributing bug reports, feature requests, PRs, etc](CONTRIBUTING.md).** -Interested in contributing to BLIS? Please read this document before getting -started. It provides a general overview of how best to report bugs, propose new -features, and offer code patches. - - * **[Coding Conventions](docs/CodingConventions.md).** If you are interested or -planning on contributing code to BLIS, please read this document so that you can -format your code in accordance with BLIS's standards. - -**Documents for BLIS developers:** - - * **[Kernels Guide](docs/KernelsHowTo.md).** If you would like to learn more -about the types of kernels that BLIS exposes, their semantics, the operations -that each kernel accelerates, and various implementation issues, please read -this guide. - - * **[Configuration Guide](docs/ConfigurationHowTo.md).** If you would like to -learn how to add new sub-configurations or configuration families, or are simply -interested in learning how BLIS organizes its configurations and kernel sets, -please read this thorough walkthrough of the configuration system. - - * **[Sandbox Guide](docs/Sandboxes.md).** If you are interested in learning -about using sandboxes in BLIS--that is, providing alternative implementations -of the `gemm` operation--please read this document. - -External Packages ------------------ - -Generally speaking, we **highly recommend** building from source whenever -possible using the latest `git` clone. (Tarballs of each -[tagged release](https://github.com/flame/blis/releases) are also available, but -we consider them to be less ideal since they are not as easy to upgrade as -`git` clones.) - -That said, some users may prefer binary and/or source packages through their -Linux distribution. Thanks to generous involvement/contributions from our -community members, the following BLIS packages are now available: - - * **Debian**. [M. Zhou](https://github.com/cdluminate) has volunteered to -sponsor and maintain BLIS packages within the Debian Linux distribution. The -Debian package tracker can be found [here](https://tracker.debian.org/pkg/blis). -(Also, thanks to [Nico Schlömer](https://github.com/nschloe) for previously -volunteering his time to set up a standalone PPA.) - - * **Gentoo**. [M. Zhou](https://github.com/cdluminate) also maintains the -[BLIS package](https://packages.gentoo.org/packages/sci-libs/blis) entry for -[Gentoo](https://www.gentoo.org/), a Linux distribution known for its -source-based [portage](https://wiki.gentoo.org/wiki/Portage) package manager -and distribution system. - - * **EPEL/Fedora**. There are official BLIS packages in Fedora and EPEL (for -RHEL7+ and compatible distributions) with versions for 64-bit integers, OpenMP, -and pthreads, and shims which can be dynamically linked instead of reference -BLAS. (NOTE: For architectures other than intel64, amd64, and maybe arm64, the -performance of packaged BLIS will be low because it uses unoptimized generic -kernels; for those architectures, [OpenBLAS](https://github.com/xianyi/OpenBLAS) -may be a better solution.) [Dave -Love](https://github.com/loveshack) provides additional packages for EPEL6 in a -[Fedora Copr](https://copr.fedorainfracloud.org/coprs/loveshack/blis/), and -possibly versions more recent than the official repo for other EPEL/Fedora -releases. The source packages may build on other rpm-based distributions. - - * **OpenSuSE**. The copr referred to above has rpms for some OpenSuSE releases; -the source rpms may build for others. - - * **GNU Guix**. Guix has BLIS packages, provides builds only for the generic -target and some specific x86_64 micro-architectures. - - * **Conda**. conda channel [conda-forge](https://github.com/conda-forge/blis-feedstock) -has Linux, OSX and Windows binary packages for x86_64. - -Discussion ----------- - -You can keep in touch with developers and other users of the project by joining -one of the following mailing lists: - - * [blis-devel](https://groups.google.com/group/blis-devel): Please join and -post to this mailing list if you are a BLIS developer, or if you are trying -to use BLIS beyond simply linking to it as a BLAS library. -**Note:** Most of the interesting discussions happen here; don't be afraid to -join! If you would like to submit a bug report, or discuss a possible bug, -please consider opening a [new issue](https://github.com/flame/blis/issues) on -github. - - * [blis-discuss](https://groups.google.com/group/blis-discuss): Please join and -post to this mailing list if you have general questions or feedback regarding -BLIS. Application developers (end users) may wish to post here, unless they -have bug reports, in which case they should open a -[new issue](https://github.com/flame/blis/issues) on github. - -Contributing ------------- - -For information on how to contribute to our project, including preferred -[coding conventions](docs/CodingConventions.md), please refer to the -[CONTRIBUTING](CONTRIBUTING.md) file at the top-level of the BLIS source -distribution. - -Citations ---------- - -For those of you looking for the appropriate article to cite regarding BLIS, we -recommend citing our -[first ACM TOMS journal paper]( https://dl.acm.org/doi/10.1145/2764454?cid=81314495332) -([unofficial backup link](https://www.cs.utexas.edu/users/flame/pubs/blis1_toms_rev3.pdf)): - -``` -@article{BLIS1, - author = {Field G. {V}an~{Z}ee and Robert A. {v}an~{d}e~{G}eijn}, - title = {{BLIS}: A Framework for Rapidly Instantiating {BLAS} Functionality}, - journal = {ACM Transactions on Mathematical Software}, - volume = {41}, - number = {3}, - pages = {14:1--14:33}, - month = {June}, - year = {2015}, - issue_date = {June 2015}, - url = {http://doi.acm.org/10.1145/2764454}, -} -``` - -You may also cite the -[second ACM TOMS journal paper]( https://dl.acm.org/doi/10.1145/2755561?cid=81314495332) -([unofficial backup link](https://www.cs.utexas.edu/users/flame/pubs/blis2_toms_rev3.pdf)): - -``` -@article{BLIS2, - author = {Field G. {V}an~{Z}ee and Tyler Smith and Francisco D. Igual and - Mikhail Smelyanskiy and Xianyi Zhang and Michael Kistler and Vernon Austel and - John Gunnels and Tze Meng Low and Bryan Marker and Lee Killough and - Robert A. {v}an~{d}e~{G}eijn}, - title = {The {BLIS} Framework: Experiments in Portability}, - journal = {ACM Transactions on Mathematical Software}, - volume = {42}, - number = {2}, - pages = {12:1--12:19}, - month = {June}, - year = {2016}, - issue_date = {June 2016}, - url = {http://doi.acm.org/10.1145/2755561}, -} -``` - -We also have a third paper, submitted to IPDPS 2014, on achieving -[multithreaded parallelism in BLIS](https://dl.acm.org/doi/10.1109/IPDPS.2014.110) -([unofficial backup link](https://www.cs.utexas.edu/users/flame/pubs/blis3_ipdps14.pdf)): - -``` -@inproceedings{BLIS3, - author = {Tyler M. Smith and Robert A. {v}an~{d}e~{G}eijn and Mikhail Smelyanskiy and - Jeff R. Hammond and Field G. {V}an~{Z}ee}, - title = {Anatomy of High-Performance Many-Threaded Matrix Multiplication}, - booktitle = {28th IEEE International Parallel \& Distributed Processing Symposium - (IPDPS 2014)}, - year = {2014}, - url = {https://doi.org/10.1109/IPDPS.2014.110}, -} -``` - -A fourth paper, submitted to ACM TOMS, also exists, which proposes an -[analytical model](https://dl.acm.org/doi/10.1145/2925987) -for determining blocksize parameters in BLIS -([unofficial backup link](https://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf)): - -``` -@article{BLIS4, - author = {Tze Meng Low and Francisco D. Igual and Tyler M. Smith and - Enrique S. Quintana-Ort\'{\i}}, - title = {Analytical Modeling Is Enough for High-Performance {BLIS}}, - journal = {ACM Transactions on Mathematical Software}, - volume = {43}, - number = {2}, - pages = {12:1--12:18}, - month = {August}, - year = {2016}, - issue_date = {August 2016}, - url = {http://doi.acm.org/10.1145/2925987}, -} -``` - -A fifth paper, submitted to ACM TOMS, begins the study of so-called -[induced methods for complex matrix multiplication]( https://dl.acm.org/doi/10.1145/3086466?cid=81314495332) -([unofficial backup link](https://www.cs.utexas.edu/users/flame/pubs/blis5_toms_rev2.pdf)): - -``` -@article{BLIS5, - author = {Field G. {V}an~{Z}ee and Tyler Smith}, - title = {Implementing High-performance Complex Matrix Multiplication via the 3m and 4m Methods}, - journal = {ACM Transactions on Mathematical Software}, - volume = {44}, - number = {1}, - pages = {7:1--7:36}, - month = {July}, - year = {2017}, - issue_date = {July 2017}, - url = {http://doi.acm.org/10.1145/3086466}, -} -``` - -A sixth paper, submitted to ACM TOMS, revisits the topic of the previous -article and derives a -[superior induced method](https://epubs.siam.org/doi/10.1137/19M1282040) -([unofficial backup link](https://www.cs.utexas.edu/users/flame/pubs/blis6_sisc_rev3.pdf)): - -``` -@article{BLIS6, - author = {Field G. {V}an~{Z}ee}, - title = {Implementing High-Performance Complex Matrix Multiplication via the 1m Method}, - journal = {SIAM Journal on Scientific Computing}, - volume = {42}, - number = {5}, - pages = {C221--C244}, - month = {September} - year = {2020}, - issue_date = {September 2020}, - url = {https://doi.org/10.1137/19M1282040} -} -``` - -A seventh paper, submitted to ACM TOMS, explores the implementation of `gemm` for -[mixed-domain and/or mixed-precision](https://www.cs.utexas.edu/users/flame/pubs/blis7_toms_rev0.pdf) operands -([unofficial backup link](https://www.cs.utexas.edu/users/flame/pubs/blis7_toms_rev0.pdf)): - -``` -@article{BLIS7, - author = {Field G. {V}an~{Z}ee and Devangi N. Parikh and Robert A. van~de~{G}eijn}, - title = {Supporting Mixed-domain Mixed-precision Matrix Multiplication -within the BLIS Framework}, - journal = {ACM Transactions on Mathematical Software}, - note = {submitted} -} -``` - -Funding -------- - -This project and its associated research were partially sponsored by grants from -[Microsoft](https://www.microsoft.com/), -[Intel](https://www.intel.com/), -[Texas Instruments](https://www.ti.com/), -[AMD](https://www.amd.com/), -[HPE](https://www.hpe.com/), -[Oracle](https://www.oracle.com/), -[Huawei](https://www.huawei.com/), -and -[Facebook](https://www.facebook.com/), -as well as grants from the -[National Science Foundation](https://www.nsf.gov/) (Awards -CCF-0917167, ACI-1148125/1340293, CCF-1320112, and ACI-1550493). - -_Any opinions, findings and conclusions or recommendations expressed in this -material are those of the author(s) and do not necessarily reflect the views of -the National Science Foundation (NSF)._ +The upstream repository (https://github.com/flame/blis) contains further information on BLIS, including background information on BLIS design, usage examples, and a complete BLIS API reference. +AOCL-BLAS is developed and maintained by AMD. You can contact us on the email-id toolchainsupport@amd.com. You can also raise any issue/suggestion on the git-hub repository at https://github.com/amd/blis/issues. diff --git a/addon/aocl_gemm/aocl_gemm.h b/addon/aocl_gemm/aocl_gemm.h index 4e971d932a..44de4ac658 100644 --- a/addon/aocl_gemm/aocl_gemm.h +++ b/addon/aocl_gemm/aocl_gemm.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -37,5 +37,18 @@ #include "aocl_gemm_post_ops.h" #include "aocl_gemm_interface_apis.h" +#include "aocl_util_interface_apis.h" +#include "aocl_bf16_type.h" +#include "lpgemm_config.h" +#include "lpgemm_post_ops.h" +#include "lpgemm_kernels.h" +#include "lpgemm_utils_kernels.h" +#include "lpgemm_packb_bf16.h" +#include "lpgemm_packb_s16.h" +#include "lpgemm_packa.h" +#include "lpgemm_packb.h" +#include "lpgemm_packa_s8.h" +#include "lpgemm_packb_s8.h" +#include "lpgemm_packb_s8s16.h" #endif // BLIS_ADDON_LPGEMM diff --git a/addon/aocl_gemm/aocl_gemm_bf16_utils.c b/addon/aocl_gemm/aocl_gemm_bf16_utils.c index 7af08b751b..fd9d3be1f7 100644 --- a/addon/aocl_gemm/aocl_gemm_bf16_utils.c +++ b/addon/aocl_gemm/aocl_gemm_bf16_utils.c @@ -1,126 +1,136 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" -#include "aocl_gemm_interface_apis.h" -#include "lpgemm_types.h" -#include "lpgemm_config.h" -#include "lpgemm_utils.h" -#include "lpgemm_reorder_bf16.h" - -AOCL_GEMM_GET_REORDER_BUF_SIZE(bf16bf16f32of32) -{ - if ( ( k <= 0 ) || ( n <= 0 ) ) - { - return 0; // Error. - } - - // Check if avx512_bf16 ISA is supported, lpgemm matmul only works with it. - if ( bli_cpuid_is_avx512_bf16_supported() == FALSE ) - { - printf(" AVX512_BF16 ISA not supported by processor, cannot perform lpgemm.\n"); - return 0; // Error. - } - - /* Initialize BLIS. */ - bli_init_auto(); - - // Set MC, NC, KC, NR, MR. - aocl_lpgemm_init_global_cntx(); - - AOCL_MATRIX_TYPE input_mat_type; - bli_param_map_char_to_lpmat_type( mat_type, &input_mat_type ); - - if ( input_mat_type == A_MATRIX ) - { - return 0; // A reorder not supported. - } - - // Extra space since packing does width in multiples of 16. The bf16 - // instruction can be used as long as atleast one zmm register can be fully - // loaded; and since k_dim needs to be atleast 2, having n_dim atleast 16 - // should give 2x16=32 elements, enough for 1 zmm register.The padding is - // not rounded to NR (=64), since that would result in memory wastage. - dim_t n_reorder = make_multiple_of_n( n, 16 ); - - // Extra space since packing does length in multiples of 2. - dim_t k_reorder = make_multiple_of_n( k, 2 ); - - siz_t size_req = sizeof( int16_t ) * k_reorder * n_reorder; - - return size_req; -} - -AOCL_GEMM_REORDER(bfloat16, bf16bf16f32of32) -{ - if ( ( input_buf_addr == NULL ) || ( reorder_buf_addr == NULL ) || - ( k <= 0 ) || ( n <= 0 ) || ( ldb < n ) ) - { - return; // Error. - } - - // Check if avx512_bf16 ISA is supported, lpgemm matmul only works with it. - if ( bli_cpuid_is_avx512_bf16_supported() == FALSE ) - { - printf(" AVX512_BF16 ISA not supported by processor, cannot perform lpgemm.\n"); - return; // Error. - } - - /* Initialize BLIS. */ - bli_init_auto(); - - // Set MC, NC, KC, NR, MR. - aocl_lpgemm_init_global_cntx(); - - AOCL_MATRIX_TYPE input_mat_type; - bli_param_map_char_to_lpmat_type( mat_type, &input_mat_type ); - - if ( input_mat_type == A_MATRIX ) - { - return; // A reorder not supported. - } - - // Create dummy b_reorder obj. - lpgemm_obj_t b_reorder; - b_reorder.storage.aligned_buffer = reorder_buf_addr; - - // Create dummy original b obj; - lpgemm_obj_t b; - b.storage.aligned_buffer = ( void* )input_buf_addr; - b.rs = ldb; - b.width = n; - b.length = k; - - reorderb_nr64_bf16bf16f32of32( &b, &b_reorder ); -} +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_interface_apis.h" +#include "lpgemm_types.h" +#include "lpgemm_config.h" +#include "lpgemm_utils.h" +#include "lpgemm_reorder_bf16.h" + +AOCL_GEMM_GET_REORDER_BUF_SIZE(bf16bf16f32of32) +{ + if ( ( k <= 0 ) || ( n <= 0 ) ) + { + return 0; // Error. + } + + // Check if avx512_bf16 ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512bf16_supported() == FALSE ) + { + bli_print_msg(" AVX512_BF16 ISA not supported by processor, " + "cannot perform bf16bf16f32 gemm.", __FILE__, __LINE__ ); + return 0; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + AOCL_MATRIX_TYPE input_mat_type; + bli_param_map_char_to_lpmat_type( mat_type, &input_mat_type ); + + if ( input_mat_type == A_MATRIX ) + { + return 0; // A reorder not supported. + } + + // Extra space since packing does width in multiples of 16. The bf16 + // instruction can be used as long as at least one zmm register can be fully + // loaded; and since k_dim needs to be at least 2, having n_dim at least 16 + // should give 2x16=32 elements, enough for 1 zmm register.The padding is + // not rounded to NR (=64), since that would result in memory wastage. + dim_t n_reorder = make_multiple_of_n( n, 16 ); + + // Extra space since packing does length in multiples of 2. + dim_t k_reorder = make_multiple_of_n( k, 2 ); + + siz_t size_req = sizeof( int16_t ) * k_reorder * n_reorder; + + return size_req; +} + +AOCL_GEMM_REORDER(bfloat16, bf16bf16f32of32) +{ + if ( ( input_buf_addr == NULL ) || ( reorder_buf_addr == NULL ) || + ( k <= 0 ) || ( n <= 0 ) || ( ldb < n ) ) + { + return; // Error. + } + + // Check if avx512_bf16 ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512bf16_supported() == FALSE ) + { + bli_print_msg(" AVX512_BF16 ISA not supported by processor, " + "cannot perform bf16bf16f32 gemm.", __FILE__, __LINE__ ); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + AOCL_MATRIX_TYPE input_mat_type; + bli_param_map_char_to_lpmat_type( mat_type, &input_mat_type ); + + if ( input_mat_type == A_MATRIX ) + { + return; // A reorder not supported. + } + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global( &rntm_g ); + bli_membrk_rntm_set_membrk( &rntm_g ); + + lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16BF16F32OF32 ); + + // Create dummy b_reorder obj. + lpgemm_obj_t b_reorder; + b_reorder.storage.aligned_buffer = reorder_buf_addr; + + // Create dummy original b obj; + lpgemm_obj_t b; + b.storage.aligned_buffer = ( void* )input_buf_addr; + b.rs = ldb; + b.width = n; + b.length = k; + + reorderb_nr64_bf16bf16f32of32( &b, &b_reorder, &rntm_g, lcntx_g ); +} diff --git a/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c b/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c index fedf3a43c5..0e0f93e191 100644 --- a/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c +++ b/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -46,10 +46,24 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16) trans_t blis_transa; trans_t blis_transb; + // There is this use case where lpgemm will be compiled using gcc9.4 + // (where bf16 ISA is not supported), but deployed on a zen4+ sustem + // (which supports bf16 ISA). Here the bf16 kernels will be concealed + // and not compiled, and subsequently this api should error out and + // return early, even if bf16 ISA is supported by machine. +#if defined( BLIS_GCC ) && ( __GNUC__ < 10 ) + { + bli_print_msg("bf16bf16f32obf16 compiled using a compiler not " + "supporting BF16 ISA.", __FILE__, __LINE__ ); + return; // Error. + } +#endif + // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. - if ( bli_cpuid_is_avx512_bf16_supported() == FALSE ) + if ( bli_cpuid_is_avx512bf16_supported() == FALSE ) { - printf(" AVX512_BF16 ISA not supported by processor, cannot perform lpgemm.\n"); + bli_print_msg(" AVX512_BF16 ISA not supported by processor, " + "cannot perform bf16bf16f32 gemm.", __FILE__, __LINE__ ); return; // Error. } @@ -158,6 +172,8 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16) bli_rntm_init_from_global( &rntm_g ); bli_membrk_rntm_set_membrk( &rntm_g ); + lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16BF16F32OF32 ); + #ifdef BLIS_ENABLE_OPENMP // Swapping inputs to induce row major computation for column major inputs. if ( is_column_major == TRUE ) @@ -169,7 +185,7 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16) a, rs_a, cs_a, mtag_a, ( float* )c, rs_c, cs_c, alpha, beta, - &rntm_g, + &rntm_g, lcntx_g, post_op_list, TRUE ); } @@ -182,7 +198,7 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16) b, rs_b, cs_b, mtag_b, ( float* )c, rs_c, cs_c, alpha, beta, - &rntm_g, + &rntm_g, lcntx_g, post_op_list, TRUE ); } @@ -197,7 +213,7 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16) a, rs_a, cs_a, mtag_a, ( float* )c, rs_c, cs_c, alpha, beta, - &rntm_g, + &rntm_g, lcntx_g, post_op_list, TRUE ); } @@ -210,7 +226,7 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16) b, rs_b, cs_b, mtag_b, ( float* )c, rs_c, cs_c, alpha, beta, - &rntm_g, + &rntm_g, lcntx_g, post_op_list, TRUE ); } diff --git a/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c b/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c index 8f87f4dff3..ca8b160220 100644 --- a/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c +++ b/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c @@ -1,218 +1,241 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" -#include "aocl_gemm_interface_apis.h" -#include "lpgemm_types.h" -#include "lpgemm_post_ops.h" -#include "lpgemm_thread_decor_openmp.h" -#include "lpgemm_5loop_interface_apis.h" -#include "lpgemm_config.h" -#include "lpgemm_utils.h" - -AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32) -{ - trans_t blis_transa; - trans_t blis_transb; - - // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. - if ( bli_cpuid_is_avx512_bf16_supported() == FALSE ) - { - printf(" AVX512_BF16 ISA not supported by processor, cannot perform lpgemm.\n"); - return; // Error. - } - - /* Initialize BLIS. */ - bli_init_auto(); - - // Set MC, NC, KC, NR, MR. - aocl_lpgemm_init_global_cntx(); - - // Null check for pointers. - if ( ( a == NULL ) || ( b == NULL ) || ( c == NULL ) ) - { - return; // Error. - } - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); - bli_param_map_netlib_to_blis_trans( transb, &blis_transb ); - - /* Perform BLAS parameter checking. */ - // Transpose not supported. - if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || - ( blis_transb != BLIS_NO_TRANSPOSE ) ) - { - return; // Error. - } - - // Sanitize order input. - char order_use = - ( ( order == 'r' ) || ( order == 'R' ) || - ( order == 'c' ) || ( order == 'C' ) ) ? - order : 'r'; - - bool is_row_major = ( ( order_use == 'r' ) || ( order_use == 'R' ) ); - bool is_column_major = ( ( order_use == 'c' ) || ( order_use == 'C' ) ); - - // Row major input expected with leading dimensions >= row stride. - if ( ( is_row_major == TRUE ) && - ( ( lda < k ) || ( ldb < n ) || ( ldc < n ) ) ) - { - return; // Error. - } - // Column major input expected with leading dimensions >= column stride. - else if ( ( is_column_major == TRUE ) && - ( ( lda < m ) || ( ldb < k ) || ( ldc < m ) ) ) - { - return; // Error. - } - - // Check if dimensions are valid. - if ( ( m <= 0) || ( n <= 0 ) || ( k <= 0 ) || - ( lda <= 0 ) || ( ldb <= 0 ) || ( ldc <= 0 ) ) - { - return; // Error. - } - - const inc_t rs_a = lda; - const inc_t cs_a = 1; - const inc_t rs_b = ldb; - const inc_t cs_b = 1; - const inc_t rs_c = ldc; - const inc_t cs_c = 1; - - AOCL_MEMORY_TAG mtag_a; - AOCL_MEMORY_TAG mtag_b; - - bli_param_map_char_to_lpmtag( mem_format_a, &mtag_a ); - bli_param_map_char_to_lpmtag( mem_format_b, &mtag_b ); - - // B matrix needs to be packed in a certain format in order to be loaded - // and used in bf16 instrution. As such the mtag_b always needs to be either - // packed or reordered. B matrix as it is (unpacked) cannot be used, and - // the mtag_b is set to packed to enable runtime packing. - if ( ( is_row_major == TRUE ) && ( mtag_b == UNPACKED ) ) - { - mtag_b = PACK; - } - // Inputs swapped in column major, A becomes B from kernel point of view. - else if ( ( is_column_major == TRUE ) && ( mtag_a == UNPACKED ) ) - { - mtag_a = PACK; - } - - // Only unpacked A supported now. - if ( ( is_row_major == TRUE ) && ( mtag_a != UNPACKED ) ) - { - return; // Error. - } - // Inputs swapped in column major, B becomes A from kernel point of view. - else if ( ( is_column_major == TRUE ) && ( mtag_b != UNPACKED ) ) - { - return; // Error. - } - - // Convert post op struct to post op linked list format. - lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; - lpgemm_translate_to_post_ops_list - ( - post_op_unparsed, post_op_list, - ( void* )c, ( void* )( &order_use ) - ); - - // Initialize a local runtime with global settings if necessary. Note - // that in the case that a runtime is passed in, we make a local copy. - rntm_t rntm_g; - bli_rntm_init_from_global( &rntm_g ); - bli_membrk_rntm_set_membrk( &rntm_g ); - -#ifdef BLIS_ENABLE_OPENMP - // Swapping inputs to induce row major computation for column major inputs. - if ( is_column_major == TRUE ) - { - lpgemm_bf16bf16f32of32_openmp_thread_decorator - ( - n, m, k, - b, rs_b, cs_b, mtag_b, - a, rs_a, cs_a, mtag_a, - c, rs_c, cs_c, - alpha, beta, - &rntm_g, - post_op_list, FALSE - ); - } - else - { - lpgemm_bf16bf16f32of32_openmp_thread_decorator - ( - m, n, k, - a, rs_a, cs_a, mtag_a, - b, rs_b, cs_b, mtag_b, - c, rs_c, cs_c, - alpha, beta, - &rntm_g, - post_op_list, FALSE - ); - } -#else - // Swapping inputs to induce row major computation for column major inputs. - if ( is_column_major == TRUE ) - { - lpgemm_bf16bf16f32of32_thread_decorator - ( - n, m, k, - b, rs_b, cs_b, mtag_b, - a, rs_a, cs_a, mtag_a, - c, rs_c, cs_c, - alpha, beta, - &rntm_g, - post_op_list, FALSE - ); - } - else - { - lpgemm_bf16bf16f32of32_thread_decorator - ( - m, n, k, - a, rs_a, cs_a, mtag_a, - b, rs_b, cs_b, mtag_b, - c, rs_c, cs_c, - alpha, beta, - &rntm_g, - post_op_list, FALSE - ); - } -#endif -} +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_interface_apis.h" +#include "lpgemm_types.h" +#include "lpgemm_post_ops.h" +#include "lpgemm_thread_decor_openmp.h" +#include "lpgemm_5loop_interface_apis.h" +#include "lpgemm_config.h" +#include "lpgemm_utils.h" + +AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32) +{ + trans_t blis_transa; + trans_t blis_transb; + + // There is this use case where lpgemm will be compiled using gcc9.4 + // (where bf16 ISA is not supported), but deployed on a zen4+ sustem + // (which supports bf16 ISA). Here the bf16 kernels will be concealed + // and not compiled, and subsequently this api should error out and + // return early, even if bf16 ISA is supported by machine. +#if defined( BLIS_GCC ) && ( __GNUC__ < 10 ) + { + bli_print_msg("bf16bf16f32of32 compiled using a compiler not " + "supporting BF16 ISA.", __FILE__, __LINE__ ); + return; // Error. + } +#endif + + // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512bf16_supported() == FALSE ) + { + bli_print_msg(" AVX512_BF16 ISA not supported by processor, " + "cannot perform bf16bf16f32 gemm.", __FILE__, __LINE__ ); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + // Null check for pointers. + if ( ( a == NULL ) || ( b == NULL ) || ( c == NULL ) ) + { + return; // Error. + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); + bli_param_map_netlib_to_blis_trans( transb, &blis_transb ); + + /* Perform BLAS parameter checking. */ + // Transpose not supported. + if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || + ( blis_transb != BLIS_NO_TRANSPOSE ) ) + { + return; // Error. + } + + // Sanitize order input. + char order_use = + ( ( order == 'r' ) || ( order == 'R' ) || + ( order == 'c' ) || ( order == 'C' ) ) ? + order : 'r'; + + bool is_row_major = ( ( order_use == 'r' ) || ( order_use == 'R' ) ); + bool is_column_major = ( ( order_use == 'c' ) || ( order_use == 'C' ) ); + + // Row major input expected with leading dimensions >= row stride. + if ( ( is_row_major == TRUE ) && + ( ( lda < k ) || ( ldb < n ) || ( ldc < n ) ) ) + { + return; // Error. + } + // Column major input expected with leading dimensions >= column stride. + else if ( ( is_column_major == TRUE ) && + ( ( lda < m ) || ( ldb < k ) || ( ldc < m ) ) ) + { + return; // Error. + } + + // Check if dimensions are valid. + if ( ( m <= 0) || ( n <= 0 ) || ( k <= 0 ) || + ( lda <= 0 ) || ( ldb <= 0 ) || ( ldc <= 0 ) ) + { + return; // Error. + } + + // The strides are set assuming a row major kernel. + const inc_t rs_a = lda; + const inc_t cs_a = 1; + const inc_t rs_b = ldb; + const inc_t cs_b = 1; + const inc_t rs_c = ldc; + const inc_t cs_c = 1; + + AOCL_MEMORY_TAG mtag_a; + AOCL_MEMORY_TAG mtag_b; + + bli_param_map_char_to_lpmtag( mem_format_a, &mtag_a ); + bli_param_map_char_to_lpmtag( mem_format_b, &mtag_b ); + + if ( ( is_column_major == TRUE ) && ( mtag_b == REORDERED ) ) + { + // Reorder not supported with column major inputs. + return; + } + + // B matrix needs to be packed in a certain format in order to be loaded + // and used in bf16 instrution. As such the mtag_b always needs to be either + // packed or reordered. B matrix as it is (unpacked) cannot be used, and + // the mtag_b is set to packed to enable runtime packing. + if ( ( is_row_major == TRUE ) && ( mtag_b == UNPACKED ) ) + { + mtag_b = PACK; + } + // Inputs swapped in column major, A becomes B from kernel point of view. + else if ( ( is_column_major == TRUE ) && ( mtag_a == UNPACKED ) ) + { + mtag_a = PACK; + } + + // Only unpacked A supported now. + if ( ( is_row_major == TRUE ) && ( mtag_a != UNPACKED ) ) + { + return; // Error. + } + // Inputs swapped in column major, B becomes A from kernel point of view. + else if ( ( is_column_major == TRUE ) && ( mtag_b != UNPACKED ) ) + { + return; // Error. + } + + // Convert post op struct to post op linked list format. + lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; + lpgemm_translate_to_post_ops_list + ( + post_op_unparsed, post_op_list, + ( void* )c, ( void* )( &order_use ) + ); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global( &rntm_g ); + bli_membrk_rntm_set_membrk( &rntm_g ); + + lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( BF16BF16F32OF32 ); + +#ifdef BLIS_ENABLE_OPENMP + // Swapping inputs to induce row major computation for column major inputs. + if ( is_column_major == TRUE ) + { + lpgemm_bf16bf16f32of32_openmp_thread_decorator + ( + n, m, k, + b, rs_b, cs_b, mtag_b, + a, rs_a, cs_a, mtag_a, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, FALSE + ); + } + else + { + lpgemm_bf16bf16f32of32_openmp_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, FALSE + ); + } +#else + // Swapping inputs to induce row major computation for column major inputs. + if ( is_column_major == TRUE ) + { + lpgemm_bf16bf16f32of32_thread_decorator + ( + n, m, k, + b, rs_b, cs_b, mtag_b, + a, rs_a, cs_a, mtag_a, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, FALSE + ); + } + else + { + lpgemm_bf16bf16f32of32_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, FALSE + ); + } +#endif +} diff --git a/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c b/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c index 8366f746cb..f3eed1aa65 100644 --- a/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c +++ b/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -37,6 +37,7 @@ #include "lpgemm_types.h" #include "lpgemm_post_ops.h" #include "lpgemm_thread_decor_openmp.h" +#include "lpgemm_config.h" #include "lpgemm_utils.h" #include "lpgemm_5loop_interface_apis.h" @@ -45,16 +46,20 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32) trans_t blis_transa; trans_t blis_transb; - // Check if avx ISA is supported, lpgemm fp32 matmul only works with it. - if ( bli_cpuid_is_avx_supported() == FALSE ) + // Check if AVX2 ISA is supported, lpgemm fp32 matmul only works with it. + if ( bli_cpuid_is_avx2fma3_supported() == FALSE ) { - printf(" AVX2 ISA not supported by processor, cannot perform lpgemm.\n"); + bli_print_msg(" AVX2 ISA not supported by processor, " + "cannot perform f32f32f32 gemm.", __FILE__, __LINE__ ); return; // Error. } /* Initialize BLIS. */ bli_init_auto(); + // Initialize lpgemm context. + aocl_lpgemm_init_global_cntx(); + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(s), transa, transb, m, n, k,\ (void*)&alpha, lda, ldb, (void*)&beta, ldc); @@ -86,16 +91,20 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32) ( ( order == 'r' ) || ( order == 'R' ) || ( order == 'c' ) || ( order == 'C' ) ) ? order : 'r'; - if ( ( order_use != 'r' ) && ( order_use != 'R' ) ) + + bool is_row_major = ( ( order_use == 'r' ) || ( order_use == 'R' ) ); + bool is_column_major = ( ( order_use == 'c' ) || ( order_use == 'C' ) ); + + // Row major input expected with leading dimensions >= row stride. + if ( ( is_row_major == TRUE ) && + ( ( lda < k ) || ( ldb < n ) || ( ldc < n ) ) ) { - return; // Only row major supported. + return; // Error. } - - // Row major input expected with leading dimensions equal to row stride. - if ( ( lda != k ) || ( ldb != n ) || ( ldc != n ) ) + // Column major input expected with leading dimensions >= column stride. + else if ( ( is_column_major == TRUE ) && + ( ( lda < m ) || ( ldb < k ) || ( ldc < m ) ) ) { - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, \ - "Column major and general stride not supported."); return; // Error. } @@ -108,6 +117,7 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32) return; // Error. } + // The strides are set assuming a row major kernel. const inc_t rs_a = lda; const inc_t cs_a = 1; const inc_t rs_b = ldb; @@ -121,11 +131,38 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32) bli_param_map_char_to_lpmtag( mem_format_a, &mtag_a ); bli_param_map_char_to_lpmtag( mem_format_b, &mtag_b ); - // Only unreordered A supported now. - if ( mtag_a != UNPACKED ) + if ( ( is_column_major == TRUE ) && ( mtag_b == REORDERED ) ) + { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, \ + "Reordered B matrix not supported in column major case."); + return; + } + + // By default enable packing for B matrix. Before the 5 loop, based on + // the input dimensions, the smart threading logic will adjust it + // (disable/enable) accordingly. + if ( ( is_row_major == TRUE ) && ( mtag_b == UNPACKED ) ) + { + mtag_b = PACK; + } + // Inputs swapped in column major, A becomes B from kernel point of view. + else if ( ( is_column_major == TRUE ) && ( mtag_a == UNPACKED ) ) + { + mtag_a = PACK; + } + + // Reordered A not supported now. + if ( ( is_row_major == TRUE ) && ( mtag_a == REORDERED ) ) + { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, \ + "A matrix reordering not supported for row major inputs."); + return; // Error. + } + // Inputs swapped in column major, A becomes B from kernel point of view. + else if ( ( is_column_major == TRUE ) && ( mtag_b == REORDERED ) ) { AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, \ - "A matrix packing/reordering not supported."); + "B matrix reordering not supported for column major inputs."); return; // Error. } @@ -143,31 +180,71 @@ AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32) bli_rntm_init_from_global( &rntm_g ); bli_membrk_rntm_set_membrk( &rntm_g ); + lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( F32F32F32OF32 ); + #ifdef BLIS_ENABLE_OPENMP - lpgemm_f32f32f32of32_openmp_thread_decorator - ( - m, n, k, - a, rs_a, cs_a, mtag_a, - b, rs_b, cs_b, mtag_b, - c, rs_c, cs_c, - alpha, beta, - &rntm_g, - post_op_list, FALSE - ); + // The lpgemm_cntx_t argument will be NULL for f32 since it still uses + // BLIS cntx_t internally. Its a workaround for now and will be replaced + // with lpgemm_cntx_t eventually. + // Swapping inputs to induce row major computation for column major inputs. + if ( is_column_major == TRUE ) + { + lpgemm_f32f32f32of32_openmp_thread_decorator + ( + n, m, k, + b, rs_b, cs_b, mtag_b, + a, rs_a, cs_a, mtag_a, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, FALSE + ); + } + else + { + lpgemm_f32f32f32of32_openmp_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, FALSE + ); + } #else - // Setting pack A by default for non open mp case. + // Setting pack A and B by default for non open mp case. bli_rntm_set_pack_a( 1, &rntm_g ); + bli_rntm_set_pack_b( 1, &rntm_g ); - lpgemm_f32f32f32of32_thread_decorator - ( - m, n, k, - a, rs_a, cs_a, mtag_a, - b, rs_b, cs_b, mtag_b, - c, rs_c, cs_c, - alpha, beta, - &rntm_g, - post_op_list, FALSE - ); + // Swapping inputs to induce row major computation for column major inputs. + if ( is_column_major == TRUE ) + { + lpgemm_f32f32f32of32_thread_decorator + ( + n, m, k, + b, rs_b, cs_b, mtag_b, + a, rs_a, cs_a, mtag_a, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, FALSE + ); + } + else + { + lpgemm_f32f32f32of32_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, FALSE + ); + } #endif AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); diff --git a/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.c b/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.c index 948c1383de..2116e418af 100644 --- a/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.c +++ b/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,6 +34,7 @@ #include "blis.h" #include "aocl_gemm_interface_apis.h" +#include "lpgemm_config.h" #include "lpgemm_utils.h" AOCL_GEMM_GET_REORDER_BUF_SIZE(f32f32f32of32) @@ -43,16 +44,20 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE(f32f32f32of32) return 0; // Error. } - // Check if avx ISA is supported, lpgemm fp32 matmul only works with it. - if ( bli_cpuid_is_avx_supported() == FALSE ) + // Check if AVX2 ISA is supported, lpgemm fp32 matmul only works with it. + if ( bli_cpuid_is_avx2fma3_supported() == FALSE ) { - printf(" AVX2 ISA not supported by processor, cannot perform lpgemm.\n"); + bli_print_msg(" AVX2 ISA not supported by processor, " + "cannot perform f32f32f32 gemm.", __FILE__, __LINE__ ); return 0; // Error. } /* Initialize BLIS. */ bli_init_auto(); + // Initialize lpgemm context. + aocl_lpgemm_init_global_cntx(); + // Query the global cntx. cntx_t* cntx = bli_gks_query_cntx(); @@ -85,16 +90,20 @@ AOCL_GEMM_REORDER(float,f32f32f32of32) return; // Error. } - // Check if avx ISA is supported, lpgemm fp32 matmul only works with it. - if ( bli_cpuid_is_avx_supported() == FALSE ) + // Check if AVX2 ISA is supported, lpgemm fp32 matmul only works with it. + if ( bli_cpuid_is_avx2fma3_supported() == FALSE ) { - printf(" AVX2 ISA not supported by processor, cannot perform lpgemm.\n"); + bli_print_msg(" AVX2 ISA not supported by processor, " + "cannot perform f32f32f32 gemm.", __FILE__, __LINE__ ); return; // Error. } /* Initialize BLIS. */ bli_init_auto(); + // Initialize lpgemm context. + aocl_lpgemm_init_global_cntx(); + // Query the global cntx. cntx_t* cntx = bli_gks_query_cntx(); @@ -122,7 +131,7 @@ AOCL_GEMM_REORDER(float,f32f32f32of32) float* restrict kappa_cast = &one_local; // Set the schema to "row stored column panels" to indicate packing to - // conventional column-stored row panels. + // conventional row-stored column panels. pack_t schema = BLIS_PACKED_COL_PANELS; trans_t transc = BLIS_NO_TRANSPOSE; conj_t conjc = bli_extract_conj( transc ); diff --git a/addon/aocl_gemm/aocl_gemm_interface_apis.h b/addon/aocl_gemm/aocl_gemm_interface_apis.h index 40101cbe6a..718c0c3de2 100644 --- a/addon/aocl_gemm/aocl_gemm_interface_apis.h +++ b/addon/aocl_gemm/aocl_gemm_interface_apis.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -51,6 +51,8 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE(f32f32f32of32); AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s8s32os32); AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s8s16os16); AOCL_GEMM_GET_REORDER_BUF_SIZE(bf16bf16f32of32); +AOCL_GEMM_GET_REORDER_BUF_SIZE(s8s8s32os32); +AOCL_GEMM_GET_REORDER_BUF_SIZE(s8s8s16os16); // Performs reordering of input matrix. Reordering is the process of packing // the entire matrix upfront, so that the benefits of packed matrix is obtained @@ -70,6 +72,8 @@ AOCL_GEMM_REORDER(float,f32f32f32of32); AOCL_GEMM_REORDER(int8_t,u8s8s32os32); AOCL_GEMM_REORDER(int8_t,u8s8s16os16); AOCL_GEMM_REORDER(bfloat16,bf16bf16f32of32); +AOCL_GEMM_REORDER(int8_t,s8s8s32os32); +AOCL_GEMM_REORDER(int8_t,s8s8s16os16); // Only supports matrices in row major format. This api can perform gemm with // both normal as well as reordered B matrix as opposesd to sgemm (only @@ -103,5 +107,9 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32); AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8); AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8); AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16); +AOCL_GEMM_MATMUL(int8_t,int8_t,int32_t,int32_t,s8s8s32os32); +AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int32_t,s8s8s32os8); +AOCL_GEMM_MATMUL(int8_t,int8_t,int16_t,int16_t,s8s8s16os16); +AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int16_t,s8s8s16os8); #endif // AOCL_GEMM_INTERFACE_H diff --git a/addon/aocl_gemm/aocl_gemm_post_ops.h b/addon/aocl_gemm/aocl_gemm_post_ops.h index 86034598ac..70084e741a 100644 --- a/addon/aocl_gemm/aocl_gemm_post_ops.h +++ b/addon/aocl_gemm/aocl_gemm_post_ops.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -41,6 +41,9 @@ typedef enum { RELU = 0, PRELU = 1, + GELU_TANH = 2, + GELU_ERF = 3, + CLIP = 4, } AOCL_ELT_ALGO_TYPE; typedef enum @@ -81,7 +84,7 @@ typedef struct typedef struct { aocl_post_op_sum sum; - aocl_post_op_eltwise eltwise; + aocl_post_op_eltwise* eltwise; //Multiple eltwise allowed. aocl_post_op_bias bias; // eg: seq_length = 2 diff --git a/addon/aocl_gemm/aocl_gemm_s8s8s16os16.c b/addon/aocl_gemm/aocl_gemm_s8s8s16os16.c new file mode 100644 index 0000000000..ca5ee12fc2 --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_s8s8s16os16.c @@ -0,0 +1,170 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_interface_apis.h" +#include "lpgemm_types.h" +#include "lpgemm_5loop_interface_apis.h" +#include "lpgemm_config.h" +#include "lpgemm_thread_decor_openmp.h" +#include "lpgemm_post_ops.h" +#include "lpgemm_utils_s8.h" + +AOCL_GEMM_MATMUL(int8_t,int8_t,int16_t,int16_t,s8s8s16os16) +{ + trans_t blis_transa; + trans_t blis_transb; + + // Check if AVX2 ISA is supported, lpgemm s8s8s16os16 matmul only works with it. + if ( bli_cpuid_is_avx2fma3_supported() == FALSE ) + { + bli_print_msg(" AVX2 ISA not supported by processor, " + "cannot perform s8s8s16 gemm.", __FILE__, __LINE__ ); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + // Null check for pointers. + if ((a == NULL) || (b == NULL) || (c == NULL)) + { + return; // Error. + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans(transa, &blis_transa); + bli_param_map_netlib_to_blis_trans(transb, &blis_transb); + + /* Perform BLAS parameter checking. */ + // Transpose not supported. + if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || + ( blis_transb != BLIS_NO_TRANSPOSE ) ) + { + return; // Error. + } + + // Sanitize order input. + char order_use = + ( ( order == 'r' ) || ( order == 'R' ) || + ( order == 'c' ) || ( order == 'C' ) ) ? + order : 'r'; + if ( ( order_use != 'r' ) && ( order_use != 'R' ) ) + { + return; // Only row major supported. + } + + // Row major input expected with leading dimensions equal to row stride. + if ((lda != k) || (ldb != n) || (ldc != n)) + { + return; // Error. + } + + // Check if dimensions are valid. + if ((m <= 0) || (n <= 0) || (k <= 0) || (lda <= 0) || (ldb <= 0) || (ldc <= 0)) + { + return; // Error. + } + + const inc_t rs_a = lda; + const inc_t cs_a = 1; + const inc_t rs_b = ldb; + const inc_t cs_b = 1; + const inc_t rs_c = ldc; + const inc_t cs_c = 1; + + AOCL_MEMORY_TAG mtag_a; + AOCL_MEMORY_TAG mtag_b; + + bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a); + bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b); + + // B matrix needs to be packed in a certain format in order to be loaded + // and used in VNNI instrution. As such the mtag_b always needs to be either + // packed or reordered. B matrix as it is (unpacked) cannot be used, and + // the mtag_b is set to packed to enable runtime packing. + if (mtag_b == UNPACKED) + { + mtag_b = PACK; + } + + // Only unpacked A supported now. + if (mtag_a != UNPACKED) + { + return; // Error. + } + + // Convert post op struct to post op linked list format. + lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; + lpgemm_translate_to_post_ops_list + ( + post_op_unparsed, post_op_list, + ( void* )c, ( void* )( &order_use ) + ); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global(&rntm_g); + bli_membrk_rntm_set_membrk(&rntm_g); + + lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S16OS16 ); + +#ifdef BLIS_ENABLE_OPENMP + lpgemm_s8s8s16o16_openmp_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, FALSE + ); +#else + lpgemm_s8s8s16o16_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, FALSE + ); +#endif +} diff --git a/addon/aocl_gemm/aocl_gemm_s8s8s16os16_utils.c b/addon/aocl_gemm/aocl_gemm_s8s8s16os16_utils.c new file mode 100644 index 0000000000..92a2663944 --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_s8s8s16os16_utils.c @@ -0,0 +1,137 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_interface_apis.h" +#include "lpgemm_types.h" +#include "lpgemm_config.h" +#include "lpgemm_utils_s8.h" +#include "lpgemm_reorder_s8s16.h" + +AOCL_GEMM_GET_REORDER_BUF_SIZE(s8s8s16os16) +{ + if ((k <= 0) || (n <= 0)) + { + return 0; // Error. + } + + // Check if AVX2 ISA is supported, lpgemm s8s8s16os16 matmul only works with it. + if ( bli_cpuid_is_avx2fma3_supported() == FALSE ) + { + bli_print_msg(" AVX2 ISA not supported by processor, " + "cannot perform s8s8s16 gemm.", __FILE__, __LINE__ ); + return 0; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + AOCL_MATRIX_TYPE input_mat_type; + bli_param_map_char_to_lpmat_type(mat_type, &input_mat_type); + + if (input_mat_type == A_MATRIX) + { + return 0; // A reorder not supported. + } + + // Extra space since packing does width in multiples of 16. The vpmaddubsw + // instruction can be used as long as atleast one ymm register can be fully + // loaded; and since k_dim needs to be at least 2, having n_dim atleast 16 + // should give 2x16=32 elements, enough for 1 ymm register.The padding is + // not rounded to NR (=16), since that would result in memory wastage. + dim_t n_reorder = make_multiple_of_n(n, 16); + + // Extra space since packing does length in multiples of 2. + dim_t k_reorder = make_multiple_of_n(k, 2); + + // Extra memory of n_reorder * sizeof( int16_t ) to store sum of every column of B matrix buffer + siz_t size_req = sizeof(int8_t) * k_reorder * n_reorder + ( n_reorder * sizeof( int16_t )); + + return size_req; +} + +AOCL_GEMM_REORDER(int8_t,s8s8s16os16) +{ + if ((input_buf_addr == NULL) || (reorder_buf_addr == NULL) || + (k <= 0) || (n <= 0) || (ldb < n)) + { + return; // Error. + } + + // Check if AVX2 ISA is supported, lpgemm s8s8s16os16 matmul only works with it. + if ( bli_cpuid_is_avx2fma3_supported() == FALSE ) + { + bli_print_msg(" AVX2 ISA not supported by processor, " + "cannot perform s8s8s16 gemm.", __FILE__, __LINE__ ); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + AOCL_MATRIX_TYPE input_mat_type; + bli_param_map_char_to_lpmat_type(mat_type, &input_mat_type); + + if (input_mat_type == A_MATRIX) + { + return; // A reorder not supported. + } + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global(&rntm_g); + bli_membrk_rntm_set_membrk(&rntm_g); + + lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S16OS16 ); + + // Create dummy b_reorder obj. + lpgemm_obj_t b_reorder; + b_reorder.storage.aligned_buffer = reorder_buf_addr; + + // Create dummy original b obj; + lpgemm_obj_t b; + b.storage.aligned_buffer = (void *)input_buf_addr; + b.rs = ldb; + b.width = n; + b.length = k; + + aocl_reorderb_nr32_s8s8s16o16( &b, &b_reorder, &rntm_g, lcntx_g ); +} diff --git a/addon/aocl_gemm/aocl_gemm_s8s8s16os8.c b/addon/aocl_gemm/aocl_gemm_s8s8s16os8.c new file mode 100644 index 0000000000..a036612c82 --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_s8s8s16os8.c @@ -0,0 +1,170 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_interface_apis.h" +#include "lpgemm_types.h" +#include "lpgemm_5loop_interface_apis.h" +#include "lpgemm_config.h" +#include "lpgemm_thread_decor_openmp.h" +#include "lpgemm_post_ops.h" +#include "lpgemm_utils_s8.h" + +AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int16_t,s8s8s16os8) +{ + trans_t blis_transa; + trans_t blis_transb; + + // Check if AVX2 ISA is supported, lpgemm s8s8s16os16 matmul only works with it. + if ( bli_cpuid_is_avx2fma3_supported() == FALSE ) + { + bli_print_msg(" AVX2 ISA not supported by processor, " + "cannot perform s8s8s16 gemm.", __FILE__, __LINE__ ); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + // Null check for pointers. + if ((a == NULL) || (b == NULL) || (c == NULL)) + { + return; // Error. + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans(transa, &blis_transa); + bli_param_map_netlib_to_blis_trans(transb, &blis_transb); + + /* Perform BLAS parameter checking. */ + // Transpose not supported. + if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || + ( blis_transb != BLIS_NO_TRANSPOSE ) ) + { + return; // Error. + } + + // Sanitize order input. + char order_use = + ( ( order == 'r' ) || ( order == 'R' ) || + ( order == 'c' ) || ( order == 'C' ) ) ? + order : 'r'; + if ( ( order_use != 'r' ) && ( order_use != 'R' ) ) + { + return; // Only row major supported. + } + + // Row major input expected with leading dimensions equal to row stride. + if ((lda != k) || (ldb != n) || (ldc != n)) + { + return; // Error. + } + + // Check if dimensions are valid. + if ((m <= 0) || (n <= 0) || (k <= 0) || (lda <= 0) || (ldb <= 0) || (ldc <= 0)) + { + return; // Error. + } + + const inc_t rs_a = lda; + const inc_t cs_a = 1; + const inc_t rs_b = ldb; + const inc_t cs_b = 1; + const inc_t rs_c = ldc; + const inc_t cs_c = 1; + + AOCL_MEMORY_TAG mtag_a; + AOCL_MEMORY_TAG mtag_b; + + bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a); + bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b); + + // B matrix needs to be packed in a certain format in order to be loaded + // and used in VNNI instrution. As such the mtag_b always needs to be either + // packed or reordered. B matrix as it is (unpacked) cannot be used, and + // the mtag_b is set to packed to enable runtime packing. + if (mtag_b == UNPACKED) + { + mtag_b = PACK; + } + + // Only unpacked A supported now. + if (mtag_a != UNPACKED) + { + return; // Error. + } + + // Convert post op struct to post op linked list format. + lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; + lpgemm_translate_to_post_ops_list + ( + post_op_unparsed, post_op_list, + ( void* )c, ( void* )( &order_use ) + ); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global(&rntm_g); + bli_membrk_rntm_set_membrk(&rntm_g); + + lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S16OS16 ); + +#ifdef BLIS_ENABLE_OPENMP + lpgemm_s8s8s16o16_openmp_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + ( int16_t* )c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, TRUE + ); +#else + lpgemm_s8s8s16o16_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + ( int16_t* )c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, TRUE + ); +#endif +} diff --git a/addon/aocl_gemm/aocl_gemm_s8s8s32os32.c b/addon/aocl_gemm/aocl_gemm_s8s8s32os32.c new file mode 100644 index 0000000000..b9ddecdba5 --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_s8s8s32os32.c @@ -0,0 +1,171 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_interface_apis.h" +#include "lpgemm_types.h" +#include "lpgemm_post_ops.h" +#include "lpgemm_thread_decor_openmp.h" +#include "lpgemm_5loop_interface_apis.h" +#include "lpgemm_config.h" +#include "lpgemm_utils_s8.h" + +AOCL_GEMM_MATMUL(int8_t,int8_t,int32_t,int32_t,s8s8s32os32) +{ + trans_t blis_transa; + trans_t blis_transb; + + // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) + { + bli_print_msg(" AVX512_VNNI ISA not supported by processor, " + "cannot perform s8s8s32 gemm.", __FILE__, __LINE__ ); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + // Null check for pointers. + if ( ( a == NULL ) || ( b == NULL ) || ( c == NULL ) ) + { + return; // Error. + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); + bli_param_map_netlib_to_blis_trans( transb, &blis_transb ); + + /* Perform BLAS parameter checking. */ + // Transpose not supported. + if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || + ( blis_transb != BLIS_NO_TRANSPOSE ) ) + { + return; // Error. + } + + // Sanitize order input. + char order_use = + ( ( order == 'r' ) || ( order == 'R' ) || + ( order == 'c' ) || ( order == 'C' ) ) ? + order : 'r'; + if ( ( order_use != 'r' ) && ( order_use != 'R' ) ) + { + return; // Only row major supported. + } + + // Row major input expected with leading dimensions equal to row stride. + if ( ( lda != k ) || ( ldb != n ) || ( ldc != n ) ) + { + return; // Error. + } + + // Check if dimensions are valid. + if ( ( m <= 0) || ( n <= 0 ) || ( k <= 0 ) || + ( lda <= 0 ) || ( ldb <= 0 ) || ( ldc <= 0 ) ) + { + return; // Error. + } + + const inc_t rs_a = lda; + const inc_t cs_a = 1; + const inc_t rs_b = ldb; + const inc_t cs_b = 1; + const inc_t rs_c = ldc; + const inc_t cs_c = 1; + + AOCL_MEMORY_TAG mtag_a; + AOCL_MEMORY_TAG mtag_b; + + bli_param_map_char_to_lpmtag( mem_format_a, &mtag_a ); + bli_param_map_char_to_lpmtag( mem_format_b, &mtag_b ); + + // B matrix needs to be packed in a certain format in order to be loaded + // and used in VNNI instrution. As such the mtag_b always needs to be either + // packed or reordered. B matrix as it is (unpacked) cannot be used, and + // the mtag_b is set to packed to enable runtime packing. + if ( mtag_b == UNPACKED ) + { + mtag_b = PACK; + } + + // Only unpacked A supported now. + if ( mtag_a != UNPACKED ) + { + return; // Error. + } + + // Convert post op struct to post op linked list format. + lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; + lpgemm_translate_to_post_ops_list + ( + post_op_unparsed, post_op_list, + ( void* )c, ( void* )( &order_use ) + ); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global( &rntm_g ); + bli_membrk_rntm_set_membrk( &rntm_g ); + + lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S32OS32 ); + +#ifdef BLIS_ENABLE_OPENMP + lpgemm_s8s8s32o32_openmp_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, FALSE + ); +#else + lpgemm_s8s8s32o32_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, FALSE + ); +#endif +} diff --git a/addon/aocl_gemm/aocl_gemm_s8s8s32os32_utils.c b/addon/aocl_gemm/aocl_gemm_s8s8s32os32_utils.c new file mode 100644 index 0000000000..4c41d8e184 --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_s8s8s32os32_utils.c @@ -0,0 +1,137 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_interface_apis.h" +#include "lpgemm_types.h" +#include "lpgemm_config.h" +#include "lpgemm_utils_s8.h" +#include "lpgemm_reorder_s8.h" + +AOCL_GEMM_GET_REORDER_BUF_SIZE(s8s8s32os32) +{ + if ( ( k <= 0 ) || ( n <= 0 ) ) + { + return 0; // Error. + } + + // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) + { + bli_print_msg(" AVX512_VNNI ISA not supported by processor, " + "cannot perform s8s8s32 gemm.", __FILE__, __LINE__ ); + return 0; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + AOCL_MATRIX_TYPE input_mat_type; + bli_param_map_char_to_lpmat_type( mat_type, &input_mat_type ); + + if ( input_mat_type == A_MATRIX ) + { + return 0; // A reorder not supported. + } + + // Extra space since packing does width in multiples of 16. The vnni + // instruction can be used as long as atleast one zmm register can be fully + // loaded; and since k_dim needs to be atleast 4, having n_dim atleast 16 + // should give 4x16=64 elements, enough for 1 zmm register.The padding is + // not rounded to NR (=64), since that would result in memory wastage. + dim_t n_reorder = make_multiple_of_n( n, 16 ); + + // Extra space since packing does length in multiples of 4. + dim_t k_reorder = make_multiple_of_n( k, 4 ); + + //extra memory of n_reorder * sizeof(int32_t) to store sum of every column of B matrix buffer + siz_t size_req = sizeof( int8_t ) * k_reorder * n_reorder + ( n_reorder * sizeof( int32_t ) ); + + return size_req; +} + +AOCL_GEMM_REORDER(int8_t,s8s8s32os32) +{ + if ( ( input_buf_addr == NULL ) || ( reorder_buf_addr == NULL ) || + ( k <= 0 ) || ( n <= 0 ) || ( ldb < n ) ) + { + return; // Error. + } + + // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) + { + bli_print_msg(" AVX512_VNNI ISA not supported by processor, " + "cannot perform s8s8s32 gemm.", __FILE__, __LINE__ ); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + AOCL_MATRIX_TYPE input_mat_type; + bli_param_map_char_to_lpmat_type( mat_type, &input_mat_type ); + + if ( input_mat_type == A_MATRIX ) + { + return; // A reorder not supported. + } + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global( &rntm_g ); + bli_membrk_rntm_set_membrk( &rntm_g ); + + lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S32OS32 ); + + // Create dummy b_reorder obj. + lpgemm_obj_t b_reorder; + b_reorder.storage.aligned_buffer = reorder_buf_addr; + + // Create dummy original b obj; + lpgemm_obj_t b; + b.storage.aligned_buffer = ( void* )input_buf_addr; + b.rs = ldb; + b.width = n; + b.length = k; + + reorderb_nr64_s8s8s32o32( &b, &b_reorder, &rntm_g, lcntx_g ); +} diff --git a/addon/aocl_gemm/aocl_gemm_s8s8s32os8.c b/addon/aocl_gemm/aocl_gemm_s8s8s32os8.c new file mode 100644 index 0000000000..7abc392a4e --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_s8s8s32os8.c @@ -0,0 +1,171 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_interface_apis.h" +#include "lpgemm_types.h" +#include "lpgemm_post_ops.h" +#include "lpgemm_thread_decor_openmp.h" +#include "lpgemm_5loop_interface_apis.h" +#include "lpgemm_config.h" +#include "lpgemm_utils_s8.h" + +AOCL_GEMM_MATMUL(int8_t,int8_t,int8_t,int32_t,s8s8s32os8) +{ + trans_t blis_transa; + trans_t blis_transb; + + // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) + { + bli_print_msg(" AVX512_VNNI ISA not supported by processor, " + "cannot perform s8s8s32 gemm.", __FILE__, __LINE__ ); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + // Null check for pointers. + if ( ( a == NULL ) || ( b == NULL ) || ( c == NULL ) ) + { + return; // Error. + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); + bli_param_map_netlib_to_blis_trans( transb, &blis_transb ); + + /* Perform BLAS parameter checking. */ + // Transpose not supported. + if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || + ( blis_transb != BLIS_NO_TRANSPOSE ) ) + { + return; // Error. + } + + // Sanitize order input. + char order_use = + ( ( order == 'r' ) || ( order == 'R' ) || + ( order == 'c' ) || ( order == 'C' ) ) ? + order : 'r'; + if ( ( order_use != 'r' ) && ( order_use != 'R' ) ) + { + return; // Only row major supported. + } + + // Row major input expected with leading dimensions equal to row stride. + if ( ( lda != k ) || ( ldb != n ) || ( ldc != n ) ) + { + return; // Error. + } + + // Check if dimensions are valid. + if ( ( m <= 0) || ( n <= 0 ) || ( k <= 0 ) || + ( lda <= 0 ) || ( ldb <= 0 ) || ( ldc <= 0 ) ) + { + return; // Error. + } + + const inc_t rs_a = lda; + const inc_t cs_a = 1; + const inc_t rs_b = ldb; + const inc_t cs_b = 1; + const inc_t rs_c = ldc; + const inc_t cs_c = 1; + + AOCL_MEMORY_TAG mtag_a; + AOCL_MEMORY_TAG mtag_b; + + bli_param_map_char_to_lpmtag( mem_format_a, &mtag_a ); + bli_param_map_char_to_lpmtag( mem_format_b, &mtag_b ); + + // B matrix needs to be packed in a certain format in order to be loaded + // and used in VNNI instrution. As such the mtag_b always needs to be either + // packed or reordered. B matrix as it is (unpacked) cannot be used, and + // the mtag_b is set to packed to enable runtime packing. + if ( mtag_b == UNPACKED ) + { + mtag_b = PACK; + } + + // Only unpacked A supported now. + if ( mtag_a != UNPACKED ) + { + return; // Error. + } + + // Convert post op struct to post op linked list format. + lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; + lpgemm_translate_to_post_ops_list + ( + post_op_unparsed, post_op_list, + ( void* )c, ( void* )( &order_use ) + ); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global( &rntm_g ); + bli_membrk_rntm_set_membrk( &rntm_g ); + + lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S32OS32 ); + +#ifdef BLIS_ENABLE_OPENMP + lpgemm_s8s8s32o32_openmp_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + ( int32_t* )c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, TRUE + ); +#else + lpgemm_s8s8s32o32_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + ( int32_t* )c, rs_c, cs_c, + alpha, beta, + &rntm_g, lcntx_g, + post_op_list, TRUE + ); +#endif +} diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c index 1c6b0899ad..f851a283d5 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -46,10 +46,11 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) trans_t blis_transa; trans_t blis_transb; - // Check if avx ISA is supported, lpgemm u8s8s16os16 matmul only works with it. - if ( bli_cpuid_is_avx_supported() == FALSE ) + // Check if AVX2 ISA is supported, lpgemm u8s8s16os16 matmul only works with it. + if ( bli_cpuid_is_avx2fma3_supported() == FALSE ) { - printf(" AVX2 ISA not supported by processor, cannot perform lpgemm.\n"); + bli_print_msg(" AVX2 ISA not supported by processor, " + "cannot perform u8s8s16 gemm.", __FILE__, __LINE__ ); return; // Error. } @@ -141,6 +142,8 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) bli_rntm_init_from_global(&rntm_g); bli_membrk_rntm_set_membrk(&rntm_g); + lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S16OS16 ); + #ifdef BLIS_ENABLE_OPENMP lpgemm_u8s8s16o16_openmp_thread_decorator ( @@ -149,7 +152,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) b, rs_b, cs_b, mtag_b, c, rs_c, cs_c, alpha, beta, - &rntm_g, + &rntm_g, lcntx_g, post_op_list, FALSE ); #else @@ -160,7 +163,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) b, rs_b, cs_b, mtag_b, c, rs_c, cs_c, alpha, beta, - &rntm_g, + &rntm_g, lcntx_g, post_op_list, FALSE ); #endif diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c b/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c index 5cadd206d5..98d8828f22 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -46,10 +46,11 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s8s16os16) return 0; // Error. } - // Check if avx ISA is supported, lpgemm u8s8s16os16 matmul only works with it. - if ( bli_cpuid_is_avx_supported() == FALSE ) + // Check if AVX2 ISA is supported, lpgemm u8s8s16os16 matmul only works with it. + if ( bli_cpuid_is_avx2fma3_supported() == FALSE ) { - printf(" AVX2 ISA not supported by processor, cannot perform lpgemm.\n"); + bli_print_msg(" AVX2 ISA not supported by processor, " + "cannot perform u8s8s16 gemm.", __FILE__, __LINE__ ); return 0; // Error. } @@ -68,8 +69,8 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s8s16os16) } // Extra space since packing does width in multiples of 16. The vpmaddubsw - // instruction can be used as long as atleast one ymm register can be fully - // loaded; and since k_dim needs to be at least 2, having n_dim atleast 16 + // instruction can be used as long as at least one ymm register can be fully + // loaded; and since k_dim needs to be at least 2, having n_dim at least 16 // should give 2x16=32 elements, enough for 1 ymm register.The padding is // not rounded to NR (=16), since that would result in memory wastage. dim_t n_reorder = make_multiple_of_n(n, 16); @@ -90,10 +91,11 @@ AOCL_GEMM_REORDER(int8_t,u8s8s16os16) return; // Error. } - // Check if avx ISA is supported, lpgemm u8s8s16os16 matmul only works with it. - if ( bli_cpuid_is_avx_supported() == FALSE ) + // Check if AVX2 ISA is supported, lpgemm u8s8s16os16 matmul only works with it. + if ( bli_cpuid_is_avx2fma3_supported() == FALSE ) { - printf(" AVX2 ISA not supported by processor, cannot perform lpgemm.\n"); + bli_print_msg(" AVX2 ISA not supported by processor, " + "cannot perform u8s8s16 gemm.", __FILE__, __LINE__ ); return; // Error. } @@ -111,6 +113,14 @@ AOCL_GEMM_REORDER(int8_t,u8s8s16os16) return; // A reorder not supported. } + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global(&rntm_g); + bli_membrk_rntm_set_membrk(&rntm_g); + + lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S16OS16 ); + // Create dummy b_reorder obj. lpgemm_obj_t b_reorder; b_reorder.storage.aligned_buffer = reorder_buf_addr; @@ -122,5 +132,5 @@ AOCL_GEMM_REORDER(int8_t,u8s8s16os16) b.width = n; b.length = k; - aocl_reorderb_nr32_u8s8s16o16(&b, &b_reorder); + aocl_reorderb_nr32_u8s8s16o16( &b, &b_reorder, &rntm_g, lcntx_g ); } diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os8.c b/addon/aocl_gemm/aocl_gemm_u8s8s16os8.c index fed10c1e01..c4ca0ac572 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s16os8.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os8.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -46,10 +46,11 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8) trans_t blis_transa; trans_t blis_transb; - // Check if avx ISA is supported, lpgemm u8s8s16os16 matmul only works with it. - if ( bli_cpuid_is_avx_supported() == FALSE ) + // Check if AVX2 ISA is supported, lpgemm u8s8s16os16 matmul only works with it. + if ( bli_cpuid_is_avx2fma3_supported() == FALSE ) { - printf(" AVX2 ISA not supported by processor, cannot perform lpgemm.\n"); + bli_print_msg(" AVX2 ISA not supported by processor, " + "cannot perform u8s8s16 gemm.", __FILE__, __LINE__ ); return; // Error. } @@ -141,6 +142,8 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8) bli_rntm_init_from_global(&rntm_g); bli_membrk_rntm_set_membrk(&rntm_g); + lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S16OS16 ); + #ifdef BLIS_ENABLE_OPENMP lpgemm_u8s8s16o16_openmp_thread_decorator ( @@ -149,7 +152,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8) b, rs_b, cs_b, mtag_b, ( int16_t* )c, rs_c, cs_c, alpha, beta, - &rntm_g, + &rntm_g, lcntx_g, post_op_list, TRUE ); #else @@ -160,7 +163,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8) b, rs_b, cs_b, mtag_b, ( int16_t* )c, rs_c, cs_c, alpha, beta, - &rntm_g, + &rntm_g, lcntx_g, post_op_list, TRUE ); #endif diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c b/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c index 39fd49bca4..5580001d69 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -49,7 +49,8 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32) // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) { - printf(" AVX512_VNNI ISA not supported by processor, cannot perform lpgemm.\n"); + bli_print_msg(" AVX512_VNNI ISA not supported by processor, " + "cannot perform u8s8s32 gemm.", __FILE__, __LINE__ ); return; // Error. } @@ -142,6 +143,8 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32) bli_rntm_init_from_global( &rntm_g ); bli_membrk_rntm_set_membrk( &rntm_g ); + lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S32OS32 ); + #ifdef BLIS_ENABLE_OPENMP lpgemm_u8s8s32o32_openmp_thread_decorator ( @@ -150,7 +153,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32) b, rs_b, cs_b, mtag_b, c, rs_c, cs_c, alpha, beta, - &rntm_g, + &rntm_g, lcntx_g, post_op_list, FALSE ); #else @@ -161,7 +164,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32) b, rs_b, cs_b, mtag_b, c, rs_c, cs_c, alpha, beta, - &rntm_g, + &rntm_g, lcntx_g, post_op_list, FALSE ); #endif diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.c b/addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.c index 11f9f6937a..20f0b322d9 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -49,7 +49,8 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s8s32os32) // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) { - printf(" AVX512_VNNI ISA not supported by processor, cannot perform lpgemm.\n"); + bli_print_msg(" AVX512_VNNI ISA not supported by processor, " + "cannot perform u8s8s32 gemm.", __FILE__, __LINE__ ); return 0; // Error. } @@ -68,8 +69,8 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s8s32os32) } // Extra space since packing does width in multiples of 16. The vnni - // instruction can be used as long as atleast one zmm register can be fully - // loaded; and since k_dim needs to be atleast 4, having n_dim atleast 16 + // instruction can be used as long as at least one zmm register can be fully + // loaded; and since k_dim needs to be at least 4, having n_dim at least 16 // should give 4x16=64 elements, enough for 1 zmm register.The padding is // not rounded to NR (=64), since that would result in memory wastage. dim_t n_reorder = make_multiple_of_n( n, 16 ); @@ -93,7 +94,8 @@ AOCL_GEMM_REORDER(int8_t,u8s8s32os32) // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) { - printf(" AVX512_VNNI ISA not supported by processor, cannot perform lpgemm.\n"); + bli_print_msg(" AVX512_VNNI ISA not supported by processor, " + "cannot perform u8s8s32 gemm.", __FILE__, __LINE__ ); return; // Error. } @@ -111,6 +113,14 @@ AOCL_GEMM_REORDER(int8_t,u8s8s32os32) return; // A reorder not supported. } + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global( &rntm_g ); + bli_membrk_rntm_set_membrk( &rntm_g ); + + lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S32OS32 ); + // Create dummy b_reorder obj. lpgemm_obj_t b_reorder; b_reorder.storage.aligned_buffer = reorder_buf_addr; @@ -122,5 +132,5 @@ AOCL_GEMM_REORDER(int8_t,u8s8s32os32) b.width = n; b.length = k; - reorderb_nr64_u8s8s32o32( &b, &b_reorder ); + reorderb_nr64_u8s8s32o32( &b, &b_reorder, &rntm_g, lcntx_g ); } diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os8.c b/addon/aocl_gemm/aocl_gemm_u8s8s32os8.c index e4a4ce3f2d..55f062ee8f 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s32os8.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s32os8.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -49,7 +49,8 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) { - printf(" AVX512_VNNI ISA not supported by processor, cannot perform lpgemm.\n"); + bli_print_msg(" AVX512_VNNI ISA not supported by processor, " + "cannot perform u8s8s32 gemm.", __FILE__, __LINE__ ); return; // Error. } @@ -142,6 +143,8 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) bli_rntm_init_from_global( &rntm_g ); bli_membrk_rntm_set_membrk( &rntm_g ); + lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( U8S8S32OS32 ); + #ifdef BLIS_ENABLE_OPENMP lpgemm_u8s8s32o32_openmp_thread_decorator ( @@ -150,7 +153,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) b, rs_b, cs_b, mtag_b, ( int32_t* )c, rs_c, cs_c, alpha, beta, - &rntm_g, + &rntm_g, lcntx_g, post_op_list, TRUE ); #else @@ -161,7 +164,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) b, rs_b, cs_b, mtag_b, ( int32_t* )c, rs_c, cs_c, alpha, beta, - &rntm_g, + &rntm_g, lcntx_g, post_op_list, TRUE ); #endif diff --git a/addon/aocl_gemm/aocl_util_interface_apis.h b/addon/aocl_gemm/aocl_util_interface_apis.h new file mode 100644 index 0000000000..d2983b8a64 --- /dev/null +++ b/addon/aocl_gemm/aocl_util_interface_apis.h @@ -0,0 +1,50 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef AOCL_UTIL_INTERFACE_H +#define AOCL_UTIL_INTERFACE_H + +#define AOCL_UTIL_L1_OP(V_type,OP_type) \ +BLIS_EXPORT_ADDON void aocl_ ## OP_type \ + ( \ + const dim_t n, \ + V_type* x, \ + const inc_t incx \ + ) \ + +AOCL_UTIL_L1_OP(float,gelu_tanh_f32); +AOCL_UTIL_L1_OP(float,gelu_erf_f32); +AOCL_UTIL_L1_OP(float,softmax_f32); + +#endif //AOCL_UTIL_INTERFACE_H diff --git a/addon/aocl_gemm/aocl_util_l1_ops.c b/addon/aocl_gemm/aocl_util_l1_ops.c new file mode 100644 index 0000000000..11a4b83078 --- /dev/null +++ b/addon/aocl_gemm/aocl_util_l1_ops.c @@ -0,0 +1,114 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_util_interface_apis.h" +#include "lpgemm_types.h" +#include "lpgemm_config.h" +#include "lpgemm_utils_kernels.h" + +AOCL_UTIL_L1_OP(float,gelu_tanh_f32) +{ + // Check if AVX2 ISA is supported, lpgemm u8s8s16os16 matmul only works with it. + if ( bli_cpuid_is_avx2fma3_supported() == FALSE ) + { + bli_print_msg(" AVX2 ISA not supported by processor, AOCL GEMM " + "utility l1 operations not supported.", __FILE__, __LINE__ ); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + if ( ( n <= 0 ) || ( x == NULL ) || ( incx <= 0 ) ) + { + return; // Error. + } + + lpgemm_util_cntx_t* lutil_cntx_g = lpgemm_util_get_global_cntx_obj( F32_GELU_TANH ); + ( ( lpgemm_util_l1_op_f32_kernel_t )lutil_cntx_g->kern_fun_ptr )( n, x, incx ); +} + +AOCL_UTIL_L1_OP(float,gelu_erf_f32) +{ + // Check if AVX2 ISA is supported, lpgemm u8s8s16os16 matmul only works with it. + if ( bli_cpuid_is_avx2fma3_supported() == FALSE ) + { + bli_print_msg(" AVX2 ISA not supported by processor, AOCL GEMM " + "utility l1 operations not supported.", __FILE__, __LINE__ ); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + if ( ( n <= 0 ) || ( x == NULL ) || ( incx <= 0 ) ) + { + return; // Error. + } + + lpgemm_util_cntx_t* lutil_cntx_g = lpgemm_util_get_global_cntx_obj( F32_GELU_ERF ); + ( ( lpgemm_util_l1_op_f32_kernel_t )lutil_cntx_g->kern_fun_ptr )( n, x, incx ); +} + +AOCL_UTIL_L1_OP(float,softmax_f32) +{ + // Check if AVX2 ISA is supported, lpgemm u8s8s16os16 matmul only works with it. + if ( bli_cpuid_is_avx2fma3_supported() == FALSE ) + { + bli_print_msg(" AVX2 ISA not supported by processor, AOCL GEMM " + "utility l1 operations not supported.", __FILE__, __LINE__ ); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + if ( ( n <= 0 ) || ( x == NULL ) || ( incx <= 0 ) ) + { + return; // Error. + } + + lpgemm_util_cntx_t* lutil_cntx_g = lpgemm_util_get_global_cntx_obj( F32_SOFTMAX ); + ( ( lpgemm_util_l1_op_f32_kernel_t )lutil_cntx_g->kern_fun_ptr )( n, x, incx ); +} diff --git a/addon/aocl_gemm/config/lpgemm_blksz_map.h b/addon/aocl_gemm/config/lpgemm_blksz_map.h new file mode 100644 index 0000000000..9991a3eb70 --- /dev/null +++ b/addon/aocl_gemm/config/lpgemm_blksz_map.h @@ -0,0 +1,55 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_BLKSZ_MAP_H +#define LPGEMM_BLKSZ_MAP_H + +// The XMACRO follows the format ID,MC,NC,KC,MR,NR,PACKA_RS,PACKA_CS,PACKB_RS,PACKB_CS: +// ID = One of the AOCL_OPERATION_TYPE enum. + +#define LPGEMM_BLKSZ_MAP_ZEN4 \ + XMACRO(U8S8S16OS16, 252, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \ + XMACRO(U8S8S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \ + XMACRO(BF16BF16F32OF32, 144, 1024, 2048, 6, 64, 0, 0, 2*64, 64/2) \ + XMACRO(S8S8S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \ + XMACRO(S8S8S16OS16, 252, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \ + +#define LPGEMM_BLKSZ_MAP_ZEN \ + XMACRO(U8S8S16OS16, 252, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \ + XMACRO(U8S8S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \ + XMACRO(BF16BF16F32OF32, 144, 1024, 2048, 6, 64, 0, 0, 2*64, 64/2) \ + XMACRO(S8S8S32OS32, 144, 1024, 2048, 6, 64, 4, 24, 4*64, 64) \ + XMACRO(S8S8S16OS16, 252, 2048, 2048, 6, 32, 0, 0, 2*32, 32) \ + +#endif //LPGEMM_BLKSZ_MAP_H diff --git a/addon/aocl_gemm/config/lpgemm_config.c b/addon/aocl_gemm/config/lpgemm_config.c new file mode 100644 index 0000000000..0dad8c88a7 --- /dev/null +++ b/addon/aocl_gemm/config/lpgemm_config.c @@ -0,0 +1,297 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_config.h" +#include "lpgemm_func_map.h" +#include "lpgemm_blksz_map.h" +#include "lpgemm_kernels.h" +#include "lpgemm_packb_bf16.h" +#include "lpgemm_packb_s16.h" +#include "lpgemm_packa.h" +#include "lpgemm_packb.h" +#include "lpgemm_packa_s8.h" +#include "lpgemm_packb_s8.h" +#include "lpgemm_packb_s8s16.h" + +static lpgemm_cntx_t global_cntx_t_list[AOCL_OPERATION_TYPE_LEN] \ + __attribute__((aligned(64))); //Only one op type supported now. +static lpgemm_util_cntx_t global_util_cntx_t_list[AOCL_UTIL_OPERATION_TYPE_LEN] \ + __attribute__((aligned(64))); //Only post-ops like utils. + +static bli_pthread_once_t once_check_lpgemm_func_map_init = BLIS_PTHREAD_ONCE_INIT; + +static void _lpgemm_util_cntx_init_func_map() +{ +#define UMACRO(ID,FUNC_PTR) global_util_cntx_t_list[ID].kern_fun_ptr = FUNC_PTR; + + global_util_cntx_t_list[F32_GELU_TANH].kern_fun_ptr = NULL; + global_util_cntx_t_list[F32_GELU_ERF].kern_fun_ptr = NULL; + + // Kernel dispatch object factory. + if ( bli_cpuid_is_avx512bf16_supported() == TRUE ) + { +#ifdef BLIS_KERNELS_ZEN4 + LPGEMM_UTIL_KERN_FUNC_MAP_AVX512_VNNI_BF16 +#endif + } + else if ( bli_cpuid_is_avx512vnni_supported() == TRUE ) + { +#ifdef BLIS_KERNELS_ZEN4 + LPGEMM_UTIL_KERN_FUNC_MAP_AVX512_VNNI +#endif + } + else if ( bli_cpuid_is_avx2fma3_supported() == TRUE ) + { +#ifdef BLIS_KERNELS_ZEN3 + LPGEMM_UTIL_KERN_FUNC_MAP_AVX2 +#endif + } + +#undef UMACRO +} + +static void _lpgemm_cntx_init_func_map() +{ +#define KMACRO(ID,FUNC_PTR) global_cntx_t_list[ID].kern_fun_ptr = FUNC_PTR; +#define PAMACRO(ID,FUNC_PTR) global_cntx_t_list[ID].packa_fun_ptr = FUNC_PTR; +#define PBMACRO(ID,FUNC_PTR) global_cntx_t_list[ID].packb_fun_ptr = FUNC_PTR; + + //TODO: Default initialize with reference kernels so that kernel pointer + // will be valid even in case none of the zen optimized kernels are + // available. This scenario could happen if the addon was built using + // a different arch config (eg: skx). + + global_cntx_t_list[U8S8S16OS16].kern_fun_ptr = NULL; + global_cntx_t_list[U8S8S32OS32].kern_fun_ptr = NULL; + global_cntx_t_list[F32F32F32OF32].kern_fun_ptr = NULL; + global_cntx_t_list[BF16BF16F32OF32].kern_fun_ptr = NULL; + + // Kernel dispatch object factory. + if ( bli_cpuid_is_avx512bf16_supported() == TRUE ) + { +#ifdef BLIS_KERNELS_ZEN4 + LPGEMM_KERN_FUNC_MAP_AVX512_VNNI_BF16 + LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI_BF16 + LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI_BF16 +#endif + } + else if ( bli_cpuid_is_avx512vnni_supported() == TRUE ) + { +#ifdef BLIS_KERNELS_ZEN4 + LPGEMM_KERN_FUNC_MAP_AVX512_VNNI + LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI + LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI +#endif + } + else if ( bli_cpuid_is_avx2fma3_supported() == TRUE ) + { +#ifdef BLIS_KERNELS_ZEN3 + LPGEMM_KERN_FUNC_MAP_AVX2 + LPGEMM_PACKA_FUNC_MAP_AVX2 + LPGEMM_PACKB_FUNC_MAP_AVX2 +#endif + } + // If built with a config not supporting zen3/zen4/amdzen, error out + // since reference kernels are not available. + if ( global_cntx_t_list[F32F32F32OF32].kern_fun_ptr == NULL ) + { + bli_print_msg( "AOCL_GEMM is not compiled using correct Zen config." + " Compile using zen3/zen4/amdzen config.", + __FILE__, __LINE__ ); + bli_abort(); + } + +#undef PBMACRO +#undef PAMACRO +#undef KMACRO +} + +BLIS_INLINE void lpgemm_set_block_sizes_global_cntx + ( + AOCL_OPERATION_TYPE op_type, + dim_t MC, + dim_t NC, + dim_t KC, + dim_t MR, + dim_t NR + ) +{ + global_cntx_t_list[op_type].blksz.MC = MC; + global_cntx_t_list[op_type].blksz.NC = NC; + global_cntx_t_list[op_type].blksz.KC = KC; + global_cntx_t_list[op_type].blksz.MR = MR; + global_cntx_t_list[op_type].blksz.NR = NR; +} + +BLIS_INLINE void lpgemm_set_pack_strides_global_cntx + ( + AOCL_OPERATION_TYPE op_type, + dim_t packa_rs, + dim_t packa_cs, + dim_t packb_rs, + dim_t packb_cs + ) +{ + global_cntx_t_list[op_type].pack_s.packa_rs = packa_rs; + global_cntx_t_list[op_type].pack_s.packa_cs = packa_cs; + global_cntx_t_list[op_type].pack_s.packb_rs = packb_rs; + global_cntx_t_list[op_type].pack_s.packb_cs = packb_cs; +} + +static void _lpgemm_cntx_init_blksz_map() +{ +#define XMACRO(ID,MC,NC,KC,MR,NR,PACKA_RS,PACKA_CS,PACKB_RS,PACKB_CS) \ + lpgemm_set_block_sizes_global_cntx(ID, MC, NC, KC, MR, NR); \ + lpgemm_set_pack_strides_global_cntx(ID, PACKA_RS, PACKA_CS, PACKB_RS, PACKB_CS); + + // Ideally the blocksize needs to be set based on arch id. However + // since this code is also expected to work on other vendor machines, + // the blocksize for a particular version of zen id is generalized + // for all machines that support the ISA supported by that particular + // zen id. + if ( bli_cpuid_is_avx512vnni_supported() == TRUE ) + { + LPGEMM_BLKSZ_MAP_ZEN4 + } + else if ( bli_cpuid_is_avx2fma3_supported() == TRUE ) + { + LPGEMM_BLKSZ_MAP_ZEN + } + else + { + LPGEMM_BLKSZ_MAP_ZEN + } + +#undef XMACRO +} + +static void lpgemm_cntx_init_map() +{ + _lpgemm_cntx_init_func_map(); + _lpgemm_cntx_init_blksz_map(); + _lpgemm_util_cntx_init_func_map(); +} + +// Sets default block sizes for lpgemm. Currently only u8s8s32 supported. +void aocl_lpgemm_init_global_cntx() +{ + bli_pthread_once + ( + &once_check_lpgemm_func_map_init, + lpgemm_cntx_init_map + ); +} + +lpgemm_cntx_t* lpgemm_get_global_cntx_obj( AOCL_OPERATION_TYPE op ) +{ + return &global_cntx_t_list[op]; +} + +lpgemm_util_cntx_t* lpgemm_util_get_global_cntx_obj( AOCL_UTIL_OPERATION_TYPE op ) +{ + return &global_util_cntx_t_list[op]; +} + +dim_t lpgemm_get_block_size_MC_global_cntx( AOCL_OPERATION_TYPE op_type ) +{ + return global_cntx_t_list[op_type].blksz.MC; +} + +dim_t lpgemm_get_block_size_NC_global_cntx( AOCL_OPERATION_TYPE op_type ) +{ + return global_cntx_t_list[op_type].blksz.NC; +} + +dim_t lpgemm_get_block_size_KC_global_cntx( AOCL_OPERATION_TYPE op_type ) +{ + return global_cntx_t_list[op_type].blksz.KC; +} + +dim_t lpgemm_get_block_size_NR_global_cntx( AOCL_OPERATION_TYPE op_type ) +{ + return global_cntx_t_list[op_type].blksz.NR; +} + +dim_t lpgemm_get_block_size_MR_global_cntx( AOCL_OPERATION_TYPE op_type ) +{ + return global_cntx_t_list[op_type].blksz.MR; +} + +void lpgemm_get_packa_strides( lpgemm_cntx_t* lcntx, dim_t* rs, dim_t* cs ) +{ + *rs = lcntx->pack_s.packa_rs; + *cs = lcntx->pack_s.packa_cs; +} + +void lpgemm_get_packb_strides( lpgemm_cntx_t* lcntx, dim_t* rs, dim_t* cs ) +{ + *rs = lcntx->pack_s.packb_rs; + *cs = lcntx->pack_s.packb_cs; +} + +void lpgemm_mod_block_size_s16 + ( + dim_t m, + dim_t n, + dim_t k, + dim_t* MC, + dim_t* NC, + dim_t* KC + ) +{ + const dim_t range[4] = {1024, 512, 256, 128}; + + if (n < *NC) + { + for (dim_t i = 0; i < 4; ++i) + { + if (n <= range[i]) + { + *NC = range[i]; + } + } + } + + if (k < *KC) + { + for (dim_t i = 0; i < 4; ++i) + { + if (k <= range[i]) + { + *KC = range[i]; + } + } + } +} diff --git a/addon/aocl_gemm/frame/lpgemm_config.h b/addon/aocl_gemm/config/lpgemm_config.h similarity index 75% rename from addon/aocl_gemm/frame/lpgemm_config.h rename to addon/aocl_gemm/config/lpgemm_config.h index 7e7f3bb2ad..91863e416a 100644 --- a/addon/aocl_gemm/frame/lpgemm_config.h +++ b/addon/aocl_gemm/config/lpgemm_config.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,10 +38,15 @@ #include "lpgemm_types.h" // equals to number of ops in enum AOCL_OPERATION_TYPE. -extern lpgemm_cntx_t lpgemm_global_cntx_t_list[4]; +extern lpgemm_cntx_t lpgemm_global_cntx_t_list[AOCL_OPERATION_TYPE_LEN]; +extern lpgemm_cntx_t lpgemm_util_global_cntx_t_list[AOCL_UTIL_OPERATION_TYPE_LEN]; void aocl_lpgemm_init_global_cntx(); +lpgemm_cntx_t* lpgemm_get_global_cntx_obj( AOCL_OPERATION_TYPE op ); + +lpgemm_util_cntx_t* lpgemm_util_get_global_cntx_obj( AOCL_UTIL_OPERATION_TYPE op ); + dim_t lpgemm_get_block_size_MC_global_cntx( AOCL_OPERATION_TYPE op_type ); dim_t lpgemm_get_block_size_NC_global_cntx( AOCL_OPERATION_TYPE op_type ); @@ -52,4 +57,18 @@ dim_t lpgemm_get_block_size_NR_global_cntx( AOCL_OPERATION_TYPE op_type ); dim_t lpgemm_get_block_size_MR_global_cntx( AOCL_OPERATION_TYPE op_type ); +void lpgemm_get_packa_strides( lpgemm_cntx_t* lcntx, dim_t* rs, dim_t* cs ); + +void lpgemm_get_packb_strides( lpgemm_cntx_t* lcntx, dim_t* rs, dim_t* cs ); + +void lpgemm_mod_block_size_s16 + ( + dim_t m, + dim_t n, + dim_t k, + dim_t* MC, + dim_t* NC, + dim_t* KC + ); + #endif //LPGEMM_CONFIG_H diff --git a/addon/aocl_gemm/config/lpgemm_func_map.h b/addon/aocl_gemm/config/lpgemm_func_map.h new file mode 100644 index 0000000000..864f84aef2 --- /dev/null +++ b/addon/aocl_gemm/config/lpgemm_func_map.h @@ -0,0 +1,159 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_FUNC_MAP_H +#define LPGEMM_FUNC_MAP_H + +// The XMACRO follows the format ID,FUNC_PTR: +// ID = One of the AOCL_OPERATION_TYPE enum. +// FUNC_PTR = Kernel associated with the AOCL_OPERATION_TYPE. +// It is to be noted that the main macros are defined for combinations +// of ISA types, and in case a kernel is not implemented for a particualr +// ISA combination, the reference kernel should be set as FUNC_PTR. +// TODO: Add reference kernels for BF16/VNNI kernels for ISA combinations +// that is not supported. + +// Genoa +#define LPGEMM_KERN_FUNC_MAP_AVX512_VNNI_BF16 \ + KMACRO(U8S8S16OS16, lpgemm_rowvar_u8s8s16o16_6x32) \ + KMACRO(U8S8S32OS32, lpgemm_rowvar_u8s8s32o32_6x64) \ + KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_avx512_6x64m) \ + KMACRO(BF16BF16F32OF32, lpgemm_rowvar_bf16bf16f32of32_6x64) \ + KMACRO(S8S8S32OS32, lpgemm_rowvar_s8s8s32os32_6x64) \ + KMACRO(S8S8S16OS16, lpgemm_rowvar_s8s8s16o16_6x32) \ + +#define LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI_BF16 \ + PAMACRO(U8S8S16OS16, NULL) \ + PAMACRO(U8S8S32OS32, packa_k64_u8s8s32o32) \ + PAMACRO(BF16BF16F32OF32, NULL) \ + PAMACRO(S8S8S32OS32, packa_k64_s8s8s32os32) \ + PAMACRO(S8S8S16OS16, NULL) \ + +#define LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI_BF16 \ + PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \ + PBMACRO(U8S8S32OS32, packb_nr64_u8s8s32o32) \ + PBMACRO(BF16BF16F32OF32, packb_nr64_bf16bf16f32of32) \ + PBMACRO(S8S8S32OS32, packb_nr64_s8s8s32os32) \ + PBMACRO(S8S8S16OS16, packb_nr32_s8s8s16o16) \ + +#define LPGEMM_UTIL_KERN_FUNC_MAP_AVX512_VNNI_BF16 \ + UMACRO(F32_GELU_TANH, lpgemm_util_f32_gelu_tanh_avx512_kernel) \ + UMACRO(F32_GELU_ERF, lpgemm_util_f32_gelu_erf_avx512_kernel) \ + UMACRO(F32_SOFTMAX, lpgemm_util_f32_softmax_avx512_kernel) \ + +// Icelake +#define LPGEMM_KERN_FUNC_MAP_AVX512_VNNI \ + KMACRO(U8S8S16OS16, lpgemm_rowvar_u8s8s16o16_6x32) \ + KMACRO(U8S8S32OS32, lpgemm_rowvar_u8s8s32o32_6x64) \ + KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_avx512_6x64m) \ + KMACRO(BF16BF16F32OF32, lpgemm_rowvar_bf16bf16f32of32_6x64) \ + KMACRO(S8S8S32OS32, lpgemm_rowvar_s8s8s32os32_6x64) \ + KMACRO(S8S8S16OS16, lpgemm_rowvar_s8s8s16o16_6x32) \ + +#define LPGEMM_PACKA_FUNC_MAP_AVX512_VNNI \ + PAMACRO(U8S8S16OS16, NULL) \ + PAMACRO(U8S8S32OS32, packa_k64_u8s8s32o32) \ + PAMACRO(BF16BF16F32OF32, NULL) \ + PAMACRO(S8S8S32OS32, packa_k64_s8s8s32os32) \ + PAMACRO(S8S8S16OS16, NULL) \ + +#define LPGEMM_PACKB_FUNC_MAP_AVX512_VNNI \ + PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \ + PBMACRO(U8S8S32OS32, packb_nr64_u8s8s32o32) \ + PBMACRO(BF16BF16F32OF32, packb_nr64_bf16bf16f32of32) \ + PBMACRO(S8S8S32OS32, packb_nr64_s8s8s32os32) \ + PBMACRO(S8S8S16OS16, packb_nr32_s8s8s16o16) \ + +#define LPGEMM_UTIL_KERN_FUNC_MAP_AVX512_VNNI \ + UMACRO(F32_GELU_TANH, lpgemm_util_f32_gelu_tanh_avx512_kernel) \ + UMACRO(F32_GELU_ERF, lpgemm_util_f32_gelu_erf_avx512_kernel) \ + UMACRO(F32_SOFTMAX, lpgemm_util_f32_softmax_avx512_kernel) \ + +// Skylake +#define LPGEMM_KERN_FUNC_MAP_AVX512 \ + KMACRO(U8S8S16OS16, lpgemm_rowvar_u8s8s16o16_6x32) \ + KMACRO(U8S8S32OS32, lpgemm_rowvar_u8s8s32o32_6x64) \ + KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_avx512_6x64m) \ + KMACRO(BF16BF16F32OF32, lpgemm_rowvar_bf16bf16f32of32_6x64) \ + KMACRO(S8S8S32OS32, lpgemm_rowvar_s8s8s32os32_6x64) \ + KMACRO(S8S8S16OS16, lpgemm_rowvar_s8s8s16o16_6x32) \ + +#define LPGEMM_PACKA_FUNC_MAP_AVX512 \ + PAMACRO(U8S8S16OS16, NULL) \ + PAMACRO(U8S8S32OS32, packa_k64_u8s8s32o32) \ + PAMACRO(BF16BF16F32OF32, NULL) \ + PAMACRO(S8S8S32OS32, packa_k64_s8s8s32os32) \ + PAMACRO(S8S8S16OS16, NULL) \ + +#define LPGEMM_PACKB_FUNC_MAP_AVX512 \ + PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \ + PBMACRO(U8S8S32OS32, packb_nr64_u8s8s32o32) \ + PBMACRO(BF16BF16F32OF32, packb_nr64_bf16bf16f32of32) \ + PBMACRO(S8S8S32OS32, packb_nr64_s8s8s32os32) \ + PBMACRO(S8S8S16OS16, packb_nr32_s8s8s16o16) \ + +#define LPGEMM_UTIL_KERN_FUNC_MAP_AVX512 \ + UMACRO(F32_GELU_TANH, lpgemm_util_f32_gelu_tanh_avx512_kernel) \ + UMACRO(F32_GELU_ERF, lpgemm_util_f32_gelu_erf_avx512_kernel) \ + UMACRO(F32_SOFTMAX, lpgemm_util_f32_softmax_avx512_kernel) \ + +// Milan, Haswell +#define LPGEMM_KERN_FUNC_MAP_AVX2 \ + KMACRO(U8S8S16OS16, lpgemm_rowvar_u8s8s16o16_6x32) \ + KMACRO(U8S8S32OS32, NULL) \ + KMACRO(F32F32F32OF32, lpgemm_rowvar_f32f32f32of32_6x16m) \ + KMACRO(BF16BF16F32OF32, NULL) \ + KMACRO(S8S8S32OS32, NULL) \ + KMACRO(S8S8S16OS16, lpgemm_rowvar_s8s8s16o16_6x32) \ + +#define LPGEMM_PACKA_FUNC_MAP_AVX2 \ + PAMACRO(U8S8S16OS16, NULL) \ + PAMACRO(U8S8S32OS32, NULL) \ + PAMACRO(BF16BF16F32OF32, NULL) \ + PAMACRO(S8S8S32OS32, NULL) \ + PAMACRO(S8S8S16OS16, NULL) \ + +#define LPGEMM_PACKB_FUNC_MAP_AVX2 \ + PBMACRO(U8S8S16OS16, packb_nr32_u8s8s16o16) \ + PBMACRO(U8S8S32OS32, NULL) \ + PBMACRO(BF16BF16F32OF32, NULL) \ + PBMACRO(S8S8S32OS32, NULL) \ + PBMACRO(S8S8S16OS16, packb_nr32_s8s8s16o16) \ + +#define LPGEMM_UTIL_KERN_FUNC_MAP_AVX2 \ + UMACRO(F32_GELU_TANH, lpgemm_util_f32_gelu_tanh_avx2_kernel) \ + UMACRO(F32_GELU_ERF, lpgemm_util_f32_gelu_erf_avx2_kernel) \ + UMACRO(F32_SOFTMAX, lpgemm_util_f32_softmax_avx2_kernel) \ + +#endif //LPGEMM_FUNC_MAP_H diff --git a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c index 5db523f987..1ece1db727 100644 --- a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c +++ b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,14 +40,36 @@ #include "lpgemm_thrinfo_utils.h" #include "lpgemm_config.h" +// Kernel function prototypes +typedef void (*lpgemm_rowvar_bf16) + ( + const dim_t, + const dim_t, + const dim_t, + const bfloat16*, + const dim_t, + const dim_t, + const dim_t, + const bfloat16*, + const dim_t, + const dim_t, + float*, + const dim_t, + const dim_t, + const float, + const float, + lpgemm_post_op*, + lpgemm_post_op_attr + ); + // B should always be packed. LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) { - dim_t NC = lpgemm_get_block_size_NC_global_cntx( BF16BF16F32OF32 ); - dim_t KC = lpgemm_get_block_size_KC_global_cntx( BF16BF16F32OF32 ); - dim_t MC = lpgemm_get_block_size_MC_global_cntx( BF16BF16F32OF32 ); - dim_t NR = lpgemm_get_block_size_NR_global_cntx( BF16BF16F32OF32 ); - dim_t MR = lpgemm_get_block_size_MR_global_cntx( BF16BF16F32OF32 ); + dim_t NC = lcntx->blksz.NC; + dim_t KC = lcntx->blksz.KC; + dim_t MC = lcntx->blksz.MC; + dim_t NR = lcntx->blksz.NR; + dim_t MR = lcntx->blksz.MR; const int16_t* a_use = NULL; dim_t cs_a_use = cs_a; @@ -80,9 +102,22 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) dim_t k_updated = k; k_updated += (k_updated & 0x1); - // Is required to decide whether to apply post ops or not. + // To decide whether to apply post ops or not. bool is_last_k = FALSE; + // To decide whether to use original s8 C or temp buffer for beta scale. + bool is_first_k = FALSE; + + lpgemm_post_op_attr post_ops_attr; + if ( c_downscale == TRUE ) + { + post_ops_attr.buf_downscale = c; + } + else + { + post_ops_attr.buf_downscale = NULL; + } + // Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t. thrinfo_t thread_jc; thrinfo_t thread_ic; @@ -102,7 +137,7 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) dim_t jc_cur_loop = jc; dim_t jc_cur_loop_rem = 0; - dim_t n_sub_updated; + dim_t n_sub_updated = 0; if ( mtag_b == REORDERED ) { @@ -121,45 +156,24 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) // Temp accumulaton buffer for C allocation. else if ( c_downscale == TRUE ) { - mem_scale_c_size_req = sizeof( float ) * nc0 * ( ic_end - ic_start ); - - lpgemm_alloc_mem_panel - ( - mem_scale_c_size_req, BLIS_BUFFER_FOR_C_PANEL, - &mem_scale_c, rntm - ); + // Buffer memory is only required if output needs to be + // persisted across iterations of the pc/KC loop. + // It was observed that the locks used while checking out + // a buffer from memory pool had an impact on performance + // and is better to not checkout if k <= KC. + if ( k > KC ) + { + mem_scale_c_size_req = sizeof( float ) * nc0 * ( ic_end - ic_start ); - temp_scal_c_buffer_bf16 = bli_mem_buffer( &mem_scale_c ); + lpgemm_alloc_mem_panel + ( + mem_scale_c_size_req, BLIS_BUFFER_FOR_C_PANEL, + &mem_scale_c, rntm + ); - c_use_jc = ( float* )temp_scal_c_buffer_bf16; + temp_scal_c_buffer_bf16 = bli_mem_buffer( &mem_scale_c ); - if ( beta != 0 ) - { - dim_t i_temp = 0; - dim_t j_temp = 0; - int32_t temp_conv_buf = 0; - // Upscale out C to temporary C matrix. - for ( dim_t i_dscale = ic_start; i_dscale < ic_end; ++i_dscale ) - { - j_temp = 0; - for ( dim_t j_dscale = jc; j_dscale < ( jc + nc0 ); ++j_dscale ) - { - // Implemented with the idea sizeof(float)=4. - temp_conv_buf = 0; - temp_conv_buf = *( ( int16_t* )( ( bfloat16* )c + - ( rs_c * i_dscale ) + j_dscale ) ); - - // Add 16 bits in the fractional part. - temp_conv_buf = temp_conv_buf << 16; - - // Store the bytes in float format. - *( temp_scal_c_buffer_bf16 + ( nc0 * i_temp ) + j_temp ) - = *( ( float* )( &temp_conv_buf ) ); - - j_temp++; - } - i_temp++; - } + c_use_jc = ( float* )temp_scal_c_buffer_bf16; } // The temp c buffer stride is modified as opposed to original C matrix. @@ -171,6 +185,13 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) float beta0 = ( pc == 0 ) ? beta : 1; dim_t kc0 = bli_min( ( k - pc ), KC ); + // No parallelization in k dim, k always starts at 0. + is_first_k = ( pc == 0 ) ? ( TRUE ) : ( FALSE ); + post_ops_attr.is_first_k = is_first_k; + + is_last_k = ( ( pc + KC ) >= k ) ? ( TRUE ) : ( FALSE ); + post_ops_attr.is_last_k = is_last_k; + // kc0 needs to be a multiple of 2 so that it can be // used with dpbf16_ps instruction. Padding is added in // cases this condition is not satisfied, and therefore @@ -179,8 +200,6 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) dim_t kc0_updated = kc0; kc0_updated += (kc0_updated & 0x1); - is_last_k = ( ( pc + KC ) >= k ) ? ( TRUE ) : ( FALSE ); - if ( mtag_b == PACK ) { // Pack B chunks are based on jc work id. @@ -235,8 +254,7 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) if ( ( jc_packb_end > jc_packb_start ) && ( jc_packb_start < ( jc + nc0 ) ) ) { -#ifdef BLIS_KERNELS_ZEN4 - packb_nr64_bf16bf16f32of32 + ( ( packb_bf16 )lcntx->packb_fun_ptr ) ( pack_b_buffer_bf16 + ( jc_packb_start * kc0_updated ), ( b + ( rs_b * pc ) + ( cs_b * jc ) + @@ -244,11 +262,10 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) ( jc_packb_end - jc_packb_start ), kc0, &rs_b_use, &cs_b_use ); -#endif } else { - get_packb_nr64_bf16bf16f32of32_strides( &rs_b_use, &cs_b_use ); + lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use ); } // All threads in work group should wait till B matrix packing @@ -271,7 +288,7 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) ( n_sub_updated * pc ) + ( jc_cur_loop_rem * kc0_updated ); - get_packb_nr64_bf16bf16f32of32_strides( &rs_b_use, &cs_b_use ); + lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use ); } for ( dim_t ic = ic_start; ic < ic_end; ic += MC ) @@ -304,30 +321,21 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) { dim_t nr0 = bli_min( ( nc0 - jr ), NR ); -#ifdef BLIS_KERNELS_ZEN4 + // Post ops meta attributes. + post_ops_attr.post_op_c_i = ic; + post_ops_attr.post_op_c_j = ( jc + jr ); + post_ops_attr.rs_c_downscale = rs_c_downscale; + // Reorder/Packed B, Reorder/Packed/Unpacked A call. - lpgemm_rowvar_bf16bf16f32of32_6x64 + ( ( lpgemm_rowvar_bf16 )lcntx->kern_fun_ptr ) ( mc0, nr0, kc0, a_use, rs_a, cs_a_use, a_block_stride, ( b_use + ( jr * kc0_updated ) ), rs_b_use, cs_b_use, ( c_use_ic + jr ), rs_c_use, 1, alpha, beta0, - is_last_k, ic, ( jc + jr ), post_op_list, rs_c_downscale + post_op_list, post_ops_attr ); -#else - // Silence compiler warnings. - ( void )b_use; - ( void )a_block_stride; - ( void )rs_c_downscale; - ( void )is_last_k; - ( void )c_use_ic; - ( void )a_use; - ( void )beta0; - ( void )nr0; - ( void )mc0; - ( void )cs_a_use; -#endif } } } diff --git a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.c b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.c index 5bb217facd..b90d339664 100644 --- a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.c +++ b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.c @@ -1,180 +1,169 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" -#include "lpgemm_utils.h" -#include "lpgemm_reorder_bf16.h" -#include "lpgemm_packb_bf16.h" -#include "lpgemm_config.h" -#include "aocl_bf16_type.h" - -void reorderb_nr64_bf16bf16f32of32 - ( - lpgemm_obj_t *b, - lpgemm_obj_t *b_reorder - ) -{ - dim_t NC = lpgemm_get_block_size_NC_global_cntx( BF16BF16F32OF32 ); - dim_t NR = lpgemm_get_block_size_NR_global_cntx( BF16BF16F32OF32 ); - dim_t KC = lpgemm_get_block_size_KC_global_cntx( BF16BF16F32OF32 ); - - // Extracting the matrix properties from the lpgemm object - dim_t rs_b = b->rs; - dim_t n = b->width; - dim_t k = b->length; - - dim_t rs_b_reorder; - dim_t cs_b_reorder; - - // k needs to be a multiple of 2 so that it can be used with dpbf - // instruction. Padding is added in cases this condition is not - // satisfied, and therefore the k offset used for packed/reordered - // buffer needs to be updated. - dim_t k_updated = k; - k_updated += (k_updated & 0x1); - - // Initialize a local runtime with global settings if necessary. Note - // that in the case that a runtime is passed in, we make a local copy. - rntm_t rntm_g; - bli_rntm_init_from_global( &rntm_g ); - - dim_t n_threads = bli_rntm_num_threads( &rntm_g ); - n_threads = ( n_threads > 0 ) ? n_threads : 1; - -#ifdef BLIS_ENABLE_OPENMP - _Pragma( "omp parallel num_threads(n_threads)" ) - { - // Initialise a local thrinfo obj for work split across threads. - thrinfo_t thread_jc; - bli_thrinfo_set_n_way( n_threads, &thread_jc ); - bli_thrinfo_set_work_id( omp_get_thread_num(), &thread_jc ); -#else - { - // Initialise a local thrinfo obj for work split across threads. - thrinfo_t thread_jc; - bli_thrinfo_set_n_way( 1, &thread_jc ); - bli_thrinfo_set_work_id( 0, &thread_jc ); -#endif - // Compute the JC loop thread range for the current thread. - dim_t jc_start, jc_end; - bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end ); - - for ( dim_t jc = jc_start; jc < jc_end; jc += NC ) - { - dim_t nc0 = bli_min( ( jc_end - jc ), NC ); - - dim_t jc_cur_loop = jc; - dim_t jc_cur_loop_rem = 0; - dim_t n_sub_updated; - - get_B_panel_reordered_start_offset_width - ( - jc, n, NC, 16, - &jc_cur_loop, &jc_cur_loop_rem, - &nc0, &n_sub_updated - ); - - for ( dim_t pc = 0; pc < k; pc += KC ) - { - dim_t kc0 = bli_min( ( k - pc ), KC ); - - // k needs to be a multiple of 2 so that it can be used with dpbf - // instruction. Padding is added in cases this condition is not - // satisfied, and therefore the k offset used for packed/reordered - // buffer needs to be updated. - dim_t kc0_updated = kc0; - kc0_updated += (kc0_updated & 0x1); - - // The offsets are calculated in such a way that it resembles - // the reorder buffer traversal in single threaded reordering. - // The panel boundaries (KCxNC) remain as it is accessed in - // single thread, and as a consequence a thread with jc_start - // inside the panel cannot consider NC range for reorder. It - // has to work with NC' < NC, and the offset is calulated using - // prev NC panels spanning k dim + cur NC panel spaning pc loop - // cur iteration + (NC - NC') spanning current kc0 (<= KC). - // - //Eg: Consider the following reordered buffer diagram: - // t1 t2 - // | | - // | |..NC..| - // | | | - // |.NC. |.NC. |NC'|NC" - // pc=0-+-----+-----+---+--+ - // KC| | | | | - // | 1 | 3 | 5 | - // pc=KC-+-----+-----+---st-+ - // KC| | | | | - // | 2 | 4 | 6 | 7| - // pc=k=2KC-+-----+-----+---+--+ - // |jc=0 |jc=NC|jc=2NC| - // - // The numbers 1,2..6,7 denotes the order in which reordered - // KCxNC blocks are stored in memory, ie: block 1 followed by 2 - // followed by 3, etc. Given two threads t1 and t2, and t2 needs - // to acces point st in the reorder buffer to write the data: - // The offset calulation logic will be: - // jc_cur_loop = 2NC, jc_cur_loop_rem = NC', pc = KC, - // n_sub_updated = NC, k = 2KC, kc0_updated = KC - // - // st = ( jc_cur_loop * k ) - // + ( n_sub_updated * pc ) - // + ( NC' * kc0_updated) -#ifdef BLIS_KERNELS_ZEN4 - // B should always be packed. - packb_nr64_bf16bf16f32of32 - ( - ( ( ( bfloat16* )b_reorder->storage.aligned_buffer ) + - ( jc_cur_loop * k_updated ) + ( n_sub_updated * pc ) + - ( jc_cur_loop_rem * kc0_updated ) ), - ( ( ( bfloat16* )b->storage.aligned_buffer ) + - ( rs_b * pc ) + jc ), - rs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder - ); -#else - // Silence compiler warnings. - rs_b_reorder = 0; - cs_b_reorder = 0; - ( void )rs_b; -#endif - } - - adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); - } - } - - b_reorder->rs = rs_b_reorder; - b_reorder->cs = cs_b_reorder; - b_reorder->mtag = REORDERED; -} +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_utils.h" +#include "lpgemm_reorder_bf16.h" +#include "lpgemm_packb_bf16.h" +#include "lpgemm_config.h" +#include "aocl_bf16_type.h" + +void reorderb_nr64_bf16bf16f32of32 + ( + lpgemm_obj_t* b, + lpgemm_obj_t* b_reorder, + rntm_t* rntm, + lpgemm_cntx_t* lcntx + ) +{ + dim_t NC = lcntx->blksz.NC; + dim_t KC = lcntx->blksz.KC; + dim_t NR = lcntx->blksz.NR; + + // Extracting the matrix properties from the lpgemm object + dim_t rs_b = b->rs; + dim_t n = b->width; + dim_t k = b->length; + + dim_t rs_b_reorder; + dim_t cs_b_reorder; + + // k needs to be a multiple of 2 so that it can be used with dpbf + // instruction. Padding is added in cases this condition is not + // satisfied, and therefore the k offset used for packed/reordered + // buffer needs to be updated. + dim_t k_updated = k; + k_updated += (k_updated & 0x1); + + dim_t n_threads = bli_rntm_num_threads( rntm ); + n_threads = ( n_threads > 0 ) ? n_threads : 1; + +#ifdef BLIS_ENABLE_OPENMP + _Pragma( "omp parallel num_threads(n_threads)" ) + { + // Initialise a local thrinfo obj for work split across threads. + thrinfo_t thread_jc; + bli_thrinfo_set_n_way( n_threads, &thread_jc ); + bli_thrinfo_set_work_id( omp_get_thread_num(), &thread_jc ); +#else + { + // Initialise a local thrinfo obj for work split across threads. + thrinfo_t thread_jc; + bli_thrinfo_set_n_way( 1, &thread_jc ); + bli_thrinfo_set_work_id( 0, &thread_jc ); +#endif + // Compute the JC loop thread range for the current thread. + dim_t jc_start, jc_end; + bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end ); + + for ( dim_t jc = jc_start; jc < jc_end; jc += NC ) + { + dim_t nc0 = bli_min( ( jc_end - jc ), NC ); + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated; + + get_B_panel_reordered_start_offset_width + ( + jc, n, NC, 16, + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated + ); + + for ( dim_t pc = 0; pc < k; pc += KC ) + { + dim_t kc0 = bli_min( ( k - pc ), KC ); + + // k needs to be a multiple of 2 so that it can be used with dpbf + // instruction. Padding is added in cases this condition is not + // satisfied, and therefore the k offset used for packed/reordered + // buffer needs to be updated. + dim_t kc0_updated = kc0; + kc0_updated += (kc0_updated & 0x1); + + // The offsets are calculated in such a way that it resembles + // the reorder buffer traversal in single threaded reordering. + // The panel boundaries (KCxNC) remain as it is accessed in + // single thread, and as a consequence a thread with jc_start + // inside the panel cannot consider NC range for reorder. It + // has to work with NC' < NC, and the offset is calulated using + // prev NC panels spanning k dim + cur NC panel spaning pc loop + // cur iteration + (NC - NC') spanning current kc0 (<= KC). + // + //Eg: Consider the following reordered buffer diagram: + // t1 t2 + // | | + // | |..NC..| + // | | | + // |.NC. |.NC. |NC'|NC" + // pc=0-+-----+-----+---+--+ + // KC| | | | | + // | 1 | 3 | 5 | + // pc=KC-+-----+-----+---st-+ + // KC| | | | | + // | 2 | 4 | 6 | 7| + // pc=k=2KC-+-----+-----+---+--+ + // |jc=0 |jc=NC|jc=2NC| + // + // The numbers 1,2..6,7 denotes the order in which reordered + // KCxNC blocks are stored in memory, ie: block 1 followed by 2 + // followed by 3, etc. Given two threads t1 and t2, and t2 needs + // to acces point st in the reorder buffer to write the data: + // The offset calulation logic will be: + // jc_cur_loop = 2NC, jc_cur_loop_rem = NC', pc = KC, + // n_sub_updated = NC, k = 2KC, kc0_updated = KC + // + // st = ( jc_cur_loop * k ) + // + ( n_sub_updated * pc ) + // + ( NC' * kc0_updated) + ( ( packb_bf16 )lcntx->packb_fun_ptr ) + ( + ( ( ( bfloat16* )b_reorder->storage.aligned_buffer ) + + ( jc_cur_loop * k_updated ) + ( n_sub_updated * pc ) + + ( jc_cur_loop_rem * kc0_updated ) ), + ( ( ( bfloat16* )b->storage.aligned_buffer ) + + ( rs_b * pc ) + jc ), + rs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder + ); + } + + adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); + } + } + + b_reorder->rs = rs_b_reorder; + b_reorder->cs = cs_b_reorder; + b_reorder->mtag = REORDERED; +} diff --git a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.h b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.h index c1b83c1b75..42c8cb9ef6 100644 --- a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.h +++ b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.h @@ -1,46 +1,48 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#ifndef LPGEMM_REORDER_BF16_H -#define LPGEMM_REORDER_BF16_H - -#include "lpgemm_types.h" - -void reorderb_nr64_bf16bf16f32of32 - ( - lpgemm_obj_t *b, - lpgemm_obj_t *b_reorder - ); - -#endif // LPGEMM_REORDER_H +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_REORDER_BF16_H +#define LPGEMM_REORDER_BF16_H + +#include "lpgemm_types.h" + +void reorderb_nr64_bf16bf16f32of32 + ( + lpgemm_obj_t * b, + lpgemm_obj_t * b_reorder, + rntm_t* rntm, + lpgemm_cntx_t* lcntx + ); + +#endif // LPGEMM_REORDER_H diff --git a/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c b/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c index 6242ceebe8..1864d78330 100644 --- a/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c +++ b/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -37,6 +37,29 @@ #include "lpgemm_types.h" #include "lpgemm_utils.h" #include "lpgemm_thrinfo_utils.h" +#include "lpgemm_kernels.h" + +// Kernel function prototypes +typedef void (*lpgemm_rowvar_f32) + ( + const dim_t, + const dim_t, + const dim_t, + const float*, + const dim_t, + const dim_t, + const dim_t, + const float*, + const dim_t, + const dim_t, + float*, + const dim_t, + const dim_t, + const float, + const float, + lpgemm_post_op*, + lpgemm_post_op_attr + ); void lpgemm_pack_a_f32f32f32of32 ( @@ -51,197 +74,338 @@ void lpgemm_pack_a_f32f32f32of32 cntx_t* cntx ); +void lpgemm_pack_b_f32f32f32of32 + ( + const float* input_buf_addr_b, + float* reorder_buf_addr_b, + const dim_t n, + const dim_t k, + const dim_t rs_b, + const dim_t cs_b, + const dim_t ps_p, + const dim_t NR, + cntx_t* cntx + ); + LPGEMM_5LOOP(float,float,float,f32f32f32of32) { - // Query the global cntx. - cntx_t* cntx = bli_gks_query_cntx(); - - num_t dt = BLIS_FLOAT; - - // Query the context for various blocksizes. - const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); - const dim_t MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx ); - const dim_t NC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx ); - const dim_t MC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MC, cntx ); - const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); - - // Strides are updated based on matrix packing/reordering. - const float* a_use = NULL; - dim_t rs_a_use = rs_a; - dim_t cs_a_use = cs_a; - - const float* b_use = NULL; - dim_t rs_b_use = rs_b; - dim_t cs_b_use = cs_b; - - float* c_use_jc = NULL; - float* c_use_ic = NULL; - - // Only supporting row major with unit column strided C for now. - const dim_t cs_c_use = 1; - - /* Compute partitioning step values for each matrix of each loop. */ - inc_t ps_a_use; - inc_t ps_b_use; - auxinfo_t aux; - - // Check if packing of A is required. - bool should_pack_A = bli_rntm_pack_a( rntm ); - - // Pack buffer for A. - float* pack_a_buffer_f32f32f32of32; - mem_t mem_a = BLIS_MEM_INITIALIZER; - siz_t mem_a_size_req = 0; - - float one_local = *PASTEMAC(s,1); - - trans_t transc = BLIS_NO_TRANSPOSE; - conj_t conjc = bli_extract_conj( transc ); - - // Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t. - thrinfo_t thread_jc; - thrinfo_t thread_ic; - - lpgemm_gen_thrinfo( thread, &thread_jc, &thread_ic ); - - // Compute the JC loop thread range for the current thread. - dim_t jc_start, jc_end; - bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end ); - - for ( dim_t jc = jc_start; jc < jc_end; jc += NC ) - { - dim_t nc0 = bli_min( ( jc_end - jc ), NC ); - c_use_jc = c + jc; - - dim_t jc_cur_loop = jc; - dim_t jc_cur_loop_rem = 0; - dim_t n_sub_updated; - - if ( mtag_b == REORDERED ) - { - get_B_panel_reordered_start_offset_width - ( - jc, n, NC, NR, - &jc_cur_loop, &jc_cur_loop_rem, - &nc0, &n_sub_updated - ); - } - - for ( dim_t pc = 0; pc < k; pc += KC ) - { - float beta0 = ( pc == 0 ) ? beta : one_local; - dim_t kc0 = bli_min( ( k - pc ), KC ); - - if ( mtag_b == REORDERED ) - { - // In multi-threaded scenarios, an extra offset into a given - // packed B panel is required, since the jc loop split can - // result in per thread start offset inside the panel, instead - // of panel boundaries. - b_use = b + ( jc_cur_loop * k ) + - ( n_sub_updated * pc ) + ( jc_cur_loop_rem * kc0 ); - - rs_b_use = NR; - cs_b_use = 1; - ps_b_use = kc0; - } - else - { - b_use = b + ( pc * rs_b ) + ( jc * cs_b ); - ps_b_use = 1; - } - - dim_t ic_start, ic_end; - bli_thread_range_sub( &thread_ic, m, MR, FALSE, &ic_start, &ic_end ); - - for ( dim_t ic = ic_start; ic < ic_end; ic += MC ) - { - dim_t mc0 = bli_min( ( ic_end - ic ), MC ); - c_use_ic = c_use_jc + ( rs_c * ic ); - - if ( mtag_a == REORDERED ) - { - // Extra space since packing does width in multiples of MR. - const dim_t m_updated = ( ( m + MR - 1 ) / MR ) * MR; - a_use = a + ( pc * m_updated ) + ( kc0 * ic ); - - rs_a_use = 1; - cs_a_use = MR; - ps_a_use = MR * kc0; - } - else if ( should_pack_A == TRUE ) - { - // Extra space since packing does width in multiples of MR. - const dim_t mc0_updated = ( ( mc0 + MR - 1 ) / MR ) * MR; - mem_a_size_req = sizeof( float ) * mc0_updated * kc0; - - lpgemm_alloc_mem_panel - ( - mem_a_size_req, BLIS_BUFFER_FOR_A_BLOCK, - &mem_a, rntm - ); - pack_a_buffer_f32f32f32of32 = ( float* )bli_mem_buffer( &mem_a ); - - rs_a_use = 1; - cs_a_use = MR; - ps_a_use = MR * kc0; - - lpgemm_pack_a_f32f32f32of32 - ( - ( a + ( rs_a * ic ) + pc ), - pack_a_buffer_f32f32f32of32, - mc0, kc0, - rs_a, cs_a, ps_a_use, MR, - cntx - ); - - a_use = pack_a_buffer_f32f32f32of32; - } - else - { - a_use = a + ( rs_a * ic ) + pc; - ps_a_use = MR * rs_a; - } - - // Embed the panel stride of A within the auxinfo_t object. The - // millikernel will query and use this to iterate through - // micropanels of A (if needed). + // Query the global cntx. + cntx_t* cntx = bli_gks_query_cntx(); + + num_t dt = BLIS_FLOAT; + + // Query the context for various blocksizes. + const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); + const dim_t MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx ); + const dim_t NC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx ); + const dim_t MC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MC, cntx ); + const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); + + /*ToDo: Based on context kernel 6x64m or 6x16m will be picked here */ + + // Strides are updated based on matrix packing/reordering. + const float* a_use = NULL; + dim_t rs_a_use = rs_a; + dim_t cs_a_use = cs_a; + + const float* b_use = NULL; + dim_t rs_b_use = rs_b; + dim_t cs_b_use = cs_b; + + float* c_use_jc = NULL; + float* c_use_ic = NULL; + + dim_t rs_c_downscale = rs_c; + + // Only supporting row major with unit column strided C for now. + const dim_t cs_c_use = 1; + + /* Compute partitioning step values for each matrix of each loop. */ + inc_t ps_a_use; + inc_t ps_b_use; + auxinfo_t aux; + + // Check if packing of A is required. + bool should_pack_A = bli_rntm_pack_a( rntm ); + + // Pack buffer for A. + float* pack_a_buffer_f32f32f32of32; + mem_t mem_a = BLIS_MEM_INITIALIZER; + siz_t mem_a_size_req = 0; + + // Check if packing of A is required. + bool should_pack_B = bli_rntm_pack_b( rntm ); + + // Pack buffer for B. + float* pack_b_buffer_f32f32f32of32; + mem_t mem_b = BLIS_MEM_INITIALIZER; + siz_t mem_b_size_req = 0; + + float one_local = *PASTEMAC(s,1); + + // To decide whether to apply post ops or not. + bool is_last_k = FALSE; + + // To decide whether to use original s8 C or temp buffer for beta scale. + bool is_first_k = FALSE; + + lpgemm_post_op_attr post_ops_attr; + if ( c_downscale == TRUE ) + { + post_ops_attr.buf_downscale = c; + } + else + { + post_ops_attr.buf_downscale = NULL; + } + + // Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t. + thrinfo_t thread_jc; + thrinfo_t thread_ic; + + lpgemm_gen_thrinfo( thread, &thread_jc, &thread_ic ); + + // Compute the JC loop thread range for the current thread. + dim_t jc_start, jc_end; + bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end ); + + for ( dim_t jc = jc_start; jc < jc_end; jc += NC ) + { + dim_t nc0 = bli_min( ( jc_end - jc ), NC ); + c_use_jc = c + jc; + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated = 0; + + if ( mtag_b == REORDERED ) + { + get_B_panel_reordered_start_offset_width + ( + jc, n, NC, NR, + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated + ); + } + + for ( dim_t pc = 0; pc < k; pc += KC ) + { + float beta0 = ( pc == 0 ) ? beta : one_local; + dim_t kc0 = bli_min( ( k - pc ), KC ); + + // No parallelization in k dim, k always starts at 0. + is_first_k = ( pc == 0 ) ? ( TRUE ) : ( FALSE ); + post_ops_attr.is_first_k = is_first_k; + + is_last_k = ( ( pc + KC ) >= k ) ? ( TRUE ) : ( FALSE ); + post_ops_attr.is_last_k = is_last_k; + + if ( ( mtag_b == PACK ) && ( should_pack_B == TRUE ) ) + { + // Pack B chunks are based on jc work id. + dim_t jc_work_id = bli_thread_work_id( &thread_jc ); + + // Using child thrinfo (thread_ic) tid to decide chief thread + // per B matrix chunk (jc work id group) + if ( bli_thread_am_ochief( &thread_ic ) ) + { + // nc0 needs to be a multiple of 16 since this gives maximum + // vectorization. Packing B always results in buffers with width + // which is a multiple of 16. Subsequently the nc0 offsets used + // for packed/reordered buffers needs to be updated. + dim_t nc0_updated = make_multiple_of_n( nc0, NR ); + mem_b_size_req = sizeof( float ) * nc0_updated * kc0; + + lpgemm_alloc_mem_panel + ( + mem_b_size_req, BLIS_BUFFER_FOR_B_PANEL, + &mem_b, rntm + ); + + thread->comm[jc_work_id].sent_object = bli_mem_buffer(&mem_b); + } + + // All threads in work group should wait till chief thread has + // finished allocating the packing buffers. + bli_thrcomm_barrier + ( + bli_thread_ocomm_id( &thread_ic ), + &thread->comm[jc_work_id] + ); + + pack_b_buffer_f32f32f32of32 = + ( float* ) thread->comm[jc_work_id].sent_object; + // Set the strides for pack buffer. + rs_b_use = NR; + cs_b_use = 1; + ps_b_use = kc0; + + // Compute the B panel per thread loop range for parallel + // packing using ic_ways number of threads. Since atmost only + // ic_ways threads can be used, the thread_ic attributes are + // used to split the loop range. + dim_t jc_packb_start, jc_packb_end; + bli_thread_range_sub + ( + &thread_ic, nc0, NR, FALSE, + &jc_packb_start, &jc_packb_end + ); + + // Ensure thread ranges are valid, especially cases where no: + // of threads available for parallelization are greater than + // no: of B panel NR chunks. + if ( ( jc_packb_end > jc_packb_start ) && + ( jc_packb_start < ( jc + nc0 ) ) ) + { + lpgemm_pack_b_f32f32f32of32 + ( + ( b + ( rs_b * pc ) + ( cs_b * jc ) + ( cs_b * jc_packb_start ) ), + pack_b_buffer_f32f32f32of32 + ( jc_packb_start * kc0 ), + ( jc_packb_end - jc_packb_start ), kc0, + rs_b, cs_b, ( NR * ps_b_use ), NR, + cntx + ); + } + + // All threads in work group should wait till B matrix packing + // is completed by the participating threads. + bli_thrcomm_barrier + ( + bli_thread_ocomm_id( &thread_ic ), + &thread->comm[jc_work_id] + ); + b_use = pack_b_buffer_f32f32f32of32; + } + else if ( mtag_b == REORDERED ) + { + // In multi-threaded scenarios, an extra offset into a given + // packed B panel is required, since the jc loop split can + // result in per thread start offset inside the panel, instead + // of panel boundaries. + b_use = b + ( jc_cur_loop * k ) + + ( n_sub_updated * pc ) + ( jc_cur_loop_rem * kc0 ); + + rs_b_use = NR; + cs_b_use = 1; + ps_b_use = kc0; + } + else + { + b_use = b + ( pc * rs_b ) + ( jc * cs_b ); + ps_b_use = 1; + } + + dim_t ic_start, ic_end; + bli_thread_range_sub( &thread_ic, m, MR, FALSE, &ic_start, &ic_end ); + + for ( dim_t ic = ic_start; ic < ic_end; ic += MC ) + { + dim_t mc0 = bli_min( ( ic_end - ic ), MC ); + c_use_ic = c_use_jc + ( rs_c * ic ); + + if ( mtag_a == REORDERED ) + { + // Extra space since packing does width in multiples of MR. + const dim_t m_updated = ( ( m + MR - 1 ) / MR ) * MR; + a_use = a + ( pc * m_updated ) + ( kc0 * ic ); + + rs_a_use = 1; + cs_a_use = MR; + ps_a_use = MR * kc0; + } + else if ( should_pack_A == TRUE ) + { + // Extra space since packing does width in multiples of MR. + const dim_t mc0_updated = ( ( mc0 + MR - 1 ) / MR ) * MR; + mem_a_size_req = sizeof( float ) * mc0_updated * kc0; + + lpgemm_alloc_mem_panel + ( + mem_a_size_req, BLIS_BUFFER_FOR_A_BLOCK, + &mem_a, rntm + ); + pack_a_buffer_f32f32f32of32 = ( float* )bli_mem_buffer( &mem_a ); + + rs_a_use = 1; + cs_a_use = MR; + ps_a_use = MR * kc0; + + lpgemm_pack_a_f32f32f32of32 + ( + ( a + ( rs_a * ic ) + pc ), + pack_a_buffer_f32f32f32of32, + mc0, kc0, + rs_a, cs_a, ps_a_use, MR, + cntx + ); + + a_use = pack_a_buffer_f32f32f32of32; + } + else + { + a_use = a + ( rs_a * ic ) + pc; + ps_a_use = MR * rs_a; + } + + // Embed the panel stride of A within the auxinfo_t object. The + // millikernel will query and use this to iterate through + // micropanels of A (if needed). bli_auxinfo_set_ps_a( ps_a_use, &aux ); - for ( dim_t jr = 0; jr < nc0; jr += NR ) - { - dim_t nr0 = bli_min( ( nc0 - jr ), NR ); - - // Reordered/unpacked B, reordered/unpacked A. - bli_sgemmsup_rv_zen_asm_6x16m - ( - conjc, - conjc, - mc0, nr0, kc0, - &alpha, - ( float* )a_use, rs_a_use, cs_a_use, - ( float* )( b_use + ( jr * ps_b_use ) ), rs_b_use, cs_b_use, - &beta0, - ( c_use_ic + jr ), rs_c, cs_c_use, - &aux, cntx - ); - } - } - } - if ( mtag_b == REORDERED ) - { - adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); - } - } - - // Release pack buffers. - if ( should_pack_A == TRUE ) - { - if ( bli_mem_is_alloc( &mem_a ) ) - { - bli_membrk_release( rntm, &mem_a ); - } - } + for ( dim_t jr = 0; jr < nc0; jr += NR ) + { + dim_t nr0 = bli_min( ( nc0 - jr ), NR ); + + // Post ops meta attributes. + post_ops_attr.post_op_c_i = ic; + post_ops_attr.post_op_c_j = ( jc + jr ); + post_ops_attr.rs_c_downscale = rs_c_downscale; + + // Reordered/unpacked B, reordered/unpacked A. + ( ( lpgemm_rowvar_f32 )lcntx->kern_fun_ptr ) + ( + mc0, nr0, kc0, + ( float* )a_use, rs_a_use, cs_a_use, ps_a_use, + ( float* )( b_use + ( jr * ps_b_use ) ), rs_b_use, cs_b_use, + ( c_use_ic + jr ), rs_c, cs_c_use, + alpha , beta0, + post_op_list, post_ops_attr + ); + } + } + } + if ( mtag_b == REORDERED ) + { + adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); + } + } + + // Release pack buffers. + if ( mtag_b == PACK ) + { + // All threads in work group should wait till B matrix usage is + // completed by the participating threads. + bli_thrcomm_barrier + ( + bli_thread_ocomm_id( &thread_jc ), + &thread->comm[bli_thread_work_id( &thread_jc)] + ); + + if ( bli_thread_am_ochief( &thread_ic ) ) + { + if ( bli_mem_is_alloc( &mem_b ) ) + { + bli_membrk_release( rntm, &mem_b ); + } + } + } + if ( should_pack_A == TRUE ) + { + if ( bli_mem_is_alloc( &mem_a ) ) + { + bli_membrk_release( rntm, &mem_a ); + } + } } void lpgemm_pack_a_f32f32f32of32 @@ -257,44 +421,99 @@ void lpgemm_pack_a_f32f32f32of32 cntx_t* cntx ) { - float one_local = *PASTEMAC(s,1); - float* restrict kappa_cast = &one_local; - - // Set the schema to "column stored row panels" to indicate packing to conventional - // column-stored row panels. - pack_t schema = BLIS_PACKED_ROW_PANELS; - trans_t transc = BLIS_NO_TRANSPOSE; - conj_t conjc = bli_extract_conj( transc ); - - // Compute the total number of iterations we'll need. - dim_t m_iter = ( m + MR - 1 ) / MR; - - inc_t cs_p = MR; - - float* p_temp = reorder_buf_addr_a; - dim_t ir, it; - // Iterate over every logical micropanel in the source matrix. - for ( ir = 0, it = 0; it < m_iter; ir += MR, it += 1 ) - { - dim_t panel_dim_i = bli_min( MR, m - ir ); - - const float* a_use = input_buf_addr_a + ( ir * rs_a ); - float* p_use = p_temp; - - PASTEMAC(s,packm_cxk) - ( - conjc, - schema, - panel_dim_i, - MR, - k, - k, - kappa_cast, - ( float* )a_use, rs_a, cs_a, - p_use, cs_p, - cntx - ); - - p_temp += ps_p; - } + float one_local = *PASTEMAC(s,1); + float* restrict kappa_cast = &one_local; + + // Set the schema to "column stored row panels" to indicate packing to conventional + // column-stored row panels. + pack_t schema = BLIS_PACKED_ROW_PANELS; + trans_t transc = BLIS_NO_TRANSPOSE; + conj_t conjc = bli_extract_conj( transc ); + // Compute the total number of iterations we'll need. + dim_t m_iter = ( m + MR - 1 ) / MR; + + inc_t cs_p = MR; + + float* p_temp = reorder_buf_addr_a; + + dim_t ir, it; + // Iterate over every logical micropanel in the source matrix. + for ( ir = 0, it = 0; it < m_iter; ir += MR, it += 1 ) + { + dim_t panel_dim_i = bli_min( MR, m - ir ); + + const float* a_use = input_buf_addr_a + ( ir * rs_a ); + float* p_use = p_temp; + + PASTEMAC(s,packm_cxk) + ( + conjc, + schema, + panel_dim_i, + MR, + k, + k, + kappa_cast, + ( float* )a_use, rs_a, cs_a, + p_use, cs_p, + cntx + ); + + p_temp += ps_p; + } +} + +void lpgemm_pack_b_f32f32f32of32 + ( + const float* input_buf_addr_b, + float* reorder_buf_addr_b, + const dim_t n, + const dim_t k, + const dim_t rs_b, + const dim_t cs_b, + const dim_t ps_p, + const dim_t NR, + cntx_t* cntx + ) +{ + float one_local = *PASTEMAC(s,1); + float* restrict kappa_cast = &one_local; + + // Set the schema to "row stored column panels" to indicate packing to + // conventional row-stored column panels. + pack_t schema = BLIS_PACKED_COL_PANELS; + trans_t transc = BLIS_NO_TRANSPOSE; + conj_t conjc = bli_extract_conj( transc ); + // Compute the total number of iterations we'll need. + dim_t n_iter = ( n + NR - 1 ) / NR; + + inc_t rs_p = NR; + + float* p_temp = reorder_buf_addr_b; + + dim_t jr, it; + // Iterate over every logical micropanel in the source matrix. + for ( jr = 0, it = 0; it < n_iter; jr += NR, it += 1 ) + { + dim_t panel_dim_i = bli_min( NR, n - jr ); + + const float* b_use = input_buf_addr_b + ( jr * cs_b ); + float* p_use = p_temp; + + PASTEMAC(s,packm_cxk) + ( + conjc, + schema, + panel_dim_i, + NR, + k, + k, + kappa_cast, + ( float* )b_use, cs_b, rs_b, + p_use, rs_p, + cntx + ); + + p_temp += ps_p; + } } diff --git a/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h b/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h index 45328669de..62fc678faa 100644 --- a/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h +++ b/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -56,10 +56,11 @@ void lpgemm_rowvar_ ## LP_SFX \ C_type* c, \ const dim_t rs_c, \ const dim_t cs_c, \ - C_type alpha, \ - C_type beta, \ + const C_type alpha, \ + const C_type beta, \ rntm_t* rntm, \ lpgemm_thrinfo_t* thread, \ + lpgemm_cntx_t* lcntx, \ lpgemm_post_op* post_op_list, \ bool c_downscale \ ) \ @@ -68,4 +69,6 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32); LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16); LPGEMM_5LOOP(float,float,float,f32f32f32of32); LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32); +LPGEMM_5LOOP(int8_t,int8_t,int32_t,s8s8s32o32); +LPGEMM_5LOOP(int8_t,int8_t,int16_t,s8s8s16o16); #endif // LPGEMM_5LOOP_INTF_H diff --git a/addon/aocl_gemm/frame/lpgemm_config.c b/addon/aocl_gemm/frame/lpgemm_config.c deleted file mode 100644 index 901ec087d2..0000000000 --- a/addon/aocl_gemm/frame/lpgemm_config.c +++ /dev/null @@ -1,90 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" -#include "lpgemm_config.h" - -lpgemm_cntx_t global_cntx_t_list[4]; //Only one op type supported now. - -BLIS_INLINE void lpgemm_set_block_sizes_global_cntx - ( - AOCL_OPERATION_TYPE op_type, - dim_t MC, - dim_t NC, - dim_t KC, - dim_t NR, - dim_t MR - ) -{ - global_cntx_t_list[op_type].blksz.MC = MC; - global_cntx_t_list[op_type].blksz.NC = NC; - global_cntx_t_list[op_type].blksz.KC = KC; - global_cntx_t_list[op_type].blksz.NR = NR; - global_cntx_t_list[op_type].blksz.MR = MR; -} - -// Sets default block sizes for lpgemm. Currently only u8s8s32 supported. -// Thread safety is not considered now since the block sizes are not expected -// to be configurable from application. -void aocl_lpgemm_init_global_cntx() -{ - lpgemm_set_block_sizes_global_cntx( U8S8S32OS32, 144, 1024, 2048, 64, 6 ); - lpgemm_set_block_sizes_global_cntx( U8S8S16OS16, 144, 1024, 1024, 32, 6 ); - lpgemm_set_block_sizes_global_cntx( BF16BF16F32OF32, 144, 1024, 2048, 64, 6 ); -} - -dim_t lpgemm_get_block_size_MC_global_cntx( AOCL_OPERATION_TYPE op_type ) -{ - return global_cntx_t_list[op_type].blksz.MC; -} - -dim_t lpgemm_get_block_size_NC_global_cntx( AOCL_OPERATION_TYPE op_type ) -{ - return global_cntx_t_list[op_type].blksz.NC; -} - -dim_t lpgemm_get_block_size_KC_global_cntx( AOCL_OPERATION_TYPE op_type ) -{ - return global_cntx_t_list[op_type].blksz.KC; -} - -dim_t lpgemm_get_block_size_NR_global_cntx( AOCL_OPERATION_TYPE op_type ) -{ - return global_cntx_t_list[op_type].blksz.NR; -} - -dim_t lpgemm_get_block_size_MR_global_cntx( AOCL_OPERATION_TYPE op_type ) -{ - return global_cntx_t_list[op_type].blksz.MR; -} diff --git a/addon/aocl_gemm/frame/lpgemm_post_ops.c b/addon/aocl_gemm/frame/lpgemm_post_ops.c index 63fb25765f..fffe14c0f8 100644 --- a/addon/aocl_gemm/frame/lpgemm_post_ops.c +++ b/addon/aocl_gemm/frame/lpgemm_post_ops.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -83,6 +83,7 @@ void lpgemm_translate_to_post_ops_list return; //Error, seq length exceeds max post ops permitted. } + dim_t e_i = 0; //Multiple eltwise supported. for ( dim_t i = 0; i < post_op_unparsed->seq_length; ++i ) { // Dispatcher code @@ -103,7 +104,7 @@ void lpgemm_translate_to_post_ops_list { LPGEMM_POST_OP_CODE tmp_code = POST_OPS_DISABLE; // Eltwise algo dispatcher. - switch ( post_op_unparsed->eltwise.algo.algo_type ) + switch ( ( post_op_unparsed->eltwise + e_i )->algo.algo_type ) { case RELU: tmp_code = POST_OPS_RELU; @@ -111,6 +112,15 @@ void lpgemm_translate_to_post_ops_list case PRELU: tmp_code = POST_OPS_RELU_SCALE; break; + case GELU_TANH: + tmp_code = POST_OPS_GELU_TANH; + break; + case GELU_ERF: + tmp_code = POST_OPS_GELU_ERF; + break; + case CLIP: + tmp_code = POST_OPS_CLIP; + break; default: break; } @@ -118,11 +128,12 @@ void lpgemm_translate_to_post_ops_list ( ( post_op_list + i ), tmp_code, NULL, - post_op_unparsed->eltwise.algo.alpha, - post_op_unparsed->eltwise.algo.beta, - post_op_unparsed->eltwise.scale_factor, - post_op_unparsed->eltwise.is_power_of_2 + ( post_op_unparsed->eltwise + e_i )->algo.alpha, + ( post_op_unparsed->eltwise + e_i )->algo.beta, + ( post_op_unparsed->eltwise + e_i )->scale_factor, + ( post_op_unparsed->eltwise + e_i )->is_power_of_2 ); + e_i += 1; } break; case BIAS: diff --git a/addon/aocl_gemm/frame/lpgemm_post_ops.h b/addon/aocl_gemm/frame/lpgemm_post_ops.h index 3932daf602..7509e57a39 100644 --- a/addon/aocl_gemm/frame/lpgemm_post_ops.h +++ b/addon/aocl_gemm/frame/lpgemm_post_ops.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -41,8 +41,11 @@ typedef enum POST_OPS_BIAS = 1, POST_OPS_RELU = 2, POST_OPS_RELU_SCALE = 3, - POST_OPS_DOWNSCALE = 4, - POST_OPS_SUM = 5, + POST_OPS_GELU_TANH = 4, + POST_OPS_GELU_ERF = 5, + POST_OPS_CLIP = 6, + POST_OPS_DOWNSCALE = 7, + POST_OPS_SUM = 8, } LPGEMM_POST_OP_CODE; // Used as an internal structure. @@ -57,6 +60,21 @@ typedef struct lpgemm_post_op_t struct lpgemm_post_op_t* next; } lpgemm_post_op; +// Used as an internal structure. +typedef struct lpgemm_post_op_attr_t +{ + dim_t post_op_c_i; + dim_t post_op_c_j; + dim_t rs_c_downscale; + dim_t cs_c_downscale; + void* buf_downscale; + bool is_first_k; + bool is_last_k; + dim_t b_sum_offset; + int32_t* b_col_sum_vec; + int16_t* b_col_sum_vec_s16; +} lpgemm_post_op_attr; + void lpgemm_translate_to_post_ops_list ( aocl_post_op* post_op_unparsed, @@ -66,7 +84,7 @@ void lpgemm_translate_to_post_ops_list ); #define POST_OP_LABEL_LASTK_SAFE_JUMP \ - if ( ( is_last_k == TRUE ) && ( post_ops_list_temp != NULL ) ) \ + if ( ( post_ops_attr.is_last_k == TRUE ) && ( post_ops_list_temp != NULL ) ) \ { \ goto *post_ops_labels[post_ops_list_temp->op_code]; \ } \ diff --git a/addon/aocl_gemm/frame/lpgemm_types.h b/addon/aocl_gemm/frame/lpgemm_types.h index aebd485d0d..b700c03878 100644 --- a/addon/aocl_gemm/frame/lpgemm_types.h +++ b/addon/aocl_gemm/frame/lpgemm_types.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -47,9 +47,20 @@ typedef enum { U8S8S16OS16 = 0, // uint8_t - A, int8_t - B, int16_t - C U8S8S32OS32 = 1, // uint8_t - A, int8_t - B, int32_t - C - F16F16F16OF16 = 2, // float16 - A, float16 - B, float16 - C - BF16BF16F32OF32 = 3 // bf16 - A, bf16 - B, float - C + F32F32F32OF32 = 2, // float - A, float - B, float - C + BF16BF16F32OF32 = 3, // bf16 - A, bf16 - B, float - C + S8S8S32OS32 = 4, // int8_t - A, int8_t - B, int32_t - C + S8S8S16OS16 = 5 // int8_t - A, int8_t - B, int16_t - C } AOCL_OPERATION_TYPE; +#define AOCL_OPERATION_TYPE_LEN 6 + +typedef enum +{ + F32_GELU_TANH = 0, + F32_GELU_ERF = 1, + F32_SOFTMAX = 2 +} AOCL_UTIL_OPERATION_TYPE; +#define AOCL_UTIL_OPERATION_TYPE_LEN 3 typedef enum { @@ -100,11 +111,28 @@ typedef struct dim_t MR; } lpgemm_block_size_t; +typedef struct +{ + dim_t packa_rs; + dim_t packa_cs; + dim_t packb_rs; + dim_t packb_cs; +} lpgemm_pack_strides_t; + typedef struct { lpgemm_block_size_t blksz; + void_fp kern_fun_ptr; + void_fp packa_fun_ptr; + void_fp packb_fun_ptr; + lpgemm_pack_strides_t pack_s; } lpgemm_cntx_t; +typedef struct +{ + void_fp kern_fun_ptr; +} lpgemm_util_cntx_t; + typedef struct { dim_t n_threads; diff --git a/addon/aocl_gemm/frame/s8s8s16/lpgemm_reorder_s8s16.c b/addon/aocl_gemm/frame/s8s8s16/lpgemm_reorder_s8s16.c new file mode 100644 index 0000000000..474014d5df --- /dev/null +++ b/addon/aocl_gemm/frame/s8s8s16/lpgemm_reorder_s8s16.c @@ -0,0 +1,187 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "blis.h" +#include "lpgemm_utils_s8.h" +#include "lpgemm_reorder_s8s16.h" +#include "lpgemm_packb_s8s16.h" +#include "lpgemm_config.h" + +void aocl_reorderb_nr32_s8s8s16o16 + ( + lpgemm_obj_t* b, + lpgemm_obj_t* b_reorder, + rntm_t* rntm, + lpgemm_cntx_t* lcntx + ) +{ + dim_t NC = lcntx->blksz.NC; + dim_t KC = lcntx->blksz.KC; + dim_t NR = lcntx->blksz.NR; + + // Extracting the matrix properties from the lpgemm object + dim_t rs_b = b->rs; + dim_t n = b->width; + dim_t k = b->length; + + lpgemm_mod_block_size_s16(0, n, k, NULL, &NC, &KC); + + dim_t rs_b_reorder; + dim_t cs_b_reorder; + + dim_t k_updated = k; + + // Making multiple of 2 to suit k in vpmaddubsw + k_updated += (k_updated & 0x1); + + dim_t n_updated = make_multiple_of_n( n, 16 ); + + dim_t n_threads = bli_rntm_num_threads( rntm ); + n_threads = ( n_threads > 0 ) ? n_threads : 1; + + // To access the last row of B matrix - Column sum of B matrix + int16_t* pack_b_column_sum = ( int16_t* ) ( b_reorder->storage.aligned_buffer + ( sizeof( int8_t ) * n_updated * k_updated )); + for (int idx = 0; idx < n_updated; idx++ ) + { + *( pack_b_column_sum + idx ) = 0; + } + +#ifdef BLIS_ENABLE_OPENMP + _Pragma( "omp parallel num_threads(n_threads)" ) + { + // Initialise a local thrinfo obj for work split across threads. + thrinfo_t thread_jc; + bli_thrinfo_set_n_way( n_threads, &thread_jc ); + bli_thrinfo_set_work_id( omp_get_thread_num(), &thread_jc ); +#else + { + // Initialise a local thrinfo obj for work split across threads. + thrinfo_t thread_jc; + bli_thrinfo_set_n_way( 1, &thread_jc ); + bli_thrinfo_set_work_id( 0, &thread_jc ); +#endif + // Compute the JC loop thread range for the current thread. + dim_t jc_start, jc_end; + bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end ); + + for ( dim_t jc = jc_start; jc < jc_end; jc += NC ) + { + dim_t nc0 = bli_min( ( jc_end - jc ), NC ); + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated; + + get_B_panel_reordered_start_offset_width + ( + jc, n, NC, 16, + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated + ); + + for ( dim_t pc = 0; pc < k; pc += KC ) + { + dim_t kc0 = bli_min( ( k - pc ), KC ); + + // kc0 needs to be a multiple of 2 so that it can be used with + // vmaddubsw instruction. Padding is added in cases this + // condition is not satisfied, and therefore the kc0 offsets + // used for packed/reordered buffers needs to be updated. + dim_t kc0_updated = make_multiple_of_n( kc0, 2 ); + + // The offsets are calculated in such a way that it resembles + // the reorder buffer traversal in single threaded reordering. + // The panel boundaries (KCxNC) remain as it is accessed in + // single thread, and as a consequence a thread with jc_start + // inside the panel cannot consider NC range for reorder. It + // has to work with NC' < NC, and the offset is calulated using + // prev NC panels spanning k dim + cur NC panel spaning pc loop + // cur iteration + (NC - NC') spanning current kc0 (<= KC). + // + //Eg: Consider the following reordered buffer diagram: + // t1 t2 + // | | + // | |..NC..| + // | | | + // |.NC. |.NC. |NC'|NC" + // pc=0-+-----+-----+---+--+ + // KC| | | | | + // | 1 | 3 | 5 | + // pc=KC-+-----+-----+---st-+ + // KC| | | | | + // | 2 | 4 | 6 | 7| + // pc=k=2KC-+-----+-----+---+--+ + // |jc=0 |jc=NC|jc=2NC| + // + // The numbers 1,2..6,7 denotes the order in which reordered + // KCxNC blocks are stored in memory, ie: block 1 followed by 2 + // followed by 3, etc. Given two threads t1 and t2, and t2 needs + // to acces point st in the reorder buffer to write the data: + // The offset calulation logic will be: + // jc_cur_loop = 2NC, jc_cur_loop_rem = NC', pc = KC, + // n_sub_updated = NC, k = 2KC, kc0_updated = KC + // + // st = ( jc_cur_loop * k ) + // + ( n_sub_updated * pc ) + // + ( NC' * kc0_updated) + ( ( packb_s16_s8 )lcntx->packb_fun_ptr ) + ( + ( ( ( int8_t* )b_reorder->storage.aligned_buffer ) + + ( jc_cur_loop * k_updated ) + ( n_sub_updated * pc ) + + ( jc_cur_loop_rem * kc0_updated ) ), + pack_b_column_sum + jc, + ( ( ( int8_t* )b->storage.aligned_buffer ) + + ( rs_b * pc ) + jc ), + rs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder + ); + } + + adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); + } + } + // for (int i =0; i< k_updated; i++) + // { + // for (int j=0; j< n_updated; j++) + // { + // printf(" %d ", *( int8_t* )(b->storage.aligned_buffer + i*n_updated + j )); + // } + // printf(" \n "); + // } + // for (int i =0; i< n_updated; i++) + // printf(" %d ", *(pack_b_column_sum + i)); + + // Changing the packed matrix properties in the packed matrix object + b_reorder->rs = rs_b_reorder; + b_reorder->cs = cs_b_reorder; + b_reorder->mtag = REORDERED; +} diff --git a/addon/aocl_gemm/frame/s8s8s16/lpgemm_reorder_s8s16.h b/addon/aocl_gemm/frame/s8s8s16/lpgemm_reorder_s8s16.h new file mode 100644 index 0000000000..8a87474ad4 --- /dev/null +++ b/addon/aocl_gemm/frame/s8s8s16/lpgemm_reorder_s8s16.h @@ -0,0 +1,47 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#ifndef LPGEMM_REORDER_S8S16_H +#define LPGEMM_REORDER_S8S16_H + +#include "lpgemm_types.h" + +void aocl_reorderb_nr32_s8s8s16o16 + ( + lpgemm_obj_t* b, + lpgemm_obj_t* b_reorder, + rntm_t* rntm, + lpgemm_cntx_t* lcntx + ); + +#endif // LPGEMM_REORDER_S8S16_H diff --git a/addon/aocl_gemm/frame/s8s8s16/lpgemm_s8s8s16.c b/addon/aocl_gemm/frame/s8s8s16/lpgemm_s8s8s16.c new file mode 100644 index 0000000000..86ee194eb5 --- /dev/null +++ b/addon/aocl_gemm/frame/s8s8s16/lpgemm_s8s8s16.c @@ -0,0 +1,402 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_5loop_interface_apis.h" +#include "lpgemm_packb_s8s16.h" +#include "lpgemm_kernels.h" +#include "lpgemm_utils_s8.h" +#include "lpgemm_config.h" +#include "lpgemm_thrinfo_utils.h" + +// Kernel function prototypes +typedef void (*lpgemm_rowvar_s16_s8) + ( + const dim_t, + const dim_t, + const dim_t, + const int8_t*, + const dim_t, + const dim_t, + const dim_t, + const int8_t*, + const dim_t, + const dim_t, + int16_t*, + const dim_t, + const dim_t, + const int16_t, + const int16_t, + lpgemm_post_op*, + lpgemm_post_op_attr + ); + +// B should always be packed. +LPGEMM_5LOOP(int8_t,int8_t,int16_t,s8s8s16o16) +{ + dim_t NC = lcntx->blksz.NC; + dim_t KC = lcntx->blksz.KC; + dim_t MC = lcntx->blksz.MC; + const dim_t NR = lcntx->blksz.NR; + const dim_t MR = lcntx->blksz.MR; + + lpgemm_mod_block_size_s16(m, n, k, &MC, &NC, &KC); + + if (mtag_b == UNPACKED) + { + // Error: can only work with packed B now. + return; + } + + const int8_t *b_use; + const int8_t *a_use; + dim_t rs_a_use = rs_a; + dim_t cs_a_use = cs_a; + + dim_t rs_b_use = rs_b; + dim_t cs_b_use = cs_b; + + int16_t *c_use_jc = NULL; + int16_t *c_use_ic = NULL; + dim_t rs_c_use = rs_c; + dim_t rs_c_downscale = rs_c; + + // Pack buffer for B. + int8_t *pack_b_buffer_s8s8s16o16; + mem_t mem_b = BLIS_MEM_INITIALIZER; + dim_t packb_min_NR = 16; + siz_t mem_b_size_req = 0; + + // Temporary buffer for C accumulation when downscaling is required. + int16_t* temp_scal_c_buffer_s8s8s16o16; + mem_t mem_scale_c = BLIS_MEM_INITIALIZER; + siz_t mem_scale_c_size_req = 0; + + // Making multiple of 2 to suit k in vpmaddubsw + dim_t k_updated = make_multiple_of_n( k, 2 ); + + // Making multiple of 16 + dim_t n_updated = make_multiple_of_n( n, 16 ); + + // To decide whether to apply post ops or not. + bool is_last_k = FALSE; + + // To decide whether to use original s8 C or temp buffer for beta scale. + bool is_first_k = FALSE; + + lpgemm_post_op_attr post_ops_attr; + if ( c_downscale == TRUE ) + { + post_ops_attr.buf_downscale = c; + } + else + { + post_ops_attr.buf_downscale = NULL; + } + + // Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t. + thrinfo_t thread_jc; + thrinfo_t thread_ic; + + lpgemm_gen_thrinfo(thread, &thread_jc, &thread_ic); + + // Compute the JC, IC loop thread range for the current thread. + dim_t jc_start, jc_end; + bli_thread_range_sub(&thread_jc, n, NR, FALSE, &jc_start, &jc_end); + + dim_t ic_start, ic_end; + bli_thread_range_sub(&thread_ic, m, MR, FALSE, &ic_start, &ic_end); + + for (dim_t jc = jc_start; jc < jc_end; jc += NC) + { + dim_t nc0 = bli_min((jc_end - jc), NC); + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated = 0; + + if (mtag_b == REORDERED) + { + get_B_panel_reordered_start_offset_width + ( + jc, n, NC, packb_min_NR, + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated + ); + } + + if ( c_downscale == FALSE ) + { + c_use_jc = c + jc; + } + // Temp accumulaton buffer for C allocation. + else if ( c_downscale == TRUE ) + { + // Buffer memory is only required if output needs to be + // persisted across iterations of the pc/KC loop. + // It was observed that the locks used while checking out + // a buffer from memory pool had an impact on performance + // and is better to not checkout if k <= KC. + if ( k > KC ) + { + mem_scale_c_size_req = sizeof( int16_t ) * nc0 * ( ic_end - ic_start ); + + lpgemm_alloc_mem_panel + ( + mem_scale_c_size_req, BLIS_BUFFER_FOR_C_PANEL, + &mem_scale_c, rntm + ); + + temp_scal_c_buffer_s8s8s16o16 = bli_mem_buffer( &mem_scale_c ); + + c_use_jc = ( int16_t* )temp_scal_c_buffer_s8s8s16o16; + } + + // The temp c buffer stride is modified as opposed to original C matrix. + rs_c_use = nc0; + } + + int16_t* pack_b_column_sum = NULL; + + for (dim_t pc = 0; pc < k; pc += KC) + { + int16_t beta0 = (pc == 0) ? beta : 1; + dim_t kc0 = bli_min((k - pc), KC); + + // No parallelization in k dim, k always starts at 0. + is_first_k = ( pc == 0 ) ? ( TRUE ) : ( FALSE ); + post_ops_attr.is_first_k = is_first_k; + + is_last_k = ( ( pc + KC ) >= k ) ? ( TRUE ) : ( FALSE ); + post_ops_attr.is_last_k = is_last_k; + + // kc0 needs to be a multiple of 2 so that it can be + // used with vpmaddubsw instruction. Padding is added in + // cases this condition is not satisfied, and therefore + // the kc0 offsets used for packed/reordered buffers + // needs to be updated. + dim_t kc0_updated = make_multiple_of_n(kc0, 2); + + if (mtag_b == PACK) + { + // Pack B chunks are based on jc work id. + dim_t jc_work_id = bli_thread_work_id(&thread_jc); + + // Using child thrinfo (thread_ic) tid to decide chief thread + // per B matrix chunk (jc work id group) + + // nc0 needs to be a multiple of 16 since this gives maximum + // vectorization. Packing B always results in buffers with width + // which is a multiple of 16. Subsequently the nc0 offsets used + // for packed/reordered buffers needs to be updated. + dim_t nc0_updated = make_multiple_of_n(nc0, packb_min_NR); + + if (bli_thread_am_ochief(&thread_ic)) + { + mem_b_size_req = sizeof(int8_t) * nc0_updated * kc0_updated + ( nc0_updated * sizeof( int16_t ) ); + + lpgemm_alloc_mem_panel( + mem_b_size_req, BLIS_BUFFER_FOR_B_PANEL, + &mem_b, rntm); + + thread->comm[jc_work_id].sent_object = + bli_mem_buffer(&mem_b); + } + + // All threads in work group should wait till chief thread has + // finished allocating the packing buffers. + bli_thrcomm_barrier + ( + bli_thread_ocomm_id(&thread_ic), + &thread->comm[jc_work_id] + ); + + pack_b_buffer_s8s8s16o16 = + (int8_t *)thread->comm[jc_work_id].sent_object; + + // Compute the B panel per thread loop range for parallel + // packing using ic_ways number of threads. Since atmost only + // ic_ways threads can be used, the thread_ic attributes are + // used to split the loop range. + dim_t jc_packb_start, jc_packb_end; + bli_thread_range_sub + ( + &thread_ic, nc0, NR, FALSE, + &jc_packb_start, &jc_packb_end + ); + + if ( pc == 0) + { + pack_b_column_sum = ( int16_t* )( pack_b_buffer_s8s8s16o16 + ( sizeof( int8_t ) * nc0_updated * kc0_updated ) ); + } + + // Ensure thread ranges are valid, especially cases where no: + // of threads available for parallelization are greater than + // no: of B panel NR chunks. + if ((jc_packb_end > jc_packb_start) && + (jc_packb_start < (jc + nc0))) + { + if ( pc == 0 ) + { + for (int idx = jc_packb_start; idx < jc_packb_end; idx++ ) + { + *( pack_b_column_sum + idx ) = 0; + } + } + + ( ( packb_s16_s8 )lcntx->packb_fun_ptr ) + ( + pack_b_buffer_s8s8s16o16 + + (jc_packb_start * kc0_updated), + pack_b_column_sum + ( cs_b * jc_packb_start ), + (b + (rs_b * pc) + (cs_b * jc) + + (cs_b * jc_packb_start)), + rs_b, + (jc_packb_end - jc_packb_start), kc0, + &rs_b_use, &cs_b_use + ); + } + else + { + lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use ); + } + + // All threads in work group should wait till B matrix packing + // is completed by the participating threads. + bli_thrcomm_barrier + ( + bli_thread_ocomm_id(&thread_ic), + &thread->comm[jc_work_id] + ); + + b_use = pack_b_buffer_s8s8s16o16; + post_ops_attr.b_col_sum_vec_s16 = pack_b_column_sum; + } + else if (mtag_b == REORDERED) + { + // In multi-threaded scenarios, an extra offset into a given + // packed B panel is required, since the jc loop split can + // result in per thread start offset inside the panel, instead + // of panel boundaries. + b_use = b + (jc_cur_loop * k_updated) + + (n_sub_updated * pc) + + (jc_cur_loop_rem * kc0_updated); + + lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use ); + + post_ops_attr.b_col_sum_vec_s16 = ( ( int16_t* )( b + ( k_updated * n_updated ) ) ) + jc; + } + else + { + // Unpacked B not supported. + return; + } + + for (dim_t ic = ic_start; ic < ic_end; ic += MC) + { + dim_t mc0 = bli_min((ic_end - ic), MC); + + // Only per thread C matrix is stored in temp buffer, so both + // per thread jc and ic start should be normalized to zero. + if ( c_downscale == TRUE ) + { + c_use_ic = c_use_jc + ( rs_c_use * ( ic - ic_start ) ); + } + else + { + c_use_ic = c_use_jc + ( rs_c_use * ic ); + } + + a_use = a + (rs_a * ic) + (cs_a * pc); + cs_a_use = 1; + + dim_t a_block_stride = rs_a; + + post_ops_attr.b_sum_offset = 0; + + for (dim_t jr = 0; jr < nc0; jr += NR) + { + dim_t nr0 = bli_min((nc0 - jr), NR); + + // Post ops meta attributes. + post_ops_attr.post_op_c_i = ic; + post_ops_attr.post_op_c_j = ( jc + jr ); + post_ops_attr.rs_c_downscale = rs_c_downscale; + + // Calls for reorder B + ( ( lpgemm_rowvar_s16_s8 )lcntx->kern_fun_ptr ) + ( + mc0, nr0, kc0, + a_use, rs_a_use, cs_a_use, a_block_stride, + (b_use + (jr * kc0_updated)), rs_b_use, cs_b_use, + (c_use_ic + jr), rs_c_use, 1, + alpha, beta0, + post_op_list, post_ops_attr + ); + post_ops_attr.b_sum_offset += NR; + } + } + } + + if (mtag_b == REORDERED) + { + adjust_B_panel_reordered_jc(&jc, jc_cur_loop); + } + } + + // Release pack buffers. + if (mtag_b == PACK) + { + // All threads in work group should wait till B matrix usage is + // completed by the participating threads. + bli_thrcomm_barrier( + bli_thread_ocomm_id(&thread_jc), + &thread->comm[bli_thread_work_id(&thread_jc)]); + + if (bli_thread_am_ochief(&thread_ic)) + { + if (bli_mem_is_alloc(&mem_b)) + { + bli_membrk_release(rntm, &mem_b); + } + } + } + if ( c_downscale == TRUE ) + { + if ( bli_mem_is_alloc( &mem_scale_c ) ) + { + bli_membrk_release( rntm, &mem_scale_c ); + } + } +} diff --git a/addon/aocl_gemm/frame/s8s8s32/lpgemm_reorder_s8.c b/addon/aocl_gemm/frame/s8s8s32/lpgemm_reorder_s8.c new file mode 100644 index 0000000000..ece6c48762 --- /dev/null +++ b/addon/aocl_gemm/frame/s8s8s32/lpgemm_reorder_s8.c @@ -0,0 +1,220 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_utils_s8.h" +#include "lpgemm_reorder_s8.h" +#include "lpgemm_packa_s8.h" +#include "lpgemm_packb_s8.h" +#include "lpgemm_config.h" + +void reorderb_nr64_s8s8s32o32 + ( + lpgemm_obj_t* b, + lpgemm_obj_t* b_reorder, + rntm_t* rntm, + lpgemm_cntx_t* lcntx + ) +{ + dim_t NC = lcntx->blksz.NC; + dim_t KC = lcntx->blksz.KC; + dim_t NR = lcntx->blksz.NR; + + dim_t rs_b = b->rs; + dim_t rs_b_reorder; + dim_t cs_b_reorder; + + dim_t n = b->width; + dim_t k = b->length; + + // k needs to be a multiple of 4 so that it can be used with vpdpbusd + // instruction. Padding is added in cases this condition is not + // satisfied, and therefore the k offset used for packed/reordered + // buffer needs to be updated. + dim_t k_updated = make_multiple_of_n( k, 4 ); + dim_t n_updated = make_multiple_of_n( n, 16 ); + + dim_t n_threads = bli_rntm_num_threads( rntm ); + n_threads = ( n_threads > 0 ) ? n_threads : 1; + + int32_t* pack_b_column_sum = ( int32_t* ) ( b_reorder->storage.aligned_buffer + ( sizeof( int8_t ) * n_updated * k_updated )); + for ( dim_t idx = 0; idx < n_updated; idx++ ) + { + *( pack_b_column_sum + idx ) = 0; + } + +#ifdef BLIS_ENABLE_OPENMP + _Pragma( "omp parallel num_threads(n_threads)" ) + { + // Initialise a local thrinfo obj for work split across threads. + thrinfo_t thread_jc; + bli_thrinfo_set_n_way( n_threads, &thread_jc ); + bli_thrinfo_set_work_id( omp_get_thread_num(), &thread_jc ); +#else + { + // Initialise a local thrinfo obj for work split across threads. + thrinfo_t thread_jc; + bli_thrinfo_set_n_way( 1, &thread_jc ); + bli_thrinfo_set_work_id( 0, &thread_jc ); +#endif + // Compute the JC loop thread range for the current thread. + dim_t jc_start, jc_end; + bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end ); + + for ( dim_t jc = jc_start; jc < jc_end; jc += NC ) + { + dim_t nc0 = bli_min( ( jc_end - jc ), NC ); + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated; + + get_B_panel_reordered_start_offset_width + ( + jc, n, NC, get_packb_s8s8s32o32_min_NR(), + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated + ); + + for ( dim_t pc = 0; pc < k; pc += KC ) + { + dim_t kc0 = bli_min( ( k - pc ), KC ); + + // kc0 needs to be a multiple of 4 so that it can be used with + // vpdpbusd instruction. Padding is added in cases this + // condition is not satisfied, and therefore the kc0 offsets + // used for packed/reordered buffers needs to be updated. + dim_t kc0_updated = make_multiple_of_n( kc0, 4 ); + + // The offsets are calculated in such a way that it resembles + // the reorder buffer traversal in single threaded reordering. + // The panel boundaries (KCxNC) remain as it is accessed in + // single thread, and as a consequence a thread with jc_start + // inside the panel cannot consider NC range for reorder. It + // has to work with NC' < NC, and the offset is calulated using + // prev NC panels spanning k dim + cur NC panel spaning pc loop + // cur iteration + (NC - NC') spanning current kc0 (<= KC). + // + //Eg: Consider the following reordered buffer diagram: + // t1 t2 + // | | + // | |..NC..| + // | | | + // |.NC. |.NC. |NC'|NC" + // pc=0-+-----+-----+---+--+ + // KC| | | | | + // | 1 | 3 | 5 | + // pc=KC-+-----+-----+---st-+ + // KC| | | | | + // | 2 | 4 | 6 | 7| + // pc=k=2KC-+-----+-----+---+--+ + // |jc=0 |jc=NC|jc=2NC| + // + // The numbers 1,2..6,7 denotes the order in which reordered + // KCxNC blocks are stored in memory, ie: block 1 followed by 2 + // followed by 3, etc. Given two threads t1 and t2, and t2 needs + // to acces point st in the reorder buffer to write the data: + // The offset calulation logic will be: + // jc_cur_loop = 2NC, jc_cur_loop_rem = NC', pc = KC, + // n_sub_updated = NC, k = 2KC, kc0_updated = KC + // + // st = ( jc_cur_loop * k ) + // + ( n_sub_updated * pc ) + // + ( NC' * kc0_updated) + ( ( packb_s32_s8 )lcntx->packb_fun_ptr ) + ( + ( ( ( int8_t* )b_reorder->storage.aligned_buffer ) + + ( jc_cur_loop * k_updated ) + ( n_sub_updated * pc ) + + ( jc_cur_loop_rem * kc0_updated ) ), + pack_b_column_sum + jc, + ( ( ( int8_t* )b->storage.aligned_buffer ) + + ( rs_b * pc ) + jc ), + rs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder + ); + } + adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); + } + } + + b_reorder->rs = rs_b_reorder; + b_reorder->cs = cs_b_reorder; + b_reorder->mtag = REORDERED; +} + +void reordera_mr6_s8s8s32o32 + ( + lpgemm_obj_t* a, + lpgemm_obj_t* a_reorder, + rntm_t* rntm, + lpgemm_cntx_t* lcntx + ) +{ + dim_t MC = lcntx->blksz.MC; + dim_t KC = lcntx->blksz.KC; + + dim_t rs_a = a->rs; + dim_t rs_a_reorder; + dim_t cs_a_reorder; + + dim_t k = a->width; + dim_t m = a->length; + + for ( dim_t pc = 0; pc < k; pc += KC ) + { + dim_t kc0 = bli_min( ( k - pc ), KC ); + + // kc0 needs to be a multiple of 4 so that it can be used with + // vpdpbusd instruction. Padding is added in cases this + // condition is not satisfied, and therefore the kc0 offsets + // used for packed/reordered buffers needs to be updated. + dim_t kc0_updated = make_multiple_of_n( kc0, 4 ); + + for ( dim_t ic = 0; ic < m; ic += MC ) + { + dim_t mc0 = bli_min( ( m - ic ), MC ); + + ( ( packa_s32_s8 )lcntx->packa_fun_ptr ) + ( + ( ( ( int8_t* )a_reorder->storage.aligned_buffer ) + ( pc * m ) + + ( ic * kc0_updated ) ), + ( ( ( int8_t* )a->storage.aligned_buffer ) + ( rs_a * ic ) + pc ), + rs_a, mc0, kc0, &rs_a_reorder, &cs_a_reorder + ); + } + } + + a_reorder->rs = rs_a_reorder; + a_reorder->cs = cs_a_reorder; + a_reorder->mtag = REORDERED; +} diff --git a/addon/aocl_gemm/frame/s8s8s32/lpgemm_reorder_s8.h b/addon/aocl_gemm/frame/s8s8s32/lpgemm_reorder_s8.h new file mode 100644 index 0000000000..62bbfdeb64 --- /dev/null +++ b/addon/aocl_gemm/frame/s8s8s32/lpgemm_reorder_s8.h @@ -0,0 +1,57 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_REORDER_H_S8 +#define LPGEMM_REORDER_H_S8 + +#include "lpgemm_types.h" + +void reorderb_nr64_s8s8s32o32 + ( + lpgemm_obj_t* b, + lpgemm_obj_t* b_reorder, + rntm_t* rntm, + lpgemm_cntx_t* lcntx + ); + +void reordera_mr6_s8s8s32o32 + ( + lpgemm_obj_t* a, + lpgemm_obj_t* a_reorder, + rntm_t* rntm, + lpgemm_cntx_t* lcntx + ); + +#endif //LPGEMM_REORDER_H_S8 + diff --git a/addon/aocl_gemm/frame/s8s8s32/lpgemm_s8s8s32.c b/addon/aocl_gemm/frame/s8s8s32/lpgemm_s8s8s32.c new file mode 100644 index 0000000000..98b8081b51 --- /dev/null +++ b/addon/aocl_gemm/frame/s8s8s32/lpgemm_s8s8s32.c @@ -0,0 +1,447 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_5loop_interface_apis.h" +#include "lpgemm_packa_s8.h" +#include "lpgemm_packb_s8.h" +#include "lpgemm_kernels.h" +#include "lpgemm_utils_s8.h" +#include "lpgemm_thrinfo_utils.h" +#include "lpgemm_config.h" + +// Kernel function prototypes +typedef void (*lpgemm_rowvar_s32_s8) + ( + const dim_t, + const dim_t, + const dim_t, + const int8_t*, + const dim_t, + const dim_t, + const dim_t, + const int8_t*, + const dim_t, + const dim_t, + int32_t*, + const dim_t, + const dim_t, + const int32_t, + const int32_t, + lpgemm_post_op*, + lpgemm_post_op_attr + ); + +// B should always be packed. +LPGEMM_5LOOP(int8_t,int8_t,int32_t,s8s8s32o32) +{ + dim_t NC = lcntx->blksz.NC; + dim_t KC = lcntx->blksz.KC; + dim_t MC = lcntx->blksz.MC; + dim_t NR = lcntx->blksz.NR; + dim_t MR = lcntx->blksz.MR; + + if ( mtag_b == UNPACKED ) + { + //Error: can only work with packed B now. + return; + } + + // Strides are updated based on matrix packing/reordering. + const int8_t* a_use = NULL; + dim_t rs_a_use = rs_a; + dim_t cs_a_use = cs_a; + dim_t a_block_stride = 0; + + const int8_t* b_use = NULL; + dim_t rs_b_use = rs_b; + dim_t cs_b_use = cs_b; + + int32_t* c_use_jc = NULL; + int32_t* c_use_ic = NULL; + dim_t rs_c_use = rs_c; + dim_t rs_c_downscale = rs_c; + + // Pack buffer for A. + int8_t* pack_a_buffer_s8s8s32o32; + mem_t mem_a = BLIS_MEM_INITIALIZER; + siz_t mem_a_size_req = 0; + + // Pack buffer for B. + int8_t* pack_b_buffer_s8s8s32o32; + mem_t mem_b = BLIS_MEM_INITIALIZER; + siz_t mem_b_size_req = 0; + dim_t packb_min_NR = get_packb_s8s8s32o32_min_NR(); + + // Temporary buffer for C accumulation when downscaling is required. + int32_t* temp_scal_c_buffer_s8s8s32o32; + mem_t mem_scale_c = BLIS_MEM_INITIALIZER; + siz_t mem_scale_c_size_req = 0; + + // kc needs to be a multiple of 4 so that it can be used with vpdpbusd + // instruction. Padding is added in cases this condition is not + // satisfied, and therefore the k offset used for packed/reordered + // buffer needs to be updated. + dim_t k_updated = make_multiple_of_n( k, 4 ); + dim_t n_updated = make_multiple_of_n( n, 16 ); + + // To decide whether to apply post ops or not. + bool is_last_k = FALSE; + + // To decide whether to use original s8 C or temp buffer for beta scale. + bool is_first_k = FALSE; + + lpgemm_post_op_attr post_ops_attr; + if ( c_downscale == TRUE ) + { + post_ops_attr.buf_downscale = c; + } + else + { + post_ops_attr.buf_downscale = NULL; + } + + // Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t. + thrinfo_t thread_jc; + thrinfo_t thread_ic; + + lpgemm_gen_thrinfo( thread, &thread_jc, &thread_ic ); + + // Compute the JC, IC loop thread range for the current thread. + dim_t jc_start, jc_end; + bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end ); + + dim_t ic_start, ic_end; + bli_thread_range_sub( &thread_ic, m, MR, FALSE, &ic_start, &ic_end ); + + for ( dim_t jc = jc_start; jc < jc_end; jc += NC ) + { + dim_t nc0 = bli_min( ( jc_end - jc ), NC ); + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated = 0; + + if ( mtag_b == REORDERED ) + { + get_B_panel_reordered_start_offset_width + ( + jc, n, NC, packb_min_NR, + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated + ); + } + + if ( c_downscale == FALSE ) + { + c_use_jc = c + jc; + } + // Temp accumulaton buffer for C allocation. + else if ( c_downscale == TRUE ) + { + // Buffer memory is only required if output needs to be + // persisted across iterations of the pc/KC loop. + // It was observed that the locks used while checking out + // a buffer from memory pool had an impact on performance + // and is better to not checkout if k <= KC. + if ( k > KC ) + { + mem_scale_c_size_req = sizeof( int32_t ) * nc0 * ( ic_end - ic_start ); + + lpgemm_alloc_mem_panel + ( + mem_scale_c_size_req, BLIS_BUFFER_FOR_C_PANEL, + &mem_scale_c, rntm + ); + + temp_scal_c_buffer_s8s8s32o32 = bli_mem_buffer( &mem_scale_c ); + + c_use_jc = ( int32_t* )temp_scal_c_buffer_s8s8s32o32; + } + + // The temp c buffer stride is modified as opposed to original C matrix. + rs_c_use = nc0; + } + + int32_t* pack_b_column_sum = NULL; + + for ( dim_t pc = 0; pc < k; pc += KC ) + { + int32_t beta0 = ( pc == 0 ) ? beta : 1; + dim_t kc0 = bli_min( ( k - pc ), KC ); + + // kc0 needs to be a multiple of 4 so that it can be + // used with vpdpbusd instruction. Padding is added in + // cases this condition is not satisfied, and therefore + // the kc0 offsets used for packed/reordered buffers + // needs to be updated. + dim_t kc0_updated = make_multiple_of_n( kc0, 4 ); + + // No parallelization in k dim, k always starts at 0. + is_first_k = ( pc == 0 ) ? ( TRUE ) : ( FALSE ); + post_ops_attr.is_first_k = is_first_k; + + is_last_k = ( ( pc + KC ) >= k ) ? ( TRUE ) : ( FALSE ); + post_ops_attr.is_last_k = is_last_k; + + if ( mtag_b == PACK ) + { + // Pack B chunks are based on jc work id. + dim_t jc_work_id = bli_thread_work_id( &thread_jc ); + + // Using child thrinfo (thread_ic) tid to decide chief thread + // per B matrix chunk (jc work id group) + dim_t nc0_updated = make_multiple_of_n( nc0, packb_min_NR ); + + if ( bli_thread_am_ochief( &thread_ic ) ) + { + // nc0 needs to be a multiple of 16 since this gives maximum + // vectorization. Packing B always results in buffers with width + // which is a multiple of 16. Subsequently the nc0 offsets used + // for packed/reordered buffers needs to be updated.pack + + mem_b_size_req = sizeof( int8_t ) * nc0_updated * kc0_updated + ( nc0_updated * sizeof( int32_t ) ); + + lpgemm_alloc_mem_panel + ( + mem_b_size_req, BLIS_BUFFER_FOR_B_PANEL, + &mem_b, rntm + ); + + thread->comm[jc_work_id].sent_object = bli_mem_buffer( &mem_b ); + } + + // All threads in work group should wait till chief thread has + // finished allocating the packing buffers. + bli_thrcomm_barrier + ( + bli_thread_ocomm_id( &thread_ic ), + &thread->comm[jc_work_id] + ); + + pack_b_buffer_s8s8s32o32 = + ( int8_t* ) thread->comm[jc_work_id].sent_object; + + // Compute the B panel per thread loop range for parallel + // packing using ic_ways number of threads. Since atmost only + // ic_ways threads can be used, the thread_ic attributes are + // used to split the loop range. + dim_t jc_packb_start, jc_packb_end; + bli_thread_range_sub + ( + &thread_ic, nc0, NR, FALSE, + &jc_packb_start, &jc_packb_end + ); + + if ( pc == 0) + { + pack_b_column_sum = ( int32_t* )( pack_b_buffer_s8s8s32o32 + ( sizeof( int8_t ) * nc0_updated * kc0_updated ) ); + } + + // Ensure thread ranges are valid, especially cases where no: + // of threads available for parallelization are greater than + // no: of B panel NR chunks. + if ( ( jc_packb_end > jc_packb_start ) && + ( jc_packb_start < ( jc + nc0 ) ) ) + { + if ( pc == 0 ) + { + for (dim_t idx = jc_packb_start; idx < jc_packb_end; idx++ ) + { + *( pack_b_column_sum + idx ) = 0; + } + } + + ( ( packb_s32_s8 )lcntx->packb_fun_ptr ) + ( + pack_b_buffer_s8s8s32o32 + ( jc_packb_start * kc0_updated ), + pack_b_column_sum + ( cs_b * jc_packb_start ), + ( b + ( rs_b * pc ) + ( cs_b * jc ) + + ( cs_b * jc_packb_start ) ), rs_b, + ( jc_packb_end - jc_packb_start ), kc0, + &rs_b_use, &cs_b_use + ); + } + else + { + lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use ); + } + + // All threads in work group should wait till B matrix packing + // is completed by the participating threads. + bli_thrcomm_barrier + ( + bli_thread_ocomm_id( &thread_ic ), + &thread->comm[jc_work_id] + ); + b_use = pack_b_buffer_s8s8s32o32; + + post_ops_attr.b_col_sum_vec = pack_b_column_sum; + } + else if ( mtag_b == REORDERED ) + { + // In multi-threaded scenarios, an extra offset into a given + // packed B panel is required, since the jc loop split can + // result in per thread start offset inside the panel, instead + // of panel boundaries. + b_use = b + ( jc_cur_loop * k_updated ) + + ( n_sub_updated * pc ) + + ( jc_cur_loop_rem * kc0_updated ); + + lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use ); + + post_ops_attr.b_col_sum_vec = ( ( int32_t* )( b + ( k_updated * n_updated ) ) ) + jc; + } + else + { + //Unpacked B not supported. + return; + } + + for ( dim_t ic = ic_start; ic < ic_end; ic += MC ) + { + dim_t mc0 = bli_min( ( ic_end - ic ), MC ); + + // Only per thread C matrix is stored in temp buffer, so both + // per thread jc and ic start should be normalized to zero. + if ( c_downscale == TRUE ) + { + c_use_ic = c_use_jc + ( rs_c_use * ( ic - ic_start ) ); + } + else + { + c_use_ic = c_use_jc + ( rs_c_use * ic ); + } + + // Matrix A packed and reordered code path is not triggerred + // currently since we do not support it yet. + if ( mtag_a == PACK ) + { + mem_a_size_req = sizeof( int8_t ) * mc0 * kc0_updated; + + lpgemm_alloc_mem_panel + ( + mem_a_size_req, BLIS_BUFFER_FOR_A_BLOCK, + &mem_a, rntm + ); + pack_a_buffer_s8s8s32o32 = ( int8_t* )bli_mem_buffer( &mem_a ); + + ( ( packa_s32_s8 )lcntx->packa_fun_ptr ) + ( + pack_a_buffer_s8s8s32o32, + ( a + ( rs_a * ic ) + pc ), rs_a, + mc0, kc0, + &rs_a_use, &cs_a_use + ); + a_use = pack_a_buffer_s8s8s32o32; + a_block_stride = kc0_updated; + } + + else + { + a_use = a + ( rs_a * ic ) + ( cs_a * pc ); + + // Int8 kernel reads 4 elements, totalling 4 bytes in a + // single broadcast for use in vnni instruction. + // Non vnni based kernel requires update to this code. + cs_a_use = 4; + a_block_stride = rs_a; + } + + post_ops_attr.b_sum_offset = 0; + + for ( dim_t jr = 0; jr < nc0; jr += NR ) + { + dim_t nr0 = bli_min( ( nc0 - jr ), NR ); + + // Post ops meta attributes. + post_ops_attr.post_op_c_i = ic; + post_ops_attr.post_op_c_j = ( jc + jr ); + post_ops_attr.rs_c_downscale = rs_c_downscale; + //post_ops_attr.b_col_sum_vec = ( int32_t* )( b_use + ( rs_b * kc0_updated ) ); + + // Reorder/Packed B, Reorder/Packed/Unpacked A call. + ( ( lpgemm_rowvar_s32_s8 )lcntx->kern_fun_ptr ) + ( + mc0, nr0, kc0, + a_use, rs_a_use, cs_a_use, a_block_stride, + ( b_use + ( jr * kc0_updated ) ), rs_b_use, cs_b_use, + ( c_use_ic + jr ), rs_c_use, 1, + alpha, beta0, + post_op_list, post_ops_attr + ); + post_ops_attr.b_sum_offset += NR; + } + } + } + if ( mtag_b == REORDERED ) + { + adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); + } + } + + // Release pack buffers. + if ( mtag_b == PACK ) + { + // All threads in work group should wait till B matrix usage is + // completed by the participating threads. + bli_thrcomm_barrier + ( + bli_thread_ocomm_id( &thread_jc ), + &thread->comm[bli_thread_work_id( &thread_jc)] + ); + + if ( bli_thread_am_ochief( &thread_ic ) ) + { + if ( bli_mem_is_alloc( &mem_b ) ) + { + bli_membrk_release( rntm, &mem_b ); + } + } + } + if ( mtag_a == PACK ) + { + if ( bli_mem_is_alloc( &mem_a ) ) + { + bli_membrk_release( rntm, &mem_a ); + } + } + if ( c_downscale == TRUE ) + { + if ( bli_mem_is_alloc( &mem_scale_c ) ) + { + bli_membrk_release( rntm, &mem_scale_c ); + } + } +} diff --git a/addon/aocl_gemm/frame/s8s8s32/lpgemm_utils_s8.c b/addon/aocl_gemm/frame/s8s8s32/lpgemm_utils_s8.c new file mode 100644 index 0000000000..dc3413d89d --- /dev/null +++ b/addon/aocl_gemm/frame/s8s8s32/lpgemm_utils_s8.c @@ -0,0 +1,156 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_utils_s8.h" + +dim_t get_64byte_aligned_memory_s8 + ( + void** original_memory, + void** aligned_memory, + int64_t allocate_size + ) +{ + // Get 64 byte aligned memory. + int8_t* t1_original = ( int8_t* ) malloc( allocate_size + 64 ); + if ( t1_original == NULL ) + { + //Error in malloc. + *original_memory = NULL; + *aligned_memory = NULL; + return -1; + } + + int8_t* ta_original = t1_original + 64; + ta_original = ta_original - ( ( int64_t )( ta_original ) % 64 ); + + *original_memory = t1_original; + *aligned_memory = ta_original; + return 0; +} + +static lpgemm_obj_t* alloc_lpgemm_obj_t_s8s8s32 + ( + dim_t length, + dim_t width, + dim_t stride, + dim_t elem_size, + AOCL_STOR_TAG stor_scheme, + AOCL_MEMORY_TAG mtag + ) +{ + lpgemm_obj_t* obj = ( lpgemm_obj_t* ) malloc( sizeof( lpgemm_obj_t ) ); + + if ( obj == NULL ) + { + return NULL; //failure + } + + // Allocate aligned buffers. + get_64byte_aligned_memory_s8( &obj->storage.origin_buffer, + &obj->storage.aligned_buffer, + ( elem_size * length * width ) ); + + if ( obj->storage.origin_buffer == NULL ) + { + // Buffer allocation failed. + free( obj ); + return NULL; + } + + obj->length = length; + obj->width = width; + obj->elem_size = elem_size; + + if ( stor_scheme == ROW_MAJOR ) + { + obj->rs = stride; + obj->cs = 4; // 4 elements read at a time. + } + else if ( stor_scheme == COLUMN_MAJOR ) + { + obj->cs = stride; + obj->rs = 1; + } + obj->mtag = mtag; + + return obj; +} + +lpgemm_obj_t* alloc_unpack_tag_lpgemm_obj_t_s8s8s32 + ( + dim_t length, + dim_t width, + dim_t stride, + dim_t elem_size, + AOCL_STOR_TAG stor_scheme + ) +{ + return alloc_lpgemm_obj_t_s8s8s32( length, width, stride, elem_size, stor_scheme, UNPACKED ); +} + +lpgemm_obj_t* alloc_pack_tag_lpgemm_obj_t_s8s8s32 + ( + dim_t length, + dim_t width, + dim_t stride, + dim_t elem_size, + AOCL_STOR_TAG stor_scheme + ) +{ + return alloc_lpgemm_obj_t_s8s8s32( length, width, stride, elem_size, stor_scheme, PACK ); +} + +lpgemm_obj_t* alloc_reorder_tag_lpgemm_obj_t_s8s8s32 + ( + dim_t length, + dim_t width, + dim_t stride, + dim_t elem_size, + AOCL_STOR_TAG stor_scheme + ) +{ + // Extra space since packing does width in multiples of 16. + dim_t width_reorder = make_multiple_of_n( width, 16 ); + // Extra space since packing does length in multiples of 4. + dim_t length_reorder = make_multiple_of_n( length, 4 ); + + return alloc_lpgemm_obj_t_s8s8s32( length_reorder, width_reorder, stride, elem_size, stor_scheme, REORDERED ); +} + +void dealloc_lpgemm_obj_t_s8s8s32( lpgemm_obj_t* obj ) +{ + free( obj->storage.origin_buffer ); + free( obj ); +} diff --git a/addon/aocl_gemm/frame/s8s8s32/lpgemm_utils_s8.h b/addon/aocl_gemm/frame/s8s8s32/lpgemm_utils_s8.h new file mode 100644 index 0000000000..e91d0f8816 --- /dev/null +++ b/addon/aocl_gemm/frame/s8s8s32/lpgemm_utils_s8.h @@ -0,0 +1,226 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_UTILS_H_S8 +#define LPGEMM_UTILS_H_S8 + +#include "lpgemm_types.h" + +// Users of this API needs to free the allocated memory on their own. +dim_t get_64byte_aligned_memory_s8 + ( + void** original_memory, + void** aligned_memory, + int64_t allocate_size + ); + +lpgemm_obj_t* alloc_unpack_tag_lpgemm_obj_t_s8s8s32 + ( + dim_t length, + dim_t width, + dim_t stride, + dim_t elem_size, + AOCL_STOR_TAG stor_scheme + ); + +lpgemm_obj_t* alloc_pack_tag_lpgemm_obj_t_s8s8s32 + ( + dim_t length, + dim_t width, + dim_t stride, + dim_t elem_size, + AOCL_STOR_TAG stor_scheme + ); + +lpgemm_obj_t* alloc_reorder_tag_lpgemm_obj_t_s8s8s32 + ( + dim_t length, + dim_t width, + dim_t stride, + dim_t elem_size, + AOCL_STOR_TAG stor_scheme + ); + +void dealloc_lpgemm_obj_t_s8s8s32( lpgemm_obj_t* obj ); + +BLIS_INLINE void bli_param_map_char_to_lpmtag + ( + char mtag, + AOCL_MEMORY_TAG* lp_mtag + ) +{ + if ( mtag == 'n' || mtag == 'N' ) *lp_mtag = UNPACKED; + else if ( mtag == 'p' || mtag == 'P' ) *lp_mtag = PACK; + else if ( mtag == 'r' || mtag == 'R' ) *lp_mtag = REORDERED; + else + { + *lp_mtag = UNPACKED; + } +} + +BLIS_INLINE void bli_param_map_char_to_lpmat_type + ( + const char mtag, + AOCL_MATRIX_TYPE* lp_mat_type + ) +{ + if ( mtag == 'a' || mtag == 'A' ) *lp_mat_type = A_MATRIX; + else if ( mtag == 'b' || mtag == 'B' ) *lp_mat_type = B_MATRIX; + else + { + *lp_mat_type = B_MATRIX; + } +} + +BLIS_INLINE dim_t make_multiple_of_n( dim_t k, dim_t n ) +{ + if ( n <= 0 ) + { + return 0; + } + + return ( ( ( k + n - 1 ) / n ) * n ); +} + +BLIS_INLINE void lpgemm_alloc_mem_panel + ( + dim_t size_req, + packbuf_t buf_type, + mem_t* mem, + rntm_t* rntm_l + ) +{ + if ( bli_mem_is_unalloc( mem ) ) + { + bli_membrk_acquire_m + ( + rntm_l, + size_req, + buf_type, + mem + ); + } + else + { + siz_t mem_size = bli_mem_size( mem ); + if ( mem_size < size_req ) + { + bli_membrk_release( rntm_l, mem ); + bli_membrk_acquire_m + ( + rntm_l, + size_req, + buf_type, + mem + ); + } + } +} + +BLIS_INLINE dim_t get_Bpanel_width_for_kdim_traversal + ( + dim_t jc, + dim_t n, + dim_t NC, + dim_t NR + ) +{ + dim_t n_mod_NR = n % NR; + dim_t n_sub_updated = NC; + + if ( ( n % NC ) != 0 ) + { + // Only applicable to final NC part of jc loop where jc + remaining + // elements is less than NC; or when n < NC in which case panel width + // is atmost n. + dim_t n_last_loop = ( n / NC ) * NC; + if ( jc >= n_last_loop ) + { + n_sub_updated = n - n_last_loop; + if ( n_mod_NR != 0 ) + { + n_sub_updated += ( NR - n_mod_NR ); + } + } + } + + return n_sub_updated; +} + +BLIS_INLINE void get_B_panel_reordered_start_offset_width + ( + dim_t jc, + dim_t n, + dim_t NC, + dim_t NR, + dim_t* panel_start, + dim_t* panel_offset, + dim_t* panel_width, + dim_t* panel_width_kdim_trav + ) +{ + // Since n dimension is split across threads in units of NR blocks, + // it could happen that B matrix chunk for a thread may be part of + // two separate NCxKC panels. In this case nc0 is updated such that + // the jr loop only accesses the remaining portion of current NCxKC + // panel, with the next jc iteration taking care of the other panel. + // This ensures that jr loop does not cross panel boundaries. + ( *panel_start ) = ( jc / NC ) * NC; + ( *panel_offset ) = jc - ( *panel_start ); + + // Check if jc + current_panel_width (nc0) crosses panel boundaries. + if ( ( jc + ( *panel_width ) ) > ( ( *panel_start ) + NC ) ) + { + ( *panel_width ) = NC - ( *panel_offset ); + } + + ( *panel_width_kdim_trav ) = get_Bpanel_width_for_kdim_traversal + ( + jc, n, NC, NR + ); +} + +BLIS_INLINE void adjust_B_panel_reordered_jc( dim_t* jc, dim_t panel_start ) +{ + // Since n dimension is split across threads in units of NR blocks, + // it could happen that B matrix chunk for a thread may be part of + // two separate NCxKC panels. In this case jc is reset to immediate + // previous panel offset so that in the next iteration, the + // following panel belonging to the B chunk is accessed. This + // ensures that jr loop does not cross panel boundaries. + ( *jc ) = panel_start; +} + +#endif //LPGEMM_UTILS_H_S8 + diff --git a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c index 0c1df5e7c3..32615afc9e 100644 --- a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c +++ b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -168,11 +168,18 @@ BLIS_INLINE void lpgemm_pnl_wrk_heur_adjust_ic_jc_ways BLIS_INLINE void lpgemm_adjust_ic_jc_ways ( - dim_t m, - dim_t n, + const dim_t m, + const dim_t n, + const dim_t k, + const dim_t MC, + const dim_t NC, + const dim_t KC, + const dim_t MR, + const dim_t NR, dim_t* n_threads, dim_t* ic_ways, - dim_t* jc_ways + dim_t* jc_ways, + dim_t m_boost ) { const dim_t m_ic = m / ( *ic_ways ); @@ -192,16 +199,56 @@ BLIS_INLINE void lpgemm_adjust_ic_jc_ways const int64_t next_jc_work_per_thread = n_next_jc + m_prev_ic; const int64_t next_ic_work_per_thread = m_next_ic + n_prev_jc; + const dim_t MCx2 = MC * 2; + const dim_t k_factor = k / KC; + const dim_t n_jc_modulo_NR = n_jc % NR; + const dim_t n_prev_jc_modulo_NR = n_prev_jc % NR; + bool can_increase_ic = FALSE; bool can_increase_jc = FALSE; - if ( next_ic_work_per_thread <= cur_work_per_thread ) + if ( ( ( *ic_ways ) > 1 ) && ( ( *jc_ways ) < ( *n_threads ) ) ) { - can_increase_ic = TRUE; + if ( next_jc_work_per_thread < cur_work_per_thread ) + { + can_increase_jc = TRUE; + } + // Check whether m_prev_ic remains in good l2 load zone. + else if ( ( ( ( m_ic <= MC ) && ( m_prev_ic <= MC ) ) || + ( m_ic > MC ) ) && + ( ( n_jc > NR ) && ( n_next_jc == NR ) ) ) + { + can_increase_jc = TRUE; + } } - else if ( next_jc_work_per_thread < cur_work_per_thread ) + if ( ( ( *ic_ways ) < ( *n_threads ) ) && ( ( *jc_ways ) > 1) ) { - can_increase_jc = TRUE; + if ( next_ic_work_per_thread <= cur_work_per_thread ) + { + can_increase_ic = TRUE; + } + // ic adjustment towards next highest factor if it results in + // m_next_ic <= MC. This helps in reducing number of A matrix + // loads per thread to l2 from main memory. + else if ( ( m_ic > MC ) && ( m_next_ic <= MC ) && + ( m_next_ic >= MR ) && ( k_factor > 4 ) ) + { + can_increase_ic = TRUE; + } + // ic adjustment towards next highest factor resulted in better + // performance when m is sufficiently larger than n. + else if ( ( m > ( m_boost * n ) ) && ( m_ic >= MCx2 ) && + ( k_factor > 4 ) ) + { + can_increase_ic = TRUE; + } + // Performance improvement also observed when n_jc is a multiple + // of NR. + else if ( ( n_jc_modulo_NR != 0 ) && ( n_prev_jc_modulo_NR == 0 ) && + ( k_factor > 4 ) ) + { + can_increase_ic = TRUE; + } } if ( can_increase_ic ) @@ -315,8 +362,6 @@ BLIS_INLINE void lpgemm_u8s8s32o32_get_threading // If BLIS_NUM_THREADS are set, generate jc,ic from the same. bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways ); - lpgemm_adjust_ic_jc_ways( m, n, n_threads, ic_ways, jc_ways ); - lpgemm_pnl_wrk_heur_adjust_ic_jc_ways ( MR, NR, m, n, @@ -375,7 +420,7 @@ BLIS_INLINE void lpgemm_bf16bf16f32of32_get_threading { // If BLIS_NUM_THREADS are set, generate jc,ic from the same. bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways ); - lpgemm_adjust_ic_jc_ways( m, n, n_threads, ic_ways, jc_ways ); + lpgemm_pnl_wrk_heur_adjust_ic_jc_ways ( MR, NR, m, n, @@ -416,6 +461,13 @@ BLIS_INLINE void lpgemm_f32f32f32of32_get_threading const dim_t NT = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_NT, cntx ); const dim_t KT = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_KT, cntx ); + // Query the context for various blocksizes. + const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); + const dim_t MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx ); + const dim_t NC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx ); + const dim_t MC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MC, cntx ); + const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); + const dim_t MT_2 = MT / 2; *n_threads = bli_rntm_num_threads( rntm_g ); @@ -436,7 +488,12 @@ BLIS_INLINE void lpgemm_f32f32f32of32_get_threading // If BLIS_NUM_THREADS are set, generate jc,ic from the same. bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways ); - lpgemm_adjust_ic_jc_ways( m, n, n_threads, ic_ways, jc_ways ); + lpgemm_adjust_ic_jc_ways + ( + m, n, k, + MC, NC, KC, MR, NR, + n_threads, ic_ways, jc_ways, 5 + ); } else { @@ -458,13 +515,126 @@ BLIS_INLINE void lpgemm_f32f32f32of32_get_threading { if ( ( k > page_size_b_floatx2 ) || ( ( k <= page_size_b_floatx2 ) && - ( m_ic > MT_2 ) && ( n_jc >= NT ) ) ) + ( m_ic > MT_2 ) && ( n_jc >= NT ) ) ) { + bli_rntm_set_pack_b( 1, rntm_g ); bli_rntm_set_pack_a( 1, rntm_g ); } } } +BLIS_INLINE void lpgemm_s8s8s32o32_get_threading + ( + dim_t* n_threads, + dim_t* ic_ways, + dim_t* jc_ways, + dim_t m, + dim_t n, + dim_t k, + rntm_t* rntm_g + ) +{ + *n_threads = bli_rntm_num_threads( rntm_g ); + *jc_ways = bli_rntm_jc_ways( rntm_g ); + *ic_ways = bli_rntm_ic_ways( rntm_g ); + + if ( ( ( *ic_ways ) > 0 ) || ( ( *jc_ways ) > 0 ) ) + { + // If BLIS_IC_NT or JC_NT are set. + // Default cases. + *ic_ways = ( ( *ic_ways ) > 0 ) ? ( *ic_ways ) : 1; + *jc_ways = ( ( *jc_ways ) > 0 ) ? ( *jc_ways ) : 1; + + *n_threads = ( *jc_ways ) * ( *ic_ways ); + } + else if ( ( *n_threads ) > 1 ) + { + + dim_t NR = lpgemm_get_block_size_NR_global_cntx( S8S8S32OS32 ); + dim_t MR = lpgemm_get_block_size_MR_global_cntx( S8S8S32OS32 ); + + if ( n <= NR ) + { + // If n is less than micro panel dimension, allocating all threads + // to ic resulted in gains. + ( *ic_ways ) = ( *n_threads ); + ( *jc_ways ) = 1; + } + else + { + // If BLIS_NUM_THREADS are set, generate jc,ic from the same. + bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways ); + + lpgemm_pnl_wrk_heur_adjust_ic_jc_ways + ( + MR, NR, m, n, + n_threads, ic_ways, jc_ways + ); + } + } + else + { + // Setting all the values to 1 in case n_threads <= 1. This ensures + // the threading parameters are valid. + *n_threads = 1; + *jc_ways = 1; + *ic_ways = 1; + } +} + +BLIS_INLINE void lpgemm_s8s8s16o16_get_threading + ( + dim_t* n_threads, + dim_t* ic_ways, + dim_t* jc_ways, + dim_t m, + dim_t n, + dim_t k, + rntm_t* rntm_g + ) +{ + *n_threads = bli_rntm_num_threads( rntm_g ); + *jc_ways = bli_rntm_jc_ways( rntm_g ); + *ic_ways = bli_rntm_ic_ways( rntm_g ); + + if ( ( ( *ic_ways ) > 0 ) || ( ( *jc_ways ) > 0 ) ) + { + // If BLIS_IC_NT or JC_NT are set. + // Default cases. + *ic_ways = ( ( *ic_ways ) > 0 ) ? ( *ic_ways ) : 1; + *jc_ways = ( ( *jc_ways ) > 0 ) ? ( *jc_ways ) : 1; + + *n_threads = ( *jc_ways ) * ( *ic_ways ); + } + else if ( ( *n_threads ) > 1 ) + { + + dim_t NR = lpgemm_get_block_size_NR_global_cntx( S8S8S16OS16 ); + + if ( n <= NR ) + { + // If n is less than micro panel dimension, allocating all threads + // to ic resulted in gains. + ( *ic_ways ) = ( *n_threads ); + ( *jc_ways ) = 1; + } + else + { + // If BLIS_NUM_THREADS are set, generate jc,ic from the same. + bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways ); + } + } + else + { + // Setting all the values to 1 in case n_threads <= 1. This ensures + // the threading parameters are valid. + *n_threads = 1; + *jc_ways = 1; + *ic_ways = 1; + } +} + + #define GEN_LPGEMM_OPENMP_DECORATOR(A_type,B_type,C_type,LPGEMM_SFX) \ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ ( \ @@ -482,9 +652,10 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ C_type* c, \ const dim_t rs_c, \ const dim_t cs_c, \ - C_type alpha, \ - C_type beta, \ + const C_type alpha, \ + const C_type beta, \ rntm_t* rntm_g, \ + lpgemm_cntx_t* lcntx, \ lpgemm_post_op* post_op_list, \ bool c_downscale \ ) \ @@ -546,6 +717,7 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ beta, \ &rntm_l, \ &thread, \ + lcntx, \ post_op_list, c_downscale \ ); \ } \ @@ -559,6 +731,8 @@ GEN_LPGEMM_OPENMP_DECORATOR(uint8_t,int8_t,int16_t,u8s8s16o16) GEN_LPGEMM_OPENMP_DECORATOR(uint8_t,int8_t,int32_t,u8s8s32o32) GEN_LPGEMM_OPENMP_DECORATOR(bfloat16,bfloat16,float,bf16bf16f32of32) GEN_LPGEMM_OPENMP_DECORATOR(float,float,float,f32f32f32of32) +GEN_LPGEMM_OPENMP_DECORATOR(int8_t,int8_t,int32_t,s8s8s32o32) +GEN_LPGEMM_OPENMP_DECORATOR(int8_t,int8_t,int16_t,s8s8s16o16) #else @@ -579,9 +753,10 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \ C_type* c, \ const dim_t rs_c, \ const dim_t cs_c, \ - C_type alpha, \ - C_type beta, \ + const C_type alpha, \ + const C_type beta, \ rntm_t* rntm_g, \ + lpgemm_cntx_t* lcntx, \ lpgemm_post_op* post_op_list, \ bool c_downscale \ ) \ @@ -622,6 +797,7 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \ beta, \ rntm_g, \ &thread, \ + lcntx, \ post_op_list, c_downscale \ ); \ } \ @@ -630,5 +806,7 @@ GEN_LPGEMM_DECORATOR(uint8_t,int8_t,int16_t,u8s8s16o16) GEN_LPGEMM_DECORATOR(uint8_t,int8_t,int32_t,u8s8s32o32) GEN_LPGEMM_DECORATOR(bfloat16,bfloat16,float,bf16bf16f32of32) GEN_LPGEMM_DECORATOR(float,float,float,f32f32f32of32) +GEN_LPGEMM_DECORATOR(int8_t,int8_t,int32_t,s8s8s32o32) +GEN_LPGEMM_DECORATOR(int8_t,int8_t,int16_t,s8s8s16o16) #endif diff --git a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h index 8055d623e6..80c657b230 100644 --- a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h +++ b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -58,9 +58,10 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ C_type* c, \ const dim_t rs_c, \ const dim_t cs_c, \ - C_type alpha, \ - C_type beta, \ + const C_type alpha, \ + const C_type beta, \ rntm_t* rntm_g, \ + lpgemm_cntx_t* lcntx, \ lpgemm_post_op* post_op_list, \ bool c_downscale \ ); \ @@ -69,6 +70,8 @@ GEN_LPGEMM_OPENMP_DECORATOR_FN(uint8_t,int8_t,int16_t,u8s8s16o16) GEN_LPGEMM_OPENMP_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s32o32) GEN_LPGEMM_OPENMP_DECORATOR_FN(bfloat16,bfloat16,float,bf16bf16f32of32) GEN_LPGEMM_OPENMP_DECORATOR_FN(float,float,float,f32f32f32of32) +GEN_LPGEMM_OPENMP_DECORATOR_FN(int8_t,int8_t,int32_t,s8s8s32o32) +GEN_LPGEMM_OPENMP_DECORATOR_FN(int8_t,int8_t,int16_t,s8s8s16o16) #else @@ -89,9 +92,10 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \ C_type* c, \ const dim_t rs_c, \ const dim_t cs_c, \ - C_type alpha, \ - C_type beta, \ + const C_type alpha, \ + const C_type beta, \ rntm_t* rntm_g, \ + lpgemm_cntx_t* lcntx, \ lpgemm_post_op* post_op_list, \ bool c_downscale \ ); \ @@ -100,6 +104,8 @@ GEN_LPGEMM_DECORATOR_FN(uint8_t,int8_t,int16_t,u8s8s16o16) GEN_LPGEMM_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s32o32) GEN_LPGEMM_DECORATOR_FN(bfloat16,bfloat16,float,bf16bf16f32of32) GEN_LPGEMM_DECORATOR_FN(float,float,float,f32f32f32of32) +GEN_LPGEMM_DECORATOR_FN(int8_t,int8_t,int32_t,s8s8s32o32) +GEN_LPGEMM_DECORATOR_FN(int8_t,int8_t,int16_t,s8s8s16o16) #endif diff --git a/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.c b/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.c index 0b55f31215..2786117131 100644 --- a/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.c +++ b/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -39,19 +39,23 @@ void aocl_reorderb_nr32_u8s8s16o16 ( - lpgemm_obj_t *b, - lpgemm_obj_t *b_reorder + lpgemm_obj_t* b, + lpgemm_obj_t* b_reorder, + rntm_t* rntm, + lpgemm_cntx_t* lcntx ) { - const dim_t NC = lpgemm_get_block_size_NC_global_cntx(U8S8S16OS16); - const dim_t KC = lpgemm_get_block_size_KC_global_cntx(U8S8S16OS16); - const dim_t NR = lpgemm_get_block_size_NR_global_cntx(U8S8S16OS16); + dim_t NC = lcntx->blksz.NC; + dim_t KC = lcntx->blksz.KC; + dim_t NR = lcntx->blksz.NR; // Extracting the matrix properties from the lpgemm object dim_t rs_b = b->rs; dim_t n = b->width; dim_t k = b->length; + lpgemm_mod_block_size_s16(0, n, k, NULL, &NC, &KC); + dim_t rs_b_reorder; dim_t cs_b_reorder; @@ -60,12 +64,7 @@ void aocl_reorderb_nr32_u8s8s16o16 // Making multiple of 2 to suit k in vpmaddubsw k_updated += (k_updated & 0x1); - // Initialize a local runtime with global settings if necessary. Note - // that in the case that a runtime is passed in, we make a local copy. - rntm_t rntm_g; - bli_rntm_init_from_global( &rntm_g ); - - dim_t n_threads = bli_rntm_num_threads( &rntm_g ); + dim_t n_threads = bli_rntm_num_threads( rntm ); n_threads = ( n_threads > 0 ) ? n_threads : 1; #ifdef BLIS_ENABLE_OPENMP @@ -146,7 +145,7 @@ void aocl_reorderb_nr32_u8s8s16o16 // st = ( jc_cur_loop * k ) // + ( n_sub_updated * pc ) // + ( NC' * kc0_updated) - packb_nr32_u8s8s16o16 + ( ( packb_s16 )lcntx->packb_fun_ptr ) ( ( ( ( int8_t* )b_reorder->storage.aligned_buffer ) + ( jc_cur_loop * k_updated ) + ( n_sub_updated * pc ) + diff --git a/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.h b/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.h index 6018978bc7..65647d9903 100644 --- a/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.h +++ b/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,8 +38,10 @@ void aocl_reorderb_nr32_u8s8s16o16 ( - lpgemm_obj_t *b, - lpgemm_obj_t *b_reorder + lpgemm_obj_t* b, + lpgemm_obj_t* b_reorder, + rntm_t* rntm, + lpgemm_cntx_t* lcntx ); #endif // LPGEMM_REORDER_H diff --git a/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c b/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c index b8f5115429..5a03493a44 100644 --- a/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c +++ b/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,15 +40,39 @@ #include "lpgemm_config.h" #include "lpgemm_thrinfo_utils.h" +// Kernel function prototypes +typedef void (*lpgemm_rowvar_s16) + ( + const dim_t, + const dim_t, + const dim_t, + const uint8_t*, + const dim_t, + const dim_t, + const dim_t, + const int8_t*, + const dim_t, + const dim_t, + int16_t*, + const dim_t, + const dim_t, + const int16_t, + const int16_t, + lpgemm_post_op*, + lpgemm_post_op_attr + ); + // B should always be packed. LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) { - const dim_t NC = lpgemm_get_block_size_NC_global_cntx( U8S8S16OS16 ); - const dim_t KC = lpgemm_get_block_size_KC_global_cntx( U8S8S16OS16 ); - const dim_t MC = lpgemm_get_block_size_MC_global_cntx( U8S8S16OS16 ); - const dim_t NR = lpgemm_get_block_size_NR_global_cntx( U8S8S16OS16 ); - const dim_t MR = lpgemm_get_block_size_MR_global_cntx( U8S8S16OS16 ); - + dim_t NC = lcntx->blksz.NC; + dim_t KC = lcntx->blksz.KC; + dim_t MC = lcntx->blksz.MC; + const dim_t NR = lcntx->blksz.NR; + const dim_t MR = lcntx->blksz.MR; + + lpgemm_mod_block_size_s16(m, n, k, &MC, &NC, &KC); + if (mtag_b == UNPACKED) { // Error: can only work with packed B now. @@ -82,9 +106,22 @@ LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) // Making multiple of 2 to suit k in vpmaddubsw dim_t k_updated = make_multiple_of_n( k, 2 ); - // Is required to decide whether to apply post ops or not. + // To decide whether to apply post ops or not. bool is_last_k = FALSE; + // To decide whether to use original s8 C or temp buffer for beta scale. + bool is_first_k = FALSE; + + lpgemm_post_op_attr post_ops_attr; + if ( c_downscale == TRUE ) + { + post_ops_attr.buf_downscale = c; + } + else + { + post_ops_attr.buf_downscale = NULL; + } + // Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t. thrinfo_t thread_jc; thrinfo_t thread_ic; @@ -123,37 +160,24 @@ LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) // Temp accumulaton buffer for C allocation. else if ( c_downscale == TRUE ) { - mem_scale_c_size_req = sizeof( int16_t ) * nc0 * ( ic_end - ic_start ); - - lpgemm_alloc_mem_panel - ( - mem_scale_c_size_req, BLIS_BUFFER_FOR_C_PANEL, - &mem_scale_c, rntm - ); + // Buffer memory is only required if output needs to be + // persisted across iterations of the pc/KC loop. + // It was observed that the locks used while checking out + // a buffer from memory pool had an impact on performance + // and is better to not checkout if k <= KC. + if ( k > KC ) + { + mem_scale_c_size_req = sizeof( int16_t ) * nc0 * ( ic_end - ic_start ); - temp_scal_c_buffer_u8s8s16o16 = bli_mem_buffer( &mem_scale_c ); + lpgemm_alloc_mem_panel + ( + mem_scale_c_size_req, BLIS_BUFFER_FOR_C_PANEL, + &mem_scale_c, rntm + ); - c_use_jc = ( int16_t* )temp_scal_c_buffer_u8s8s16o16; + temp_scal_c_buffer_u8s8s16o16 = bli_mem_buffer( &mem_scale_c ); - if ( beta != 0 ) - { - dim_t i_temp = 0; - dim_t j_temp = 0; - // Upscale out C to temporary C matrix. - for ( dim_t i_dscale = ic_start; i_dscale < ic_end; ++i_dscale ) - { - j_temp = 0; - for ( dim_t j_dscale = jc; j_dscale < ( jc + nc0 ); ++j_dscale ) - { - *( temp_scal_c_buffer_u8s8s16o16 + - ( nc0 * i_temp ) + j_temp ) = - ( int16_t )( *( ( ( int8_t* )c ) + - ( rs_c * i_dscale ) + j_dscale ) ); - - j_temp++; - } - i_temp++; - } + c_use_jc = ( int16_t* )temp_scal_c_buffer_u8s8s16o16; } // The temp c buffer stride is modified as opposed to original C matrix. @@ -165,7 +189,12 @@ LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) int16_t beta0 = (pc == 0) ? beta : 1; dim_t kc0 = bli_min((k - pc), KC); + // No parallelization in k dim, k always starts at 0. + is_first_k = ( pc == 0 ) ? ( TRUE ) : ( FALSE ); + post_ops_attr.is_first_k = is_first_k; + is_last_k = ( ( pc + KC ) >= k ) ? ( TRUE ) : ( FALSE ); + post_ops_attr.is_last_k = is_last_k; // kc0 needs to be a multiple of 2 so that it can be // used with vpmaddubsw instruction. Padding is added in @@ -200,9 +229,11 @@ LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) // All threads in work group should wait till chief thread has // finished allocating the packing buffers. - bli_thrcomm_barrier( - bli_thread_ocomm_id(&thread_ic), - &thread->comm[jc_work_id]); + bli_thrcomm_barrier + ( + bli_thread_ocomm_id(&thread_ic), + &thread->comm[jc_work_id] + ); pack_b_buffer_u8s8s16o16 = (int8_t *)thread->comm[jc_work_id].sent_object; @@ -224,9 +255,9 @@ LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) if ((jc_packb_end > jc_packb_start) && (jc_packb_start < (jc + nc0))) { - packb_nr32_u8s8s16o16 + ( ( packb_s16 )lcntx->packb_fun_ptr ) ( - pack_b_buffer_u8s8s16o16 + + pack_b_buffer_u8s8s16o16 + (jc_packb_start * kc0_updated), (b + (rs_b * pc) + (cs_b * jc) + (cs_b * jc_packb_start)), @@ -237,7 +268,7 @@ LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) } else { - get_packb_nr32_u8s8s16o16_strides(&rs_b_use, &cs_b_use); + lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use ); } // All threads in work group should wait till B matrix packing @@ -260,7 +291,7 @@ LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) (n_sub_updated * pc) + (jc_cur_loop_rem * kc0_updated); - get_packb_nr32_u8s8s16o16_strides(&rs_b_use, &cs_b_use); + lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use ); } else { @@ -292,15 +323,20 @@ LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) { dim_t nr0 = bli_min((nc0 - jr), NR); + // Post ops meta attributes. + post_ops_attr.post_op_c_i = ic; + post_ops_attr.post_op_c_j = ( jc + jr ); + post_ops_attr.rs_c_downscale = rs_c_downscale; + // Calls for reorder B - lpgemm_rowvar_u8s8s16o16_6x32 + ( ( lpgemm_rowvar_s16 )lcntx->kern_fun_ptr ) ( mc0, nr0, kc0, a_use, rs_a_use, cs_a_use, a_block_stride, (b_use + (jr * kc0_updated)), rs_b_use, cs_b_use, (c_use_ic + jr), rs_c_use, 1, alpha, beta0, - is_last_k, ic, ( jc + jr ), post_op_list, rs_c_downscale + post_op_list, post_ops_attr ); } } diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.c b/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.c index 746a134100..224e0791ff 100644 --- a/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.c +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -41,13 +41,15 @@ void reorderb_nr64_u8s8s32o32 ( - lpgemm_obj_t* b, - lpgemm_obj_t* b_reorder + lpgemm_obj_t* b, + lpgemm_obj_t* b_reorder, + rntm_t* rntm, + lpgemm_cntx_t* lcntx ) { - dim_t NC = lpgemm_get_block_size_NC_global_cntx( U8S8S32OS32 ); - dim_t NR = lpgemm_get_block_size_NR_global_cntx( U8S8S32OS32 ); - dim_t KC = lpgemm_get_block_size_KC_global_cntx( U8S8S32OS32 ); + dim_t NC = lcntx->blksz.NC; + dim_t KC = lcntx->blksz.KC; + dim_t NR = lcntx->blksz.NR; dim_t rs_b = b->rs; dim_t rs_b_reorder; @@ -62,12 +64,7 @@ void reorderb_nr64_u8s8s32o32 // buffer needs to be updated. dim_t k_updated = make_multiple_of_n( k, 4 ); - // Initialize a local runtime with global settings if necessary. Note - // that in the case that a runtime is passed in, we make a local copy. - rntm_t rntm_g; - bli_rntm_init_from_global( &rntm_g ); - - dim_t n_threads = bli_rntm_num_threads( &rntm_g ); + dim_t n_threads = bli_rntm_num_threads( rntm ); n_threads = ( n_threads > 0 ) ? n_threads : 1; #ifdef BLIS_ENABLE_OPENMP @@ -148,8 +145,7 @@ void reorderb_nr64_u8s8s32o32 // st = ( jc_cur_loop * k ) // + ( n_sub_updated * pc ) // + ( NC' * kc0_updated) -#ifdef BLIS_KERNELS_ZEN4 - packb_nr64_u8s8s32o32 + ( ( packb_s32 )lcntx->packb_fun_ptr ) ( ( ( ( int8_t* )b_reorder->storage.aligned_buffer ) + ( jc_cur_loop * k_updated ) + ( n_sub_updated * pc ) + @@ -158,14 +154,6 @@ void reorderb_nr64_u8s8s32o32 ( rs_b * pc ) + jc ), rs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder ); -#else - // Silence compiler warnings. - rs_b_reorder = 0; - cs_b_reorder = 0; - ( void )kc0_updated; - ( void )k_updated; - ( void )rs_b; -#endif } adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); @@ -179,12 +167,14 @@ void reorderb_nr64_u8s8s32o32 void reordera_mr6_u8s8s32o32 ( - lpgemm_obj_t* a, - lpgemm_obj_t* a_reorder + lpgemm_obj_t* a, + lpgemm_obj_t* a_reorder, + rntm_t* rntm, + lpgemm_cntx_t* lcntx ) { - dim_t MC = lpgemm_get_block_size_MC_global_cntx( U8S8S32OS32 ); - dim_t KC = lpgemm_get_block_size_KC_global_cntx( U8S8S32OS32 ); + dim_t MC = lcntx->blksz.MC; + dim_t KC = lcntx->blksz.KC; dim_t rs_a = a->rs; dim_t rs_a_reorder; @@ -207,21 +197,13 @@ void reordera_mr6_u8s8s32o32 { dim_t mc0 = bli_min( ( m - ic ), MC ); -#ifdef BLIS_KERNELS_ZEN4 - packa_k64_u8s8s32o32 + ( ( packa_s32 )lcntx->packa_fun_ptr ) ( ( ( ( uint8_t* )a_reorder->storage.aligned_buffer ) + ( pc * m ) + ( ic * kc0_updated ) ), ( ( ( uint8_t* )a->storage.aligned_buffer ) + ( rs_a * ic ) + pc ), rs_a, mc0, kc0, &rs_a_reorder, &cs_a_reorder ); -#else - rs_a_reorder = 0; - cs_a_reorder = 0; - ( void )kc0_updated; - ( void )rs_a; - ( void )mc0; -#endif } } diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.h b/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.h index eb8dad9cfc..232b02238d 100644 --- a/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.h +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -39,14 +39,18 @@ void reorderb_nr64_u8s8s32o32 ( - lpgemm_obj_t* b, - lpgemm_obj_t* b_reorder + lpgemm_obj_t* b, + lpgemm_obj_t* b_reorder, + rntm_t* rntm, + lpgemm_cntx_t* lcntx ); void reordera_mr6_u8s8s32o32 ( - lpgemm_obj_t* a, - lpgemm_obj_t* a_reorder + lpgemm_obj_t* a, + lpgemm_obj_t* a_reorder, + rntm_t* rntm, + lpgemm_cntx_t* lcntx ); #endif //LPGEMM_REORDER_H diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c b/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c index 82a745fcf5..feedda0212 100644 --- a/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -41,14 +41,36 @@ #include "lpgemm_thrinfo_utils.h" #include "lpgemm_config.h" +// Kernel function prototypes +typedef void (*lpgemm_rowvar_s32) + ( + const dim_t, + const dim_t, + const dim_t, + const uint8_t*, + const dim_t, + const dim_t, + const dim_t, + const int8_t*, + const dim_t, + const dim_t, + int32_t*, + const dim_t, + const dim_t, + const int32_t, + const int32_t, + lpgemm_post_op*, + lpgemm_post_op_attr + ); + // B should always be packed. LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) { - dim_t NC = lpgemm_get_block_size_NC_global_cntx( U8S8S32OS32 ); - dim_t KC = lpgemm_get_block_size_KC_global_cntx( U8S8S32OS32 ); - dim_t MC = lpgemm_get_block_size_MC_global_cntx( U8S8S32OS32 ); - dim_t NR = lpgemm_get_block_size_NR_global_cntx( U8S8S32OS32 ); - dim_t MR = lpgemm_get_block_size_MR_global_cntx( U8S8S32OS32 ); + dim_t NC = lcntx->blksz.NC; + dim_t KC = lcntx->blksz.KC; + dim_t MC = lcntx->blksz.MC; + dim_t NR = lcntx->blksz.NR; + dim_t MR = lcntx->blksz.MR; if ( mtag_b == UNPACKED ) { @@ -93,9 +115,22 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) // buffer needs to be updated. dim_t k_updated = make_multiple_of_n( k, 4 ); - // Is required to decide whether to apply post ops or not. + // To decide whether to apply post ops or not. bool is_last_k = FALSE; + // To decide whether to use original s8 C or temp buffer for beta scale. + bool is_first_k = FALSE; + + lpgemm_post_op_attr post_ops_attr; + if ( c_downscale == TRUE ) + { + post_ops_attr.buf_downscale = c; + } + else + { + post_ops_attr.buf_downscale = NULL; + } + // Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t. thrinfo_t thread_jc; thrinfo_t thread_ic; @@ -115,7 +150,7 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) dim_t jc_cur_loop = jc; dim_t jc_cur_loop_rem = 0; - dim_t n_sub_updated; + dim_t n_sub_updated = 0; if ( mtag_b == REORDERED ) { @@ -134,37 +169,24 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) // Temp accumulaton buffer for C allocation. else if ( c_downscale == TRUE ) { - mem_scale_c_size_req = sizeof( int32_t ) * nc0 * ( ic_end - ic_start ); - - lpgemm_alloc_mem_panel - ( - mem_scale_c_size_req, BLIS_BUFFER_FOR_C_PANEL, - &mem_scale_c, rntm - ); + // Buffer memory is only required if output needs to be + // persisted across iterations of the pc/KC loop. + // It was observed that the locks used while checking out + // a buffer from memory pool had an impact on performance + // and is better to not checkout if k <= KC. + if ( k > KC ) + { + mem_scale_c_size_req = sizeof( int32_t ) * nc0 * ( ic_end - ic_start ); - temp_scal_c_buffer_u8s8s32o32 = bli_mem_buffer( &mem_scale_c ); + lpgemm_alloc_mem_panel + ( + mem_scale_c_size_req, BLIS_BUFFER_FOR_C_PANEL, + &mem_scale_c, rntm + ); - c_use_jc = ( int32_t* )temp_scal_c_buffer_u8s8s32o32; + temp_scal_c_buffer_u8s8s32o32 = bli_mem_buffer( &mem_scale_c ); - if ( beta != 0 ) - { - dim_t i_temp = 0; - dim_t j_temp = 0; - // Upscale out C to temporary C matrix. - for ( dim_t i_dscale = ic_start; i_dscale < ic_end; ++i_dscale ) - { - j_temp = 0; - for ( dim_t j_dscale = jc; j_dscale < ( jc + nc0 ); ++j_dscale ) - { - *( temp_scal_c_buffer_u8s8s32o32 + - ( nc0 * i_temp ) + j_temp ) = - ( int32_t )( *( ( ( int8_t* )c ) + - ( rs_c * i_dscale ) + j_dscale ) ); - - j_temp++; - } - i_temp++; - } + c_use_jc = ( int32_t* )temp_scal_c_buffer_u8s8s32o32; } // The temp c buffer stride is modified as opposed to original C matrix. @@ -183,7 +205,12 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) // needs to be updated. dim_t kc0_updated = make_multiple_of_n( kc0, 4 ); + // No parallelization in k dim, k always starts at 0. + is_first_k = ( pc == 0 ) ? ( TRUE ) : ( FALSE ); + post_ops_attr.is_first_k = is_first_k; + is_last_k = ( ( pc + KC ) >= k ) ? ( TRUE ) : ( FALSE ); + post_ops_attr.is_last_k = is_last_k; if ( mtag_b == PACK ) { @@ -239,8 +266,7 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) if ( ( jc_packb_end > jc_packb_start ) && ( jc_packb_start < ( jc + nc0 ) ) ) { -#ifdef BLIS_KERNELS_ZEN4 - packb_nr64_u8s8s32o32 + ( ( packb_s32 )lcntx->packb_fun_ptr ) ( pack_b_buffer_u8s8s32o32 + ( jc_packb_start * kc0_updated ), ( b + ( rs_b * pc ) + ( cs_b * jc ) + @@ -248,11 +274,10 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) ( jc_packb_end - jc_packb_start ), kc0, &rs_b_use, &cs_b_use ); -#endif } else { - get_packb_nr64_u8s8s32o32_strides( &rs_b_use, &cs_b_use ); + lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use ); } // All threads in work group should wait till B matrix packing @@ -274,7 +299,7 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) ( n_sub_updated * pc ) + ( jc_cur_loop_rem * kc0_updated ); - get_packb_nr64_u8s8s32o32_strides( &rs_b_use, &cs_b_use ); + lpgemm_get_packb_strides( lcntx, &rs_b_use, &cs_b_use ); } else { @@ -310,21 +335,19 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) ); pack_a_buffer_u8s8s32o32 = ( uint8_t* )bli_mem_buffer( &mem_a ); -#ifdef BLIS_KERNELS_ZEN4 - packa_k64_u8s8s32o32 + ( ( packa_s32 )lcntx->packa_fun_ptr ) ( pack_a_buffer_u8s8s32o32, ( a + ( rs_a * ic ) + pc ), rs_a, mc0, kc0, &rs_a_use, &cs_a_use ); -#endif a_use = pack_a_buffer_u8s8s32o32; a_block_stride = kc0_updated; } else if ( mtag_a == REORDERED ) { - get_packa_k64_u8s8s32o32_strides( &rs_a_use, &cs_a_use ); + lpgemm_get_packa_strides( lcntx, &rs_a_use, &cs_a_use ); a_use = a + ( pc * m ) + ( kc0_updated * ic ); a_block_stride = kc0_updated; } @@ -343,28 +366,21 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) { dim_t nr0 = bli_min( ( nc0 - jr ), NR ); -#ifdef BLIS_KERNELS_ZEN4 + // Post ops meta attributes. + post_ops_attr.post_op_c_i = ic; + post_ops_attr.post_op_c_j = ( jc + jr ); + post_ops_attr.rs_c_downscale = rs_c_downscale; + // Reorder/Packed B, Reorder/Packed/Unpacked A call. - lpgemm_rowvar_u8s8s32o32_6x64 + ( ( lpgemm_rowvar_s32 )lcntx->kern_fun_ptr ) ( mc0, nr0, kc0, a_use, rs_a_use, cs_a_use, a_block_stride, ( b_use + ( jr * kc0_updated ) ), rs_b_use, cs_b_use, ( c_use_ic + jr ), rs_c_use, 1, alpha, beta0, - is_last_k, ic, ( jc + jr ), post_op_list, rs_c_downscale + post_op_list, post_ops_attr ); -#else - // Silence compiler warnings. - ( void )b_use; - ( void )a_block_stride; - ( void )rs_c_downscale; - ( void )is_last_k; - ( void )c_use_ic; - ( void )a_use; - ( void )beta0; - ( void )nr0; -#endif } } } diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c deleted file mode 100644 index 65a4963dcb..0000000000 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c +++ /dev/null @@ -1,1146 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS dim_tERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include - -#include "blis.h" -#include "lpgemm_kernels.h" -#include "lpgemm_f32_kern_macros.h" - -#ifdef BLIS_KERNELS_ZEN4 -// 6x64 bf16 kernel -LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_6x64_DISABLE, - &&POST_OPS_BIAS_6x64, - &&POST_OPS_RELU_6x64, - &&POST_OPS_RELU_SCALE_6x64, - &&POST_OPS_DOWNSCALE_6x64 - }; - dim_t MR = 6; - dim_t NR = 64; - - dim_t m_full_pieces = m0 / MR; - dim_t m_full_pieces_loop_limit = m_full_pieces * MR; - dim_t m_partial_pieces = m0 % MR; - - dim_t k_full_pieces = k0 / 2; - dim_t k_partial_pieces = k0 % 2; - - int32_t a_kfringe_buf = 0; - - if ( n0 < NR ) - { - dim_t n0_rem = n0 % 16; - - // Split dim_to multiple smaller fringe kernels, so as to maximize - // vectorization. Any n0 < NR(64) can be expressed as n0 = 48 + n` - // or n0 = 32 + n` or n0 = 16 + n`, where n` < 16. - dim_t n0_48 = n0 / 48; - dim_t n0_32 = n0 / 32; - dim_t n0_16 = n0 / 16; - - // KC when not multiple of 2 will have padding to make it multiple of - // 2 in packed buffer. Also the k0 cannot be passed as the updated - // value since A matrix is not packed and requires original k0. - dim_t k0_updated = k0; - k0_updated += (k0_updated & 0x1); - - if ( n0_48 == 1 ) - { - lpgemm_rowvar_bf16bf16f32of32_6x48 - ( - m0, k0, - a, rs_a, cs_a, ps_a, - b, ( ( rs_b / 4 ) * 3 ), cs_b, - c, rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - - b = b + ( 48 * k0_updated ); // k0x48 packed contiguosly. - c = c + 48; - post_op_c_j += 48; - } - - else if ( n0_32 == 1 ) - { - lpgemm_rowvar_bf16bf16f32of32_6x32 - ( - m0, k0, - a, rs_a, cs_a, ps_a, - b, ( ( rs_b / 4 ) * 2 ), cs_b, - c, rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - - b = b + ( 32 * k0_updated ); // k0x32 packed contiguosly. - c = c + 32; - post_op_c_j += 32; - } - - else if ( n0_16 == 1 ) - { - lpgemm_rowvar_bf16bf16f32of32_6x16 - ( - m0, k0, - a, rs_a, cs_a, ps_a, - b, ( ( rs_b / 4 ) * 1 ), cs_b, - c, rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - - b = b + ( 16 * k0_updated ); // k0x16 packed contiguosly. - c = c + 16; - post_op_c_j += 16; - } - - if ( n0_rem > 0 ) - { - lpgemm_rowvar_bf16bf16f32of32_6xlt16 - ( - m0, k0, - a, rs_a, cs_a, ps_a, - b, ( ( rs_b / 4 ) * 1 ), cs_b, - c, rs_c, - alpha, beta, n0_rem, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - - // No leftover fringe after this podint. - } - return; - } - - // B matrix storage bfloat type - __m512bh b0; - __m512bh b1; - __m512bh b2; - __m512bh b3; - - // A matrix storage bfloat type - __m512bh a_bf16_0; - __m512bh a_bf16_1; - - for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) - { - // Registers to use for accumulating C. - __m512 c_float_0p0 = _mm512_setzero_ps(); - __m512 c_float_0p1 = _mm512_setzero_ps(); - __m512 c_float_0p2 = _mm512_setzero_ps(); - __m512 c_float_0p3 = _mm512_setzero_ps(); - - __m512 c_float_1p0 = _mm512_setzero_ps(); - __m512 c_float_1p1 = _mm512_setzero_ps(); - __m512 c_float_1p2 = _mm512_setzero_ps(); - __m512 c_float_1p3 = _mm512_setzero_ps(); - - __m512 c_float_2p0 = _mm512_setzero_ps(); - __m512 c_float_2p1 = _mm512_setzero_ps(); - __m512 c_float_2p2 = _mm512_setzero_ps(); - __m512 c_float_2p3 = _mm512_setzero_ps(); - - __m512 c_float_3p0 = _mm512_setzero_ps(); - __m512 c_float_3p1 = _mm512_setzero_ps(); - __m512 c_float_3p2 = _mm512_setzero_ps(); - __m512 c_float_3p3 = _mm512_setzero_ps(); - - __m512 c_float_4p0 = _mm512_setzero_ps(); - __m512 c_float_4p1 = _mm512_setzero_ps(); - __m512 c_float_4p2 = _mm512_setzero_ps(); - __m512 c_float_4p3 = _mm512_setzero_ps(); - - __m512 c_float_5p0 = _mm512_setzero_ps(); - __m512 c_float_5p1 = _mm512_setzero_ps(); - __m512 c_float_5p2 = _mm512_setzero_ps(); - __m512 c_float_5p3 = _mm512_setzero_ps(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - // The instructions are arranged in a mixed way to reduce data - // chain dependencies. - - b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+2] - a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )(a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); - b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 3 ) ); - - // Perform column direction mat-mul with k = 2. - // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] - c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); - - // Broadcast a[1,kr:kr+2]. - a_bf16_1 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); - c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); - c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); - - // Perform column direction mat-mul with k = 2. - // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] - c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); - - // Broadcast a[2,kr:kr+2]. - a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); - - c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); - c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); - c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); - - // Perform column direction mat-mul with k = 2. - // c[2,0-63] = a[2,kr:kr+2]*b[kr:kr+2,0-63] - c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); - - // Broadcast a[3,kr:kr+2]. - a_bf16_1 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); - - c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); - c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); - c_float_2p3 = _mm512_dpbf16_ps( c_float_2p3, a_bf16_0, b3 ); - - // Perform column direction mat-mul with k = 2. - // c[3,0-63] = a[3,kr:kr+2]*b[kr:kr+2,0-63] - c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_1, b0 ); - - // Broadcast a[4,kr:kr+2]. - a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); - - c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_1, b1 ); - c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_1, b2 ); - c_float_3p3 = _mm512_dpbf16_ps( c_float_3p3, a_bf16_1, b3 ); - - // Perform column direction mat-mul with k = 2. - // c[4,0-63] = a[4,kr:kr+2]*b[kr:kr+2,0-63] - c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); - - // Broadcast a[5,kr:kr+2]. - a_bf16_1 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); - - c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); - c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); - c_float_4p3 = _mm512_dpbf16_ps( c_float_4p3, a_bf16_0, b3 ); - - // Perform column direction mat-mul with k = 2. - // c[5,0-63] = a[5,kr:kr+2]*b[kr:kr+2,0-63] - c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_1, b0 ); - c_float_5p1 = _mm512_dpbf16_ps( c_float_5p1, a_bf16_1, b1 ); - c_float_5p2 = _mm512_dpbf16_ps( c_float_5p2, a_bf16_1, b2 ); - c_float_5p3 = _mm512_dpbf16_ps( c_float_5p3, a_bf16_1, b3 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); - b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); - - // Perform column direction mat-mul with k = 2. - // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] - c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); - - // Broadcast a[1,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_1 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); - c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); - c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); - - // Perform column direction mat-mul with k = 2. - // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] - c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); - - // Broadcast a[2,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); - c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); - c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); - - // Perform column direction mat-mul with k = 2. - // c[2,0-63] = a[2,kr:kr+2]*b[kr:kr+2,0-63] - c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); - - // Broadcast a[3,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_1 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); - c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); - c_float_2p3 = _mm512_dpbf16_ps( c_float_2p3, a_bf16_0, b3 ); - - // Perform column direction mat-mul with k = 2. - // c[3,0-63] = a[3,kr:kr+2]*b[kr:kr+2,0-63] - c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_1, b0 ); - - // Broadcast a[4,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_1, b1 ); - c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_1, b2 ); - c_float_3p3 = _mm512_dpbf16_ps( c_float_3p3, a_bf16_1, b3 ); - - // Perform column direction mat-mul with k = 2. - // c[4,0-63] = a[4,kr:kr+2]*b[kr:kr+2,0-63] - c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); - - // Broadcast a[5,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_1 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); - c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); - c_float_4p3 = _mm512_dpbf16_ps( c_float_4p3, a_bf16_0, b3 ); - - // Perform column direction mat-mul with k = 2. - // c[5,0-63] = a[5,kr:kr+2]*b[kr:kr+2,0-63] - c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_1, b0 ); - c_float_5p1 = _mm512_dpbf16_ps( c_float_5p1, a_bf16_1, b1 ); - c_float_5p2 = _mm512_dpbf16_ps( c_float_5p2, a_bf16_1, b2 ); - c_float_5p3 = _mm512_dpbf16_ps( c_float_5p3, a_bf16_1, b3 ); - } - - // Load alpha and beta - __m512 selector1 = _mm512_set1_ps ( alpha ); - __m512 selector2 = _mm512_set1_ps ( beta ); - - // Scale by alpha - c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); - c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); - c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); - c_float_0p3 = _mm512_mul_ps( selector1, c_float_0p3 ); - - c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); - c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); - c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); - c_float_1p3 = _mm512_mul_ps( selector1, c_float_1p3 ); - - c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); - c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); - c_float_2p2 = _mm512_mul_ps( selector1, c_float_2p2 ); - c_float_2p3 = _mm512_mul_ps( selector1, c_float_2p3 ); - - c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); - c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); - c_float_3p2 = _mm512_mul_ps( selector1, c_float_3p2 ); - c_float_3p3 = _mm512_mul_ps( selector1, c_float_3p3 ); - - c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); - c_float_4p1 = _mm512_mul_ps( selector1, c_float_4p1 ); - c_float_4p2 = _mm512_mul_ps( selector1, c_float_4p2 ); - c_float_4p3 = _mm512_mul_ps( selector1, c_float_4p3 ); - - c_float_5p0 = _mm512_mul_ps( selector1, c_float_5p0 ); - c_float_5p1 = _mm512_mul_ps( selector1, c_float_5p1 ); - c_float_5p2 = _mm512_mul_ps( selector1, c_float_5p2 ); - c_float_5p3 = _mm512_mul_ps( selector1, c_float_5p3 ); - - // Scale C by beta. - if ( beta != 0 ) - { - // c[0,0-15] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - - // c[0, 16-31] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); - - // c[0,32-47] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 2*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); - - // c[0,48-63] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 3*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); - - // c[1,0-15] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - - // c[1,16-31] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); - - // c[1,32-47] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 2*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_1p2 = _mm512_add_ps( selector1, c_float_1p2 ); - - // c[1,48-63] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 3*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_1p3 = _mm512_add_ps( selector1, c_float_1p3 ); - - // c[2,0-15] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); - - // c[2,16-31] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_2p1 = _mm512_add_ps( selector1, c_float_2p1 ); - - // c[2,32-47] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 2*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_2p2 = _mm512_add_ps( selector1, c_float_2p2 ); - - // c[2,48-63] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 3*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_2p3 = _mm512_add_ps( selector1, c_float_2p3 ); - - // c[3,0-15] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); - - // c[3,16-31] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_3p1 = _mm512_add_ps( selector1, c_float_3p1 ); - - // c[3,32-47] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 2*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_3p2 = _mm512_add_ps( selector1, c_float_3p2 ); - - // c[3,48-63] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 3*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_3p3 = _mm512_add_ps( selector1, c_float_3p3 ); - - // c[4,0-15] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); - - // c[4,16-31] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_4p1 = _mm512_add_ps( selector1, c_float_4p1 ); - - // c[4,32-47] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 2*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_4p2 = _mm512_add_ps( selector1, c_float_4p2 ); - - // c[4,48-63] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 3*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_4p3 = _mm512_add_ps( selector1, c_float_4p3 ); - - // c[5,0-15] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); - - // c[5,16-31] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_5p1 = _mm512_add_ps( selector1, c_float_5p1 ); - - // c[5,32-47] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 2*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_5p2 = _mm512_add_ps( selector1, c_float_5p2 ); - - // c[5,48-63] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 3*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_5p3 = _mm512_add_ps( selector1, c_float_5p3 ); - } - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_6x64: - { - __m512 selector3; - __m512 selector4; - - if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || - ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) - { - selector1 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - selector3 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 2 * 16 ) ); - selector4 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 3 * 16 ) ); - - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - - // c[0, 16-31] - c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); - - // c[0,32-47] - c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); - - // c[0,48-63] - c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); - - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - - // c[1, 16-31] - c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); - - // c[1,32-47] - c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); - - // c[1,48-63] - c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); - - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); - - // c[2, 16-31] - c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); - - // c[2,32-47] - c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); - - // c[2,48-63] - c_float_2p3 = _mm512_add_ps( selector4, c_float_2p3 ); - - // c[3,0-15] - c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); - - // c[3, 16-31] - c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); - - // c[3,32-47] - c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); - - // c[3,48-63] - c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); - - // c[4,0-15] - c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); - - // c[4, 16-31] - c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); - - // c[4,32-47] - c_float_4p2 = _mm512_add_ps( selector3, c_float_4p2 ); - - // c[4,48-63] - c_float_4p3 = _mm512_add_ps( selector4, c_float_4p3 ); - - // c[5,0-15] - c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); - - // c[5, 16-31] - c_float_5p1 = _mm512_add_ps( selector2, c_float_5p1 ); - - // c[5,32-47] - c_float_5p2 = _mm512_add_ps( selector3, c_float_5p2 ); - - // c[5,48-63] - c_float_5p3 = _mm512_add_ps( selector4, c_float_5p3 ); - } - else - { - // If original output was columns major, then by the time - // kernel sees it, the matrix would be accessed as if it were - // transposed. Due to this the bias array will be accessed by - // the ic index, and each bias element corresponds to an - // entire row of the transposed output array, instead of an - // entire column. - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + - post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + - post_op_c_i + 1 ) ); - selector3 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + - post_op_c_i + 2 ) ); - selector4 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + - post_op_c_i + 3 ) ); - __m512 selector5 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + - post_op_c_i + 4 ) ); - __m512 selector6 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + - post_op_c_i + 5 ) ); - - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - - // c[0, 16-31] - c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); - - // c[0,32-47] - c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); - - // c[0,48-63] - c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); - - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); - - // c[1, 16-31] - c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); - - // c[1,32-47] - c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); - - // c[1,48-63] - c_float_1p3 = _mm512_add_ps( selector2, c_float_1p3 ); - - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); - - // c[2, 16-31] - c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); - - // c[2,32-47] - c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); - - // c[2,48-63] - c_float_2p3 = _mm512_add_ps( selector3, c_float_2p3 ); - - // c[3,0-15] - c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); - - // c[3, 16-31] - c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); - - // c[3,32-47] - c_float_3p2 = _mm512_add_ps( selector4, c_float_3p2 ); - - // c[3,48-63] - c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); - - // c[4,0-15] - c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); - - // c[4, 16-31] - c_float_4p1 = _mm512_add_ps( selector5, c_float_4p1 ); - - // c[4,32-47] - c_float_4p2 = _mm512_add_ps( selector5, c_float_4p2 ); - - // c[4,48-63] - c_float_4p3 = _mm512_add_ps( selector5, c_float_4p3 ); - - // c[5,0-15] - c_float_5p0 = _mm512_add_ps( selector6, c_float_5p0 ); - - // c[5, 16-31] - c_float_5p1 = _mm512_add_ps( selector6, c_float_5p1 ); - - // c[5,32-47] - c_float_5p2 = _mm512_add_ps( selector6, c_float_5p2 ); - - // c[5,48-63] - c_float_5p3 = _mm512_add_ps( selector6, c_float_5p3 ); - } - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_6x64: - { - selector1 = _mm512_setzero_ps(); - - // c[0,0-15] - c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); - - // c[0, 16-31] - c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); - - // c[0,32-47] - c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); - - // c[0,48-63] - c_float_0p3 = _mm512_max_ps( selector1, c_float_0p3 ); - - // c[1,0-15] - c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); - - // c[1,16-31] - c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); - - // c[1,32-47] - c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); - - // c[1,48-63] - c_float_1p3 = _mm512_max_ps( selector1, c_float_1p3 ); - - // c[2,0-15] - c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); - - // c[2,16-31] - c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); - - // c[2,32-47] - c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); - - // c[2,48-63] - c_float_2p3 = _mm512_max_ps( selector1, c_float_2p3 ); - - // c[3,0-15] - c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); - - // c[3,16-31] - c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); - - // c[3,32-47] - c_float_3p2 = _mm512_max_ps( selector1, c_float_3p2 ); - - // c[3,48-63] - c_float_3p3 = _mm512_max_ps( selector1, c_float_3p3 ); - - // c[4,0-15] - c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); - - // c[4,16-31] - c_float_4p1 = _mm512_max_ps( selector1, c_float_4p1 ); - - // c[4,32-47] - c_float_4p2 = _mm512_max_ps( selector1, c_float_4p2 ); - - // c[4,48-63] - c_float_4p3 = _mm512_max_ps( selector1, c_float_4p3 ); - - // c[5,0-15] - c_float_5p0 = _mm512_max_ps( selector1, c_float_5p0 ); - - // c[5,16-31] - c_float_5p1 = _mm512_max_ps( selector1, c_float_5p1 ); - - // c[5,32-47] - c_float_5p2 = _mm512_max_ps( selector1, c_float_5p2 ); - - // c[5,48-63] - c_float_5p3 = _mm512_max_ps( selector1, c_float_5p3 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_6x64: - { - selector1 = _mm512_setzero_ps(); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_0p0) - - // c[0, 16-31] - RELU_SCALE_OP_F32_AVX512(c_float_0p1) - - // c[0, 32-47] - RELU_SCALE_OP_F32_AVX512(c_float_0p2) - - // c[0, 48-63] - RELU_SCALE_OP_F32_AVX512(c_float_0p3) - - // c[1, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_1p0) - - // c[1, 16-31] - RELU_SCALE_OP_F32_AVX512(c_float_1p1) - - // c[1, 32-47] - RELU_SCALE_OP_F32_AVX512(c_float_1p2) - - // c[1, 48-63] - RELU_SCALE_OP_F32_AVX512(c_float_1p3) - - // c[2, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_2p0) - - // c[2, 16-31] - RELU_SCALE_OP_F32_AVX512(c_float_2p1) - - // c[2, 32-47] - RELU_SCALE_OP_F32_AVX512(c_float_2p2) - - // c[2, 48-63] - RELU_SCALE_OP_F32_AVX512(c_float_2p3) - - // c[3, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_3p0) - - // c[3, 16-31] - RELU_SCALE_OP_F32_AVX512(c_float_3p1) - - // c[3, 32-47] - RELU_SCALE_OP_F32_AVX512(c_float_3p2) - - // c[3, 48-63] - RELU_SCALE_OP_F32_AVX512(c_float_3p3) - - // c[4, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_4p0) - - // c[4, 16-31] - RELU_SCALE_OP_F32_AVX512(c_float_4p1) - - // c[4, 32-47] - RELU_SCALE_OP_F32_AVX512(c_float_4p2) - - // c[4, 48-63] - RELU_SCALE_OP_F32_AVX512(c_float_4p3) - - // c[5, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_5p0) - - // c[5, 16-31] - RELU_SCALE_OP_F32_AVX512(c_float_5p1) - - // c[5, 32-47] - RELU_SCALE_OP_F32_AVX512(c_float_5p2) - - // c[5, 48-63] - RELU_SCALE_OP_F32_AVX512(c_float_5p3) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_6x64: -{ - // c[0, 0-15] - CVT_F32_BF16(c_float_0p0,0,0); - - // c[0, 16-31] - CVT_F32_BF16(c_float_0p1,0,1); - - // c[0, 32-47] - CVT_F32_BF16(c_float_0p2,0,2); - - // c[0, 48-63] - CVT_F32_BF16(c_float_0p3,0,3); - - // c[1, 0-15] - CVT_F32_BF16(c_float_1p0,1,0); - - // c[1, 16-31] - CVT_F32_BF16(c_float_1p1,1,1); - - // c[1, 32-47] - CVT_F32_BF16(c_float_1p2,1,2); - - // c[1, 48-63] - CVT_F32_BF16(c_float_1p3,1,3); - - // c[2, 0-15] - CVT_F32_BF16(c_float_2p0,2,0); - - // c[2, 16-31] - CVT_F32_BF16(c_float_2p1,2,1); - - // c[2, 32-47] - CVT_F32_BF16(c_float_2p2,2,2); - - // c[2, 48-63] - CVT_F32_BF16(c_float_2p3,2,3); - - // c[3, 0-15] - CVT_F32_BF16(c_float_3p0,3,0); - - // c[3, 16-31] - CVT_F32_BF16(c_float_3p1,3,1); - - // c[3, 32-47] - CVT_F32_BF16(c_float_3p2,3,2); - - // c[3, 48-63] - CVT_F32_BF16(c_float_3p3,3,3); - - // c[4, 0-15] - CVT_F32_BF16(c_float_4p0,4,0); - - // c[4, 16-31] - CVT_F32_BF16(c_float_4p1,4,1); - - // c[4, 32-47] - CVT_F32_BF16(c_float_4p2,4,2); - - // c[4, 48-63] - CVT_F32_BF16(c_float_4p3,4,3); - - // c[5, 0-15] - CVT_F32_BF16(c_float_5p0,5,0); - - // c[5, 16-31] - CVT_F32_BF16(c_float_5p1,5,1); - - // c[5, 32-47] - CVT_F32_BF16(c_float_5p2,5,2); - - // c[5, 48-63] - CVT_F32_BF16(c_float_5p3,5,3); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR -} - -POST_OPS_6x64_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), c_float_0p0 ); - - // c[0, 16-31] - _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ), c_float_0p1 ); - - // c[0,32-47] - _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 2*16 ), c_float_0p2 ); - - // c[0,48-63] - _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 3*16 ), c_float_0p3 ); - - // c[1,0-15] - _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), c_float_1p0 ); - - // c[1,16-31] - _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ), c_float_1p1 ); - - // c[1,32-47] - _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 2*16 ), c_float_1p2 ); - - // c[1,48-63] - _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 3*16 ), c_float_1p3 ); - - // c[2,0-15] - _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), c_float_2p0 ); - - // c[2,16-31] - _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ), c_float_2p1 ); - - // c[2,32-47] - _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 2*16 ), c_float_2p2 ); - - // c[2,48-63] - _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 3*16 ), c_float_2p3 ); - - // c[3,0-15] - _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), c_float_3p0 ); - - // c[3,16-31] - _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ), c_float_3p1 ); - - // c[3,32-47] - _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 2*16 ), c_float_3p2 ); - - // c[3,48-63] - _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 3*16 ), c_float_3p3 ); - - // c[4,0-15] - _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), c_float_4p0 ); - - // c[4,16-31] - _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ), c_float_4p1 ); - - // c[4,32-47] - _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 2*16 ), c_float_4p2 ); - - // c[4,48-63] - _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 3*16 ), c_float_4p3 ); - - // c[5,0-15] - _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_float_5p0 ); - - // c[5,16-31] - _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ), c_float_5p1 ); - - // c[5,32-47] - _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 2*16 ), c_float_5p2 ); - - // c[5,48-63] - _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 3*16 ), c_float_5p3 ); - - a = a + ( MR * ps_a ); - post_op_c_i += MR; - } - - if ( m_partial_pieces > 0 ) - { - if ( m_partial_pieces == 5 ) - { - // In cases where A matrix is packed cs_a is set to 12, since the - // next column in a given row is accessed after 2*6 elements, where - // 6 is MR and 2 elements are broadcasted each time from A (bf16). - // In fringe case, where m < MR, the next column will be after m'*2 - // elements, and subsequently following adjustment of cs_a is - // required before calling m fringe kernels. - dim_t cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 5 ); - lpgemm_rowvar_bf16bf16f32of32_5x64 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 4 ) - { - dim_t cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 4 ); - lpgemm_rowvar_bf16bf16f32of32_4x64 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 3 ) - { - dim_t cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 3 ); - lpgemm_rowvar_bf16bf16f32of32_3x64 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 2 ) - { - dim_t cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 2 ); - lpgemm_rowvar_bf16bf16f32of32_2x64 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 1 ) - { - dim_t cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 1 ); - lpgemm_rowvar_bf16bf16f32of32_1x64 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - } -} -#endif diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c deleted file mode 100644 index 1a37ab071a..0000000000 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c +++ /dev/null @@ -1,2502 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ -#include -#include - -#include "blis.h" -#include "lpgemm_kernels.h" -#include "lpgemm_f32_kern_macros.h" - -#ifdef BLIS_KERNELS_ZEN4 -// 6xlt16 bf16 fringe kernel -LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_6xLT16_DISABLE, - &&POST_OPS_BIAS_6xLT16, - &&POST_OPS_RELU_6xLT16, - &&POST_OPS_RELU_SCALE_6xLT16, - &&POST_OPS_DOWNSCALE_6xLT16 - }; - dim_t MR = 6; - dim_t m_full_pieces = m0 / MR; - dim_t m_full_pieces_loop_limit = m_full_pieces * MR; - dim_t m_partial_pieces = m0 % MR; - - dim_t k_full_pieces = k0 / 2; - dim_t k_partial_pieces = k0 % 2; - - int32_t a_kfringe_buf = 0; - - // B matrix storage bfloat type - __m512bh b0; - - // A matrix storage bfloat type - __m512bh a_bf16_0; - - // For corner cases. - float buf0[16]; - float buf1[16]; - float buf2[16]; - float buf3[16]; - float buf4[16]; - float buf5[16]; - - for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) - { - // Registers to use for accumulating C. - __m512 c_float_0p0 = _mm512_setzero_ps(); - - __m512 c_float_1p0 = _mm512_setzero_ps(); - - __m512 c_float_2p0 = _mm512_setzero_ps(); - - __m512 c_float_3p0 = _mm512_setzero_ps(); - - __m512 c_float_4p0 = _mm512_setzero_ps(); - - __m512 c_float_5p0 = _mm512_setzero_ps(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - // Load 2 rows with 16 extended elements each from B to 1 ZMM - // registers. It is to be noted that the B matrix is packed for use - // in bf16 instructions and each load to ZMM register will have 2 - // elements along k direction and 16 elements across n directions, - // so 2x16 elements to a ZMM register. - b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+2]. - a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 2. - // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] - c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); - - // Broadcast a[1,kr:kr+2]. - a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 2. - // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] - c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); - - // Broadcast a[2,kr:kr+2]. - a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 2. - // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] - c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); - - // Broadcast a[3,kr:kr+2]. - a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 2. - // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] - c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); - - // Broadcast a[4,kr:kr+2]. - a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 2. - // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] - c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); - - // Broadcast a[5,kr:kr+2]. - a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 2. - // c[5,0-15] = a[5,kr:kr+2]*b[kr:kr+2,0-15] - c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); - } - - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 2. - // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] - c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); - - // Broadcast a[1,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 2. - // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] - c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); - - // Broadcast a[2,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 2. - // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] - c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); - - // Broadcast a[3,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 2. - // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] - c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); - - // Broadcast a[4,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 2. - // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] - c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); - - // Broadcast a[5,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 2. - // c[5,0-15] = a[5,kr:kr+2]*b[kr:kr+2,0-15] - c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); - } - - // Load alpha and beta - __m512 selector1 = _mm512_set1_ps( alpha ); - __m512 selector2 = _mm512_set1_ps( beta ); - - // Scale by alpha - c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); - - c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); - - c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); - - c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); - - c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); - - c_float_5p0 = _mm512_mul_ps( selector1, c_float_5p0 ); - - // Scale C by beta. - if ( beta != 0 ) - { - memcpy( buf0, ( c + ( rs_c * ( ir + 0 ) ) ), ( n0_rem * sizeof( float ) ) ); - memcpy( buf1, ( c + ( rs_c * ( ir + 1 ) ) ), ( n0_rem * sizeof( float ) ) ); - memcpy( buf2, ( c + ( rs_c * ( ir + 2 ) ) ), ( n0_rem * sizeof( float ) ) ); - memcpy( buf3, ( c + ( rs_c * ( ir + 3 ) ) ), ( n0_rem * sizeof( float) ) ); - memcpy( buf4, ( c + ( rs_c * ( ir + 4 ) ) ), ( n0_rem * sizeof( float ) ) ); - memcpy( buf5, ( c + ( rs_c * ( ir + 5 ) ) ), ( n0_rem * sizeof( float ) ) ); - - // c[0,0-15] - selector1 = _mm512_loadu_ps( buf0 ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - - // c[1,0-15] - selector1 = _mm512_loadu_ps( buf1 ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - - // c[2,0-15] - selector1 = _mm512_loadu_ps( buf2 ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); - - // c[3,0-15] - selector1 = _mm512_loadu_ps( buf3 ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); - - // c[4,0-15] - selector1 = _mm512_loadu_ps( buf4 ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); - - // c[5,0-15] - selector1 = _mm512_loadu_ps( buf5 ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); - } - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_6xLT16: - { - if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || - ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) - { - memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - selector1 = _mm512_loadu_ps( buf0 ); - - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); - - // c[3,0-15] - c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); - - // c[4,0-15] - c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); - - // c[5,0-15] - c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); - } - else - { - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_op_c_i + 1 ) ); - __m512 selector3 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_op_c_i + 2 ) ); - __m512 selector4 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_op_c_i + 3 ) ); - __m512 selector5 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_op_c_i + 4 ) ); - __m512 selector6 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_op_c_i + 5 ) ); - - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); - - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); - - // c[3,0-15] - c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); - - // c[4,0-15] - c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); - - // c[5,0-15] - c_float_5p0 = _mm512_add_ps( selector6, c_float_5p0 ); - } - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_6xLT16: - { - selector1 = _mm512_setzero_ps(); - - // c[0,0-15] - c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); - - // c[1,0-15] - c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); - - // c[2,0-15] - c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); - - // c[3,0-15] - c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); - - // c[4,0-15] - c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); - - // c[5,0-15] - c_float_5p0 = _mm512_max_ps( selector1, c_float_5p0 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_6xLT16: - { - selector1 = _mm512_setzero_ps(); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_0p0) - - // c[1, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_1p0) - - // c[2, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_2p0) - - // c[3, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_3p0) - - // c[4, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_4p0) - - // c[5, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_5p0) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_6xLT16: - { - // c[0, 0-15] - CVT_F32_BF16_LT16(c_float_0p0,0,0); - - // c[1, 0-15] - CVT_F32_BF16_LT16(c_float_1p0,1,0); - - // c[2, 0-15] - CVT_F32_BF16_LT16(c_float_2p0,2,0); - - // c[3, 0-15] - CVT_F32_BF16_LT16(c_float_3p0,3,0); - - // c[4, 0-15] - CVT_F32_BF16_LT16(c_float_4p0,4,0); - - // c[5, 0-15] - CVT_F32_BF16_LT16(c_float_5p0,5,0); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_6xLT16_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_ps( buf0, c_float_0p0 ); - - // c[1,0-15] - _mm512_storeu_ps( buf1, c_float_1p0 ); - - // c[2,0-15] - _mm512_storeu_ps( buf2, c_float_2p0 ); - - // c[3,0-15] - _mm512_storeu_ps( buf3, c_float_3p0 ); - - // c[4,0-15] - _mm512_storeu_ps( buf4, c_float_4p0 ); - - // c[5,0-15] - _mm512_storeu_ps( buf5, c_float_5p0 ); - - // Memcpy partial parts. - // c[0,0-15] - memcpy( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), buf0, ( n0_rem * sizeof( float ) ) ); - - // c[1,0-15] - memcpy( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), buf1, ( n0_rem * sizeof( float ) ) ); - - // c[2,0-15] - memcpy( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), buf2, ( n0_rem * sizeof( float ) ) ); - - // c[3,0-15] - memcpy( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), buf3, ( n0_rem * sizeof( float ) ) ); - - // c[4,0-15] - memcpy( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), buf4, ( n0_rem * sizeof( float ) ) ); - - // c[5,0-15] - memcpy( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), buf5, ( n0_rem * sizeof( float ) ) ); - - a = a + ( MR * ps_a ); - post_op_c_i += MR; - } - - if ( m_partial_pieces > 0 ) - { - if ( m_partial_pieces == 5 ) - { - int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 5 ); - lpgemm_rowvar_bf16bf16f32of32_5xlt16 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, n0_rem, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 4 ) - { - int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 4 ); - lpgemm_rowvar_bf16bf16f32of32_4xlt16 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, n0_rem, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 3 ) - { - int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 3 ); - lpgemm_rowvar_bf16bf16f32of32_3xlt16 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, n0_rem, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 2 ) - { - int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 2 ); - lpgemm_rowvar_bf16bf16f32of32_2xlt16 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, n0_rem, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 1 ) - { - int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 1 ); - lpgemm_rowvar_bf16bf16f32of32_1xlt16 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, n0_rem, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - } -} - -// 6x16 bf16 fringe kernel -LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_6x16_DISABLE, - &&POST_OPS_BIAS_6x16, - &&POST_OPS_RELU_6x16, - &&POST_OPS_RELU_SCALE_6x16, - &&POST_OPS_DOWNSCALE_6x16 - }; - dim_t MR = 6; - dim_t m_full_pieces = m0 / MR; - dim_t m_full_pieces_loop_limit = m_full_pieces * MR; - dim_t m_partial_pieces = m0 % MR; - - dim_t k_full_pieces = k0 / 2; - dim_t k_partial_pieces = k0 % 2; - - int32_t a_kfringe_buf = 0; - - // B matrix storage bfloat type - __m512bh b0; - - // A matrix storage bfloat type - __m512bh a_bf16_0; - - for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) - { - // Registers to use for accumulating C. - __m512 c_float_0p0 = _mm512_setzero_ps(); - - __m512 c_float_1p0 = _mm512_setzero_ps(); - - __m512 c_float_2p0 = _mm512_setzero_ps(); - - __m512 c_float_3p0 = _mm512_setzero_ps(); - - __m512 c_float_4p0 = _mm512_setzero_ps(); - - __m512 c_float_5p0 = _mm512_setzero_ps(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - // Load 2 rows with 16 elements each from B to 1 ZMM registers. It - // is to be noted that the B matrix is packed for use in bf16 - // instructions and each load to ZMM register will have 2 elements - // along k direction and 16 elements across n directions, so 2x16 - // elements to a ZMM register. - b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+2]. - a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 2. - // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] - c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); - - // Broadcast a[1,kr:kr+2]. - a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 2. - // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] - c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); - - // Broadcast a[2,kr:kr+2]. - a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 2. - // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] - c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); - - // Broadcast a[3,kr:kr+2]. - a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 2. - // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] - c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); - - // Broadcast a[4,kr:kr+2]. - a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 2. - // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] - c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); - - // Broadcast a[5,kr:kr+2]. - a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 2. - // c[5,0-15] = a[5,kr:kr+2]*b[kr:kr+2,0-15] - c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); - } - // Handle k remainder. - - if ( k_partial_pieces > 0 ) - { - b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 2. - // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] - c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); - - // Broadcast a[1,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 2. - // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] - c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); - - // Broadcast a[2,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 2. - // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] - c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); - - // Broadcast a[3,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 2. - // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] - c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); - - // Broadcast a[4,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 2. - // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] - c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); - - // Broadcast a[5,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 2. - // c[5,0-15] = a[5,kr:kr+2]*b[kr:kr+2,0-15] - c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); - } - - // Load alpha and beta - __m512 selector1 = _mm512_set1_ps( alpha ); - __m512 selector2 = _mm512_set1_ps( beta ); - - // Scale by alpha - c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); - - c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); - - c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); - - c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); - - c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); - - c_float_5p0 = _mm512_mul_ps( selector1, c_float_5p0 ); - - // Scale C by beta. - if ( beta != 0 ) - { - // c[0,0-15] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - - // c[1,0-15] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - - // c[2,0-15] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); - - // c[3,0-15] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); - - // c[4,0-15] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); - - // c[5,0-15] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); - } - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_6x16: - { - if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || - ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) - { - selector1 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j ); - - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); - - // c[3,0-15] - c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); - - // c[4,0-15] - c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); - - // c[5,0-15] - c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); - } - else - { - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_op_c_i + 1 ) ); - __m512 selector3 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_op_c_i + 2 ) ); - __m512 selector4 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_op_c_i + 3 ) ); - __m512 selector5 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_op_c_i + 4 ) ); - __m512 selector6 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_op_c_i + 5 ) ); - - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); - - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); - - // c[3,0-15] - c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); - - // c[4,0-15] - c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); - - // c[5,0-15] - c_float_5p0 = _mm512_add_ps( selector6, c_float_5p0 ); - } - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_6x16: - { - selector1 = _mm512_setzero_ps(); - - // c[0,0-15] - c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); - - // c[1,0-15] - c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); - - // c[2,0-15] - c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); - - // c[3,0-15] - c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); - - // c[4,0-15] - c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); - - // c[5,0-15] - c_float_5p0 = _mm512_max_ps( selector1, c_float_5p0 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_6x16: - { - selector1 = _mm512_setzero_ps(); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_0p0) - - // c[1, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_1p0) - - // c[2, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_2p0) - - // c[3, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_3p0) - - // c[4, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_4p0) - - // c[5, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_5p0) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_6x16: - { - // c[0, 0-15] - CVT_F32_BF16(c_float_0p0,0,0); - - // c[1, 0-15] - CVT_F32_BF16(c_float_1p0,1,0); - - // c[2, 0-15] - CVT_F32_BF16(c_float_2p0,2,0); - - // c[3, 0-15] - CVT_F32_BF16(c_float_3p0,3,0); - - // c[4, 0-15] - CVT_F32_BF16(c_float_4p0,4,0); - - // c[5, 0-15] - CVT_F32_BF16(c_float_5p0,5,0); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_6x16_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), c_float_0p0 ); - - // c[1,0-15] - _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), c_float_1p0 ); - - // c[2,0-15] - _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), c_float_2p0 ); - - // c[3,0-15] - _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), c_float_3p0 ); - - // c[4,0-15] - _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), c_float_4p0 ); - - // c[5,0-15] - _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_float_5p0 ); - - a = a + ( MR * ps_a ); - post_op_c_i += MR; - } - - if ( m_partial_pieces > 0 ) - { - if ( m_partial_pieces == 5 ) - { - int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 5 ); - lpgemm_rowvar_bf16bf16f32of32_5x16 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 4 ) - { - int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 4 ); - lpgemm_rowvar_bf16bf16f32of32_4x16 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 3 ) - { - int cs_a_use = ( cs_a == 2) ? 2 : ( ( cs_a / 6 ) * 3 ); - lpgemm_rowvar_bf16bf16f32of32_3x16 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 2 ) - { - int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 2 ); - lpgemm_rowvar_bf16bf16f32of32_2x16 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 1 ) - { - int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 1 ); - lpgemm_rowvar_bf16bf16f32of32_1x16 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - } -} - -// 6x32 bf16 fringe kernel -LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_6x32_DISABLE, - &&POST_OPS_BIAS_6x32, - &&POST_OPS_RELU_6x32, - &&POST_OPS_RELU_SCALE_6x32, - &&POST_OPS_DOWNSCALE_6x32 - }; - dim_t MR = 6; - dim_t m_full_pieces = m0 / MR; - dim_t m_full_pieces_loop_limit = m_full_pieces * MR; - dim_t m_partial_pieces = m0 % MR; - - dim_t k_full_pieces = k0 / 2; - dim_t k_partial_pieces = k0 % 2; - - int32_t a_kfringe_buf = 0; - - // B matrix storage bfloat type - __m512bh b0; - __m512bh b1; - - // A matrix storage bfloat type - __m512bh a_bf16_0; - - for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) - { - // Registers to use for accumulating C. - __m512 c_float_0p0 = _mm512_setzero_ps(); - __m512 c_float_0p1 = _mm512_setzero_ps(); - - __m512 c_float_1p0 = _mm512_setzero_ps(); - __m512 c_float_1p1 = _mm512_setzero_ps(); - - __m512 c_float_2p0 = _mm512_setzero_ps(); - __m512 c_float_2p1 = _mm512_setzero_ps(); - - __m512 c_float_3p0 = _mm512_setzero_ps(); - __m512 c_float_3p1 = _mm512_setzero_ps(); - - __m512 c_float_4p0 = _mm512_setzero_ps(); - __m512 c_float_4p1 = _mm512_setzero_ps(); - - __m512 c_float_5p0 = _mm512_setzero_ps(); - __m512 c_float_5p1 = _mm512_setzero_ps(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - // Load 2 rows with 32 elements each from B to 2 ZMM registers. It - // is to be noted that the B matrix is packed for use in bf16 - // instructions and each load to ZMM register will have 2 elements - // along k direction and 32 elements across n directions, so 2x16 - // elements to a ZMM register. - b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - - // Broadcast a[0,kr:kr+2]. - a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 2. - // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] - c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); - c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); - - // Broadcast a[1,kr:kr+2]. - a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 2. - // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] - c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); - c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); - - // Broadcast a[2,kr:kr+2]. - a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 2. - // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] - c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); - c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); - - // Broadcast a[3,kr:kr+2]. - a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 2. - // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] - c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); - c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); - - // Broadcast a[4,kr:kr+2]. - a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 2. - // c[4,0-31] = a[4,kr:kr+2]*b[kr:kr+2,0-31] - c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); - c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); - - // Broadcast a[5,kr:kr+2]. - a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 2. - // c[5,0-31] = a[5,kr:kr+2]*b[kr:kr+2,0-31] - c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); - c_float_5p1 = _mm512_dpbf16_ps( c_float_5p1, a_bf16_0, b1 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - - // Broadcast a[0,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 2. - // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] - c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); - c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); - - // Broadcast a[1,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 2. - // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] - c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); - c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); - - // Broadcast a[2,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 2. - // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] - c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); - c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); - - // Broadcast a[3,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 2. - // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] - c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); - c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); - - // Broadcast a[4,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 2. - // c[4,0-31] = a[4,kr:kr+2]*b[kr:kr+2,0-31] - c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); - c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); - - // Broadcast a[5,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 2. - // c[5,0-31] = a[5,kr:kr+2]*b[kr:kr+2,0-31] - c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); - c_float_5p1 = _mm512_dpbf16_ps( c_float_5p1, a_bf16_0, b1 ); - } - // Load alpha and beta - __m512 selector1 = _mm512_set1_ps( alpha ); - __m512 selector2 = _mm512_set1_ps( beta ); - - // Scale by alpha - c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); - c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); - - c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); - c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); - - c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); - c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); - - c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); - c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); - - c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); - c_float_4p1 = _mm512_mul_ps( selector1, c_float_4p1 ); - - c_float_5p0 = _mm512_mul_ps( selector1, c_float_5p0 ); - c_float_5p1 = _mm512_mul_ps( selector1, c_float_5p1 ); - - // Scale C by beta. - if ( beta != 0 ) - { - // c[0,0-15] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - - // c[0, 16-31] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); - - // c[1,0-15] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - - // c[1,16-31] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); - - // c[2,0-15] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); - - // c[2,16-31] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_2p1 = _mm512_add_ps( selector1, c_float_2p1 ); - - // c[3,0-15] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); - - // c[3,16-31] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_3p1 = _mm512_add_ps( selector1, c_float_3p1 ); - - // c[4,0-15] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); - - // c[4,16-31] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_4p1 = _mm512_add_ps( selector1, c_float_4p1 ); - - // c[5,0-15] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); - - // c[5,16-31] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_5p1 = _mm512_add_ps( selector1, c_float_5p1 ); - } - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_6x32: - { - if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || - ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) - { - selector1 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - - // c[0, 16-31] - c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); - - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - - // c[1, 16-31] - c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); - - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); - - // c[2, 16-31] - c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); - - // c[3,0-15] - c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); - - // c[3, 16-31] - c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); - - // c[4,0-15] - c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); - - // c[4, 16-31] - c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); - - // c[5,0-15] - c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); - - // c[5, 16-31] - c_float_5p1 = _mm512_add_ps( selector2, c_float_5p1 ); - } - else - { - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_op_c_i + 1 ) ); - __m512 selector3 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_op_c_i + 2 ) ); - __m512 selector4 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_op_c_i + 3 ) ); - __m512 selector5 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_op_c_i + 4 ) ); - __m512 selector6 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_op_c_i + 5 ) ); - - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - - // c[0, 16-31] - c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); - - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); - - // c[1, 16-31] - c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); - - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); - - // c[2, 16-31] - c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); - - // c[3,0-15] - c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); - - // c[3, 16-31] - c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); - - // c[4,0-15] - c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); - - // c[4, 16-31] - c_float_4p1 = _mm512_add_ps( selector5, c_float_4p1 ); - - // c[5,0-15] - c_float_5p0 = _mm512_add_ps( selector6, c_float_5p0 ); - - // c[5, 16-31] - c_float_5p1 = _mm512_add_ps( selector6, c_float_5p1 ); - } - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_6x32: - { - selector1 = _mm512_setzero_ps(); - - // c[0,0-15] - c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); - - // c[0, 16-31] - c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); - - // c[1,0-15] - c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); - - // c[1,16-31] - c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); - - // c[2,0-15] - c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); - - // c[2,16-31] - c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); - - // c[3,0-15] - c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); - - // c[3,16-31] - c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); - - // c[4,0-15] - c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); - - // c[4,16-31] - c_float_4p1 = _mm512_max_ps( selector1, c_float_4p1 ); - - // c[5,0-15] - c_float_5p0 = _mm512_max_ps( selector1, c_float_5p0 ); - - // c[5,16-31] - c_float_5p1 = _mm512_max_ps( selector1, c_float_5p1 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_6x32: - { - selector1 = _mm512_setzero_ps(); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_0p0) - - // c[0, 16-31] - RELU_SCALE_OP_F32_AVX512(c_float_0p1) - - // c[1, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_1p0) - - // c[1, 16-31] - RELU_SCALE_OP_F32_AVX512(c_float_1p1) - - // c[2, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_2p0) - - // c[2, 16-31] - RELU_SCALE_OP_F32_AVX512(c_float_2p1) - - // c[3, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_3p0) - - // c[3, 16-31] - RELU_SCALE_OP_F32_AVX512(c_float_3p1) - - // c[4, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_4p0) - - // c[4, 16-31] - RELU_SCALE_OP_F32_AVX512(c_float_4p1) - - // c[5, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_5p0) - - // c[5, 16-31] - RELU_SCALE_OP_F32_AVX512(c_float_5p1) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_6x32: - { - // c[0, 0-15] - CVT_F32_BF16(c_float_0p0,0,0); - - // c[0, 16-31] - CVT_F32_BF16(c_float_0p1,0,1); - - // c[1, 0-15] - CVT_F32_BF16(c_float_1p0,1,0); - - // c[1, 16-31] - CVT_F32_BF16(c_float_1p1,1,1); - - // c[2, 0-15] - CVT_F32_BF16(c_float_2p0,2,0); - - // c[2, 16-31] - CVT_F32_BF16(c_float_2p1,2,1); - - // c[3, 0-15] - CVT_F32_BF16(c_float_3p0,3,0); - - // c[3, 16-31] - CVT_F32_BF16(c_float_3p1,3,1); - - // c[4, 0-15] - CVT_F32_BF16(c_float_4p0,4,0); - - // c[4, 16-31] - CVT_F32_BF16(c_float_4p1,4,1); - - // c[5, 0-15] - CVT_F32_BF16(c_float_5p0,5,0); - - // c[5, 16-31] - CVT_F32_BF16(c_float_5p1,5,1); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_6x32_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), c_float_0p0 ); - - // c[0, 16-31] - _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ), c_float_0p1 ); - - // c[1,0-15] - _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), c_float_1p0 ); - - // c[1,16-31] - _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ), c_float_1p1 ); - - // c[2,0-15] - _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), c_float_2p0 ); - - // c[2,16-31] - _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ), c_float_2p1 ); - - // c[3,0-15] - _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), c_float_3p0 ); - - // c[3,16-31] - _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ), c_float_3p1 ); - - // c[4,0-15] - _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), c_float_4p0 ); - - // c[4,16-31] - _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ), c_float_4p1 ); - - // c[5,0-15] - _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_float_5p0 ); - - // c[5,16-31] - _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ), c_float_5p1 ); - - a = a + ( MR * ps_a ); - post_op_c_i += MR; - } - - if ( m_partial_pieces > 0 ) - { - if ( m_partial_pieces == 5 ) - { - int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 5 ); - lpgemm_rowvar_bf16bf16f32of32_5x32 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 4 ) - { - int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 4 ); - lpgemm_rowvar_bf16bf16f32of32_4x32 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 3 ) - { - int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 3 ); - lpgemm_rowvar_bf16bf16f32of32_3x32 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 2 ) - { - int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 2 ); - lpgemm_rowvar_bf16bf16f32of32_2x32 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 1 ) - { - int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 1 ); - lpgemm_rowvar_bf16bf16f32of32_1x32 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - } -} - -// 6x48 bf16 fringe kernel -LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_6x48_DISABLE, - &&POST_OPS_BIAS_6x48, - &&POST_OPS_RELU_6x48, - &&POST_OPS_RELU_SCALE_6x48, - &&POST_OPS_DOWNSCALE_6x48 - }; - dim_t MR = 6; - dim_t m_full_pieces = m0 / MR; - dim_t m_full_pieces_loop_limit = m_full_pieces * MR; - dim_t m_partial_pieces = m0 % MR; - - dim_t k_full_pieces = k0 / 2; - dim_t k_partial_pieces = k0 % 2; - - int32_t a_kfringe_buf = 0; - - // B matrix storage bfloat type - __m512bh b0; - __m512bh b1; - __m512bh b2; - - // A matrix storage bfloat type - __m512bh a_bf16_0; - - for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) - { - // Registers to use for accumulating C. - __m512 c_float_0p0 = _mm512_setzero_ps(); - __m512 c_float_0p1 = _mm512_setzero_ps(); - __m512 c_float_0p2 = _mm512_setzero_ps(); - - __m512 c_float_1p0 = _mm512_setzero_ps(); - __m512 c_float_1p1 = _mm512_setzero_ps(); - __m512 c_float_1p2 = _mm512_setzero_ps(); - - __m512 c_float_2p0 = _mm512_setzero_ps(); - __m512 c_float_2p1 = _mm512_setzero_ps(); - __m512 c_float_2p2 = _mm512_setzero_ps(); - - __m512 c_float_3p0 = _mm512_setzero_ps(); - __m512 c_float_3p1 = _mm512_setzero_ps(); - __m512 c_float_3p2 = _mm512_setzero_ps(); - - __m512 c_float_4p0 = _mm512_setzero_ps(); - __m512 c_float_4p1 = _mm512_setzero_ps(); - __m512 c_float_4p2 = _mm512_setzero_ps(); - - __m512 c_float_5p0 = _mm512_setzero_ps(); - __m512 c_float_5p1 = _mm512_setzero_ps(); - __m512 c_float_5p2 = _mm512_setzero_ps(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - - // Load 2 rows with 48 elements each from B to 3 ZMM registers. It - // is to be noted that the B matrix is packed for use in bf16 - // instructions and each load to ZMM register will have 2 elements - // along k direction and 16 elements across n directions, so 2x16 - // elements to a ZMM register. - b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); - - // Broadcast a[0,kr:kr+2]. - a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 2. - // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] - c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); - c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); - c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); - - // Broadcast a[1,kr:kr+2]. - a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 2. - // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] - c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); - c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); - c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); - - // Broadcast a[2,kr:kr+2]. - a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 2. - // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] - c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); - c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); - c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); - - // Broadcast a[3,kr:kr+2]. - a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 2. - // c[3,0-47] = a[3,kr:kr+2]*b[kr:kr+2,0-47] - c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); - c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); - c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_0, b2 ); - - // Broadcast a[4,kr:kr+2]. - a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 2. - // c[4,0-47] = a[4,kr:kr+2]*b[kr:kr+2,0-47] - c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); - c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); - c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); - - // Broadcast a[5,kr:kr+2]. - a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 2. - // c[5,0-47] = a[5,kr:kr+2]*b[kr:kr+2,0-47] - c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); - c_float_5p1 = _mm512_dpbf16_ps( c_float_5p1, a_bf16_0, b1 ); - c_float_5p2 = _mm512_dpbf16_ps( c_float_5p2, a_bf16_0, b2 ); - - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 2. - // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] - c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); - c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); - c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); - - // Broadcast a[1,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 2. - // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] - c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); - c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); - c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); - - // Broadcast a[2,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 2. - // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] - c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); - c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); - c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); - - // Broadcast a[3,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 2. - // c[3,0-47] = a[3,kr:kr+2]*b[kr:kr+2,0-47] - c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); - c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); - c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_0, b2 ); - - // Broadcast a[4,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 2. - // c[4,0-47] = a[4,kr:kr+2]*b[kr:kr+2,0-47] - c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); - c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); - c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); - - // Broadcast a[5,kr:kr+2]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( bfloat16 ) ) - ); - a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 2. - // c[5,0-47] = a[5,kr:kr+2]*b[kr:kr+2,0-47] - c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); - c_float_5p1 = _mm512_dpbf16_ps( c_float_5p1, a_bf16_0, b1 ); - c_float_5p2 = _mm512_dpbf16_ps( c_float_5p2, a_bf16_0, b2 ); - } - - // Load alpha and beta - __m512 selector1 = _mm512_set1_ps( alpha ); - __m512 selector2 = _mm512_set1_ps( beta ); - - // Scale by alpha - c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); - c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); - c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); - - c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); - c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); - c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); - - c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); - c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); - c_float_2p2 = _mm512_mul_ps( selector1, c_float_2p2 ); - - c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); - c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); - c_float_3p2 = _mm512_mul_ps( selector1, c_float_3p2 ); - - c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); - c_float_4p1 = _mm512_mul_ps( selector1, c_float_4p1 ); - c_float_4p2 = _mm512_mul_ps( selector1, c_float_4p2 ); - - c_float_5p0 = _mm512_mul_ps( selector1, c_float_5p0 ); - c_float_5p1 = _mm512_mul_ps( selector1, c_float_5p1 ); - c_float_5p2 = _mm512_mul_ps( selector1, c_float_5p2 ); - - // Scale C by beta. - if ( beta != 0 ) - { - // c[0,0-15] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - - // c[0, 16-31] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); - - // c[0,32-47] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 2*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); - - // c[1,0-15] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - - // c[1,16-31] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); - - // c[1,32-47] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 2*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_1p2 = _mm512_add_ps( selector1, c_float_1p2 ); - - // c[2,0-15] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); - - // c[2,16-31] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_2p1 = _mm512_add_ps( selector1, c_float_2p1 ); - - // c[2,32-47] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 2*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_2p2 = _mm512_add_ps( selector1, c_float_2p2 ); - - // c[3,0-15] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); - - // c[3,16-31] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_3p1 = _mm512_add_ps( selector1, c_float_3p1 ); - - // c[3,32-47] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 2*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_3p2 = _mm512_add_ps( selector1, c_float_3p2 ); - - // c[4,0-15] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); - - // c[4,16-31] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_4p1 = _mm512_add_ps( selector1, c_float_4p1 ); - - // c[4,32-47] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 2*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_4p2 = _mm512_add_ps( selector1, c_float_4p2 ); - - // c[5,0-15] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); - - // c[5,16-31] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_5p1 = _mm512_add_ps( selector1, c_float_5p1 ); - - // c[5,32-47] - selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 2*16 ) ); - selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float_5p2 = _mm512_add_ps( selector1, c_float_5p2 ); - } - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_6x48: - { - __m512 selector3; - - if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || - ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) - { - selector1 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - selector3 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 2 * 16 ) ); - - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - - // c[0, 16-31] - c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); - - // c[0,32-47] - c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); - - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - - // c[1, 16-31] - c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); - - // c[1,32-47] - c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); - - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); - - // c[2, 16-31] - c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); - - // c[2,32-47] - c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); - - // c[3,0-15] - c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); - - // c[3, 16-31] - c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); - - // c[3,32-47] - c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); - - // c[4,0-15] - c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); - - // c[4, 16-31] - c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); - - // c[4,32-47] - c_float_4p2 = _mm512_add_ps( selector3, c_float_4p2 ); - - // c[5,0-15] - c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); - - // c[5, 16-31] - c_float_5p1 = _mm512_add_ps( selector2, c_float_5p1 ); - - // c[5,32-47] - c_float_5p2 = _mm512_add_ps( selector3, c_float_5p2 ); - } - else - { - selector1 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_op_c_i + 0 ) ); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_op_c_i + 1 ) ); - selector3 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_op_c_i + 2 ) ); - __m512 selector4 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_op_c_i + 3 ) ); - __m512 selector5 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_op_c_i + 4 ) ); - __m512 selector6 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 - + post_op_c_i + 5 ) ); - - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - - // c[0, 16-31] - c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); - - // c[0,32-47] - c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); - - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); - - // c[1, 16-31] - c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); - - // c[1,32-47] - c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); - - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); - - // c[2, 16-31] - c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); - - // c[2,32-47] - c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); - - // c[3,0-15] - c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); - - // c[3, 16-31] - c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); - - // c[3,32-47] - c_float_3p2 = _mm512_add_ps( selector4, c_float_3p2 ); - - // c[4,0-15] - c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); - - // c[4, 16-31] - c_float_4p1 = _mm512_add_ps( selector5, c_float_4p1 ); - - // c[4,32-47] - c_float_4p2 = _mm512_add_ps( selector5, c_float_4p2 ); - - // c[5,0-15] - c_float_5p0 = _mm512_add_ps( selector6, c_float_5p0 ); - - // c[5, 16-31] - c_float_5p1 = _mm512_add_ps( selector6, c_float_5p1 ); - - // c[5,32-47] - c_float_5p2 = _mm512_add_ps( selector6, c_float_5p2 ); - } - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_6x48: - { - //printf("relu\n"); - selector1 = _mm512_setzero_ps(); - - // c[0,0-15] - c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); - - // c[0, 16-31] - c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); - - // c[0,32-47] - c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); - - // c[1,0-15] - c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); - - // c[1,16-31] - c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); - - // c[1,32-47] - c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); - - // c[2,0-15] - c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); - - // c[2,16-31] - c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); - - // c[2,32-47] - c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); - - // c[3,0-15] - c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); - - // c[3,16-31] - c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); - - // c[3,32-47] - c_float_3p2 = _mm512_max_ps( selector1, c_float_3p2 ); - - // c[4,0-15] - c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); - - // c[4,16-31] - c_float_4p1 = _mm512_max_ps( selector1, c_float_4p1 ); - - // c[4,32-47] - c_float_4p2 = _mm512_max_ps( selector1, c_float_4p2 ); - - // c[5,0-15] - c_float_5p0 = _mm512_max_ps( selector1, c_float_5p0 ); - - // c[5,16-31] - c_float_5p1 = _mm512_max_ps( selector1, c_float_5p1 ); - - // c[5,32-47] - c_float_5p2 = _mm512_max_ps( selector1, c_float_5p2 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_6x48: - { - selector1 = _mm512_setzero_ps(); - selector2 = - _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_0p0) - - // c[0, 16-31] - RELU_SCALE_OP_F32_AVX512(c_float_0p1) - - // c[0, 32-47] - RELU_SCALE_OP_F32_AVX512(c_float_0p2) - - // c[1, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_1p0) - - // c[1, 16-31] - RELU_SCALE_OP_F32_AVX512(c_float_1p1) - - // c[1, 32-47] - RELU_SCALE_OP_F32_AVX512(c_float_1p2) - - // c[2, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_2p0) - - // c[2, 16-31] - RELU_SCALE_OP_F32_AVX512(c_float_2p1) - - // c[2, 32-47] - RELU_SCALE_OP_F32_AVX512(c_float_2p2) - - // c[3, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_3p0) - - // c[3, 16-31] - RELU_SCALE_OP_F32_AVX512(c_float_3p1) - - // c[3, 32-47] - RELU_SCALE_OP_F32_AVX512(c_float_3p2) - - // c[4, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_4p0) - - // c[4, 16-31] - RELU_SCALE_OP_F32_AVX512(c_float_4p1) - - // c[4, 32-47] - RELU_SCALE_OP_F32_AVX512(c_float_4p2) - - // c[5, 0-15] - RELU_SCALE_OP_F32_AVX512(c_float_5p0) - - // c[5, 16-31] - RELU_SCALE_OP_F32_AVX512(c_float_5p1) - - // c[5, 32-47] - RELU_SCALE_OP_F32_AVX512(c_float_5p2) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_6x48: - { - // c[0, 0-15] - CVT_F32_BF16(c_float_0p0,0,0); - - // c[0, 16-31] - CVT_F32_BF16(c_float_0p1,0,1); - - // c[0, 32-47] - CVT_F32_BF16(c_float_0p2,0,2); - - // c[1, 0-15] - CVT_F32_BF16(c_float_1p0,1,0); - - // c[1, 16-31] - CVT_F32_BF16(c_float_1p1,1,1); - - // c[1, 32-47] - CVT_F32_BF16(c_float_1p2,1,2); - - // c[2, 0-15] - CVT_F32_BF16(c_float_2p0,2,0); - - // c[2, 16-31] - CVT_F32_BF16(c_float_2p1,2,1); - - // c[2, 32-47] - CVT_F32_BF16(c_float_2p2,2,2); - - // c[3, 0-15] - CVT_F32_BF16(c_float_3p0,3,0); - - // c[3, 16-31] - CVT_F32_BF16(c_float_3p1,3,1); - - // c[3, 32-47] - CVT_F32_BF16(c_float_3p2,3,2); - - // c[4, 0-15] - CVT_F32_BF16(c_float_4p0,4,0); - - // c[4, 16-31] - CVT_F32_BF16(c_float_4p1,4,1); - - // c[4, 32-47] - CVT_F32_BF16(c_float_4p2,4,2); - - // c[5, 0-15] - CVT_F32_BF16(c_float_5p0,5,0); - - // c[5, 16-31] - CVT_F32_BF16(c_float_5p1,5,1); - - // c[5, 32-47] - CVT_F32_BF16(c_float_5p2,5,2); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_6x48_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), c_float_0p0 ); - - // c[0, 16-31] - _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ), c_float_0p1 ); - - // c[0,32-47] - _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 2*16 ), c_float_0p2 ); - - // c[1,0-15] - _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), c_float_1p0 ); - - // c[1,16-31] - _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ), c_float_1p1 ); - - // c[1,32-47] - _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 2*16 ), c_float_1p2 ); - - // c[2,0-15] - _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), c_float_2p0 ); - - // c[2,16-31] - _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ), c_float_2p1 ); - - // c[2,32-47] - _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 2*16 ), c_float_2p2 ); - - // c[3,0-15] - _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), c_float_3p0 ); - - // c[3,16-31] - _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ), c_float_3p1 ); - - // c[3,32-47] - _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 2*16 ), c_float_3p2 ); - - // c[4,0-15] - _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), c_float_4p0 ); - - // c[4,16-31] - _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ), c_float_4p1 ); - - // c[4,32-47] - _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 2*16 ), c_float_4p2 ); - - // c[5,0-15] - _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_float_5p0 ); - - // c[5,16-31] - _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ), c_float_5p1 ); - - // c[5,32-47] - _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 2*16 ), c_float_5p2 ); - - a = a + ( MR * ps_a ); - post_op_c_i += MR; - - } - - if ( m_partial_pieces > 0 ) - { - if ( m_partial_pieces == 5 ) - { - int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 5 ); - lpgemm_rowvar_bf16bf16f32of32_5x48 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 4 ) - { - int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 4 ); - lpgemm_rowvar_bf16bf16f32of32_4x48 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 3 ) - { - int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 3 ); - lpgemm_rowvar_bf16bf16f32of32_3x48 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 2 ) - { - int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 2 ); - lpgemm_rowvar_bf16bf16f32of32_2x48 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 1 ) - { - int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 1 ); - lpgemm_rowvar_bf16bf16f32of32_1x48 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - } -} -#endif diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_packb_bf16.h b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_packb_bf16.h index 07b22a5b25..db5d31e513 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_packb_bf16.h +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_packb_bf16.h @@ -1,67 +1,72 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#ifndef BLIS_GEMM_BF16_PACKB -#define BLIS_GEMM_BF16_PACKB - -#include "lpgemm_kernels.h" - -BLIS_INLINE dim_t get_packb_bf16bf16f32of32_min_NR() -{ - // This is the minimum NR' required for use in bf16bf16f32 kernels. The idea - // here is that since k needs to be a multiple of 2 (BF16 instr), NR'=16 - // results in total of 2 * NR' = 64 bytes to be loaded, which fits in 1 ZMM - // register. Thus the smallest n fringe kernel dimension has n=16, and thus - // any rounding for buffer sizes should be to 16. - return 16; -} - -void get_packb_nr64_bf16bf16f32of32_strides - ( - dim_t* rs_b, - dim_t* cs_b - ); - -void packb_nr64_bf16bf16f32of32 - ( - bfloat16* pack_b_buffer_bf16bf16f32of32, - const bfloat16* b, - const dim_t ldb, - const dim_t NC, - const dim_t KC, - dim_t* rs_b, - dim_t* cs_b - ); - -#endif //BLIS_GEMM_BF16_PACKB +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_GEMM_BF16_PACKB +#define BLIS_GEMM_BF16_PACKB + +#include "aocl_bf16_type.h" + +BLIS_INLINE dim_t get_packb_bf16bf16f32of32_min_NR() +{ + // This is the minimum NR' required for use in bf16bf16f32 kernels. The idea + // here is that since k needs to be a multiple of 2 (BF16 instr), NR'=16 + // results in total of 2 * NR' = 64 bytes to be loaded, which fits in 1 ZMM + // register. Thus the smallest n fringe kernel dimension has n=16, and thus + // any rounding for buffer sizes should be to 16. + return 16; +} + +typedef void (*packb_bf16) + ( + bfloat16*, + const bfloat16*, + const dim_t, + const dim_t, + const dim_t, + dim_t*, + dim_t* + ); + +void packb_nr64_bf16bf16f32of32 + ( + bfloat16* pack_b_buffer_bf16bf16f32of32, + const bfloat16* b, + const dim_t ldb, + const dim_t NC, + const dim_t KC, + dim_t* rs_b, + dim_t* cs_b + ); + +#endif //BLIS_GEMM_BF16_PACKB diff --git a/addon/aocl_gemm/kernels/lpgemm_kernels.h b/addon/aocl_gemm/kernels/lpgemm_kernels.h index 7b73ba27e9..add69df94f 100644 --- a/addon/aocl_gemm/kernels/lpgemm_kernels.h +++ b/addon/aocl_gemm/kernels/lpgemm_kernels.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-23, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -38,54 +38,69 @@ #include "lpgemm_post_ops.h" #include "aocl_bf16_type.h" +typedef void (*lpgemm_m_fringe_f32_ker_ft) + ( + const dim_t k0, + const float* a, + const dim_t rs_a, + const dim_t cs_a, + const float* b, + const dim_t rs_b, + const dim_t cs_b, + float* c, + const dim_t rs_c, + const float alpha, + const float beta, + lpgemm_post_op* post_ops_list, + lpgemm_post_op_attr post_ops_attr + ); + #define LPGEMM_MAIN_KERN(A_type,B_type,C_type,LP_SFX) \ void lpgemm_rowvar_ ## LP_SFX \ ( \ - const dim_t m0, \ - const dim_t n0, \ - const dim_t k0, \ - const A_type* a, \ - const dim_t rs_a, \ - const dim_t cs_a, \ - const dim_t ps_a, \ - const B_type* b, \ - const dim_t rs_b, \ - const dim_t cs_b, \ - C_type* c, \ - const dim_t rs_c, \ - const dim_t cs_c, \ - const C_type alpha, \ - const C_type beta, \ - bool is_last_k, \ - dim_t post_op_c_i, \ - dim_t post_op_c_j, \ - lpgemm_post_op* post_ops_list, \ - const dim_t rs_c_downscale \ + const dim_t m0, \ + const dim_t n0, \ + const dim_t k0, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const dim_t ps_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + C_type* c, \ + const dim_t rs_c, \ + const dim_t cs_c, \ + const C_type alpha, \ + const C_type beta, \ + lpgemm_post_op* post_ops_list, \ + lpgemm_post_op_attr post_ops_attr \ ) \ LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64); LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32); LPGEMM_MAIN_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x64); +LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_6x16m); +LPGEMM_MAIN_KERN(float,float,float,f32f32f32of32_avx512_6x64m); +LPGEMM_MAIN_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x64); +LPGEMM_MAIN_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x32); #define LPGEMM_M_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \ void lpgemm_rowvar_ ## LP_SFX \ ( \ - const dim_t k0, \ - const A_type* a, \ - const dim_t rs_a, \ - const dim_t cs_a, \ - const B_type* b, \ - const dim_t rs_b, \ - const dim_t cs_b, \ - C_type* c, \ - const dim_t rs_c, \ - const C_type alpha, \ - const C_type beta, \ - bool is_last_k, \ - dim_t post_op_c_i, \ - dim_t post_op_c_j, \ - lpgemm_post_op* post_ops_list, \ - const dim_t rs_c_downscale \ + const dim_t k0, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + C_type* c, \ + const dim_t rs_c, \ + const C_type alpha, \ + const C_type beta, \ + lpgemm_post_op* post_ops_list, \ + lpgemm_post_op_attr post_ops_attr \ ) \ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x64); @@ -104,31 +119,81 @@ LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3x64); LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2x64); LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1x64); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x64); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x64); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x64); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x64); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x64); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x48); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x48); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x48); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x48); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x48); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_5x32); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_4x32); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_3x32); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_2x32); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_1x32); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x16); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x16); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x16); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_2x16); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x16); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x8); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x8); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x8); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_2x8); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x8); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x4); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x4); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x4); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_2x4); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x4); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x2); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x2); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x2); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_2x2); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x2); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_5x1); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_4x1); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_3x1); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_2x1); +LPGEMM_M_FRINGE_KERN(float,float,float,f32f32f32of32_1x1); + +LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5x64); +LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4x64); +LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x64); +LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x64); +LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x64); + +LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x32); +LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x32); +LPGEMM_M_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x32); + #define LPGEMM_N_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \ void lpgemm_rowvar_ ## LP_SFX \ ( \ - const dim_t m0, \ - const dim_t k0, \ - const A_type* a, \ - const dim_t rs_a, \ - const dim_t cs_a, \ - const dim_t ps_a, \ - const B_type* b, \ - const dim_t rs_b, \ - const dim_t cs_b, \ - C_type* c, \ - const dim_t rs_c, \ - const C_type alpha, \ - const C_type beta, \ - bool is_last_k, \ - dim_t post_op_c_i, \ - dim_t post_op_c_j, \ - lpgemm_post_op* post_ops_list, \ - const dim_t rs_c_downscale \ + const dim_t m0, \ + const dim_t k0, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const dim_t ps_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + C_type* c, \ + const dim_t rs_c, \ + const C_type alpha, \ + const C_type beta, \ + lpgemm_post_op* post_ops_list, \ + lpgemm_post_op_attr post_ops_attr \ ) \ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x16); +LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_12x16); LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x32); +LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_9x32); LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x48); LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16); @@ -137,55 +202,67 @@ LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x16); LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x32); LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x48); +LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_6x48m); +LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_avx512_6x32m); +LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x8m); +LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x4m); +LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x2m); +LPGEMM_N_FRINGE_KERN(float,float,float,f32f32f32of32_6x1m); + +LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x16); +LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x32); +LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6x48); + +LPGEMM_N_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6x16); + #define LPGEMM_N_LT_NR0_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \ void lpgemm_rowvar_ ## LP_SFX \ ( \ - const dim_t m0, \ - const dim_t k0, \ - const A_type* a, \ - const dim_t rs_a, \ - const dim_t cs_a, \ - const dim_t ps_a, \ - const B_type* b, \ - const dim_t rs_b, \ - const dim_t cs_b, \ - C_type* c, \ - const dim_t rs_c, \ - const C_type alpha, \ - const C_type beta, \ - const dim_t n0_rem, \ - bool is_last_k, \ - dim_t post_op_c_i, \ - dim_t post_op_c_j, \ - lpgemm_post_op* post_ops_list, \ - const dim_t rs_c_downscale \ + const dim_t m0, \ + const dim_t k0, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const dim_t ps_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + C_type* c, \ + const dim_t rs_c, \ + const C_type alpha, \ + const C_type beta, \ + const dim_t n0_rem, \ + lpgemm_post_op* post_ops_list, \ + lpgemm_post_op_attr post_ops_attr \ ) \ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16); +LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_12xlt16); LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16); LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6xlt16); +LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_6xlt16); + +LPGEMM_N_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_6xlt16); + #define LPGEMM_MN_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \ void lpgemm_rowvar_ ## LP_SFX \ ( \ - const dim_t k0, \ - const A_type* a, \ - const dim_t rs_a, \ - const dim_t cs_a, \ - const B_type* b, \ - const dim_t rs_b, \ - const dim_t cs_b, \ - C_type* c, \ - const dim_t rs_c, \ - const C_type alpha, \ - const C_type beta, \ - bool is_last_k, \ - dim_t post_op_c_i, \ - dim_t post_op_c_j, \ - lpgemm_post_op* post_ops_list, \ - const dim_t rs_c_downscale \ + const dim_t k0, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + C_type* c, \ + const dim_t rs_c, \ + const C_type alpha, \ + const C_type beta, \ + lpgemm_post_op* post_ops_list, \ + lpgemm_post_op_attr post_ops_attr \ ) \ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x16); @@ -224,26 +301,43 @@ LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3x48); LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2x48); LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1x48); +LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5x16); +LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4x16); +LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x16); +LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x16); +LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x16); +LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5x32); +LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4x32); +LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x32); +LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x32); +LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x32); +LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5x48); +LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4x48); +LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3x48); +LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2x48); +LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1x48); + +LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4x16); +LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2x16); +LPGEMM_MN_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1x16); + #define LPGEMM_MN_LT_NR0_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \ void lpgemm_rowvar_ ## LP_SFX \ ( \ - const dim_t k0, \ - const A_type* a, \ - const dim_t rs_a, \ - const dim_t cs_a, \ - const B_type* b, \ - const dim_t rs_b, \ - const dim_t cs_b, \ - C_type* c, \ - const dim_t rs_c, \ - const C_type alpha, \ - const C_type beta, \ - const dim_t n0_rem, \ - bool is_last_k, \ - dim_t post_op_c_i, \ - dim_t post_op_c_j, \ - lpgemm_post_op* post_ops_list, \ - const dim_t rs_c_downscale \ + const dim_t k0, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + C_type* c, \ + const dim_t rs_c, \ + const C_type alpha, \ + const C_type beta, \ + const dim_t n0_rem, \ + lpgemm_post_op* post_ops_list, \ + lpgemm_post_op_attr post_ops_attr \ ) \ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5xlt16); @@ -262,4 +356,14 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3xlt16); LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2xlt16); LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_5xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_4xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_3xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_2xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int32_t,s8s8s32os32_1xlt16); + +LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_4xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_2xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(int8_t,int8_t,int16_t,s8s8s16o16_1xlt16); + #endif //BLIS_LPGEMM_KERN_H diff --git a/addon/aocl_gemm/kernels/lpgemm_utils_kernels.h b/addon/aocl_gemm/kernels/lpgemm_utils_kernels.h new file mode 100644 index 0000000000..7849e5a537 --- /dev/null +++ b/addon/aocl_gemm/kernels/lpgemm_utils_kernels.h @@ -0,0 +1,63 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_LPGEMM_UTILS_KERN_H +#define BLIS_LPGEMM_UTILS_KERN_H + +typedef void (*lpgemm_util_l1_op_f32_kernel_t) + ( + const dim_t n, + float* x, + const inc_t incx + ); + +#define LPGEMM_UTIL_L1_OP_KERNEL(V_type,OP_type) \ +void lpgemm_util_ ## OP_type ## _kernel \ + ( \ + const dim_t n, \ + V_type* x, \ + const inc_t incx \ + ) \ + +// AVX512 +LPGEMM_UTIL_L1_OP_KERNEL(float,f32_gelu_tanh_avx512); +LPGEMM_UTIL_L1_OP_KERNEL(float,f32_gelu_erf_avx512); +LPGEMM_UTIL_L1_OP_KERNEL(float,f32_softmax_avx512); + +// AVX2 +LPGEMM_UTIL_L1_OP_KERNEL(float,f32_gelu_tanh_avx2); +LPGEMM_UTIL_L1_OP_KERNEL(float,f32_gelu_erf_avx2); +LPGEMM_UTIL_L1_OP_KERNEL(float,f32_softmax_avx2); + +#endif //BLIS_LPGEMM_UTILS_KERN_H diff --git a/addon/aocl_gemm/kernels/s8s8s16/lpgemm_packb_s8s16.h b/addon/aocl_gemm/kernels/s8s8s16/lpgemm_packb_s8s16.h new file mode 100644 index 0000000000..f3f49e9002 --- /dev/null +++ b/addon/aocl_gemm/kernels/s8s8s16/lpgemm_packb_s8s16.h @@ -0,0 +1,63 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_GEMM_S8_INT16_PACKB +#define BLIS_GEMM_S8_INT16_PACKB + +typedef void (*packb_s16_s8) + ( + int8_t*, + int16_t*, + const int8_t*, + const dim_t, + const dim_t, + const dim_t, + dim_t*, + dim_t* + ); + +void packb_nr32_s8s8s16o16 + ( + int8_t *pack_b_buffer_s8s8s16o16, + int16_t *pack_b_column_sum, + const int8_t *b, + const dim_t ldb, + const dim_t cols, + const dim_t rows, + dim_t *rs_b, + dim_t *cs_b + ); + +#endif // BLIS_GEMM_S8_INT16_PACKB + diff --git a/addon/aocl_gemm/kernels/s8s8s32/lpgemm_packa_s8.h b/addon/aocl_gemm/kernels/s8s8s32/lpgemm_packa_s8.h new file mode 100644 index 0000000000..e31c30c563 --- /dev/null +++ b/addon/aocl_gemm/kernels/s8s8s32/lpgemm_packa_s8.h @@ -0,0 +1,61 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_GEMM_INT8_PACKA_S8 +#define BLIS_GEMM_INT8_PACKA_S8 + +typedef void (*packa_s32_s8) + ( + int8_t*, + const int8_t*, + const dim_t, + const dim_t, + const dim_t, + dim_t*, + dim_t* + ); + +void packa_k64_s8s8s32os32 + ( + int8_t* pack_a_buffer_s8s8s32o32, + const int8_t* a, + const dim_t lda, + const dim_t MC, + const dim_t KC, + dim_t* rs_a, + dim_t* cs_a + ); + +#endif //BLIS_GEMM_INT8_PACKA_S8 + diff --git a/addon/aocl_gemm/kernels/s8s8s32/lpgemm_packb_s8.h b/addon/aocl_gemm/kernels/s8s8s32/lpgemm_packb_s8.h new file mode 100644 index 0000000000..661c153436 --- /dev/null +++ b/addon/aocl_gemm/kernels/s8s8s32/lpgemm_packb_s8.h @@ -0,0 +1,73 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_GEMM_INT8_PACKB_S8 +#define BLIS_GEMM_INT8_PACKB_S8 + +BLIS_INLINE dim_t get_packb_s8s8s32o32_min_NR() +{ + // This is the minimum NR' required for use in u8s8s32 kernels. The idea + // here is that since k needs to be a multiple of 4 (VNNI instr), NR'=16 + // results in total of 4 * NR' = 64 bytes to be loaded, which fits in 1 ZMM + // register. Thus the smallest n fringe kernel dimension has n=16, and thus + // any rounding for buffer sizes should be to 16. + return 16; +} + +typedef void (*packb_s32_s8) + ( + int8_t*, + int32_t*, + const int8_t*, + const dim_t, + const dim_t, + const dim_t, + dim_t*, + dim_t* + ); + +void packb_nr64_s8s8s32os32 + ( + int8_t* pack_b_buffer_s8s8s32o32, + int32_t* pack_b_column_sum, + const int8_t* b, + const dim_t ldb, + const dim_t NC, + const dim_t KC, + dim_t* rs_b, + dim_t* cs_b + ); + +#endif //BLIS_GEMM_INT8_PACKB_S8 + diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_s16.h b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_s16.h index b8d73c862c..a8f64c3fe0 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_s16.h +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_s16.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,10 +35,15 @@ #ifndef BLIS_GEMM_INT16_PACKB #define BLIS_GEMM_INT16_PACKB -void get_packb_nr32_u8s8s16o16_strides +typedef void (*packb_s16) ( - dim_t* rs_b, - dim_t* cs_b + int8_t*, + const int8_t*, + const dim_t, + const dim_t, + const dim_t, + dim_t*, + dim_t* ); void packb_nr32_u8s8s16o16 @@ -52,4 +57,4 @@ void packb_nr32_u8s8s16o16 dim_t *cs_b ); -#endif // BLIS_GEMM_INT16_PACKB \ No newline at end of file +#endif // BLIS_GEMM_INT16_PACKB diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_s16_kern_macros.h b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_s16_kern_macros.h deleted file mode 100644 index 00583977f3..0000000000 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_s16_kern_macros.h +++ /dev/null @@ -1,404 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#ifndef LPGEMM_S16_KERN_MACROS_H -#define LPGEMM_S16_KERN_MACROS_H -#define S8_MIN (-128) -#define S8_MAX (+127) - -#define RELU_SCALE_OP_S16_AVX2(reg) \ - selector1 = _mm256_setzero_si256();\ - selector1 = _mm256_cmpgt_epi16 ( selector1, reg ); \ - \ - /* Only < 0 elements in b0. */ \ - b0 = _mm256_and_si256 ( selector1, reg ); \ -\ - /* Only >= 0 elements in c_int16_0p0. */ \ - reg = _mm256_andnot_si256( selector1, reg ); \ - \ - /* Only scaling for < 0 elements. */ \ - b0 = _mm256_mullo_epi16( b0, selector2 ); \ - \ - /* Combine the scaled < 0 and >= 0 elements. */ \ - reg = _mm256_or_si256( b0, reg ); \ - \ - -//-------------------------------------------------------------------------- - -#define BLI_MM256_S16_DOWNSCALE(c_int16__p0, c_int16__p1, vec_loc)\ -\ - /* Extract the first 128 bits of the register*/\ - temp[0] = _mm256_extractf128_si256(c_int16__p0, 0);\ - /* Extract the second 128 bits of the register*/\ - temp[1] = _mm256_extractf128_si256(c_int16__p0, 1);\ -\ - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ -\ - /* Multiply the C matrix by the scale value*/\ - res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ - res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ -\ - /* Round the resultant value to the nearest float value and clip the values between [-128, 127] */\ - res_1 = _mm256_min_ps(_mm256_max_ps \ - (_mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ - _mm256_set1_ps(( float )S8_MIN)), _mm256_set1_ps(( float )S8_MAX));\ - res_2 = _mm256_min_ps(_mm256_max_ps \ - (_mm256_round_ps (res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ - _mm256_set1_ps(( float )S8_MIN)), _mm256_set1_ps(( float )S8_MAX));\ -\ - /* Convert the clipped float32 scaled rounded value to int32 */\ - temp_32[0] = _mm256_cvtps_epi32(res_1);\ - temp_32[1] = _mm256_cvtps_epi32(res_2);\ -\ - /* Convert the s32 to s16 */\ - c_int16__p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]);\ -\ - /*Permute to make sure the order is correct*/\ - c_int16__p0 = _mm256_permute4x64_epi64(c_int16__p0, 0XD8);\ -\ - /* Extract the first 128 bits of the register*/\ - temp[0] = _mm256_extractf128_si256(c_int16__p1, 0);\ -\ - /* Extract the second 128 bits of the register*/\ - temp[1] = _mm256_extractf128_si256(c_int16__p1, 1);\ -\ - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ -\ - /* Multiply the C matrix by the scale value*/\ - res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ - res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ -\ - /* Round the resultant value to the nearest float value and clip the values between [-128, 127] */\ - res_1 = _mm256_min_ps(_mm256_max_ps \ - (_mm256_round_ps (res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ - _mm256_set1_ps(( float )S8_MIN)), _mm256_set1_ps(( float )S8_MAX));\ - res_2 = _mm256_min_ps(_mm256_max_ps \ - (_mm256_round_ps (res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ - _mm256_set1_ps(( float )S8_MIN)), _mm256_set1_ps(( float )S8_MAX));\ -\ - /* Convert the clipped float32 scaled rounded value to int32 */\ - temp_32[0] = _mm256_cvtps_epi32(res_1);\ - temp_32[1] = _mm256_cvtps_epi32(res_2);\ -\ - /* Convert the s32 to s16 */\ - c_int16__p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]);\ -\ - /*Permute to make sure the order is correct*/\ - c_int16__p1 = _mm256_permute4x64_epi64(c_int16__p1, 0XD8);\ -\ - /* Convert the s16 to s8 */\ - store_reg = _mm256_packs_epi16(c_int16__p0, c_int16__p1);\ - store_reg = _mm256_permute4x64_epi64(store_reg, 0XD8);\ -\ - /* Store the result in s8 form */\ - _mm256_storeu_si256((__m256i *)(( int8_t* )post_ops_list_temp->op_args3 + \ - ( rs_c_downscale * ( post_op_c_i + vec_loc ) ) + post_op_c_j), store_reg);\ -\ - -//-------------------------------------------------------------------------- - -#define BLI_MM256_S16_DOWNSCALE2(c_int16__p0, c_int16__p1, vec_loc1, vec_loc2)\ -\ - /* Extract the first 128 bits of the register*/\ - temp[0] = _mm256_extractf128_si256(c_int16__p0, 0);\ - /* Extract the second 128 bits of the register*/\ - temp[1] = _mm256_extractf128_si256(c_int16__p0, 1);\ -\ - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ -\ - /* Multiply the C matrix by the scale value*/\ - res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ - res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ -\ - /* Round the resultant value to the nearest float value and clip the values between [-128, 127] */\ - res_1 = _mm256_min_ps(_mm256_max_ps \ - (_mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ - _mm256_set1_ps(( float )S8_MIN)), _mm256_set1_ps(( float )S8_MAX));\ - res_2 = _mm256_min_ps(_mm256_max_ps \ - (_mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ - _mm256_set1_ps(( float )S8_MIN)), _mm256_set1_ps(( float )S8_MAX));\ -\ - /* Convert the clipped float32 scaled rounded value to int32 */\ - temp_32[0] = _mm256_cvtps_epi32(res_1);\ - temp_32[1] = _mm256_cvtps_epi32(res_2);\ -\ - /* Convert the s32 to s16 */\ - c_int16__p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]);\ -\ - /*Permute to make sure the order is correct*/\ - c_int16__p0 = _mm256_permute4x64_epi64(c_int16__p0, 0XD8);\ -\ - /* Extract the first 128 bits of the register*/\ - temp[0] = _mm256_extractf128_si256(c_int16__p1, 0);\ -\ - /* Extract the second 128 bits of the register*/\ - temp[1] = _mm256_extractf128_si256(c_int16__p1, 1);\ -\ - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ -\ - /* Multiply the C matrix by the scale value*/\ - res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ - res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ -\ - /* Round the resultant value to the nearest float value and clip the values between [-128, 127] */\ - res_1 = _mm256_min_ps(_mm256_max_ps \ - (_mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ - _mm256_set1_ps(( float )S8_MIN)), _mm256_set1_ps(( float )S8_MAX));\ - res_2 = _mm256_min_ps(_mm256_max_ps \ - (_mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ - _mm256_set1_ps(( float )S8_MIN)), _mm256_set1_ps(( float )S8_MAX));\ -\ - /* Convert the clipped float32 scaled rounded value to int32 */\ - temp_32[0] = _mm256_cvtps_epi32(res_1);\ - temp_32[1] = _mm256_cvtps_epi32(res_2);\ -\ - /* Convert the s32 to s16 */\ - c_int16__p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]);\ -\ - /*Permute to make sure the order is correct*/\ - c_int16__p1 = _mm256_permute4x64_epi64(c_int16__p1, 0XD8);\ -\ - /* Convert the s16 to s8 */\ - store_reg = _mm256_packs_epi16(c_int16__p0, c_int16__p1);\ - store_reg = _mm256_permute4x64_epi64(store_reg, 0XD8);\ - /* Extract the first 128 bits of the register*/\ - temp[0] = _mm256_extractf128_si256(store_reg, 0);\ - /* Extract the second 128 bits of the register*/\ - temp[1] = _mm256_extractf128_si256(store_reg, 1);\ -\ - /* Store the result in s8 form */\ - _mm_storeu_si128((__m128i *)(( int8_t* )post_ops_list_temp->op_args3 + \ - ( rs_c_downscale * ( post_op_c_i + vec_loc1 ) ) + post_op_c_j), temp[0]);\ - _mm_storeu_si128((__m128i *)(( int8_t* )post_ops_list_temp->op_args3 + \ - ( rs_c_downscale * ( post_op_c_i + vec_loc2 ) ) + post_op_c_j), temp[1]);\ -\ - -//-------------------------------------------------------------------------- - -#define BLI_MM256_S16_DOWNSCALE2_LT16(c_int16__p0, c_int16__p1, vec_loc1, vec_loc2)\ -\ - /* Extract the first 128 bits of the register*/\ - temp[0] = _mm256_extractf128_si256(c_int16__p0, 0);\ - /* Extract the second 128 bits of the register*/\ - temp[1] = _mm256_extractf128_si256(c_int16__p0, 1);\ -\ - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ -\ - /* Multiply the C matrix by the scale value*/\ - res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ - res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ -\ - /* Round the resultant value to the nearest float value and clip the values between [-128, 127] */\ - res_1 = _mm256_min_ps(_mm256_max_ps \ - (_mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ - _mm256_set1_ps (( float )S8_MIN)), _mm256_set1_ps (( float )S8_MAX));\ - res_2 = _mm256_min_ps(_mm256_max_ps \ - (_mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ - _mm256_set1_ps (( float )S8_MIN)), _mm256_set1_ps (( float )S8_MAX));\ -\ - /* Convert the clipped float32 scaled rounded value to int32 */\ - temp_32[0] = _mm256_cvtps_epi32(res_1);\ - temp_32[1] = _mm256_cvtps_epi32(res_2);\ -\ - /* Convert the s32 to s16 */\ - c_int16__p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]);\ -\ - /*Permute to make sure the order is correct*/\ - c_int16__p0 = _mm256_permute4x64_epi64(c_int16__p0, 0XD8);\ -\ - /* Extract the first 128 bits of the register*/\ - temp[0] = _mm256_extractf128_si256(c_int16__p1, 0);\ -\ - /* Extract the second 128 bits of the register*/\ - temp[1] = _mm256_extractf128_si256(c_int16__p1, 1);\ -\ - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ -\ - /* Multiply the C matrix by the scale value*/\ - res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ - res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ -\ - /* Round the resultant value to the nearest float value and clip the values between [-128, 127] */\ - res_1 = _mm256_min_ps(_mm256_max_ps \ - (_mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ - _mm256_set1_ps (( float )S8_MIN)), _mm256_set1_ps (( float )S8_MAX));\ - res_2 = _mm256_min_ps(_mm256_max_ps \ - (_mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ - _mm256_set1_ps (( float )S8_MIN)), _mm256_set1_ps (( float )S8_MAX));\ -\ - /* Convert the clipped float32 scaled rounded value to int32 */\ - temp_32[0] = _mm256_cvtps_epi32(res_1);\ - temp_32[1] = _mm256_cvtps_epi32(res_2);\ -\ - /* Convert the s32 to s16 */\ - c_int16__p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]);\ -\ - /*Permute to make sure the order is correct*/\ - c_int16__p1 = _mm256_permute4x64_epi64(c_int16__p1, 0XD8);\ -\ - /* Convert the s16 to s8 */\ - store_reg = _mm256_packs_epi16(c_int16__p0, c_int16__p1);\ - store_reg = _mm256_permute4x64_epi64(store_reg, 0XD8);\ - /* Extract the first 128 bits of the register*/\ - temp[0] = _mm256_extractf128_si256(store_reg, 0);\ - /* Extract the second 128 bits of the register*/\ - temp[1] = _mm256_extractf128_si256(store_reg, 1);\ -\ - /* Store the result in s8 form */\ - _mm_storeu_si128((__m128i *)store_buf, temp[0]);\ - memcpy( ( int8_t* )post_ops_list_temp->op_args3 + \ - ( rs_c_downscale * ( post_op_c_i + vec_loc1 ) ) + post_op_c_j \ - , store_buf, ( n0_rem * sizeof( int8_t ) ) ); \ -\ - _mm_storeu_si128((__m128i *)store_buf, temp[1]);\ - memcpy( ( int8_t* )post_ops_list_temp->op_args3 + \ - ( rs_c_downscale * ( post_op_c_i + vec_loc2 ) ) + post_op_c_j \ - , store_buf, ( n0_rem * sizeof( int8_t ) ) ); \ -\ - -//-------------------------------------------------------------------------- - -#define BLI_MM256_S16_DOWNSCALE2_EDGE(c_int16__p0, vec_ind)\ -\ - /* Extract the first 128 bits of the register*/\ - temp[0] = _mm256_extractf128_si256(c_int16__p0, 0);\ - /* Extract the second 128 bits of the register*/\ - temp[1] = _mm256_extractf128_si256(c_int16__p0, 1);\ -\ - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ -\ - /* Multiply the C matrix by the scale value*/\ - res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ - res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ -\ - /* Round the resultant value to the nearest float value and clip the values between [-128, 127] */\ - res_1 = _mm256_min_ps(_mm256_max_ps \ - (_mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ - _mm256_set1_ps (( float )S8_MIN)), _mm256_set1_ps (( float )S8_MAX));\ - res_2 = _mm256_min_ps(_mm256_max_ps \ - (_mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ - _mm256_set1_ps (( float )S8_MIN)), _mm256_set1_ps (( float )S8_MAX));\ -\ - /* Convert the clipped float32 scaled rounded value to int32 */\ - temp_32[0] = _mm256_cvtps_epi32(res_1);\ - temp_32[1] = _mm256_cvtps_epi32(res_2);\ -\ - /* Convert the s32 to s16 */\ - c_int16__p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]);\ -\ - /*Permute to make sure the order is correct*/\ - c_int16__p0 = _mm256_permute4x64_epi64(c_int16__p0, 0XD8);\ -\ - /* Convert the s16 to s8 */\ - store_reg = _mm256_packs_epi16(c_int16__p0, zero_reg);\ - store_reg = _mm256_permute4x64_epi64(store_reg, 0XD8);\ - /* Extract the first 128 bits of the register*/\ - temp[0] = _mm256_extractf128_si256(store_reg, 0);\ -\ - /* Store the result in s8 form */\ - _mm_storeu_si128((__m128i *)(( int8_t* )post_ops_list_temp->op_args3 + \ - ( rs_c_downscale * ( post_op_c_i + vec_ind ) ) + post_op_c_j), temp[0]);\ -\ - -//-------------------------------------------------------------------------- - -#define BLI_MM256_S16_DOWNSCALE2_EDGE_LT16(c_int16__p0, vec_ind)\ -\ - /* Extract the first 128 bits of the register*/\ - temp[0] = _mm256_extractf128_si256(c_int16__p0, 0);\ - /* Extract the second 128 bits of the register*/\ - temp[1] = _mm256_extractf128_si256(c_int16__p0, 1);\ -\ - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ -\ - /* Multiply the C matrix by the scale value*/\ - res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ - res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ -\ - /* Round the resultant value to the nearest float value and clip the values between [-128, 127] */\ - res_1 = _mm256_min_ps(_mm256_max_ps \ - (_mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ - _mm256_set1_ps (( float )S8_MIN)), _mm256_set1_ps (( float )S8_MAX));\ - res_2 = _mm256_min_ps(_mm256_max_ps \ - (_mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ - _mm256_set1_ps (( float )S8_MIN)), _mm256_set1_ps (( float )S8_MAX));\ -\ - /* Convert the clipped float32 scaled rounded value to int32 */\ - temp_32[0] = _mm256_cvtps_epi32(res_1);\ - temp_32[1] = _mm256_cvtps_epi32(res_2);\ -\ - /* Convert the s32 to s16 */\ - c_int16__p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]);\ -\ - /*Permute to make sure the order is correct*/\ - c_int16__p0 = _mm256_permute4x64_epi64(c_int16__p0, 0XD8);\ -\ - /* Convert the s16 to s8 */\ - store_reg = _mm256_packs_epi16(c_int16__p0, zero_reg);\ - store_reg = _mm256_permute4x64_epi64(store_reg, 0XD8);\ - /* Extract the first 128 bits of the register*/\ - temp[0] = _mm256_extractf128_si256(store_reg, 0);\ -\ - /* Store the result in s8 form */\ - _mm_storeu_si128((__m128i *)store_buf, temp[0]);\ - memcpy( (( int8_t* )post_ops_list_temp->op_args3 + \ - ( rs_c_downscale * ( post_op_c_i + vec_ind ) ) + post_op_c_j) \ - ,store_buf, ( n0_rem * sizeof( int8_t ) ) ); \ -\ - -#endif //LPGEMM_S16_KERN_MACROS_H \ No newline at end of file diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c deleted file mode 100644 index 1674a22bd0..0000000000 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c +++ /dev/null @@ -1,2362 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include - -#include "blis.h" -#include "lpgemm_kernels.h" -#include "lpgemm_s32_kern_macros.h" - -#ifdef BLIS_KERNELS_ZEN4 -// 5x64 int8o32 kernel -LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x64) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_5x64_DISABLE, - &&POST_OPS_BIAS_5x64, - &&POST_OPS_RELU_5x64, - &&POST_OPS_RELU_SCALE_5x64, - &&POST_OPS_DOWNSCALE_5x64 - }; - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; - __m512i b3; - - // A matrix storage. - __m512i a_int32_0; - __m512i a_int32_1; - - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - __m512i c_int32_0p1 = _mm512_setzero_epi32(); - __m512i c_int32_0p2 = _mm512_setzero_epi32(); - __m512i c_int32_0p3 = _mm512_setzero_epi32(); - - __m512i c_int32_1p0 = _mm512_setzero_epi32(); - __m512i c_int32_1p1 = _mm512_setzero_epi32(); - __m512i c_int32_1p2 = _mm512_setzero_epi32(); - __m512i c_int32_1p3 = _mm512_setzero_epi32(); - - __m512i c_int32_2p0 = _mm512_setzero_epi32(); - __m512i c_int32_2p1 = _mm512_setzero_epi32(); - __m512i c_int32_2p2 = _mm512_setzero_epi32(); - __m512i c_int32_2p3 = _mm512_setzero_epi32(); - - __m512i c_int32_3p0 = _mm512_setzero_epi32(); - __m512i c_int32_3p1 = _mm512_setzero_epi32(); - __m512i c_int32_3p2 = _mm512_setzero_epi32(); - __m512i c_int32_3p3 = _mm512_setzero_epi32(); - - __m512i c_int32_4p0 = _mm512_setzero_epi32(); - __m512i c_int32_4p1 = _mm512_setzero_epi32(); - __m512i c_int32_4p2 = _mm512_setzero_epi32(); - __m512i c_int32_4p3 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); - b3 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 3 ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - - // Broadcast a[1,kr:kr+4]. - a_int32_1 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); - c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-63] = a[1,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_1, b0 ); - - // Broadcast a[2,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); - - c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_1, b1 ); - c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_1, b2 ); - c_int32_1p3 = _mm512_dpbusd_epi32( c_int32_1p3, a_int32_1, b3 ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-63] = a[2,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - - // Broadcast a[3,kr:kr+4]. - a_int32_1 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); - - c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); - c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); - c_int32_2p3 = _mm512_dpbusd_epi32( c_int32_2p3, a_int32_0, b3 ); - - // Perform column direction mat-mul with k = 4. - // c[3,0-63] = a[3,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_1, b0 ); - - // Broadcast a[4,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); - - c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_1, b1 ); - c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_1, b2 ); - c_int32_3p3 = _mm512_dpbusd_epi32( c_int32_3p3, a_int32_1, b3 ); - - // Perform column direction mat-mul with k = 4. - // c[4,0-63] = a[4,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); - c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); - c_int32_4p2 = _mm512_dpbusd_epi32( c_int32_4p2, a_int32_0, b2 ); - c_int32_4p3 = _mm512_dpbusd_epi32( c_int32_4p3, a_int32_0, b3 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); - b3 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - - // Broadcast a[1,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_1 = _mm512_set1_epi32( a_kfringe_buf ); - - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); - c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-63] = a[1,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_1, b0 ); - - // Broadcast a[2,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_1, b1 ); - c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_1, b2 ); - c_int32_1p3 = _mm512_dpbusd_epi32( c_int32_1p3, a_int32_1, b3 ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-63] = a[2,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - - // Broadcast a[3,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_1 = _mm512_set1_epi32( a_kfringe_buf ); - - c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); - c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); - c_int32_2p3 = _mm512_dpbusd_epi32( c_int32_2p3, a_int32_0, b3 ); - - // Perform column direction mat-mul with k = 4. - // c[3,0-63] = a[3,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_1, b0 ); - - // Broadcast a[4,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_1, b1 ); - c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_1, b2 ); - c_int32_3p3 = _mm512_dpbusd_epi32( c_int32_3p3, a_int32_1, b3 ); - - // Perform column direction mat-mul with k = 4. - // c[4,0-63] = a[4,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); - c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); - c_int32_4p2 = _mm512_dpbusd_epi32( c_int32_4p2, a_int32_0, b2 ); - c_int32_4p3 = _mm512_dpbusd_epi32( c_int32_4p3, a_int32_0, b3 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); - c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); - c_int32_0p3 = _mm512_mullo_epi32( selector1, c_int32_0p3 ); - - c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); - c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); - c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); - c_int32_1p3 = _mm512_mullo_epi32( selector1, c_int32_1p3 ); - - c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); - c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); - c_int32_2p2 = _mm512_mullo_epi32( selector1, c_int32_2p2 ); - c_int32_2p3 = _mm512_mullo_epi32( selector1, c_int32_2p3 ); - - c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); - c_int32_3p1 = _mm512_mullo_epi32( selector1, c_int32_3p1 ); - c_int32_3p2 = _mm512_mullo_epi32( selector1, c_int32_3p2 ); - c_int32_3p3 = _mm512_mullo_epi32( selector1, c_int32_3p3 ); - - c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); - c_int32_4p1 = _mm512_mullo_epi32( selector1, c_int32_4p1 ); - c_int32_4p2 = _mm512_mullo_epi32( selector1, c_int32_4p2 ); - c_int32_4p3 = _mm512_mullo_epi32( selector1, c_int32_4p3 ); - - // Scale C by beta. - if ( beta != 0 ) - { - // c[0,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); - - // c[0,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); - - // c[0,48-63] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 3*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p3 = _mm512_add_epi32( selector1, c_int32_0p3 ); - - // c[1,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[1,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); - - // c[1,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); - - // c[1,48-63] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 3*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p3 = _mm512_add_epi32( selector1, c_int32_1p3 ); - - // c[2,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[2,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); - - // c[2,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p2 = _mm512_add_epi32( selector1, c_int32_2p2 ); - - // c[2,48-63] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 3*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p3 = _mm512_add_epi32( selector1, c_int32_2p3 ); - - // c[3,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); - - // c[3,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p1 = _mm512_add_epi32( selector1, c_int32_3p1 ); - - // c[3,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p2 = _mm512_add_epi32( selector1, c_int32_3p2 ); - - // c[3,48-63] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 3*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p3 = _mm512_add_epi32( selector1, c_int32_3p3 ); - - // c[4,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); - - // c[4,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_4p1 = _mm512_add_epi32( selector1, c_int32_4p1 ); - - // c[4,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_4p2 = _mm512_add_epi32( selector1, c_int32_4p2 ); - - // c[4,48-63] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 3*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_4p3 = _mm512_add_epi32( selector1, c_int32_4p3 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_5x64: - { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j ); - selector2 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 2 * 16 ) ); - a_int32_1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 3 * 16 ) ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); - - // c[0,32-47] - c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); - - // c[0,48-63] - c_int32_0p3 = _mm512_add_epi32( a_int32_1, c_int32_0p3 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[1, 16-31] - c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); - - // c[1,32-47] - c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); - - // c[1,48-63] - c_int32_1p3 = _mm512_add_epi32( a_int32_1, c_int32_1p3 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[2, 16-31] - c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); - - // c[2,32-47] - c_int32_2p2 = _mm512_add_epi32( a_int32_0, c_int32_2p2 ); - - // c[2,48-63] - c_int32_2p3 = _mm512_add_epi32( a_int32_1, c_int32_2p3 ); - - // c[3,0-15] - c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); - - // c[3, 16-31] - c_int32_3p1 = _mm512_add_epi32( selector2, c_int32_3p1 ); - - // c[3,32-47] - c_int32_3p2 = _mm512_add_epi32( a_int32_0, c_int32_3p2 ); - - // c[3,48-63] - c_int32_3p3 = _mm512_add_epi32( a_int32_1, c_int32_3p3 ); - - // c[4,0-15] - c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); - - // c[4, 16-31] - c_int32_4p1 = _mm512_add_epi32( selector2, c_int32_4p1 ); - - // c[4,32-47] - c_int32_4p2 = _mm512_add_epi32( a_int32_0, c_int32_4p2 ); - - // c[4,48-63] - c_int32_4p3 = _mm512_add_epi32( a_int32_1, c_int32_4p3 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_5x64: - { - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); - - // c[0,32-47] - c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); - - // c[0,48-63] - c_int32_0p3 = _mm512_max_epi32( selector1, c_int32_0p3 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); - - // c[1,16-31] - c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); - - // c[1,32-47] - c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); - - // c[1,48-63] - c_int32_1p3 = _mm512_max_epi32( selector1, c_int32_1p3 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); - - // c[2,16-31] - c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); - - // c[2,32-47] - c_int32_2p2 = _mm512_max_epi32( selector1, c_int32_2p2 ); - - // c[2,48-63] - c_int32_2p3 = _mm512_max_epi32( selector1, c_int32_2p3 ); - - // c[3,0-15] - c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); - - // c[3,16-31] - c_int32_3p1 = _mm512_max_epi32( selector1, c_int32_3p1 ); - - // c[3,32-47] - c_int32_3p2 = _mm512_max_epi32( selector1, c_int32_3p2 ); - - // c[3,48-63] - c_int32_3p3 = _mm512_max_epi32( selector1, c_int32_3p3 ); - - // c[4,0-15] - c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); - - // c[4,16-31] - c_int32_4p1 = _mm512_max_epi32( selector1, c_int32_4p1 ); - - // c[4,32-47] - c_int32_4p2 = _mm512_max_epi32( selector1, c_int32_4p2 ); - - // c[4,48-63] - c_int32_4p3 = _mm512_max_epi32( selector1, c_int32_4p3 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_5x64: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - // c[0, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_0p1) - - // c[0, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_0p2) - - // c[0, 48-63] - RELU_SCALE_OP_S32_AVX512(c_int32_0p3) - - // c[1, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_1p0) - - // c[1, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_1p1) - - // c[1, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_1p2) - - // c[1, 48-63] - RELU_SCALE_OP_S32_AVX512(c_int32_1p3) - - // c[2, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_2p0) - - // c[2, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_2p1) - - // c[2, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_2p2) - - // c[2, 48-63] - RELU_SCALE_OP_S32_AVX512(c_int32_2p3) - - // c[3, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_3p0) - - // c[3, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_3p1) - - // c[3, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_3p2) - - // c[3, 48-63] - RELU_SCALE_OP_S32_AVX512(c_int32_3p3) - - // c[4, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_4p0) - - // c[4, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_4p1) - - // c[4, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_4p2) - - // c[4, 48-63] - RELU_SCALE_OP_S32_AVX512(c_int32_4p3) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_5x64: - { - selector1 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 2 * 16 ) ); - a_int32_1 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 3 * 16 ) ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); - - // c[0, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); - - // c[0, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); - - // c[0, 48-63] - CVT_MULRND_CVT32_CVT8(c_int32_0p3,a_int32_1,0,3); - - // c[1, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); - - // c[1, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); - - // c[1, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); - - // c[1, 48-63] - CVT_MULRND_CVT32_CVT8(c_int32_1p3,a_int32_1,1,3); - - // c[2, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); - - // c[2, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); - - // c[2, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_2p2,a_int32_0,2,2); - - // c[2, 48-63] - CVT_MULRND_CVT32_CVT8(c_int32_2p3,a_int32_1,2,3); - - // c[3, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); - - // c[3, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_3p1,selector2,3,1); - - // c[3, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_3p2,a_int32_0,3,2); - - // c[3, 48-63] - CVT_MULRND_CVT32_CVT8(c_int32_3p3,a_int32_1,3,3); - - // c[4, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_4p0,selector1,4,0); - - // c[4, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_4p1,selector2,4,1); - - // c[4, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_4p2,a_int32_0,4,2); - - // c[4, 48-63] - CVT_MULRND_CVT32_CVT8(c_int32_4p3,a_int32_1,4,3); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_5x64_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); - - // c[0, 16-31] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); - - // c[0,32-47] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); - - // c[0,48-63] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 3*16 ), c_int32_0p3 ); - - // c[1,0-15] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); - - // c[1,16-31] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); - - // c[1,32-47] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 2*16 ), c_int32_1p2 ); - - // c[1,48-63] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 3*16 ), c_int32_1p3 ); - - // c[2,0-15] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); - - // c[2,16-31] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); - - // c[2,32-47] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 2*16 ), c_int32_2p2 ); - - // c[2,48-63] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 3*16 ), c_int32_2p3 ); - - // c[3,0-15] - _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 0*16 ), c_int32_3p0 ); - - // c[3,16-31] - _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 1*16 ), c_int32_3p1 ); - - // c[3,32-47] - _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 2*16 ), c_int32_3p2 ); - - // c[3,48-63] - _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 3*16 ), c_int32_3p3 ); - - // c[4,0-15] - _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 0*16 ), c_int32_4p0 ); - - // c[4,16-31] - _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 1*16 ), c_int32_4p1 ); - - // c[4,32-47] - _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 2*16 ), c_int32_4p2 ); - - // c[4,48-63] - _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 3*16 ), c_int32_4p3 ); -} - -// 4x64 int8o32 kernel -LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x64) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_4x64_DISABLE, - &&POST_OPS_BIAS_4x64, - &&POST_OPS_RELU_4x64, - &&POST_OPS_RELU_SCALE_4x64, - &&POST_OPS_DOWNSCALE_4x64 - }; - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; - __m512i b3; - - // A matrix storage. - __m512i a_int32_0; - __m512i a_int32_1; - - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - __m512i c_int32_0p1 = _mm512_setzero_epi32(); - __m512i c_int32_0p2 = _mm512_setzero_epi32(); - __m512i c_int32_0p3 = _mm512_setzero_epi32(); - - __m512i c_int32_1p0 = _mm512_setzero_epi32(); - __m512i c_int32_1p1 = _mm512_setzero_epi32(); - __m512i c_int32_1p2 = _mm512_setzero_epi32(); - __m512i c_int32_1p3 = _mm512_setzero_epi32(); - - __m512i c_int32_2p0 = _mm512_setzero_epi32(); - __m512i c_int32_2p1 = _mm512_setzero_epi32(); - __m512i c_int32_2p2 = _mm512_setzero_epi32(); - __m512i c_int32_2p3 = _mm512_setzero_epi32(); - - __m512i c_int32_3p0 = _mm512_setzero_epi32(); - __m512i c_int32_3p1 = _mm512_setzero_epi32(); - __m512i c_int32_3p2 = _mm512_setzero_epi32(); - __m512i c_int32_3p3 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); - b3 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 3 ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - - // Broadcast a[1,kr:kr+4]. - a_int32_1 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); - c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-63] = a[1,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_1, b0 ); - - // Broadcast a[2,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); - - c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_1, b1 ); - c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_1, b2 ); - c_int32_1p3 = _mm512_dpbusd_epi32( c_int32_1p3, a_int32_1, b3 ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-63] = a[2,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - - // Broadcast a[3,kr:kr+4]. - a_int32_1 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); - - c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); - c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); - c_int32_2p3 = _mm512_dpbusd_epi32( c_int32_2p3, a_int32_0, b3 ); - - // Perform column direction mat-mul with k = 4. - // c[3,0-63] = a[3,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_1, b0 ); - c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_1, b1 ); - c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_1, b2 ); - c_int32_3p3 = _mm512_dpbusd_epi32( c_int32_3p3, a_int32_1, b3 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); - b3 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - - // Broadcast a[1,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_1 = _mm512_set1_epi32( a_kfringe_buf ); - - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); - c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-63] = a[1,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_1, b0 ); - - // Broadcast a[2,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_1, b1 ); - c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_1, b2 ); - c_int32_1p3 = _mm512_dpbusd_epi32( c_int32_1p3, a_int32_1, b3 ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-63] = a[2,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - - // Broadcast a[3,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_1 = _mm512_set1_epi32( a_kfringe_buf ); - - c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); - c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); - c_int32_2p3 = _mm512_dpbusd_epi32( c_int32_2p3, a_int32_0, b3 ); - - // Perform column direction mat-mul with k = 4. - // c[3,0-63] = a[3,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_1, b0 ); - c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_1, b1 ); - c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_1, b2 ); - c_int32_3p3 = _mm512_dpbusd_epi32( c_int32_3p3, a_int32_1, b3 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); - c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); - c_int32_0p3 = _mm512_mullo_epi32( selector1, c_int32_0p3 ); - - c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); - c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); - c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); - c_int32_1p3 = _mm512_mullo_epi32( selector1, c_int32_1p3 ); - - c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); - c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); - c_int32_2p2 = _mm512_mullo_epi32( selector1, c_int32_2p2 ); - c_int32_2p3 = _mm512_mullo_epi32( selector1, c_int32_2p3 ); - - c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); - c_int32_3p1 = _mm512_mullo_epi32( selector1, c_int32_3p1 ); - c_int32_3p2 = _mm512_mullo_epi32( selector1, c_int32_3p2 ); - c_int32_3p3 = _mm512_mullo_epi32( selector1, c_int32_3p3 ); - - // Scale C by beta. - if ( beta != 0 ) - { - // c[0,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); - - // c[0,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); - - // c[0,48-63] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 3*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p3 = _mm512_add_epi32( selector1, c_int32_0p3 ); - - // c[1,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[1,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); - - // c[1,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); - - // c[1,48-63] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 3*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p3 = _mm512_add_epi32( selector1, c_int32_1p3 ); - - // c[2,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[2,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); - - // c[2,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p2 = _mm512_add_epi32( selector1, c_int32_2p2 ); - - // c[2,48-63] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 3*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p3 = _mm512_add_epi32( selector1, c_int32_2p3 ); - - // c[3,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); - - // c[3,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p1 = _mm512_add_epi32( selector1, c_int32_3p1 ); - - // c[3,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p2 = _mm512_add_epi32( selector1, c_int32_3p2 ); - - // c[3,48-63] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 3*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p3 = _mm512_add_epi32( selector1, c_int32_3p3 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_4x64: - { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j ); - selector2 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 2 * 16 ) ); - a_int32_1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 3 * 16 ) ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); - - // c[0,32-47] - c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); - - // c[0,48-63] - c_int32_0p3 = _mm512_add_epi32( a_int32_1, c_int32_0p3 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[1, 16-31] - c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); - - // c[1,32-47] - c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); - - // c[1,48-63] - c_int32_1p3 = _mm512_add_epi32( a_int32_1, c_int32_1p3 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[2, 16-31] - c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); - - // c[2,32-47] - c_int32_2p2 = _mm512_add_epi32( a_int32_0, c_int32_2p2 ); - - // c[2,48-63] - c_int32_2p3 = _mm512_add_epi32( a_int32_1, c_int32_2p3 ); - - // c[3,0-15] - c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); - - // c[3, 16-31] - c_int32_3p1 = _mm512_add_epi32( selector2, c_int32_3p1 ); - - // c[3,32-47] - c_int32_3p2 = _mm512_add_epi32( a_int32_0, c_int32_3p2 ); - - // c[3,48-63] - c_int32_3p3 = _mm512_add_epi32( a_int32_1, c_int32_3p3 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_4x64: - { - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); - - // c[0,32-47] - c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); - - // c[0,48-63] - c_int32_0p3 = _mm512_max_epi32( selector1, c_int32_0p3 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); - - // c[1,16-31] - c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); - - // c[1,32-47] - c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); - - // c[1,48-63] - c_int32_1p3 = _mm512_max_epi32( selector1, c_int32_1p3 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); - - // c[2,16-31] - c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); - - // c[2,32-47] - c_int32_2p2 = _mm512_max_epi32( selector1, c_int32_2p2 ); - - // c[2,48-63] - c_int32_2p3 = _mm512_max_epi32( selector1, c_int32_2p3 ); - - // c[3,0-15] - c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); - - // c[3,16-31] - c_int32_3p1 = _mm512_max_epi32( selector1, c_int32_3p1 ); - - // c[3,32-47] - c_int32_3p2 = _mm512_max_epi32( selector1, c_int32_3p2 ); - - // c[3,48-63] - c_int32_3p3 = _mm512_max_epi32( selector1, c_int32_3p3 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_4x64: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - // c[0, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_0p1) - - // c[0, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_0p2) - - // c[0, 48-63] - RELU_SCALE_OP_S32_AVX512(c_int32_0p3) - - // c[1, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_1p0) - - // c[1, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_1p1) - - // c[1, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_1p2) - - // c[1, 48-63] - RELU_SCALE_OP_S32_AVX512(c_int32_1p3) - - // c[2, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_2p0) - - // c[2, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_2p1) - - // c[2, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_2p2) - - // c[2, 48-63] - RELU_SCALE_OP_S32_AVX512(c_int32_2p3) - - // c[3, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_3p0) - - // c[3, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_3p1) - - // c[3, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_3p2) - - // c[3, 48-63] - RELU_SCALE_OP_S32_AVX512(c_int32_3p3) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_4x64: - { - selector1 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 2 * 16 ) ); - a_int32_1 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 3 * 16 ) ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); - - // c[0, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); - - // c[0, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); - - // c[0, 48-63] - CVT_MULRND_CVT32_CVT8(c_int32_0p3,a_int32_1,0,3); - - // c[1, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); - - // c[1, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); - - // c[1, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); - - // c[1, 48-63] - CVT_MULRND_CVT32_CVT8(c_int32_1p3,a_int32_1,1,3); - - // c[2, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); - - // c[2, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); - - // c[2, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_2p2,a_int32_0,2,2); - - // c[2, 48-63] - CVT_MULRND_CVT32_CVT8(c_int32_2p3,a_int32_1,2,3); - - // c[3, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); - - // c[3, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_3p1,selector2,3,1); - - // c[3, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_3p2,a_int32_0,3,2); - - // c[3, 48-63] - CVT_MULRND_CVT32_CVT8(c_int32_3p3,a_int32_1,3,3); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_4x64_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); - - // c[0, 16-31] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); - - // c[0,32-47] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); - - // c[0,48-63] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 3*16 ), c_int32_0p3 ); - - // c[1,0-15] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); - - // c[1,16-31] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); - - // c[1,32-47] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 2*16 ), c_int32_1p2 ); - - // c[1,48-63] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 3*16 ), c_int32_1p3 ); - - // c[2,0-15] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); - - // c[2,16-31] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); - - // c[2,32-47] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 2*16 ), c_int32_2p2 ); - - // c[2,48-63] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 3*16 ), c_int32_2p3 ); - - // c[3,0-15] - _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 0*16 ), c_int32_3p0 ); - - // c[3,16-31] - _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 1*16 ), c_int32_3p1 ); - - // c[3,32-47] - _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 2*16 ), c_int32_3p2 ); - - // c[3,48-63] - _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 3*16 ), c_int32_3p3 ); -} - -// 3x64 int8o32 kernel -LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x64) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_3x64_DISABLE, - &&POST_OPS_BIAS_3x64, - &&POST_OPS_RELU_3x64, - &&POST_OPS_RELU_SCALE_3x64, - &&POST_OPS_DOWNSCALE_3x64 - }; - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; - __m512i b3; - - // A matrix storage. - __m512i a_int32_0; - __m512i a_int32_1; - - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - __m512i c_int32_0p1 = _mm512_setzero_epi32(); - __m512i c_int32_0p2 = _mm512_setzero_epi32(); - __m512i c_int32_0p3 = _mm512_setzero_epi32(); - - __m512i c_int32_1p0 = _mm512_setzero_epi32(); - __m512i c_int32_1p1 = _mm512_setzero_epi32(); - __m512i c_int32_1p2 = _mm512_setzero_epi32(); - __m512i c_int32_1p3 = _mm512_setzero_epi32(); - - __m512i c_int32_2p0 = _mm512_setzero_epi32(); - __m512i c_int32_2p1 = _mm512_setzero_epi32(); - __m512i c_int32_2p2 = _mm512_setzero_epi32(); - __m512i c_int32_2p3 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); - b3 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 3 ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - - // Broadcast a[1,kr:kr+4]. - a_int32_1 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); - c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-63] = a[1,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_1, b0 ); - - // Broadcast a[2,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); - - c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_1, b1 ); - c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_1, b2 ); - c_int32_1p3 = _mm512_dpbusd_epi32( c_int32_1p3, a_int32_1, b3 ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-63] = a[2,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); - c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); - c_int32_2p3 = _mm512_dpbusd_epi32( c_int32_2p3, a_int32_0, b3 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); - b3 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - - // Broadcast a[1,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_1 = _mm512_set1_epi32( a_kfringe_buf ); - - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); - c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-63] = a[1,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_1, b0 ); - - // Broadcast a[2,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_1, b1 ); - c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_1, b2 ); - c_int32_1p3 = _mm512_dpbusd_epi32( c_int32_1p3, a_int32_1, b3 ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-63] = a[2,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); - c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); - c_int32_2p3 = _mm512_dpbusd_epi32( c_int32_2p3, a_int32_0, b3 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); - c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); - c_int32_0p3 = _mm512_mullo_epi32( selector1, c_int32_0p3 ); - - c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); - c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); - c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); - c_int32_1p3 = _mm512_mullo_epi32( selector1, c_int32_1p3 ); - - c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); - c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); - c_int32_2p2 = _mm512_mullo_epi32( selector1, c_int32_2p2 ); - c_int32_2p3 = _mm512_mullo_epi32( selector1, c_int32_2p3 ); - - // Scale C by beta. - if ( beta != 0 ) - { - // c[0,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); - - // c[0,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); - - // c[0,48-63] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 3*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p3 = _mm512_add_epi32( selector1, c_int32_0p3 ); - - // c[1,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[1,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); - - // c[1,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); - - // c[1,48-63] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 3*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p3 = _mm512_add_epi32( selector1, c_int32_1p3 ); - - // c[2,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[2,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); - - // c[2,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p2 = _mm512_add_epi32( selector1, c_int32_2p2 ); - - // c[2,48-63] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 3*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p3 = _mm512_add_epi32( selector1, c_int32_2p3 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_3x64: - { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j ); - selector2 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 2 * 16 ) ); - a_int32_1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 3 * 16 ) ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); - - // c[0,32-47] - c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); - - // c[0,48-63] - c_int32_0p3 = _mm512_add_epi32( a_int32_1, c_int32_0p3 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[1, 16-31] - c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); - - // c[1,32-47] - c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); - - // c[1,48-63] - c_int32_1p3 = _mm512_add_epi32( a_int32_1, c_int32_1p3 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[2, 16-31] - c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); - - // c[2,32-47] - c_int32_2p2 = _mm512_add_epi32( a_int32_0, c_int32_2p2 ); - - // c[2,48-63] - c_int32_2p3 = _mm512_add_epi32( a_int32_1, c_int32_2p3 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_3x64: - { - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); - - // c[0,32-47] - c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); - - // c[0,48-63] - c_int32_0p3 = _mm512_max_epi32( selector1, c_int32_0p3 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); - - // c[1,16-31] - c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); - - // c[1,32-47] - c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); - - // c[1,48-63] - c_int32_1p3 = _mm512_max_epi32( selector1, c_int32_1p3 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); - - // c[2,16-31] - c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); - - // c[2,32-47] - c_int32_2p2 = _mm512_max_epi32( selector1, c_int32_2p2 ); - - // c[2,48-63] - c_int32_2p3 = _mm512_max_epi32( selector1, c_int32_2p3 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_3x64: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - // c[0, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_0p1) - - // c[0, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_0p2) - - // c[0, 48-63] - RELU_SCALE_OP_S32_AVX512(c_int32_0p3) - - // c[1, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_1p0) - - // c[1, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_1p1) - - // c[1, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_1p2) - - // c[1, 48-63] - RELU_SCALE_OP_S32_AVX512(c_int32_1p3) - - // c[2, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_2p0) - - // c[2, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_2p1) - - // c[2, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_2p2) - - // c[2, 48-63] - RELU_SCALE_OP_S32_AVX512(c_int32_2p3) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_3x64: - { - selector1 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 2 * 16 ) ); - a_int32_1 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 3 * 16 ) ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); - - // c[0, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); - - // c[0, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); - - // c[0, 48-63] - CVT_MULRND_CVT32_CVT8(c_int32_0p3,a_int32_1,0,3); - - // c[1, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); - - // c[1, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); - - // c[1, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); - - // c[1, 48-63] - CVT_MULRND_CVT32_CVT8(c_int32_1p3,a_int32_1,1,3); - - // c[2, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); - - // c[2, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); - - // c[2, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_2p2,a_int32_0,2,2); - - // c[2, 48-63] - CVT_MULRND_CVT32_CVT8(c_int32_2p3,a_int32_1,2,3); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_3x64_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); - - // c[0, 16-31] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); - - // c[0,32-47] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); - - // c[0,48-63] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 3*16 ), c_int32_0p3 ); - - // c[1,0-15] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); - - // c[1,16-31] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); - - // c[1,32-47] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 2*16 ), c_int32_1p2 ); - - // c[1,48-63] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 3*16 ), c_int32_1p3 ); - - // c[2,0-15] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); - - // c[2,16-31] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); - - // c[2,32-47] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 2*16 ), c_int32_2p2 ); - - // c[2,48-63] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 3*16 ), c_int32_2p3 ); -} - -// 2x64 int8o32 kernel -LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x64) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_2x64_DISABLE, - &&POST_OPS_BIAS_2x64, - &&POST_OPS_RELU_2x64, - &&POST_OPS_RELU_SCALE_2x64, - &&POST_OPS_DOWNSCALE_2x64 - }; - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; - __m512i b3; - - // A matrix storage. - __m512i a_int32_0; - __m512i a_int32_1; - - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - __m512i c_int32_0p1 = _mm512_setzero_epi32(); - __m512i c_int32_0p2 = _mm512_setzero_epi32(); - __m512i c_int32_0p3 = _mm512_setzero_epi32(); - - __m512i c_int32_1p0 = _mm512_setzero_epi32(); - __m512i c_int32_1p1 = _mm512_setzero_epi32(); - __m512i c_int32_1p2 = _mm512_setzero_epi32(); - __m512i c_int32_1p3 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); - b3 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 3 ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - - // Broadcast a[1,kr:kr+4]. - a_int32_1 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); - c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-63] = a[1,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_1, b0 ); - c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_1, b1 ); - c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_1, b2 ); - c_int32_1p3 = _mm512_dpbusd_epi32( c_int32_1p3, a_int32_1, b3 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); - b3 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - - // Broadcast a[1,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_1 = _mm512_set1_epi32( a_kfringe_buf ); - - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); - c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-63] = a[1,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_1, b0 ); - c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_1, b1 ); - c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_1, b2 ); - c_int32_1p3 = _mm512_dpbusd_epi32( c_int32_1p3, a_int32_1, b3 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); - c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); - c_int32_0p3 = _mm512_mullo_epi32( selector1, c_int32_0p3 ); - - c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); - c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); - c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); - c_int32_1p3 = _mm512_mullo_epi32( selector1, c_int32_1p3 ); - - // Scale C by beta. - if ( beta != 0 ) - { - // c[0,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); - - // c[0,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); - - // c[0,48-63] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 3*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p3 = _mm512_add_epi32( selector1, c_int32_0p3 ); - - // c[1,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[1,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); - - // c[1,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); - - // c[1,48-63] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 3*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p3 = _mm512_add_epi32( selector1, c_int32_1p3 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_2x64: - { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j ); - selector2 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 2 * 16 ) ); - a_int32_1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 3 * 16 ) ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); - - // c[0,32-47] - c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); - - // c[0,48-63] - c_int32_0p3 = _mm512_add_epi32( a_int32_1, c_int32_0p3 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[1, 16-31] - c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); - - // c[1,32-47] - c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); - - // c[1,48-63] - c_int32_1p3 = _mm512_add_epi32( a_int32_1, c_int32_1p3 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_2x64: - { - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); - - // c[0,32-47] - c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); - - // c[0,48-63] - c_int32_0p3 = _mm512_max_epi32( selector1, c_int32_0p3 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); - - // c[1,16-31] - c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); - - // c[1,32-47] - c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); - - // c[1,48-63] - c_int32_1p3 = _mm512_max_epi32( selector1, c_int32_1p3 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_2x64: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - // c[0, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_0p1) - - // c[0, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_0p2) - - // c[0, 48-63] - RELU_SCALE_OP_S32_AVX512(c_int32_0p3) - - // c[1, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_1p0) - - // c[1, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_1p1) - - // c[1, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_1p2) - - // c[1, 48-63] - RELU_SCALE_OP_S32_AVX512(c_int32_1p3) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_2x64: - { - selector1 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 2 * 16 ) ); - a_int32_1 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 3 * 16 ) ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); - - // c[0, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); - - // c[0, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); - - // c[0, 48-63] - CVT_MULRND_CVT32_CVT8(c_int32_0p3,a_int32_1,0,3); - - // c[1, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); - - // c[1, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); - - // c[1, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); - - // c[1, 48-63] - CVT_MULRND_CVT32_CVT8(c_int32_1p3,a_int32_1,1,3); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_2x64_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); - - // c[0, 16-31] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); - - // c[0,32-47] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); - - // c[0,48-63] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 3*16 ), c_int32_0p3 ); - - // c[1,0-15] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); - - // c[1,16-31] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); - - // c[1,32-47] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 2*16 ), c_int32_1p2 ); - - // c[1,48-63] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 3*16 ), c_int32_1p3 ); -} - -// 1x64 int8o32 kernel -LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x64) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_1x64_DISABLE, - &&POST_OPS_BIAS_1x64, - &&POST_OPS_RELU_1x64, - &&POST_OPS_RELU_SCALE_1x64, - &&POST_OPS_DOWNSCALE_1x64 - }; - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; - __m512i b3; - - // A matrix storage. - __m512i a_int32_0; - __m512i a_int32_1; - - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - __m512i c_int32_0p1 = _mm512_setzero_epi32(); - __m512i c_int32_0p2 = _mm512_setzero_epi32(); - __m512i c_int32_0p3 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr] - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); - b3 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 3 ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); - c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); - b3 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); - c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); - c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); - c_int32_0p3 = _mm512_mullo_epi32( selector1, c_int32_0p3 ); - - // Scale C by beta. - if ( beta != 0) - { - // c[0,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); - - // c[0,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); - - // c[0,48-63] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 3*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p3 = _mm512_add_epi32( selector1, c_int32_0p3 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_1x64: - { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j ); - selector2 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 2 * 16 ) ); - a_int32_1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 3 * 16 ) ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); - - // c[0,32-47] - c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); - - // c[0,48-63] - c_int32_0p3 = _mm512_add_epi32( a_int32_1, c_int32_0p3 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_1x64: - { - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); - - // c[0,32-47] - c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); - - // c[0,48-63] - c_int32_0p3 = _mm512_max_epi32( selector1, c_int32_0p3 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_1x64: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - // c[0, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_0p1) - - // c[0, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_0p2) - - // c[0, 48-63] - RELU_SCALE_OP_S32_AVX512(c_int32_0p3) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_1x64: - { - selector1 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 2 * 16 ) ); - a_int32_1 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 3 * 16 ) ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); - - // c[0, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); - - // c[0, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); - - // c[0, 48-63] - CVT_MULRND_CVT32_CVT8(c_int32_0p3,a_int32_1,0,3); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_1x64_DISABLE: - ; - - // Store the accumulated results. - // c[0,0-15] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); - - // c[0, 16-31] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); - - // c[0,32-47] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); - - // c[0,48-63] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 3*16 ), c_int32_0p3 ); -} -#endif diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c deleted file mode 100644 index b202061e6a..0000000000 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c +++ /dev/null @@ -1,5283 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include - -#include "blis.h" -#include "lpgemm_kernels.h" -#include "lpgemm_s32_kern_macros.h" - -#ifdef BLIS_KERNELS_ZEN4 -// 5xlt16 int8o32 fringe kernel -LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5xlt16) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_5xLT16_DISABLE, - &&POST_OPS_BIAS_5xLT16, - &&POST_OPS_RELU_5xLT16, - &&POST_OPS_RELU_SCALE_5xLT16, - &&POST_OPS_DOWNSCALE_5xLT16 - }; - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // For corner cases. - int32_t buf0[16]; - int32_t buf1[16]; - int32_t buf2[16]; - int32_t buf3[16]; - int32_t buf4[16]; - - { - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - - __m512i c_int32_1p0 = _mm512_setzero_epi32(); - - __m512i c_int32_2p0 = _mm512_setzero_epi32(); - - __m512i c_int32_3p0 = _mm512_setzero_epi32(); - - __m512i c_int32_4p0 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - - // Broadcast a[1,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - - // Broadcast a[2,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - - // Broadcast a[3,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); - - // Broadcast a[4,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[4,0-15] = a[4,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - - // Broadcast a[1,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - - // Broadcast a[2,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - - // Broadcast a[3,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); - - // Broadcast a[4,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[4,0-15] = a[4,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - - c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); - - c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); - - c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); - - c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); - - // Scale C by beta. - if ( beta != 0 ) - { - memcpy( buf0, ( c + ( rs_c * 0 ) ), ( n0_rem * sizeof( int32_t ) ) ); - memcpy( buf1, ( c + ( rs_c * 1 ) ), ( n0_rem * sizeof( int32_t ) ) ); - memcpy( buf2, ( c + ( rs_c * 2 ) ), ( n0_rem * sizeof( int32_t ) ) ); - memcpy( buf3, ( c + ( rs_c * 3 ) ), ( n0_rem * sizeof( int32_t ) ) ); - memcpy( buf4, ( c + ( rs_c * 4 ) ), ( n0_rem * sizeof( int32_t ) ) ); - - // c[0,0-15] - selector1 = _mm512_loadu_epi32( buf0 ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - selector1 = _mm512_loadu_epi32( buf1 ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[2,0-15] - selector1 = _mm512_loadu_epi32( buf2 ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[3,0-15] - selector1 = _mm512_loadu_epi32( buf3 ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); - - // c[4,0-15] - selector1 = _mm512_loadu_epi32( buf4 ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_5xLT16: - { - memcpy( buf0, ( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j ), ( n0_rem * sizeof( int32_t ) ) ); - selector1 = _mm512_loadu_epi32( buf0 ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[3,0-15] - c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); - - // c[4,0-15] - c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_5xLT16: - { - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); - - // c[3,0-15] - c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); - - // c[4,0-15] - c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_5xLT16: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - // c[1, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_1p0) - - // c[2, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_2p0) - - // c[3, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_3p0) - - // c[4, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_4p0) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_5xLT16: - { - memcpy( buf0, ( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - selector1 = _mm512_loadu_epi32( buf0 ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8_LT16(c_int32_0p0,selector1,0,0); - - // c[1, 0-15] - CVT_MULRND_CVT32_CVT8_LT16(c_int32_1p0,selector1,1,0); - - // c[2, 0-15] - CVT_MULRND_CVT32_CVT8_LT16(c_int32_2p0,selector1,2,0); - - // c[3, 0-15] - CVT_MULRND_CVT32_CVT8_LT16(c_int32_3p0,selector1,3,0); - - // c[4, 0-15] - CVT_MULRND_CVT32_CVT8_LT16(c_int32_4p0,selector1,4,0); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_5xLT16_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( buf0, c_int32_0p0 ); - - // c[1,0-15] - _mm512_storeu_epi32( buf1, c_int32_1p0 ); - - // c[2,0-15] - _mm512_storeu_epi32( buf2, c_int32_2p0 ); - - // c[3,0-15] - _mm512_storeu_epi32( buf3, c_int32_3p0 ); - - // c[4,0-15] - _mm512_storeu_epi32( buf4, c_int32_4p0 ); - - // Memcpy partial parts. - // c[0,0-15] - memcpy( c + ( rs_c * 0 ) + ( 0*16 ), buf0, ( n0_rem * sizeof( int32_t ) ) ); - - // c[1,0-15] - memcpy( c + ( rs_c * 1 ) + ( 0*16 ), buf1, ( n0_rem * sizeof( int32_t ) ) ); - - // c[2,0-15] - memcpy( c + ( rs_c * 2 ) + ( 0*16 ), buf2, ( n0_rem * sizeof( int32_t ) ) ); - - // c[3,0-15] - memcpy( c + ( rs_c * 3 ) + ( 0*16 ), buf3, ( n0_rem * sizeof( int32_t ) ) ); - - // c[4,0-15] - memcpy( c + ( rs_c * 4 ) + ( 0*16 ), buf4, ( n0_rem * sizeof( int32_t ) ) ); - } -} - -// 4xlt16 int8o32 fringe kernel -LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4xlt16) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_4xLT16_DISABLE, - &&POST_OPS_BIAS_4xLT16, - &&POST_OPS_RELU_4xLT16, - &&POST_OPS_RELU_SCALE_4xLT16, - &&POST_OPS_DOWNSCALE_4xLT16 - }; - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // For corner cases. - int32_t buf0[16]; - int32_t buf1[16]; - int32_t buf2[16]; - int32_t buf3[16]; - - { - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - - __m512i c_int32_1p0 = _mm512_setzero_epi32(); - - __m512i c_int32_2p0 = _mm512_setzero_epi32(); - - __m512i c_int32_3p0 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - - // Broadcast a[1,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - - // Broadcast a[2,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - - // Broadcast a[3,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - - // Broadcast a[1,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - - // Broadcast a[2,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - - // Broadcast a[3,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - - c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); - - c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); - - c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); - - - // Scale C by beta. - if ( beta != 0 ) - { - memcpy( buf0, ( c + ( rs_c * 0 ) ), ( n0_rem * sizeof( int32_t ) ) ); - memcpy( buf1, ( c + ( rs_c * 1 ) ), ( n0_rem * sizeof( int32_t ) ) ); - memcpy( buf2, ( c + ( rs_c * 2 ) ), ( n0_rem * sizeof( int32_t ) ) ); - memcpy( buf3, ( c + ( rs_c * 3 ) ), ( n0_rem * sizeof( int32_t ) ) ); - - // c[0,0-15] - selector1 = _mm512_loadu_epi32( buf0 ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - selector1 = _mm512_loadu_epi32( buf1 ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[2,0-15] - selector1 = _mm512_loadu_epi32( buf2 ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[3,0-15] - selector1 = _mm512_loadu_epi32( buf3 ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_4xLT16: - { - memcpy( buf0, ( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j ), ( n0_rem * sizeof( int32_t ) ) ); - selector1 = _mm512_loadu_epi32( buf0 ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[3,0-15] - c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_4xLT16: - { - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); - - // c[3,0-15] - c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_4xLT16: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - // c[1, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_1p0) - - // c[2, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_2p0) - - // c[3, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_3p0) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_4xLT16: - { - memcpy( buf0, ( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - selector1 = _mm512_loadu_epi32( buf0 ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8_LT16(c_int32_0p0,selector1,0,0); - - // c[1, 0-15] - CVT_MULRND_CVT32_CVT8_LT16(c_int32_1p0,selector1,1,0); - - // c[2, 0-15] - CVT_MULRND_CVT32_CVT8_LT16(c_int32_2p0,selector1,2,0); - - // c[3, 0-15] - CVT_MULRND_CVT32_CVT8_LT16(c_int32_3p0,selector1,3,0); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_4xLT16_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( buf0, c_int32_0p0 ); - - // c[1,0-15] - _mm512_storeu_epi32( buf1, c_int32_1p0 ); - - // c[2,0-15] - _mm512_storeu_epi32( buf2, c_int32_2p0 ); - - // c[3,0-15] - _mm512_storeu_epi32( buf3, c_int32_3p0 ); - - // Memcpy partial parts. - // c[0,0-15] - memcpy( c + ( rs_c * 0 ) + ( 0*16 ), buf0, ( n0_rem * sizeof( int32_t ) ) ); - - // c[1,0-15] - memcpy( c + ( rs_c * 1 ) + ( 0*16 ), buf1, ( n0_rem * sizeof( int32_t ) ) ); - - // c[2,0-15] - memcpy( c + ( rs_c * 2 ) + ( 0*16 ), buf2, ( n0_rem * sizeof( int32_t ) ) ); - - // c[3,0-15] - memcpy( c + ( rs_c * 3 ) + ( 0*16 ), buf3, ( n0_rem * sizeof( int32_t ) ) ); - } -} - -// 3xlt16 int8o32 fringe kernel -LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3xlt16) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_3xLT16_DISABLE, - &&POST_OPS_BIAS_3xLT16, - &&POST_OPS_RELU_3xLT16, - &&POST_OPS_RELU_SCALE_3xLT16, - &&POST_OPS_DOWNSCALE_3xLT16 - }; - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // For corner cases. - int32_t buf0[16]; - int32_t buf1[16]; - int32_t buf2[16]; - - { - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - - __m512i c_int32_1p0 = _mm512_setzero_epi32(); - - __m512i c_int32_2p0 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - - // Broadcast a[1,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - - // Broadcast a[2,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - - // Broadcast a[1,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - - // Broadcast a[2,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - - c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); - - c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); - - // Scale C by beta. - if ( beta != 0 ) - { - memcpy( buf0, ( c + ( rs_c * 0 ) ), ( n0_rem * sizeof( int32_t ) ) ); - memcpy( buf1, ( c + ( rs_c * 1 ) ), ( n0_rem * sizeof( int32_t ) ) ); - memcpy( buf2, ( c + ( rs_c * 2 ) ), ( n0_rem * sizeof( int32_t ) ) ); - - // c[0,0-15] - selector1 = _mm512_loadu_epi32( buf0 ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - selector1 = _mm512_loadu_epi32( buf1 ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[2,0-15] - selector1 = _mm512_loadu_epi32( buf2 ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_3xLT16: - { - memcpy( buf0, ( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j ), ( n0_rem * sizeof( int32_t ) ) ); - selector1 = _mm512_loadu_epi32( buf0 ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_3xLT16: - { - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_3xLT16: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - // c[1, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_1p0) - - // c[2, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_2p0) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_3xLT16: - { - memcpy( buf0, ( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - selector1 = _mm512_loadu_epi32( buf0 ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8_LT16(c_int32_0p0,selector1,0,0); - - // c[1, 0-15] - CVT_MULRND_CVT32_CVT8_LT16(c_int32_1p0,selector1,1,0); - - // c[2, 0-15] - CVT_MULRND_CVT32_CVT8_LT16(c_int32_2p0,selector1,2,0); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_3xLT16_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( buf0, c_int32_0p0 ); - - // c[1,0-15] - _mm512_storeu_epi32( buf1, c_int32_1p0 ); - - // c[2,0-15] - _mm512_storeu_epi32( buf2, c_int32_2p0 ); - - // Memcpy partial parts. - // c[0,0-15] - memcpy( c + ( rs_c * 0 ) + ( 0*16 ), buf0, ( n0_rem * sizeof( int32_t ) ) ); - - // c[1,0-15] - memcpy( c + ( rs_c * 1 ) + ( 0*16 ), buf1, ( n0_rem * sizeof( int32_t ) ) ); - - // c[2,0-15] - memcpy( c + ( rs_c * 2 ) + ( 0*16 ), buf2, ( n0_rem * sizeof( int32_t ) ) ); - } -} - -// 2xlt16 int8o32 fringe kernel -LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2xlt16) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_2xLT16_DISABLE, - &&POST_OPS_BIAS_2xLT16, - &&POST_OPS_RELU_2xLT16, - &&POST_OPS_RELU_SCALE_2xLT16, - &&POST_OPS_DOWNSCALE_2xLT16 - }; - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // For corner cases. - int32_t buf0[16]; - int32_t buf1[16]; - - { - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - - __m512i c_int32_1p0 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - - // Broadcast a[1,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - - // Broadcast a[1,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - - c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); - - // Scale C by beta. - if ( beta != 0 ) - { - memcpy( buf0, ( c + ( rs_c * 0 ) ), ( n0_rem * sizeof( int32_t ) ) ); - memcpy( buf1, ( c + ( rs_c * 1 ) ), ( n0_rem * sizeof( int32_t ) ) ); - - // c[0,0-15] - selector1 = _mm512_loadu_epi32( buf0 ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - selector1 = _mm512_loadu_epi32( buf1 ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_2xLT16: - { - memcpy( buf0, ( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j ), ( n0_rem * sizeof( int32_t ) ) ); - selector1 = _mm512_loadu_epi32( buf0 ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_2xLT16: - { - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_2xLT16: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - // c[1, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_1p0) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_2xLT16: - { - memcpy( buf0, ( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - selector1 = _mm512_loadu_epi32( buf0 ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8_LT16(c_int32_0p0,selector1,0,0); - - // c[1, 0-15] - CVT_MULRND_CVT32_CVT8_LT16(c_int32_1p0,selector1,1,0); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_2xLT16_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( buf0, c_int32_0p0 ); - - // c[1,0-15] - _mm512_storeu_epi32( buf1, c_int32_1p0 ); - - // Memcpy partial parts. - // c[0,0-15] - memcpy( c + ( rs_c * 0 ) + ( 0*16 ), buf0, ( n0_rem * sizeof( int32_t ) ) ); - - // c[1,0-15] - memcpy( c + ( rs_c * 1 ) + ( 0*16 ), buf1, ( n0_rem * sizeof( int32_t ) ) ); - } -} - -// 1xlt16 int8o32 fringe kernel -LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1xlt16) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_1xLT16_DISABLE, - &&POST_OPS_BIAS_1xLT16, - &&POST_OPS_RELU_1xLT16, - &&POST_OPS_RELU_SCALE_1xLT16, - &&POST_OPS_DOWNSCALE_1xLT16 - }; - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // For corner cases. - int32_t buf0[16]; - - { - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - - // Scale C by beta. - if ( beta != 0 ) - { - memcpy( buf0, ( c + ( rs_c * 0 ) ), ( n0_rem * sizeof( int32_t ) ) ); - - // c[0,0-15] - selector1 = _mm512_loadu_epi32( buf0 ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_1xLT16: - { - memcpy( buf0, ( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j ), ( n0_rem * sizeof( int32_t ) ) ); - selector1 = _mm512_loadu_epi32( buf0 ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_1xLT16: - { - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_1xLT16: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_1xLT16: - { - memcpy( buf0, ( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - selector1 = _mm512_loadu_epi32( buf0 ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8_LT16(c_int32_0p0,selector1,0,0); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_1xLT16_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( buf0, c_int32_0p0 ); - - // Memcpy partial parts. - // c[0,0-15] - memcpy( c + ( rs_c * 0 ) + ( 0*16 ), buf0, ( n0_rem * sizeof( int32_t ) ) ); - } -} - -// 5x16 int8o32 kernel -LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x16) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_5x16_DISABLE, - &&POST_OPS_BIAS_5x16, - &&POST_OPS_RELU_5x16, - &&POST_OPS_RELU_SCALE_5x16, - &&POST_OPS_DOWNSCALE_5x16 - }; - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - - __m512i c_int32_1p0 = _mm512_setzero_epi32(); - - __m512i c_int32_2p0 = _mm512_setzero_epi32(); - - __m512i c_int32_3p0 = _mm512_setzero_epi32(); - - __m512i c_int32_4p0 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - - // Broadcast a[1,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - - // Broadcast a[2,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - - // Broadcast a[3,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); - - // Broadcast a[4,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[4,0-15] = a[4,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - - // Broadcast a[1,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - - // Broadcast a[2,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - - // Broadcast a[3,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); - - // Broadcast a[4,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[4,0-15] = a[4,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - - c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); - - c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); - - c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); - - c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); - - // Scale C by beta. - if ( beta != 0 ) - { - // c[0,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[2,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[3,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); - - // c[4,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_5x16: - { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[3,0-15] - c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); - - // c[4,0-15] - c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_5x16: - { - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); - - // c[3,0-15] - c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); - - // c[4,0-15] - c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_5x16: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - // c[1, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_1p0) - - // c[2, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_2p0) - - // c[3, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_3p0) - - // c[4, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_4p0) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_5x16: - { - selector1 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 0 * 16 ) ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); - - // c[1, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); - - // c[2, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); - - // c[3, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); - - // c[4, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_4p0,selector1,4,0); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_5x16_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); - - // c[1,0-15] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); - - // c[2,0-15] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); - - // c[3,0-15] - _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 0*16 ), c_int32_3p0 ); - - // c[4,0-15] - _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 0*16 ), c_int32_4p0 ); -} - -// 4x16 int8o32 kernel -LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x16) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_4x16_DISABLE, - &&POST_OPS_BIAS_4x16, - &&POST_OPS_RELU_4x16, - &&POST_OPS_RELU_SCALE_4x16, - &&POST_OPS_DOWNSCALE_4x16 - }; - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - - __m512i c_int32_1p0 = _mm512_setzero_epi32(); - - __m512i c_int32_2p0 = _mm512_setzero_epi32(); - - __m512i c_int32_3p0 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - - // Broadcast a[1,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - - // Broadcast a[2,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - - // Broadcast a[3,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - - // Broadcast a[1,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - - // Broadcast a[2,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - - // Broadcast a[3,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - - c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); - - c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); - - c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); - - // Scale C by beta. - if ( beta != 0 ) - { - // c[0,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[2,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[3,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_4x16: - { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[3,0-15] - c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_4x16: - { - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); - - // c[3,0-15] - c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_4x16: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - // c[1, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_1p0) - - // c[2, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_2p0) - - // c[3, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_3p0) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_4x16: - { - selector1 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 0 * 16 ) ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); - - // c[1, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); - - // c[2, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); - - // c[3, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_4x16_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); - - // c[1,0-15] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); - - // c[2,0-15] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); - - // c[3,0-15] - _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 0*16 ), c_int32_3p0 ); -} - -// 3x16 int8o32 kernel -LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x16) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_3x16_DISABLE, - &&POST_OPS_BIAS_3x16, - &&POST_OPS_RELU_3x16, - &&POST_OPS_RELU_SCALE_3x16, - &&POST_OPS_DOWNSCALE_3x16 - }; - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - - __m512i c_int32_1p0 = _mm512_setzero_epi32(); - - __m512i c_int32_2p0 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - - // Broadcast a[1,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - - // Broadcast a[2,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - - // Broadcast a[1,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - - // Broadcast a[2,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - - c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); - - c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); - - // Scale C by beta. - if ( beta != 0 ) - { - // c[0,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[2,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_3x16: - { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_3x16: - { - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_3x16: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - // c[1, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_1p0) - - // c[2, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_2p0) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_3x16: - { - selector1 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 0 * 16 ) ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); - - // c[1, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); - - // c[2, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_3x16_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); - - // c[1,0-15] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); - - // c[2,0-15] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); -} - -// 2x16 int8o32 kernel -LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x16) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_2x16_DISABLE, - &&POST_OPS_BIAS_2x16, - &&POST_OPS_RELU_2x16, - &&POST_OPS_RELU_SCALE_2x16, - &&POST_OPS_DOWNSCALE_2x16 - }; - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - - __m512i c_int32_1p0 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - - // Broadcast a[1,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - - // Broadcast a[1,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - - c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); - - // Scale C by beta. - if ( beta != 0 ) - { - // c[0,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_2x16: - { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_2x16: - { - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_2x16: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - // c[1, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_1p0) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_2x16: - { - selector1 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 0 * 16 ) ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); - - // c[1, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_2x16_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); - - // c[1,0-15] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); -} - -// 1x16 int8o32 kernel -LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x16) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_1x16_DISABLE, - &&POST_OPS_BIAS_1x16, - &&POST_OPS_RELU_1x16, - &&POST_OPS_RELU_SCALE_1x16, - &&POST_OPS_DOWNSCALE_1x16 - }; - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - - // Scale C by beta. - if ( beta != 0 ) - { - // c[0,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_1x16: - { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_1x16: - { - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_1x16: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_1x16: - { - selector1 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 0 * 16 ) ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_1x16_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); -} - -// 5x32 int8o32 kernel -LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x32) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_5x32_DISABLE, - &&POST_OPS_BIAS_5x32, - &&POST_OPS_RELU_5x32, - &&POST_OPS_RELU_SCALE_5x32, - &&POST_OPS_DOWNSCALE_5x32 - }; - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // B matrix storage. - __m512i b0; - __m512i b1; - - // A matrix storage. - __m512i a_int32_0; - - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - __m512i c_int32_0p1 = _mm512_setzero_epi32(); - - __m512i c_int32_1p0 = _mm512_setzero_epi32(); - __m512i c_int32_1p1 = _mm512_setzero_epi32(); - - __m512i c_int32_2p0 = _mm512_setzero_epi32(); - __m512i c_int32_2p1 = _mm512_setzero_epi32(); - - __m512i c_int32_3p0 = _mm512_setzero_epi32(); - __m512i c_int32_3p1 = _mm512_setzero_epi32(); - - __m512i c_int32_4p0 = _mm512_setzero_epi32(); - __m512i c_int32_4p1 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - - // Broadcast a[0,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - - // Broadcast a[1,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); - - // Broadcast a[2,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-31] = a[2,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); - - // Broadcast a[3,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[3,0-31] = a[3,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); - c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); - - // Broadcast a[4,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[4,0-31] = a[4,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); - c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - - // Broadcast a[1,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); - - // Broadcast a[2,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-31] = a[2,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); - - // Broadcast a[3,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[3,0-31] = a[3,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); - c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); - - // Broadcast a[4,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[4,0-31] = a[4,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); - c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); - - c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); - c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); - - c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); - c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); - - c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); - c_int32_3p1 = _mm512_mullo_epi32( selector1, c_int32_3p1 ); - - c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); - c_int32_4p1 = _mm512_mullo_epi32( selector1, c_int32_4p1 ); - - // Scale C by beta. - if ( beta != 0 ) - { - // c[0,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); - - // c[1,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[1,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); - - // c[2,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[2,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); - - // c[3,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); - - // c[3,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p1 = _mm512_add_epi32( selector1, c_int32_3p1 ); - - // c[4,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); - - // c[4,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_4p1 = _mm512_add_epi32( selector1, c_int32_4p1 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_5x32: - { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[1, 16-31] - c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[2, 16-31] - c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); - - // c[3,0-15] - c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); - - // c[3, 16-31] - c_int32_3p1 = _mm512_add_epi32( selector2, c_int32_3p1 ); - - // c[4,0-15] - c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); - - // c[4, 16-31] - c_int32_4p1 = _mm512_add_epi32( selector2, c_int32_4p1 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_5x32: - { - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); - - // c[1,16-31] - c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); - - // c[2,16-31] - c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); - - // c[3,0-15] - c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); - - // c[3,16-31] - c_int32_3p1 = _mm512_max_epi32( selector1, c_int32_3p1 ); - - // c[4,0-15] - c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); - - // c[4,16-31] - c_int32_4p1 = _mm512_max_epi32( selector1, c_int32_4p1 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_5x32: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - // c[0, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_0p1) - - // c[1, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_1p0) - - // c[1, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_1p1) - - // c[2, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_2p0) - - // c[2, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_2p1) - - // c[3, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_3p0) - - // c[3, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_3p1) - - // c[4, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_4p0) - - // c[4, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_4p1) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_5x32: - { - selector1 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 1 * 16 ) ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); - - // c[0, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); - - // c[1, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); - - // c[1, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); - - // c[2, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); - - // c[2, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); - - // c[3, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); - - // c[3, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_3p1,selector2,3,1); - - // c[4, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_4p0,selector1,4,0); - - // c[4, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_4p1,selector2,4,1); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_5x32_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); - - // c[0, 16-31] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); - - // c[1,0-15] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); - - // c[1,16-31] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); - - // c[2,0-15] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); - - // c[2,16-31] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); - - // c[3,0-15] - _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 0*16 ), c_int32_3p0 ); - - // c[3,16-31] - _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 1*16 ), c_int32_3p1 ); - - // c[4,0-15] - _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 0*16 ), c_int32_4p0 ); - - // c[4,16-31] - _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 1*16 ), c_int32_4p1 ); -} - -// 4x32 int8o32 kernel -LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x32) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_4x32_DISABLE, - &&POST_OPS_BIAS_4x32, - &&POST_OPS_RELU_4x32, - &&POST_OPS_RELU_SCALE_4x32, - &&POST_OPS_DOWNSCALE_4x32 - }; - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // B matrix storage. - __m512i b0; - __m512i b1; - - // A matrix storage. - __m512i a_int32_0; - - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - __m512i c_int32_0p1 = _mm512_setzero_epi32(); - - __m512i c_int32_1p0 = _mm512_setzero_epi32(); - __m512i c_int32_1p1 = _mm512_setzero_epi32(); - - __m512i c_int32_2p0 = _mm512_setzero_epi32(); - __m512i c_int32_2p1 = _mm512_setzero_epi32(); - - __m512i c_int32_3p0 = _mm512_setzero_epi32(); - __m512i c_int32_3p1 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - - // Broadcast a[0,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - - // Broadcast a[1,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); - - // Broadcast a[2,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-31] = a[2,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); - - // Broadcast a[3,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[3,0-31] = a[3,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); - c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - - // Broadcast a[1,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); - - // Broadcast a[2,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-31] = a[2,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); - - // Broadcast a[3,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[3,0-31] = a[3,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); - c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); - - c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); - c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); - - c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); - c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); - - c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); - c_int32_3p1 = _mm512_mullo_epi32( selector1, c_int32_3p1 ); - - // Scale C by beta. - if ( beta != 0 ) - { - // c[0,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); - - // c[1,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[1,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); - - // c[2,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[2,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); - - // c[3,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); - - // c[3,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p1 = _mm512_add_epi32( selector1, c_int32_3p1 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_4x32: - { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[1, 16-31] - c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[2, 16-31] - c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); - - // c[3,0-15] - c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); - - // c[3, 16-31] - c_int32_3p1 = _mm512_add_epi32( selector2, c_int32_3p1 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_4x32: - { - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); - - // c[1,16-31] - c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); - - // c[2,16-31] - c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); - - // c[3,0-15] - c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); - - // c[3,16-31] - c_int32_3p1 = _mm512_max_epi32( selector1, c_int32_3p1 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_4x32: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - // c[0, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_0p1) - - // c[1, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_1p0) - - // c[1, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_1p1) - - // c[2, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_2p0) - - // c[2, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_2p1) - - // c[3, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_3p0) - - // c[3, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_3p1) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_4x32: - { - selector1 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 1 * 16 ) ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); - - // c[0, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); - - // c[1, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); - - // c[1, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); - - // c[2, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); - - // c[2, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); - - // c[3, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); - - // c[3, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_3p1,selector2,3,1); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_4x32_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); - - // c[0, 16-31] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); - - // c[1,0-15] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); - - // c[1,16-31] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); - - // c[2,0-15] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); - - // c[2,16-31] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); - - // c[3,0-15] - _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 0*16 ), c_int32_3p0 ); - - // c[3,16-31] - _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 1*16 ), c_int32_3p1 ); -} - -// 3x32 int8o32 kernel -LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x32) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_3x32_DISABLE, - &&POST_OPS_BIAS_3x32, - &&POST_OPS_RELU_3x32, - &&POST_OPS_RELU_SCALE_3x32, - &&POST_OPS_DOWNSCALE_3x32 - }; - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // B matrix storage. - __m512i b0; - __m512i b1; - - // A matrix storage. - __m512i a_int32_0; - - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - __m512i c_int32_0p1 = _mm512_setzero_epi32(); - - __m512i c_int32_1p0 = _mm512_setzero_epi32(); - __m512i c_int32_1p1 = _mm512_setzero_epi32(); - - __m512i c_int32_2p0 = _mm512_setzero_epi32(); - __m512i c_int32_2p1 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - - // Broadcast a[0,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - - // Broadcast a[1,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); - - // Broadcast a[2,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-31] = a[2,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - - // Broadcast a[1,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); - - // Broadcast a[2,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-31] = a[2,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); - - c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); - c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); - - c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); - c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); - - // Scale C by beta. - if ( beta != 0 ) - { - // c[0,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); - - // c[1,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[1,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); - - // c[2,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[2,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_3x32: - { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[1, 16-31] - c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[2, 16-31] - c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_3x32: - { - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); - - // c[1,16-31] - c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); - - // c[2,16-31] - c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_3x32: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - // c[0, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_0p1) - - // c[1, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_1p0) - - // c[1, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_1p1) - - // c[2, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_2p0) - - // c[2, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_2p1) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_3x32: - { - selector1 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 1 * 16 ) ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); - - // c[0, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); - - // c[1, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); - - // c[1, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); - - // c[2, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); - - // c[2, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_3x32_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); - - // c[0, 16-31] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); - - // c[1,0-15] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); - - // c[1,16-31] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); - - // c[2,0-15] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); - - // c[2,16-31] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); -} - -// 2x32 int8o32 kernel -LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x32) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_2x32_DISABLE, - &&POST_OPS_BIAS_2x32, - &&POST_OPS_RELU_2x32, - &&POST_OPS_RELU_SCALE_2x32, - &&POST_OPS_DOWNSCALE_2x32 - }; - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // B matrix storage. - __m512i b0; - __m512i b1; - - // A matrix storage. - __m512i a_int32_0; - - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - __m512i c_int32_0p1 = _mm512_setzero_epi32(); - - __m512i c_int32_1p0 = _mm512_setzero_epi32(); - __m512i c_int32_1p1 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - - // Broadcast a[0,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - - // Broadcast a[1,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - - // Broadcast a[1,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); - - c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); - c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); - - // Scale C by beta. - if ( beta != 0 ) - { - // c[0,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); - - // c[1,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[1,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_2x32: - { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[1, 16-31] - c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_2x32: - { - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); - - // c[1,16-31] - c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_2x32: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - // c[0, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_0p1) - - // c[1, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_1p0) - - // c[1, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_1p1) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_2x32: - { - selector1 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 1 * 16 ) ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); - - // c[0, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); - - // c[1, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); - - // c[1, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_2x32_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); - - // c[0, 16-31] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); - - // c[1,0-15] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); - - // c[1,16-31] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); -} - -// 1x32 int8o32 kernel -LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x32) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_1x32_DISABLE, - &&POST_OPS_BIAS_1x32, - &&POST_OPS_RELU_1x32, - &&POST_OPS_RELU_SCALE_1x32, - &&POST_OPS_DOWNSCALE_1x32 - }; - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // B matrix storage. - __m512i b0; - __m512i b1; - - // A matrix storage. - __m512i a_int32_0; - - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - __m512i c_int32_0p1 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - - // Broadcast a[0,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); - - // Scale C by beta. - if ( beta != 0 ) - { - // c[0,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_1x32: - { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_1x32: - { - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_1x32: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - // c[0, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_0p1) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_1x32: - { - selector1 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 1 * 16 ) ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); - - // c[0, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_1x32_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); - - // c[0, 16-31] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); -} - -// 5x48 int8o32 kernel -LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x48) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_5x48_DISABLE, - &&POST_OPS_BIAS_5x48, - &&POST_OPS_RELU_5x48, - &&POST_OPS_RELU_SCALE_5x48, - &&POST_OPS_DOWNSCALE_5x48 - }; - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; - - // A matrix storage. - __m512i a_int32_0; - - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - __m512i c_int32_0p1 = _mm512_setzero_epi32(); - __m512i c_int32_0p2 = _mm512_setzero_epi32(); - - __m512i c_int32_1p0 = _mm512_setzero_epi32(); - __m512i c_int32_1p1 = _mm512_setzero_epi32(); - __m512i c_int32_1p2 = _mm512_setzero_epi32(); - - __m512i c_int32_2p0 = _mm512_setzero_epi32(); - __m512i c_int32_2p1 = _mm512_setzero_epi32(); - __m512i c_int32_2p2 = _mm512_setzero_epi32(); - - __m512i c_int32_3p0 = _mm512_setzero_epi32(); - __m512i c_int32_3p1 = _mm512_setzero_epi32(); - __m512i c_int32_3p2 = _mm512_setzero_epi32(); - - __m512i c_int32_4p0 = _mm512_setzero_epi32(); - __m512i c_int32_4p1 = _mm512_setzero_epi32(); - __m512i c_int32_4p2 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); - - // Broadcast a[0,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); - - // Broadcast a[1,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); - c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); - - // Broadcast a[2,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-47] = a[2,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); - c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); - - // Broadcast a[3,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[3,0-47] = a[3,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); - c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); - c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_0, b2 ); - - // Broadcast a[4,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[4,0-47] = a[4,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); - c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); - c_int32_4p2 = _mm512_dpbusd_epi32( c_int32_4p2, a_int32_0, b2 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); - - // Broadcast a[1,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); - c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); - - // Broadcast a[2,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-47] = a[2,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); - c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); - - // Broadcast a[3,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[3,0-47] = a[3,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); - c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); - c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_0, b2 ); - - // Broadcast a[4,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[4,0-47] = a[4,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); - c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); - c_int32_4p2 = _mm512_dpbusd_epi32( c_int32_4p2, a_int32_0, b2 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); - c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); - - c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); - c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); - c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); - - c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); - c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); - c_int32_2p2 = _mm512_mullo_epi32( selector1, c_int32_2p2 ); - - c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); - c_int32_3p1 = _mm512_mullo_epi32( selector1, c_int32_3p1 ); - c_int32_3p2 = _mm512_mullo_epi32( selector1, c_int32_3p2 ); - - c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); - c_int32_4p1 = _mm512_mullo_epi32( selector1, c_int32_4p1 ); - c_int32_4p2 = _mm512_mullo_epi32( selector1, c_int32_4p2 ); - - // Scale C by beta. - if ( beta != 0 ) - { - // c[0,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); - - // c[0,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); - - // c[1,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[1,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); - - // c[1,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); - - // c[2,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[2,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); - - // c[2,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p2 = _mm512_add_epi32( selector1, c_int32_2p2 ); - - // c[3,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); - - // c[3,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p1 = _mm512_add_epi32( selector1, c_int32_3p1 ); - - // c[3,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p2 = _mm512_add_epi32( selector1, c_int32_3p2 ); - - // c[4,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); - - // c[4,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_4p1 = _mm512_add_epi32( selector1, c_int32_4p1 ); - - // c[4,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_4p2 = _mm512_add_epi32( selector1, c_int32_4p2 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_5x48: - { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 2 * 16 ) ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); - - // c[0,32-47] - c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[1, 16-31] - c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); - - // c[1,32-47] - c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[2, 16-31] - c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); - - // c[2,32-47] - c_int32_2p2 = _mm512_add_epi32( a_int32_0, c_int32_2p2 ); - - // c[3,0-15] - c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); - - // c[3, 16-31] - c_int32_3p1 = _mm512_add_epi32( selector2, c_int32_3p1 ); - - // c[3,32-47] - c_int32_3p2 = _mm512_add_epi32( a_int32_0, c_int32_3p2 ); - - // c[4,0-15] - c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); - - // c[4, 16-31] - c_int32_4p1 = _mm512_add_epi32( selector2, c_int32_4p1 ); - - // c[4,32-47] - c_int32_4p2 = _mm512_add_epi32( a_int32_0, c_int32_4p2 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_5x48: - { - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); - - // c[0,32-47] - c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); - - // c[1,16-31] - c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); - - // c[1,32-47] - c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); - - // c[2,16-31] - c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); - - // c[2,32-47] - c_int32_2p2 = _mm512_max_epi32( selector1, c_int32_2p2 ); - - // c[3,0-15] - c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); - - // c[3,16-31] - c_int32_3p1 = _mm512_max_epi32( selector1, c_int32_3p1 ); - - // c[3,32-47] - c_int32_3p2 = _mm512_max_epi32( selector1, c_int32_3p2 ); - - // c[4,0-15] - c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); - - // c[4,16-31] - c_int32_4p1 = _mm512_max_epi32( selector1, c_int32_4p1 ); - - // c[4,32-47] - c_int32_4p2 = _mm512_max_epi32( selector1, c_int32_4p2 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_5x48: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - // c[0, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_0p1) - - // c[0, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_0p2) - - // c[1, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_1p0) - - // c[1, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_1p1) - - // c[1, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_1p2) - - // c[2, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_2p0) - - // c[2, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_2p1) - - // c[2, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_2p2) - - // c[3, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_3p0) - - // c[3, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_3p1) - - // c[3, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_3p2) - - // c[4, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_4p0) - - // c[4, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_4p1) - - // c[4, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_4p2) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_5x48: - { - selector1 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 2 * 16 ) ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); - - // c[0, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); - - // c[0, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); - - // c[1, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); - - // c[1, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); - - // c[1, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); - - // c[2, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); - - // c[2, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); - - // c[2, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_2p2,a_int32_0,2,2); - - // c[3, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); - - // c[3, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_3p1,selector2,3,1); - - // c[3, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_3p2,a_int32_0,3,2); - - // c[4, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_4p0,selector1,4,0); - - // c[4, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_4p1,selector2,4,1); - - // c[4, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_4p2,a_int32_0,4,2); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_5x48_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); - - // c[0, 16-31] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); - - // c[0,32-47] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); - - // c[1,0-15] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); - - // c[1,16-31] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); - - // c[1,32-47] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 2*16 ), c_int32_1p2 ); - - // c[2,0-15] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); - - // c[2,16-31] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); - - // c[2,32-47] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 2*16 ), c_int32_2p2 ); - - // c[3,0-15] - _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 0*16 ), c_int32_3p0 ); - - // c[3,16-31] - _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 1*16 ), c_int32_3p1 ); - - // c[3,32-47] - _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 2*16 ), c_int32_3p2 ); - - // c[4,0-15] - _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 0*16 ), c_int32_4p0 ); - - // c[4,16-31] - _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 1*16 ), c_int32_4p1 ); - - // c[4,32-47] - _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 2*16 ), c_int32_4p2 ); -} - -// 4x48 int8o32 kernel -LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x48) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_4x48_DISABLE, - &&POST_OPS_BIAS_4x48, - &&POST_OPS_RELU_4x48, - &&POST_OPS_RELU_SCALE_4x48, - &&POST_OPS_DOWNSCALE_4x48 - }; - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; - - // A matrix storage. - __m512i a_int32_0; - - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - __m512i c_int32_0p1 = _mm512_setzero_epi32(); - __m512i c_int32_0p2 = _mm512_setzero_epi32(); - - __m512i c_int32_1p0 = _mm512_setzero_epi32(); - __m512i c_int32_1p1 = _mm512_setzero_epi32(); - __m512i c_int32_1p2 = _mm512_setzero_epi32(); - - __m512i c_int32_2p0 = _mm512_setzero_epi32(); - __m512i c_int32_2p1 = _mm512_setzero_epi32(); - __m512i c_int32_2p2 = _mm512_setzero_epi32(); - - __m512i c_int32_3p0 = _mm512_setzero_epi32(); - __m512i c_int32_3p1 = _mm512_setzero_epi32(); - __m512i c_int32_3p2 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); - - // Broadcast a[0,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); - - // Broadcast a[1,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); - c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); - - // Broadcast a[2,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-47] = a[2,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); - c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); - - // Broadcast a[3,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[3,0-47] = a[3,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); - c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); - c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_0, b2 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); - - // Broadcast a[1,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); - c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); - - // Broadcast a[2,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-47] = a[2,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); - c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); - - // Broadcast a[3,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[3,0-47] = a[3,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); - c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); - c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_0, b2 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); - c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); - - c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); - c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); - c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); - - c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); - c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); - c_int32_2p2 = _mm512_mullo_epi32( selector1, c_int32_2p2 ); - - c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); - c_int32_3p1 = _mm512_mullo_epi32( selector1, c_int32_3p1 ); - c_int32_3p2 = _mm512_mullo_epi32( selector1, c_int32_3p2 ); - - // Scale C by beta. - if ( beta != 0 ) - { - // c[0,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); - - // c[0,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); - - // c[1,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[1,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); - - // c[1,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); - - // c[2,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[2,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); - - // c[2,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p2 = _mm512_add_epi32( selector1, c_int32_2p2 ); - - // c[3,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); - - // c[3,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p1 = _mm512_add_epi32( selector1, c_int32_3p1 ); - - // c[3,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p2 = _mm512_add_epi32( selector1, c_int32_3p2 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_4x48: - { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 2 * 16 ) ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); - - // c[0,32-47] - c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[1, 16-31] - c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); - - // c[1,32-47] - c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[2, 16-31] - c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); - - // c[2,32-47] - c_int32_2p2 = _mm512_add_epi32( a_int32_0, c_int32_2p2 ); - - // c[3,0-15] - c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); - - // c[3, 16-31] - c_int32_3p1 = _mm512_add_epi32( selector2, c_int32_3p1 ); - - // c[3,32-47] - c_int32_3p2 = _mm512_add_epi32( a_int32_0, c_int32_3p2 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_4x48: - { - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); - - // c[0,32-47] - c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); - - // c[1,16-31] - c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); - - // c[1,32-47] - c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); - - // c[2,16-31] - c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); - - // c[2,32-47] - c_int32_2p2 = _mm512_max_epi32( selector1, c_int32_2p2 ); - - // c[3,0-15] - c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); - - // c[3,16-31] - c_int32_3p1 = _mm512_max_epi32( selector1, c_int32_3p1 ); - - // c[3,32-47] - c_int32_3p2 = _mm512_max_epi32( selector1, c_int32_3p2 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_4x48: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - // c[0, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_0p1) - - // c[0, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_0p2) - - // c[1, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_1p0) - - // c[1, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_1p1) - - // c[1, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_1p2) - - // c[2, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_2p0) - - // c[2, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_2p1) - - // c[2, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_2p2) - - // c[3, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_3p0) - - // c[3, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_3p1) - - // c[3, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_3p2) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_4x48: - { - selector1 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 2 * 16 ) ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); - - // c[0, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); - - // c[0, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); - - // c[1, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); - - // c[1, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); - - // c[1, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); - - // c[2, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); - - // c[2, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); - - // c[2, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_2p2,a_int32_0,2,2); - - // c[3, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); - - // c[3, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_3p1,selector2,3,1); - - // c[3, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_3p2,a_int32_0,3,2); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_4x48_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); - - // c[0, 16-31] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); - - // c[0,32-47] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); - - // c[1,0-15] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); - - // c[1,16-31] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); - - // c[1,32-47] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 2*16 ), c_int32_1p2 ); - - // c[2,0-15] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); - - // c[2,16-31] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); - - // c[2,32-47] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 2*16 ), c_int32_2p2 ); - - // c[3,0-15] - _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 0*16 ), c_int32_3p0 ); - - // c[3,16-31] - _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 1*16 ), c_int32_3p1 ); - - // c[3,32-47] - _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 2*16 ), c_int32_3p2 ); -} - -// 3x48 int8o32 kernel -LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x48) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_3x48_DISABLE, - &&POST_OPS_BIAS_3x48, - &&POST_OPS_RELU_3x48, - &&POST_OPS_RELU_SCALE_3x48, - &&POST_OPS_DOWNSCALE_3x48 - }; - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; - - // A matrix storage. - __m512i a_int32_0; - - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - __m512i c_int32_0p1 = _mm512_setzero_epi32(); - __m512i c_int32_0p2 = _mm512_setzero_epi32(); - - __m512i c_int32_1p0 = _mm512_setzero_epi32(); - __m512i c_int32_1p1 = _mm512_setzero_epi32(); - __m512i c_int32_1p2 = _mm512_setzero_epi32(); - - __m512i c_int32_2p0 = _mm512_setzero_epi32(); - __m512i c_int32_2p1 = _mm512_setzero_epi32(); - __m512i c_int32_2p2 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); - - // Broadcast a[0,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); - - // Broadcast a[1,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); - c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); - - // Broadcast a[2,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-47] = a[2,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); - c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); - - // Broadcast a[1,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); - c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); - - // Broadcast a[2,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-47] = a[2,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); - c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); - c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); - - c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); - c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); - c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); - - c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); - c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); - c_int32_2p2 = _mm512_mullo_epi32( selector1, c_int32_2p2 ); - - // Scale C by beta. - if ( beta != 0 ) - { - // c[0,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); - - // c[0,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); - - // c[1,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[1,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); - - // c[1,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); - - // c[2,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[2,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); - - // c[2,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p2 = _mm512_add_epi32( selector1, c_int32_2p2 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_3x48: - { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 2 * 16 ) ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); - - // c[0,32-47] - c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[1, 16-31] - c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); - - // c[1,32-47] - c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[2, 16-31] - c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); - - // c[2,32-47] - c_int32_2p2 = _mm512_add_epi32( a_int32_0, c_int32_2p2 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_3x48: - { - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); - - // c[0,32-47] - c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); - - // c[1,16-31] - c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); - - // c[1,32-47] - c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); - - // c[2,16-31] - c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); - - // c[2,32-47] - c_int32_2p2 = _mm512_max_epi32( selector1, c_int32_2p2 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_3x48: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - // c[0, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_0p1) - - // c[0, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_0p2) - - // c[1, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_1p0) - - // c[1, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_1p1) - - // c[1, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_1p2) - - // c[2, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_2p0) - - // c[2, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_2p1) - - // c[2, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_2p2) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_3x48: - { - selector1 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 2 * 16 ) ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); - - // c[0, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); - - // c[0, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); - - // c[1, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); - - // c[1, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); - - // c[1, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); - - // c[2, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); - - // c[2, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); - - // c[2, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_2p2,a_int32_0,2,2); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_3x48_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); - - // c[0, 16-31] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); - - // c[0,32-47] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); - - // c[1,0-15] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); - - // c[1,16-31] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); - - // c[1,32-47] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 2*16 ), c_int32_1p2 ); - - // c[2,0-15] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); - - // c[2,16-31] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); - - // c[2,32-47] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 2*16 ), c_int32_2p2 ); -} - -// 2x48 int8o32 kernel -LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x48) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_2x48_DISABLE, - &&POST_OPS_BIAS_2x48, - &&POST_OPS_RELU_2x48, - &&POST_OPS_RELU_SCALE_2x48, - &&POST_OPS_DOWNSCALE_2x48 - }; - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; - - // A matrix storage. - __m512i a_int32_0; - - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - __m512i c_int32_0p1 = _mm512_setzero_epi32(); - __m512i c_int32_0p2 = _mm512_setzero_epi32(); - - __m512i c_int32_1p0 = _mm512_setzero_epi32(); - __m512i c_int32_1p1 = _mm512_setzero_epi32(); - __m512i c_int32_1p2 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); - - // Broadcast a[0,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); - - // Broadcast a[1,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); - c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); - - // Broadcast a[1,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); - c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); - c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); - - c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); - c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); - c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); - - // Scale C by beta. - if ( beta != 0 ) - { - // c[0,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); - - // c[0,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); - - // c[1,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[1,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); - - // c[1,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_2x48: - { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 2 * 16 ) ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); - - // c[0,32-47] - c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[1, 16-31] - c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); - - // c[1,32-47] - c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_2x48: - { - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); - - // c[0,32-47] - c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); - - // c[1,16-31] - c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); - - // c[1,32-47] - c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_2x48: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - // c[0, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_0p1) - - // c[0, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_0p2) - - // c[1, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_1p0) - - // c[1, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_1p1) - - // c[1, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_1p2) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_2x48: - { - selector1 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 2 * 16 ) ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); - - // c[0, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); - - // c[0, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); - - // c[1, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); - - // c[1, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); - - // c[1, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_2x48_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); - - // c[0, 16-31] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); - - // c[0,32-47] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); - - // c[1,0-15] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); - - // c[1,16-31] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); - - // c[1,32-47] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 2*16 ), c_int32_1p2 ); -} - -// 1x48 int8o32 kernel -LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x48) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_1x48_DISABLE, - &&POST_OPS_BIAS_1x48, - &&POST_OPS_RELU_1x48, - &&POST_OPS_RELU_SCALE_1x48, - &&POST_OPS_DOWNSCALE_1x48 - }; - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; - - // A matrix storage. - __m512i a_int32_0; - - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - __m512i c_int32_0p1 = _mm512_setzero_epi32(); - __m512i c_int32_0p2 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); - - // Broadcast a[0,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); - c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); - - // Scale C by beta. - if ( beta != 0 ) - { - // c[0,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); - - // c[0,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_1x48: - { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 2 * 16 ) ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); - - // c[0,32-47] - c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_1x48: - { - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); - - // c[0,32-47] - c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_1x48: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - // c[0, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_0p1) - - // c[0, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_0p2) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_1x48: - { - selector1 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 2 * 16 ) ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); - - // c[0, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); - - // c[0, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_1x48_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); - - // c[0, 16-31] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); - - // c[0,32-47] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); -} -#endif diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c deleted file mode 100644 index 856dc1355e..0000000000 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c +++ /dev/null @@ -1,2300 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include - -#include "blis.h" -#include "lpgemm_kernels.h" -#include "lpgemm_s32_kern_macros.h" - -#ifdef BLIS_KERNELS_ZEN4 -// 6xlt16 int8o32 fringe kernel -LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_6xLT16_DISABLE, - &&POST_OPS_BIAS_6xLT16, - &&POST_OPS_RELU_6xLT16, - &&POST_OPS_RELU_SCALE_6xLT16, - &&POST_OPS_DOWNSCALE_6xLT16 - }; - dim_t MR = 6; - dim_t m_full_pieces = m0 / MR; - dim_t m_full_pieces_loop_limit = m_full_pieces * MR; - dim_t m_partial_pieces = m0 % MR; - - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // B matrix storage. - __m512i b0; - - // A matrix storage. - __m512i a_int32_0; - - // For corner cases. - int32_t buf0[16]; - int32_t buf1[16]; - int32_t buf2[16]; - int32_t buf3[16]; - int32_t buf4[16]; - int32_t buf5[16]; - - for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) - { - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - - __m512i c_int32_1p0 = _mm512_setzero_epi32(); - - __m512i c_int32_2p0 = _mm512_setzero_epi32(); - - __m512i c_int32_3p0 = _mm512_setzero_epi32(); - - __m512i c_int32_4p0 = _mm512_setzero_epi32(); - - __m512i c_int32_5p0 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - // Load 4 rows with 16 extended elements each from B to 1 ZMM - // registers. It is to be noted that the B matrix is packed for use - // in vnni instructions and each load to ZMM register will have 4 - // elements along k direction and 16 elements across n directions, - // so 4x16 elements to a ZMM register. - b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - - // Broadcast a[1,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - - // Broadcast a[2,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - - // Broadcast a[3,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); - - // Broadcast a[4,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[4,0-15] = a[4,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); - - // Broadcast a[5,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[5,0-15] = a[5,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_5p0 = _mm512_dpbusd_epi32( c_int32_5p0, a_int32_0, b0 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - - // Broadcast a[1,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - - // Broadcast a[2,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - - // Broadcast a[3,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); - - // Broadcast a[4,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[4,0-15] = a[4,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); - - // Broadcast a[5,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[5,0-15] = a[5,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_5p0 = _mm512_dpbusd_epi32( c_int32_5p0, a_int32_0, b0 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - - c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); - - c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); - - c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); - - c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); - - c_int32_5p0 = _mm512_mullo_epi32( selector1, c_int32_5p0 ); - - // Scale C by beta. - if ( beta != 0 ) - { - memcpy( buf0, ( c + ( rs_c * ( ir + 0 ) ) ), ( n0_rem * sizeof( int32_t ) ) ); - memcpy( buf1, ( c + ( rs_c * ( ir + 1 ) ) ), ( n0_rem * sizeof( int32_t ) ) ); - memcpy( buf2, ( c + ( rs_c * ( ir + 2 ) ) ), ( n0_rem * sizeof( int32_t ) ) ); - memcpy( buf3, ( c + ( rs_c * ( ir + 3 ) ) ), ( n0_rem * sizeof( int32_t ) ) ); - memcpy( buf4, ( c + ( rs_c * ( ir + 4 ) ) ), ( n0_rem * sizeof( int32_t ) ) ); - memcpy( buf5, ( c + ( rs_c * ( ir + 5 ) ) ), ( n0_rem * sizeof( int32_t ) ) ); - - // c[0,0-15] - selector1 = _mm512_loadu_epi32( buf0 ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - selector1 = _mm512_loadu_epi32( buf1 ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[2,0-15] - selector1 = _mm512_loadu_epi32( buf2 ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[3,0-15] - selector1 = _mm512_loadu_epi32( buf3 ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); - - // c[4,0-15] - selector1 = _mm512_loadu_epi32( buf4 ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); - - // c[5,0-15] - selector1 = _mm512_loadu_epi32( buf5 ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_6xLT16: - { - memcpy( buf0, ( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j ), ( n0_rem * sizeof( int32_t ) ) ); - selector1 = _mm512_loadu_epi32( buf0 ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[3,0-15] - c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); - - // c[4,0-15] - c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); - - // c[5,0-15] - c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_6xLT16: - { - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); - - // c[3,0-15] - c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); - - // c[4,0-15] - c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); - - // c[5,0-15] - c_int32_5p0 = _mm512_max_epi32( selector1, c_int32_5p0 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_6xLT16: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - // c[1, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_1p0) - - // c[2, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_2p0) - - // c[3, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_3p0) - - // c[4, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_4p0) - - // c[5, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_5p0) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_6xLT16: - { - memcpy( buf0, ( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - selector1 = _mm512_loadu_epi32( buf0 ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8_LT16(c_int32_0p0,selector1,0,0); - - // c[1, 0-15] - CVT_MULRND_CVT32_CVT8_LT16(c_int32_1p0,selector1,1,0); - - // c[2, 0-15] - CVT_MULRND_CVT32_CVT8_LT16(c_int32_2p0,selector1,2,0); - - // c[3, 0-15] - CVT_MULRND_CVT32_CVT8_LT16(c_int32_3p0,selector1,3,0); - - // c[4, 0-15] - CVT_MULRND_CVT32_CVT8_LT16(c_int32_4p0,selector1,4,0); - - // c[5, 0-15] - CVT_MULRND_CVT32_CVT8_LT16(c_int32_5p0,selector1,5,0); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_6xLT16_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( buf0, c_int32_0p0 ); - - // c[1,0-15] - _mm512_storeu_epi32( buf1, c_int32_1p0 ); - - // c[2,0-15] - _mm512_storeu_epi32( buf2, c_int32_2p0 ); - - // c[3,0-15] - _mm512_storeu_epi32( buf3, c_int32_3p0 ); - - // c[4,0-15] - _mm512_storeu_epi32( buf4, c_int32_4p0 ); - - // c[5,0-15] - _mm512_storeu_epi32( buf5, c_int32_5p0 ); - - // Memcpy partial parts. - // c[0,0-15] - memcpy( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), buf0, ( n0_rem * sizeof( int32_t ) ) ); - - // c[1,0-15] - memcpy( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), buf1, ( n0_rem * sizeof( int32_t ) ) ); - - // c[2,0-15] - memcpy( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), buf2, ( n0_rem * sizeof( int32_t ) ) ); - - // c[3,0-15] - memcpy( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), buf3, ( n0_rem * sizeof( int32_t ) ) ); - - // c[4,0-15] - memcpy( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), buf4, ( n0_rem * sizeof( int32_t ) ) ); - - // c[5,0-15] - memcpy( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), buf5, ( n0_rem * sizeof( int32_t ) ) ); - - a = a + ( MR * ps_a ); - post_op_c_i += MR; - } - - if ( m_partial_pieces > 0 ) - { - if ( m_partial_pieces == 5 ) - { - dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 5 ); - lpgemm_rowvar_u8s8s32o32_5xlt16 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, n0_rem, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 4 ) - { - dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 4 ); - lpgemm_rowvar_u8s8s32o32_4xlt16 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, n0_rem, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 3 ) - { - dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 3 ); - lpgemm_rowvar_u8s8s32o32_3xlt16 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, n0_rem, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 2 ) - { - dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 2 ); - lpgemm_rowvar_u8s8s32o32_2xlt16 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, n0_rem, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 1 ) - { - dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 1 ); - lpgemm_rowvar_u8s8s32o32_1xlt16 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, n0_rem, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - } -} - -// 6x16 int8o32 fringe kernel -LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x16) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_6x16_DISABLE, - &&POST_OPS_BIAS_6x16, - &&POST_OPS_RELU_6x16, - &&POST_OPS_RELU_SCALE_6x16, - &&POST_OPS_DOWNSCALE_6x16 - }; - dim_t MR = 6; - dim_t m_full_pieces = m0 / MR; - dim_t m_full_pieces_loop_limit = m_full_pieces * MR; - dim_t m_partial_pieces = m0 % MR; - - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // B matrix storage. - __m512i b0; - - // A matrix storage. - __m512i a_int32_0; - - for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) - { - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - - __m512i c_int32_1p0 = _mm512_setzero_epi32(); - - __m512i c_int32_2p0 = _mm512_setzero_epi32(); - - __m512i c_int32_3p0 = _mm512_setzero_epi32(); - - __m512i c_int32_4p0 = _mm512_setzero_epi32(); - - __m512i c_int32_5p0 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - // Load 4 rows with 16 elements each from B to 1 ZMM registers. It - // is to be noted that the B matrix is packed for use in vnni - // instructions and each load to ZMM register will have 4 elements - // along k direction and 16 elements across n directions, so 4x16 - // elements to a ZMM register. - b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - - // Broadcast a[1,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - - // Broadcast a[2,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - - // Broadcast a[3,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); - - // Broadcast a[4,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[4,0-15] = a[4,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); - - // Broadcast a[5,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[5,0-15] = a[5,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_5p0 = _mm512_dpbusd_epi32( c_int32_5p0, a_int32_0, b0 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - - // Broadcast a[1,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - - // Broadcast a[2,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - - // Broadcast a[3,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); - - // Broadcast a[4,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[4,0-15] = a[4,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); - - // Broadcast a[5,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[5,0-15] = a[5,kr:kr+4]*b[kr:kr+4,0-15] - c_int32_5p0 = _mm512_dpbusd_epi32( c_int32_5p0, a_int32_0, b0 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - - c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); - - c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); - - c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); - - c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); - - c_int32_5p0 = _mm512_mullo_epi32( selector1, c_int32_5p0 ); - - // Scale C by beta. - if ( beta != 0 ) - { - // c[0,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[2,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[3,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); - - // c[4,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); - - // c[5,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_6x16: - { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[3,0-15] - c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); - - // c[4,0-15] - c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); - - // c[5,0-15] - c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_6x16: - { - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); - - // c[3,0-15] - c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); - - // c[4,0-15] - c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); - - // c[5,0-15] - c_int32_5p0 = _mm512_max_epi32( selector1, c_int32_5p0 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_6x16: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - // c[1, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_1p0) - - // c[2, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_2p0) - - // c[3, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_3p0) - - // c[4, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_4p0) - - // c[5, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_5p0) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_6x16: - { - selector1 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 0 * 16 ) ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); - - // c[1, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); - - // c[2, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); - - // c[3, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); - - // c[4, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_4p0,selector1,4,0); - - // c[5, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_5p0,selector1,5,0); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_6x16_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), c_int32_0p0 ); - - // c[1,0-15] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), c_int32_1p0 ); - - // c[2,0-15] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), c_int32_2p0 ); - - // c[3,0-15] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), c_int32_3p0 ); - - // c[4,0-15] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), c_int32_4p0 ); - - // c[5,0-15] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_int32_5p0 ); - - a = a + ( MR * ps_a ); - post_op_c_i += MR; - } - - if ( m_partial_pieces > 0 ) - { - if ( m_partial_pieces == 5 ) - { - dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 5 ); - lpgemm_rowvar_u8s8s32o32_5x16 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 4 ) - { - dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 4 ); - lpgemm_rowvar_u8s8s32o32_4x16 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 3 ) - { - dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 3 ); - lpgemm_rowvar_u8s8s32o32_3x16 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 2 ) - { - dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 2 ); - lpgemm_rowvar_u8s8s32o32_2x16 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 1 ) - { - dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 1 ); - lpgemm_rowvar_u8s8s32o32_1x16 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - } -} - -// 6x32 int8o32 fringe kernel -LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x32) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_6x32_DISABLE, - &&POST_OPS_BIAS_6x32, - &&POST_OPS_RELU_6x32, - &&POST_OPS_RELU_SCALE_6x32, - &&POST_OPS_DOWNSCALE_6x32 - }; - dim_t MR = 6; - dim_t m_full_pieces = m0 / MR; - dim_t m_full_pieces_loop_limit = m_full_pieces * MR; - dim_t m_partial_pieces = m0 % MR; - - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // B matrix storage. - __m512i b0; - __m512i b1; - - // A matrix storage. - __m512i a_int32_0; - - for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) - { - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - __m512i c_int32_0p1 = _mm512_setzero_epi32(); - - __m512i c_int32_1p0 = _mm512_setzero_epi32(); - __m512i c_int32_1p1 = _mm512_setzero_epi32(); - - __m512i c_int32_2p0 = _mm512_setzero_epi32(); - __m512i c_int32_2p1 = _mm512_setzero_epi32(); - - __m512i c_int32_3p0 = _mm512_setzero_epi32(); - __m512i c_int32_3p1 = _mm512_setzero_epi32(); - - __m512i c_int32_4p0 = _mm512_setzero_epi32(); - __m512i c_int32_4p1 = _mm512_setzero_epi32(); - - __m512i c_int32_5p0 = _mm512_setzero_epi32(); - __m512i c_int32_5p1 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - // Load 4 rows with 32 elements each from B to 2 ZMM registers. It - // is to be noted that the B matrix is packed for use in vnni - // instructions and each load to ZMM register will have 4 elements - // along k direction and 16 elements across n directions, so 4x16 - // elements to a ZMM register. - b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - - // Broadcast a[0,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - - // Broadcast a[1,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); - - // Broadcast a[2,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-31] = a[2,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); - - // Broadcast a[3,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[3,0-31] = a[3,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); - c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); - - // Broadcast a[4,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[4,0-31] = a[4,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); - c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); - - // Broadcast a[5,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[5,0-31] = a[5,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_5p0 = _mm512_dpbusd_epi32( c_int32_5p0, a_int32_0, b0 ); - c_int32_5p1 = _mm512_dpbusd_epi32( c_int32_5p1, a_int32_0, b1 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - - // Broadcast a[1,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); - - // Broadcast a[2,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-31] = a[2,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); - - // Broadcast a[3,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[3,0-31] = a[3,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); - c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); - - // Broadcast a[4,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[4,0-31] = a[4,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); - c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); - - // Broadcast a[5,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[5,0-31] = a[5,kr:kr+4]*b[kr:kr+4,0-31] - c_int32_5p0 = _mm512_dpbusd_epi32( c_int32_5p0, a_int32_0, b0 ); - c_int32_5p1 = _mm512_dpbusd_epi32( c_int32_5p1, a_int32_0, b1 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); - - c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); - c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); - - c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); - c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); - - c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); - c_int32_3p1 = _mm512_mullo_epi32( selector1, c_int32_3p1 ); - - c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); - c_int32_4p1 = _mm512_mullo_epi32( selector1, c_int32_4p1 ); - - c_int32_5p0 = _mm512_mullo_epi32( selector1, c_int32_5p0 ); - c_int32_5p1 = _mm512_mullo_epi32( selector1, c_int32_5p1 ); - - // Scale C by beta. - if ( beta != 0 ) - { - // c[0,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); - - // c[1,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[1,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); - - // c[2,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[2,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); - - // c[3,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); - - // c[3,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p1 = _mm512_add_epi32( selector1, c_int32_3p1 ); - - // c[4,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); - - // c[4,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_4p1 = _mm512_add_epi32( selector1, c_int32_4p1 ); - - // c[5,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); - - // c[5,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_5p1 = _mm512_add_epi32( selector1, c_int32_5p1 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_6x32: - { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[1, 16-31] - c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[2, 16-31] - c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); - - // c[3,0-15] - c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); - - // c[3, 16-31] - c_int32_3p1 = _mm512_add_epi32( selector2, c_int32_3p1 ); - - // c[4,0-15] - c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); - - // c[4, 16-31] - c_int32_4p1 = _mm512_add_epi32( selector2, c_int32_4p1 ); - - // c[5,0-15] - c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); - - // c[5, 16-31] - c_int32_5p1 = _mm512_add_epi32( selector2, c_int32_5p1 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_6x32: - { - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); - - // c[1,16-31] - c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); - - // c[2,16-31] - c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); - - // c[3,0-15] - c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); - - // c[3,16-31] - c_int32_3p1 = _mm512_max_epi32( selector1, c_int32_3p1 ); - - // c[4,0-15] - c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); - - // c[4,16-31] - c_int32_4p1 = _mm512_max_epi32( selector1, c_int32_4p1 ); - - // c[5,0-15] - c_int32_5p0 = _mm512_max_epi32( selector1, c_int32_5p0 ); - - // c[5,16-31] - c_int32_5p1 = _mm512_max_epi32( selector1, c_int32_5p1 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_6x32: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - // c[0, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_0p1) - - // c[1, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_1p0) - - // c[1, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_1p1) - - // c[2, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_2p0) - - // c[2, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_2p1) - - // c[3, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_3p0) - - // c[3, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_3p1) - - // c[4, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_4p0) - - // c[4, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_4p1) - - // c[5, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_5p0) - - // c[5, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_5p1) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_6x32: - { - selector1 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 1 * 16 ) ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); - - // c[0, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); - - // c[1, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); - - // c[1, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); - - // c[2, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); - - // c[2, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); - - // c[3, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); - - // c[3, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_3p1,selector2,3,1); - - // c[4, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_4p0,selector1,4,0); - - // c[4, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_4p1,selector2,4,1); - - // c[5, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_5p0,selector1,5,0); - - // c[5, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_5p1,selector2,5,1); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_6x32_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), c_int32_0p0 ); - - // c[0, 16-31] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ), c_int32_0p1 ); - - // c[1,0-15] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), c_int32_1p0 ); - - // c[1,16-31] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ), c_int32_1p1 ); - - // c[2,0-15] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), c_int32_2p0 ); - - // c[2,16-31] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ), c_int32_2p1 ); - - // c[3,0-15] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), c_int32_3p0 ); - - // c[3,16-31] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ), c_int32_3p1 ); - - // c[4,0-15] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), c_int32_4p0 ); - - // c[4,16-31] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ), c_int32_4p1 ); - - // c[5,0-15] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_int32_5p0 ); - - // c[5,16-31] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ), c_int32_5p1 ); - - a = a + ( MR * ps_a ); - post_op_c_i += MR; - } - - if ( m_partial_pieces > 0 ) - { - if ( m_partial_pieces == 5 ) - { - dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 5 ); - lpgemm_rowvar_u8s8s32o32_5x32 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 4 ) - { - dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 4 ); - lpgemm_rowvar_u8s8s32o32_4x32 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 3 ) - { - dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 3 ); - lpgemm_rowvar_u8s8s32o32_3x32 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 2 ) - { - dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 2 ); - lpgemm_rowvar_u8s8s32o32_2x32 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 1 ) - { - dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 1 ); - lpgemm_rowvar_u8s8s32o32_1x32 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - } -} - -// 6x48 int8o32 fringe kernel -LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x48) -{ - static void* post_ops_labels[] = - { - &&POST_OPS_6x48_DISABLE, - &&POST_OPS_BIAS_6x48, - &&POST_OPS_RELU_6x48, - &&POST_OPS_RELU_SCALE_6x48, - &&POST_OPS_DOWNSCALE_6x48 - }; - dim_t MR = 6; - dim_t m_full_pieces = m0 / MR; - dim_t m_full_pieces_loop_limit = m_full_pieces * MR; - dim_t m_partial_pieces = m0 % MR; - - dim_t k_full_pieces = k0 / 4; - dim_t k_partial_pieces = k0 % 4; - - uint32_t a_kfringe_buf = 0; - - // B matrix storage. - __m512i b0; - __m512i b1; - __m512i b2; - - // A matrix storage. - __m512i a_int32_0; - - for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) - { - // Registers to use for accumulating C. - __m512i c_int32_0p0 = _mm512_setzero_epi32(); - __m512i c_int32_0p1 = _mm512_setzero_epi32(); - __m512i c_int32_0p2 = _mm512_setzero_epi32(); - - __m512i c_int32_1p0 = _mm512_setzero_epi32(); - __m512i c_int32_1p1 = _mm512_setzero_epi32(); - __m512i c_int32_1p2 = _mm512_setzero_epi32(); - - __m512i c_int32_2p0 = _mm512_setzero_epi32(); - __m512i c_int32_2p1 = _mm512_setzero_epi32(); - __m512i c_int32_2p2 = _mm512_setzero_epi32(); - - __m512i c_int32_3p0 = _mm512_setzero_epi32(); - __m512i c_int32_3p1 = _mm512_setzero_epi32(); - __m512i c_int32_3p2 = _mm512_setzero_epi32(); - - __m512i c_int32_4p0 = _mm512_setzero_epi32(); - __m512i c_int32_4p1 = _mm512_setzero_epi32(); - __m512i c_int32_4p2 = _mm512_setzero_epi32(); - - __m512i c_int32_5p0 = _mm512_setzero_epi32(); - __m512i c_int32_5p1 = _mm512_setzero_epi32(); - __m512i c_int32_5p2 = _mm512_setzero_epi32(); - - for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) - { - // Load 4 rows with 48 elements each from B to 3 ZMM registers. It - // is to be noted that the B matrix is packed for use in vnni - // instructions and each load to ZMM register will have 4 elements - // along k direction and 16 elements across n directions, so 4x16 - // elements to a ZMM register. - b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); - - // Broadcast a[0,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); - - // Broadcast a[1,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); - c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); - - // Broadcast a[2,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-47] = a[2,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); - c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); - - // Broadcast a[3,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[3,0-47] = a[3,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); - c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); - c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_0, b2 ); - - // Broadcast a[4,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[4,0-47] = a[4,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); - c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); - c_int32_4p2 = _mm512_dpbusd_epi32( c_int32_4p2, a_int32_0, b2 ); - - // Broadcast a[5,kr:kr+4]. - a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); - - // Perform column direction mat-mul with k = 4. - // c[5,0-47] = a[5,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_5p0 = _mm512_dpbusd_epi32( c_int32_5p0, a_int32_0, b0 ); - c_int32_5p1 = _mm512_dpbusd_epi32( c_int32_5p1, a_int32_0, b1 ); - c_int32_5p2 = _mm512_dpbusd_epi32( c_int32_5p2, a_int32_0, b2 ); - } - // Handle k remainder. - if ( k_partial_pieces > 0 ) - { - b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); - - // Broadcast a[0,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); - c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); - c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); - - // Broadcast a[1,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); - c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); - c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); - - // Broadcast a[2,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[2,0-47] = a[2,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); - c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); - c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); - - // Broadcast a[3,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[3,0-47] = a[3,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); - c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); - c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_0, b2 ); - - // Broadcast a[4,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[4,0-47] = a[4,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); - c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); - c_int32_4p2 = _mm512_dpbusd_epi32( c_int32_4p2, a_int32_0, b2 ); - - // Broadcast a[5,kr:kr+4]. - memcpy - ( - &a_kfringe_buf, - ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), - ( k_partial_pieces * sizeof( uint8_t ) ) - ); - a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - - // Perform column direction mat-mul with k = 4. - // c[5,0-47] = a[5,kr:kr+4]*b[kr:kr+4,0-47] - c_int32_5p0 = _mm512_dpbusd_epi32( c_int32_5p0, a_int32_0, b0 ); - c_int32_5p1 = _mm512_dpbusd_epi32( c_int32_5p1, a_int32_0, b1 ); - c_int32_5p2 = _mm512_dpbusd_epi32( c_int32_5p2, a_int32_0, b2 ); - } - - // Load alpha and beta - __m512i selector1 = _mm512_set1_epi32( alpha ); - __m512i selector2 = _mm512_set1_epi32( beta ); - - // Scale by alpha - c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); - c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); - c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); - - c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); - c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); - c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); - - c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); - c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); - c_int32_2p2 = _mm512_mullo_epi32( selector1, c_int32_2p2 ); - - c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); - c_int32_3p1 = _mm512_mullo_epi32( selector1, c_int32_3p1 ); - c_int32_3p2 = _mm512_mullo_epi32( selector1, c_int32_3p2 ); - - c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); - c_int32_4p1 = _mm512_mullo_epi32( selector1, c_int32_4p1 ); - c_int32_4p2 = _mm512_mullo_epi32( selector1, c_int32_4p2 ); - - c_int32_5p0 = _mm512_mullo_epi32( selector1, c_int32_5p0 ); - c_int32_5p1 = _mm512_mullo_epi32( selector1, c_int32_5p1 ); - c_int32_5p2 = _mm512_mullo_epi32( selector1, c_int32_5p2 ); - - // Scale C by beta. - if ( beta != 0 ) - { - // c[0,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); - - // c[0,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); - - // c[1,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[1,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); - - // c[1,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); - - // c[2,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[2,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); - - // c[2,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_2p2 = _mm512_add_epi32( selector1, c_int32_2p2 ); - - // c[3,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); - - // c[3,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p1 = _mm512_add_epi32( selector1, c_int32_3p1 ); - - // c[3,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_3p2 = _mm512_add_epi32( selector1, c_int32_3p2 ); - - // c[4,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); - - // c[4,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_4p1 = _mm512_add_epi32( selector1, c_int32_4p1 ); - - // c[4,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_4p2 = _mm512_add_epi32( selector1, c_int32_4p2 ); - - // c[5,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); - - // c[5,16-31] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_5p1 = _mm512_add_epi32( selector1, c_int32_5p1 ); - - // c[5,32-47] - selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 2*16 ) ); - selector1 = _mm512_mullo_epi32( selector2, selector1 ); - c_int32_5p2 = _mm512_add_epi32( selector1, c_int32_5p2 ); - } - - // Post Ops - lpgemm_post_op* post_ops_list_temp = post_ops_list; - POST_OP_LABEL_LASTK_SAFE_JUMP -POST_OPS_BIAS_6x48: - { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 2 * 16 ) ); - - // c[0,0-15] - c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); - - // c[0,32-47] - c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - - // c[1, 16-31] - c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); - - // c[1,32-47] - c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); - - // c[2, 16-31] - c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); - - // c[2,32-47] - c_int32_2p2 = _mm512_add_epi32( a_int32_0, c_int32_2p2 ); - - // c[3,0-15] - c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); - - // c[3, 16-31] - c_int32_3p1 = _mm512_add_epi32( selector2, c_int32_3p1 ); - - // c[3,32-47] - c_int32_3p2 = _mm512_add_epi32( a_int32_0, c_int32_3p2 ); - - // c[4,0-15] - c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); - - // c[4, 16-31] - c_int32_4p1 = _mm512_add_epi32( selector2, c_int32_4p1 ); - - // c[4,32-47] - c_int32_4p2 = _mm512_add_epi32( a_int32_0, c_int32_4p2 ); - - // c[5,0-15] - c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); - - // c[5, 16-31] - c_int32_5p1 = _mm512_add_epi32( selector2, c_int32_5p1 ); - - // c[5,32-47] - c_int32_5p2 = _mm512_add_epi32( a_int32_0, c_int32_5p2 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_6x48: - { - //printf("relu\n"); - selector1 = _mm512_setzero_epi32(); - - // c[0,0-15] - c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); - - // c[0, 16-31] - c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); - - // c[0,32-47] - c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); - - // c[1,0-15] - c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); - - // c[1,16-31] - c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); - - // c[1,32-47] - c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); - - // c[2,0-15] - c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); - - // c[2,16-31] - c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); - - // c[2,32-47] - c_int32_2p2 = _mm512_max_epi32( selector1, c_int32_2p2 ); - - // c[3,0-15] - c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); - - // c[3,16-31] - c_int32_3p1 = _mm512_max_epi32( selector1, c_int32_3p1 ); - - // c[3,32-47] - c_int32_3p2 = _mm512_max_epi32( selector1, c_int32_3p2 ); - - // c[4,0-15] - c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); - - // c[4,16-31] - c_int32_4p1 = _mm512_max_epi32( selector1, c_int32_4p1 ); - - // c[4,32-47] - c_int32_4p2 = _mm512_max_epi32( selector1, c_int32_4p2 ); - - // c[5,0-15] - c_int32_5p0 = _mm512_max_epi32( selector1, c_int32_5p0 ); - - // c[5,16-31] - c_int32_5p1 = _mm512_max_epi32( selector1, c_int32_5p1 ); - - // c[5,32-47] - c_int32_5p2 = _mm512_max_epi32( selector1, c_int32_5p2 ); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_RELU_SCALE_6x48: - { - selector1 = _mm512_setzero_epi32(); - selector2 = - _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); - - __mmask16 relu_cmp_mask; - - // c[0, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_0p0) - - // c[0, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_0p1) - - // c[0, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_0p2) - - // c[1, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_1p0) - - // c[1, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_1p1) - - // c[1, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_1p2) - - // c[2, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_2p0) - - // c[2, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_2p1) - - // c[2, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_2p2) - - // c[3, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_3p0) - - // c[3, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_3p1) - - // c[3, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_3p2) - - // c[4, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_4p0) - - // c[4, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_4p1) - - // c[4, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_4p2) - - // c[5, 0-15] - RELU_SCALE_OP_S32_AVX512(c_int32_5p0) - - // c[5, 16-31] - RELU_SCALE_OP_S32_AVX512(c_int32_5p1) - - // c[5, 32-47] - RELU_SCALE_OP_S32_AVX512(c_int32_5p2) - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_DOWNSCALE_6x48: - { - selector1 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 1 * 16 ) ); - a_int32_0 = - _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j + ( 2 * 16 ) ); - - // c[0, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); - - // c[0, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); - - // c[0, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); - - // c[1, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); - - // c[1, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); - - // c[1, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); - - // c[2, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); - - // c[2, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); - - // c[2, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_2p2,a_int32_0,2,2); - - // c[3, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); - - // c[3, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_3p1,selector2,3,1); - - // c[3, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_3p2,a_int32_0,3,2); - - // c[4, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_4p0,selector1,4,0); - - // c[4, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_4p1,selector2,4,1); - - // c[4, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_4p2,a_int32_0,4,2); - - // c[5, 0-15] - CVT_MULRND_CVT32_CVT8(c_int32_5p0,selector1,5,0); - - // c[5, 16-31] - CVT_MULRND_CVT32_CVT8(c_int32_5p1,selector2,5,1); - - // c[5, 32-47] - CVT_MULRND_CVT32_CVT8(c_int32_5p2,a_int32_0,5,2); - - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } -POST_OPS_6x48_DISABLE: - ; - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), c_int32_0p0 ); - - // c[0, 16-31] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ), c_int32_0p1 ); - - // c[0,32-47] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 2*16 ), c_int32_0p2 ); - - // c[1,0-15] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), c_int32_1p0 ); - - // c[1,16-31] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ), c_int32_1p1 ); - - // c[1,32-47] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 2*16 ), c_int32_1p2 ); - - // c[2,0-15] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), c_int32_2p0 ); - - // c[2,16-31] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ), c_int32_2p1 ); - - // c[2,32-47] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 2*16 ), c_int32_2p2 ); - - // c[3,0-15] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), c_int32_3p0 ); - - // c[3,16-31] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ), c_int32_3p1 ); - - // c[3,32-47] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 2*16 ), c_int32_3p2 ); - - // c[4,0-15] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), c_int32_4p0 ); - - // c[4,16-31] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ), c_int32_4p1 ); - - // c[4,32-47] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 2*16 ), c_int32_4p2 ); - - // c[5,0-15] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_int32_5p0 ); - - // c[5,16-31] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ), c_int32_5p1 ); - - // c[5,32-47] - _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 2*16 ), c_int32_5p2 ); - - a = a + ( MR * ps_a ); - post_op_c_i += MR; - } - - if ( m_partial_pieces > 0 ) - { - if ( m_partial_pieces == 5 ) - { - dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 5 ); - lpgemm_rowvar_u8s8s32o32_5x48 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 4 ) - { - dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 4 ); - lpgemm_rowvar_u8s8s32o32_4x48 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 3 ) - { - dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 3 ); - lpgemm_rowvar_u8s8s32o32_3x48 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 2 ) - { - dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 2 ); - lpgemm_rowvar_u8s8s32o32_2x48 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - else if ( m_partial_pieces == 1 ) - { - dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 1 ); - lpgemm_rowvar_u8s8s32o32_1x48 - ( - k0, - a, rs_a, cs_a_use, - b, rs_b, cs_b, - ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, - is_last_k, - post_op_c_i, post_op_c_j, - post_ops_list, rs_c_downscale - ); - } - } -} -#endif diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa.h b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa.h index b983b0c617..9b1c55046e 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa.h +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,14 +35,19 @@ #ifndef BLIS_GEMM_INT8_PACKA #define BLIS_GEMM_INT8_PACKA -void get_packa_k64_u8s8s32o32_strides +typedef void (*packa_s32) ( - dim_t* rs_a, - dim_t* cs_a + uint8_t*, + const uint8_t*, + const dim_t, + const dim_t, + const dim_t, + dim_t*, + dim_t* ); void packa_k64_u8s8s32o32 - ( + ( uint8_t* pack_a_buffer_u8s8s32o32, const uint8_t* a, const dim_t lda, diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb.h b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb.h index 3f310c0a48..1d69148e3c 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb.h +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -45,14 +45,19 @@ BLIS_INLINE dim_t get_packb_u8s8s32o32_min_NR() return 16; } -void get_packb_nr64_u8s8s32o32_strides +typedef void (*packb_s32) ( - dim_t* rs_b, - dim_t* cs_b + int8_t*, + const int8_t*, + const dim_t, + const dim_t, + const dim_t, + dim_t*, + dim_t* ); void packb_nr64_u8s8s32o32 - ( + ( int8_t* pack_b_buffer_u8s8s32o32, const int8_t* b, const dim_t ldb, diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_s32_kern_macros.h b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_s32_kern_macros.h deleted file mode 100644 index bc3546736c..0000000000 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_s32_kern_macros.h +++ /dev/null @@ -1,103 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#ifndef LPGEMM_S32_KERN_MACROS_H -#define LPGEMM_S32_KERN_MACROS_H -#define S8_MIN (-128) -#define S8_MAX (+127) - -#define RELU_SCALE_OP_S32_AVX512(reg) \ - /* Generate indenx of elements <= 0.*/ \ - relu_cmp_mask = _mm512_cmple_epi32_mask( reg, selector1 ); \ - \ - /* Apply scaling on for <= 0 elements.*/ \ - reg = _mm512_mask_mullo_epi32( reg, relu_cmp_mask, reg, selector2 ); \ - -#define CVT_MULRND_CVT32_CVT8(reg,selector,m_ind,n_ind) \ - _mm_storeu_epi8 \ - ( \ - ( int8_t* )post_ops_list_temp->op_args3 + \ - ( rs_c_downscale * ( post_op_c_i + m_ind ) ) + post_op_c_j + ( n_ind * 16 ), \ - _mm512_cvtepi32_epi8 \ - ( \ - _mm512_cvtps_epi32 \ - ( \ - _mm512_min_ps \ - ( \ - _mm512_max_ps \ - ( \ - _mm512_mul_round_ps \ - ( \ - _mm512_cvtepi32_ps( reg ), \ - ( __m512 )selector, \ - ( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) \ - ) \ - , _mm512_set1_ps (( float )S8_MIN) \ - ) \ - , _mm512_set1_ps (( float )S8_MAX) \ - ) \ - ) \ - ) \ - ) \ - -#define CVT_MULRND_CVT32_CVT8_LT16(reg,selector,m_ind,n_ind) \ - _mm_storeu_epi8 \ - ( \ - buf0, \ - _mm512_cvtepi32_epi8 \ - ( \ - _mm512_cvtps_epi32 \ - ( \ - _mm512_min_ps \ - ( \ - _mm512_max_ps \ - ( \ - _mm512_mul_round_ps \ - ( \ - _mm512_cvtepi32_ps( reg ), \ - ( __m512 )selector, \ - ( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) \ - ) \ - , _mm512_set1_ps (( float )S8_MIN) \ - ) \ - , _mm512_set1_ps (( float )S8_MAX) \ - ) \ - ) \ - ) \ - ); \ - memcpy( ( int8_t* )post_ops_list_temp->op_args3 + \ - ( rs_c_downscale * ( post_op_c_i + m_ind ) ) + post_op_c_j + \ - ( n_ind * 16 ) , buf0, ( n0_rem * sizeof( int8_t ) ) ); \ - -#endif // LPGEMM_S32_KERN_MACROS_H diff --git a/aocl_dtl/aocldtl.c b/aocl_dtl/aocldtl.c index 6e7ee35102..a9b3db1786 100644 --- a/aocl_dtl/aocldtl.c +++ b/aocl_dtl/aocldtl.c @@ -5,7 +5,7 @@ * These functions are invoked though macros by * end user. * - * Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2023, Advanced Micro Devices, Inc. All rights reserved. * *=======================================================================*/ #include "blis.h" @@ -129,7 +129,7 @@ void DTL_Initialize( #if (AOCL_DTL_LOG_ENABLE || AOCL_DTL_DUMP_ENABLE) - /* Check if DTL logging is requested via envoronment variable */ + /* Check if DTL logging is requested via environment variable */ gbIsLoggingEnabled = bli_env_get_var( "AOCL_VERBOSE", TRUE ); #endif diff --git a/aocl_dtl/aocldtl.h b/aocl_dtl/aocldtl.h index f520518e9c..7f9934ed24 100644 --- a/aocl_dtl/aocldtl.h +++ b/aocl_dtl/aocldtl.h @@ -1,195 +1,195 @@ -/*=================================================================== - * File Name : aocldtl.h - * - * Description : This is main interface file for the end user - * It provides defination for all macros to be - * used by user to add debug/trace information. - * - * Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. - * - *==================================================================*/ - -#ifndef _AOCLDTL_H_ -#define _AOCLDTL_H_ - -#include "aocldtlcf.h" -#include "aocltpdef.h" -#include "aoclflist.h" -#include "aoclos.h" - -#define TRACE_TYPE_FENTRY (1) -#define TRACE_TYPE_FEXIT (2) -#define TRACE_TYPE_LOG (3) -#define TRACE_TYPE_RAW (4) - -/* Type definition for printf */ -#define AOCL_DEBUGPRINT printf - -/* Define the AOCL_DTL_INITIALIZE_ENABLE if any of the debug macro - * are defined */ -#if (AOCL_DTL_TRACE_ENABLE || AOCL_DTL_DUMP_ENABLE || AOCL_DTL_LOG_ENABLE) -#define AOCL_DTL_INITIALIZE_ENABLE -#endif - -#if AOCL_DTL_TRACE_ENABLE -/* Entry macro to trace the flow of control The parameter LogLevel specifies - the log level String will preferably contains the function name in which - this macro is invoked */ -#define AOCL_DTL_TRACE_ENTRY(LogLevel) \ - DTL_Trace(LogLevel, \ - TRACE_TYPE_FENTRY, \ - __FILE__, \ - __FUNCTION__, \ - __LINE__, \ - NULL); -#else -/* Dummy macro definition if the AOCL_DTL_TRACE_ENABLE macro is not enabled */ -#define AOCL_DTL_TRACE_ENTRY(LogLevel) -#endif - -#if AOCL_DTL_TRACE_ENABLE -/* Exit macro to trace the flow of control The parameter LogLevel specifies - log level String will preferably contains the function name in which this - macro is invoked */ -#define AOCL_DTL_TRACE_EXIT(LogLevel) \ - DTL_Trace(LogLevel, \ - TRACE_TYPE_FEXIT, \ - __FILE__, \ - __FUNCTION__, \ - __LINE__, \ - NULL); - -#define AOCL_DTL_TRACE_EXIT_ERR(LogLevel, Message) \ - DTL_Trace(LogLevel, \ - TRACE_TYPE_FEXIT, \ - __FILE__, \ - __FUNCTION__, \ - __LINE__, \ - Message); -#else -/* Dummy macro definition if the AOCL_DTL_TRACE_ENABLE macro is not enabled */ -#define AOCL_DTL_TRACE_EXIT(LogLevel) -#define AOCL_DTL_TRACE_EXIT_ERR(LogLevel, Message) -#endif - -#if AOCL_DTL_DUMP_ENABLE -/* Macro to Dump the DATA The parameters Buffer contains the data to be - dumped BufferSize specifies the no. of bytes to be dumped DataType - specifies the data type of Buffer */ -#define AOCL_DTL_DUMP(LogLevel, Buffer, BufferSize, DataType, String, OutputType) \ - /* Call the Dump function to Dump the DATA */ \ - DTL_DumpData(LogLevel, \ - Buffer, \ - BufferSize, \ - DataType, \ - String, \ - OutputType); -#else -/* Dummy macro definition if the AOCL_DTL_DUMP_ENABLE macro is not enabled */ -#define AOCL_DTL_DUMP(Buffer, BufferSize, DataType, String, OutputType) - -#endif - -#if AOCL_DTL_LOG_ENABLE -/* Macro to log the Data */ -#define AOCL_DTL_LOG(LogLevel, Message) \ - DTL_Trace(LogLevel, \ - TRACE_TYPE_LOG, \ - __FILE__, \ - __FUNCTION__, \ - __LINE__, \ - Message); -#else -/* Dummy macro definition if the AOCL_DTL_LOG_ENABLE macro is not enabled */ -#define AOCL_DTL_LOG(LogLevel, Message) -#endif - -#if AOCL_DTL_LOG_ENABLE - -void AOCL_DTL_start_perf_timer(void); -uint64 AOCL_DTL_get_time_spent(void); - -/* - * Logging of inputs can be enabled by two methods: - * - * 1. Using environment variable AOCL_VERBOSE. - * 2. APIs - * - * The API takes precedence over environment variable. - * - * The global flag is maintain in the code to track the final - * state of the logging feature. - */ -extern Bool gbIsLoggingEnabled; - -/* API to enable logging at runtime */ -#define AOCL_DTL_Enable_Logs() \ - /* Initialize DTL if not alredy done so */ \ - AOCL_DTL_INITIALIZE(AOCL_DTL_TRACE_LEVEL); \ - gbIsLoggingEnabled = TRUE; - -/* API to disable logging at runtime */ -#define AOCL_DTL_Disable_Logs() \ - /* Initialize DTL if not alredy done so */ \ - AOCL_DTL_INITIALIZE(AOCL_DTL_TRACE_LEVEL); \ - gbIsLoggingEnabled = FALSE; - -/* Macro to log the Data */ -#define AOCL_DTL_START_PERF_TIMER() \ - AOCL_DTL_start_perf_timer() -#else -/* Dummy macro definition if the AOCL_DTL_LOG_ENABLE macro is not enabled */ -#define AOCL_DTL_START_PERF_TIMER() -#endif - -/* Macro to initialize the prerequisite for debuging */ -#ifdef AOCL_DTL_INITIALIZE_ENABLE -#define AOCL_DTL_INITIALIZE(CURRENT_LOG_LEVEL) \ - DTL_Initialize(CURRENT_LOG_LEVEL); -#else -/* Dummy macro definition if the AOCL_DTL_INITIALIZE macro is not enabled */ -#define AOCL_DTL_INITIALIZE(CURRENT_LOG_LEVEL) -#endif - -/* Macro for uninitializing the prerequisite */ -#ifdef AOCL_DTL_INITIALIZE_ENABLE -#define AOCL_DTL_UNINITIALIZE() \ - DTL_Uninitialize(); -#else -/* Dummy macro definition if the AOCL_DTL_INITIALIZE macro is not enabled */ -#define AOCL_DTL_UNINITIALIZE() -#endif - -#ifdef AOCL_DTL_INITIALIZE_ENABLE -/* Prototypes for initializing and uninitializing the debug functions */ -void DTL_Initialize( - uint32 ui32CurrentLogLevel); -void DTL_Uninitialize(void); -#endif - -#if (AOCL_DTL_TRACE_ENABLE || AOCL_DTL_LOG_ENABLE) -/* Debug trace Function protoypes */ -void DTL_Trace( - uint8 ui8LogLevel, - uint8 ui8LogType, - const int8 *pi8FileName, - const int8 *pi8FunctionName, - uint32 ui32LineNumber, - const int8 *pi8Message); - -#endif - -#if AOCL_DTL_DUMP_ENABLE -/* Function Prototype for dumping the data */ -void DTL_DumpData( - uint8 ui8LogLevel, - void *pvBuffer, - uint32 ui32BufferSize, - uint8 ui8DataType, - int8 *pi8Message, - int8 i8OutputType); -#endif - -#endif /* _AOCLDTL_H_ */ - -/* --------------- End of aocldtl.h ----------------- */ +/*=================================================================== + * File Name : aocldtl.h + * + * Description : This is main interface file for the end user + * It provides defination for all macros to be + * used by user to add debug/trace information. + * + * Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. + * + *==================================================================*/ + +#ifndef _AOCLDTL_H_ +#define _AOCLDTL_H_ + +#include "aocldtlcf.h" +#include "aocltpdef.h" +#include "aoclflist.h" +#include "aoclos.h" + +#define TRACE_TYPE_FENTRY (1) +#define TRACE_TYPE_FEXIT (2) +#define TRACE_TYPE_LOG (3) +#define TRACE_TYPE_RAW (4) + +/* Type definition for printf */ +#define AOCL_DEBUGPRINT printf + +/* Define the AOCL_DTL_INITIALIZE_ENABLE if any of the debug macro + * are defined */ +#if (AOCL_DTL_TRACE_ENABLE || AOCL_DTL_DUMP_ENABLE || AOCL_DTL_LOG_ENABLE) +#define AOCL_DTL_INITIALIZE_ENABLE +#endif + +#if AOCL_DTL_TRACE_ENABLE +/* Entry macro to trace the flow of control The parameter LogLevel specifies + the log level String will preferably contains the function name in which + this macro is invoked */ +#define AOCL_DTL_TRACE_ENTRY(LogLevel) \ + DTL_Trace(LogLevel, \ + TRACE_TYPE_FENTRY, \ + __FILE__, \ + __FUNCTION__, \ + __LINE__, \ + NULL); +#else +/* Dummy macro definition if the AOCL_DTL_TRACE_ENABLE macro is not enabled */ +#define AOCL_DTL_TRACE_ENTRY(LogLevel) +#endif + +#if AOCL_DTL_TRACE_ENABLE +/* Exit macro to trace the flow of control The parameter LogLevel specifies + log level String will preferably contains the function name in which this + macro is invoked */ +#define AOCL_DTL_TRACE_EXIT(LogLevel) \ + DTL_Trace(LogLevel, \ + TRACE_TYPE_FEXIT, \ + __FILE__, \ + __FUNCTION__, \ + __LINE__, \ + NULL); + +#define AOCL_DTL_TRACE_EXIT_ERR(LogLevel, Message) \ + DTL_Trace(LogLevel, \ + TRACE_TYPE_FEXIT, \ + __FILE__, \ + __FUNCTION__, \ + __LINE__, \ + Message); +#else +/* Dummy macro definition if the AOCL_DTL_TRACE_ENABLE macro is not enabled */ +#define AOCL_DTL_TRACE_EXIT(LogLevel) +#define AOCL_DTL_TRACE_EXIT_ERR(LogLevel, Message) +#endif + +#if AOCL_DTL_DUMP_ENABLE +/* Macro to Dump the DATA The parameters Buffer contains the data to be + dumped BufferSize specifies the no. of bytes to be dumped DataType + specifies the data type of Buffer */ +#define AOCL_DTL_DUMP(LogLevel, Buffer, BufferSize, DataType, String, OutputType) \ + /* Call the Dump function to Dump the DATA */ \ + DTL_DumpData(LogLevel, \ + Buffer, \ + BufferSize, \ + DataType, \ + String, \ + OutputType); +#else +/* Dummy macro definition if the AOCL_DTL_DUMP_ENABLE macro is not enabled */ +#define AOCL_DTL_DUMP(Buffer, BufferSize, DataType, String, OutputType) + +#endif + +#if AOCL_DTL_LOG_ENABLE +/* Macro to log the Data */ +#define AOCL_DTL_LOG(LogLevel, Message) \ + DTL_Trace(LogLevel, \ + TRACE_TYPE_LOG, \ + __FILE__, \ + __FUNCTION__, \ + __LINE__, \ + Message); +#else +/* Dummy macro definition if the AOCL_DTL_LOG_ENABLE macro is not enabled */ +#define AOCL_DTL_LOG(LogLevel, Message) +#endif + +#if AOCL_DTL_LOG_ENABLE + +void AOCL_DTL_start_perf_timer(void); +uint64 AOCL_DTL_get_time_spent(void); + +/* + * Logging of inputs can be enabled by two methods: + * + * 1. Using environment variable AOCL_VERBOSE. + * 2. APIs + * + * The API takes precedence over environment variable. + * + * The global flag is maintain in the code to track the final + * state of the logging feature. + */ +extern Bool gbIsLoggingEnabled; + +/* API to enable logging at runtime */ +#define AOCL_DTL_Enable_Logs() \ + /* Initialize DTL if not alredy done so */ \ + AOCL_DTL_INITIALIZE(AOCL_DTL_TRACE_LEVEL); \ + gbIsLoggingEnabled = TRUE; + +/* API to disable logging at runtime */ +#define AOCL_DTL_Disable_Logs() \ + /* Initialize DTL if not alredy done so */ \ + AOCL_DTL_INITIALIZE(AOCL_DTL_TRACE_LEVEL); \ + gbIsLoggingEnabled = FALSE; + +/* Macro to log the Data */ +#define AOCL_DTL_START_PERF_TIMER() \ + AOCL_DTL_start_perf_timer() +#else +/* Dummy macro definition if the AOCL_DTL_LOG_ENABLE macro is not enabled */ +#define AOCL_DTL_START_PERF_TIMER() +#endif + +/* Macro to initialize the prerequisite for debuging */ +#ifdef AOCL_DTL_INITIALIZE_ENABLE +#define AOCL_DTL_INITIALIZE(CURRENT_LOG_LEVEL) \ + DTL_Initialize(CURRENT_LOG_LEVEL); +#else +/* Dummy macro definition if the AOCL_DTL_INITIALIZE macro is not enabled */ +#define AOCL_DTL_INITIALIZE(CURRENT_LOG_LEVEL) +#endif + +/* Macro for uninitializing the prerequisite */ +#ifdef AOCL_DTL_INITIALIZE_ENABLE +#define AOCL_DTL_UNINITIALIZE() \ + DTL_Uninitialize(); +#else +/* Dummy macro definition if the AOCL_DTL_INITIALIZE macro is not enabled */ +#define AOCL_DTL_UNINITIALIZE() +#endif + +#ifdef AOCL_DTL_INITIALIZE_ENABLE +/* Prototypes for initializing and uninitializing the debug functions */ +void DTL_Initialize( + uint32 ui32CurrentLogLevel); +void DTL_Uninitialize(void); +#endif + +#if (AOCL_DTL_TRACE_ENABLE || AOCL_DTL_LOG_ENABLE) +/* Debug trace Function protoypes */ +void DTL_Trace( + uint8 ui8LogLevel, + uint8 ui8LogType, + const int8 *pi8FileName, + const int8 *pi8FunctionName, + uint32 ui32LineNumber, + const int8 *pi8Message); + +#endif + +#if AOCL_DTL_DUMP_ENABLE +/* Function Prototype for dumping the data */ +void DTL_DumpData( + uint8 ui8LogLevel, + void *pvBuffer, + uint32 ui32BufferSize, + uint8 ui8DataType, + int8 *pi8Message, + int8 i8OutputType); +#endif + +#endif /* _AOCLDTL_H_ */ + +/* --------------- End of aocldtl.h ----------------- */ diff --git a/aocl_dtl/aocldtlcf.h b/aocl_dtl/aocldtlcf.h index 1f44f54405..408f38c516 100644 --- a/aocl_dtl/aocldtlcf.h +++ b/aocl_dtl/aocldtlcf.h @@ -1,77 +1,77 @@ -/*=================================================================== - * File Name : aocldtlcf.h - * - * Description : This is configuration file for debug and trace - * libaray, all debug features (except auto trace) - * can be enabled/disabled in this file. - * - * Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. - * - *==================================================================*/ - -#ifndef _AOCLDTLCF_H_ -#define _AOCLDTLCF_H_ - -/* Macro for tracing the log If the user wants to enable tracing he has to - enable this macro by making it to 1 else 0 */ -#define AOCL_DTL_TRACE_ENABLE 0 - -/* Macro for dumping the log If the user wants to enable dumping he has to - enable this macro by making it to 1 else 0 */ -#define AOCL_DTL_DUMP_ENABLE 0 - -/* Macro for dumping the log If the user wants to enable input logs he has to - enable this macro by making it to 1 else 0 */ -#define AOCL_DTL_LOG_ENABLE 0 - -/* Select the trace level till which you want to log the data */ -/* By default it will log for all levels */ -#define AOCL_DTL_TRACE_LEVEL AOCL_DTL_LEVEL_TRACE_5 - -/* user has to explicitly use the below macros to identify - ciriticality of the logged message */ -#define AOCL_DTL_LEVEL_ALL (15) -#define AOCL_DTL_LEVEL_TRACE_9 (14) -#define AOCL_DTL_LEVEL_TRACE_8 (13) -#define AOCL_DTL_LEVEL_TRACE_7 (12) /* Kernels */ -#define AOCL_DTL_LEVEL_TRACE_6 (11) -#define AOCL_DTL_LEVEL_TRACE_5 (10) -#define AOCL_DTL_LEVEL_TRACE_4 (9) -#define AOCL_DTL_LEVEL_TRACE_3 (8) -#define AOCL_DTL_LEVEL_TRACE_2 (7) -#define AOCL_DTL_LEVEL_TRACE_1 (6) /* BLIS/BLAS API */ -#define AOCL_DTL_LEVEL_VERBOSE (5) -#define AOCL_DTL_LEVEL_INFO (4) -#define AOCL_DTL_LEVEL_MINOR (3) -#define AOCL_DTL_LEVEL_MAJOR (2) -#define AOCL_DTL_LEVEL_CRITICAL (1) - - -#define AOCL_DTL_TRACE_FILE "aocldtl_trace.txt" -#define AOCL_DTL_AUTO_TRACE_FILE "aocldtl_auto_trace.rawfile" -#define AOCL_DTL_LOG_FILE "aocldtl_log.txt" - -/* The use can use below three macros for different data type while dumping data - * or specify the size of data type in bytes macro for character data type */ -#define AOCL_CHAR_DATA_TYPE (1) - -/* macro for short data type */ -#define AOCL_UINT16_DATA_TYPE (2) - -/* macro for String data type */ -#define AOCL_STRING_DATA_TYPE (3) - -/* macro for uint32 data type */ -#define AOCL_UINT32_DATA_TYPE (4) - -/* macro for printing Hex values */ -#define AOCL_LOG_HEX_VALUE ('x') - -/* macro for printing Decimal values */ -#define AOCL_LOG_DECIMAL_VALUE ('d') - - - -#endif /* _AOCLDTLCF_H_ */ - -/* --------------- End of aocldtlcf.h ----------------- */ +/*=================================================================== + * File Name : aocldtlcf.h + * + * Description : This is configuration file for debug and trace + * libaray, all debug features (except auto trace) + * can be enabled/disabled in this file. + * + * Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. + * + *==================================================================*/ + +#ifndef _AOCLDTLCF_H_ +#define _AOCLDTLCF_H_ + +/* Macro for tracing the log If the user wants to enable tracing he has to + enable this macro by making it to 1 else 0 */ +#define AOCL_DTL_TRACE_ENABLE 0 + +/* Macro for dumping the log If the user wants to enable dumping he has to + enable this macro by making it to 1 else 0 */ +#define AOCL_DTL_DUMP_ENABLE 0 + +/* Macro for dumping the log If the user wants to enable input logs he has to + enable this macro by making it to 1 else 0 */ +#define AOCL_DTL_LOG_ENABLE 0 + +/* Select the trace level till which you want to log the data */ +/* By default it will log for all levels */ +#define AOCL_DTL_TRACE_LEVEL AOCL_DTL_LEVEL_TRACE_5 + +/* user has to explicitly use the below macros to identify + ciriticality of the logged message */ +#define AOCL_DTL_LEVEL_ALL (15) +#define AOCL_DTL_LEVEL_TRACE_9 (14) +#define AOCL_DTL_LEVEL_TRACE_8 (13) +#define AOCL_DTL_LEVEL_TRACE_7 (12) /* Kernels */ +#define AOCL_DTL_LEVEL_TRACE_6 (11) +#define AOCL_DTL_LEVEL_TRACE_5 (10) +#define AOCL_DTL_LEVEL_TRACE_4 (9) +#define AOCL_DTL_LEVEL_TRACE_3 (8) +#define AOCL_DTL_LEVEL_TRACE_2 (7) +#define AOCL_DTL_LEVEL_TRACE_1 (6) /* BLIS/BLAS API */ +#define AOCL_DTL_LEVEL_VERBOSE (5) +#define AOCL_DTL_LEVEL_INFO (4) +#define AOCL_DTL_LEVEL_MINOR (3) +#define AOCL_DTL_LEVEL_MAJOR (2) +#define AOCL_DTL_LEVEL_CRITICAL (1) + + +#define AOCL_DTL_TRACE_FILE "aocldtl_trace.txt" +#define AOCL_DTL_AUTO_TRACE_FILE "aocldtl_auto_trace.rawfile" +#define AOCL_DTL_LOG_FILE "aocldtl_log.txt" + +/* The use can use below three macros for different data type while dumping data + * or specify the size of data type in bytes macro for character data type */ +#define AOCL_CHAR_DATA_TYPE (1) + +/* macro for short data type */ +#define AOCL_UINT16_DATA_TYPE (2) + +/* macro for String data type */ +#define AOCL_STRING_DATA_TYPE (3) + +/* macro for uint32 data type */ +#define AOCL_UINT32_DATA_TYPE (4) + +/* macro for printing Hex values */ +#define AOCL_LOG_HEX_VALUE ('x') + +/* macro for printing Decimal values */ +#define AOCL_LOG_DECIMAL_VALUE ('d') + + + +#endif /* _AOCLDTLCF_H_ */ + +/* --------------- End of aocldtlcf.h ----------------- */ diff --git a/aocl_dtl/aoclfal.c b/aocl_dtl/aoclfal.c index a317e69cbd..1eadf99b49 100644 --- a/aocl_dtl/aoclfal.c +++ b/aocl_dtl/aoclfal.c @@ -1,265 +1,265 @@ -/*=================================================================== - * File Name : aoclfal.c - * - * Description : Platform/os independed file handling API's - * - * Copyright (C) 2020, Advanced Micro Devices, Inc - * - *==================================================================*/ - -#include "aocltpdef.h" -#include "aocldtl.h" -#include "aoclfal.h" - - - -/* Disable instrumentation for following function, since they are called from - * Auto Generated execution trace handlers. */ - -/* The FAL function declaration */ -int32 AOCL_FAL_Close( - AOCL_FAL_FILE *fpFilePointer) __attribute__((no_instrument_function)); - -int32 AOCL_FAL_Error( - AOCL_FAL_FILE *fpFilePointer) __attribute__((no_instrument_function)); - -AOCL_FAL_FILE *AOCL_FAL_Open( - const int8 *pchFileName, - const int8 *pchMode) __attribute__((no_instrument_function)); - -int32 AOCL_FAL_Read( - void *pvBuffer, - int32 i32Size, - int32 i32Count, - AOCL_FAL_FILE *fpFilePointer) __attribute__((no_instrument_function)); - -int32 AOCL_FAL_Write( - const void *pvBuffer, - int32 i32Size, - int32 iCount, - AOCL_FAL_FILE *fpFilePointer) __attribute__((no_instrument_function)); - -/*============================================================================= -* Function Name : AOCL_FAL_Open -* Description : Used for opening a file specified by name -* Input Parameter(s) : int8 *pchFileName - Stores the file name (path) -* int8 *pchMode - Specify the mode for opening file -* Output Parameter(s) : None -* Return parameter(s) : AOCL_FAL_FILE - If the file is opened successfully -* NULL - If there is any error while opening file -*============================================================================*/ -AOCL_FAL_FILE *AOCL_FAL_Open( - const int8 *pchFileName, - const int8 *pchMode) -{ - AOCL_FAL_FILE *fpFileOpen = NULL; - /* Open the file with provided by specified path and mode in which it should - be opened. Refer to FILE I/O operation help for getting mode types */ - fpFileOpen = fopen(pchFileName, pchMode); - /* If the file is not opened then NULL value should be returned */ - if (NULL == fpFileOpen) - { - AOCL_DTL_LOG(AOCL_DTL_LEVEL_MAJOR, "Cannot open file: AOCL_FAL_Open()"); - } - return fpFileOpen; -} /* end of AOCL_FAL_Open */ - -/*============================================================================= -* Function Name : AOCL_FAL_Close -* Description : Used for closing a file specified by file pointer -* Input Parameter(s) : AOCL_FAL_FILE *fpFilePointer - File pointer -* Output Parameter(s) : None -* Return parameter(s) : 0 - If the file is closed successfully -* AOCL_FAL_CLOSE_ERROR - For any error while closing file -* -*============================================================================*/ -int32 AOCL_FAL_Close( - AOCL_FAL_FILE *fpFilePointer) -{ - /* Return value for the file close */ - int32 i32RetVal; - i32RetVal = AOCL_FAL_CLOSE_ERROR; - - /* Check whether the file pointer passed is valid or not */ - if (NULL == fpFilePointer) - { - AOCL_DTL_LOG(AOCL_DTL_LEVEL_MAJOR, "Can not close file: AOCL_FAL_Close()"); - return i32RetVal; - } - - /* Close the file using the FILE pointer passed */ - i32RetVal = fclose(fpFilePointer); - - /* If the return value is non zero then it indicates an error */ - if (i32RetVal) - { - AOCL_DTL_LOG(AOCL_DTL_LEVEL_MAJOR, - "Can't close file, Invalid file pointer passed"); - return i32RetVal; - } - - /* On successful closing of the file, function should return 0 */ - return i32RetVal; - -} /* End of AOCL_FAL_Close */ - -/*============================================================================= -* Function Name : AOCL_FAL_Read -* Description : Used for reading a file specified by file pointer. -* This function reads the specified number of bytes -* from the file into the buffer specified. The bytes -* read are returned by this function. -* Input Parameter(s) : int32 i32Size - Item size in bytes -* int32 i32Count - Maximum number of items to be read -* AOCL_FAL_FILE *fpFilePointer - File ptr to read from -* Output Parameter(s) : void *pvBuffer - Storage location of data -* Return parameter(s) : i32RetVal - Number of bytes read if successful -* AOCL_FAL_READ_ERROR - In case of error while reading -*============================================================================*/ -int32 AOCL_FAL_Read( - void *pvBuffer, - int32 i32Size, - int32 i32Count, - AOCL_FAL_FILE *fpFilePointer) -{ - /* Return value for the file read */ - int32 i32RetVal; - i32RetVal = AOCL_FAL_READ_ERROR; - - /* Check pointer used for pointing the storage location data is valid */ - if (NULL == pvBuffer) - { - AOCL_DTL_LOG(AOCL_DTL_LEVEL_MAJOR, - "Can not read the file, Buffer pointer is NULL"); - return i32RetVal; - } - - /* Check whether file pointer passed is valid */ - if (NULL == fpFilePointer) - { - AOCL_DTL_LOG(AOCL_DTL_LEVEL_MAJOR, - "Can not read the file, Buffer pointer is NULL"); - return i32RetVal; - } - - /* Read the file using file pointer */ - i32RetVal = fread(pvBuffer, i32Size, i32Count, fpFilePointer); - - if (i32RetVal != i32Count) - { - /* Check whether this is an end of file The AOCL_FAL_Error() will return - non-zero value to indicate an error */ - if (AOCL_FAL_Error(fpFilePointer)) /* AOCL_FAL_EndOfFile (fpFilePointer) */ - { - AOCL_DTL_LOG(AOCL_DTL_LEVEL_MAJOR, - "There is an error condition while file read"); - i32RetVal = AOCL_FAL_READ_ERROR; - } - /* This is condition where file read has encountered an end of file */ - else - { - AOCL_DTL_LOG(AOCL_DTL_LEVEL_MAJOR, "End of file..."); - } - } - - /* The number of bytes read by the file read operation. - * This value may be less than the actual count, due to end of file - * or an error while reading the file */ - return i32RetVal; - -} /* End of AOCL_FAL_Read */ - -/*============================================================================= -* Function Name : AOCL_FAL_Write -* Description : Used for writing data to a file specified by file -* pointer. The number of bytes written to file are -* written by this function. -* Input Parameter(s) : const void *pvBuffer - Pointer to data location from -* where the data to be copied - int32 i32Size - Item size in bytes -* int32 i32Count - Maximum number of items to be -* written -* AOCL_FAL_FILE *fpFilePointer - File pointer to write to -* Output Parameter(s) : None -* Return parameter(s) : i32RetVal - Number of bytes written if successful -* AOCL_FAL_WRITE_ERROR - In case of error while writing -*============================================================================*/ -int32 AOCL_FAL_Write( - const void *pvBuffer, - int32 i32Size, - int32 iCount, - AOCL_FAL_FILE *fpFilePointer) -{ - /* Return value for write operation */ - int32 i32RetVal; - i32RetVal = AOCL_FAL_WRITE_ERROR; - /* Check pointer used for pointing the storage location data is valid */ - if (NULL == pvBuffer) - { - AOCL_DTL_LOG(AOCL_DTL_LEVEL_MAJOR, "Can not perform file write"); - return i32RetVal; - } - - /* Check whether the file pointer passed is valid or not */ - if (NULL == fpFilePointer) - { - AOCL_DTL_LOG(AOCL_DTL_LEVEL_MAJOR, "Can not perform file write"); - return i32RetVal; - } - - /* Write into the file specified by the file pointer */ - i32RetVal = fwrite(pvBuffer, i32Size, iCount, fpFilePointer); - - /* If the number of bytes written into the file are less than specified - * bytes then it is an error while file writing */ - if (i32RetVal != iCount) - { - AOCL_DTL_LOG(AOCL_DTL_LEVEL_MAJOR, "File write operation error"); - i32RetVal = AOCL_FAL_WRITE_ERROR; - } - - /* The return value of the file write operation */ - return i32RetVal; - -} /* End of AOCL_FAL_Write */ - -/*============================================================================= -* Function Name : AOCL_FAL_Error -* Description : Used for testing an error on the file specified -* Input Parameter(s) : AOCL_FAL_FILE *fpFilePointer - File pointer -* Output Parameter(s) : None -* Return parameter(s) : non-zero - Indicates an end of file -* 0 - Indicates that function is successful -* non-zero - Indicates that there is some error -* AOCL_FAL_ERROR - Indicates error during the operation -*============================================================================*/ -int32 AOCL_FAL_Error( - AOCL_FAL_FILE *fpFilePointer) -{ - /* Used for storing the return value for ferror function */ - int32 i32RetVal; - i32RetVal = AOCL_FAL_FERROR; - - /* Check whether the file pointer is NULL */ - if (NULL == fpFilePointer) - { - AOCL_DTL_LOG(AOCL_DTL_LEVEL_MAJOR, "Invalid file pointer is passed"); - return i32RetVal; - } - - /* Call the ferror function to get an error on the file */ - i32RetVal = ferror(fpFilePointer); - - /* Check for the return value, it non-zero there is an error */ - if (i32RetVal) - { - AOCL_DTL_LOG(AOCL_DTL_LEVEL_MAJOR, "The file has some error"); - i32RetVal = AOCL_FAL_FERROR; - } - - /* In case of success, this function should return 0 */ - return i32RetVal; - -} /* End of AOCL_FAL_Error */ - -/* ------------------- End of aoclfal.c ----------------------- */ +/*=================================================================== + * File Name : aoclfal.c + * + * Description : Platform/os independed file handling API's + * + * Copyright (C) 2020, Advanced Micro Devices, Inc + * + *==================================================================*/ + +#include "aocltpdef.h" +#include "aocldtl.h" +#include "aoclfal.h" + + + +/* Disable instrumentation for following function, since they are called from + * Auto Generated execution trace handlers. */ + +/* The FAL function declaration */ +int32 AOCL_FAL_Close( + AOCL_FAL_FILE *fpFilePointer) __attribute__((no_instrument_function)); + +int32 AOCL_FAL_Error( + AOCL_FAL_FILE *fpFilePointer) __attribute__((no_instrument_function)); + +AOCL_FAL_FILE *AOCL_FAL_Open( + const int8 *pchFileName, + const int8 *pchMode) __attribute__((no_instrument_function)); + +int32 AOCL_FAL_Read( + void *pvBuffer, + int32 i32Size, + int32 i32Count, + AOCL_FAL_FILE *fpFilePointer) __attribute__((no_instrument_function)); + +int32 AOCL_FAL_Write( + const void *pvBuffer, + int32 i32Size, + int32 iCount, + AOCL_FAL_FILE *fpFilePointer) __attribute__((no_instrument_function)); + +/*============================================================================= +* Function Name : AOCL_FAL_Open +* Description : Used for opening a file specified by name +* Input Parameter(s) : int8 *pchFileName - Stores the file name (path) +* int8 *pchMode - Specify the mode for opening file +* Output Parameter(s) : None +* Return parameter(s) : AOCL_FAL_FILE - If the file is opened successfully +* NULL - If there is any error while opening file +*============================================================================*/ +AOCL_FAL_FILE *AOCL_FAL_Open( + const int8 *pchFileName, + const int8 *pchMode) +{ + AOCL_FAL_FILE *fpFileOpen = NULL; + /* Open the file with provided by specified path and mode in which it should + be opened. Refer to FILE I/O operation help for getting mode types */ + fpFileOpen = fopen(pchFileName, pchMode); + /* If the file is not opened then NULL value should be returned */ + if (NULL == fpFileOpen) + { + AOCL_DTL_LOG(AOCL_DTL_LEVEL_MAJOR, "Cannot open file: AOCL_FAL_Open()"); + } + return fpFileOpen; +} /* end of AOCL_FAL_Open */ + +/*============================================================================= +* Function Name : AOCL_FAL_Close +* Description : Used for closing a file specified by file pointer +* Input Parameter(s) : AOCL_FAL_FILE *fpFilePointer - File pointer +* Output Parameter(s) : None +* Return parameter(s) : 0 - If the file is closed successfully +* AOCL_FAL_CLOSE_ERROR - For any error while closing file +* +*============================================================================*/ +int32 AOCL_FAL_Close( + AOCL_FAL_FILE *fpFilePointer) +{ + /* Return value for the file close */ + int32 i32RetVal; + i32RetVal = AOCL_FAL_CLOSE_ERROR; + + /* Check whether the file pointer passed is valid or not */ + if (NULL == fpFilePointer) + { + AOCL_DTL_LOG(AOCL_DTL_LEVEL_MAJOR, "Can not close file: AOCL_FAL_Close()"); + return i32RetVal; + } + + /* Close the file using the FILE pointer passed */ + i32RetVal = fclose(fpFilePointer); + + /* If the return value is non zero then it indicates an error */ + if (i32RetVal) + { + AOCL_DTL_LOG(AOCL_DTL_LEVEL_MAJOR, + "Can't close file, Invalid file pointer passed"); + return i32RetVal; + } + + /* On successful closing of the file, function should return 0 */ + return i32RetVal; + +} /* End of AOCL_FAL_Close */ + +/*============================================================================= +* Function Name : AOCL_FAL_Read +* Description : Used for reading a file specified by file pointer. +* This function reads the specified number of bytes +* from the file into the buffer specified. The bytes +* read are returned by this function. +* Input Parameter(s) : int32 i32Size - Item size in bytes +* int32 i32Count - Maximum number of items to be read +* AOCL_FAL_FILE *fpFilePointer - File ptr to read from +* Output Parameter(s) : void *pvBuffer - Storage location of data +* Return parameter(s) : i32RetVal - Number of bytes read if successful +* AOCL_FAL_READ_ERROR - In case of error while reading +*============================================================================*/ +int32 AOCL_FAL_Read( + void *pvBuffer, + int32 i32Size, + int32 i32Count, + AOCL_FAL_FILE *fpFilePointer) +{ + /* Return value for the file read */ + int32 i32RetVal; + i32RetVal = AOCL_FAL_READ_ERROR; + + /* Check pointer used for pointing the storage location data is valid */ + if (NULL == pvBuffer) + { + AOCL_DTL_LOG(AOCL_DTL_LEVEL_MAJOR, + "Can not read the file, Buffer pointer is NULL"); + return i32RetVal; + } + + /* Check whether file pointer passed is valid */ + if (NULL == fpFilePointer) + { + AOCL_DTL_LOG(AOCL_DTL_LEVEL_MAJOR, + "Can not read the file, Buffer pointer is NULL"); + return i32RetVal; + } + + /* Read the file using file pointer */ + i32RetVal = fread(pvBuffer, i32Size, i32Count, fpFilePointer); + + if (i32RetVal != i32Count) + { + /* Check whether this is an end of file The AOCL_FAL_Error() will return + non-zero value to indicate an error */ + if (AOCL_FAL_Error(fpFilePointer)) /* AOCL_FAL_EndOfFile (fpFilePointer) */ + { + AOCL_DTL_LOG(AOCL_DTL_LEVEL_MAJOR, + "There is an error condition while file read"); + i32RetVal = AOCL_FAL_READ_ERROR; + } + /* This is condition where file read has encountered an end of file */ + else + { + AOCL_DTL_LOG(AOCL_DTL_LEVEL_MAJOR, "End of file..."); + } + } + + /* The number of bytes read by the file read operation. + * This value may be less than the actual count, due to end of file + * or an error while reading the file */ + return i32RetVal; + +} /* End of AOCL_FAL_Read */ + +/*============================================================================= +* Function Name : AOCL_FAL_Write +* Description : Used for writing data to a file specified by file +* pointer. The number of bytes written to file are +* written by this function. +* Input Parameter(s) : const void *pvBuffer - Pointer to data location from +* where the data to be copied + int32 i32Size - Item size in bytes +* int32 i32Count - Maximum number of items to be +* written +* AOCL_FAL_FILE *fpFilePointer - File pointer to write to +* Output Parameter(s) : None +* Return parameter(s) : i32RetVal - Number of bytes written if successful +* AOCL_FAL_WRITE_ERROR - In case of error while writing +*============================================================================*/ +int32 AOCL_FAL_Write( + const void *pvBuffer, + int32 i32Size, + int32 iCount, + AOCL_FAL_FILE *fpFilePointer) +{ + /* Return value for write operation */ + int32 i32RetVal; + i32RetVal = AOCL_FAL_WRITE_ERROR; + /* Check pointer used for pointing the storage location data is valid */ + if (NULL == pvBuffer) + { + AOCL_DTL_LOG(AOCL_DTL_LEVEL_MAJOR, "Can not perform file write"); + return i32RetVal; + } + + /* Check whether the file pointer passed is valid or not */ + if (NULL == fpFilePointer) + { + AOCL_DTL_LOG(AOCL_DTL_LEVEL_MAJOR, "Can not perform file write"); + return i32RetVal; + } + + /* Write into the file specified by the file pointer */ + i32RetVal = fwrite(pvBuffer, i32Size, iCount, fpFilePointer); + + /* If the number of bytes written into the file are less than specified + * bytes then it is an error while file writing */ + if (i32RetVal != iCount) + { + AOCL_DTL_LOG(AOCL_DTL_LEVEL_MAJOR, "File write operation error"); + i32RetVal = AOCL_FAL_WRITE_ERROR; + } + + /* The return value of the file write operation */ + return i32RetVal; + +} /* End of AOCL_FAL_Write */ + +/*============================================================================= +* Function Name : AOCL_FAL_Error +* Description : Used for testing an error on the file specified +* Input Parameter(s) : AOCL_FAL_FILE *fpFilePointer - File pointer +* Output Parameter(s) : None +* Return parameter(s) : non-zero - Indicates an end of file +* 0 - Indicates that function is successful +* non-zero - Indicates that there is some error +* AOCL_FAL_ERROR - Indicates error during the operation +*============================================================================*/ +int32 AOCL_FAL_Error( + AOCL_FAL_FILE *fpFilePointer) +{ + /* Used for storing the return value for ferror function */ + int32 i32RetVal; + i32RetVal = AOCL_FAL_FERROR; + + /* Check whether the file pointer is NULL */ + if (NULL == fpFilePointer) + { + AOCL_DTL_LOG(AOCL_DTL_LEVEL_MAJOR, "Invalid file pointer is passed"); + return i32RetVal; + } + + /* Call the ferror function to get an error on the file */ + i32RetVal = ferror(fpFilePointer); + + /* Check for the return value, it non-zero there is an error */ + if (i32RetVal) + { + AOCL_DTL_LOG(AOCL_DTL_LEVEL_MAJOR, "The file has some error"); + i32RetVal = AOCL_FAL_FERROR; + } + + /* In case of success, this function should return 0 */ + return i32RetVal; + +} /* End of AOCL_FAL_Error */ + +/* ------------------- End of aoclfal.c ----------------------- */ diff --git a/aocl_dtl/aoclfal.h b/aocl_dtl/aoclfal.h index 9b8074528d..401ed4c355 100644 --- a/aocl_dtl/aoclfal.h +++ b/aocl_dtl/aoclfal.h @@ -1,50 +1,50 @@ -/*=================================================================== - * File Name : aoclfal.h - * - * Description : Interfaces for platform/os independed file - * handling API's - * - * Copyright (C) 2020, Advanced Micro Devices, Inc - * - *==================================================================*/ - -#ifndef _AOCL_FAL_H_ -#define _AOCL_FAL_H_ - -/* The possible error values of FAL */ -#define AOCL_FAL_SUCCESS 0 -#define AOCL_FAL_CLOSE_ERROR -1 -#define AOCL_FAL_READ_ERROR -2 -#define AOCL_FAL_WRITE_ERROR -3 -#define AOCL_FAL_EOF_ERROR -6 -#define AOCL_FAL_FERROR -7 - -/* The type definition for FILE */ -#define AOCL_FAL_FILE FILE - -/* The FAL function declaration */ -int32 AOCL_FAL_Close( - AOCL_FAL_FILE *fpFilePointer); - -int32 AOCL_FAL_Error( - AOCL_FAL_FILE *fpFilePointer); - -AOCL_FAL_FILE *AOCL_FAL_Open( - const int8 *pchFileName, - const int8 *pchMode); - -int32 AOCL_FAL_Read( - void *pvBuffer, - int32 i32Size, - int32 i32Count, - AOCL_FAL_FILE *fpFilePointer); - -int32 AOCL_FAL_Write( - const void *pvBuffer, - int32 i32Size, - int32 iCount, - AOCL_FAL_FILE *fpFilePointer); - -#endif /* _AOCL_FAL_H_ */ - -/* --------------- End of aoclfal.h ----------------- */ +/*=================================================================== + * File Name : aoclfal.h + * + * Description : Interfaces for platform/os independed file + * handling API's + * + * Copyright (C) 2020, Advanced Micro Devices, Inc + * + *==================================================================*/ + +#ifndef _AOCL_FAL_H_ +#define _AOCL_FAL_H_ + +/* The possible error values of FAL */ +#define AOCL_FAL_SUCCESS 0 +#define AOCL_FAL_CLOSE_ERROR -1 +#define AOCL_FAL_READ_ERROR -2 +#define AOCL_FAL_WRITE_ERROR -3 +#define AOCL_FAL_EOF_ERROR -6 +#define AOCL_FAL_FERROR -7 + +/* The type definition for FILE */ +#define AOCL_FAL_FILE FILE + +/* The FAL function declaration */ +int32 AOCL_FAL_Close( + AOCL_FAL_FILE *fpFilePointer); + +int32 AOCL_FAL_Error( + AOCL_FAL_FILE *fpFilePointer); + +AOCL_FAL_FILE *AOCL_FAL_Open( + const int8 *pchFileName, + const int8 *pchMode); + +int32 AOCL_FAL_Read( + void *pvBuffer, + int32 i32Size, + int32 i32Count, + AOCL_FAL_FILE *fpFilePointer); + +int32 AOCL_FAL_Write( + const void *pvBuffer, + int32 i32Size, + int32 iCount, + AOCL_FAL_FILE *fpFilePointer); + +#endif /* _AOCL_FAL_H_ */ + +/* --------------- End of aoclfal.h ----------------- */ diff --git a/aocl_dtl/aocltpdef.h b/aocl_dtl/aocltpdef.h index 7c08455369..d842fffbac 100644 --- a/aocl_dtl/aocltpdef.h +++ b/aocl_dtl/aocltpdef.h @@ -1,42 +1,42 @@ - -/*=================================================================== - * File Name : aocltpdef.h - * - * Description : Abstraction for various datatypes used by DTL. - * - * Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. - * - *==================================================================*/ -#ifndef AOCL_TYPEDEF_H_ -#define AOCL_TYPEDEF_H_ - -#include -#include -#include -#include -#include -#ifndef _WIN32 -#include -#else -typedef int pid_t; -#endif - -typedef double Double; -typedef float Float; -typedef void Void; -typedef unsigned char uint8; -typedef unsigned short int uint16; -typedef unsigned int uint32; -typedef unsigned long uint64; -typedef uint8 *STRING; -typedef unsigned char Bool; -typedef char int8; -typedef signed long int int32; -typedef short int int16; - -typedef Void *AOCL_HANDLE; -typedef pid_t AOCL_TID; - -#endif /*AOCL_TYPEDEF_H_ */ - -/* --------------- End of aocltpdef.h ----------------- */ + +/*=================================================================== + * File Name : aocltpdef.h + * + * Description : Abstraction for various datatypes used by DTL. + * + * Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. + * + *==================================================================*/ +#ifndef AOCL_TYPEDEF_H_ +#define AOCL_TYPEDEF_H_ + +#include +#include +#include +#include +#include +#ifndef _WIN32 +#include +#else +typedef int pid_t; +#endif + +typedef double Double; +typedef float Float; +typedef void Void; +typedef unsigned char uint8; +typedef unsigned short int uint16; +typedef unsigned int uint32; +typedef unsigned long uint64; +typedef uint8 *STRING; +typedef unsigned char Bool; +typedef char int8; +typedef signed long int int32; +typedef short int int16; + +typedef Void *AOCL_HANDLE; +typedef pid_t AOCL_TID; + +#endif /*AOCL_TYPEDEF_H_ */ + +/* --------------- End of aocltpdef.h ----------------- */ diff --git a/aocl_dtl/test_dtl.c b/aocl_dtl/test_dtl.c index 978f4ac44b..08ff3296c3 100644 --- a/aocl_dtl/test_dtl.c +++ b/aocl_dtl/test_dtl.c @@ -1,96 +1,96 @@ -/*=================================================================== - * File Name : test_dtl.c - * - * Description : Unit test cases for dtl. - * - * Copyright (C) 2020, Advanced Micro Devices, Inc - * - *==================================================================*/ - -#if 0 // Disable this for normal build. - -#include "aocltpdef.h" -#include "aocldtl.h" - -int aocl_allocate(double**A, double** B, double** C, int N) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); - - *A = (double*)malloc(sizeof(double) * N); - if (*A == NULL) - { - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_MAJOR, "Error allocating memory to A"); - return 1; - } - - *B = (double*)malloc(sizeof(double) * N); - if (*B == NULL) - { - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_MAJOR, "Error allocating memory to B"); - return 1; - } - - *C = (double*)malloc(sizeof(double) * N); - if (*C == NULL) - { - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_MAJOR, "Error allocating memory to C"); - return 1; - } - - for (int i = 0; i < N; i++) - { - (*A)[i] = (double)((i + 1) * 1.0); - (*B)[i] = (double)((i - 1) * 1.0); - (*C)[i] = (double)((i) * 1.0); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO, " aocl_allocate()"); - return 0; -} - -void sumV(double* A, double* B, double* C, int N) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); - if ((A == NULL) || (B == NULL) || (C == NULL)) - { - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_MAJOR, "Invalid Pointers"); - return; - } - for (int i = 0; i < N; i++) - { - C[i] += A[i] + B[i]; - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); -} - -int main(void) -{ - int status = 0; - double* A = NULL; - double* B = NULL; - double* C = NULL; - - printf("Initializing\n"); - AOCL_DTL_INITIALIZE(AOCL_DTL_LEVEL_ALL); - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); - - status = aocl_allocate(&A, &B, &C, 120); - if (status != 0) - { - printf("Error allocating memory\n"); - - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_CRITICAL, "Error in function aocl_allocate()"); - AOCL_DTL_UNINITIALIZE(); - exit(1); - } - - sumV(A, B, C, 120); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - AOCL_DTL_UNINITIALIZE(); - - return 0; -} -#endif +/*=================================================================== + * File Name : test_dtl.c + * + * Description : Unit test cases for dtl. + * + * Copyright (C) 2020, Advanced Micro Devices, Inc + * + *==================================================================*/ + +#if 0 // Disable this for normal build. + +#include "aocltpdef.h" +#include "aocldtl.h" + +int aocl_allocate(double**A, double** B, double** C, int N) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); + + *A = (double*)malloc(sizeof(double) * N); + if (*A == NULL) + { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_MAJOR, "Error allocating memory to A"); + return 1; + } + + *B = (double*)malloc(sizeof(double) * N); + if (*B == NULL) + { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_MAJOR, "Error allocating memory to B"); + return 1; + } + + *C = (double*)malloc(sizeof(double) * N); + if (*C == NULL) + { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_MAJOR, "Error allocating memory to C"); + return 1; + } + + for (int i = 0; i < N; i++) + { + (*A)[i] = (double)((i + 1) * 1.0); + (*B)[i] = (double)((i - 1) * 1.0); + (*C)[i] = (double)((i) * 1.0); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO, " aocl_allocate()"); + return 0; +} + +void sumV(double* A, double* B, double* C, int N) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); + if ((A == NULL) || (B == NULL) || (C == NULL)) + { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_MAJOR, "Invalid Pointers"); + return; + } + for (int i = 0; i < N; i++) + { + C[i] += A[i] + B[i]; + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); +} + +int main(void) +{ + int status = 0; + double* A = NULL; + double* B = NULL; + double* C = NULL; + + printf("Initializing\n"); + AOCL_DTL_INITIALIZE(AOCL_DTL_LEVEL_ALL); + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); + + status = aocl_allocate(&A, &B, &C, 120); + if (status != 0) + { + printf("Error allocating memory\n"); + + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_CRITICAL, "Error in function aocl_allocate()"); + AOCL_DTL_UNINITIALIZE(); + exit(1); + } + + sumV(A, B, C, 120); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + AOCL_DTL_UNINITIALIZE(); + + return 0; +} +#endif diff --git a/bench/Makefile b/bench/Makefile index 0203d5a5b0..751f7129a5 100755 --- a/bench/Makefile +++ b/bench/Makefile @@ -303,4 +303,4 @@ endif clean: cleanx cleanx: - - $(RM_F) *.o *.x \ No newline at end of file + - $(RM_F) *.o *.x diff --git a/bench/bench_amaxv.c b/bench/bench_amaxv.c index 2a0e578975..eb37319b6f 100644 --- a/bench/bench_amaxv.c +++ b/bench/bench_amaxv.c @@ -247,4 +247,4 @@ int main( int argc, char** argv ) fclose(fout); return 0; -} \ No newline at end of file +} diff --git a/bench/bench_aocl_gemm/Makefile b/bench/bench_aocl_gemm/Makefile index 91b3a7b587..897a982ba3 100755 --- a/bench/bench_aocl_gemm/Makefile +++ b/bench/bench_aocl_gemm/Makefile @@ -106,7 +106,8 @@ CFLAGS += -I$(TEST_SRC_PATH) all: blis blis: \ - bench_lpgemm_blis.x + bench_lpgemm_blis.x \ + bench_lpgemm_utils_blis.x # --Object file rules -- diff --git a/bench/bench_aocl_gemm/bench_input.txt b/bench/bench_aocl_gemm/bench_input.txt index d8b8226a13..9034a0d550 100644 --- a/bench/bench_aocl_gemm/bench_input.txt +++ b/bench/bench_aocl_gemm/bench_input.txt @@ -1,3 +1,199 @@ +u r p 480 20 2050 2050 20 20 +u r p 481 20 2050 2050 20 20 +u r p 482 20 2050 2050 20 20 +u r p 483 20 2050 2050 20 20 +u r R 484 20 2050 2050 20 20 +u r R 485 20 2050 2050 20 20 +u r R 480 39 2050 2050 39 39 +u r R 481 39 2050 2050 39 39 +u r R 482 39 2050 2050 39 39 +u r R 483 39 2050 2050 39 39 +u r R 484 39 2050 2050 39 39 +u r p 485 39 2050 2050 39 39 +u r p 480 50 2050 2050 50 50 +u r p 481 50 2050 2050 50 50 +u r p 482 50 2050 2050 50 50 +u r p 483 50 2050 2050 50 50 +u r p 484 50 2050 2050 50 50 +u r p 485 50 2050 2050 50 50 +u r R 480 1108 2050 2050 1108 1108 +u r R 481 1108 2050 2050 1108 1108 +u r R 482 1108 2050 2050 1108 1108 +u r R 483 1108 2050 2050 1108 1108 +u r R 484 1108 2050 2050 1108 1108 +u r R 485 1108 2050 2050 1108 1108 +u r R 480 1127 2050 2050 1127 1127 +u r R 481 1127 2050 2050 1127 1127 +u r R 482 1127 2050 2050 1127 1127 +u r R 483 1127 2050 2050 1127 1127 +u r p 484 1127 2050 2050 1127 1127 +u r p 485 1127 2050 2050 1127 1127 +u r p 480 1138 2050 2050 1138 1138 +u r p 481 1138 2050 2050 1138 1138 +u r p 482 1138 2050 2050 1138 1138 +u r p 483 1138 2050 2050 1138 1138 +u r p 484 1138 2050 2050 1138 1138 +u r p 485 1138 2050 2050 1138 1138 +u r p 1 1 3 3 1 1 +u r p 1 9 3 3 9 9 +u r p 1 2048 3 3 2048 2048 +u r p 1 2048 5192 5192 2048 2048 +u r p 9 1 3 3 1 1 +u r p 576 1 3500 3500 1 1 +u r p 1 1 1 1 1 1 +u r p 102 1088 1024 1024 1088 1088 +u r p 102 2048 1024 1024 2048 2048 +u r p 485 656 1024 1024 656 656 +u r p 483 656 1024 1024 656 656 +u r p 81 128 3 3 128 128 +u r p 1022 512 515 515 512 512 +u r p 74 512 515 515 512 512 +u r p 253 2048 515 515 2048 2048 +u r p 8192 1040 515 515 1040 1040 +u r p 10 1029 515 515 1029 1029 +u r p 24 1040 2050 2050 1040 1040 +u r p 1024 1029 2050 2050 1029 1029 +u r p 480 660 2050 2050 660 660 +u r p 481 660 2050 2050 660 660 +u r p 482 660 2050 2050 660 660 +u r p 483 660 2050 2050 660 660 +u r p 484 660 2050 2050 660 660 +u r p 485 660 2050 2050 660 660 +u r p 480 679 2050 2050 679 679 +u r p 481 679 2050 2050 679 679 +u r p 482 679 2050 2050 679 679 +u r p 483 679 2050 2050 679 679 +u r p 484 679 2050 2050 679 679 +u r p 485 679 2050 2050 679 679 +u r p 480 690 2050 2050 690 690 +u r p 481 690 2050 2050 690 690 +u r p 482 690 2050 2050 690 690 +u r p 483 690 2050 2050 690 690 +u r p 484 690 2050 2050 690 690 +u r p 485 690 2050 2050 690 690 +u r p 480 660 2048 2048 660 660 +u r p 481 660 2048 2048 660 660 +u r p 482 660 2048 2048 660 660 +u r p 483 660 2048 2048 660 660 +u r p 484 660 2048 2048 660 660 +u r p 485 660 2048 2048 660 660 +u r p 480 679 2048 2048 679 679 +u r p 481 679 2048 2048 679 679 +u r p 482 679 2048 2048 679 679 +u r p 483 679 2048 2048 679 679 +u r p 484 679 2048 2048 679 679 +u r p 485 679 2048 2048 679 679 +u r p 480 690 2048 2048 690 690 +u r p 481 690 2048 2048 690 690 +u r p 482 690 2048 2048 690 690 +u r p 483 690 2048 2048 690 690 +u r p 484 690 2048 2048 690 690 +u r p 485 690 2048 2048 690 690 +u r p 480 656 1024 1024 656 656 +u r p 480 128 3 3 128 128 +u r p 1024 512 515 515 512 512 +u r p 1024 2048 1024 1024 2048 2048 +u r p 1024 2048 515 515 2048 2048 +u r p 1024 1040 515 515 1040 1040 +u r p 5 1029 515 515 1029 1029 +u r p 1024 1029 515 515 1029 1029 +u r p 1024 1040 2050 2050 1040 1040 +u r p 1029 1029 2050 2050 1029 1029 +u r R 480 646 2050 2050 646 646 +u r R 481 646 2050 2050 646 646 +u r R 482 646 2050 2050 646 646 +u r R 483 646 2050 2050 646 646 +u r R 484 646 2050 2050 646 646 +u r R 485 646 2050 2050 646 646 +u r R 481 656 2050 2050 656 656 +u r R 482 656 2050 2050 656 656 +u r R 483 656 2050 2050 656 656 +u r R 484 656 2050 2050 656 656 +u r p 485 656 2050 2050 656 656 +u r p 480 672 2050 2050 672 672 +u r p 481 672 2050 2050 672 672 +u r p 482 672 2050 2050 672 672 +u r p 483 672 2050 2050 672 672 +u r p 484 672 2050 2050 672 672 +u r p 485 672 2050 2050 672 672 +u r p 480 688 2050 2050 688 688 +u r p 481 688 2050 2050 688 688 +u r r 482 688 2050 2050 688 688 +u r r 483 688 2050 2050 688 688 +u r r 484 688 2050 2050 688 688 +u r r 485 688 2050 2050 688 688 +u r r 1024 512 64 64 512 512 +u r r 16 256 512 512 256 256 +u r r 480 640 512 512 640 640 +u r r 64 768 512 512 768 768 +u r r 128 128 128 128 128 128 +u r r 1024 64 512 512 64 64 +u r r 1024 256 32 32 256 256 +u r r 1024 512 64 64 512 512 +u r r 480 640 512 512 640 640 +u r p 1024 32 256 256 32 32 +u r P 1024 64 512 512 64 64 +u r P 64 800 320 320 800 800 +u r P 64 768 512 512 768 768 +u r P 16 256 512 512 256 256 +u r P 128 128 128 128 128 128 +u r P 256 512 256 256 512 512 +u r P 1024 1024 1024 1024 1024 1024 +u r P 480 640 1024 1024 640 640 +u r P 480 640 256 256 640 640 +u r P 8 64 32 32 64 64 +u r P 9 64 32 32 64 64 +u r P 10 128 64 64 128 128 +u r P 8 8 8 8 8 8 +u r P 12 12 12 12 12 12 +u r P 25 25 25 25 25 25 +u r P 25 25 20 20 25 25 +u r r 4096 256 5 5 256 256 +u r r 3000 256 128 128 256 256 +u r r 4096 1024 512 512 1024 1024 +u r r 144 256 5 5 256 256 +u r r 144 256 128 128 256 256 +u r r 144 1024 512 512 1024 1024 +u r r 480 688 256 256 688 688 +u r r 480 640 512 512 640 640 +u r r 480 640 1024 1024 640 640 +u r r 64 800 320 320 800 800 +u r r 64 768 512 512 768 768 +u r r 16 256 512 512 256 256 +u r r 128 128 128 128 128 128 +u r r 256 512 256 256 512 512 +u r r 1024 1024 1024 1024 1024 1024 +u r r 1024 32 256 256 32 32 +u r r 1024 64 512 512 64 64 +u r r 1024 256 32 32 256 256 +u r r 1024 512 64 64 512 512 +u r r 512 32 256 256 32 32 +u r r 512 768 512 512 768 768 +u r r 512 256 32 32 256 256 +u r r 512 512 64 64 512 512 +u r r 512 256 768 768 256 256 +u r r 768 768 1024 1024 768 768 +u r r 768 768 768 768 768 768 +u r r 2048 2048 2048 2048 2048 2048 +u r r 4096 4096 4096 4096 4096 4096 +f c p 2482 1127 2050 2482 2050 2482 +f c p 2483 1127 2050 2483 2050 2483 +f c p 2484 1127 2050 2484 2050 2484 +f c p 2485 1127 2050 2485 2050 2485 +f c p 480 1138 2050 480 2050 480 +f c p 481 1138 2050 481 2050 481 +f c p 482 1138 2050 482 2050 482 +f c p 483 1138 2050 483 2050 483 +f c p 484 1138 2050 484 2050 484 +f c p 485 1138 2050 485 2050 485 +f c p 1 1 3 1 3 1 +f c p 1 9 3 1 3 1 +f c p 1 2048 3 1 3 1 +f c p 1 2048 5192 1 5192 1 +f c p 9 1 3 9 3 9 +f c p 576 1 3500 576 3500 576 +f c p 1 1 1 1 1 1 +f c p 102 1088 1024 102 1024 102 b r r 480 20 2050 2050 20 20 b r r 481 20 2050 2050 20 20 b r r 482 20 2050 2050 20 20 diff --git a/bench/bench_aocl_gemm/bench_lpgemm.c b/bench/bench_aocl_gemm/bench_lpgemm.c index 92b7a7a1a6..7dd049b159 100644 --- a/bench/bench_aocl_gemm/bench_lpgemm.c +++ b/bench/bench_aocl_gemm/bench_lpgemm.c @@ -1,3 +1,37 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + #include #include #include @@ -5,6 +39,7 @@ #include #include #include +#include #include "blis.h" @@ -21,11 +56,39 @@ int32_t global_n_repeat = 0; char global_dscale_out = 'n'; +dim_t num_eltwise = 0; // To keep track of eltwise operations. + #define _XSTR(str) #str #define XSTR(str) _XSTR(str) #define GEN_FUNC_NAME(prototype,ctype) prototype ## ctype +inline void float_to_bf16( float* float_value, bfloat16* bf16_val ) +{ + /*Set offset 2 to copy most significant 2 bytes of float + to convert float values to bf16 values*/ + memcpy( ( bf16_val ), (char *)( float_value ) + 2, sizeof ( bfloat16 ) ); +} + +inline float bf16_to_float + ( + bfloat16 bf16_val + ) +{ + int32_t inter_temp = *( ( int16_t* ) &bf16_val ); + inter_temp = inter_temp << 16; + float float_value = *( float* ) ( &inter_temp ); + return float_value; +} + +inline void convert_float_arr_to_bf16( float* array, bfloat16* array_bf16, int size ) +{ + for (int i=0; i< size; i++) + { + float_to_bf16( ( array + i ), ( array_bf16 + i ) ); + } +} + #define GEN_FILL_ARRAY_FUNC(ctype) \ void fill_array_ ## ctype ( void* arr, dim_t size ) \ { \ @@ -38,21 +101,21 @@ void fill_array_ ## ctype ( void* arr, dim_t size ) \ GEN_FILL_ARRAY_FUNC(uint8_t) GEN_FILL_ARRAY_FUNC(int8_t) +GEN_FILL_ARRAY_FUNC(int16_t) GEN_FILL_ARRAY_FUNC(float) GEN_FILL_ARRAY_FUNC(int32_t) -inline void float_to_bf16( float* float_value, bfloat16* bf16_val ) +void fill_array_bfloat16( void* arr, dim_t size ) { - /*Set offset 2 to copy most significant 2 bytes of float - to convert float values to bf16 values*/ - memcpy( ( bf16_val ), (char *)( float_value ) + 2, sizeof ( bfloat16 ) ); -} - -inline void convert_float_arr_to_bf16( float* array, bfloat16* array_bf16, int size ) -{ - for (int i=0; i< size; i++) + float* c_float = ( float* ) bli_malloc_user( sizeof( float ) * size ); + for ( dim_t i = 0; i < size; ++i ) { - float_to_bf16( ( array + i ), ( array_bf16 + i ) ); + c_float[i] = 2.0; + } + convert_float_arr_to_bf16( c_float, arr, size ); + if ( c_float != NULL ) + { + bli_free_user( c_float ); } } @@ -178,6 +241,10 @@ GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) GEN_BLIS_MAT_MUL_FUNC(bfloat16,bfloat16,float,float,bf16bf16f32of32) GEN_BLIS_MAT_MUL_FUNC(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16) GEN_BLIS_MAT_MUL_FUNC(float,float,float,float,f32f32f32of32) +GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,int32_t,int32_t,s8s8s32os32) +GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,int8_t,int32_t,s8s8s32os8) +GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,int16_t,int16_t,s8s8s16os16) +GEN_BLIS_MAT_MUL_FUNC(int8_t,int8_t,int8_t,int16_t,s8s8s16os8) double get_gflops ( @@ -234,7 +301,7 @@ void mat_mul_bench_driver_ ## BLAS_SFX \ { \ if ( bench_mode == 'a' ) \ { \ - memset( ( void* ) c, 0, sizeof( C_type ) * m * n ); \ + GEN_FUNC_NAME(fill_array_,C_type)( c, ( m * n ) ); \ } \ \ struct timespec tstart={0,0}, tend={0,0}; \ @@ -269,6 +336,10 @@ GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,bfloat16,float,float,bf16bf16f32of32) GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16) GEN_MAT_MUL_BENCH_DRV_FUNC(float,float,float,float,f32f32f32of32) +GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,int32_t,int32_t,s8s8s32os32) +GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,int8_t,int32_t,s8s8s32os8) +GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,int16_t,int16_t,s8s8s16os16) +GEN_MAT_MUL_BENCH_DRV_FUNC(int8_t,int8_t,int8_t,int16_t,s8s8s16os8) int max (int a, int b) { @@ -280,33 +351,32 @@ int min (int a, int b) return ( a < b ? a : b ); } -#define GEN_MAT_MUL_ACC_CHK_DOWNSCALE(C_type,ACCUM_type,SCALE_type,BLAS_DOWNSCALE_SFX) \ -inline C_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX \ +#define GEN_MAT_MUL_ACC_CHK_DOWNSCALE(ACCUM_type,SCALE_type,BLAS_DOWNSCALE_SFX) \ +inline ACCUM_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX \ (\ ACCUM_type temp_accum,\ - C_type out_temp_accum, \ aocl_post_op* post_op, \ dim_t j \ )\ {\ - out_temp_accum = ( C_type ) min ( max ( nearbyintf( ( SCALE_type )temp_accum * \ + ACCUM_type out_temp_accum = ( ACCUM_type ) min ( max ( nearbyintf( ( SCALE_type )temp_accum * \ ( *( ( SCALE_type* )post_op->sum.scale_factor + j ) ) ), S8_MIN ), S8_MAX ) ; \ return out_temp_accum; \ }\ -GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int16_t,float,u8s8s16os8) -GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int32_t,float,u8s8s32os8) +GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int16_t,float,u8s8s16os8) +GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int32_t,float,u8s8s32os8) +GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int32_t,float,s8s8s32os8) +GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int16_t,float,s8s8s16os8) -inline bfloat16 mat_mul_accuracy_check_downscale_bf16bf16f32obf16 +inline float mat_mul_accuracy_check_downscale_bf16bf16f32obf16 ( - float temp_accum, - bfloat16 out_temp_accum, - aocl_post_op* post_op, + float temp_accum, + aocl_post_op* post_op, dim_t j ) { - float_to_bf16( ( &temp_accum ), ( &out_temp_accum ) ); - return out_temp_accum; + return temp_accum; } #define GEN_MAT_MUL_ACC_CHK_ACCUM(A_type, B_type, C_type,ACCUM_type,BLAS_SFX) \ @@ -345,77 +415,167 @@ GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32) GEN_MAT_MUL_ACC_CHK_ACCUM(float,float,float,float,f32f32f32of32) +GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int8_t,int32_t,s8s8s32os8) +GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int32_t,int32_t,s8s8s32os32) +GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int8_t,int16_t,s8s8s16os8) +GEN_MAT_MUL_ACC_CHK_ACCUM(int8_t,int8_t,int16_t,int16_t,s8s8s16os16) -inline float bf16_to_float - ( - bfloat16 bf16_val - ) -{ - int32_t inter_temp = *( ( int16_t* ) &bf16_val ); - inter_temp = inter_temp << 16; - float float_value = *( float* ) ( &inter_temp ); - return float_value; -} - inline float mat_mul_accuracy_check_accum_bf16bf16f32of32 ( - bfloat16* a, - bfloat16* b, - float* c_ref, + bfloat16* a, + bfloat16* b, + float* c_ref, float temp_accum, - float alpha, - float beta, - dim_t rs_a, + float alpha, + float beta, + dim_t rs_a, dim_t rs_b, - dim_t cs_a, - dim_t cs_b, + dim_t cs_a, + dim_t cs_b, dim_t rs_c_ref, - dim_t cs_c_ref, - dim_t i, - dim_t j, - dim_t k + dim_t cs_c_ref, + dim_t i, + dim_t j, + dim_t k ) { - for ( dim_t p = 0; p < k; ++p) - { - float a_float = bf16_to_float( *( a + i * rs_a + p * cs_a ) ); - float b_float = bf16_to_float( *( b + p * rs_b + j * cs_b ) ); - temp_accum += ( ( a_float ) * ( b_float ) ); - } - temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) ) - + ( alpha * temp_accum ); - return temp_accum; + for ( dim_t p = 0; p < k; ++p) + { + float a_float = bf16_to_float( *( a + i * rs_a + p * cs_a ) ); + float b_float = bf16_to_float( *( b + p * rs_b + j * cs_b ) ); + temp_accum += ( ( a_float ) * ( b_float ) ); + } + temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) ) + + ( alpha * temp_accum ); + return temp_accum; } inline float mat_mul_accuracy_check_accum_bf16bf16f32obf16 ( - bfloat16* a, - bfloat16* b, - bfloat16* c_ref, + bfloat16* a, + bfloat16* b, + bfloat16* c_ref, float temp_accum, - float alpha, - float beta, - dim_t rs_a, + float alpha, + float beta, + dim_t rs_a, dim_t rs_b, - dim_t cs_a, - dim_t cs_b, + dim_t cs_a, + dim_t cs_b, dim_t rs_c_ref, - dim_t cs_c_ref, - dim_t i, - dim_t j, - dim_t k + dim_t cs_c_ref, + dim_t i, + dim_t j, + dim_t k ) { - for ( dim_t p = 0; p < k; ++p) - { - float a_float = bf16_to_float( *( a + i*rs_a + p*cs_a ) ); - float b_float = bf16_to_float( *( b + p*rs_b + j*cs_b ) ); - temp_accum += ( ( a_float ) * ( b_float ) ); - } - float c_ref_float = bf16_to_float( *( c_ref + i*rs_c_ref + j*cs_c_ref ) ); - temp_accum = ( beta * ( c_ref_float ) ) + ( alpha * temp_accum ); - - return temp_accum; + for ( dim_t p = 0; p < k; ++p) + { + float a_float = bf16_to_float( *( a + i*rs_a + p*cs_a ) ); + float b_float = bf16_to_float( *( b + p*rs_b + j*cs_b ) ); + temp_accum += ( ( a_float ) * ( b_float ) ); + } + float c_ref_float = bf16_to_float( *( c_ref + i*rs_c_ref + j*cs_c_ref ) ); + temp_accum = ( beta * ( c_ref_float ) ) + ( alpha * temp_accum ); + + return temp_accum; +} + +#define GEN_GELU_TANH_POSTOP_INT(ACCUM_type,BLAS_SFX) \ +inline ACCUM_type GELU_TANH_post_op_ ## BLAS_SFX \ + (\ + ACCUM_type temp_accum \ + )\ +{\ + float gelu_reference = 0.5 *(double)temp_accum * (1 + tanhf( 0.797884 * ( (double)temp_accum + \ + ( 0.044715 * ((double)temp_accum * (double)temp_accum * \ + (double)temp_accum ) ) ) ) ); \ + temp_accum = round (gelu_reference); \ + return temp_accum; \ +}\ + +GEN_GELU_TANH_POSTOP_INT(int16_t,u8s8s16os8) +GEN_GELU_TANH_POSTOP_INT(int16_t,u8s8s16os16) +GEN_GELU_TANH_POSTOP_INT(int32_t,u8s8s32os8) +GEN_GELU_TANH_POSTOP_INT(int32_t,u8s8s32os32) +GEN_GELU_TANH_POSTOP_INT(int32_t,s8s8s32os8) +GEN_GELU_TANH_POSTOP_INT(int32_t,s8s8s32os32) +GEN_GELU_TANH_POSTOP_INT(int16_t,s8s8s16os8) +GEN_GELU_TANH_POSTOP_INT(int16_t,s8s8s16os16) + +#define GEN_GELU_TANH_POSTOP_FLOAT(BLAS_SFX) \ +inline float GELU_TANH_post_op_ ## BLAS_SFX \ + (\ + float temp_accum \ + )\ +{\ + temp_accum = 0.5 *(double)temp_accum * (1 + tanhf( 0.797884 * ( (double)temp_accum + \ + ( 0.044715 * ((double)temp_accum * (double)temp_accum * \ + (double)temp_accum ) ) ) ) ); \ + return temp_accum; \ +}\ + +GEN_GELU_TANH_POSTOP_FLOAT(f32f32f32of32) +GEN_GELU_TANH_POSTOP_FLOAT(bf16bf16f32of32) +GEN_GELU_TANH_POSTOP_FLOAT(bf16bf16f32obf16) + +#define GEN_GELU_ERF_POSTOP_INT(ACCUM_type,BLAS_SFX) \ +inline ACCUM_type GELU_ERF_post_op_ ## BLAS_SFX \ + (\ + ACCUM_type temp_accum \ + )\ +{\ + float gelu_reference = 0.5 *(double)temp_accum * (1 + erff( (double)temp_accum * 0.707107 )); \ + temp_accum = round (gelu_reference); \ + return temp_accum; \ +}\ + +GEN_GELU_ERF_POSTOP_INT(int16_t,u8s8s16os8) +GEN_GELU_ERF_POSTOP_INT(int16_t,u8s8s16os16) +GEN_GELU_ERF_POSTOP_INT(int32_t,u8s8s32os8) +GEN_GELU_ERF_POSTOP_INT(int32_t,u8s8s32os32) +GEN_GELU_ERF_POSTOP_INT(int32_t,s8s8s32os8) +GEN_GELU_ERF_POSTOP_INT(int32_t,s8s8s32os32) +GEN_GELU_ERF_POSTOP_INT(int16_t,s8s8s16os8) +GEN_GELU_ERF_POSTOP_INT(int16_t,s8s8s16os16) + +#define GEN_GELU_ERF_POSTOP_FLOAT(BLAS_SFX) \ +inline float GELU_ERF_post_op_ ## BLAS_SFX \ + (\ + float temp_accum \ + )\ +{\ + temp_accum = 0.5 *(double)temp_accum * (1 + erff( (double)temp_accum * 0.707107 )); \ + return temp_accum; \ +}\ + +GEN_GELU_ERF_POSTOP_FLOAT(f32f32f32of32) +GEN_GELU_ERF_POSTOP_FLOAT(bf16bf16f32of32) +GEN_GELU_ERF_POSTOP_FLOAT(bf16bf16f32obf16) + +#define GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(C_type, ACCUM_type) \ +void mat_mul_get_output_type_val ## ACCUM_type ## C_type \ + ( \ + C_type* out_temp_accum, \ + ACCUM_type* temp_accum \ + ) \ +{ \ + ( *out_temp_accum ) = ( C_type )( *temp_accum ); \ +} \ + +GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int32_t,int32_t) +GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int8_t,int32_t) +GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int16_t,int16_t) +GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(int8_t,int16_t) +GEN_MAT_MUL_GET_OUTPUT_TYPE_VALUE(float,float) + +void mat_mul_get_output_type_valfloatbfloat16 + ( + bfloat16* out_temp_accum, + float* temp_accum + ) +{ + float_to_bf16( temp_accum, out_temp_accum ); } #define GEN_MAT_MUL_ACC_CHK_DRV_FUNC(A_type,B_type,C_type,ACCUM_type,SCALE_type,BLAS_SFX,BLAS_DOWNSCALE_SFX) \ @@ -472,59 +632,76 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ \ if ( post_op != NULL ) \ { \ - /* Apply bias followed by relu. */ \ - if ( post_op->seq_vector[0] == BIAS ) \ + dim_t ele_i = 0; \ + for ( dim_t op_id = 0; op_id < post_op->seq_length; ++op_id ) \ { \ - if ( post_op->seq_length >= 1 ) \ + if ( post_op->seq_vector[op_id] == BIAS ) \ { \ temp_accum += ( *( ( ACCUM_type* )post_op->bias.bias + j ) ); \ } \ - if ( ( post_op->seq_length > 1 ) && \ - ( post_op->seq_vector[1] == ELTWISE ) ) \ + else if ( post_op->seq_vector[op_id] == ELTWISE ) \ { \ - if ( post_op->eltwise.algo.alpha != NULL ) /* PReLU*/ \ + if ( ( post_op->eltwise + ele_i )->algo.algo_type == \ + PRELU ) /* PReLU*/ \ { \ temp_accum = ( temp_accum > 0 ) ? \ temp_accum : \ ( temp_accum * \ - *( ( ACCUM_type* ) post_op->eltwise.algo.alpha ) ); \ + *( ( ACCUM_type* ) ( post_op->eltwise + ele_i )->algo.alpha ) ); \ + ele_i += 1; \ } \ - else \ + else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \ + GELU_TANH ) /* TANH GeLU*/ \ { \ - temp_accum = ( temp_accum > 0 ) ? temp_accum : 0 ; \ + temp_accum = GEN_FUNC_NAME(GELU_TANH_post_op_,BLAS_SFX) (temp_accum);\ + ele_i += 1; \ } \ - } \ - } \ - else if ( post_op->seq_vector[0] == ELTWISE ) \ - { \ - if ( post_op->seq_length >= 1 ) \ - { \ - if ( post_op->eltwise.algo.alpha != NULL ) /* PReLU*/ \ + else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \ + GELU_ERF ) /* ERF GeLU*/ \ { \ - temp_accum = ( temp_accum > 0 ) ? \ - temp_accum : \ - ( temp_accum * *( ( ACCUM_type* ) post_op->eltwise.algo.alpha ) ); \ + temp_accum = GEN_FUNC_NAME(GELU_ERF_post_op_,BLAS_SFX) (temp_accum);\ + ele_i += 1; \ } \ - else \ + else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \ + RELU ) /* ReLU*/ \ { \ temp_accum = ( temp_accum > 0 ) ? temp_accum : 0 ; \ + ele_i += 1; \ + } \ + else if ( ( post_op->eltwise + ele_i )->algo.algo_type == \ + CLIP ) /* CLIP*/ \ + { \ + temp_accum = \ + min \ + ( \ + max \ + ( \ + temp_accum, \ + *( ( ACCUM_type* ) \ + ( post_op->eltwise + ele_i )->algo.alpha ) \ + ), \ + *( ( ACCUM_type* ) \ + ( post_op->eltwise + ele_i )->algo.beta) \ + ); \ + ele_i += 1; \ } \ + else \ + {} \ } \ - if ( ( post_op->seq_length > 1 ) && ( post_op->seq_vector[1] == BIAS ) ) \ + else if ( post_op->seq_vector[op_id] == SCALE ) \ { \ - temp_accum += ( *( ( ACCUM_type* )post_op->bias.bias + j ) ); \ + temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_downscale_,BLAS_DOWNSCALE_SFX) \ + (temp_accum, post_op, j); \ } \ + else \ + {} \ } \ } \ - if ( global_dscale_out == 'y' ) \ - { \ - out_temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_downscale_,BLAS_DOWNSCALE_SFX) \ - (temp_accum, out_temp_accum, post_op, j); \ - } \ - else \ - { \ - out_temp_accum = ( C_type )temp_accum; \ - } \ + /* Need to convert to downscaled type if required.*/ \ + mat_mul_get_output_type_val ## ACCUM_type ## C_type \ + ( \ + &out_temp_accum, &temp_accum \ + ); \ \ if ( *( c + ( rs_c * i ) + ( cs_c * j ) ) != out_temp_accum ) \ { \ @@ -535,7 +712,7 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ XSTR(BLAS_SFX), m, n, k, lda, ldb, ldc ); \ fflush( fout ); \ } \ - printf("failure, m: %ld, n: %ld, k: %ld\n", i, j, k ); \ + printf("failure, m: %ld, n: %ld, k: %ld\n", i, j, k); \ goto cleanup_acc; \ } \ } \ @@ -550,7 +727,11 @@ GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int32_t,int32_t,float,u8s8s32os32,u8 GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int8_t,int32_t,float,u8s8s32os8,u8s8s32os8) GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,bfloat16,float,float,float,bf16bf16f32of32,bf16bf16f32obf16) GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,bfloat16,bfloat16,float,float,bf16bf16f32obf16,bf16bf16f32obf16) -GEN_MAT_MUL_ACC_CHK_DRV_FUNC(float,float,float,float,float,f32f32f32of32,bf16bf16f32obf16) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(float,float,float,float,float,f32f32f32of32,bf16bf16f32obf16) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int32_t,int32_t,float,s8s8s32os32,s8s8s32os8) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int8_t,int32_t,float,s8s8s32os8,s8s8s32os8) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int16_t,int16_t,float,s8s8s16os16,s8s8s16os8) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(int8_t,int8_t,int8_t,int16_t,float,s8s8s16os8,s8s8s16os8) /* Only supports bias followed by RELU and vice versa for now.*/ \ #define GEN_MAT_MUL_POST_OPS_CREATOR(C_type,DSCALE_type,BLAS_SFX) \ @@ -569,8 +750,8 @@ aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \ return NULL; \ } \ \ - /* Only supporting 3 post ops at max for now.*/ \ - dim_t max_post_ops_seq_length = 3; \ + /* Only supporting 5 post ops at max for now.*/ \ + dim_t max_post_ops_seq_length = 5; \ post_ops->seq_vector = ( AOCL_POST_OP_TYPE* ) \ malloc \ ( \ @@ -587,30 +768,79 @@ aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \ /* Parse post ops list.*/ \ dim_t cur_op_index = 0; \ /* Ensure the buffers that use NULL check in deinit code is properly set to NULL.*/ \ - post_ops->eltwise.algo.alpha = NULL; \ + post_ops->eltwise = NULL; \ post_ops->bias.bias = NULL; \ post_ops->sum.scale_factor = NULL; \ if ( post_ops_str != NULL ) \ { \ char* ops_tok = strtok(post_ops_str, ", " ); \ + bool is_relu = FALSE; \ bool is_param_relu = FALSE; \ + bool is_gelu_tanh = FALSE; \ + bool is_gelu_erf = FALSE; \ + bool is_clip = FALSE; \ + dim_t activator_idx = 0; \ + dim_t clip_idx = 0; \ + \ + /* Ensure only one activator is used as an eltwise post-op.*/ \ + bool is_activator_set = FALSE; \ + num_eltwise = 0; \ while ( ops_tok ) \ { \ if ( strcmp( ops_tok, "bias") == 0 ) \ { \ post_ops->seq_vector[cur_op_index] = BIAS; \ + cur_op_index++; \ } \ - else if ( strcmp( ops_tok, "relu") == 0 ) \ + else if ( ( strcmp( ops_tok, "relu") == 0 ) && \ + ( is_activator_set == FALSE ) ) \ { \ post_ops->seq_vector[cur_op_index] = ELTWISE; \ + is_relu = TRUE; \ + is_activator_set = TRUE; \ + num_eltwise += 1; \ + activator_idx = cur_op_index; \ + cur_op_index++; \ } \ - else if ( strcmp( ops_tok, "prelu") == 0 ) \ + else if ( ( strcmp( ops_tok, "prelu") == 0 ) && \ + ( is_activator_set == FALSE ) ) \ { \ post_ops->seq_vector[cur_op_index] = ELTWISE; \ is_param_relu = TRUE; \ + is_activator_set = TRUE; \ + num_eltwise += 1; \ + activator_idx = cur_op_index; \ + cur_op_index++; \ + } \ + else if ( ( strcmp( ops_tok, "gelu_tanh") == 0 ) && \ + ( is_activator_set == FALSE ) ) \ + { \ + post_ops->seq_vector[cur_op_index] = ELTWISE; \ + is_gelu_tanh = TRUE; \ + is_activator_set = TRUE; \ + num_eltwise += 1; \ + activator_idx = cur_op_index; \ + cur_op_index++; \ + } \ + else if ( ( strcmp( ops_tok, "gelu_erf") == 0 ) && \ + ( is_activator_set == FALSE ) ) \ + { \ + post_ops->seq_vector[cur_op_index] = ELTWISE; \ + is_gelu_erf = TRUE; \ + is_activator_set = TRUE; \ + num_eltwise += 1; \ + activator_idx = cur_op_index; \ + cur_op_index++; \ + } \ + else if ( strcmp( ops_tok, "clip") == 0 ) \ + { \ + post_ops->seq_vector[cur_op_index] = ELTWISE; \ + is_clip = TRUE; \ + num_eltwise += 1; \ + clip_idx = cur_op_index; \ + cur_op_index++; \ } \ ops_tok = strtok( NULL, ", " ); \ - cur_op_index++; \ } \ \ /* Allocate bias buffer, return early if alloc fails.*/ \ @@ -623,17 +853,80 @@ aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \ } \ GEN_FUNC_NAME(fill_array_post_ops_,C_type)( post_ops->bias.bias, n ); \ \ - post_ops->eltwise.is_power_of_2 = FALSE; \ - post_ops->eltwise.scale_factor = NULL; \ - post_ops->eltwise.algo.alpha = NULL; \ - post_ops->eltwise.algo.algo_type = RELU; \ - if ( is_param_relu == TRUE ) \ + post_ops->eltwise = malloc( num_eltwise * sizeof( aocl_post_op_eltwise ) ); \ + if ( post_ops->eltwise == NULL ) \ + { \ + free( post_ops->bias.bias ); \ + free( post_ops->seq_vector ); \ + free( post_ops ); \ + return NULL; \ + } \ + \ + if ( num_eltwise > 0 ) \ + { \ + if ( num_eltwise > 1 ) \ + { \ + if ( activator_idx < clip_idx ) \ + { \ + activator_idx = 0; \ + clip_idx = 1; \ + } \ + else \ + { \ + activator_idx = 1; \ + clip_idx = 0; \ + } \ + } \ + else \ + { \ + activator_idx = 0; \ + clip_idx = 0; \ + } \ + } \ + /* Only one of relu,prelu,gelu_tanh,gelu_erf allowed as an activator.*/ \ + if ( is_relu == TRUE ) \ + { \ + ( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \ + ( post_ops->eltwise + activator_idx )->scale_factor = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.beta = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.algo_type = RELU; \ + } \ + else if ( is_param_relu == TRUE ) \ + { \ + ( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \ + ( post_ops->eltwise + activator_idx )->scale_factor = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.beta = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.alpha = malloc( sizeof( C_type ) ); \ + *( ( C_type* ) ( post_ops->eltwise + activator_idx )->algo.alpha ) = ( C_type )6; \ + ( post_ops->eltwise + activator_idx )->algo.algo_type = PRELU; \ + } \ + else if ( is_gelu_tanh == TRUE ) \ + { \ + ( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \ + ( post_ops->eltwise + activator_idx )->scale_factor = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.beta = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.algo_type = GELU_TANH; \ + } \ + else if ( is_gelu_erf == TRUE ) \ + { \ + ( post_ops->eltwise + activator_idx )->is_power_of_2 = FALSE; \ + ( post_ops->eltwise + activator_idx )->scale_factor = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.alpha = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.beta = NULL; \ + ( post_ops->eltwise + activator_idx )->algo.algo_type = GELU_ERF; \ + } \ + if ( is_clip == TRUE ) \ { \ - post_ops->eltwise.algo.alpha = malloc( sizeof( C_type ) ); \ - *( ( C_type* ) post_ops->eltwise.algo.alpha ) = ( C_type )6; \ - post_ops->eltwise.algo.algo_type = PRELU; \ + ( post_ops->eltwise + clip_idx )->is_power_of_2 = FALSE; \ + ( post_ops->eltwise + clip_idx )->scale_factor = NULL; \ + ( post_ops->eltwise + clip_idx )->algo.alpha = malloc( sizeof( C_type ) ); \ + ( post_ops->eltwise + clip_idx )->algo.beta = malloc( sizeof( C_type ) ); \ + *( ( C_type* ) ( post_ops->eltwise + clip_idx )->algo.alpha ) = ( C_type ) ( -64 ); \ + *( ( C_type* ) ( post_ops->eltwise + clip_idx )->algo.beta ) = ( C_type ) ( 3 ); \ + ( post_ops->eltwise + clip_idx )->algo.algo_type = CLIP; \ } \ - post_ops->eltwise.algo.beta = NULL; \ } \ \ if ( global_dscale_out == 'y' ) \ @@ -651,6 +944,7 @@ aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \ post_ops->sum.scale_factor = malloc( n * sizeof( DSCALE_type ) ); \ if ( post_ops->sum.scale_factor == NULL ) \ { \ + free ( post_ops->eltwise ); \ free ( post_ops->bias.bias ); \ free( post_ops->seq_vector ); \ free( post_ops ); \ @@ -672,8 +966,10 @@ aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \ GEN_MAT_MUL_POST_OPS_CREATOR(int16_t,float,u8s8s16os16) GEN_MAT_MUL_POST_OPS_CREATOR(int32_t,float,u8s8s32os32) -GEN_MAT_MUL_POST_OPS_CREATOR(float,float,bf16bf16f32of32) +GEN_MAT_MUL_POST_OPS_CREATOR(float,float,bf16bf16f32of32) GEN_MAT_MUL_POST_OPS_CREATOR(float,float,f32f32f32of32) +GEN_MAT_MUL_POST_OPS_CREATOR(int32_t,float,s8s8s32os32) +GEN_MAT_MUL_POST_OPS_CREATOR(int16_t,float,s8s8s16os16) void lpgemm_destroy_post_ops_struct( aocl_post_op* post_ops ) { @@ -682,9 +978,20 @@ void lpgemm_destroy_post_ops_struct( aocl_post_op* post_ops ) return; } - if ( post_ops->eltwise.algo.alpha != NULL ) + if ( post_ops->eltwise != NULL ) { - free( post_ops->eltwise.algo.alpha ); + for ( dim_t i = 0; i < num_eltwise; ++i ) + { + if ( ( post_ops->eltwise + i )->algo.alpha != NULL ) + { + free( ( post_ops->eltwise + i )->algo.alpha ); + } + if ( ( post_ops->eltwise + i )->algo.beta != NULL ) + { + free( ( post_ops->eltwise + i )->algo.beta ); + } + } + free( post_ops->eltwise ); } if ( post_ops->sum.scale_factor != NULL ) { @@ -740,6 +1047,15 @@ void mat_mul_bench_main_ ## BLAS_SFX \ \ C_type* c_ref = ( C_type* ) bli_malloc_user( sizeof( C_type ) * m * n ); \ memset( ( void* ) c_ref, 0, sizeof( C_type ) * m * n ); \ + \ + GEN_FUNC_NAME(fill_array_,A_type)( a, ( m * k ) ); \ + GEN_FUNC_NAME(fill_array_,B_type)( b, ( k * n ) ); \ + \ + if ( bench_mode == 'a' ) \ + { \ + GEN_FUNC_NAME(fill_array_,C_type)( c, ( m * n ) ); \ + GEN_FUNC_NAME(fill_array_,C_type)( c_ref, ( m * n ) ); \ + } \ \ C_type alpha; \ C_type beta; \ @@ -753,9 +1069,6 @@ void mat_mul_bench_main_ ## BLAS_SFX \ alpha = 2; \ beta = 9; \ } \ - \ - GEN_FUNC_NAME(fill_array_,A_type)( a, ( m * k ) ); \ - GEN_FUNC_NAME(fill_array_,B_type)( b, ( k * n ) ); \ \ aocl_post_op* post_op = NULL; \ if ( ( post_ops_str != NULL ) || ( global_dscale_out == 'y' ) ) \ @@ -846,6 +1159,10 @@ GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int8_t,u8s8s16os8,u8s8s16os16) GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int32_t,u8s8s32os32,u8s8s32os32) GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int8_t,u8s8s32os8,u8s8s32os32) GEN_MAT_MUL_BENCH_MAIN_FUNC(float,float,float,f32f32f32of32,f32f32f32of32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int32_t,s8s8s32os32,s8s8s32os32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int8_t,s8s8s32os8,s8s8s32os32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int16_t,s8s8s16os16,s8s8s16os16) +GEN_MAT_MUL_BENCH_MAIN_FUNC(int8_t,int8_t,int8_t,s8s8s16os8,s8s8s16os16) #define GEN_MAT_MUL_BENCH_MAIN_FUNC_BF16(C_type, BLAS_SFX) \ void mat_mul_bench_main_ ## BLAS_SFX \ @@ -897,6 +1214,12 @@ void mat_mul_bench_main_ ## BLAS_SFX \ \ C_type* c_ref = ( C_type* ) bli_malloc_user( sizeof( C_type ) * m * n ); \ memset( ( void* ) c_ref, 0, sizeof( C_type ) * m * n ); \ + \ + if ( bench_mode == 'a' ) \ + { \ + GEN_FUNC_NAME(fill_array_,C_type)( c, ( m * n ) ); \ + GEN_FUNC_NAME(fill_array_,C_type)( c_ref, ( m * n ) ); \ + } \ \ float alpha; \ float beta; \ @@ -945,7 +1268,7 @@ void mat_mul_bench_main_ ## BLAS_SFX \ bfloat16* b_reorder = ( bfloat16* ) bli_malloc_user( b_reorder_buf_siz_req ); \ aocl_reorder_bf16bf16f32of32( 'B', b, b_reorder, k, n, stride_b ); \ \ - GEN_FUNC_NAME(mat_mul_bench_driver_,BLAS_SFX) \ + GEN_FUNC_NAME(mat_mul_bench_driver_,BLAS_SFX) \ ( \ stor_order, op_t, n_repeats, m, n, k, \ alpha, \ @@ -957,7 +1280,7 @@ void mat_mul_bench_main_ ## BLAS_SFX \ ); \ } \ \ -if ( bench_mode == 'a' ) \ + if ( bench_mode == 'a' ) \ { \ printf(" Running accuracy check.\n"); \ GEN_FUNC_NAME(mat_mul_accuracy_check_driver_,BLAS_SFX) \ @@ -1009,16 +1332,36 @@ int main( int argc, char** argv ) FILE* fin = NULL; if ( argc < 5 ) { - printf( "Usage: ./mat_mul -i input.txt -m mode < -n 1000 -o op1,op2.. >" \ - "\nMode is either a or p. a is used for accuracy test, " \ - "whereas p is used for performance benchmarking." \ - "\nn_repeats can be set optionally using -n arg." \ - "\nPost ops can be executed optionaly by providing a " \ - "coma separated list of ops after -o arg.\nCurrently " \ - "bias and relu/prelu is supported and can be specified " \ - "as a single post op or combination of the same. eg: -o bias,relu ; -o prelu." \ - "\nDownscaled version of an API can be enabled by using -d arg. " \ - "downscale is used to enable- u8s8s32os8, u8s8s16os8 or bf16bf16f32obf16 \n" ); + printf + ( + "Usage: ./bench_lpgemm -i input.txt -m mode < -n 100 -o op1,op2 >\n" \ + "--Mode is either a or p.\n" \ + "\ta is used for accuracy testing.\n" \ + "\tp is used for performance benchmarking.\n" \ + "--n_repeats can be set optionally using -n arg.\n" \ + "--Post ops can be executed optionaly by providing a coma separated\n" \ + " list of post-ops after -o arg. Following post-ops are supported:\n" \ + " 1. bias\n" \ + " 2. 4 activators\n" \ + " a. relu\n" \ + " b. prelu\n" \ + " c. gelu_tanh\n" \ + " d. gelu_erf\n" \ + " 3.clip\n" \ + " Atleast one post-op needs to be specified if the -o arg is used.\n" \ + " eg: -o gelu_tanh; -o bias,relu ; -o clip,prelu,bias.\n" \ + " It is to be noted only one activator can be used at a time.\n" \ + " If more than one activator is used, only the first activator is\n" \ + " applied and the other activators are ignored.\n" \ + "--Downscaled version of an API is enabled by using -d arg.\n" \ + " Downscaled api's are used to enable quantization workflows.\n" \ + " Following downscaled api's are supported:\n" \ + " 1. u8s8s32os32 -d = u8s8s32os8.\n" \ + " 2. u8s8s16os16 -d = u8s8s16os8.\n" \ + " 3. bf16bf16f32obf32 -d = bf16bf16f32obf16.\n" \ + " 4. s8s8s32os32 -d = s8s8s32os8.\n" \ + " 5. s8s8s16os16 -d = s8s8s16os8.\n" \ + ); exit( 1 ); } @@ -1055,7 +1398,9 @@ int main( int argc, char** argv ) if ( post_ops_str != NULL ) { - post_ops_str_dest = strdup( post_ops_str ); + post_ops_str_dest = ( char* )malloc \ + ( ( strlen( post_ops_str) + 1 )* sizeof( char ) ); + strcpy( post_ops_str_dest, post_ops_str ); } if ( bench_mode == 'p' ) @@ -1081,9 +1426,9 @@ int main( int argc, char** argv ) } FILE* fout = NULL; - + fout = fopen( "lpgemm_accuracy_test_failures.txt", "w" ); - + char op_type_char; char op_t; char stor_order; @@ -1110,7 +1455,7 @@ int main( int argc, char** argv ) #ifdef BLIS_ENABLE_OPENMP omp_set_num_threads( list_omp_cores_for_testing[core_index] ); #endif - printf( "Accuracy test using %ld threads.\n", + printf( "Accuracy test using %ld threads.\n", list_omp_cores_for_testing[core_index] ); core_index++; @@ -1160,7 +1505,7 @@ int main( int argc, char** argv ) ( fin, fout, stor_order, op_t, m, n, k, stride_a, stride_b, stride_c, - NULL + post_ops_str_dest ); } else if ((op_type_char == 's') || (op_type_char == 'S')) @@ -1184,7 +1529,7 @@ int main( int argc, char** argv ) ); } } - if ((op_type_char == 'b') || (op_type_char == 'B')) + else if ((op_type_char == 'b') || (op_type_char == 'B')) { if ( global_dscale_out == 'n' ) { @@ -1203,7 +1548,49 @@ int main( int argc, char** argv ) m, n, k, stride_a, stride_b, stride_c, post_ops_str_dest ); - } + } + } + else if ( ( op_type_char == 'u' ) || ( op_type_char == 'U' ) ) + { + if ( global_dscale_out == 'n' ) + { + GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s32os32) + ( + fin, fout, stor_order, op_t, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } + else + { + GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s32os8) + ( + fin, fout, stor_order, op_t, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } + } + else if ( ( op_type_char == 'v' ) || ( op_type_char == 'V' ) ) + { + if ( global_dscale_out == 'n' ) + { + GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s16os16) + ( + fin, fout, stor_order, op_t, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } + else + { + GEN_FUNC_NAME(mat_mul_bench_main_,s8s8s16os8) + ( + fin, fout, stor_order, op_t, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } } if ( post_ops_str != NULL ) { diff --git a/bench/bench_aocl_gemm/bench_lpgemm_utils.c b/bench/bench_aocl_gemm/bench_lpgemm_utils.c new file mode 100644 index 0000000000..dbbdce6703 --- /dev/null +++ b/bench/bench_aocl_gemm/bench_lpgemm_utils.c @@ -0,0 +1,392 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include +#include +#include +#include +#include + +#include "blis.h" + +// Mode can be one of the follwoing: +// 1. p - performance, used for benchmarks. +// 2. a - accuracy, used to test accuracy/correctness. +// Default value is p, can be modified by passing command line arg. +char bench_mode = 'p'; + +int32_t global_n_repeat = 0; + +#define _XSTR(str) #str +#define XSTR(str) _XSTR(str) + +#define GEN_FUNC_NAME(prototype,ctype) prototype ## ctype + +#define GEN_FILL_ARRAY_FUNC(ctype) \ +void fill_array_ ## ctype ( void* arr, dim_t size ) \ +{ \ + ctype* temp_arr = ( ctype* ) arr; \ + for ( dim_t i = 0; i < size; ++i ) \ + { \ + temp_arr[i] = ( ctype )( i % 10 ); \ + } \ +} \ + +GEN_FILL_ARRAY_FUNC(float) + +void print_result + ( + const char* msg, + int32_t n_repeats, + dim_t n, + dim_t incx, + double runtime + ) +{ + printf("%s n: %ld, incx: %ld, runtime: %f s, n_repeats: %d\n", \ + msg, n, incx, runtime, n_repeats); +} + +#define GEN_GELU_BENCH_DRV_FN(V_type,GELU_SFX) \ +void gelu_bench_driver_ ## GELU_SFX \ + ( \ + int32_t n_repeats, \ + dim_t n, \ + V_type* x, \ + inc_t incx \ + ) \ +{ \ + double min_time_diff = DBL_MAX; \ + for ( int32_t nr = 0; nr < n_repeats; ++nr ) \ + { \ + struct timespec tstart={0,0}, tend={0,0}; \ + clock_gettime(CLOCK_MONOTONIC, &tstart); \ + \ + if ( bench_mode == 'a' ) \ + { \ + GEN_FUNC_NAME(fill_array_,V_type)( x, ( n * incx) ); \ + } \ + \ + GEN_FUNC_NAME(aocl_,GELU_SFX) \ + ( \ + n, x, incx \ + ); \ + \ + clock_gettime(CLOCK_MONOTONIC, &tend); \ + \ + double diff = \ + ( ( double ) tend.tv_sec + ( 1.0e-9 * tend.tv_nsec ) ) - \ + ( ( double ) tstart.tv_sec + ( 1.0e-9 * tstart.tv_nsec ) ); \ + min_time_diff = ( diff < min_time_diff ) ? diff : min_time_diff; \ + } \ + \ + print_result( XSTR(GELU_SFX), n_repeats, n, incx, min_time_diff); \ +} \ + +GEN_GELU_BENCH_DRV_FN(float,gelu_tanh_f32) +GEN_GELU_BENCH_DRV_FN(float,gelu_erf_f32) + +#define GEN_SOFTMAX_BENCH_DRV_FN(V_type,SOFTMAX_SFX) \ +void softmax_bench_driver_ ## SOFTMAX_SFX \ + ( \ + int32_t n_repeats, \ + dim_t n, \ + V_type* x, \ + inc_t incx \ + ) \ +{ \ + double min_time_diff = DBL_MAX; \ + for ( int32_t nr = 0; nr < n_repeats; ++nr ) \ + { \ + struct timespec tstart={0,0}, tend={0,0}; \ + clock_gettime(CLOCK_MONOTONIC, &tstart); \ + \ + if ( bench_mode == 'a' ) \ + { \ + GEN_FUNC_NAME(fill_array_,V_type)( x, ( n * incx) ); \ + } \ + \ + GEN_FUNC_NAME(aocl_,SOFTMAX_SFX) \ + ( \ + n, x, incx \ + ); \ + \ + clock_gettime(CLOCK_MONOTONIC, &tend); \ + \ + double diff = \ + ( ( double ) tend.tv_sec + ( 1.0e-9 * tend.tv_nsec ) ) - \ + ( ( double ) tstart.tv_sec + ( 1.0e-9 * tstart.tv_nsec ) ); \ + min_time_diff = ( diff < min_time_diff ) ? diff : min_time_diff; \ + } \ + \ + print_result( XSTR(SOFTMAX_SFX), n_repeats, n, incx, min_time_diff); \ +} \ + +GEN_SOFTMAX_BENCH_DRV_FN(float,softmax_f32) + +inline float gelu_tanh_f32 + ( + float temp_accum + ) +{ + temp_accum = 0.5 *(double)temp_accum * (1 + tanhf( 0.797884 * ( (double)temp_accum + \ + ( 0.044715 * ((double)temp_accum * (double)temp_accum * \ + (double)temp_accum ) ) ) ) ); + return temp_accum; +}\ + +inline float gelu_erf_f32 + ( + float temp_accum + ) +{ + temp_accum = 0.5 *(double)temp_accum * (1 + erff( (double)temp_accum * 0.707107 )); + return temp_accum; +} + +#define GEN_GELU_ACC_CHK_FN(V_type,GELU_SFX) \ +void gelu_acc_check_ ## GELU_SFX \ + ( \ + FILE* fout, \ + dim_t n, \ + V_type* x, \ + V_type* ref_x, \ + inc_t incx \ + ) \ +{ \ + for ( dim_t idx = 0; idx < ( n * incx ); idx += incx ) \ + { \ + V_type temp_acc = GELU_SFX( *( ref_x + idx ) ); \ + if ( temp_acc != *( x + idx ) ) \ + { \ + if ( fout ) \ + { \ + fprintf( fout, "%s Failure input n: %ld, incx: %ld, idx: %ld \n", \ + XSTR(GELU_SFX), n, incx, ( idx / incx ) ); \ + fflush( fout ); \ + } \ + printf("%s failure, n: %ld, incx: %ld, idx: %ld, ref: %f, calc: %f\n", \ + XSTR(GELU_SFX), n, incx, ( idx / incx ), temp_acc, *(x + idx)); \ + goto cleanup_acc; \ + } \ + } \ +cleanup_acc: \ + return; \ +} \ + +GEN_GELU_ACC_CHK_FN(float,gelu_tanh_f32) +GEN_GELU_ACC_CHK_FN(float,gelu_erf_f32) + +#define GEN_SOFTMAX_ACC_CHK_FN(V_type,SOFTMAX_SFX) \ +void softmax_acc_check_ ## SOFTMAX_SFX \ + ( \ + FILE* fout, \ + dim_t n, \ + V_type* x, \ + V_type* ref_x, \ + inc_t incx \ + ) \ +{ \ + double exp_sum = 0.0; \ + for ( dim_t idx = 0; idx < ( n * incx ); idx += incx )\ + { \ + exp_sum += ( double )expf( *(ref_x + idx ) ); \ + } \ + for ( dim_t idx = 0; idx < ( n * incx ); idx += incx ) \ + { \ + V_type temp_acc = ( V_type )( ( ( double )*( ref_x + idx ) ) / exp_sum ); \ + if ( temp_acc != *( x + idx ) ) \ + { \ + if ( fout ) \ + { \ + fprintf( fout, "%s Failure input n: %ld, incx: %ld, idx: %ld \n", \ + XSTR(SOFTMAX_SFX), n, incx, ( idx / incx ) ); \ + fflush( fout ); \ + } \ + printf("%s failure, n: %ld, incx: %ld, idx: %ld, ref: %.10f, calc: %.10f\n", \ + XSTR(SOFTMAX_SFX), n, incx, ( idx / incx ), temp_acc, *(x + idx)); \ + goto cleanup_acc; \ + } \ + } \ +cleanup_acc: \ + return; \ +} \ + +GEN_SOFTMAX_ACC_CHK_FN(float,softmax_f32) + +#define GEN_GELU_BENCH_MAIN_FN(V_type,GELU_SFX) \ +void gelu_bench_main_ ## GELU_SFX \ + ( \ + FILE* fout, \ + dim_t n, \ + inc_t incx \ + ) \ +{ \ + int32_t n_repeats = 1000; \ + if ( global_n_repeat > 0 ) \ + { \ + n_repeats = global_n_repeat; \ + } \ + \ + V_type* x = ( V_type* ) bli_malloc_user( sizeof( V_type ) * n * incx ); \ + GEN_FUNC_NAME(fill_array_,V_type)( x, ( n * incx ) ); \ + \ + V_type* ref_x = ( V_type* ) bli_malloc_user( sizeof( V_type ) * n * incx ); \ + GEN_FUNC_NAME(fill_array_,V_type)( ref_x, ( n * incx ) ); \ + \ + GEN_FUNC_NAME(gelu_bench_driver_,GELU_SFX)(n_repeats,n,x,incx); \ + \ + if ( bench_mode == 'a' ) \ + { \ + GEN_FUNC_NAME(gelu_acc_check_,GELU_SFX)(fout,n,x,ref_x,incx); \ + } \ +} \ + +GEN_GELU_BENCH_MAIN_FN(float,gelu_tanh_f32) +GEN_GELU_BENCH_MAIN_FN(float,gelu_erf_f32) + +#define GEN_SOFTMAX_BENCH_MAIN_FN(V_type,SOFTMAX_SFX) \ +void softmax_bench_main_ ## SOFTMAX_SFX \ + ( \ + FILE* fout, \ + dim_t n, \ + inc_t incx \ + ) \ +{ \ + int32_t n_repeats = 1000; \ + if ( global_n_repeat > 0 ) \ + { \ + n_repeats = global_n_repeat; \ + } \ + \ + V_type* x = ( V_type* ) bli_malloc_user( sizeof( V_type ) * n * incx ); \ + GEN_FUNC_NAME(fill_array_,V_type)( x, ( n * incx ) ); \ + \ + V_type* ref_x = ( V_type* ) bli_malloc_user( sizeof( V_type ) * n * incx ); \ + GEN_FUNC_NAME(fill_array_,V_type)( ref_x, ( n * incx ) ); \ + \ + GEN_FUNC_NAME(softmax_bench_driver_,SOFTMAX_SFX)(n_repeats,n,x,incx); \ + \ + if ( bench_mode == 'a' ) \ + { \ + GEN_FUNC_NAME(softmax_acc_check_,SOFTMAX_SFX)(fout,n,x,ref_x,incx); \ + } \ +} \ + +GEN_SOFTMAX_BENCH_MAIN_FN(float,softmax_f32) + +int main( int argc, char** argv ) +{ + FILE* fin = NULL; + if ( argc < 5 ) + { + printf( "Usage: ./bench_lpgemm_utils -i input.txt -m mode < -n 1000 >" \ + "\nMode is either a or p. a is used for accuracy test, " \ + "whereas p is used for performance benchmarking." \ + "\nn_repeats can be set optionally using -n arg.\n" ); + exit( 1 ); + } + + char* file_name = NULL; + + // Parse CLI arguments. + opterr = 0; + int opt_val; + while ( ( opt_val = getopt( argc, argv, "i:m:n:" ) ) != -1 ) + { + switch ( opt_val ) + { + case 'i': + file_name = optarg; + break; + case 'm': + bench_mode = ( ( ( *optarg ) == 'a' ) || ( ( *optarg ) == 'p' ) ) ? ( *optarg ) : 'p'; + break; + case 'n': + global_n_repeat = ( atoi( optarg ) > 0 ) ? atoi( optarg ) : 0; + break; + default: + break; + } + } + + if ( bench_mode == 'p' ) + { + printf( "Running bench in performance benchmarking mode.\n" ); + } + else if ( bench_mode == 'a' ) + { + printf( "Running bench in accuracy/correctness testing mode.\n" ); + } + + if ( file_name == NULL ) + { + printf( " File name provided is invalid.\n" ); + exit( 1 ); + } + + fin = fopen( file_name, "r" ); + if (fin == NULL) + { + printf( "Error opening the file %s\n", argv[1] ); + exit( 1 ); + } + + FILE* fout = NULL; + + fout = fopen( "lpgemm_accuracy_test_failures.txt", "w" ); + + char l1_op_type[128]; + dim_t n; + inc_t incx; + while ( fscanf( fin, "%s %ld %ld\n", l1_op_type, &n, &incx ) == 3 ) + { + if ( strcmp( l1_op_type, "f32_gelu_tanh" ) == 0 ) + { + gelu_bench_main_gelu_tanh_f32( fout, n, incx ); + } + else if ( strcmp( l1_op_type, "f32_gelu_erf" ) == 0 ) + { + gelu_bench_main_gelu_erf_f32( fout, n, incx ); + } + else if ( strcmp( l1_op_type, "f32_softmax" ) == 0 ) + { + softmax_bench_main_softmax_f32( fout, n, incx ); + } + } + + return 0; +} diff --git a/bench/bench_aocl_gemm/bench_utils_input.txt b/bench/bench_aocl_gemm/bench_utils_input.txt new file mode 100644 index 0000000000..af9051b6a4 --- /dev/null +++ b/bench/bench_aocl_gemm/bench_utils_input.txt @@ -0,0 +1,33 @@ +f32_softmax 1 1 +f32_softmax 2 1 +f32_softmax 4 1 +f32_softmax 21 1 +f32_softmax 64 1 +f32_gelu_tanh 1 1 +f32_gelu_tanh 2 1 +f32_gelu_tanh 8 1 +f32_gelu_tanh 16 1 +f32_gelu_tanh 21 1 +f32_gelu_tanh 64 1 +f32_gelu_tanh 1029 1 +f32_gelu_erf 1 1 +f32_gelu_erf 2 1 +f32_gelu_erf 8 1 +f32_gelu_erf 16 1 +f32_gelu_erf 21 1 +f32_gelu_erf 64 1 +f32_gelu_erf 1029 1 +f32_gelu_tanh 1 9 +f32_gelu_tanh 2 9 +f32_gelu_tanh 8 9 +f32_gelu_tanh 16 1024 +f32_gelu_tanh 21 1024 +f32_gelu_tanh 64 1024 +f32_gelu_tanh 1029 512 +f32_gelu_erf 1 9 +f32_gelu_erf 2 9 +f32_gelu_erf 8 9 +f32_gelu_erf 16 1024 +f32_gelu_erf 21 1024 +f32_gelu_erf 64 1024 +f32_gelu_erf 1029 512 diff --git a/bench/bench_aocl_gemm/test_small.txt b/bench/bench_aocl_gemm/test_small.txt new file mode 100644 index 0000000000..13f47bce81 --- /dev/null +++ b/bench/bench_aocl_gemm/test_small.txt @@ -0,0 +1,54 @@ +i r r 4 3 204 204 3 3 +i r r 6 5 204 204 5 5 +i r r 6 7 204 204 7 7 +i r r 6 9 204 204 9 9 +i r r 78402 8 190 190 8 8 +i r r 78402 9 190 190 9 9 +i r r 78402 10 190 190 10 10 +i r r 78402 11 190 190 11 11 +i r r 78402 12 190 190 12 12 +i r r 78402 13 190 190 13 13 +i r r 78402 14 190 190 14 14 +i r r 78402 15 190 190 15 15 +i r r 78403 8 190 190 8 8 +i r r 78403 9 190 190 9 9 +i r r 78403 10 190 190 10 10 +i r r 78403 11 190 190 11 11 +i r r 78403 12 190 190 12 12 +i r r 78403 13 190 190 13 13 +i r r 78403 14 190 190 14 14 +i r r 78403 15 190 190 15 15 +i r r 78404 8 190 190 8 8 +i r r 78404 9 190 190 9 9 +i r r 78404 10 190 190 10 10 +i r r 78404 11 190 190 11 11 +i r r 78404 12 190 190 12 12 +i r r 78404 13 190 190 13 13 +i r r 78404 14 190 190 14 14 +i r r 78404 15 190 190 15 15 +i r r 78405 8 190 190 8 8 +i r r 78405 9 190 190 9 9 +i r r 78405 10 190 190 10 10 +i r r 78405 11 190 190 11 11 +i r r 78405 12 190 190 12 12 +i r r 78405 13 190 190 13 13 +i r r 78405 14 190 190 14 14 +i r r 78405 15 190 190 15 15 +i r r 78406 8 190 190 8 8 +i r r 78406 9 190 190 9 9 +i r r 78406 10 190 190 10 10 +i r r 78406 11 190 190 11 11 +i r r 78406 12 190 190 12 12 +i r r 78406 13 190 190 13 13 +i r r 78406 14 190 190 14 14 +i r r 78406 15 190 190 15 15 +i r r 78407 8 190 190 8 8 +i r r 78407 9 190 190 9 9 +i r r 78407 10 190 190 10 10 +i r r 78407 11 190 190 11 11 +i r r 78407 12 190 190 12 12 +i r r 78407 13 190 190 13 13 +i r r 78407 14 190 190 14 14 +i r r 78407 15 190 190 15 15 +### +i r r 78402 16 190 190 16 16 diff --git a/bench/bench_axpbyv.c b/bench/bench_axpbyv.c index c962079dd6..db62ead33e 100644 --- a/bench/bench_axpbyv.c +++ b/bench/bench_axpbyv.c @@ -262,4 +262,4 @@ int main( int argc, char** argv ) } return 0; -} \ No newline at end of file +} diff --git a/bench/bench_gemm.c b/bench/bench_gemm.c index 908ce0fca5..d9dc523e92 100755 --- a/bench/bench_gemm.c +++ b/bench/bench_gemm.c @@ -109,6 +109,10 @@ int main( int argc, char** argv ) printf("Error opening output file %s\n", argv[2]); exit(1); } + if (argc > 3) + { + n_repeats = atoi(argv[3]); + } fprintf(fout, "Dt transa transb m n k alphaR alphaI lda ldb betaR betaI ldc gflops\n"); diff --git a/bench/bench_swapv.c b/bench/bench_swapv.c index 34af6b7975..6f2c8fd90e 100644 --- a/bench/bench_swapv.c +++ b/bench/bench_swapv.c @@ -248,4 +248,4 @@ int main( int argc, char** argv ) fclose(fout); return 0; -} \ No newline at end of file +} diff --git a/bench/bench_trsv.c b/bench/bench_trsv.c index ddf3ea187a..425f61f1d0 100644 --- a/bench/bench_trsv.c +++ b/bench/bench_trsv.c @@ -395,4 +395,4 @@ int main( int argc, char** argv ) // bli_finalize(); return 0; -} \ No newline at end of file +} diff --git a/bench/inputnrm2.txt b/bench/inputnrm2.txt index 567d6e4691..517f5eac41 100644 --- a/bench/inputnrm2.txt +++ b/bench/inputnrm2.txt @@ -39,4 +39,4 @@ dnrm2:171: D 8192 5 dnrm2:171: D 16384 11 dnrm2:171: D 20976 3 dnrm2:171: D 56841 19 -dnrm2:171: D 65536 6 \ No newline at end of file +dnrm2:171: D 65536 6 diff --git a/blastest/f2c/CMakeLists.txt b/blastest/f2c/CMakeLists.txt index 00d8291164..87ec3b6a5b 100644 --- a/blastest/f2c/CMakeLists.txt +++ b/blastest/f2c/CMakeLists.txt @@ -56,4 +56,4 @@ target_sources("${F2C_LIB}" ${CMAKE_CURRENT_SOURCE_DIR}/wrtfmt.c ${CMAKE_CURRENT_SOURCE_DIR}/wsfe.c ${CMAKE_CURRENT_SOURCE_DIR}/wsle.c - ) \ No newline at end of file + ) diff --git a/blastest/f2c/endfile.c b/blastest/f2c/endfile.c index def9988d12..8be2d826b5 100644 --- a/blastest/f2c/endfile.c +++ b/blastest/f2c/endfile.c @@ -21,6 +21,10 @@ other tortious action, arising out of or in connection with the use or performance of this software. ****************************************************************/ +/* + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +*/ + #include #include "f2c.h" #include "fio.h" @@ -43,7 +47,9 @@ integer f_end(alist *a) if(a->aunit>=MXUNIT || a->aunit<0) err(a->aerr,101,"endfile"); b = &f__units[a->aunit]; if(b->ufd==NULL) { - char nbuf[10]; + /* Increased buffer size from 10 to 17 to eliminate + warning message from gcc. */ + char nbuf[17]; sprintf(nbuf,"fort.%ld",(long)a->aunit); if (tf = fopen(nbuf, f__w_mode[0])) fclose(tf); diff --git a/blastest/f2c/f2c_config.h b/blastest/f2c/f2c_config.h index af39bbe5d0..303cddc517 100644 --- a/blastest/f2c/f2c_config.h +++ b/blastest/f2c/f2c_config.h @@ -188,4 +188,4 @@ #ifdef _MSC_VER #define NON_UNIX_STDIO 1 -#endif \ No newline at end of file +#endif diff --git a/build/auto_config.py b/build/auto_config.py index 4221ca637d..1ce3989e4e 100644 --- a/build/auto_config.py +++ b/build/auto_config.py @@ -1,73 +1,73 @@ -"""Copyright (C) 2020, Advanced Micro Devices, Inc. All Rights Reserved""" - -import subprocess -import sys - -def config_check(): - # Execute wmic shell command with sub-process - result = subprocess.run(['wmic', 'cpu', 'get', 'caption'], stdout=subprocess.PIPE, text=True).stdout - - # Replace the newline character with empty char - result=result.replace('\n', '') - - # parse the string into list of string - parse_string=result.split(" ") - - # Strip the empty strings from list - parse_string=[list for list in parse_string if list.strip()] - - vendor=parse_string[1] - family=hex(int(parse_string[3])) - model=hex(int(parse_string[5])) - stepping=hex(int(parse_string[7])) - - # AMD family numbers - # Zen/ Zen+/Zen2 family number - zen_family="0x17" - # Bulldozer / Piledriver / Steamroller / Excavator family number - amd_family="0x15" - - # AMD CPUID model numbers - zen_model=["0x30", "0xff"] - zen2_model=["0x00", "0xff"] - excavator_model=["0x60","0x7f"] - steamroller_model=["0x30", "0x3f"] - piledriver_model=["0x02", "0x10", "0x1f"] - bulldozer_model=["0x00", "0x01"] - - # Check the CPU configuration Intel64/AMD64 - if vendor.count("Intel64"): - return - elif vendor.count("AMD64"): - # Check the AMD family name - if family == zen_family: - if (zen_model[0] <= model and model <= zen_model[1]) : - family="zen2" - elif (zen2_model[0] <= model and model <= zen2_model[1]) : - family="zen" - else: - print("Unknown model number") - elif family == amd_family: - # check for specific models of excavator family - if (excavator_model[0] <= model and model <= excavator_model[1]) : - family="excavator" - # check for specific models of steamroller family - elif (steamroller_model[0] <= model and model <= steamroller_model[1]) : - family="steamroller" - # check for specific models of piledriver family - elif (model == piledriver_model[0] or (piledriver_model[1] <= model and model <= piledriver_model[2])) : - family="piledriver" - # check for specific models of bulldozer family - elif (model == bulldozer_model[0] or model == bulldozer_model[1]) : - family="bulldozer" - else: - print("Unknown model number") - else: - print("Unknown family") - else: - print("UNKNOWN CPU") - return family - -# Function call for config family names -FAMILY=config_check() -print(FAMILY) \ No newline at end of file +"""Copyright (C) 2020, Advanced Micro Devices, Inc. All Rights Reserved""" + +import subprocess +import sys + +def config_check(): + # Execute wmic shell command with sub-process + result = subprocess.run(['wmic', 'cpu', 'get', 'caption'], stdout=subprocess.PIPE, text=True).stdout + + # Replace the newline character with empty char + result=result.replace('\n', '') + + # parse the string into list of string + parse_string=result.split(" ") + + # Strip the empty strings from list + parse_string=[list for list in parse_string if list.strip()] + + vendor=parse_string[1] + family=hex(int(parse_string[3])) + model=hex(int(parse_string[5])) + stepping=hex(int(parse_string[7])) + + # AMD family numbers + # Zen/ Zen+/Zen2 family number + zen_family="0x17" + # Bulldozer / Piledriver / Steamroller / Excavator family number + amd_family="0x15" + + # AMD CPUID model numbers + zen_model=["0x30", "0xff"] + zen2_model=["0x00", "0xff"] + excavator_model=["0x60","0x7f"] + steamroller_model=["0x30", "0x3f"] + piledriver_model=["0x02", "0x10", "0x1f"] + bulldozer_model=["0x00", "0x01"] + + # Check the CPU configuration Intel64/AMD64 + if vendor.count("Intel64"): + return + elif vendor.count("AMD64"): + # Check the AMD family name + if family == zen_family: + if (zen_model[0] <= model and model <= zen_model[1]) : + family="zen2" + elif (zen2_model[0] <= model and model <= zen2_model[1]) : + family="zen" + else: + print("Unknown model number") + elif family == amd_family: + # check for specific models of excavator family + if (excavator_model[0] <= model and model <= excavator_model[1]) : + family="excavator" + # check for specific models of steamroller family + elif (steamroller_model[0] <= model and model <= steamroller_model[1]) : + family="steamroller" + # check for specific models of piledriver family + elif (model == piledriver_model[0] or (piledriver_model[1] <= model and model <= piledriver_model[2])) : + family="piledriver" + # check for specific models of bulldozer family + elif (model == bulldozer_model[0] or model == bulldozer_model[1]) : + family="bulldozer" + else: + print("Unknown model number") + else: + print("Unknown family") + else: + print("UNKNOWN CPU") + return family + +# Function call for config family names +FAMILY=config_check() +print(FAMILY) diff --git a/build/bli_config.h.in b/build/bli_config.h.in index 6c17fc5e74..ba0c16100b 100644 --- a/build/bli_config.h.in +++ b/build/bli_config.h.in @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -132,6 +132,16 @@ #endif #endif +// If the CBLAS compatibility layer was enabled while the BLAS layer +// was not enabled, we must enable the BLAS layer here. Also undefine +// BLIS_DISABLE_BLAS to ensure consistency. +#ifdef BLIS_ENABLE_CBLAS +#ifndef BLIS_ENABLE_BLAS +#define BLIS_ENABLE_BLAS +#endif +#undef BLIS_DISABLE_BLAS +#endif // BLIS_ENABLE_CBLAS + #ifndef BLIS_ENABLE_MIXED_DT #ifndef BLIS_DISABLE_MIXED_DT #if @enable_mixed_dt@ @@ -196,8 +206,10 @@ #if @disable_blis_arch_type@ #define DISABLE_BLIS_ARCH_TYPE +#define DISABLE_BLIS_MODEL_TYPE #endif #define __blis_arch_type_name "@rename_blis_arch_type@" +#define __blis_model_type_name "@rename_blis_model_type@" #endif diff --git a/build/bli_win_config.h.in b/build/bli_win_config.h.in index 24e1fc3d59..4645b5cf95 100644 --- a/build/bli_win_config.h.in +++ b/build/bli_win_config.h.in @@ -1,54 +1,58 @@ -/* - * Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All Rights Reserved - */ - -#ifndef BLIS_CONFIG_H -#define BLIS_CONFIG_H - -#cmakedefine AOCL_DYNAMIC - -#cmakedefine AOCL_BLIS_ZEN - -#cmakedefine BLIS_ENABLE_OPENMP - -#cmakedefine BLIS_ENABLE_JRIR_SLAB - -#cmakedefine BLIS_ENABLE_JRIR_RR - -#cmakedefine BLIS_ENABLE_PBA_POOLS - -#cmakedefine BLIS_ENABLE_SBA_POOLS - -#cmakedefine BLIS_ENABLE_MEM_TRACING - -#cmakedefine BLIS_INT_TYPE_SIZE @INT_TYPE_SIZE@ - -#cmakedefine BLIS_BLAS_INT_TYPE_SIZE @BLAS_INT_TYPE_SIZE@ - -#cmakedefine BLIS_ENABLE_BLAS - -#cmakedefine BLIS_ENABLE_CBLAS - -#cmakedefine BLIS_ENABLE_MIXED_DT - -#cmakedefine BLIS_ENABLE_MIXED_DT_EXTRA_MEM - -#cmakedefine BLIS_ENABLE_SUP_HANDLING - -#cmakedefine BLIS_ENABLE_MEMKIND - -#cmakedefine BLIS_ENABLE_TRSM_PREINVERSION - -#cmakedefine BLIS_ENABLE_PRAGMA_OMP_SIMD - -#cmakedefine BLIS_ENABLE_SANDBOX - -#cmakedefine BLIS_ENABLE_SHARED - -#cmakedefine BLIS_ENABLE_COMPLEX_RETURN_INTEL - -#cmakedefine DISABLE_BLIS_ARCH_TYPE - -#cmakedefine __blis_arch_type_name "@rename_blis_arch_type@" - -#endif +/* + * Copyright (C) 2020-2023, Advanced Micro Devices, Inc. All rights reserved. + */ + +#ifndef BLIS_CONFIG_H +#define BLIS_CONFIG_H + +#cmakedefine AOCL_DYNAMIC + +#cmakedefine AOCL_BLIS_ZEN + +#cmakedefine BLIS_ENABLE_OPENMP + +#cmakedefine BLIS_ENABLE_JRIR_SLAB + +#cmakedefine BLIS_ENABLE_JRIR_RR + +#cmakedefine BLIS_ENABLE_PBA_POOLS + +#cmakedefine BLIS_ENABLE_SBA_POOLS + +#cmakedefine BLIS_ENABLE_MEM_TRACING + +#cmakedefine BLIS_INT_TYPE_SIZE @INT_TYPE_SIZE@ + +#cmakedefine BLIS_BLAS_INT_TYPE_SIZE @BLAS_INT_TYPE_SIZE@ + +#cmakedefine BLIS_ENABLE_BLAS + +#cmakedefine BLIS_ENABLE_CBLAS + +#cmakedefine BLIS_ENABLE_MIXED_DT + +#cmakedefine BLIS_ENABLE_MIXED_DT_EXTRA_MEM + +#cmakedefine BLIS_ENABLE_SUP_HANDLING + +#cmakedefine BLIS_ENABLE_MEMKIND + +#cmakedefine BLIS_ENABLE_TRSM_PREINVERSION + +#cmakedefine BLIS_ENABLE_PRAGMA_OMP_SIMD + +#cmakedefine BLIS_ENABLE_SANDBOX + +#cmakedefine BLIS_ENABLE_SHARED + +#cmakedefine BLIS_ENABLE_COMPLEX_RETURN_INTEL + +#cmakedefine DISABLE_BLIS_ARCH_TYPE + +#cmakedefine DISABLE_BLIS_MODEL_TYPE + +#cmakedefine __blis_arch_type_name "@rename_blis_arch_type@" + +#cmakedefine __blis_model_type_name "@rename_blis_model_type@" + +#endif diff --git a/build/blis_ref_kernel_mirror.py b/build/blis_ref_kernel_mirror.py index 834de1cee9..f49d101ae7 100644 --- a/build/blis_ref_kernel_mirror.py +++ b/build/blis_ref_kernel_mirror.py @@ -1,4 +1,18 @@ -"""Copyright (C) 2021, Advanced Micro Devices, Inc. All Rights Reserved""" +"""Copyright (C) 2021-2023, Advanced Micro Devices, Inc. All Rights Reserved""" + +################################################################################ +# This file is used to mirroring the refkernels folder data into to zen, zen2, # +# zen3, zen4 and generic folder. # +# Rename all .c files by adding zen, zen2, zen3, zen4 and generic for the # +# corresponding folder .c files and update the corresponding CMakeLists.txt # +# file for amdzen (dynamic dispatcher) config option. # +# # +# Usage: # +# python blis_ref_kernel_mirror.py # +# # +# Author: Chandrashekara K R # +# # +################################################################################ import os import shutil import subprocess @@ -95,10 +109,25 @@ def write_to_file(filename, data): fd.write(data + '\n') +def update_cmakelists_contents(cmakefiles, replacement_str): + for cmakefile in cmakefiles: + if os.path.exists(cmakefile): + # Updating the modified .c files name in CMakeLists.txt + with open(cmakefile, 'r') as fd: + file_content = fd.read() + file_content = file_content.replace( + 'ref.c', replacement_str + '_ref.c') + with open(cmakefile, 'w') as fd: + fd.write(file_content) + + def add_macro_to_cfiles(cfiles, macro): for cfile in cfiles: if os.path.exists(cfile): write_to_file(cfile, macro) + # Renaming the .c files name to incorporate with linux + os.rename(cfile, cfile.split('ref.c')[0] + macro.split(' ')[ + -1].split('\n')[0][1:] + '_ref.c') if __name__ == '__main__': @@ -109,6 +138,7 @@ def add_macro_to_cfiles(cfiles, macro): if os.path.exists(dest_path): remove_folder(dest_path) + # Creating all the required folders temp = os.path.join(cwd, 'temp') create_folder(temp) execute_and_check('XCOPY {} {} /E'.format(source_path, temp)) @@ -117,6 +147,7 @@ def add_macro_to_cfiles(cfiles, macro): create_folder(os.path.join(dest_path, 'zen3')) create_folder(os.path.join(dest_path, 'zen4')) create_folder(os.path.join(dest_path, 'generic')) + # Mirroring refkernels folder data to zen, zen2, zen3, zen4 and generic folder execute_and_check('XCOPY {} {} /E'.format( temp, os.path.join(dest_path, 'zen'))) execute_and_check('XCOPY {} {} /E'.format( @@ -144,23 +175,53 @@ def add_macro_to_cfiles(cfiles, macro): cfiles_in_generic = cfiles_in_generic.split('\r\n') add_macro_to_cfiles(cfiles_in_generic, '\n#define BLIS_CNAME_INFIX _generic\n') + # Listing all CMakelists.txt file from generic folder and updating them. + cmake_files_in_generic = execute_and_check( + 'cd {} && dir / s / b / o: gn CMakeLists.txt'.format( + os.path.join(dest_path, 'generic'))) + cmake_files_in_generic = cmake_files_in_generic.split('\r\n') + update_cmakelists_contents(cmake_files_in_generic, 'generic') cfiles_in_zen = execute_and_check('cd {} && dir / s / b / o: gn *.c' .format(os.path.join(dest_path, 'zen'))) cfiles_in_zen = cfiles_in_zen.split('\r\n') add_macro_to_cfiles(cfiles_in_zen, '\n#define BLIS_CNAME_INFIX _zen\n') + # Listing all CMakelists.txt file from zen folder and updating them. + cmake_files_in_zen = execute_and_check( + 'cd {} && dir / s / b / o: gn CMakeLists.txt'.format( + os.path.join(dest_path, 'zen'))) + cmake_files_in_zen = cmake_files_in_zen.split('\r\n') + update_cmakelists_contents(cmake_files_in_zen, 'zen') cfiles_in_zen2 = execute_and_check('cd {} && dir / s / b / o: gn *.c' .format(os.path.join(dest_path, 'zen2'))) cfiles_in_zen2 = cfiles_in_zen2.split('\r\n') add_macro_to_cfiles(cfiles_in_zen2, '\n#define BLIS_CNAME_INFIX _zen2\n') + # Listing all CMakelists.txt file from zen2 folder and updating them. + cmake_files_in_zen2 = execute_and_check( + 'cd {} && dir / s / b / o: gn CMakeLists.txt'.format( + os.path.join(dest_path, 'zen2'))) + cmake_files_in_zen2 = cmake_files_in_zen2.split('\r\n') + update_cmakelists_contents(cmake_files_in_zen2, 'zen2') cfiles_in_zen3 = execute_and_check('cd {} && dir / s / b / o: gn *.c' .format(os.path.join(dest_path, 'zen3'))) cfiles_in_zen3 = cfiles_in_zen3.split('\r\n') add_macro_to_cfiles(cfiles_in_zen3, '\n#define BLIS_CNAME_INFIX _zen3\n') + # Listing all CMakelists.txt file from zen3 folder and updating them. + cmake_files_in_zen3 = execute_and_check( + 'cd {} && dir / s / b / o: gn CMakeLists.txt'.format( + os.path.join(dest_path, 'zen3'))) + cmake_files_in_zen3 = cmake_files_in_zen3.split('\r\n') + update_cmakelists_contents(cmake_files_in_zen3, 'zen3') cfiles_in_zen4 = execute_and_check('cd {} && dir / s / b / o: gn *.c' .format(os.path.join(dest_path, 'zen4'))) cfiles_in_zen4 = cfiles_in_zen4.split('\r\n') add_macro_to_cfiles(cfiles_in_zen4, '\n#define BLIS_CNAME_INFIX _zen4\n') + # Listing all CMakelists.txt file from zen4 folder and updating them. + cmake_files_in_zen4 = execute_and_check( + 'cd {} && dir / s / b / o: gn CMakeLists.txt'.format( + os.path.join(dest_path, 'zen4'))) + cmake_files_in_zen4 = cmake_files_in_zen4.split('\r\n') + update_cmakelists_contents(cmake_files_in_zen4, 'zen4') diff --git a/config/CMakeLists.txt b/config/CMakeLists.txt index 7429ff42ee..3a5925a306 100644 --- a/config/CMakeLists.txt +++ b/config/CMakeLists.txt @@ -25,4 +25,4 @@ add_subdirectory(haswell) else(${TARGET_ARCH} STREQUAL generic) message("The configuration is : ${TARGET_ARCH}") add_subdirectory(generic) -endif() \ No newline at end of file +endif() diff --git a/config/amdzen/bli_family_amdzen.h b/config/amdzen/bli_family_amdzen.h index 0cf46d5a4e..7e4d460d13 100644 --- a/config/amdzen/bli_family_amdzen.h +++ b/config/amdzen/bli_family_amdzen.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -72,6 +72,8 @@ */ BLIS_EXPORT_BLIS void bli_zen4_override_trsm_blkszs (cntx_t* cntx); +BLIS_EXPORT_BLIS void bli_zen4_override_gemmt_blkszs (cntx_t* cntx); + /* * Restore the block sizes to default values needed for zen4 context. * diff --git a/config/haswell/CMakeLists.txt b/config/haswell/CMakeLists.txt index a16f3ef51b..a43bfe2b23 100644 --- a/config/haswell/CMakeLists.txt +++ b/config/haswell/CMakeLists.txt @@ -18,4 +18,4 @@ if(FILES) #Install our source files install(FILES ${FILES} DESTINATION ${RELATIVE_PATH}) -endif() \ No newline at end of file +endif() diff --git a/config/zen/bli_cntx_init_zen.c b/config/zen/bli_cntx_init_zen.c index 9d4197712e..83ce2cf8b6 100644 --- a/config/zen/bli_cntx_init_zen.c +++ b/config/zen/bli_cntx_init_zen.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -103,7 +103,7 @@ void bli_cntx_init_zen( cntx_t* cntx ) // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 26, + 29, // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, @@ -136,6 +136,7 @@ void bli_cntx_init_zen( cntx_t* cntx ) BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, + BLIS_SCALV_KER, BLIS_DCOMPLEX, bli_zscalv_zen_int, // swapv BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, @@ -144,10 +145,14 @@ void bli_cntx_init_zen( cntx_t* cntx ) // copyv BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, + BLIS_COPYV_KER, BLIS_DCOMPLEX, bli_zcopyv_zen_int, //set BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, + + // scal2v + BLIS_SCAL2V_KER, BLIS_DCOMPLEX, bli_zscal2v_zen_int, cntx ); @@ -251,7 +256,7 @@ void bli_cntx_init_zen( cntx_t* cntx ) // Update the context with optimized small/unpacked gemm kernels. bli_cntx_set_l3_sup_kers ( - 28, + 30, //BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_r_haswell_ref, BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, @@ -276,9 +281,11 @@ void bli_cntx_init_zen( cntx_t* cntx ) BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_RRC, BLIS_DCOMPLEX, bli_zgemmsup_rd_zen_asm_3x4m, TRUE, BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, - BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_CRC, BLIS_DCOMPLEX, bli_zgemmsup_rd_zen_asm_3x4n, TRUE, BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, cntx diff --git a/config/zen/make_defs.mk b/config/zen/make_defs.mk index b4153fcbfb..59fc7b0a67 100644 --- a/config/zen/make_defs.mk +++ b/config/zen/make_defs.mk @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -33,8 +33,9 @@ # # -# FLAGS specific to zen architecture are added here. -# FLAGS that are common for all the AMD architectures are present in amd_config.mk +# FLAGS that are specific to the 'zen' architecture are added here. +# FLAGS that are common for all the AMD architectures are present in +# config/zen/amd_config.mk. # Declare the name of the current configuration and add it to the # running list of configurations included by common.mk. @@ -46,10 +47,27 @@ AMD_CONFIG_FILE := amd_config.mk AMD_CONFIG_PATH := $(BASE_SHARE_PATH)/config/zen -include $(AMD_CONFIG_PATH)/$(AMD_CONFIG_FILE) +# +# --- Determine the C compiler and related flags --- +# + +# NOTE: The build system will append these variables with various +# general-purpose/configuration-agnostic flags in common.mk. You +# may specify additional flags here as needed. + +CPPROCFLAGS := +CMISCFLAGS := +CPICFLAGS := +CWARNFLAGS := + +ifneq ($(DEBUG_TYPE),off) + CDBGFLAGS := -g +endif + ifeq ($(DEBUG_TYPE),noopt) -COPTFLAGS := -O0 + COPTFLAGS := -O0 else -COPTFLAGS := -O3 + COPTFLAGS := -O3 endif # @@ -61,16 +79,21 @@ endif # they make explicit use of the rbp register. CKOPTFLAGS := $(COPTFLAGS) -fomit-frame-pointer ifeq ($(CC_VENDOR),gcc) -CKVECFLAGS += -march=znver1 -endif + CKVECFLAGS += -march=znver1 + GCC_VERSION := $(strip $(shell $(CC) -dumpversion | cut -d. -f1)) + + ifeq ($(shell test $(GCC_VERSION) -ge 9; echo $$?),0) + CKOPTFLAGS += -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize -fno-gcse + endif +endif# gcc # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CRVECFLAGS := $(CKVECFLAGS) + CRVECFLAGS := $(CKVECFLAGS) else -CRVECFLAGS := $(CKVECFLAGS) + CRVECFLAGS := $(CKVECFLAGS) endif # Store all of the variables here to new variables containing the diff --git a/config/zen2/bli_cntx_init_zen2.c b/config/zen2/bli_cntx_init_zen2.c index 3ce2fced92..42eae35d95 100644 --- a/config/zen2/bli_cntx_init_zen2.c +++ b/config/zen2/bli_cntx_init_zen2.c @@ -3,7 +3,7 @@ An object-based framework for developing high-performance BLAS-like libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -115,7 +115,7 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 26, + 29, // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, @@ -148,6 +148,7 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // scalv BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, + BLIS_SCALV_KER, BLIS_DCOMPLEX, bli_zscalv_zen_int, //swap BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, @@ -156,10 +157,14 @@ void bli_cntx_init_zen2( cntx_t* cntx ) //copy BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, + BLIS_COPYV_KER, BLIS_DCOMPLEX, bli_zcopyv_zen_int, //set BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, + + // scal2v + BLIS_SCAL2V_KER, BLIS_DCOMPLEX, bli_zscal2v_zen_int, cntx ); @@ -247,7 +252,7 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // Update the context with optimized small/unpacked gemm kernels. bli_cntx_set_l3_sup_kers ( - 28, + 30, //BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_r_haswell_ref, BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, @@ -267,15 +272,17 @@ void bli_cntx_init_zen2( cntx_t* cntx ) BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, - BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, BLIS_RCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_RRC, BLIS_DCOMPLEX, bli_zgemmsup_rd_zen_asm_3x4m, TRUE, BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, - BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_CRC, BLIS_DCOMPLEX, bli_zgemmsup_rd_zen_asm_3x4n, TRUE, BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, cntx diff --git a/config/zen2/make_defs.mk b/config/zen2/make_defs.mk index 3b87d35b00..180c201b06 100644 --- a/config/zen2/make_defs.mk +++ b/config/zen2/make_defs.mk @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -42,6 +42,11 @@ THIS_CONFIG := zen2 #CONFIGS_INCL += $(THIS_CONFIG) +# Include file containing common flags for all AMD architectures +AMD_CONFIG_FILE := amd_config.mk +AMD_CONFIG_PATH := $(BASE_SHARE_PATH)/config/zen +-include $(AMD_CONFIG_PATH)/$(AMD_CONFIG_FILE) + # # --- Determine the C compiler and related flags --- # @@ -56,48 +61,68 @@ CPICFLAGS := CWARNFLAGS := ifneq ($(DEBUG_TYPE),off) -CDBGFLAGS := -g + CDBGFLAGS := -g endif ifeq ($(DEBUG_TYPE),noopt) -COPTFLAGS := -O0 + COPTFLAGS := -O0 else -COPTFLAGS := -O3 + COPTFLAGS := -O3 endif # Flags specific to optimized kernels. # NOTE: The -fomit-frame-pointer option is needed for some kernels because # they make explicit use of the rbp register. CKOPTFLAGS := $(COPTFLAGS) -fomit-frame-pointer + +# gcc or clang version must be at least 4.0 ifeq ($(CC_VENDOR),gcc) -GCC_VERSION := $(strip $(shell $(CC) -dumpversion | cut -d. -f1)) -#gcc or clang version must be atleast 4.0 -# gcc 9.0 or later: -ifeq ($(shell test $(GCC_VERSION) -ge 9; echo $$?),0) -CKVECFLAGS += -march=znver2 -else -# If gcc is older than 9.1.0 but at least 6.1.0, then we can use -march=znver1 -# as the fallback option. -CRVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store -CKVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store -endif -else + GCC_VERSION := $(strip $(shell $(CC) -dumpversion | cut -d. -f1)) + + ifeq ($(shell test $(GCC_VERSION) -ge 9; echo $$?),0) + # gcc 9.0 or later + CKVECFLAGS += -march=znver2 + CKOPTFLAGS += -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize -fno-gcse + else + # If gcc is older than 9.1.0 but at least 6.1.0, then we can use -march=znver1 + # as the fallback option. + CKVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store + CRVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store + endif +endif # gcc + ifeq ($(CC_VENDOR),clang) -ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC.LLVM.2\|AOCC_2')),1) -CKVECFLAGS += -march=znver2 -else -#if compiling with clang -VENDOR_STRING := $(strip $(shell ${CC_VENDOR} --version | egrep -o '[0-9]+\.[0-9]+\.?[0-9]*')) -CC_MAJOR := $(shell (echo ${VENDOR_STRING} | cut -d. -f1)) -#clang 9.0 or later: -ifeq ($(shell test $(CC_MAJOR) -ge 9; echo $$?),0) -CKVECFLAGS += -march=znver2 -else -CKVECFLAGS += -march=znver1 -endif -endif -endif -endif + # AOCC clang has various formats for the version line + + # AOCC.LLVM.2.0.0.B191.2019_07_19 clang version 8.0.0 (CLANG: Jenkins AOCC_2_0_0-Build#191) (based on LLVM AOCC.LLVM.2.0.0.B191.2019_07_19) + # AOCC.LLVM.2.1.0.B1030.2019_11_12 clang version 9.0.0 (CLANG: Build#1030) (based on LLVM AOCC.LLVM.2.1.0.B1030.2019_11_12) + # AMD clang version 10.0.0 (CLANG: AOCC_2.2.0-Build#93 2020_06_25) (based on LLVM Mirror.Version.10.0.0) + # AMD clang version 11.0.0 (CLANG: AOCC_2.3.0-Build#85 2020_11_10) (based on LLVM Mirror.Version.11.0.0) + # AMD clang version 12.0.0 (CLANG: AOCC_3.0.0-Build#2 2020_11_05) (based on LLVM Mirror.Version.12.0.0) + # AMD clang version 14.0.0 (CLANG: AOCC_4.0.0-Build#98 2022_06_15) (based on LLVM Mirror.Version.14.0.0) + + # For our purpose we just want to know if it version 2x or 3x or 4x + + # But also set these in case we are using upstream LLVM clang + VENDOR_STRING := $(strip $(shell ${CC_VENDOR} --version | egrep -o '[0-9]+\.[0-9]+\.?[0-9]*')) + CC_MAJOR := $(shell (echo ${VENDOR_STRING} | cut -d. -f1)) + + ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC_4')),1) + # AOCC version 4x we will enable znver2 + CKVECFLAGS += -march=znver2 + else ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC_3')),1) + # AOCC version 3x we will enable znver2 + CKVECFLAGS += -march=znver2 + else ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC.LLVM.2\|AOCC_2')),1) + # AOCC version 2x we will enable znver2 + CKVECFLAGS += -march=znver2 + else ifeq ($(shell test $(CC_MAJOR) -ge 9; echo $$?),0) + # LLVM clang 9.0 or later + CKVECFLAGS += -march=znver2 + else + CKVECFLAGS += -march=znver1 + endif +endif # clang # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) diff --git a/config/zen3/bli_cntx_init_zen3.c b/config/zen3/bli_cntx_init_zen3.c index 779bb7277c..31a9ff5957 100644 --- a/config/zen3/bli_cntx_init_zen3.c +++ b/config/zen3/bli_cntx_init_zen3.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -115,7 +115,7 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 26, + 29, // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, @@ -148,6 +148,7 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // scalv BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, + BLIS_SCALV_KER, BLIS_DCOMPLEX, bli_zscalv_zen_int, //swap BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, @@ -156,10 +157,14 @@ void bli_cntx_init_zen3( cntx_t* cntx ) //copy BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, + BLIS_COPYV_KER, BLIS_DCOMPLEX, bli_zcopyv_zen_int, //set BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, + + // scal2v + BLIS_SCAL2V_KER, BLIS_DCOMPLEX, bli_zscal2v_zen_int, cntx ); @@ -243,7 +248,7 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // Update the context with optimized small/unpacked gemm kernels. bli_cntx_set_l3_sup_kers ( - 28, + 30, //BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_r_haswell_ref, BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, @@ -268,9 +273,11 @@ void bli_cntx_init_zen3( cntx_t* cntx ) BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_RRC, BLIS_DCOMPLEX, bli_zgemmsup_rd_zen_asm_3x4m, TRUE, BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, - BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_CRC, BLIS_DCOMPLEX, bli_zgemmsup_rd_zen_asm_3x4n, TRUE, BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, cntx diff --git a/config/zen3/make_defs.mk b/config/zen3/make_defs.mk index 8522a1e956..7ec1ee32e9 100644 --- a/config/zen3/make_defs.mk +++ b/config/zen3/make_defs.mk @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2019-2023, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -42,6 +42,11 @@ THIS_CONFIG := zen3 #CONFIGS_INCL += $(THIS_CONFIG) +# Include file containing common flags for all AMD architectures +AMD_CONFIG_FILE := amd_config.mk +AMD_CONFIG_PATH := $(BASE_SHARE_PATH)/config/zen +-include $(AMD_CONFIG_PATH)/$(AMD_CONFIG_FILE) + # # --- Determine the C compiler and related flags --- # @@ -56,69 +61,77 @@ CPICFLAGS := CWARNFLAGS := ifneq ($(DEBUG_TYPE),off) -CDBGFLAGS := -g + CDBGFLAGS := -g endif ifeq ($(DEBUG_TYPE),noopt) -COPTFLAGS := -O0 + COPTFLAGS := -O0 else -COPTFLAGS := -O3 + COPTFLAGS := -O3 endif # Flags specific to optimized kernels. # NOTE: The -fomit-frame-pointer option is needed for some kernels because # they make explicit use of the rbp register. CKOPTFLAGS := $(COPTFLAGS) -fomit-frame-pointer + +# gcc or clang version must be at least 4.0 ifeq ($(CC_VENDOR),gcc) -GCC_VERSION := $(strip $(shell $(CC) -dumpversion | cut -d. -f1)) -# gcc or clang version must be atleast 4.0 -# gcc 9.0 or later: -ifeq ($(shell test $(GCC_VERSION) -ge 11; echo $$?),0) -CKVECFLAGS += -march=znver3 -else -ifeq ($(shell test $(GCC_VERSION) -ge 9; echo $$?),0) -CKVECFLAGS += -march=znver2 -else -# If gcc is older than 9.1.0 but at least 6.1.0, then we can use -march=znver1 -# as the fallback option. -CRVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store -CKVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store -endif # GCC 9 -endif # GCC 11 -else + GCC_VERSION := $(strip $(shell $(CC) -dumpversion | cut -d. -f1)) + + ifeq ($(shell test $(GCC_VERSION) -ge 11; echo $$?),0) + # gcc 11.0 or later + CKVECFLAGS += -march=znver3 + # Update CKOPTFLAGS for gcc to use O3 optimization without + # -ftree-pre and -ftree-partial-pre flag. These flag results + # in suboptimal code generation for instrinsic based kernels. + # The -ftree-loop-vectorize results in inefficient code gen + # for amd optimized l1 kernels based on instrinsics. + CKOPTFLAGS += -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize -fno-gcse + else ifeq ($(shell test $(GCC_VERSION) -ge 9; echo $$?),0) + # gcc 9.0 or later + CKVECFLAGS += -march=znver2 + CKOPTFLAGS += -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize -fno-gcse + else + # If gcc is older than 9.1.0 but at least 6.1.0, then we can use -march=znver1 + # as the fallback option. + CKVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store + CRVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store + endif +endif # gcc + ifeq ($(CC_VENDOR),clang) + # AOCC clang has various formats for the version line -# AOCC clang has various formats for the version line + # AOCC.LLVM.2.0.0.B191.2019_07_19 clang version 8.0.0 (CLANG: Jenkins AOCC_2_0_0-Build#191) (based on LLVM AOCC.LLVM.2.0.0.B191.2019_07_19) + # AOCC.LLVM.2.1.0.B1030.2019_11_12 clang version 9.0.0 (CLANG: Build#1030) (based on LLVM AOCC.LLVM.2.1.0.B1030.2019_11_12) + # AMD clang version 10.0.0 (CLANG: AOCC_2.2.0-Build#93 2020_06_25) (based on LLVM Mirror.Version.10.0.0) + # AMD clang version 11.0.0 (CLANG: AOCC_2.3.0-Build#85 2020_11_10) (based on LLVM Mirror.Version.11.0.0) + # AMD clang version 12.0.0 (CLANG: AOCC_3.0.0-Build#2 2020_11_05) (based on LLVM Mirror.Version.12.0.0) + # AMD clang version 14.0.0 (CLANG: AOCC_4.0.0-Build#98 2022_06_15) (based on LLVM Mirror.Version.14.0.0) -# AOCC.LLVM.2.0.0.B191.2019_07_19 clang version 8.0.0 (CLANG: Jenkins AOCC_2_0_0-Build#191) (based on LLVM AOCC.LLVM.2.0.0.B191.2019_07_19) -# AOCC.LLVM.2.1.0.B1030.2019_11_12 clang version 9.0.0 (CLANG: Build#1030) (based on LLVM AOCC.LLVM.2.1.0.B1030.2019_11_12) -# AMD clang version 10.0.0 (CLANG: AOCC_2.2.0-Build#93 2020_06_25) (based on LLVM Mirror.Version.10.0.0) -# AMD clang version 11.0.0 (CLANG: AOCC_2.3.0-Build#85 2020_11_10) (based on LLVM Mirror.Version.11.0.0) -# AMD clang version 12.0.0 (CLANG: AOCC_3.0.0-Build#2 2020_11_05) (based on LLVM Mirror.Version.12.0.0) + # For our purpose we just want to know if it version 2x or 3x or 4x -# For our prupose we just want to know if it version 2x or 3x + # But also set these in case we are using upstream LLVM clang + VENDOR_STRING := $(strip $(shell ${CC_VENDOR} --version | egrep -o '[0-9]+\.[0-9]+\.?[0-9]*')) + CC_MAJOR := $(shell (echo ${VENDOR_STRING} | cut -d. -f1)) -# for version 3x we will enable znver3 -ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC_3')),1) -CKVECFLAGS += -march=znver3 -else -# for version 2x we will enable znver2 -ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC.LLVM.2\|AOCC_2')),1) -CKVECFLAGS += -march=znver2 -else -#if compiling with clang -VENDOR_STRING := $(strip $(shell ${CC_VENDOR} --version | egrep -o '[0-9]+\.[0-9]+\.?[0-9]*')) -CC_MAJOR := $(shell (echo ${VENDOR_STRING} | cut -d. -f1)) -#clang 9.0 or later: -ifeq ($(shell test $(CC_MAJOR) -ge 9; echo $$?),0) -CKVECFLAGS += -march=znver2 -else -CKVECFLAGS += -march=znver1 -endif # ge 9 -endif # aocc 2 -endif # aocc 3 + ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC_4')),1) + # AOCC version 4x we will enable znver3 + CKVECFLAGS += -march=znver3 + else ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC_3')),1) + # AOCC version 3x we will enable znver3 + CKVECFLAGS += -march=znver3 + else ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC.LLVM.2\|AOCC_2')),1) + # AOCC version 2x we will enable znver2 + CKVECFLAGS += -march=znver2 + else ifeq ($(shell test $(CC_MAJOR) -ge 9; echo $$?),0) + # LLVM clang 9.0 or later + CKVECFLAGS += -march=znver2 + else + CKVECFLAGS += -march=znver1 + endif endif # clang -endif # gcc # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) diff --git a/config/zen4/bli_cntx_init_zen4.c b/config/zen4/bli_cntx_init_zen4.c index ac9875abf6..8dda84ccce 100644 --- a/config/zen4/bli_cntx_init_zen4.c +++ b/config/zen4/bli_cntx_init_zen4.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -39,16 +39,29 @@ * Converted it to macro as this list is used at multiple places in this file. */ -#define BLI_CNTX_DEFAULT_BLKSZ_LIST(blkszs) \ +#define BLI_CNTX_DEFAULT_BLKSZ_LIST_GENOA(blkszs) \ /* s d c z */ \ - bli_blksz_init_easy( &blkszs[ BLIS_MR ], 32, 16, 3, 3 ); \ - bli_blksz_init_easy( &blkszs[ BLIS_NR ], 12, 14, 8, 4 ); \ - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 512, 240, 144, 18 ); \ - bli_blksz_init ( &blkszs[ BLIS_KC ], 480, 512, 256, 566, \ - 480, 320, 256, 566 ); \ - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 6144, 4004, 4080, 256 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 32, 32, 3, 12 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 12, 6, 8, 4 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 512, 128, 144, 60 ); \ + bli_blksz_init ( &blkszs[ BLIS_KC ], 480, 512, 256, 512, \ + 480, 320, 256, 160 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 6144, 4002, 4080, 2004 ); \ \ - bli_blksz_init_easy( &blkszs[ BLIS_AF ], 8, 8, -1, -1 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_AF ], 5, 5, -1, -1 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); \ + + +#define BLI_CNTX_DEFAULT_BLKSZ_LIST_BERGAMO(blkszs) \ + /* s d c z */ \ + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 32, 32, 3, 12 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 12, 6, 8, 4 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 512, 64, 144, 60 ); \ + bli_blksz_init ( &blkszs[ BLIS_KC ], 480, 512, 256, 512, \ + 480, 320, 256, 160 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 6144, 3600, 4080, 2004 ); \ + \ + bli_blksz_init_easy( &blkszs[ BLIS_AF ], 5, 5, -1, -1 ); \ bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); \ @@ -68,19 +81,21 @@ void bli_cntx_init_zen4( cntx_t* cntx ) 10, // gemm BLIS_GEMM_UKR, BLIS_FLOAT , bli_sgemm_skx_asm_32x12_l2, FALSE, - BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_skx_asm_16x14, FALSE, - BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemm_haswell_asm_3x8, TRUE, - BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemm_haswell_asm_3x4, TRUE, + BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_zen4_asm_32x6, FALSE, + BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemm_haswell_asm_3x8, TRUE, + /*bli_zgemm_zen4_asm_12x4 is a column preferred kernel*/ + BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemm_zen4_asm_12x4, FALSE, - BLIS_GEMM_AVX2_UKR, BLIS_FLOAT, bli_sgemm_haswell_asm_6x16, TRUE, - BLIS_GEMM_AVX2_UKR, BLIS_DOUBLE, bli_dgemm_haswell_asm_6x8, TRUE, + // Different GEMM kernels are used for TRSM for zen4 architecture + BLIS_GEMM_FOR_TRSM_UKR, BLIS_FLOAT, bli_sgemm_haswell_asm_6x16, TRUE, + BLIS_GEMM_FOR_TRSM_UKR, BLIS_DOUBLE, bli_dgemm_zen4_asm_8x24, TRUE, // gemmtrsm_l BLIS_GEMMTRSM_L_UKR, BLIS_FLOAT, bli_sgemmtrsm_l_haswell_asm_6x16, TRUE, - BLIS_GEMMTRSM_L_UKR, BLIS_DOUBLE, bli_dgemmtrsm_l_zen_asm_16x14, TRUE, + BLIS_GEMMTRSM_L_UKR, BLIS_DOUBLE, bli_dgemmtrsm_l_zen4_asm_8x24, TRUE, // gemmtrsm_u BLIS_GEMMTRSM_U_UKR, BLIS_FLOAT, bli_sgemmtrsm_u_haswell_asm_6x16, TRUE, - BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsm_u_zen_asm_16x14, TRUE, + BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsm_u_zen4_asm_8x24, TRUE, cntx ); @@ -88,7 +103,9 @@ void bli_cntx_init_zen4( cntx_t* cntx ) // Update the context with architecture specific threshold functions bli_cntx_set_l3_thresh_funcs ( - 2, + 3, + // GEMM + BLIS_GEMM, bli_cntx_gemmsup_thresh_is_met_zen4, // GEMMT BLIS_GEMMT, bli_cntx_gemmtsup_thresh_is_met_zen, // SYRK @@ -99,15 +116,18 @@ void bli_cntx_init_zen4( cntx_t* cntx ) // packm kernels bli_cntx_set_packm_kers ( - 8, + 11, BLIS_PACKM_6XK_KER, BLIS_FLOAT, bli_spackm_haswell_asm_6xk, BLIS_PACKM_16XK_KER, BLIS_FLOAT, bli_spackm_haswell_asm_16xk, BLIS_PACKM_6XK_KER, BLIS_DOUBLE, bli_dpackm_haswell_asm_6xk, - BLIS_PACKM_8XK_KER, BLIS_DOUBLE, bli_dpackm_haswell_asm_8xk, + BLIS_PACKM_8XK_KER, BLIS_DOUBLE, bli_dpackm_zen4_asm_8xk, + BLIS_PACKM_24XK_KER, BLIS_DOUBLE, bli_dpackm_zen4_asm_24xk, + BLIS_PACKM_32XK_KER, BLIS_DOUBLE, bli_dpackm_zen4_asm_32xk, BLIS_PACKM_3XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_3xk, BLIS_PACKM_8XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_8xk, BLIS_PACKM_3XK_KER, BLIS_DCOMPLEX, bli_zpackm_haswell_asm_3xk, - BLIS_PACKM_4XK_KER, BLIS_DCOMPLEX, bli_zpackm_haswell_asm_4xk, + BLIS_PACKM_12XK_KER, BLIS_DCOMPLEX, bli_zpackm_zen4_asm_12xk, + BLIS_PACKM_4XK_KER, BLIS_DCOMPLEX, bli_zpackm_zen4_asm_4xk, cntx ); @@ -133,7 +153,7 @@ void bli_cntx_init_zen4( cntx_t* cntx ) // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 24, + 28, // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int_avx512, @@ -146,24 +166,26 @@ void bli_cntx_init_zen4( cntx_t* cntx ) BLIS_AXPBYV_KER, BLIS_DCOMPLEX, bli_zaxpbyv_zen_int, // axpyv - BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int10, - BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int10, + BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int_avx512, + BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int_avx512, BLIS_AXPYV_KER, BLIS_SCOMPLEX, bli_caxpyv_zen_int5, BLIS_AXPYV_KER, BLIS_DCOMPLEX, bli_zaxpyv_zen_int5, // dotv - BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int10, - BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int10, + BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int_avx512, + BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int_avx512, BLIS_DOTV_KER, BLIS_SCOMPLEX, bli_cdotv_zen_int5, BLIS_DOTV_KER, BLIS_DCOMPLEX, bli_zdotv_zen_int5, // dotxv BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, + BLIS_DOTXV_KER, BLIS_DCOMPLEX, bli_zdotxv_zen_int, // scalv - BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, - BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, + BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int_avx512, + BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int_avx512, + BLIS_SCALV_KER, BLIS_DCOMPLEX, bli_zscalv_zen_int, //swap BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, @@ -172,10 +194,14 @@ void bli_cntx_init_zen4( cntx_t* cntx ) //copy BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, + BLIS_COPYV_KER, BLIS_DCOMPLEX, bli_zcopyv_zen_int, //set BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, + + // scal2v + BLIS_SCAL2V_KER, BLIS_DCOMPLEX, bli_zscal2v_zen_int, cntx ); @@ -183,8 +209,15 @@ void bli_cntx_init_zen4( cntx_t* cntx ) // // These are reference block sizes and may be overridden based on // number of threads used at runtime. - - BLI_CNTX_DEFAULT_BLKSZ_LIST(blkszs); + + if ( bli_init_model_query_id() == BLIS_MODEL_BERGAMO ) + { + BLI_CNTX_DEFAULT_BLKSZ_LIST_BERGAMO(blkszs); + } + else // BLIS_MODEL_DEFAULT choice, also currently used for BLIS_MODEL_GENOA and BLIS_MODEL_GENOA_X + { + BLI_CNTX_DEFAULT_BLKSZ_LIST_GENOA(blkszs); + } // Update the context with the current architecture's register and cache // blocksizes (and multiples) for native execution. @@ -205,8 +238,8 @@ void bli_cntx_init_zen4( cntx_t* cntx ) // ------------------------------------------------------------------------- // Initialize sup thresholds with architecture-appropriate values. s d c z - bli_blksz_init_easy( &thresh[ BLIS_MT ], 512, 256, 380, 110 ); - bli_blksz_init_easy( &thresh[ BLIS_NT ], 200, 256, 256, 128 ); + bli_blksz_init_easy( &thresh[ BLIS_MT ], 682, 1000, 380, 110 ); + bli_blksz_init_easy( &thresh[ BLIS_NT ], 512, 1000, 256, 128 ); bli_blksz_init_easy( &thresh[ BLIS_KT ], 240, 220, 220, 110 ); // Initialize the context with the sup thresholds. @@ -231,48 +264,49 @@ void bli_cntx_init_zen4( cntx_t* cntx ) // Update the context with optimized small/unpacked gemm kernels. bli_cntx_set_l3_sup_kers ( - 28, - //BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_r_haswell_ref, - BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, - BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, - BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, - BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, - BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, - BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8n, TRUE, - BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, - BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, - BLIS_RRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, - BLIS_RRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16m, TRUE, - BLIS_RCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, - BLIS_RCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, - BLIS_CRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, - BLIS_CRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16n, TRUE, - BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, - BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + 30, + BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, + BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, + BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, + BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, + BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, + BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, + BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, + BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_zen4_asm_24x8m, FALSE, + BLIS_RRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64m_avx512, TRUE, + BLIS_RRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x64m_avx512, TRUE, + BLIS_RCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64m_avx512, TRUE, + BLIS_RCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64n_avx512, TRUE, + BLIS_CRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64m_avx512, TRUE, + BLIS_CRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x64n_avx512, TRUE, + BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64n_avx512, TRUE, + BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x64n_avx512, TRUE, BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, BLIS_RCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, - BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, - BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, - BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, - BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, - BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, - BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + BLIS_RRC, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + BLIS_CRC, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, + BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_cv_zen4_asm_12x4m, FALSE, cntx ); // Initialize level-3 sup blocksize objects with architecture-specific // values. // s d c z - bli_blksz_init ( &blkszs[ BLIS_MR ], 6, 6, 3, 3, - 9, 9, 3, 3 ); - bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 72, 36 ); - bli_blksz_init_easy( &blkszs[ BLIS_KC ], 512, 256, 128, 64 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8160, 4080, 2040, 1020 ); + bli_blksz_init ( &blkszs[ BLIS_MR ], 6, 24, 3, 12, + 6, 9, 3, 12 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 64, 8, 8, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 192, 144, 72, 48 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 512, 480, 128, 64 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8064, 4080, 2040, 1020 ); // Update the context with the current architecture's register and cache // blocksizes for small/unpacked level-3 problems. @@ -292,21 +326,21 @@ void bli_cntx_init_zen4( cntx_t* cntx ) * Override the block sizes in the context to the block sizes used * by AVX2 GEMM+TRSM kernels, this is needed in Zen4 context as default * GEMM kernels are AVX512 based and uses different block sizes. - * + * * This function should be called in TRSM path before performing - * any packing operations. - * - * Also the context must be restored to default values by calling + * any packing operations. + * + * Also the context must be restored to default values by calling * bli_zen4_restore_default_blkszs() before exiting TRSM Path */ void bli_zen4_override_trsm_blkszs (cntx_t* cntx) { blksz_t blkszs[ BLIS_NUM_BLKSZS ]; - bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 16, 3, 3 ); - bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 14, 8, 4 ); - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 240, 144, 72 ); + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 8, 3, 3 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 24, 8, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 120, 144, 72 ); bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 512, 256, 256 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4004, 4080, 4080 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4008, 4080, 4080 ); // Update the context with the current architecture's register and cache @@ -324,21 +358,88 @@ void bli_zen4_override_trsm_blkszs (cntx_t* cntx) ); } + +// Since the output of syrk/gemmt is a triangular matrix, +// near-to-square shaped kernel performs better than +// skewed/rectangular shaped kernel. +// Hence we are overriding blocksizes and kernel +// function pointers for gemmt/syrk with avx2 specific ones +void bli_zen4_override_gemmt_blkszs (cntx_t* cntx) +{ + blksz_t blkszs[ BLIS_NUM_BLKSZS ]; + + bli_blksz_init ( &blkszs[ BLIS_MR ], 6, 6, 3, 3, + 9, 9, 3, 3 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 72, 36 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 512, 256, 128, 64 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8160, 4080, 2040, 1020 ); + // Update the context with the current architecture's register and cache + // blocksizes (and multiples) for native execution. + bli_cntx_set_l3_sup_blkszs + ( + 4, + // level-3 + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); + + bli_cntx_set_l3_sup_kers + ( + 24, + BLIS_RRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_RRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16m, TRUE, + BLIS_RCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_RCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + BLIS_CRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_CRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16n, TRUE, + BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, + BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8n, TRUE, + BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_RRC, BLIS_DCOMPLEX, bli_zgemmsup_rd_zen_asm_3x4m, TRUE, + BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_CRC, BLIS_DCOMPLEX, bli_zgemmsup_rd_zen_asm_3x4n, TRUE, + BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + cntx + ); +} + /* * Restore the block sizes to default values needed for zen4 context. * * This function should be called to restore the block sizes to there * default values if they where overriden by calling - * bli_zen4_override_trsm_blkszs() to enable AVX2 GEMM kernels in the + * bli_zen4_override_trsm_blkszs() to enable AVX2 GEMM kernels in the * TRSM path. - * + * */ void bli_zen4_restore_default_blkszs (cntx_t* cntx) { blksz_t blkszs[ BLIS_NUM_BLKSZS ]; - BLI_CNTX_DEFAULT_BLKSZ_LIST(blkszs); - + if ( bli_init_model_query_id() == BLIS_MODEL_BERGAMO ) + { + BLI_CNTX_DEFAULT_BLKSZ_LIST_BERGAMO(blkszs); + } + else // BLIS_MODEL_DEFAULT choice, also currently used for BLIS_MODEL_GENOA and BLIS_MODEL_GENOA_X + { + BLI_CNTX_DEFAULT_BLKSZ_LIST_GENOA(blkszs); + } + // Update the context with the current architecture's register and cache // blocksizes (and multiples) for native execution. bli_cntx_set_blkszs diff --git a/config/zen4/bli_family_zen4.h b/config/zen4/bli_family_zen4.h index b21d1582f7..a1666ea9d3 100644 --- a/config/zen4/bli_family_zen4.h +++ b/config/zen4/bli_family_zen4.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -64,23 +64,25 @@ * Override the block sizes in the context to the block sizes used * by AVX2 GEMM+TRSM kernels, this is needed in Zen4 context as default * GEMM kernels are AVX512 based and uses different block sizes. - * + * * This function should be called in TRSM path before performing - * any packing operations. - * - * Also the context must be restored to default values by calling + * any packing operations. + * + * Also the context must be restored to default values by calling * bli_zen4_restore_default_blkszs() before exiting TRSM Path */ BLIS_EXPORT_BLIS void bli_zen4_override_trsm_blkszs (cntx_t* cntx); +BLIS_EXPORT_BLIS void bli_zen4_override_gemmt_blkszs (cntx_t* cntx); + /* * Restore the block sizes to default values needed for zen4 context. * * This function should be called to restore the block sizes to there * default values if they where overriden by calling - * bli_zen4_override_trsm_blkszs() to enable AVX2 GEMM kernels in the + * bli_zen4_override_trsm_blkszs() to enable AVX2 GEMM kernels in the * TRSM path. - * + * */ BLIS_EXPORT_BLIS void bli_zen4_restore_default_blkszs (cntx_t* cntx); diff --git a/config/zen4/make_defs.mk b/config/zen4/make_defs.mk index 062e680910..5a058e2fbc 100644 --- a/config/zen4/make_defs.mk +++ b/config/zen4/make_defs.mk @@ -4,7 +4,7 @@ # An object-based framework for developing high-performance BLAS-like # libraries. # -# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022-2023, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -41,6 +41,11 @@ THIS_CONFIG := zen4 #CONFIGS_INCL += $(THIS_CONFIG) +# Include file containing common flags for all AMD architectures +AMD_CONFIG_FILE := amd_config.mk +AMD_CONFIG_PATH := $(BASE_SHARE_PATH)/config/zen +-include $(AMD_CONFIG_PATH)/$(AMD_CONFIG_FILE) + # # --- Determine the C compiler and related flags --- # @@ -55,105 +60,105 @@ CPICFLAGS := CWARNFLAGS := ifneq ($(DEBUG_TYPE),off) -CDBGFLAGS := -g + CDBGFLAGS := -g endif ifeq ($(DEBUG_TYPE),noopt) -COPTFLAGS := -O0 + COPTFLAGS := -O0 else -COPTFLAGS := -O3 + COPTFLAGS := -O3 endif # Flags specific to optimized kernels. # NOTE: The -fomit-frame-pointer option is needed for some kernels because # they make explicit use of the rbp register. CKOPTFLAGS := $(COPTFLAGS) -fomit-frame-pointer -ifeq ($(CC_VENDOR),gcc) -GCC_VERSION := $(strip $(shell $(CC) -dumpversion | cut -d. -f1)) - -# gcc 11.0 or later: -ifeq ($(shell test $(GCC_VERSION) -ge 11; echo $$?),0) -# Update CKOPTFLAGS for gcc 11+ to use O3 optimization without -# -ftree-partial-pre flag. This flag results in suboptimal code -# generation for instrinsics based kernels. -ifneq ($(DEBUG_TYPE),noopt) -CKOPTFLAGS := -O2 -fgcse-after-reload -fipa-cp-clone -floop-interchange -floop-unroll-and-jam -fpeel-loops -fpredictive-commoning -fsplit-loops -fsplit-paths -ftree-loop-distribution -funswitch-loops -fvect-cost-model=dynamic -fversion-loops-for-strides -fomit-frame-pointer -endif +# gcc or clang version must be at least 4.0 +ifeq ($(CC_VENDOR),gcc) + GCC_VERSION := $(strip $(shell $(CC) -dumpversion | cut -d. -f1)) + + ifeq ($(shell test $(GCC_VERSION) -ge 13; echo $$?),0) + # gcc 13.0 or later + CKVECFLAGS += -march=znver4 + CRVECFLAGS += -march=znver4 + # Update CKOPTFLAGS for gcc to use O3 optimization without + # -ftree-pre and -ftree-partial-pre flag. These flag results + # in suboptimal code generation for instrinsic based kernels. + # The -ftree-loop-vectorize results in inefficient code gen + # for amd optimized l1 kernels based on instrinsics. + CKOPTFLAGS += -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize + else ifeq ($(shell test $(GCC_VERSION) -ge 11; echo $$?),0) + # gcc 11.0 or later + CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 + CRVECFLAGS += -march=znver3 + CKOPTFLAGS += -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize + else ifeq ($(shell test $(GCC_VERSION) -ge 9; echo $$?),0) + # gcc 9.0 or later + CKVECFLAGS += -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni + CRVECFLAGS += -march=znver2 + CKOPTFLAGS += -fno-tree-partial-pre -fno-tree-pre -fno-tree-loop-vectorize + else ifeq ($(shell test $(GCC_VERSION) -ge 8; echo $$?),0) + # gcc 8.0 or later + CKVECFLAGS += -march=znver1 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni + CRVECFLAGS += -march=znver1 + else + # If gcc is older than 8.0.0 but at least 6.1.0, then we can use -march=znver1 + # as the fallback option. + CKVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store + CRVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store + endif +endif # gcc -CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -mfpmath=sse -CRVECFLAGS += -march=znver3 -else -# gcc 9.0 or later: -ifeq ($(shell test $(GCC_VERSION) -ge 9; echo $$?),0) -CKVECFLAGS += -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mfpmath=sse -CRVECFLAGS += -march=znver2 -else -ifeq ($(shell test $(GCC_VERSION) -ge 8; echo $$?),0) -CKVECFLAGS += -march=znver1 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mfpmath=sse -CRVECFLAGS += -march=znver1 -else -# If gcc is older than 8.0.0 but at least 6.1.0, then we can use -march=znver1 -# as the fallback option. -CKVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store -CRVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store -endif # GCC 8 -endif # GCC 9 -endif # GCC 11 -else ifeq ($(CC_VENDOR),clang) - -# AOCC clang has various formats for the version line - -# AOCC.LLVM.2.0.0.B191.2019_07_19 clang version 8.0.0 (CLANG: Jenkins AOCC_2_0_0-Build#191) (based on LLVM AOCC.LLVM.2.0.0.B191.2019_07_19) -# AOCC.LLVM.2.1.0.B1030.2019_11_12 clang version 9.0.0 (CLANG: Build#1030) (based on LLVM AOCC.LLVM.2.1.0.B1030.2019_11_12) -# AMD clang version 10.0.0 (CLANG: AOCC_2.2.0-Build#93 2020_06_25) (based on LLVM Mirror.Version.10.0.0) -# AMD clang version 11.0.0 (CLANG: AOCC_2.3.0-Build#85 2020_11_10) (based on LLVM Mirror.Version.11.0.0) -# AMD clang version 12.0.0 (CLANG: AOCC_3.0.0-Build#2 2020_11_05) (based on LLVM Mirror.Version.12.0.0) -# AMD clang version 14.0.0 (CLANG: AOCC_4.0.0-Build#98 2022_06_15) (based on LLVM Mirror.Version.14.0.0) - -# For our prupose we just want to know if it version 2x or 3x or 4x - -# for version 4x we will enable znver4 -ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC_4')),1) -CKVECFLAGS += -march=znver4 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512bf16 -mfpmath=sse -CRVECFLAGS += -march=znver4 -else -# for version 3x we will enable znver3 -ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC_3')),1) -CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -mfpmath=sse -CRVECFLAGS += -march=znver3 -else -# for version 2x we will enable znver2 -ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC.LLVM.2\|AOCC_2')),1) -CKVECFLAGS += -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mfpmath=sse -CRVECFLAGS += -march=znver2 -else -#if compiling with clang -VENDOR_STRING := $(strip $(shell ${CC_VENDOR} --version | egrep -o '[0-9]+\.[0-9]+\.?[0-9]*')) -CC_MAJOR := $(shell (echo ${VENDOR_STRING} | cut -d. -f1)) -#clang 9.0 or later: -ifeq ($(shell test $(CC_MAJOR) -ge 9; echo $$?),0) -CKVECFLAGS += -march=znver2 -CRVECFLAGS += -march=znver2 -else -CKVECFLAGS += -march=znver1 -CRVECFLAGS += -march=znver1 -endif # ge 9 -endif # aocc 2 -endif # aocc 3 -endif # aocc 4 + # AOCC clang has various formats for the version line + + # AOCC.LLVM.2.0.0.B191.2019_07_19 clang version 8.0.0 (CLANG: Jenkins AOCC_2_0_0-Build#191) (based on LLVM AOCC.LLVM.2.0.0.B191.2019_07_19) + # AOCC.LLVM.2.1.0.B1030.2019_11_12 clang version 9.0.0 (CLANG: Build#1030) (based on LLVM AOCC.LLVM.2.1.0.B1030.2019_11_12) + # AMD clang version 10.0.0 (CLANG: AOCC_2.2.0-Build#93 2020_06_25) (based on LLVM Mirror.Version.10.0.0) + # AMD clang version 11.0.0 (CLANG: AOCC_2.3.0-Build#85 2020_11_10) (based on LLVM Mirror.Version.11.0.0) + # AMD clang version 12.0.0 (CLANG: AOCC_3.0.0-Build#2 2020_11_05) (based on LLVM Mirror.Version.12.0.0) + # AMD clang version 14.0.0 (CLANG: AOCC_4.0.0-Build#98 2022_06_15) (based on LLVM Mirror.Version.14.0.0) + + # For our purpose we just want to know if it version 2x or 3x or 4x + + # But also set these in case we are using upstream LLVM clang + VENDOR_STRING := $(strip $(shell ${CC_VENDOR} --version | egrep -o '[0-9]+\.[0-9]+\.?[0-9]*')) + CC_MAJOR := $(shell (echo ${VENDOR_STRING} | cut -d. -f1)) + + ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC_4')),1) + # AOCC version 4x we will enable znver4 + CKVECFLAGS += -march=znver4 -falign-loops=64 + CRVECFLAGS += -march=znver4 + else ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC_3')),1) + # AOCC version 3x we will enable znver3 + CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -falign-loops=64 + CRVECFLAGS += -march=znver3 + else ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC.LLVM.2\|AOCC_2')),1) + # AOCC version 2x we will enable znver2 + CKVECFLAGS += -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni + CRVECFLAGS += -march=znver2 + else ifeq ($(shell test $(CC_MAJOR) -ge 16; echo $$?),0) + # LLVM clang 16.0 or later + CKVECFLAGS += -march=znver4 -falign-loops=64 + CRVECFLAGS += -march=znver4 + else ifeq ($(shell test $(CC_MAJOR) -ge 13; echo $$?),0) + # LLVM clang 13.0 or later + CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -falign-loops=64 + CRVECFLAGS += -march=znver3 + else ifeq ($(shell test $(CC_MAJOR) -ge 9; echo $$?),0) + # LLVM clang 9.0 or later + CKVECFLAGS += -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -falign-loops=64 + CRVECFLAGS += -march=znver2 + else + CKVECFLAGS += -march=znver1 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -falign-loops=64 + CRVECFLAGS += -march=znver1 + endif endif # clang -endif # gcc # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) - -# Flags specific to reference kernels. -# Note: We use AVX2 for reference kernels because, as Jeff Hammond says, -# reference kernel code "is not going to achieve high enough SIMD utilization -# to overcome the AVX-512 frequency drop". (Issue #187) -CRVECFLAGS += -mno-avx512f -mno-avx512vl -mno-avx512bw -mno-avx512dq -mno-avx512cd -funsafe-math-optimizations -ffp-contract=fast +CRVECFLAGS := $(CKVECFLAGS) # Store all of the variables here to new variables containing the # configuration name. diff --git a/configure b/configure index 73dc8cc358..a165c1ad51 100755 --- a/configure +++ b/configure @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2020-2023, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -215,7 +215,7 @@ print_usage() echo " " echo " Set the size (in bits) of internal BLIS integers and" echo " integer types used in native BLIS interfaces. The" - echo " default inteter type size is architecture dependent." + echo " default integer type size is architecture dependent." echo " (Hint: You can always find this value printed at the" echo " beginning of the testsuite output.)" echo " " @@ -355,17 +355,22 @@ print_usage() echo " " echo " --enable-blis-arch-type, --disable-blis-arch-type" echo " " - echo " Disable (Enabled by default) support for BLIS_ARCH_TYPE" - echo " environment variable, which allows user to select" - echo " architecture-specific code path at runtime." + echo " Disable (Enabled by default) support for BLIS_ARCH_TYPE and BLIS_MODEL_TYPE" + echo " environment variables, which allows user to select" + echo " architecture specific code path and optimizations at runtime." echo " If disabled, in builds with multiple code paths, BLIS" - echo " will still select path automatically." + echo " will still select path and optimizations automatically." echo " " echo " --rename-blis-arch-type=STRING" echo " " - echo " Change environment variable used to select architecture-specific" + echo " Change environment variable used to select architecture specific" echo " code path from BLIS_ARCH_TYPE to STRING" echo " " + echo " --rename-blis-model-type=STRING" + echo " " + echo " Change environment variable used to select architecture model specific" + echo " optimizations from BLIS_MODEL_TYPE to STRING" + echo " " echo " -q, --quiet Suppress informational output. By default, configure" echo " is verbose. (NOTE: -q is not yet implemented)" echo " " @@ -1163,6 +1168,7 @@ auto_detect() cmd="${cc} ${config_defines} \ -DBLIS_CONFIGURETIME_CPUID \ -D__blis_arch_type_name=${double_quote_open}${rename_blis_arch_type}${double_quote_close} \ + -D__blis_model_type_name=${double_quote_open}${rename_blis_model_type}${double_quote_close} \ ${c_hdr_paths} \ -std=c99 -D_GNU_SOURCE \ ${cflags} \ @@ -2043,6 +2049,7 @@ main() complex_return='default' disable_blis_arch_type='no' rename_blis_arch_type='BLIS_ARCH_TYPE' + rename_blis_model_type='BLIS_MODEL_TYPE' # The addon flag and names. addon_flag='' @@ -2281,6 +2288,9 @@ main() rename-blis-arch-type=*) rename_blis_arch_type=${OPTARG#*=} ;; + rename-blis-model-type=*) + rename_blis_model_type=${OPTARG#*=} + ;; *) print_usage ;; @@ -3076,13 +3086,6 @@ main() echo "${script_name}: compiler appears to not support #pragma omp simd." enable_pragma_omp_simd_01=0 fi - if [ "x${enable_blas}" = "xyes" ]; then - echo "${script_name}: the BLAS compatibility layer is enabled." - enable_blas_01=1 - else - echo "${script_name}: the BLAS compatibility layer is disabled." - enable_blas_01=0 - fi if [ "x${enable_cblas}" = "xyes" ]; then echo "${script_name}: the CBLAS compatibility layer is enabled." enable_cblas_01=1 @@ -3092,6 +3095,13 @@ main() echo "${script_name}: the CBLAS compatibility layer is disabled." enable_cblas_01=0 fi + if [ "x${enable_blas}" = "xyes" ]; then + echo "${script_name}: the BLAS compatibility layer is enabled." + enable_blas_01=1 + else + echo "${script_name}: the BLAS compatibility layer is disabled." + enable_blas_01=0 + fi if [ "x${enable_mixed_dt}" = "xyes" ]; then echo "${script_name}: mixed datatype support is enabled." @@ -3257,7 +3267,7 @@ main() fi if [ "x${disable_blis_arch_type}" = "xyes" ]; then - echo "${script_name}: user selection of code path using BLIS_ARCH_TYPE env var is disabled." + echo "${script_name}: user selection of code path using BLIS_ARCH_TYPE and BLIS_MODEL_TYPE env vars is disabled." disable_blis_arch_type_01='1' else disable_blis_arch_type_01='0' @@ -3267,6 +3277,10 @@ main() if [ "x${rename_blis_arch_type}" != "xBLIS_ARCH_TYPE" ]; then echo "${script_name}: configuring with BLIS_ARCH_TYPE env var renamed to '${rename_blis_arch_type}'." fi + # Check if the user requested a custom env var name to replace BLIS_MODEL_TYPE. + if [ "x${rename_blis_model_type}" != "xBLIS_MODEL_TYPE" ]; then + echo "${script_name}: configuring with BLIS_MODEL_TYPE env var renamed to '${rename_blis_model_type}'." + fi echo "${script_name}: configuring complex return type as \"${complex_return}\"." @@ -3482,6 +3496,7 @@ main() | sed -e "s/@complex_return_intel@/${complex_return_intel01}/g" \ | sed -e "s/@disable_blis_arch_type@/${disable_blis_arch_type_01}/g" \ | sed -e "s/@rename_blis_arch_type@/${rename_blis_arch_type}/g" \ + | sed -e "s/@rename_blis_model_type@/${rename_blis_model_type}/g" \ > "${bli_config_h_out_path}" # -- Instantiate bli_addon.h file from template ---------------------------- diff --git a/docs/BLISTypedAPI.md b/docs/BLISTypedAPI.md index 7d6e92edac..e495aa00a8 100644 --- a/docs/BLISTypedAPI.md +++ b/docs/BLISTypedAPI.md @@ -1891,7 +1891,7 @@ Possible microkernel types (ie: the return values for `bli_info_get_*_ukr_impl_s ### Operation implementation type query -The following routines allow the caller to obtain a string that identifies the implementation (`ind_t`) that is currently active (ie: implemented and enabled) for each level-3 operation. Possible implementation types are listed in the section above covering [microkernel implemenation query](BLISTypedAPI.md#microkernel-implementation-type-query). +The following routines allow the caller to obtain a string that identifies the implementation (`ind_t`) that is currently active (ie: implemented and enabled) for each level-3 operation. Possible implementation types are listed in the section above covering [microkernel implementation query](BLISTypedAPI.md#microkernel-implementation-type-query). ```c char* bli_info_get_gemm_impl_string( num_t dt ); char* bli_info_get_hemm_impl_string( num_t dt ); diff --git a/docs/Doxyfile b/docs/Doxyfile new file mode 100644 index 0000000000..36ae286238 --- /dev/null +++ b/docs/Doxyfile @@ -0,0 +1,2842 @@ +# Doxyfile 1.9.6 + +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + +# This file describes the settings to be used by the documentation system +# doxygen (www.doxygen.org) for a project. +# +# All text after a double hash (##) is considered a comment and is placed in +# front of the TAG it is preceding. +# +# All text after a single hash (#) is considered a comment and will be ignored. +# The format is: +# TAG = value [value, ...] +# For lists, items can also be appended using: +# TAG += value [value, ...] +# Values that contain spaces should be placed between quotes (\" \"). +# +# Note: +# +# Use doxygen to compare the used configuration file with the template +# configuration file: +# doxygen -x [configFile] +# Use doxygen to compare the used configuration file with the template +# configuration file without replacing the environment variables or CMake type +# replacement variables: +# doxygen -x_noenv [configFile] + +#--------------------------------------------------------------------------- +# Project related configuration options +#--------------------------------------------------------------------------- + +# This tag specifies the encoding used for all characters in the configuration +# file that follow. The default is UTF-8 which is also the encoding used for all +# text before the first occurrence of this tag. Doxygen uses libiconv (or the +# iconv built into libc) for the transcoding. See +# https://www.gnu.org/software/libiconv/ for the list of possible encodings. +# The default value is: UTF-8. + +DOXYFILE_ENCODING = UTF-8 + +# The PROJECT_NAME tag is a single word (or a sequence of words surrounded by +# double-quotes, unless you are using Doxywizard) that should identify the +# project for which the documentation is generated. This name is used in the +# title of most generated pages and in a few other places. +# The default value is: My Project. + +PROJECT_NAME = AOCL-BLIS + +# The PROJECT_NUMBER tag can be used to enter a project or revision number. This +# could be handy for archiving the generated documentation or if some version +# control system is used. + +PROJECT_NUMBER = + +# Using the PROJECT_BRIEF tag one can provide an optional one line description +# for a project that appears at the top of each page and should give viewer a +# quick idea about the purpose of the project. Keep the description short. + +PROJECT_BRIEF = + +# With the PROJECT_LOGO tag one can specify a logo or an icon that is included +# in the documentation. The maximum height of the logo should not exceed 55 +# pixels and the maximum width should not exceed 200 pixels. Doxygen will copy +# the logo to the output directory. + +PROJECT_LOGO = ./styling/AMD_Logo.png + +# The OUTPUT_DIRECTORY tag is used to specify the (relative or absolute) path +# into which the generated documentation will be written. If a relative path is +# entered, it will be relative to the location where doxygen was started. If +# left blank the current directory will be used. + +OUTPUT_DIRECTORY = ./ + +# If the CREATE_SUBDIRS tag is set to YES then doxygen will create up to 4096 +# sub-directories (in 2 levels) under the output directory of each output format +# and will distribute the generated files over these directories. Enabling this +# option can be useful when feeding doxygen a huge amount of source files, where +# putting all generated files in the same directory would otherwise causes +# performance problems for the file system. Adapt CREATE_SUBDIRS_LEVEL to +# control the number of sub-directories. +# The default value is: NO. + +CREATE_SUBDIRS = NO + +# Controls the number of sub-directories that will be created when +# CREATE_SUBDIRS tag is set to YES. Level 0 represents 16 directories, and every +# level increment doubles the number of directories, resulting in 4096 +# directories at level 8 which is the default and also the maximum value. The +# sub-directories are organized in 2 levels, the first level always has a fixed +# number of 16 directories. +# Minimum value: 0, maximum value: 8, default value: 8. +# This tag requires that the tag CREATE_SUBDIRS is set to YES. + +CREATE_SUBDIRS_LEVEL = 8 + +# If the ALLOW_UNICODE_NAMES tag is set to YES, doxygen will allow non-ASCII +# characters to appear in the names of generated files. If set to NO, non-ASCII +# characters will be escaped, for example _xE3_x81_x84 will be used for Unicode +# U+3044. +# The default value is: NO. + +ALLOW_UNICODE_NAMES = NO + +# The OUTPUT_LANGUAGE tag is used to specify the language in which all +# documentation generated by doxygen is written. Doxygen will use this +# information to generate all constant output in the proper language. +# Possible values are: Afrikaans, Arabic, Armenian, Brazilian, Bulgarian, +# Catalan, Chinese, Chinese-Traditional, Croatian, Czech, Danish, Dutch, English +# (United States), Esperanto, Farsi (Persian), Finnish, French, German, Greek, +# Hindi, Hungarian, Indonesian, Italian, Japanese, Japanese-en (Japanese with +# English messages), Korean, Korean-en (Korean with English messages), Latvian, +# Lithuanian, Macedonian, Norwegian, Persian (Farsi), Polish, Portuguese, +# Romanian, Russian, Serbian, Serbian-Cyrillic, Slovak, Slovene, Spanish, +# Swedish, Turkish, Ukrainian and Vietnamese. +# The default value is: English. + +OUTPUT_LANGUAGE = English + +# If the BRIEF_MEMBER_DESC tag is set to YES, doxygen will include brief member +# descriptions after the members that are listed in the file and class +# documentation (similar to Javadoc). Set to NO to disable this. +# The default value is: YES. + +BRIEF_MEMBER_DESC = YES + +# If the REPEAT_BRIEF tag is set to YES, doxygen will prepend the brief +# description of a member or function before the detailed description +# +# Note: If both HIDE_UNDOC_MEMBERS and BRIEF_MEMBER_DESC are set to NO, the +# brief descriptions will be completely suppressed. +# The default value is: YES. + +REPEAT_BRIEF = YES + +# This tag implements a quasi-intelligent brief description abbreviator that is +# used to form the text in various listings. Each string in this list, if found +# as the leading text of the brief description, will be stripped from the text +# and the result, after processing the whole list, is used as the annotated +# text. Otherwise, the brief description is used as-is. If left blank, the +# following values are used ($name is automatically replaced with the name of +# the entity):The $name class, The $name widget, The $name file, is, provides, +# specifies, contains, represents, a, an and the. + +ABBREVIATE_BRIEF = "The $name class" \ + "The $name widget" \ + "The $name file" \ + is \ + provides \ + specifies \ + contains \ + represents \ + a \ + an \ + the + +# If the ALWAYS_DETAILED_SEC and REPEAT_BRIEF tags are both set to YES then +# doxygen will generate a detailed section even if there is only a brief +# description. +# The default value is: NO. + +ALWAYS_DETAILED_SEC = NO + +# If the INLINE_INHERITED_MEMB tag is set to YES, doxygen will show all +# inherited members of a class in the documentation of that class as if those +# members were ordinary class members. Constructors, destructors and assignment +# operators of the base classes will not be shown. +# The default value is: NO. + +INLINE_INHERITED_MEMB = NO + +# If the FULL_PATH_NAMES tag is set to YES, doxygen will prepend the full path +# before files name in the file list and in the header files. If set to NO the +# shortest path that makes the file name unique will be used +# The default value is: YES. + +FULL_PATH_NAMES = YES + +# The STRIP_FROM_PATH tag can be used to strip a user-defined part of the path. +# Stripping is only done if one of the specified strings matches the left-hand +# part of the path. The tag can be used to show relative paths in the file list. +# If left blank the directory from which doxygen is run is used as the path to +# strip. +# +# Note that you can specify absolute paths here, but also relative paths, which +# will be relative from the directory where doxygen is started. +# This tag requires that the tag FULL_PATH_NAMES is set to YES. + +STRIP_FROM_PATH = + +# The STRIP_FROM_INC_PATH tag can be used to strip a user-defined part of the +# path mentioned in the documentation of a class, which tells the reader which +# header file to include in order to use a class. If left blank only the name of +# the header file containing the class definition is used. Otherwise one should +# specify the list of include paths that are normally passed to the compiler +# using the -I flag. + +STRIP_FROM_INC_PATH = + +# If the SHORT_NAMES tag is set to YES, doxygen will generate much shorter (but +# less readable) file names. This can be useful is your file systems doesn't +# support long names like on DOS, Mac, or CD-ROM. +# The default value is: NO. + +SHORT_NAMES = NO + +# If the JAVADOC_AUTOBRIEF tag is set to YES then doxygen will interpret the +# first line (until the first dot) of a Javadoc-style comment as the brief +# description. If set to NO, the Javadoc-style will behave just like regular Qt- +# style comments (thus requiring an explicit @brief command for a brief +# description.) +# The default value is: NO. + +JAVADOC_AUTOBRIEF = NO + +# If the JAVADOC_BANNER tag is set to YES then doxygen will interpret a line +# such as +# /*************** +# as being the beginning of a Javadoc-style comment "banner". If set to NO, the +# Javadoc-style will behave just like regular comments and it will not be +# interpreted by doxygen. +# The default value is: NO. + +JAVADOC_BANNER = NO + +# If the QT_AUTOBRIEF tag is set to YES then doxygen will interpret the first +# line (until the first dot) of a Qt-style comment as the brief description. If +# set to NO, the Qt-style will behave just like regular Qt-style comments (thus +# requiring an explicit \brief command for a brief description.) +# The default value is: NO. + +QT_AUTOBRIEF = NO + +# The MULTILINE_CPP_IS_BRIEF tag can be set to YES to make doxygen treat a +# multi-line C++ special comment block (i.e. a block of //! or /// comments) as +# a brief description. This used to be the default behavior. The new default is +# to treat a multi-line C++ comment block as a detailed description. Set this +# tag to YES if you prefer the old behavior instead. +# +# Note that setting this tag to YES also means that rational rose comments are +# not recognized any more. +# The default value is: NO. + +MULTILINE_CPP_IS_BRIEF = NO + +# By default Python docstrings are displayed as preformatted text and doxygen's +# special commands cannot be used. By setting PYTHON_DOCSTRING to NO the +# doxygen's special commands can be used and the contents of the docstring +# documentation blocks is shown as doxygen documentation. +# The default value is: YES. + +PYTHON_DOCSTRING = YES + +# If the INHERIT_DOCS tag is set to YES then an undocumented member inherits the +# documentation from any documented member that it re-implements. +# The default value is: YES. + +INHERIT_DOCS = YES + +# If the SEPARATE_MEMBER_PAGES tag is set to YES then doxygen will produce a new +# page for each member. If set to NO, the documentation of a member will be part +# of the file/class/namespace that contains it. +# The default value is: NO. + +SEPARATE_MEMBER_PAGES = NO + +# The TAB_SIZE tag can be used to set the number of spaces in a tab. Doxygen +# uses this value to replace tabs by spaces in code fragments. +# Minimum value: 1, maximum value: 16, default value: 4. + +TAB_SIZE = 4 + +# This tag can be used to specify a number of aliases that act as commands in +# the documentation. An alias has the form: +# name=value +# For example adding +# "sideeffect=@par Side Effects:^^" +# will allow you to put the command \sideeffect (or @sideeffect) in the +# documentation, which will result in a user-defined paragraph with heading +# "Side Effects:". Note that you cannot put \n's in the value part of an alias +# to insert newlines (in the resulting output). You can put ^^ in the value part +# of an alias to insert a newline as if a physical newline was in the original +# file. When you need a literal { or } or , in the value part of an alias you +# have to escape them by means of a backslash (\), this can lead to conflicts +# with the commands \{ and \} for these it is advised to use the version @{ and +# @} or use a double escape (\\{ and \\}) + +ALIASES = + +# Set the OPTIMIZE_OUTPUT_FOR_C tag to YES if your project consists of C sources +# only. Doxygen will then generate output that is more tailored for C. For +# instance, some of the names that are used will be different. The list of all +# members will be omitted, etc. +# The default value is: NO. + +OPTIMIZE_OUTPUT_FOR_C = NO + +# Set the OPTIMIZE_OUTPUT_JAVA tag to YES if your project consists of Java or +# Python sources only. Doxygen will then generate output that is more tailored +# for that language. For instance, namespaces will be presented as packages, +# qualified scopes will look different, etc. +# The default value is: NO. + +OPTIMIZE_OUTPUT_JAVA = NO + +# Set the OPTIMIZE_FOR_FORTRAN tag to YES if your project consists of Fortran +# sources. Doxygen will then generate output that is tailored for Fortran. +# The default value is: NO. + +OPTIMIZE_FOR_FORTRAN = NO + +# Set the OPTIMIZE_OUTPUT_VHDL tag to YES if your project consists of VHDL +# sources. Doxygen will then generate output that is tailored for VHDL. +# The default value is: NO. + +OPTIMIZE_OUTPUT_VHDL = NO + +# Set the OPTIMIZE_OUTPUT_SLICE tag to YES if your project consists of Slice +# sources only. Doxygen will then generate output that is more tailored for that +# language. For instance, namespaces will be presented as modules, types will be +# separated into more groups, etc. +# The default value is: NO. + +OPTIMIZE_OUTPUT_SLICE = NO + +# Doxygen selects the parser to use depending on the extension of the files it +# parses. With this tag you can assign which parser to use for a given +# extension. Doxygen has a built-in mapping, but you can override or extend it +# using this tag. The format is ext=language, where ext is a file extension, and +# language is one of the parsers supported by doxygen: IDL, Java, JavaScript, +# Csharp (C#), C, C++, Lex, D, PHP, md (Markdown), Objective-C, Python, Slice, +# VHDL, Fortran (fixed format Fortran: FortranFixed, free formatted Fortran: +# FortranFree, unknown formatted Fortran: Fortran. In the later case the parser +# tries to guess whether the code is fixed or free formatted code, this is the +# default for Fortran type files). For instance to make doxygen treat .inc files +# as Fortran files (default is PHP), and .f files as C (default is Fortran), +# use: inc=Fortran f=C. +# +# Note: For files without extension you can use no_extension as a placeholder. +# +# Note that for custom extensions you also need to set FILE_PATTERNS otherwise +# the files are not read by doxygen. When specifying no_extension you should add +# * to the FILE_PATTERNS. +# +# Note see also the list of default file extension mappings. + +EXTENSION_MAPPING = + +# If the MARKDOWN_SUPPORT tag is enabled then doxygen pre-processes all comments +# according to the Markdown format, which allows for more readable +# documentation. See https://daringfireball.net/projects/markdown/ for details. +# The output of markdown processing is further processed by doxygen, so you can +# mix doxygen, HTML, and XML commands with Markdown formatting. Disable only in +# case of backward compatibilities issues. +# The default value is: YES. + +MARKDOWN_SUPPORT = YES + +# When the TOC_INCLUDE_HEADINGS tag is set to a non-zero value, all headings up +# to that level are automatically included in the table of contents, even if +# they do not have an id attribute. +# Note: This feature currently applies only to Markdown headings. +# Minimum value: 0, maximum value: 99, default value: 5. +# This tag requires that the tag MARKDOWN_SUPPORT is set to YES. + +TOC_INCLUDE_HEADINGS = 5 + +# When enabled doxygen tries to link words that correspond to documented +# classes, or namespaces to their corresponding documentation. Such a link can +# be prevented in individual cases by putting a % sign in front of the word or +# globally by setting AUTOLINK_SUPPORT to NO. +# The default value is: YES. + +AUTOLINK_SUPPORT = YES + +# If you use STL classes (i.e. std::string, std::vector, etc.) but do not want +# to include (a tag file for) the STL sources as input, then you should set this +# tag to YES in order to let doxygen match functions declarations and +# definitions whose arguments contain STL classes (e.g. func(std::string); +# versus func(std::string) {}). This also make the inheritance and collaboration +# diagrams that involve STL classes more complete and accurate. +# The default value is: NO. + +BUILTIN_STL_SUPPORT = NO + +# If you use Microsoft's C++/CLI language, you should set this option to YES to +# enable parsing support. +# The default value is: NO. + +CPP_CLI_SUPPORT = NO + +# Set the SIP_SUPPORT tag to YES if your project consists of sip (see: +# https://www.riverbankcomputing.com/software/sip/intro) sources only. Doxygen +# will parse them like normal C++ but will assume all classes use public instead +# of private inheritance when no explicit protection keyword is present. +# The default value is: NO. + +SIP_SUPPORT = NO + +# For Microsoft's IDL there are propget and propput attributes to indicate +# getter and setter methods for a property. Setting this option to YES will make +# doxygen to replace the get and set methods by a property in the documentation. +# This will only work if the methods are indeed getting or setting a simple +# type. If this is not the case, or you want to show the methods anyway, you +# should set this option to NO. +# The default value is: YES. + +IDL_PROPERTY_SUPPORT = YES + +# If member grouping is used in the documentation and the DISTRIBUTE_GROUP_DOC +# tag is set to YES then doxygen will reuse the documentation of the first +# member in the group (if any) for the other members of the group. By default +# all members of a group must be documented explicitly. +# The default value is: NO. + +DISTRIBUTE_GROUP_DOC = NO + +# If one adds a struct or class to a group and this option is enabled, then also +# any nested class or struct is added to the same group. By default this option +# is disabled and one has to add nested compounds explicitly via \ingroup. +# The default value is: NO. + +GROUP_NESTED_COMPOUNDS = NO + +# Set the SUBGROUPING tag to YES to allow class member groups of the same type +# (for instance a group of public functions) to be put as a subgroup of that +# type (e.g. under the Public Functions section). Set it to NO to prevent +# subgrouping. Alternatively, this can be done per class using the +# \nosubgrouping command. +# The default value is: YES. + +SUBGROUPING = YES + +# When the INLINE_GROUPED_CLASSES tag is set to YES, classes, structs and unions +# are shown inside the group in which they are included (e.g. using \ingroup) +# instead of on a separate page (for HTML and Man pages) or section (for LaTeX +# and RTF). +# +# Note that this feature does not work in combination with +# SEPARATE_MEMBER_PAGES. +# The default value is: NO. + +INLINE_GROUPED_CLASSES = NO + +# When the INLINE_SIMPLE_STRUCTS tag is set to YES, structs, classes, and unions +# with only public data fields or simple typedef fields will be shown inline in +# the documentation of the scope in which they are defined (i.e. file, +# namespace, or group documentation), provided this scope is documented. If set +# to NO, structs, classes, and unions are shown on a separate page (for HTML and +# Man pages) or section (for LaTeX and RTF). +# The default value is: NO. + +INLINE_SIMPLE_STRUCTS = NO + +# When TYPEDEF_HIDES_STRUCT tag is enabled, a typedef of a struct, union, or +# enum is documented as struct, union, or enum with the name of the typedef. So +# typedef struct TypeS {} TypeT, will appear in the documentation as a struct +# with name TypeT. When disabled the typedef will appear as a member of a file, +# namespace, or class. And the struct will be named TypeS. This can typically be +# useful for C code in case the coding convention dictates that all compound +# types are typedef'ed and only the typedef is referenced, never the tag name. +# The default value is: NO. + +TYPEDEF_HIDES_STRUCT = NO + +# The size of the symbol lookup cache can be set using LOOKUP_CACHE_SIZE. This +# cache is used to resolve symbols given their name and scope. Since this can be +# an expensive process and often the same symbol appears multiple times in the +# code, doxygen keeps a cache of pre-resolved symbols. If the cache is too small +# doxygen will become slower. If the cache is too large, memory is wasted. The +# cache size is given by this formula: 2^(16+LOOKUP_CACHE_SIZE). The valid range +# is 0..9, the default is 0, corresponding to a cache size of 2^16=65536 +# symbols. At the end of a run doxygen will report the cache usage and suggest +# the optimal cache size from a speed point of view. +# Minimum value: 0, maximum value: 9, default value: 0. + +LOOKUP_CACHE_SIZE = 0 + +# The NUM_PROC_THREADS specifies the number of threads doxygen is allowed to use +# during processing. When set to 0 doxygen will based this on the number of +# cores available in the system. You can set it explicitly to a value larger +# than 0 to get more control over the balance between CPU load and processing +# speed. At this moment only the input processing can be done using multiple +# threads. Since this is still an experimental feature the default is set to 1, +# which effectively disables parallel processing. Please report any issues you +# encounter. Generating dot graphs in parallel is controlled by the +# DOT_NUM_THREADS setting. +# Minimum value: 0, maximum value: 32, default value: 1. + +NUM_PROC_THREADS = 1 + +#--------------------------------------------------------------------------- +# Build related configuration options +#--------------------------------------------------------------------------- + +# If the EXTRACT_ALL tag is set to YES, doxygen will assume all entities in +# documentation are documented, even if no documentation was available. Private +# class members and static file members will be hidden unless the +# EXTRACT_PRIVATE respectively EXTRACT_STATIC tags are set to YES. +# Note: This will also disable the warnings about undocumented members that are +# normally produced when WARNINGS is set to YES. +# The default value is: NO. + +EXTRACT_ALL = NO + +# If the EXTRACT_PRIVATE tag is set to YES, all private members of a class will +# be included in the documentation. +# The default value is: NO. + +EXTRACT_PRIVATE = NO + +# If the EXTRACT_PRIV_VIRTUAL tag is set to YES, documented private virtual +# methods of a class will be included in the documentation. +# The default value is: NO. + +EXTRACT_PRIV_VIRTUAL = NO + +# If the EXTRACT_PACKAGE tag is set to YES, all members with package or internal +# scope will be included in the documentation. +# The default value is: NO. + +EXTRACT_PACKAGE = NO + +# If the EXTRACT_STATIC tag is set to YES, all static members of a file will be +# included in the documentation. +# The default value is: NO. + +EXTRACT_STATIC = NO + +# If the EXTRACT_LOCAL_CLASSES tag is set to YES, classes (and structs) defined +# locally in source files will be included in the documentation. If set to NO, +# only classes defined in header files are included. Does not have any effect +# for Java sources. +# The default value is: YES. + +EXTRACT_LOCAL_CLASSES = YES + +# This flag is only useful for Objective-C code. If set to YES, local methods, +# which are defined in the implementation section but not in the interface are +# included in the documentation. If set to NO, only methods in the interface are +# included. +# The default value is: NO. + +EXTRACT_LOCAL_METHODS = NO + +# If this flag is set to YES, the members of anonymous namespaces will be +# extracted and appear in the documentation as a namespace called +# 'anonymous_namespace{file}', where file will be replaced with the base name of +# the file that contains the anonymous namespace. By default anonymous namespace +# are hidden. +# The default value is: NO. + +EXTRACT_ANON_NSPACES = NO + +# If this flag is set to YES, the name of an unnamed parameter in a declaration +# will be determined by the corresponding definition. By default unnamed +# parameters remain unnamed in the output. +# The default value is: YES. + +RESOLVE_UNNAMED_PARAMS = YES + +# If the HIDE_UNDOC_MEMBERS tag is set to YES, doxygen will hide all +# undocumented members inside documented classes or files. If set to NO these +# members will be included in the various overviews, but no documentation +# section is generated. This option has no effect if EXTRACT_ALL is enabled. +# The default value is: NO. + +HIDE_UNDOC_MEMBERS = NO + +# If the HIDE_UNDOC_CLASSES tag is set to YES, doxygen will hide all +# undocumented classes that are normally visible in the class hierarchy. If set +# to NO, these classes will be included in the various overviews. This option +# will also hide undocumented C++ concepts if enabled. This option has no effect +# if EXTRACT_ALL is enabled. +# The default value is: NO. + +HIDE_UNDOC_CLASSES = NO + +# If the HIDE_FRIEND_COMPOUNDS tag is set to YES, doxygen will hide all friend +# declarations. If set to NO, these declarations will be included in the +# documentation. +# The default value is: NO. + +HIDE_FRIEND_COMPOUNDS = NO + +# If the HIDE_IN_BODY_DOCS tag is set to YES, doxygen will hide any +# documentation blocks found inside the body of a function. If set to NO, these +# blocks will be appended to the function's detailed documentation block. +# The default value is: NO. + +HIDE_IN_BODY_DOCS = NO + +# The INTERNAL_DOCS tag determines if documentation that is typed after a +# \internal command is included. If the tag is set to NO then the documentation +# will be excluded. Set it to YES to include the internal documentation. +# The default value is: NO. + +INTERNAL_DOCS = NO + +# With the correct setting of option CASE_SENSE_NAMES doxygen will better be +# able to match the capabilities of the underlying filesystem. In case the +# filesystem is case sensitive (i.e. it supports files in the same directory +# whose names only differ in casing), the option must be set to YES to properly +# deal with such files in case they appear in the input. For filesystems that +# are not case sensitive the option should be set to NO to properly deal with +# output files written for symbols that only differ in casing, such as for two +# classes, one named CLASS and the other named Class, and to also support +# references to files without having to specify the exact matching casing. On +# Windows (including Cygwin) and MacOS, users should typically set this option +# to NO, whereas on Linux or other Unix flavors it should typically be set to +# YES. +# Possible values are: SYSTEM, NO and YES. +# The default value is: SYSTEM. + +CASE_SENSE_NAMES = SYSTEM + +# If the HIDE_SCOPE_NAMES tag is set to NO then doxygen will show members with +# their full class and namespace scopes in the documentation. If set to YES, the +# scope will be hidden. +# The default value is: NO. + +HIDE_SCOPE_NAMES = NO + +# If the HIDE_COMPOUND_REFERENCE tag is set to NO (default) then doxygen will +# append additional text to a page's title, such as Class Reference. If set to +# YES the compound reference will be hidden. +# The default value is: NO. + +HIDE_COMPOUND_REFERENCE= NO + +# If the SHOW_HEADERFILE tag is set to YES then the documentation for a class +# will show which file needs to be included to use the class. +# The default value is: YES. + +SHOW_HEADERFILE = YES + +# If the SHOW_INCLUDE_FILES tag is set to YES then doxygen will put a list of +# the files that are included by a file in the documentation of that file. +# The default value is: YES. + +SHOW_INCLUDE_FILES = YES + +# If the SHOW_GROUPED_MEMB_INC tag is set to YES then Doxygen will add for each +# grouped member an include statement to the documentation, telling the reader +# which file to include in order to use the member. +# The default value is: NO. + +SHOW_GROUPED_MEMB_INC = NO + +# If the FORCE_LOCAL_INCLUDES tag is set to YES then doxygen will list include +# files with double quotes in the documentation rather than with sharp brackets. +# The default value is: NO. + +FORCE_LOCAL_INCLUDES = NO + +# If the INLINE_INFO tag is set to YES then a tag [inline] is inserted in the +# documentation for inline members. +# The default value is: YES. + +INLINE_INFO = YES + +# If the SORT_MEMBER_DOCS tag is set to YES then doxygen will sort the +# (detailed) documentation of file and class members alphabetically by member +# name. If set to NO, the members will appear in declaration order. +# The default value is: YES. + +SORT_MEMBER_DOCS = YES + +# If the SORT_BRIEF_DOCS tag is set to YES then doxygen will sort the brief +# descriptions of file, namespace and class members alphabetically by member +# name. If set to NO, the members will appear in declaration order. Note that +# this will also influence the order of the classes in the class list. +# The default value is: NO. + +SORT_BRIEF_DOCS = NO + +# If the SORT_MEMBERS_CTORS_1ST tag is set to YES then doxygen will sort the +# (brief and detailed) documentation of class members so that constructors and +# destructors are listed first. If set to NO the constructors will appear in the +# respective orders defined by SORT_BRIEF_DOCS and SORT_MEMBER_DOCS. +# Note: If SORT_BRIEF_DOCS is set to NO this option is ignored for sorting brief +# member documentation. +# Note: If SORT_MEMBER_DOCS is set to NO this option is ignored for sorting +# detailed member documentation. +# The default value is: NO. + +SORT_MEMBERS_CTORS_1ST = NO + +# If the SORT_GROUP_NAMES tag is set to YES then doxygen will sort the hierarchy +# of group names into alphabetical order. If set to NO the group names will +# appear in their defined order. +# The default value is: NO. + +SORT_GROUP_NAMES = NO + +# If the SORT_BY_SCOPE_NAME tag is set to YES, the class list will be sorted by +# fully-qualified names, including namespaces. If set to NO, the class list will +# be sorted only by class name, not including the namespace part. +# Note: This option is not very useful if HIDE_SCOPE_NAMES is set to YES. +# Note: This option applies only to the class list, not to the alphabetical +# list. +# The default value is: NO. + +SORT_BY_SCOPE_NAME = NO + +# If the STRICT_PROTO_MATCHING option is enabled and doxygen fails to do proper +# type resolution of all parameters of a function it will reject a match between +# the prototype and the implementation of a member function even if there is +# only one candidate or it is obvious which candidate to choose by doing a +# simple string match. By disabling STRICT_PROTO_MATCHING doxygen will still +# accept a match between prototype and implementation in such cases. +# The default value is: NO. + +STRICT_PROTO_MATCHING = NO + +# The GENERATE_TODOLIST tag can be used to enable (YES) or disable (NO) the todo +# list. This list is created by putting \todo commands in the documentation. +# The default value is: YES. + +GENERATE_TODOLIST = YES + +# The GENERATE_TESTLIST tag can be used to enable (YES) or disable (NO) the test +# list. This list is created by putting \test commands in the documentation. +# The default value is: YES. + +GENERATE_TESTLIST = YES + +# The GENERATE_BUGLIST tag can be used to enable (YES) or disable (NO) the bug +# list. This list is created by putting \bug commands in the documentation. +# The default value is: YES. + +GENERATE_BUGLIST = YES + +# The GENERATE_DEPRECATEDLIST tag can be used to enable (YES) or disable (NO) +# the deprecated list. This list is created by putting \deprecated commands in +# the documentation. +# The default value is: YES. + +GENERATE_DEPRECATEDLIST= YES + +# The ENABLED_SECTIONS tag can be used to enable conditional documentation +# sections, marked by \if ... \endif and \cond +# ... \endcond blocks. + +ENABLED_SECTIONS = + +# The MAX_INITIALIZER_LINES tag determines the maximum number of lines that the +# initial value of a variable or macro / define can have for it to appear in the +# documentation. If the initializer consists of more lines than specified here +# it will be hidden. Use a value of 0 to hide initializers completely. The +# appearance of the value of individual variables and macros / defines can be +# controlled using \showinitializer or \hideinitializer command in the +# documentation regardless of this setting. +# Minimum value: 0, maximum value: 10000, default value: 30. + +MAX_INITIALIZER_LINES = 30 + +# Set the SHOW_USED_FILES tag to NO to disable the list of files generated at +# the bottom of the documentation of classes and structs. If set to YES, the +# list will mention the files that were used to generate the documentation. +# The default value is: YES. + +SHOW_USED_FILES = YES + +# Set the SHOW_FILES tag to NO to disable the generation of the Files page. This +# will remove the Files entry from the Quick Index and from the Folder Tree View +# (if specified). +# The default value is: YES. + +SHOW_FILES = YES + +# Set the SHOW_NAMESPACES tag to NO to disable the generation of the Namespaces +# page. This will remove the Namespaces entry from the Quick Index and from the +# Folder Tree View (if specified). +# The default value is: YES. + +SHOW_NAMESPACES = YES + +# The FILE_VERSION_FILTER tag can be used to specify a program or script that +# doxygen should invoke to get the current version for each file (typically from +# the version control system). Doxygen will invoke the program by executing (via +# popen()) the command command input-file, where command is the value of the +# FILE_VERSION_FILTER tag, and input-file is the name of an input file provided +# by doxygen. Whatever the program writes to standard output is used as the file +# version. For an example see the documentation. + +FILE_VERSION_FILTER = + +# The LAYOUT_FILE tag can be used to specify a layout file which will be parsed +# by doxygen. The layout file controls the global structure of the generated +# output files in an output format independent way. To create the layout file +# that represents doxygen's defaults, run doxygen with the -l option. You can +# optionally specify a file name after the option, if omitted DoxygenLayout.xml +# will be used as the name of the layout file. See also section "Changing the +# layout of pages" for information. +# +# Note that if you run doxygen from a directory containing a file called +# DoxygenLayout.xml, doxygen will parse it automatically even if the LAYOUT_FILE +# tag is left empty. + +LAYOUT_FILE = + +# The CITE_BIB_FILES tag can be used to specify one or more bib files containing +# the reference definitions. This must be a list of .bib files. The .bib +# extension is automatically appended if omitted. This requires the bibtex tool +# to be installed. See also https://en.wikipedia.org/wiki/BibTeX for more info. +# For LaTeX the style of the bibliography can be controlled using +# LATEX_BIB_STYLE. To use this feature you need bibtex and perl available in the +# search path. See also \cite for info how to create references. + +CITE_BIB_FILES = + +#--------------------------------------------------------------------------- +# Configuration options related to warning and progress messages +#--------------------------------------------------------------------------- + +# The QUIET tag can be used to turn on/off the messages that are generated to +# standard output by doxygen. If QUIET is set to YES this implies that the +# messages are off. +# The default value is: NO. + +QUIET = NO + +# The WARNINGS tag can be used to turn on/off the warning messages that are +# generated to standard error (stderr) by doxygen. If WARNINGS is set to YES +# this implies that the warnings are on. +# +# Tip: Turn warnings on while writing the documentation. +# The default value is: YES. + +WARNINGS = YES + +# If the WARN_IF_UNDOCUMENTED tag is set to YES then doxygen will generate +# warnings for undocumented members. If EXTRACT_ALL is set to YES then this flag +# will automatically be disabled. +# The default value is: YES. + +WARN_IF_UNDOCUMENTED = YES + +# If the WARN_IF_DOC_ERROR tag is set to YES, doxygen will generate warnings for +# potential errors in the documentation, such as documenting some parameters in +# a documented function twice, or documenting parameters that don't exist or +# using markup commands wrongly. +# The default value is: YES. + +WARN_IF_DOC_ERROR = YES + +# If WARN_IF_INCOMPLETE_DOC is set to YES, doxygen will warn about incomplete +# function parameter documentation. If set to NO, doxygen will accept that some +# parameters have no documentation without warning. +# The default value is: YES. + +WARN_IF_INCOMPLETE_DOC = YES + +# This WARN_NO_PARAMDOC option can be enabled to get warnings for functions that +# are documented, but have no documentation for their parameters or return +# value. If set to NO, doxygen will only warn about wrong parameter +# documentation, but not about the absence of documentation. If EXTRACT_ALL is +# set to YES then this flag will automatically be disabled. See also +# WARN_IF_INCOMPLETE_DOC +# The default value is: NO. + +WARN_NO_PARAMDOC = NO + +# If WARN_IF_UNDOC_ENUM_VAL option is set to YES, doxygen will warn about +# undocumented enumeration values. If set to NO, doxygen will accept +# undocumented enumeration values. If EXTRACT_ALL is set to YES then this flag +# will automatically be disabled. +# The default value is: NO. + +WARN_IF_UNDOC_ENUM_VAL = NO + +# If the WARN_AS_ERROR tag is set to YES then doxygen will immediately stop when +# a warning is encountered. If the WARN_AS_ERROR tag is set to FAIL_ON_WARNINGS +# then doxygen will continue running as if WARN_AS_ERROR tag is set to NO, but +# at the end of the doxygen process doxygen will return with a non-zero status. +# Possible values are: NO, YES and FAIL_ON_WARNINGS. +# The default value is: NO. + +WARN_AS_ERROR = NO + +# The WARN_FORMAT tag determines the format of the warning messages that doxygen +# can produce. The string should contain the $file, $line, and $text tags, which +# will be replaced by the file and line number from which the warning originated +# and the warning text. Optionally the format may contain $version, which will +# be replaced by the version of the file (if it could be obtained via +# FILE_VERSION_FILTER) +# See also: WARN_LINE_FORMAT +# The default value is: $file:$line: $text. + +WARN_FORMAT = "$file:$line: $text" + +# In the $text part of the WARN_FORMAT command it is possible that a reference +# to a more specific place is given. To make it easier to jump to this place +# (outside of doxygen) the user can define a custom "cut" / "paste" string. +# Example: +# WARN_LINE_FORMAT = "'vi $file +$line'" +# See also: WARN_FORMAT +# The default value is: at line $line of file $file. + +WARN_LINE_FORMAT = "at line $line of file $file" + +# The WARN_LOGFILE tag can be used to specify a file to which warning and error +# messages should be written. If left blank the output is written to standard +# error (stderr). In case the file specified cannot be opened for writing the +# warning and error messages are written to standard error. When as file - is +# specified the warning and error messages are written to standard output +# (stdout). + +WARN_LOGFILE = + +#--------------------------------------------------------------------------- +# Configuration options related to the input files +#--------------------------------------------------------------------------- + +# The INPUT tag is used to specify the files and/or directories that contain +# documented source files. You may enter file names like myfile.cpp or +# directories like /usr/src/myproject. Separate the files or directories with +# spaces. See also FILE_PATTERNS and EXTENSION_MAPPING +# Note: If this tag is empty the current directory is searched. + +INPUT = ../ + +# This tag can be used to specify the character encoding of the source files +# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses +# libiconv (or the iconv built into libc) for the transcoding. See the libiconv +# documentation (see: +# https://www.gnu.org/software/libiconv/) for the list of possible encodings. +# See also: INPUT_FILE_ENCODING +# The default value is: UTF-8. + +INPUT_ENCODING = UTF-8 + +# This tag can be used to specify the character encoding of the source files +# that doxygen parses The INPUT_FILE_ENCODING tag can be used to specify +# character encoding on a per file pattern basis. Doxygen will compare the file +# name with each pattern and apply the encoding instead of the default +# INPUT_ENCODING) if there is a match. The character encodings are a list of the +# form: pattern=encoding (like *.php=ISO-8859-1). See cfg_input_encoding +# "INPUT_ENCODING" for further information on supported encodings. + +INPUT_FILE_ENCODING = + +# If the value of the INPUT tag contains directories, you can use the +# FILE_PATTERNS tag to specify one or more wildcard patterns (like *.cpp and +# *.h) to filter out the source-files in the directories. +# +# Note that for custom extensions or not directly supported extensions you also +# need to set EXTENSION_MAPPING for the extension otherwise the files are not +# read by doxygen. +# +# Note the list of default checked file patterns might differ from the list of +# default file extension mappings. +# +# If left blank the following patterns are tested:*.c, *.cc, *.cxx, *.cpp, +# *.c++, *.java, *.ii, *.ixx, *.ipp, *.i++, *.inl, *.idl, *.ddl, *.odl, *.h, +# *.hh, *.hxx, *.hpp, *.h++, *.l, *.cs, *.d, *.php, *.php4, *.php5, *.phtml, +# *.inc, *.m, *.markdown, *.md, *.mm, *.dox (to be provided as doxygen C +# comment), *.py, *.pyw, *.f90, *.f95, *.f03, *.f08, *.f18, *.f, *.for, *.vhd, +# *.vhdl, *.ucf, *.qsf and *.ice. + +FILE_PATTERNS = *.c \ + *.cc \ + *.cxx \ + *.cpp \ + *.c++ \ + *.java \ + *.ii \ + *.ixx \ + *.ipp \ + *.i++ \ + *.inl \ + *.idl \ + *.ddl \ + *.odl \ + *.h \ + *.hh \ + *.hxx \ + *.hpp \ + *.h++ \ + *.l \ + *.cs \ + *.d \ + *.php \ + *.php4 \ + *.php5 \ + *.phtml \ + *.inc \ + *.m \ + *.markdown \ + *.md \ + *.mm \ + *.dox \ + *.py \ + *.pyw \ + *.f90 \ + *.f95 \ + *.f03 \ + *.f08 \ + *.f18 \ + *.f \ + *.for \ + *.vhd \ + *.vhdl \ + *.ucf \ + *.qsf \ + *.ice + +# The RECURSIVE tag can be used to specify whether or not subdirectories should +# be searched for input files as well. +# The default value is: NO. + +RECURSIVE = YES + +# The EXCLUDE tag can be used to specify files and/or directories that should be +# excluded from the INPUT source files. This way you can easily exclude a +# subdirectory from a directory tree whose root is specified with the INPUT tag. +# +# Note that relative paths are relative to the directory from which doxygen is +# run. + +EXCLUDE = ../addon \ + ../aocl_dtl \ + ../bench \ + ../blastest \ + ../build \ + ../config \ + ../examples \ + ../include \ + ../gtestsuite \ + ../kernels \ + ../lib \ + ../mpi_test \ + ../ref_kernels \ + ../sandbox \ + ../test \ + ../testsuite \ + ../travis \ + ../vendor \ + ../windows \ + ../frame/0 \ + ../frame/1 \ + ../frame/1d \ + ../frame/1f \ + ../frame/1m \ + ../frame/2 \ + ../frame/3 \ + ../frame/base \ + ../frame/include \ + ../frame/ind \ + ../frame/thread \ + ../frame/util \ + ../bli_addon.h \ + ../bli_config.h \ + ../configure \ + ../CONTRIBUTING.md \ + ../INSTALL \ + ../LICENSE \ + ../Makefile \ + ../README.md \ + ../RELEASING \ + ../docs/Addons.md \ + ../docs/BLISObjectAPI.md \ + ../docs/BLISTypedAPI.md \ + ../docs/BuildSystem.md \ + ../docs/CodingConventions.md \ + ../docs/ConfigurationHowTo.md \ + ../docs/Doxyfile \ + ../docs/FAQ.md \ + ../docs/HardwareSupport.md \ + ../docs/KernelsHowTo.md \ + ../docs/MixedDatatypes.md \ + ../docs/Multithreading.md \ + ../docs/Performance.md \ + ../docs/PerformanceSmall.md \ + ../docs/ReleaseNotes.md \ + ../docs/Sandboxes.md \ + ../docs/Testsuite.md + +# The EXCLUDE_SYMLINKS tag can be used to select whether or not files or +# directories that are symbolic links (a Unix file system feature) are excluded +# from the input. +# The default value is: NO. + +EXCLUDE_SYMLINKS = NO + +# If the value of the INPUT tag contains directories, you can use the +# EXCLUDE_PATTERNS tag to specify one or more wildcard patterns to exclude +# certain files from those directories. +# +# Note that the wildcards are matched against the file with absolute path, so to +# exclude all test directories for example use the pattern */test/* + +EXCLUDE_PATTERNS = + +# The EXCLUDE_SYMBOLS tag can be used to specify one or more symbol names +# (namespaces, classes, functions, etc.) that should be excluded from the +# output. The symbol name can be a fully qualified name, a word, or if the +# wildcard * is used, a substring. Examples: ANamespace, AClass, +# ANamespace::AClass, ANamespace::*Test +# +# Note that the wildcards are matched against the file with absolute path, so to +# exclude all test directories use the pattern */test/* + +EXCLUDE_SYMBOLS = + +# The EXAMPLE_PATH tag can be used to specify one or more files or directories +# that contain example code fragments that are included (see the \include +# command). + +EXAMPLE_PATH = + +# If the value of the EXAMPLE_PATH tag contains directories, you can use the +# EXAMPLE_PATTERNS tag to specify one or more wildcard pattern (like *.cpp and +# *.h) to filter out the source-files in the directories. If left blank all +# files are included. + +EXAMPLE_PATTERNS = * + +# If the EXAMPLE_RECURSIVE tag is set to YES then subdirectories will be +# searched for input files to be used with the \include or \dontinclude commands +# irrespective of the value of the RECURSIVE tag. +# The default value is: NO. + +EXAMPLE_RECURSIVE = NO + +# The IMAGE_PATH tag can be used to specify one or more files or directories +# that contain images that are to be included in the documentation (see the +# \image command). + +IMAGE_PATH = + +# The INPUT_FILTER tag can be used to specify a program that doxygen should +# invoke to filter for each input file. Doxygen will invoke the filter program +# by executing (via popen()) the command: +# +# +# +# where is the value of the INPUT_FILTER tag, and is the +# name of an input file. Doxygen will then use the output that the filter +# program writes to standard output. If FILTER_PATTERNS is specified, this tag +# will be ignored. +# +# Note that the filter must not add or remove lines; it is applied before the +# code is scanned, but not when the output code is generated. If lines are added +# or removed, the anchors will not be placed correctly. +# +# Note that doxygen will use the data processed and written to standard output +# for further processing, therefore nothing else, like debug statements or used +# commands (so in case of a Windows batch file always use @echo OFF), should be +# written to standard output. +# +# Note that for custom extensions or not directly supported extensions you also +# need to set EXTENSION_MAPPING for the extension otherwise the files are not +# properly processed by doxygen. + +INPUT_FILTER = + +# The FILTER_PATTERNS tag can be used to specify filters on a per file pattern +# basis. Doxygen will compare the file name with each pattern and apply the +# filter if there is a match. The filters are a list of the form: pattern=filter +# (like *.cpp=my_cpp_filter). See INPUT_FILTER for further information on how +# filters are used. If the FILTER_PATTERNS tag is empty or if none of the +# patterns match the file name, INPUT_FILTER is applied. +# +# Note that for custom extensions or not directly supported extensions you also +# need to set EXTENSION_MAPPING for the extension otherwise the files are not +# properly processed by doxygen. + +FILTER_PATTERNS = + +# If the FILTER_SOURCE_FILES tag is set to YES, the input filter (if set using +# INPUT_FILTER) will also be used to filter the input files that are used for +# producing the source files to browse (i.e. when SOURCE_BROWSER is set to YES). +# The default value is: NO. + +FILTER_SOURCE_FILES = NO + +# The FILTER_SOURCE_PATTERNS tag can be used to specify source filters per file +# pattern. A pattern will override the setting for FILTER_PATTERN (if any) and +# it is also possible to disable source filtering for a specific pattern using +# *.ext= (so without naming a filter). +# This tag requires that the tag FILTER_SOURCE_FILES is set to YES. + +FILTER_SOURCE_PATTERNS = + +# If the USE_MDFILE_AS_MAINPAGE tag refers to the name of a markdown file that +# is part of the input, its contents will be placed on the main page +# (index.html). This can be useful if you have a project on for instance GitHub +# and want to reuse the introduction page also for the doxygen output. + +USE_MDFILE_AS_MAINPAGE = + +# The Fortran standard specifies that for fixed formatted Fortran code all +# characters from position 72 are to be considered as comment. A common +# extension is to allow longer lines before the automatic comment starts. The +# setting FORTRAN_COMMENT_AFTER will also make it possible that longer lines can +# be processed before the automatic comment starts. +# Minimum value: 7, maximum value: 10000, default value: 72. + +FORTRAN_COMMENT_AFTER = 72 + +#--------------------------------------------------------------------------- +# Configuration options related to source browsing +#--------------------------------------------------------------------------- + +# If the SOURCE_BROWSER tag is set to YES then a list of source files will be +# generated. Documented entities will be cross-referenced with these sources. +# +# Note: To get rid of all source code in the generated output, make sure that +# also VERBATIM_HEADERS is set to NO. +# The default value is: NO. + +SOURCE_BROWSER = NO + +# Setting the INLINE_SOURCES tag to YES will include the body of functions, +# classes and enums directly into the documentation. +# The default value is: NO. + +INLINE_SOURCES = NO + +# Setting the STRIP_CODE_COMMENTS tag to YES will instruct doxygen to hide any +# special comment blocks from generated source code fragments. Normal C, C++ and +# Fortran comments will always remain visible. +# The default value is: YES. + +STRIP_CODE_COMMENTS = YES + +# If the REFERENCED_BY_RELATION tag is set to YES then for each documented +# entity all documented functions referencing it will be listed. +# The default value is: NO. + +REFERENCED_BY_RELATION = NO + +# If the REFERENCES_RELATION tag is set to YES then for each documented function +# all documented entities called/used by that function will be listed. +# The default value is: NO. + +REFERENCES_RELATION = NO + +# If the REFERENCES_LINK_SOURCE tag is set to YES and SOURCE_BROWSER tag is set +# to YES then the hyperlinks from functions in REFERENCES_RELATION and +# REFERENCED_BY_RELATION lists will link to the source code. Otherwise they will +# link to the documentation. +# The default value is: YES. + +REFERENCES_LINK_SOURCE = YES + +# If SOURCE_TOOLTIPS is enabled (the default) then hovering a hyperlink in the +# source code will show a tooltip with additional information such as prototype, +# brief description and links to the definition and documentation. Since this +# will make the HTML file larger and loading of large files a bit slower, you +# can opt to disable this feature. +# The default value is: YES. +# This tag requires that the tag SOURCE_BROWSER is set to YES. + +SOURCE_TOOLTIPS = YES + +# If the USE_HTAGS tag is set to YES then the references to source code will +# point to the HTML generated by the htags(1) tool instead of doxygen built-in +# source browser. The htags tool is part of GNU's global source tagging system +# (see https://www.gnu.org/software/global/global.html). You will need version +# 4.8.6 or higher. +# +# To use it do the following: +# - Install the latest version of global +# - Enable SOURCE_BROWSER and USE_HTAGS in the configuration file +# - Make sure the INPUT points to the root of the source tree +# - Run doxygen as normal +# +# Doxygen will invoke htags (and that will in turn invoke gtags), so these +# tools must be available from the command line (i.e. in the search path). +# +# The result: instead of the source browser generated by doxygen, the links to +# source code will now point to the output of htags. +# The default value is: NO. +# This tag requires that the tag SOURCE_BROWSER is set to YES. + +USE_HTAGS = NO + +# If the VERBATIM_HEADERS tag is set the YES then doxygen will generate a +# verbatim copy of the header file for each class for which an include is +# specified. Set to NO to disable this. +# See also: Section \class. +# The default value is: YES. + +VERBATIM_HEADERS = YES + +# If the CLANG_ASSISTED_PARSING tag is set to YES then doxygen will use the +# clang parser (see: +# http://clang.llvm.org/) for more accurate parsing at the cost of reduced +# performance. This can be particularly helpful with template rich C++ code for +# which doxygen's built-in parser lacks the necessary type information. +# Note: The availability of this option depends on whether or not doxygen was +# generated with the -Duse_libclang=ON option for CMake. +# The default value is: NO. + +CLANG_ASSISTED_PARSING = NO + +# If the CLANG_ASSISTED_PARSING tag is set to YES and the CLANG_ADD_INC_PATHS +# tag is set to YES then doxygen will add the directory of each input to the +# include path. +# The default value is: YES. +# This tag requires that the tag CLANG_ASSISTED_PARSING is set to YES. + +CLANG_ADD_INC_PATHS = YES + +# If clang assisted parsing is enabled you can provide the compiler with command +# line options that you would normally use when invoking the compiler. Note that +# the include paths will already be set by doxygen for the files and directories +# specified with INPUT and INCLUDE_PATH. +# This tag requires that the tag CLANG_ASSISTED_PARSING is set to YES. + +CLANG_OPTIONS = + +# If clang assisted parsing is enabled you can provide the clang parser with the +# path to the directory containing a file called compile_commands.json. This +# file is the compilation database (see: +# http://clang.llvm.org/docs/HowToSetupToolingForLLVM.html) containing the +# options used when the source files were built. This is equivalent to +# specifying the -p option to a clang tool, such as clang-check. These options +# will then be passed to the parser. Any options specified with CLANG_OPTIONS +# will be added as well. +# Note: The availability of this option depends on whether or not doxygen was +# generated with the -Duse_libclang=ON option for CMake. + +CLANG_DATABASE_PATH = + +#--------------------------------------------------------------------------- +# Configuration options related to the alphabetical class index +#--------------------------------------------------------------------------- + +# If the ALPHABETICAL_INDEX tag is set to YES, an alphabetical index of all +# compounds will be generated. Enable this if the project contains a lot of +# classes, structs, unions or interfaces. +# The default value is: YES. + +ALPHABETICAL_INDEX = YES + +# The IGNORE_PREFIX tag can be used to specify a prefix (or a list of prefixes) +# that should be ignored while generating the index headers. The IGNORE_PREFIX +# tag works for classes, function and member names. The entity will be placed in +# the alphabetical list under the first letter of the entity name that remains +# after removing the prefix. +# This tag requires that the tag ALPHABETICAL_INDEX is set to YES. + +IGNORE_PREFIX = + +#--------------------------------------------------------------------------- +# Configuration options related to the HTML output +#--------------------------------------------------------------------------- + +# If the GENERATE_HTML tag is set to YES, doxygen will generate HTML output +# The default value is: YES. + +GENERATE_HTML = YES + +# The HTML_OUTPUT tag is used to specify where the HTML docs will be put. If a +# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of +# it. +# The default directory is: html. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_OUTPUT = html + +# The HTML_FILE_EXTENSION tag can be used to specify the file extension for each +# generated HTML page (for example: .htm, .php, .asp). +# The default value is: .html. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_FILE_EXTENSION = .html + +# The HTML_HEADER tag can be used to specify a user-defined HTML header file for +# each generated HTML page. If the tag is left blank doxygen will generate a +# standard header. +# +# To get valid HTML the header file that includes any scripts and style sheets +# that doxygen needs, which is dependent on the configuration options used (e.g. +# the setting GENERATE_TREEVIEW). It is highly recommended to start with a +# default header using +# doxygen -w html new_header.html new_footer.html new_stylesheet.css +# YourConfigFile +# and then modify the file new_header.html. See also section "Doxygen usage" +# for information on how to generate the default header that doxygen normally +# uses. +# Note: The header is subject to change so you typically have to regenerate the +# default header when upgrading to a newer version of doxygen. For a description +# of the possible markers and block names see the documentation. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_HEADER = ./styling/header.html + +# The HTML_FOOTER tag can be used to specify a user-defined HTML footer for each +# generated HTML page. If the tag is left blank doxygen will generate a standard +# footer. See HTML_HEADER for more information on how to generate a default +# footer and what special commands can be used inside the footer. See also +# section "Doxygen usage" for information on how to generate the default footer +# that doxygen normally uses. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_FOOTER = ./styling/footer.html + +# The HTML_STYLESHEET tag can be used to specify a user-defined cascading style +# sheet that is used by each HTML page. It can be used to fine-tune the look of +# the HTML output. If left blank doxygen will generate a default style sheet. +# See also section "Doxygen usage" for information on how to generate the style +# sheet that doxygen normally uses. +# Note: It is recommended to use HTML_EXTRA_STYLESHEET instead of this tag, as +# it is more robust and this tag (HTML_STYLESHEET) will in the future become +# obsolete. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_STYLESHEET = + +# The HTML_EXTRA_STYLESHEET tag can be used to specify additional user-defined +# cascading style sheets that are included after the standard style sheets +# created by doxygen. Using this option one can overrule certain style aspects. +# This is preferred over using HTML_STYLESHEET since it does not replace the +# standard style sheet and is therefore more robust against future updates. +# Doxygen will copy the style sheet files to the output directory. +# Note: The order of the extra style sheet files is of importance (e.g. the last +# style sheet in the list overrules the setting of the previous ones in the +# list). +# Note: Since the styling of scrollbars can currently not be overruled in +# Webkit/Chromium, the styling will be left out of the default doxygen.css if +# one or more extra stylesheets have been specified. So if scrollbar +# customization is desired it has to be added explicitly. For an example see the +# documentation. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_EXTRA_STYLESHEET = ./styling/doxygen-awesome.css + +# The HTML_EXTRA_FILES tag can be used to specify one or more extra images or +# other source files which should be copied to the HTML output directory. Note +# that these files will be copied to the base HTML output directory. Use the +# $relpath^ marker in the HTML_HEADER and/or HTML_FOOTER files to load these +# files. In the HTML_STYLESHEET file, use the file name only. Also note that the +# files will be copied as-is; there are no commands or markers available. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_EXTRA_FILES = ./styling/AMD_Logo.png \ + ./styling/doxygen-fragment-copy-button.js \ + ./styling/doxygen-interactive-toc.js + +# The HTML_COLORSTYLE tag can be used to specify if the generated HTML output +# should be rendered with a dark or light theme. +# Possible values are: LIGHT always generate light mode output, DARK always +# generate dark mode output, AUTO_LIGHT automatically set the mode according to +# the user preference, use light mode if no preference is set (the default), +# AUTO_DARK automatically set the mode according to the user preference, use +# dark mode if no preference is set and TOGGLE allow to user to switch between +# light and dark mode via a button. +# The default value is: AUTO_LIGHT. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COLORSTYLE = AUTO_LIGHT + +# The HTML_COLORSTYLE_HUE tag controls the color of the HTML output. Doxygen +# will adjust the colors in the style sheet and background images according to +# this color. Hue is specified as an angle on a color-wheel, see +# https://en.wikipedia.org/wiki/Hue for more information. For instance the value +# 0 represents red, 60 is yellow, 120 is green, 180 is cyan, 240 is blue, 300 +# purple, and 360 is red again. +# Minimum value: 0, maximum value: 359, default value: 220. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COLORSTYLE_HUE = 220 + +# The HTML_COLORSTYLE_SAT tag controls the purity (or saturation) of the colors +# in the HTML output. For a value of 0 the output will use gray-scales only. A +# value of 255 will produce the most vivid colors. +# Minimum value: 0, maximum value: 255, default value: 100. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COLORSTYLE_SAT = 100 + +# The HTML_COLORSTYLE_GAMMA tag controls the gamma correction applied to the +# luminance component of the colors in the HTML output. Values below 100 +# gradually make the output lighter, whereas values above 100 make the output +# darker. The value divided by 100 is the actual gamma applied, so 80 represents +# a gamma of 0.8, The value 220 represents a gamma of 2.2, and 100 does not +# change the gamma. +# Minimum value: 40, maximum value: 240, default value: 80. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COLORSTYLE_GAMMA = 80 + +# If the HTML_TIMESTAMP tag is set to YES then the footer of each generated HTML +# page will contain the date and time when the page was generated. Setting this +# to YES can help to show when doxygen was last run and thus if the +# documentation is up to date. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_TIMESTAMP = NO + +# If the HTML_DYNAMIC_MENUS tag is set to YES then the generated HTML +# documentation will contain a main index with vertical navigation menus that +# are dynamically created via JavaScript. If disabled, the navigation index will +# consists of multiple levels of tabs that are statically embedded in every HTML +# page. Disable this option to support browsers that do not have JavaScript, +# like the Qt help browser. +# The default value is: YES. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_DYNAMIC_MENUS = YES + +# If the HTML_DYNAMIC_SECTIONS tag is set to YES then the generated HTML +# documentation will contain sections that can be hidden and shown after the +# page has loaded. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_DYNAMIC_SECTIONS = NO + +# With HTML_INDEX_NUM_ENTRIES one can control the preferred number of entries +# shown in the various tree structured indices initially; the user can expand +# and collapse entries dynamically later on. Doxygen will expand the tree to +# such a level that at most the specified number of entries are visible (unless +# a fully collapsed tree already exceeds this amount). So setting the number of +# entries 1 will produce a full collapsed tree by default. 0 is a special value +# representing an infinite number of entries and will result in a full expanded +# tree by default. +# Minimum value: 0, maximum value: 9999, default value: 100. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_INDEX_NUM_ENTRIES = 100 + +# If the GENERATE_DOCSET tag is set to YES, additional index files will be +# generated that can be used as input for Apple's Xcode 3 integrated development +# environment (see: +# https://developer.apple.com/xcode/), introduced with OSX 10.5 (Leopard). To +# create a documentation set, doxygen will generate a Makefile in the HTML +# output directory. Running make will produce the docset in that directory and +# running make install will install the docset in +# ~/Library/Developer/Shared/Documentation/DocSets so that Xcode will find it at +# startup. See https://developer.apple.com/library/archive/featuredarticles/Doxy +# genXcode/_index.html for more information. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_DOCSET = NO + +# This tag determines the name of the docset feed. A documentation feed provides +# an umbrella under which multiple documentation sets from a single provider +# (such as a company or product suite) can be grouped. +# The default value is: Doxygen generated docs. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_FEEDNAME = "Doxygen generated docs" + +# This tag determines the URL of the docset feed. A documentation feed provides +# an umbrella under which multiple documentation sets from a single provider +# (such as a company or product suite) can be grouped. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_FEEDURL = + +# This tag specifies a string that should uniquely identify the documentation +# set bundle. This should be a reverse domain-name style string, e.g. +# com.mycompany.MyDocSet. Doxygen will append .docset to the name. +# The default value is: org.doxygen.Project. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_BUNDLE_ID = org.doxygen.Project + +# The DOCSET_PUBLISHER_ID tag specifies a string that should uniquely identify +# the documentation publisher. This should be a reverse domain-name style +# string, e.g. com.mycompany.MyDocSet.documentation. +# The default value is: org.doxygen.Publisher. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_PUBLISHER_ID = org.doxygen.Publisher + +# The DOCSET_PUBLISHER_NAME tag identifies the documentation publisher. +# The default value is: Publisher. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_PUBLISHER_NAME = Publisher + +# If the GENERATE_HTMLHELP tag is set to YES then doxygen generates three +# additional HTML index files: index.hhp, index.hhc, and index.hhk. The +# index.hhp is a project file that can be read by Microsoft's HTML Help Workshop +# on Windows. In the beginning of 2021 Microsoft took the original page, with +# a.o. the download links, offline the HTML help workshop was already many years +# in maintenance mode). You can download the HTML help workshop from the web +# archives at Installation executable (see: +# http://web.archive.org/web/20160201063255/http://download.microsoft.com/downlo +# ad/0/A/9/0A939EF6-E31C-430F-A3DF-DFAE7960D564/htmlhelp.exe). +# +# The HTML Help Workshop contains a compiler that can convert all HTML output +# generated by doxygen into a single compiled HTML file (.chm). Compiled HTML +# files are now used as the Windows 98 help format, and will replace the old +# Windows help format (.hlp) on all Windows platforms in the future. Compressed +# HTML files also contain an index, a table of contents, and you can search for +# words in the documentation. The HTML workshop also contains a viewer for +# compressed HTML files. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_HTMLHELP = NO + +# The CHM_FILE tag can be used to specify the file name of the resulting .chm +# file. You can add a path in front of the file if the result should not be +# written to the html output directory. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +CHM_FILE = + +# The HHC_LOCATION tag can be used to specify the location (absolute path +# including file name) of the HTML help compiler (hhc.exe). If non-empty, +# doxygen will try to run the HTML help compiler on the generated index.hhp. +# The file has to be specified with full path. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +HHC_LOCATION = + +# The GENERATE_CHI flag controls if a separate .chi index file is generated +# (YES) or that it should be included in the main .chm file (NO). +# The default value is: NO. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +GENERATE_CHI = NO + +# The CHM_INDEX_ENCODING is used to encode HtmlHelp index (hhk), content (hhc) +# and project file content. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +CHM_INDEX_ENCODING = + +# The BINARY_TOC flag controls whether a binary table of contents is generated +# (YES) or a normal table of contents (NO) in the .chm file. Furthermore it +# enables the Previous and Next buttons. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +BINARY_TOC = NO + +# The TOC_EXPAND flag can be set to YES to add extra items for group members to +# the table of contents of the HTML help documentation and to the tree view. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +TOC_EXPAND = NO + +# If the GENERATE_QHP tag is set to YES and both QHP_NAMESPACE and +# QHP_VIRTUAL_FOLDER are set, an additional index file will be generated that +# can be used as input for Qt's qhelpgenerator to generate a Qt Compressed Help +# (.qch) of the generated HTML documentation. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_QHP = NO + +# If the QHG_LOCATION tag is specified, the QCH_FILE tag can be used to specify +# the file name of the resulting .qch file. The path specified is relative to +# the HTML output folder. +# This tag requires that the tag GENERATE_QHP is set to YES. + +QCH_FILE = + +# The QHP_NAMESPACE tag specifies the namespace to use when generating Qt Help +# Project output. For more information please see Qt Help Project / Namespace +# (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#namespace). +# The default value is: org.doxygen.Project. +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_NAMESPACE = org.doxygen.Project + +# The QHP_VIRTUAL_FOLDER tag specifies the namespace to use when generating Qt +# Help Project output. For more information please see Qt Help Project / Virtual +# Folders (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#virtual-folders). +# The default value is: doc. +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_VIRTUAL_FOLDER = doc + +# If the QHP_CUST_FILTER_NAME tag is set, it specifies the name of a custom +# filter to add. For more information please see Qt Help Project / Custom +# Filters (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom-filters). +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_CUST_FILTER_NAME = + +# The QHP_CUST_FILTER_ATTRS tag specifies the list of the attributes of the +# custom filter to add. For more information please see Qt Help Project / Custom +# Filters (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom-filters). +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_CUST_FILTER_ATTRS = + +# The QHP_SECT_FILTER_ATTRS tag specifies the list of the attributes this +# project's filter section matches. Qt Help Project / Filter Attributes (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#filter-attributes). +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_SECT_FILTER_ATTRS = + +# The QHG_LOCATION tag can be used to specify the location (absolute path +# including file name) of Qt's qhelpgenerator. If non-empty doxygen will try to +# run qhelpgenerator on the generated .qhp file. +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHG_LOCATION = + +# If the GENERATE_ECLIPSEHELP tag is set to YES, additional index files will be +# generated, together with the HTML files, they form an Eclipse help plugin. To +# install this plugin and make it available under the help contents menu in +# Eclipse, the contents of the directory containing the HTML and XML files needs +# to be copied into the plugins directory of eclipse. The name of the directory +# within the plugins directory should be the same as the ECLIPSE_DOC_ID value. +# After copying Eclipse needs to be restarted before the help appears. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_ECLIPSEHELP = NO + +# A unique identifier for the Eclipse help plugin. When installing the plugin +# the directory name containing the HTML and XML files should also have this +# name. Each documentation set should have its own identifier. +# The default value is: org.doxygen.Project. +# This tag requires that the tag GENERATE_ECLIPSEHELP is set to YES. + +ECLIPSE_DOC_ID = org.doxygen.Project + +# If you want full control over the layout of the generated HTML pages it might +# be necessary to disable the index and replace it with your own. The +# DISABLE_INDEX tag can be used to turn on/off the condensed index (tabs) at top +# of each HTML page. A value of NO enables the index and the value YES disables +# it. Since the tabs in the index contain the same information as the navigation +# tree, you can set this option to YES if you also set GENERATE_TREEVIEW to YES. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +DISABLE_INDEX = NO + +# The GENERATE_TREEVIEW tag is used to specify whether a tree-like index +# structure should be generated to display hierarchical information. If the tag +# value is set to YES, a side panel will be generated containing a tree-like +# index structure (just like the one that is generated for HTML Help). For this +# to work a browser that supports JavaScript, DHTML, CSS and frames is required +# (i.e. any modern browser). Windows users are probably better off using the +# HTML help feature. Via custom style sheets (see HTML_EXTRA_STYLESHEET) one can +# further fine tune the look of the index (see "Fine-tuning the output"). As an +# example, the default style sheet generated by doxygen has an example that +# shows how to put an image at the root of the tree instead of the PROJECT_NAME. +# Since the tree basically has the same information as the tab index, you could +# consider setting DISABLE_INDEX to YES when enabling this option. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_TREEVIEW = NO + +# When both GENERATE_TREEVIEW and DISABLE_INDEX are set to YES, then the +# FULL_SIDEBAR option determines if the side bar is limited to only the treeview +# area (value NO) or if it should extend to the full height of the window (value +# YES). Setting this to YES gives a layout similar to +# https://docs.readthedocs.io with more room for contents, but less room for the +# project logo, title, and description. If either GENERATE_TREEVIEW or +# DISABLE_INDEX is set to NO, this option has no effect. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +FULL_SIDEBAR = NO + +# The ENUM_VALUES_PER_LINE tag can be used to set the number of enum values that +# doxygen will group on one line in the generated HTML documentation. +# +# Note that a value of 0 will completely suppress the enum values from appearing +# in the overview section. +# Minimum value: 0, maximum value: 20, default value: 4. +# This tag requires that the tag GENERATE_HTML is set to YES. + +ENUM_VALUES_PER_LINE = 4 + +# If the treeview is enabled (see GENERATE_TREEVIEW) then this tag can be used +# to set the initial width (in pixels) of the frame in which the tree is shown. +# Minimum value: 0, maximum value: 1500, default value: 250. +# This tag requires that the tag GENERATE_HTML is set to YES. + +TREEVIEW_WIDTH = 250 + +# If the EXT_LINKS_IN_WINDOW option is set to YES, doxygen will open links to +# external symbols imported via tag files in a separate window. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +EXT_LINKS_IN_WINDOW = NO + +# If the OBFUSCATE_EMAILS tag is set to YES, doxygen will obfuscate email +# addresses. +# The default value is: YES. +# This tag requires that the tag GENERATE_HTML is set to YES. + +OBFUSCATE_EMAILS = YES + +# If the HTML_FORMULA_FORMAT option is set to svg, doxygen will use the pdf2svg +# tool (see https://github.com/dawbarton/pdf2svg) or inkscape (see +# https://inkscape.org) to generate formulas as SVG images instead of PNGs for +# the HTML output. These images will generally look nicer at scaled resolutions. +# Possible values are: png (the default) and svg (looks nicer but requires the +# pdf2svg or inkscape tool). +# The default value is: png. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_FORMULA_FORMAT = png + +# Use this tag to change the font size of LaTeX formulas included as images in +# the HTML documentation. When you change the font size after a successful +# doxygen run you need to manually remove any form_*.png images from the HTML +# output directory to force them to be regenerated. +# Minimum value: 8, maximum value: 50, default value: 10. +# This tag requires that the tag GENERATE_HTML is set to YES. + +FORMULA_FONTSIZE = 10 + +# The FORMULA_MACROFILE can contain LaTeX \newcommand and \renewcommand commands +# to create new LaTeX commands to be used in formulas as building blocks. See +# the section "Including formulas" for details. + +FORMULA_MACROFILE = + +# Enable the USE_MATHJAX option to render LaTeX formulas using MathJax (see +# https://www.mathjax.org) which uses client side JavaScript for the rendering +# instead of using pre-rendered bitmaps. Use this if you do not have LaTeX +# installed or if you want to formulas look prettier in the HTML output. When +# enabled you may also need to install MathJax separately and configure the path +# to it using the MATHJAX_RELPATH option. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +USE_MATHJAX = YES + +# With MATHJAX_VERSION it is possible to specify the MathJax version to be used. +# Note that the different versions of MathJax have different requirements with +# regards to the different settings, so it is possible that also other MathJax +# settings have to be changed when switching between the different MathJax +# versions. +# Possible values are: MathJax_2 and MathJax_3. +# The default value is: MathJax_2. +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_VERSION = MathJax_2 + +# When MathJax is enabled you can set the default output format to be used for +# the MathJax output. For more details about the output format see MathJax +# version 2 (see: +# http://docs.mathjax.org/en/v2.7-latest/output.html) and MathJax version 3 +# (see: +# http://docs.mathjax.org/en/latest/web/components/output.html). +# Possible values are: HTML-CSS (which is slower, but has the best +# compatibility. This is the name for Mathjax version 2, for MathJax version 3 +# this will be translated into chtml), NativeMML (i.e. MathML. Only supported +# for NathJax 2. For MathJax version 3 chtml will be used instead.), chtml (This +# is the name for Mathjax version 3, for MathJax version 2 this will be +# translated into HTML-CSS) and SVG. +# The default value is: HTML-CSS. +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_FORMAT = HTML-CSS + +# When MathJax is enabled you need to specify the location relative to the HTML +# output directory using the MATHJAX_RELPATH option. The destination directory +# should contain the MathJax.js script. For instance, if the mathjax directory +# is located at the same level as the HTML output directory, then +# MATHJAX_RELPATH should be ../mathjax. The default value points to the MathJax +# Content Delivery Network so you can quickly see the result without installing +# MathJax. However, it is strongly recommended to install a local copy of +# MathJax from https://www.mathjax.org before deployment. The default value is: +# - in case of MathJax version 2: https://cdn.jsdelivr.net/npm/mathjax@2 +# - in case of MathJax version 3: https://cdn.jsdelivr.net/npm/mathjax@3 +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_RELPATH = + +# The MATHJAX_EXTENSIONS tag can be used to specify one or more MathJax +# extension names that should be enabled during MathJax rendering. For example +# for MathJax version 2 (see +# https://docs.mathjax.org/en/v2.7-latest/tex.html#tex-and-latex-extensions): +# MATHJAX_EXTENSIONS = TeX/AMSmath TeX/AMSsymbols +# For example for MathJax version 3 (see +# http://docs.mathjax.org/en/latest/input/tex/extensions/index.html): +# MATHJAX_EXTENSIONS = ams +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_EXTENSIONS = + +# The MATHJAX_CODEFILE tag can be used to specify a file with javascript pieces +# of code that will be used on startup of the MathJax code. See the MathJax site +# (see: +# http://docs.mathjax.org/en/v2.7-latest/output.html) for more details. For an +# example see the documentation. +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_CODEFILE = + +# When the SEARCHENGINE tag is enabled doxygen will generate a search box for +# the HTML output. The underlying search engine uses javascript and DHTML and +# should work on any modern browser. Note that when using HTML help +# (GENERATE_HTMLHELP), Qt help (GENERATE_QHP), or docsets (GENERATE_DOCSET) +# there is already a search function so this one should typically be disabled. +# For large projects the javascript based search engine can be slow, then +# enabling SERVER_BASED_SEARCH may provide a better solution. It is possible to +# search using the keyboard; to jump to the search box use + S +# (what the is depends on the OS and browser, but it is typically +# , /